In [None]:
import glob
import numpy as np
import scanpy as sq
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
import scanpy as sc
from tqdm import tqdm
import yaml

In [None]:
with open("../config.yaml", "r") as stream:
    DATASET_INFO = yaml.safe_load(stream)
DATASET_INFO

In [None]:
# Function to read a file and return the number of spots
def get_n_spots(file):
    
    adata = sc.read_h5ad(file, backed="r")
    if "bulk_norm_tpm_unstranded" in adata.var.columns:
        observed_bulk = adata.var.bulk_tpm_unstranded
    else:
        observed_bulk = []
    n = len(adata)
        
    
    return n, observed_bulk

In [None]:
import torch
num_workers = torch.get_num_threads()
num_workers

In [None]:
datasets = DATASET_INFO["DATASET_NAME"].keys()
datasets

In [None]:
data = []
for dataset in datasets:
    if "TCGA" in dataset:
        slide_type = dataset.split("_")[-1]
        dataset_name = dataset.replace(f"_{slide_type}", "")
        print(dataset_name, slide_type)
        metadata = pd.read_csv(f"../{dataset_name}/data/metadata_{slide_type}.csv")
        metadata = metadata[~metadata.image_path.duplicated()]
        metadata = metadata.set_index("id_pair")
        files = glob.glob(f'../{dataset_name}/out_benchmark_{slide_type}/data/h5ad/*.h5ad')
        files = np.array(files)
        files = files[np.array([int(i.split("/")[-1].split(".")[0]) in metadata.index for i in files])]

    else:
        files = glob.glob(f'../{dataset}/out_benchmark/data/h5ad/*.h5ad')
    files = [f for f in files if "all" not in f]
    observed_bulk = []
    n_spots_per_sample = []
    
    # Using ProcessPoolExecutor to parallelize the processing
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        # Submit all file processing tasks
        
        futures = [executor.submit(get_n_spots, f) for f in files]
    
        # Collect results as they complete
        for future in tqdm(as_completed(futures), total=len(files)):
            n, bulk = future.result()
            n_spots_per_sample.append(n)
            observed_bulk.append(bulk)

        n_spots_per_sample_df = pd.DataFrame(n_spots_per_sample, columns=["spots_per_sample"])
        n_spots_per_sample_df['dataset'] = dataset

    data.append([dataset, len(files), np.nansum(n_spots_per_sample)])
data = pd.DataFrame(data, columns=["dataset", 'samples', 'spots'])
data

In [None]:
data["data_type"] = data.dataset.apply(lambda x: "Digital ST" if "TCGA" in x else "Visium,\n10x Genomics")
data

In [None]:
data.query("data_type == 'Digital ST'")["spots"].sum()

In [None]:
data.query("data_type == 'Digital ST'")["samples"].sum()

In [None]:
DATASET_INFO['DATASET_NAME']

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 25})
# Data
datasets = [DATASET_INFO['DATASET_NAME'][d].split(" (n")[0].replace(' (F', '\n(F') for d in data.dataset.values]
samples = data.samples.values
spots = data.spots.values
data_type = data.data_type.values

# Define colors for different data types
color_map = {
    'Visium,\n10x Genomics': 'lightblue',
    'Digital ST': 'gray',
}

# Convert spots to millions for better readability
spots_millions = [s for s in spots]

# Sort by data_type first (Visium, 10x Genomics first), then by number of spots
sorted_indices = sorted(range(len(spots)), key=lambda i: (data_type[i] != 'Visium,\n10x Genomics', -spots[i]))
sorted_datasets = [datasets[i] for i in sorted_indices]
sorted_samples = [samples[i] for i in sorted_indices]
sorted_spots_millions = [spots_millions[i] for i in sorted_indices]
sorted_data_type = [data_type[i] for i in sorted_indices]

# Assign colors based on data type
colors = [color_map[dt] for dt in sorted_data_type]

# Create a figure and axis with increased width
fig, ax = plt.subplots(figsize=(18, 9))  # Adjust width (12) and height (6) as needed

# Create bar chart with colors based on data type
bars = ax.bar(sorted_datasets, sorted_spots_millions, color=colors)

# Annotate bars with the number of samples
for bar, sample in zip(bars, sorted_samples):
    yval = bar.get_height()
    ax.text(bar.get_x() + bar.get_width() / 2, yval, f'(n={sample})', 
            ha='center', va='bottom', color='blue')

# Add labels and title
ax.set_xlabel('Dataset', fontsize=30)
ax.set_ylabel('Spots', fontsize=30)
#ax.set_title('Number of Spots by Dataset')
# Rotate x-axis labels 90 degrees
plt.xticks(rotation=30)
# Customize y-axis to show 'M' for millions
ax.set_yscale('log')

# Create custom legend
legend_patches = [mpatches.Patch(color=color, label=label) for label, color in color_map.items()]
ax.legend(handles=legend_patches, title="")
plt.savefig('figures/Figure6A-dataset_size.png', dpi=300, bbox_inches = 'tight')
# Show the plot
plt.show()