In [7]:
import os
import sys
import numpy as np
import anndata as ad
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

sys.path.append("/mnt/lareaulab/reliscu/code")

from parse_gtf import *

In [8]:
# At SC level:

## Are CT DE genes also DE at exon level? 
    # Plot distribution of counts at gene level and exon(s) level per cell type
    # Question being asked: are reads at exon level capturing cell type specific activity

In [9]:
# At pseudobulk level:

## 1. Are CT DE genes correlated with most enriched module eigengene?  
    # Question being asked: is the module eigengene representing what it should?

## 2. Are CT DE gene exons correlated with most enriched module eigengene?
    # Question being asked: do cell type-specific exons track with cell type abundance?
    # Restrict to exons from genes that fulfill the first question, i.e. they track with ME.

## 3. Are CT DE genes exons correlated with their parent gene?
    # Question being asked: do cell type-specific exons track with cell type abundance?
    # Restrict to exons from genes that fulfill the first question, i.e. they track with ME.
    # Also restrict to exons that are DE at the single-cell level; there's no reason they shouldn't be correlated with gene expression in that case. 
    #       Otherwise, if the gene is also expressed in other cell types, it could be specific to a different cell type, or just non-specific.

In [10]:
def annotate_df_exons(exon_df, mapping_df):
    exon_df['gene_id'] = exon_df.index.str.split("_").str[0]
    exon_df['exon_id'] = exon_df.index.values
    exon_df_anno = pd.merge(mapping_df, exon_df, on="gene_id", how="right")
    exon_df_anno = exon_df_anno.set_index("exon_id").rename_axis(None)
    exon_df_anno = exon_df_anno.drop(columns=["gene_id"])
    return exon_df_anno

def normalize_counts(counts_df, quantile_cutoff=None):
    arr = counts_df.to_numpy()         
    row_mean = np.nanmean(arr, axis=1)      
    if quantile_cutoff is not None:
        threshold = np.quantile(row_mean, quantile_cutoff)
        mask = row_mean >= threshold        # boolean mask over rows

        arr = arr[mask, :]                  # keep only selected rows
        row_mean = row_mean[mask]           # (optional, if you still need it)
        new_index = counts_df.index[mask]
    else:
        new_index = counts_df.index
        
    total_expr = arr.sum(axis=0)
    arr_norm = (arr / total_expr) * 1e4
    counts_df_norm = pd.DataFrame(arr_norm,
                                  index=new_index,
                                  columns=counts_df.columns)

    return counts_df_norm

In [11]:
w_ctype = "L5_IT"

unique = True

In [12]:
# Parse GTF attribute column
gtf_file = "/mnt/lareaulab/reliscu/data/GENCODE/GRCm39/gencode.vM35.annotation.gtf"
gtf = gtf_parse(gtf_file)
gtf_subset = gtf.loc[gtf['feature'].isin(["gene"])]
attrs = gtf_subset['attribute'].apply(extract_attributes)
attrs_df = attrs.apply(pd.Series)
gtf_parsed = pd.concat([gtf_subset.drop(columns=["attribute"]), attrs_df], axis=1)
gtf_parsed['gene_id'] = gtf_parsed['gene_id'].str.split(".").str[0]

## Prep single-cell data

In [13]:
sdata = ad.read_h5ad("data/tasic_2018_ALM_STAR_SJ_counts_annotated_PSI.hd5")
adata = ad.read_h5ad("data/tasic_2018_ALM_STAR_model/tasic_2018_ALM_STAR_gene_counts_scVI.h5ad")

In [14]:
gene_counts_df = pd.DataFrame.sparse.from_spmatrix(
    adata.raw.X.T,
    columns=adata.obs_names,        # one per row (cells)
    index=adata.raw.var_names,  # one per column (genes/exons)
)
exon_PSI_df = pd.DataFrame(
    sdata.X.T, 
    index=sdata.var_names, 
    columns=sdata.obs_names
)
exon_counts_df = pd.DataFrame(
    sdata.layers['exon_counts'].T, 
    index=sdata.var_names, 
    columns=sdata.obs_names
)

In [15]:
mapping_df = gtf_parsed[['gene_id', 'gene_name']]

In [16]:
exon_expr_df = normalize_counts(exon_counts_df, quantile_cutoff=0.2)
gene_expr_df = normalize_counts(gene_counts_df, quantile_cutoff=0.5)

