In [7]:
from pathlib import Path
import importlib
import os
import numpy as np
from anndata import read_h5ad
#from scipy.sparse import dok_matrix
import json
from operator import itemgetter

In [8]:
files = list(Path("./datasets").glob("*.h5ad"))

def check_mem_usage():
    import psutil
    import os
    
    process = psutil.Process(os.getpid())
    memory_info = process.memory_info()
    print(f"Current memory usage: {memory_info.rss / (1024 * 1024):.2f} MB")

def force_gc():
    import gc
    check_mem_usage()
    print(f"collecting {gc.collect()} objects")
    check_mem_usage()

for f in files:
    print(f"Processing dataset {f.stem}")
    adata = read_h5ad(f)
    
    # filter out cells that originate from other datasets
    adata = adata[adata.obs["is_primary_data"] == True]
    
    dataset_id = f.stem
    
    metadata_markers = adata.obs[["disease", "tissue"]].drop_duplicates()
    
    analysis_data = {}
    analysis_data_path = os.path.join("output", "analysis", dataset_id + ".json")
    if os.path.exists(analysis_data_path):
        with open(analysis_data_path, "r") as in_f:
            analysis_data = json.load(in_f)
    
    else:
        os.makedirs(os.path.dirname(analysis_data_path), exist_ok=True)
    
    # assume that the imputation methods were all run correctly and the subset results are present, splitting
    i = 0
    for _, (disease, tissue) in metadata_markers.iterrows():
        i += 1
        print(f"Analyzing sample {i}/{len(metadata_markers)}: {tissue} ({disease})")
        
        mask = (adata.obs["disease"] == disease) & (adata.obs["tissue"] == tissue)
        adata_sample = adata[mask]
        input = adata_sample.X.toarray()
        print(f"- Sample matrix has shape {input.shape}")
        # `sample_analysis` refers to data about the entire sample
        sample_analysis = analysis_data.setdefault(f"{disease}, {tissue}", {})
        sample_analysis["shape"] = input.shape
        print("adata_sample.obs: ", adata_sample.obs)
        # print("adata_sample.obs_keys(): ", adata_sample.obs_keys())
        # print("adata_sample.var_keys(): ", adata_sample.var_keys())
        # print("adata_sample.obs.columns: ", adata_sample.obs.columns)
        # print("adata_sample.var.columns: ", adata_sample.var.columns)
        # print("adata_sample.obs.index: ", adata_sample.obs.index)
        # print("adata_sample.var.index: ", adata_sample.var.index)
        # print("adata_sample.obs_names: ", adata_sample.obs_names)
        # print("adata_sample.var_names: ", adata_sample.var_names)
        # print("adata_sample.obs.index.values: ", adata_sample.obs.index.values)
        print("adata_sample.var[feature_name]: ", adata_sample.var["feature_name"])
        print("adata_sample.obs[Celltype]: ", adata_sample.obs.index)
        # print("adata_sample.var.index.values: ", adata_sample.var.index.values)
    
        # prematurely dump the json after testing every sample for easy monitoring
        with open(analysis_data_path, "w") as out_f:
            json.dump(analysis_data, out_f, indent=4)
            
        force_gc()
        print("")


Processing dataset 63ff2c52-cb63-44f0-bac3-d0b33373e312
Analyzing sample 1/9: lamina propria of mucosa of colon (Crohn disease)
- Sample matrix has shape (8076, 27289)
adata_sample.obs:                             biosample_id  n_genes  n_counts  Type donor_id  \
cell_id                                                                     
N105446_L-ATTGTTCCAAACGTGG    N105446_L   2205.0   76317.0  NonI   105446   
N105446_L-TCGACGGGTGAGACCA    N105446_L   2088.0   67801.0  NonI   105446   
N105446_L-AGTAACCGTTAAGGGC    N105446_L   1862.0   58097.0  NonI   105446   
N105446_L-GCAGGCTTCGCTAAAC    N105446_L   5472.0   43418.0  NonI   105446   
N105446_L-ATCTTCATCTGAGAGG    N105446_L   5650.0   42871.0  NonI   105446   
...                                 ...      ...       ...   ...      ...   
N130084_L-GTGGTTACAGTTCCAA    N130084_L    305.0     488.0  NonI   130084   
N130084_L-TATCCTATCGTTCATT    N130084_L    294.0     490.0  NonI   130084   
N130084_L-TCATCCGGTATGATCC    N130084_L    

