In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.multioutput import MultiOutputRegressor
import pickle
import pandas as pd
import scanpy as sc

In [2]:
for sample in ['iGABA_post','iGABA_pre','iGlut_post','iGlut_pre']:
    meta = pd.read_csv("scanpy/"+sample+"_dr_clustered_raw_merged_meta.tsv",sep="\t",index_col=0)
    meta['M_CycA'] = meta['CycA']
    meta.drop('CycA',axis=1,inplace=True)
    meta.head()

    meta['condition'] = meta['AP_axis'] + "_" + meta['DV_axis']
    print(sample)
    print(meta['Basal_media'].value_counts())

iGABA_post
mTeSR    85756
Name: Basal_media, dtype: int64
iGABA_pre
N2B27_2Si        80751
N2B27_SB_CHIR    40278
NIM              18682
Name: Basal_media, dtype: int64
iGlut_post
mTeSR    184431
Name: Basal_media, dtype: int64
iGlut_pre
N2B27_2Si        114780
NIM              108104
N2B27_SB_CHIR     82028
Name: Basal_media, dtype: int64


In [4]:
from sklearn.model_selection import KFold
from sklearn.ensemble import RandomForestRegressor
from sklearn.multioutput import MultiOutputRegressor
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import spearmanr
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
import os
import random
random.seed(42)

# Define k-fold cross-validator
kf = KFold(n_splits=5, shuffle=True, random_state=42)

# Placeholder for storing results
results = []