In [17]:
exon_counts_anno_df = annotate_df_exons(exon_counts_df, mapping_df)
exon_expr_anno_df = annotate_df_exons(exon_expr_df, mapping_df)
exon_PSI_anno_df = annotate_df_exons(exon_PSI_df, mapping_df)

In [18]:
gene_exon_df = exon_counts_anno_df.iloc[:,0].to_frame(name='gene_name')

## Single-cell

In [19]:
DE_genes = pd.read_csv(f"pairwise_DE_genes/{w_ctype}.csv")
                       
ctypes = np.unique(sdata.obs['cell_subclass'])

outdir = f"diagnostics/{w_ctype}/SC"
os.makedirs(outdir, exist_ok=True)

#### Q: Are CT DE genes also DE at exon level? 

In [20]:
# Get exons correspond to cell type DE genes
exon_expr_DE_df = exon_expr_anno_df.loc[exon_expr_anno_df['gene_name'].isin(DE_genes['Gene'].tolist())]

In [21]:
genes = exon_expr_DE_df['gene_name'].unique()

for gene in genes:
    # exons belonging to this gene
    w_exons = exon_expr_DE_df.loc[exon_expr_DE_df['gene_name'] == gene].index
    
    exon_cts_by_ctype = []
    gene_cts_by_ctype = []

    # collect per-cell-type data for this gene
    for ctype in ctypes:
        ctype_cells = sdata.obs.index[sdata.obs['cell_subclass'] == ctype]
        exon_cts_by_ctype.append(exon_expr_anno_df.loc[w_exons, ctype_cells])
        gene_cts_by_ctype.append(gene_expr_df.loc[gene, ctype_cells])

    # make a PDF per gene
    
    pdf_path = f"{outdir}/{gene}_gene_and_exon_distribution_by_cell_type.pdf"
    
    with PdfPages(pdf_path) as pdf:
        gene_data = [np.ravel(g) for g in gene_cts_by_ctype]
        fig, ax = plt.subplots(figsize=(10, 4))
        ax.boxplot(gene_data, labels=ctypes, showfliers=False)
        ax.set_ylabel("Gene-level counts")
        ax.set_xlabel("Cell subclass")
        ax.set_title(f"{gene}: gene-level counts by cell type")
        plt.xticks(rotation=90)
        fig.tight_layout()
        pdf.savefig(fig)
        plt.close(fig)

        for exon in w_exons:
            exon_data = []
            for exon_counts_subset in exon_cts_by_ctype:
                vals = exon_counts_subset.loc[exon].to_numpy()
                exon_data.append(vals)

            fig, ax = plt.subplots(figsize=(10, 4))
            ax.boxplot(exon_data, labels=ctypes, showfliers=False)
            ax.set_ylabel(f"Counts for exon {exon}")
            ax.set_xlabel("Cell subclass")
            ax.set_title(f"{gene}: exon {exon} counts by cell type")
            plt.xticks(rotation=90)
            fig.tight_layout()
            pdf.savefig(fig)
            plt.close(fig)

    print("Saved:", pdf_path)

Saved: diagnostics/L5_IT/SC/Msc_gene_and_exon_distribution_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/2810408I11Rik_gene_and_exon_distribution_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Lemd1_gene_and_exon_distribution_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Cnih3_gene_and_exon_distribution_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/4930447M23Rik_gene_and_exon_distribution_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/2600014E21Rik_gene_and_exon_distribution_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Lrrc4c_gene_and_exon_distribution_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Postn_gene_and_exon_distribution_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Agbl4_gene_and_exon_distribution_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Lat2_gene_and_exon_distribution_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Stx1a_gene_and_exon_distribution_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Dkkl1_gene_and_exon_distribution_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Mfsd13b_gene_a

#### A: Yes

#### Q: Are cell types correlated with DE genes at gene and exon-level?

In [None]:
exon_expr_DE_df = exon_expr_anno_df.loc[exon_expr_anno_df['gene_name'].isin(DE_genes['Gene'].tolist())]