In [9]:
# load json data
with open("unique_celltypes_per_dataset.json", "r") as f:
    unique_cell_types_per_dataset = json.load(f)
# get a list of unique cell type over entries
unique_cell_types = set()
for dataset, cell_types in unique_cell_types_per_dataset.items():
    unique_cell_types.update(cell_types)

In [10]:
for item in unique_cell_types:
    print(item)

Stem cells OLFM4
Paneth cells
Goblet cells SPINK4
Goblet cells MUC2 TFF1
Tuft cells
Epithelial Cycling cells
Stem cells OLFM4 PCNA
Stem cells OLFM4 LGR5
Enterocytes BEST4
Enterocytes TMIGD1 MEP1A
Enteroendocrine cells
Enterocytes CA1 CA2 CA4-
Goblet cells MUC2 TFF1-


In [11]:
# test = {
#     "Enterocytes BEST4": {
#         BEST4, OTOP2, CA7, GUCA2A, GUCA2B, SPIB, CFTR # with region-specific additions like CFTR in the small intestine and OTOP2 in the colon.
#     },
#     "Goblet cells MUC2 TFF1":
#         {
#             MUC2, TFF1, TFF3, FCGBP, AGR2, SPDEF
#         }
    
#     "Tuft cells": {
#     POU2F3, DCLK1
#     }

#     "Goblet cells SPINK4": {
#         MUC2, SPINK4
#     }
#     "Enterocytes TMIGD1 MEP1A": {
#     CA1, CA2, TMIGD1, MEP1A
#     }
#     "Enterocytes CA1 CA2 CA4-": {
#         CA1, CA2
#     }
#     Goblet cells MUC2 TFF1-: MUC2
#     Epithelial Cycling cells: LGR5, OLFM4, MKI67
#     Enteroendocrine cells: CHGA, GCG, GIP, CCK
#     Stem cells OLFM4: OLFM4, LGR5
#     Paneth cells: LYZ, DEFA5
#         "Stem cells OLFM4 LGR5": {
#     OLFM4, LGR5, and Ascl2
#     }
        
#         Stem cells OLFM4 PCNA {
#         OLFM4 and PCNA
#         LGR5, Ascl2, SOX9, and TERT}
# }

In [None]:
marker_genes = {
  "Enterocytes BEST4": [
    "BEST4",
    "OTOP2",
    "CA7",
    "GUCA2A",
    "GUCA2B",
    "SPIB",
    "CFTR"
  ],
  "Goblet cells MUC2 TFF1": [
    "MUC2",
    "TFF1",
    "TFF3",
    "FCGBP",
    "AGR2",
    "SPDEF"
  ],
  "Tuft cells": [
    "POU2F3",
    "DCLK1"
  ],
  "Goblet cells SPINK4": [
    "MUC2",
    "SPINK4"
  ],
  "Enterocytes TMIGD1 MEP1A": [
    "CA1",
    "CA2",
    "TMIGD1",
    "MEP1A"
  ],
  "Enterocytes CA1 CA2 CA4-": [
    "CA1",
    "CA2"
  ],
  "Goblet cells MUC2 TFF1-": [
    "MUC2"
  ],
  "Epithelial Cycling cells": [
    "LGR5",
    "OLFM4",
    "MKI67"
  ],
  "Enteroendocrine cells": [
    "CHGA",
    "GCG",
    "GIP",
    "CCK"
  ],
  "Stem cells OLFM4": [
    "OLFM4",
    "LGR5"
  ],
  "Stem cells OLFM4 LGR5": [
    "OLFM4",
    "LGR5",
    "ASCL2"
  ],
  "Stem cells OLFM4 PCNA": [
    "OLFM4",
    "PCNA",
    "LGR5",
    "ASCL2",
    "SOX9",
    "TERT"
  ],
  "Paneth cells": [
    "LYZ",
    "DEFA5"
  ]
}


import os
import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from anndata import read_h5ad

methods = [
    "MAGIC",
    "SAUCIE"
]