# Iterate through samples
for sample in ['iGlut_post']:
    print(sample)
    
    # Load metadata and expression data (same as before)
    meta = pd.read_csv("scanpy/"+sample+"_dr_clustered_raw_merged_meta.tsv", sep="\t", index_col=0)
    meta['M_CycA'] = meta['CycA']
    adata = sc.read_h5ad("scanpy/"+sample+"_dr_clustered_raw_merged.h5ad")
    adata.obs['BC'] = adata.obs.index

    for BM in set(meta['Basal_media']):
        print(f"Basal Media: {BM}")
        meta_bm = meta.loc[meta['Basal_media'] == BM].copy()
        meta_bm['condition'] = meta_bm['AP_axis'] + "_" + meta_bm['DV_axis']
        
        # Only keep conditions with at least 250 cells
        conditions_min250 = meta_bm['condition'].value_counts()[meta_bm['condition'].value_counts() > 250].index
        meta_bm = meta_bm[meta_bm['condition'].isin(conditions_min250)]

        # Find minimal number of cells per condition
        min_cells_condition = meta_bm['condition'].value_counts().min()

        # Sample cells per condition
        meta_bm_sampled = meta_bm.groupby('condition').sample(n=min_cells_condition, random_state=42)

        # Get expression data for sampled cells
        adata_subset = adata[meta_bm_sampled.index, :].copy()
        dgem = pd.DataFrame.sparse.from_spmatrix(adata_subset.X, index=adata_subset.obs.index, columns=adata_subset.var_names)
        
        # Morphogens and preprocessing (same as before)
        morphogens = ['M_'+x for x in ['XAV','CHIR','RA','FGF8','BMP4','SHH','CycA']]
        morph_sum = meta_bm_sampled[morphogens].sum()
        morphogens = list(morph_sum[morph_sum > 0].index)

        # Normalize data (log and min-max normalization)
        y = np.log10(meta_bm_sampled[morphogens] + 1).fillna(0)
        for morph in morphogens: 
            y[morph] = (y[morph] - y[morph].min()) / (y[morph].max() - y[morph].min())

        # Cross-validation
        X = dgem
        y_values = y.values

        fold_results = []
        for fold_idx, (train_index, test_index) in enumerate(kf.split(X)):
            X_train, X_test = X.iloc[train_index], X.iloc[test_index]
            y_train, y_test = y_values[train_index], y_values[test_index]

            # Train RandomForestRegressor with MultiOutputRegressor
            regr = RandomForestRegressor(random_state=42, n_jobs=20, n_estimators=100, max_features='sqrt', max_depth=10)
            regr_multi = MultiOutputRegressor(regr, n_jobs=4).fit(X_train, y_train)

            # Predict on test set
            y_pred = regr_multi.predict(X_test)

            # Calculate Spearman correlation for each morphogen
            spearman_corrs = []
            for i, morph in enumerate(morphogens):
                corr, _ = spearmanr(y_test[:, i], y_pred[:, i])
                spearman_corrs.append(corr)

            # Calculate average Spearman correlation, R² score, and MSE
            fold_score = {
                'r2_score': r2_score(y_test, y_pred, multioutput='uniform_average'),
                'mse': mean_squared_error(y_test, y_pred),
                'avg_spearman': np.mean(spearman_corrs)
            }
            fold_results.append(fold_score)

            # Generate plots
            y_pred_df = pd.DataFrame(y_pred, index=X_test.index, columns=[morph + '_pred' for morph in morphogens])
            y_test_df = pd.DataFrame(y_test, index=X_test.index, columns=morphogens)
            comb = pd.concat([y_test_df, y_pred_df], axis=1)

            for morph in morphogens:
                plt.figure(figsize=(10, 5))
                sns.boxplot(data=comb, x=morph, y=morph + '_pred')
                plt.ylabel('Predicted concentration (A.U.)', fontsize=20)
                plt.xlabel('Real concentration (A.U.)', fontsize=20)
                plt.title(f"{morph} - Fold {fold_idx}", fontsize=22)
                plt.xticks(fontsize=18)
                plt.yticks(fontsize=18)
                
                # Create directories if not exist
                os.makedirs(f"figures/multiregressor_validate/{sample}", exist_ok=True)
                
                plt.savefig(f"figures/multiregressor_validate/{sample}/subsampled_multi_v1_BM_{morph}_{BM}_test_fold_{fold_idx}.png", dpi=350, bbox_inches='tight', pad_inches=0)
                plt.close()

            # Save model
            filename = f"figures/multiregressor_validate/{sample}/{BM}_fold_{fold_idx}.p"
            with open(filename, 'wb') as filehandler:
                pickle.dump(regr_multi, filehandler)

            # Save test data
            comb.to_csv(f"figures/multiregressor_validate/{sample}/subsampled_multi_v1_BM_test_{BM}_fold_{fold_idx}.tsv", sep='\t')

            # Predict and save training data
            y_train_pred = regr_multi.predict(X_train)
            y_train_pred_df = pd.DataFrame(y_train_pred, index=X_train.index, columns=[morph + '_pred' for morph in morphogens])
            comb_train = pd.concat([pd.DataFrame(y_train, index=X_train.index, columns=morphogens), y_train_pred_df], axis=1)
            comb_train.to_csv(f"figures/multiregressor_validate/{sample}/subsampled_multi_v1_BM_train_{BM}_fold_{fold_idx}.tsv", sep='\t')

        # Append average fold scores for each basal media
        results.append({
            'sample': sample,
            'basal_media': BM,
            'avg_r2_score': np.mean([res['r2_score'] for res in fold_results]),
            'avg_mse': np.mean([res['mse'] for res in fold_results]),
            'avg_spearman': np.mean([res['avg_spearman'] for res in fold_results])
        })

# Results
for result in results:
    print(f"Sample: {result['sample']}, Basal Media: {result['basal_media']}, Avg R²: {result['avg_r2_score']:.4f}, Avg MSE: {result['avg_mse']:.4f}, Avg Spearman: {result['avg_spearman']:.4f}")


iGlut_post
Basal Media: mTeSR
Sample: iGlut_post, Basal Media: mTeSR, Avg R²: 0.3423, Avg MSE: 0.0820, Avg Spearman: 0.6082


In [6]:
color_dict = {}
color_dict['M_SHH'] = ['w',"#C2D9F7", "#98C1F0", "#4782DD", "#1D52A1"]
color_dict['M_RA'] = ['w', "#aadce0","#72bcd5", "#528fad", "#376795"]
color_dict['M_BMP4'] = ['w', "#ffe6b7", "#ffd353","#ffb242"]
color_dict['M_XAV'] = ['w', "#f9b4c9","#d8527c","#9a133d"]
color_dict['M_CHIR'] = ['w',"#dec5da", "#b695bc", "#90719f", "#574571"]
color_dict['M_FGF8'] = ['w','#ffbbff','#ee7ae9','#b452cd','#8b008b']