In [None]:
for gene in exon_expr_DE_df['gene_name'].unique():
    w_exons = exon_expr_DE_df.loc[exon_expr_DE_df['gene_name'] == gene].index

    # build exon_expr_mat and correlations as before
    exon_expr_mat = exon_expr_anno_df.loc[w_exons, sdata.obs_names]

    gene_corr_by_ctype = []
    exon_corr_by_ctype = []   # list of Series, one per ctype

    for ctype in ctypes:
        indicator_vec = (sdata.obs['cell_subclass'] == ctype).astype(int)
        indicator_vec = indicator_vec.loc[sdata.obs_names]

        gene_expr_vec = gene_expr_df.loc[gene].loc[sdata.obs_names]

        # gene-level correlation
        gene_corr = np.corrcoef(
            indicator_vec.to_numpy(),
            gene_expr_vec.to_numpy()
        )[0, 1]
        gene_corr_by_ctype.append(gene_corr)

        # exon-level correlations (one value per exon)
        exon_corrs = exon_expr_mat.T.corrwith(indicator_vec, axis=0)
        exon_corr_by_ctype.append(exon_corrs)

    # put correlations into convenient structures
    gene_corr_series = pd.Series(gene_corr_by_ctype, index=ctypes)
    exon_corr_by_ctype_df = pd.DataFrame(exon_corr_by_ctype,
                                         index=ctypes,
                                         columns=w_exons)

    pdf_path = f"{outdir}/{gene}_corr_by_cell_type.pdf"

    with PdfPages(pdf_path) as pdf:
        # --- page 1: gene-level correlation barplot ---
        fig, ax = plt.subplots(figsize=(10, 4))
        ax.bar(range(len(ctypes)), gene_corr_series.values)
        ax.set_xticks(range(len(ctypes)))
        ax.set_xticklabels(ctypes, rotation=90)
        ax.set_ylabel("Correlation (gene vs cell-type indicator)")
        ax.set_xlabel("Cell subclass")
        ax.set_title(f"{gene}: gene-level correlation by cell type")
        fig.tight_layout()
        pdf.savefig(fig)
        plt.close(fig)

        # --- subsequent pages: exon-level correlation barplots ---
        for exon in w_exons:
            exon_corr_vals = exon_corr_by_ctype_df[exon].values  # length = len(ctypes)

            fig, ax = plt.subplots(figsize=(10, 4))
            ax.bar(range(len(ctypes)), exon_corr_vals)
            ax.set_xticks(range(len(ctypes)))
            ax.set_xticklabels(ctypes, rotation=90)
            ax.set_ylabel(f"Correlation (exon {exon} vs cell-type indicator)")
            ax.set_xlabel("Cell subclass")
            ax.set_title(f"{gene}: exon {exon} correlation by cell type")
            fig.tight_layout()
            pdf.savefig(fig)
            plt.close(fig)

    print("Saved:", pdf_path)

Saved: diagnostics/L5_IT/SC/Msc_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/2810408I11Rik_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Lemd1_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Cnih3_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/4930447M23Rik_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/2600014E21Rik_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Lrrc4c_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Postn_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Agbl4_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Lat2_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Stx1a_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Dkkl1_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Mfsd13b_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Cdh8_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Cibar2_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Vill_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/SC/Tpd52l1_corr_by_cell_type.pdf
Saved: diagnostics/L5_I

#### A: Yes

## Prep pseudobulk data

In [22]:
# pseudobulk_str = "20pcntCells_30pcntVar_200samples" 
# pseudobulk_data = f"SyntheticDataset1_{pseudobulk_str}"
# psi_data = f"{pseudobulk_data}_SJ_pseudobulk_min_observed0.05_minPsi0.05"
# merge_param = "0.96"

# bulk_exon_PSI_df = pd.read_csv(f"data/tasic_2018_ALM_STAR_{psi_data}_PSI.csv", index_col=0)
# bulk_exon_counts_df = pd.read_csv(f"data/tasic_2018_ALM_STAR_{psi_data}_counts.csv", index_col=0)

# bulk_gene_counts_df = pd.read_csv(f"data/SyntheticDatasets/{pseudobulk_data}_10-04-30.csv", index_col=0)
# top_qval_mods_df = pd.read_csv(f"data/tasic_2018_ALM_STAR_donor_cell_type_pseudobulk_pairwise_DE_genes_dream_{pseudobulk_str}_log2_pseudobulk_mergeParam0.96_PosBC_top_Qval_modules.csv")

# # exon_corr_df = pd.read_csv(f"data/tasic_2018_ALM_STAR_{psi_data}_PSI_exon_corr.csv", index_col=0)