def expand_magic_matrix(y, reduced_matrix):
    expanded = np.zeros(y.shape)
    
    nonzero_counts = (y != 0).sum(axis=0)
    kept_columns = np.where(nonzero_counts >= 5)[0]
    removed_columns = np.where(nonzero_counts < 5)[0]
    
    # Copy values from reduced matrix to their original positions
    for j_reduced, j_original in enumerate(kept_columns):
        expanded[:, j_original] = reduced_matrix[:, j_reduced]
        
    for _, j_original in enumerate(removed_columns):
        expanded[:, j_original] = y[:, j_original]
    
    return expanded

data_dir = Path("./datasets")
files = list(data_dir.glob("*.h5ad"))

output_dir = Path("output/barplots")
output_dir.mkdir(exist_ok=True, parents=True)

for f in files:
    print(f"Processing dataset {f.stem}")
    adata = read_h5ad(f)
    
    # Filter out non-primary data if needed
    adata = adata[adata.obs["is_primary_data"] == True]
    
    # -------------------------------------------------------
    # 1) Overwrite the var_names with "feature_name" so we can match gene symbols
    #    We'll also strip off trailing version suffixes like ".1", ".10" if present
    # -------------------------------------------------------
    if "feature_name" in adata.var.columns:
        # Convert to string, strip trailing version number if they exist
        adata.var_names = adata.var["feature_name"].astype(str).str.replace(r"\.\d+$", "", regex=True)
    else:
        print("Warning: 'feature_name' not found in adata.var.columns. Gene symbol matching may fail.")
    
    # We'll gather the sample-level splits by "disease" and "tissue"
    metadata_markers = adata.obs[["disease", "tissue"]].drop_duplicates()
    
    # For each subset, do the analysis
    for idx, (disease, tissue) in metadata_markers.iterrows():
        mask = (adata.obs["disease"] == disease) & (adata.obs["tissue"] == tissue)
        adata_sample = adata[mask]
        
        # The original expression matrix (before imputation)
        X_original = adata_sample.X.toarray()  # shape: nCells x nGenes
        
        # 'obs_celltype' is the cell-type labeling in this subset
        obs_celltype = adata_sample.obs["Celltype"]
        
        print(f"  Analyzing subset: {tissue} ({disease}), shape={X_original.shape}")
        
        # For each cell type in marker_genes, check if that celltype is present
        # (exact or partial match in obs_celltype)
        for ctype_key, ref_genes in marker_genes.items():
            
            # We'll do exact matching here:
            matching_mask = (obs_celltype == ctype_key)
            n_cells = np.sum(matching_mask)
            
            if n_cells < 5:
                print(f"    [Skipping: too few cells of type {ctype_key} ({n_cells} cells)]")
                continue
            
            print(f"    Celltype={ctype_key}, #cells={n_cells}")
            
            # Subset to just those cells
            adata_ct = adata_sample[matching_mask]
            X_ct_original = adata_ct.X.toarray()  # shape: n_cells_of_ctype x nGenes
            var_names_ct = adata_ct.var_names     # Now these are the gene symbols after our overwrite
            
            # For each reference gene in that cell type, do the co-occurrence analysis
            for ref_gene in ref_genes:
                if ref_gene not in var_names_ct:
                    # skip if the gene is not found
                    print(f"      [Skipping: {ref_gene} not found in cell type {ctype_key}]")
                    continue
                
                # Get the column index of this reference gene
                y_idx = var_names_ct.get_loc(ref_gene)  # integer index in var_names_ct
                
                # For each imputation method, load the imputed matrix and analyze
                for method_name in methods:
                    imputed_path = Path("output") / method_name / f"{f.stem}" / disease / (tissue + ".npy")
                    if not imputed_path.exists():
                        print(f"      [Skipping: no imputed data for {method_name}]")
                        continue
                        
                    X_ct_imputed = np.load(imputed_path)  # shape should match X_ct_original
                    
                    if method_name == "MAGIC":
                      X_ct_imputed = expand_magic_matrix(X_original, X_ct_imputed)
                    if X_ct_imputed.shape != X_ct_original.shape:
                        print(f"      [Warning: shape mismatch for {method_name}]")
                        continue
                    
                    # 1) Identify which cells express Y after imputation
                    y_expr_imputed = X_ct_imputed[:, y_idx]
                    y_positive_imputed = (y_expr_imputed > 0)
                    
                    # skip if few or none express Y
                    n_ypos = y_positive_imputed.sum()
                    if n_ypos < 5:
                        print(f"      [Skipping: too few Y+ cells ({n_ypos}) for {ref_gene}]")
                        continue
                    
                    # 2) Among Y+ cells, count how many express each gene
                    coexpr_counts = np.sum(X_ct_imputed[y_positive_imputed, :] > 0, axis=0)
                    coexpr_fraction = coexpr_counts / n_ypos
                    
                    # 3) Sort descending to pick top 10
                    topN = 10
                    top_indices = np.argsort(coexpr_fraction)[::-1][:topN]
                    top_genes = var_names_ct[top_indices]
                    
                    # Fractions AFTER imputation (all Y+ cells):
                    coexpr_counts_after = coexpr_counts[top_indices]
                    frac_after = coexpr_counts_after / n_ypos
                    
                    # 4) Fractions BEFORE imputation
                    y_expr_before = X_ct_original[:, y_idx]
                    y_positive_before = (y_expr_before > 0)
                    n_ypos_before = y_positive_before.sum()
                    
                    frac_before_list = []
                    for gidx in top_indices:
                        both_mask = (y_positive_before & (X_ct_original[:, gidx] > 0))
                        frac_before_list.append(both_mask.sum() / (n_ypos_before + 1e-9))
                    frac_before = np.array(frac_before_list)
                    
                    # 5) Fractions among cells that had Y=0 and got imputed to Y>0
                    y_zero2nonzero = ((y_expr_before == 0) & (y_expr_imputed > 0))
                    n_y_zero2nonzero = y_zero2nonzero.sum()
                    
                    frac_zero2nonzero_list = []
                    for gidx in top_indices:
                        both_mask = (y_zero2nonzero & (X_ct_imputed[:, gidx] > 0))
                        frac_zero2nonzero_list.append(both_mask.sum() / (n_y_zero2nonzero + 1e-9))
                    frac_zero2nonzero = np.array(frac_zero2nonzero_list)
                    
                    # 6) Plot the bar chart with 3 bars per gene
                    x = np.arange(len(top_genes))  # e.g. 0..9
                    width = 0.25
                    fig, ax = plt.subplots(figsize=(8, 4))
                    ax.bar(x - width, frac_before, width, label='Before')
                    ax.bar(x,         frac_after,  width, label='After')
                    ax.bar(x + width, frac_zero2nonzero, width, label='Y=0→non0')
                    
                    ax.set_xticks(x)
                    ax.set_xticklabels(top_genes, rotation=90)
                    ax.set_ylabel("Fraction of Y+ cells also expressing gene")
                    ax.set_title(
                        f"{f.stem}\n{disease}, {tissue}\nCelltype={ctype_key}, Ref={ref_gene}, Method={method_name}"
                    )
                    ax.legend()
                    plt.tight_layout()
                    
                    # 7) Save figure
                    plot_out = output_dir / f"{f.stem}_{disease}_{tissue}_{ctype_key}_{ref_gene}_{method_name}.png"
                    print(f"      Saving plot to {plot_out}")
                    plt.savefig(plot_out, dpi=120)
                    plt.close()
        
        print()  # blank line for readability