In [30]:
for sample in ['iGlut_post']:
    print(sample)
    for BM in ['mTeSR']:
        print(BM)
        for fold_idx in [0,1,2,3,4]:
            file_path = f"figures/multiregressor_validate/{sample}/subsampled_multi_v1_BM_test_{BM}_fold_{fold_idx}.tsv"
            print(file_path)
            comb = pd.read_csv(file_path,sep="\t",index_col=0)
            comb['fold_idx'] = fold_idx
            if fold_idx == 0:
                comb_all = comb
            else:
                comb_all = pd.concat([comb_all,comb])

            file_path = f"figures/multiregressor_validate/{sample}/subsampled_multi_v1_BM_train_{BM}_fold_{fold_idx}.tsv"
            print(file_path)
            comb_train = pd.read_csv(file_path,sep="\t",index_col=0)
            comb_train['fold_idx'] = fold_idx
            if fold_idx == 0:
                comb_train_all = comb_train
            else:
                comb_train_all = pd.concat([comb_train_all,comb_train])



iGlut_post
mTeSR
figures/multiregressor_validate/iGlut_post/subsampled_multi_v1_BM_test_mTeSR_fold_0.tsv
figures/multiregressor_validate/iGlut_post/subsampled_multi_v1_BM_train_mTeSR_fold_0.tsv
figures/multiregressor_validate/iGlut_post/subsampled_multi_v1_BM_test_mTeSR_fold_1.tsv
figures/multiregressor_validate/iGlut_post/subsampled_multi_v1_BM_train_mTeSR_fold_1.tsv
figures/multiregressor_validate/iGlut_post/subsampled_multi_v1_BM_test_mTeSR_fold_2.tsv
figures/multiregressor_validate/iGlut_post/subsampled_multi_v1_BM_train_mTeSR_fold_2.tsv
figures/multiregressor_validate/iGlut_post/subsampled_multi_v1_BM_test_mTeSR_fold_3.tsv
figures/multiregressor_validate/iGlut_post/subsampled_multi_v1_BM_train_mTeSR_fold_3.tsv
figures/multiregressor_validate/iGlut_post/subsampled_multi_v1_BM_test_mTeSR_fold_4.tsv
figures/multiregressor_validate/iGlut_post/subsampled_multi_v1_BM_train_mTeSR_fold_4.tsv


In [48]:
import pandas as pd
from scipy.stats import spearmanr
from sklearn.metrics import mean_squared_error

# Assuming df is your dataframe

# Columns for which you want to calculate correlation and MSE
columns = ['M_XAV', 'M_CHIR', 'M_RA', 'M_FGF8', 'M_BMP4', 'M_SHH']

df = comb_all
# Group by fold
grouped = df.groupby('fold_idx')

# Store results
results = []

for fold_idx, group in grouped:
    fold_results = {'fold_idx': fold_idx}
    
    for col in columns:
        if not 'pred' in col:
            # Get the column and its predicted counterpart
            col_pred = col + '_pred'
            
            # Calculate Spearman correlation
            spearman_corr, _ = spearmanr(group[col], group[col_pred])
            
            # Calculate MSE
            mse_value = mean_squared_error(group[col], group[col_pred])
            
            # Store the results
            fold_results[f'{col}_spearman'] = spearman_corr
            fold_results[f'{col}_mse'] = mse_value
    
    results.append(fold_results)

# Convert results to a dataframe
results_df_test = pd.DataFrame(results)

In [51]:
results_df_train

Unnamed: 0,fold_idx,M_XAV_spearman,M_XAV_mse,M_CHIR_spearman,M_CHIR_mse,M_RA_spearman,M_RA_mse,M_FGF8_spearman,M_FGF8_mse,M_BMP4_spearman,M_BMP4_mse,M_SHH_spearman,M_SHH_mse
0,0,0.358927,0.031468,0.789688,0.124228,0.734472,0.03611,0.599767,0.101874,0.844572,0.051871,0.757717,0.086336
1,1,0.359188,0.031016,0.793681,0.123785,0.737793,0.036663,0.59781,0.103042,0.844956,0.05135,0.762145,0.086002
2,2,0.357212,0.031531,0.792309,0.12503,0.738091,0.036332,0.59805,0.101935,0.845873,0.051162,0.755958,0.086329
3,3,0.355672,0.031919,0.792486,0.124245,0.739178,0.03681,0.595519,0.101713,0.845572,0.050067,0.764336,0.085756
4,4,0.35737,0.031389,0.791565,0.124605,0.738667,0.036138,0.597851,0.101454,0.845146,0.0506,0.761069,0.085955