In [23]:
# pseudobulk_str = "20pcntCells_35SD_200samples" 
# pseudobulk_data = f"SyntheticDataset1_{pseudobulk_str}"
# psi_data = f"{pseudobulk_data}_SJ_pseudobulk_min_observed0.05_minPsi0.05"
# merge_param = "0.9"

# bulk_exon_PSI_df = pd.read_csv(f"data/tasic_2018_ALM_STAR_{psi_data}_PSI.csv", index_col=0)
# bulk_exon_counts_df = pd.read_csv(f"data/tasic_2018_ALM_STAR_{psi_data}_counts.csv", index_col=0)

# bulk_gene_counts_df = pd.read_csv(f"data/SyntheticDatasets/{pseudobulk_data}_12-54-23.csv", index_col=0)
# top_qval_mods_df = pd.read_csv(f"data/tasic_2018_ALM_STAR_donor_cell_type_pseudobulk_pairwise_DE_genes_dream_{pseudobulk_str}_log2_pseudobulk_PosBC_top_Qval_modules.csv")

# # exon_corr_df = pd.read_csv(f"data/tasic_2018_ALM_STAR_{psi_data}_PSI_exon_corr.csv", index_col=0)

In [24]:
# pseudobulk_str = "25pcntCells_50SD_200samples" 
# pseudobulk_data = f"SyntheticDataset1_{pseudobulk_str}"
# psi_data = f"{pseudobulk_data}_SJ_pseudobulk_min_observed0.05"
# merge_param = "0.9"

# bulk_exon_PSI_df = pd.read_csv(f"data/tasic_2018_ALM_STAR_{psi_data}_PSI.csv", index_col=0)
# bulk_exon_counts_df = pd.read_csv(f"data/tasic_2018_ALM_STAR_{psi_data}_counts.csv", index_col=0)

# bulk_gene_counts_df = pd.read_csv(f"data/SyntheticDatasets/{pseudobulk_data}_02-50-14.csv", index_col=0)
# top_qval_mods_df = pd.read_csv(f"data/tasic_2018_ALM_STAR_donor_cell_type_pseudobulk_pairwise_DE_genes_dream_{pseudobulk_str}_log2_pseudobulk_PosBC_top_Qval_modules.csv")

# # exon_corr_df = pd.read_csv(f"data/tasic_2018_ALM_STAR_{psi_data}_PSI_exon_corr.csv", index_col=0)

In [29]:
pseudobulk_str = "25pcntCells_100SD_200samples" 
pseudobulk_data = f"SyntheticDataset1_{pseudobulk_str}"
psi_data = f"{pseudobulk_data}_SJ_pseudobulk_min_observed0.05"
merge_param = "0.9"

bulk_exon_PSI_df = pd.read_csv(f"data/tasic_2018_ALM_STAR_{psi_data}_PSI.csv", index_col=0)
bulk_exon_counts_df = pd.read_csv(f"data/tasic_2018_ALM_STAR_{psi_data}_counts.csv", index_col=0)

bulk_gene_counts_df = pd.read_csv(f"data/SyntheticDatasets/{pseudobulk_data}_11-18-00.csv", index_col=0)
top_qval_mods_df = pd.read_csv(f"data/tasic_2018_ALM_STAR_donor_cell_type_pseudobulk_pairwise_DE_genes_dream_{pseudobulk_str}_log2_pseudobulk_PosBC_top_Qval_modules.csv")

# exon_corr_df = pd.read_csv(f"data/tasic_2018_ALM_STAR_{psi_data}_PSI_exon_corr.csv", index_col=0)

In [30]:
bulk_gene_expr_df = normalize_counts(bulk_gene_counts_df, quantile_cutoff=0.4)
bulk_exon_expr_df = normalize_counts(bulk_exon_counts_df, quantile_cutoff=0.2)

In [31]:
bulk_exon_PSI_anno_df = annotate_df_exons(bulk_exon_PSI_df, mapping_df)
bulk_exon_counts_anno_df = annotate_df_exons(bulk_exon_counts_df, mapping_df)
bulk_exon_expr_anno_df = annotate_df_exons(bulk_exon_expr_df, mapping_df)

## Pseudobulk

In [32]:
ctypes = np.unique(top_qval_mods_df['Cell_type'])