Processing dataset 63ff2c52-cb63-44f0-bac3-d0b33373e312
  Analyzing subset: lamina propria of mucosa of colon (Crohn disease), shape=(8076, 27289)
    Celltype=Enterocytes BEST4, #cells=130
    Celltype=Goblet cells MUC2 TFF1, #cells=185
      [Skipping: FCGBP not found in cell type Goblet cells MUC2 TFF1]
    Celltype=Tuft cells, #cells=76
    Celltype=Goblet cells SPINK4, #cells=29
    Celltype=Enterocytes TMIGD1 MEP1A, #cells=1503
    Celltype=Enterocytes CA1 CA2 CA4-, #cells=2555


KeyboardInterrupt: 

In [13]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from anndata import read_h5ad

# ----------------- MARKER GENES DICTIONARY -----------------
marker_genes = {
  "Enterocytes BEST4": [
    "BEST4", "OTOP2", "CA7", "GUCA2A", "GUCA2B", "SPIB", "CFTR"
  ],
  "Goblet cells MUC2 TFF1": [
    "MUC2", "TFF1", "TFF3", "FCGBP", "AGR2", "SPDEF"
  ],
  "Tuft cells": [
    "POU2F3", "DCLK1"
  ],
  "Goblet cells SPINK4": [
    "MUC2", "SPINK4"
  ],
  "Enterocytes TMIGD1 MEP1A": [
    "CA1", "CA2", "TMIGD1", "MEP1A"
  ],
  "Enterocytes CA1 CA2 CA4-": [
    "CA1", "CA2"
  ],
  "Goblet cells MUC2 TFF1-": [
    "MUC2"
  ],
  "Epithelial Cycling cells": [
    "LGR5", "OLFM4", "MKI67"
  ],
  "Enteroendocrine cells": [
    "CHGA", "GCG", "GIP", "CCK"
  ],
  "Stem cells OLFM4": [
    "OLFM4", "LGR5"
  ],
  "Stem cells OLFM4 LGR5": [
    "OLFM4", "LGR5", "ASCL2"
  ],
  "Stem cells OLFM4 PCNA": [
    "OLFM4", "PCNA", "LGR5", "ASCL2", "SOX9", "TERT"
  ],
  "Paneth cells": [
    "LYZ", "DEFA5"
  ]
}