In [52]:
results_df_train.columns = [x+"_train" for x in results_df_train.columns]
results_df_test.columns = [x+"_test" for x in results_df_test.columns]

In [54]:
results_df_merged = pd.merge(results_df_train,results_df_test,left_index=True,right_index=True)

In [57]:
results_df_merged = results_df_merged.drop('fold_idx_train',axis=1)

In [60]:
sorted_cols = ['fold_idx_test','M_BMP4_mse_test', 'M_BMP4_mse_train', 'M_BMP4_spearman_test', 'M_BMP4_spearman_train', 'M_CHIR_mse_test', 'M_CHIR_mse_train', 'M_CHIR_spearman_test', 'M_CHIR_spearman_train', 'M_FGF8_mse_test', 'M_FGF8_mse_train', 'M_FGF8_spearman_test', 'M_FGF8_spearman_train', 'M_RA_mse_test', 'M_RA_mse_train', 'M_RA_spearman_test', 'M_RA_spearman_train', 'M_SHH_mse_test', 'M_SHH_mse_train', 'M_SHH_spearman_test', 'M_SHH_spearman_train', 'M_XAV_mse_test', 'M_XAV_mse_train', 'M_XAV_spearman_test', 'M_XAV_spearman_train']

In [62]:
results_df_merged = results_df_merged[sorted_cols]

In [63]:
results_df_merged.to_csv("figures/multiregressor_validate/training_overview.tsv",sep="\t")

In [67]:
import pandas as pd

# Assuming df is your dataframe with the structure you provided
# First, melt the dataframe to have a long format

melted_df = pd.melt(results_df_merged, id_vars=['fold_idx_test'], 
                    var_name='metric', 
                    value_name='value')

# Extract the actual columns (e.g., M_BMP4) and the metric (e.g., mse/spearman, test/train)
melted_df['column'] = melted_df['metric'].str.extract(r'(M_[A-Z0-9]+)')
melted_df['metric_type'] = melted_df['metric'].str.extract(r'(mse|spearman)')
melted_df['train_test'] = melted_df['metric'].str.extract(r'(train|test)')

# Now pivot the table to get the train and test metrics in separate columns
reshaped_df = melted_df.pivot_table(index=['column', 'fold_idx_test'], 
                                    columns=['metric_type', 'train_test'], 
                                    values='value').reset_index()

# Flatten the multi-level columns
reshaped_df.columns = ['_'.join(col).strip() for col in reshaped_df.columns.values]

# Rename the 'fold_idx_test_' column to 'fold_idx' (optional)
reshaped_df.rename(columns={'fold_idx_test_': 'fold_idx'}, inplace=True)

print(reshaped_df)


   column_  fold_idx  mse_test  mse_train  spearman_test  spearman_train
0   M_BMP4         0  0.059409   0.051871       0.838077        0.844572
1   M_BMP4         1  0.057434   0.051350       0.838974        0.844956
2   M_BMP4         2  0.057382   0.051162       0.837578        0.845873
3   M_BMP4         3  0.058064   0.050067       0.839274        0.845572
4   M_BMP4         4  0.057939   0.050600       0.838972        0.845146
5   M_CHIR         0  0.143628   0.124228       0.729066        0.789688
6   M_CHIR         1  0.142210   0.123785       0.722553        0.793681
7   M_CHIR         2  0.143160   0.125030       0.722977        0.792309
8   M_CHIR         3  0.145720   0.124245       0.729131        0.792486
9   M_CHIR         4  0.143170   0.124605       0.732817        0.791565
10  M_FGF8         0  0.117045   0.101874       0.439499        0.599767
11  M_FGF8         1  0.113922   0.103042       0.455752        0.597810
12  M_FGF8         2  0.115712   0.101935       0.4