DE_genes = pd.read_csv(f"pairwise_DE_genes/{w_ctype}.csv")
outdir = f"diagnostics/{w_ctype}/{pseudobulk_str}/{merge_param}/unique{unique}"
os.makedirs(outdir, exist_ok=True)

### Q: Are CT DE genes correlated with most enriched module eigengene?

In [33]:
# Subset bulk data to genes with detected exons
bulk_exon_expr_DE_df = bulk_exon_expr_anno_df.loc[bulk_exon_expr_anno_df['gene_name'].isin(exon_expr_DE_df['gene_name'].tolist())]

In [34]:
for gene in bulk_exon_expr_DE_df['gene_name'].unique():
    w_exons = bulk_exon_expr_DE_df.loc[
        bulk_exon_expr_DE_df['gene_name'] == gene
    ].index

    exon_ME_corr_by_ctype = []   # list of Series (one per cell type)
    gene_ME_corr_by_ctype = []   # list of floats

    for ctype in ctypes:
        # row for this cell type (adjust if the key column is named differently)
        row = top_qval_mods_df.loc[top_qval_mods_df['Cell_type'] == ctype].iloc[0]

        # load ME file, extract eigengene
        mod_df = pd.read_csv(row['ME_path'])
        mod_df = mod_df.set_index("Sample")   # index = sample IDs
        mod_eig = mod_df[row['Module']]      # Series: index=sample, values=ME

        # align samples between ME and expression matrices
        common_samples = bulk_gene_expr_df.columns.intersection(mod_eig.index)
        me = mod_eig.loc[common_samples]

        # exon expression (exons x samples)
        exon_expr = bulk_exon_expr_anno_df.loc[w_exons, common_samples]
        # gene expression (Series: samples)
        gene_expr = bulk_gene_expr_df.loc[gene, common_samples]

        # exon-level correlations: corr(exon_expr, ME) across samples
        # exon_expr.T: samples x exons, `me`: samples
        exon_corrs = exon_expr.T.corrwith(me, axis=0)  # Series indexed by w_exons

        # gene-level correlation: single float per cell type
        gene_corr = gene_expr.corr(me)

        exon_ME_corr_by_ctype.append(exon_corrs)
        gene_ME_corr_by_ctype.append(gene_corr)

    # convenient structures
    gene_corr_series = pd.Series(gene_ME_corr_by_ctype, index=ctypes)
    exon_corr_by_ctype_df = pd.DataFrame(
        exon_ME_corr_by_ctype,
        index=ctypes,
        columns=w_exons
    )

    pdf_path = f"{outdir}/{gene}_bulk_ME_corr_by_cell_type.pdf"

    with PdfPages(pdf_path) as pdf:
        # --- page 1: gene-level correlation barplot ---
        fig, ax = plt.subplots(figsize=(10, 4))
        ax.bar(range(len(ctypes)), gene_corr_series.values)
        ax.set_xticks(range(len(ctypes)))
        ax.set_xticklabels(ctypes, rotation=90)
        ax.set_ylabel("Correlation (gene expression vs ME)")
        ax.set_xlabel("Cell subclass")
        ax.set_title(f"{gene}: gene–ME correlation by cell type")
        fig.tight_layout()
        pdf.savefig(fig)
        plt.close(fig)

        # --- subsequent pages: exon-level correlation barplots ---
        for exon in w_exons:
            exon_corr_vals = exon_corr_by_ctype_df[exon].values  # len = len(ctypes)
            fig, ax = plt.subplots(figsize=(10, 4))
            ax.bar(range(len(ctypes)), exon_corr_vals)
            ax.set_xticks(range(len(ctypes)))
            ax.set_xticklabels(ctypes, rotation=90)
            ax.set_ylabel(f"Correlation (exon {exon} expression vs ME)")
            ax.set_xlabel("Cell subclass")
            ax.set_title(f"{gene}: exon {exon} ME correlation by cell type")
            fig.tight_layout()
            pdf.savefig(fig)
            plt.close(fig)

    print("Saved:", pdf_path)