# The imputation methods you want to compare
methods = ["MAGIC", "SAUCIE"]


def expand_magic_matrix(original_matrix, magic_imputed):
    """
    If your MAGIC pipeline removed columns for rare genes,
    this function re-inserts them. Otherwise, you can skip it.
    """
    expanded = np.zeros_like(original_matrix)
    nonzero_counts = (original_matrix != 0).sum(axis=0)
    kept_columns = np.where(nonzero_counts >= 5)[0]
    removed_columns = np.where(nonzero_counts < 5)[0]

    # Place the imputed columns back in the original positions
    for j_reduced, j_original in enumerate(kept_columns):
        expanded[:, j_original] = magic_imputed[:, j_reduced]

    # For columns that MAGIC removed, keep original counts
    for j_original in removed_columns:
        expanded[:, j_original] = original_matrix[:, j_original]

    return expanded


# We'll build one big DataFrame with columns=methods, and rows=(dataset, disease, tissue, celltype, gene)
# The cell values = fraction of cells that express that gene after imputation.

# Create an empty list to store row records
rows_list = []

data_dir = Path("./datasets")
files = list(data_dir.glob("*.h5ad"))

for f in files:
    dataset_id = f.stem
    print(f"Processing dataset {dataset_id}")
    adata = read_h5ad(f)

    # Filter out non-primary data if needed
    if "is_primary_data" in adata.obs.columns:
        adata = adata[adata.obs["is_primary_data"] == True]

    # Overwrite var_names with gene symbols from feature_name (strip version suffix)
    if "feature_name" in adata.var.columns:
        adata.var_names = adata.var["feature_name"].astype(str).str.replace(r"\.\d+$", "", regex=True)
    else:
        print("Warning: 'feature_name' not found in adata.var.columns. Gene symbol matching may fail.")

    # We'll gather the sample-level splits by "disease" and "tissue"
    # If your data doesn't have these columns, adapt accordingly
    if not {"disease", "tissue"}.issubset(adata.obs.columns):
        # If missing, we skip or just treat the entire dataset as one subset
        disease_tissue_df = pd.DataFrame([("Unknown", "Unknown")], columns=["disease", "tissue"])
    else:
        disease_tissue_df = adata.obs[["disease", "tissue"]].drop_duplicates()

    for _, (disease, tissue) in disease_tissue_df.iterrows():
        # Subset the AnnData
        mask = (adata.obs["disease"] == disease) & (adata.obs["tissue"] == tissue)
        adata_sample = adata[mask]

        if adata_sample.n_obs < 5:
            # skip if too few cells in that subset
            continue

        X_original = adata_sample.X.toarray()  # shape: nCells x nGenes
        obs_celltype = adata_sample.obs["Celltype"] if "Celltype" in adata_sample.obs.columns else None

        # If no "Celltype" col, skip or handle differently
        if obs_celltype is None:
            continue

        # For each known cell type in our marker_genes:
        for ctype_key, gene_list in marker_genes.items():
            # Subset to cells of this type
            matching_mask = (obs_celltype == ctype_key)
            n_cells_type = matching_mask.sum()
            if n_cells_type < 5:
                # skip
                continue

            adata_ct = adata_sample[matching_mask]
            X_ct_original = adata_ct.X.toarray()  # shape: (n_cells_of_type, nGenes)
            var_names_ct = adata_ct.var_names

            # For each gene in the marker list:
            for gene in gene_list:
                if gene not in var_names_ct:
                    # That gene doesn't exist in the final data
                    # We'll record fraction=0 or NaN to indicate missing
                    row_dict = {
                        "dataset": dataset_id,
                        "disease": disease,
                        "tissue": tissue,
                        "cell_type": ctype_key,
                        "gene": gene
                    }
                    # Put np.nan for each method
                    for m in methods:
                        row_dict[m] = np.nan
                    rows_list.append(row_dict)
                    continue

                gene_idx = var_names_ct.get_loc(gene)

                # For each method, load the imputed matrix & compute fraction
                row_dict = {
                    "dataset": dataset_id,
                    "disease": disease,
                    "tissue": tissue,
                    "cell_type": ctype_key,
                    "gene": gene
                }

                for m in methods:
                    imputed_path = Path("output") / m / dataset_id / disease / (tissue + ".npy")
                    if not imputed_path.exists():
                        # no data => store NaN
                        row_dict[m] = np.nan
                        continue

                    X_imputed_subset = np.load(imputed_path)
                    # If method=MAGIC, might need to expand columns
                    if m == "MAGIC":
                        X_imputed_subset = expand_magic_matrix(X_ct_original, X_imputed_subset)

                    # Check shape
                    if X_imputed_subset.shape != X_ct_original.shape:
                        print(f"[Warning: shape mismatch for {m}, skipping.]")
                        row_dict[m] = np.nan
                        continue

                    # fraction of cells that express the gene after imputation
                    expr_values = X_imputed_subset[:, gene_idx]
                    fraction_expr = np.mean(expr_values > 0)
                    row_dict[m] = fraction_expr

                # Add this row to our records
                rows_list.append(row_dict)

