# Cross-Modality Prediction with Linear Regression

In [1]:
import numpy as np
import scanpy as sc
from sklearn.linear_model import LinearRegression
from scipy.stats import pearsonr
from sklearn.metrics import mean_squared_error
import pandas as pd
import os
from self_supervision.paths import MULTIMODAL_FOLDER

### Load and Prepare Data

In [2]:
adata = sc.read_h5ad(os.path.join(MULTIMODAL_FOLDER, "NeurIPS_filtered_hvg_adata.h5ad"))
adata


AnnData object with n_obs × n_vars = 90261 × 2000
    obs: 'GEX_n_genes_by_counts', 'GEX_pct_counts_mt', 'GEX_size_factors', 'GEX_phase', 'ADT_n_antibodies_by_counts', 'ADT_total_counts', 'ADT_iso_count', 'cell_type', 'batch', 'ADT_pseudotime_order', 'GEX_pseudotime_order', 'Samplename', 'Site', 'DonorNumber', 'Modality', 'VendorLot', 'DonorID', 'DonorAge', 'DonorBMI', 'DonorBloodType', 'DonorRace', 'Ethnicity', 'DonorGender', 'QCMeds', 'DonorSmoker', 'is_train', 'split'
    var: 'feature_types', 'gene_id', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'dataset_id', 'genome', 'hvg', 'log1p', 'organism'
    obsm: 'ADT_X_pca', 'ADT_X_umap', 'ADT_isotype_controls', 'GEX_X_pca', 'GEX_X_umap', 'protein_counts'
    layers: 'counts'

In [3]:
train_adata = adata[adata.obs['split'] == 'train']
test_adata = adata[adata.obs['split'] == 'ood_test']

In [4]:
train_mRNA = np.asarray(train_adata.X.todense())
train_protein = np.log1p(np.asarray(train_adata.obsm['protein_counts']))

test_mRNA = np.asarray(test_adata.X.todense())
test_protein = np.log1p(np.asarray(test_adata.obsm['protein_counts']))

### Fit Linear Model

In [5]:
model = LinearRegression()
model.fit(train_mRNA, train_protein)

train_predictions = model.predict(train_mRNA)
test_predictions = model.predict(test_mRNA)

train_mse = mean_squared_error(train_predictions, train_protein)
test_mse = mean_squared_error(test_predictions, test_protein)

print("Train MSE:", train_mse)
print("Test MSE:", test_mse)

Train MSE: 0.5309232
Test MSE: 1.1918472


In [6]:
df = pd.DataFrame(test_predictions)
#df.to_csv('LinReg_test_protein.csv', index=False)
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,124,125,126,127,128,129,130,131,132,133
0,0.567372,1.580514,2.797127,1.114277,2.077988,3.797059,3.923937,1.960537,2.983202,3.436579,...,2.050917,3.697425,1.423468,2.836527,0.176854,2.354324,2.994437,2.397672,2.322180,2.490471
1,0.574022,1.949335,2.599400,0.582670,1.757538,2.650935,2.597198,2.936141,2.608345,2.501446,...,1.608597,1.562154,1.576531,2.820076,0.347608,2.032809,2.206467,1.045456,1.816128,1.899944
2,1.641404,2.240649,2.985437,1.859532,2.228678,4.279867,4.326148,1.781528,3.024471,3.024823,...,2.078953,3.863859,1.956591,2.815760,2.968354,2.785178,2.960572,3.912214,2.785528,3.317249
3,2.119231,1.554606,2.718613,1.247801,1.756236,4.192995,3.908851,1.788978,2.722189,3.238468,...,1.798182,4.763229,2.477818,2.487440,1.947106,2.682157,2.044715,3.382274,3.528497,2.248569
4,0.377468,2.022115,2.925131,0.498923,1.909719,3.312759,2.936779,2.726048,2.846499,2.973350,...,1.601198,1.998087,1.856366,2.739463,0.281502,2.287681,3.201321,1.983115,2.269296,2.356172
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10460,0.372634,1.800690,2.557297,1.105458,1.978532,2.319202,0.777478,1.432728,3.115477,1.271966,...,1.950803,0.953255,1.554047,2.826198,0.355819,2.148901,3.182093,1.613770,3.417982,1.901903
10461,2.281189,2.008333,3.031861,2.114973,2.552917,4.305439,4.336975,1.898395,3.217922,2.802584,...,2.141768,4.326425,2.411513,3.003736,3.702658,2.858924,2.861776,3.906376,3.501036,3.304327
10462,0.353658,1.165260,2.663476,0.706642,1.494424,3.632160,3.642930,1.574918,2.681316,3.447630,...,1.497875,3.387916,1.117874,2.407871,0.178105,2.209592,2.590366,2.383527,1.845070,1.545758
10463,0.269225,1.675319,2.810692,0.878852,1.809433,3.959491,3.933648,1.582298,3.130629,3.595952,...,1.508742,3.050074,1.195946,2.833106,0.262645,2.389876,2.848648,2.249928,2.161940,2.384729


### Evaluate Performance

In [7]:
# observed = np.log1p(adata[adata.obs['split'] == 6].obsm['protein_counts'])
observed = np.log1p(test_adata.obsm['protein_counts'])
held_vs_denoised_percells = pd.DataFrame()
imputed = df
imputed.index = observed.index
held_vs_denoised_percells["Observed (log)"] = observed.mean(1)
held_vs_denoised_percells["Imputed (log)"] = imputed.mean(1)
corr = []
mse = []
for i in range(len(imputed)):
    res = pearsonr(observed.iloc[i,:], imputed.iloc[i,:])
    dis = mean_squared_error(observed.iloc[i,:], imputed.iloc[i,:])
    corr.append(np.round(res[0],3))
    mse.append(np.round(dis,3))
held_vs_denoised_percells["corr"] = corr
held_vs_denoised_percells["mse"] = mse
print(f"mean of mse: {np.mean(held_vs_denoised_percells['mse'])}")
print(f"mean of corr: {np.mean(held_vs_denoised_percells['corr'])}")
held_vs_denoised_percells

mean of mse: 1.1918474435806274
mean of corr: 0.8085559483994267


Unnamed: 0,Observed (log),Imputed (log),corr,mse
TGTTGAGGTTTACGTG-1-s2d1,0.919579,1.960781,0.765,1.492
CTACATTTCGCAGATT-1-s2d1,1.063311,1.626454,0.826,0.627
ACGTACAAGTAGAATC-1-s2d1,1.319220,2.354523,0.813,1.463
TGGGAAGAGTCTGCGC-1-s2d1,1.512837,1.982302,0.865,0.572
ATCCGTCAGCAGTCTT-1-s2d1,1.027916,1.915029,0.783,1.215
...,...,...,...,...
TCATTCAGTCACCGAC-1-s2d1,1.171844,1.552774,0.867,0.437
GGAGAACAGCCAAGCA-1-s2d1,1.593776,2.508789,0.839,1.230
TGCGACGGTAGAATGT-1-s2d1,0.912626,1.735089,0.796,1.097
TCCTGCAGTCATGACT-1-s2d1,0.933807,1.904605,0.815,1.338