Saved: diagnostics/L5_IT/25pcntCells_100SD_200samples/0.9/uniqueTrue/Msc_bulk_ME_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/25pcntCells_100SD_200samples/0.9/uniqueTrue/2810408I11Rik_bulk_ME_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/25pcntCells_100SD_200samples/0.9/uniqueTrue/Lemd1_bulk_ME_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/25pcntCells_100SD_200samples/0.9/uniqueTrue/Cnih3_bulk_ME_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/25pcntCells_100SD_200samples/0.9/uniqueTrue/4930447M23Rik_bulk_ME_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/25pcntCells_100SD_200samples/0.9/uniqueTrue/2600014E21Rik_bulk_ME_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/25pcntCells_100SD_200samples/0.9/uniqueTrue/Lrrc4c_bulk_ME_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/25pcntCells_100SD_200samples/0.9/uniqueTrue/Postn_bulk_ME_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/25pcntCells_100SD_200samples/0.9/uniqueTrue/Agbl4_bulk_ME_corr_by_cell_type.pdf
Saved: diagnostics/L5_IT/25pcntC

### A: Mostly yes

## Pseudobulk vs. single-cell metrics

#### Which pseudobulk datasets look most similar to single-cell?

In [35]:
def get_bulk_stats(ctypes, w_exons, top_qval_mods_df, gene):
    """For one gene, get bulk ME correlations across ctypes.
       Returns:
         gene_ME_series: pd.Series (index=ctypes)
         exon_ME_df:     pd.DataFrame (index=ctypes, columns=w_exons)
    """
    gene_ME_corr_by_ctype = []
    exon_ME_corr_by_ctype = []

    for ctype in ctypes:
        row = top_qval_mods_df.loc[top_qval_mods_df['Cell_type'] == ctype].iloc[0]

        mod_df = pd.read_csv(row['ME_path'])
        mod_df = mod_df.set_index("Sample")
        mod_eig = mod_df[row['Module']]          # Series: index = samples

        common_samples = bulk_gene_expr_df.columns.intersection(mod_eig.index)
        me = mod_eig.loc[common_samples]

        exon_expr = bulk_exon_expr_anno_df.loc[w_exons, common_samples]
        gene_expr = bulk_gene_expr_df.loc[gene, common_samples]

        # exon-level ME correlations (one value per exon)
        exon_corrs = exon_expr.T.corrwith(me, axis=0)  # Series indexed by w_exons
        # gene-level ME correlation
        gene_corr = gene_expr.corr(me)

        exon_ME_corr_by_ctype.append(exon_corrs)
        gene_ME_corr_by_ctype.append(gene_corr)

    gene_ME_series = pd.Series(gene_ME_corr_by_ctype, index=ctypes)
    exon_ME_df = pd.DataFrame(exon_ME_corr_by_ctype,
                              index=ctypes,
                              columns=w_exons)
    return gene_ME_series, exon_ME_df


def get_SC_stats(ctypes, w_exons, sdata, gene):
    """For one gene, get SC correlations across ctypes.
       Returns:
         gene_SC_series: pd.Series (index=ctypes)
         exon_SC_df:     pd.DataFrame (index=ctypes, columns=w_exons)
    """
    exon_expr_mat = exon_expr_anno_df.loc[w_exons, sdata.obs_names]

    gene_corr_by_ctype = []
    exon_corr_by_ctype = []

    for ctype in ctypes:
        indicator_vec = (sdata.obs['cell_subclass'] == ctype).astype(int)
        indicator_vec = indicator_vec.loc[sdata.obs_names]

        gene_expr_vec = gene_expr_df.loc[gene, sdata.obs_names]

        # gene-level SC correlation
        gene_corr = np.corrcoef(
            indicator_vec.to_numpy(),
            gene_expr_vec.to_numpy()
        )[0, 1]
        gene_corr_by_ctype.append(gene_corr)

        # exon-level SC correlations vs indicator (one value per exon)
        exon_corrs = exon_expr_mat.T.corrwith(indicator_vec, axis=0)
        exon_corr_by_ctype.append(exon_corrs)

    gene_SC_series = pd.Series(gene_corr_by_ctype, index=ctypes)
    exon_SC_df = pd.DataFrame(exon_corr_by_ctype,
                              index=ctypes,
                              columns=w_exons)
    return gene_SC_series, exon_SC_df


In [36]:
bulk_exon_expr_DE_df = bulk_exon_expr_anno_df.loc[bulk_exon_expr_anno_df['gene_name'].isin(exon_expr_DE_df['gene_name'].tolist())]
exon_expr_DE_df = exon_expr_anno_df.loc[exon_expr_anno_df['gene_name'].isin(DE_genes['Gene'].tolist())]