# ------------------- Build the final DataFrame ---------------------
df_recovery = pd.DataFrame(rows_list)
# Columns: [dataset, disease, tissue, cell_type, gene, MAGIC, SAUCIE, ...]

# Save as CSV
df_recovery.to_csv("marker_gene_recovery.csv", index=False)
print("Saved marker gene recovery fractions to 'marker_gene_recovery.csv'.")


# ------------------- OPTIONAL: Create Heatmaps ---------------------
# Because we have multiple datasets, diseases, tissues, & cell types, the DataFrame can be big.
# You may want to pivot & plot a separate heatmap for each dataset or each (disease,tissue).

# Example: pivot for a single dataset
unique_datasets = df_recovery["dataset"].unique()
for ds in unique_datasets:
    df_ds = df_recovery[df_recovery["dataset"] == ds]
    # Also might want to pick one disease/tissue or combine them
    # Let's do a demonstration of a pivot just by (cell_type, gene) as rows, methods as columns
    # We'll do an average across all disease/tissue for that dataset:

    # group by (cell_type, gene) and take the average across all subsets
    df_agg = df_ds.groupby(["cell_type", "gene"])[methods].mean().reset_index()

    # pivot => index=(cell_type,gene), columns=method, values= fraction
    df_pivot = df_agg.pivot(index=["cell_type", "gene"], columns=None)  # We'll flatten columns next

    # Because we have multiple methods => we'll get a multi-index
    # we can rearrange manually or do a separate pivot for each method
    # simpler approach: just do a pivot_table for each method individually

    # We'll do a single heatmap in a wide format: rows=(cell_type,gene), columns=method
    # so let's do a pivot_table with index=(cell_type,gene) and columns=method
    # value = fraction
    df_pivot = df_agg.pivot_table(index=["cell_type", "gene"], values=methods)

    # This yields a DataFrame with row index=(cell_type,gene), columns=[MAGIC, SAUCIE, ...]
    # Then we can make a heatmap:
    plt.figure(figsize=(10, 0.5*len(df_pivot) + 3))
    sns.heatmap(
        df_pivot, 
        annot=True, 
        fmt=".2f", 
        cmap="Blues",
        cbar_kws={"label": "Fraction Expressed (After Imputation)"}
    )
    plt.title(f"Marker Recovery: {ds} (avg across all disease/tissue)")
    plt.tight_layout()
    plt.savefig(f"marker_recovery_{ds}.png")
    plt.close()

print("Done creating optional heatmaps. See marker_recovery_<dataset>.png files.")


Processing dataset 63ff2c52-cb63-44f0-bac3-d0b33373e312


: 

: 

: 