shared_genes = np.intersect1d(
    bulk_exon_expr_DE_df['gene_name'].unique(),
    exon_expr_DE_df['gene_name'].unique()
)
shared_exons = bulk_exon_expr_DE_df.index.intersection(exon_expr_DE_df.index)

In [37]:
sdata.obs['cell_subclass'] = (
    sdata.obs['cell_subclass']
    .astype(str)
    .str.replace(r'[\/\-\s]+', '_', regex=True)
)
ctypes_SC = np.unique(sdata.obs['cell_subclass'])
ctypes_bulk = np.unique(top_qval_mods_df['Cell_type'])
ctypes = np.intersect1d(ctypes_SC, ctypes_bulk)

In [38]:
outdir = f"diagnostics/{w_ctype}/{pseudobulk_str}/{merge_param}/unique{unique}"

In [39]:
pdf_path = f"{outdir}/bulk_vs_SC_corrs.pdf"
os.makedirs(os.path.dirname(pdf_path), exist_ok=True)

gene_agreement = []
exon_agreement = []

for gene in shared_genes:
    exons_this_gene = bulk_exon_expr_DE_df.index[
        bulk_exon_expr_DE_df['gene_name'] == gene
    ]
    w_exons = shared_exons.intersection(exons_this_gene)

    if len(w_exons) == 0:
        # skip genes with no shared exons
        continue

    # SC stats
    gene_SC_series, exon_SC_df = get_SC_stats(ctypes, w_exons, sdata, gene)
    # bulk stats
    gene_bulk_series, exon_bulk_df = get_bulk_stats(ctypes, w_exons, top_qval_mods_df, gene)

    # align on ctypes just in case
    common_ctypes = gene_SC_series.index.intersection(gene_bulk_series.index)
    g_sc = gene_SC_series.loc[common_ctypes].values
    g_bulk = gene_bulk_series.loc[common_ctypes].values

    # scalar correlation for this gene (SC pattern vs bulk pattern across ctypes)
    gene_agreement.append(np.corrcoef(g_sc, g_bulk)[0, 1])

    # exon-level agreement: flatten matrices over ctypes × exons
    common_exons = exon_SC_df.columns.intersection(exon_bulk_df.columns)
    exon_sc = exon_SC_df.loc[common_ctypes, common_exons].values.ravel()
    exon_bulk = exon_bulk_df.loc[common_ctypes, common_exons].values.ravel()
    exon_agreement.append(np.corrcoef(exon_sc, exon_bulk)[0, 1])

# Now plot across genes
with PdfPages(pdf_path) as pdf:
    # --- page 1: gene-level agreement across genes ---
    fig, ax = plt.subplots(figsize=(10, 4))
    ax.bar(range(len(gene_agreement)), gene_agreement)
    ax.set_xticks(range(len(gene_agreement)))
    ax.set_xticklabels(shared_genes[:len(gene_agreement)], rotation=90)
    ax.set_ylabel("Corr(SC gene-corrs, bulk gene-ME-corrs)")
    ax.set_xlabel("Gene")
    ax.set_title("Gene-level SC vs bulk ME correlation pattern by gene")

    # horizontal line at mean correlation
    mean_gene_corr = np.nanmean(gene_agreement)
    ax.axhline(mean_gene_corr, color="red", linestyle="--", linewidth=1)

    fig.tight_layout()
    pdf.savefig(fig)
    plt.close(fig)

    # --- page 2: exon-level agreement across genes ---
    fig, ax = plt.subplots(figsize=(10, 4))
    ax.bar(range(len(exon_agreement)), exon_agreement)
    ax.set_xticks(range(len(exon_agreement)))
    ax.set_xticklabels(shared_genes[:len(exon_agreement)], rotation=90)
    ax.set_ylabel("Corr(SC exon-corrs, bulk exon-ME-corrs)")
    ax.set_xlabel("Gene")
    ax.set_title("Exon-level SC vs bulk ME correlation pattern by gene")

    # horizontal line at mean correlation
    mean_exon_corr = np.nanmean(exon_agreement)
    ax.axhline(mean_exon_corr, color="red", linestyle="--", linewidth=1)

    fig.tight_layout()
    pdf.savefig(fig)
    plt.close(fig)

print("Saved:", pdf_path)

Saved: diagnostics/L5_IT/25pcntCells_100SD_200samples/0.9/uniqueTrue/bulk_vs_SC_corrs.pdf
