# Cross-Modality Prediction with Linear Regression

In [1]:
%load_ext autoreload

In [2]:
import numpy as np
import scanpy as sc
from sklearn.linear_model import LinearRegression, Ridge
from scipy.stats import pearsonr
from sklearn.metrics import mean_squared_error
import pandas as pd
import os
import torch
from torch.utils.data import DataLoader, Dataset
import lightning.pytorch as pl
from self_supervision.paths import MULTIMODAL_FOLDER, TRAINING_FOLDER, RESULTS_FOLDER

In [3]:
%autoreload 2
from self_supervision.models.lightning_modules.multiomics_autoencoder import MultiomicsMultiAutoencoder
from self_supervision.data.datamodules import ATACDataloader

[rank: 0] Global seed set to 0


### Load and Prepare Data

In [4]:
adata = sc.read_h5ad(os.path.join(MULTIMODAL_FOLDER, "NeurIPS_tfidf_filtered_hvg_adata.h5ad"))
adata

AnnData object with n_obs × n_vars = 69249 × 2000
    obs: 'GEX_pct_counts_mt', 'GEX_n_counts', 'GEX_n_genes', 'GEX_size_factors', 'GEX_phase', 'ATAC_nCount_peaks', 'ATAC_atac_fragments', 'ATAC_reads_in_peaks_frac', 'ATAC_blacklist_fraction', 'ATAC_nucleosome_signal', 'cell_type', 'batch', 'ATAC_pseudotime_order', 'GEX_pseudotime_order', 'Samplename', 'Site', 'DonorNumber', 'Modality', 'VendorLot', 'DonorID', 'DonorAge', 'DonorBMI', 'DonorBloodType', 'DonorRace', 'Ethnicity', 'DonorGender', 'QCMeds', 'DonorSmoker', 'split'
    var: 'feature_types', 'gene_id', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    obsm: 'atac'

In [5]:
test_adata = adata[adata.obs['split'] == 'test']

In [6]:
test_mRNA = torch.tensor(np.array(test_adata.X.todense()), dtype=torch.float32)
test_atac = test_adata.obsm['atac']
test_batches = test_adata.obs['batch']

In [7]:
test_mRNA.shape

torch.Size([6896, 2000])

In [8]:
test_batches.shape

(6896,)

In [9]:
test_atac.shape

(6896, 116490)

In [10]:
type(test_mRNA)

torch.Tensor

In [11]:
# Create DataLoader for test data
test_dataset = ATACDataloader(test_atac, test_mRNA, test_batches)
test_dataloader = DataLoader(test_dataset, batch_size=512, shuffle=False)

### Inference with MLP

In [12]:
MODEL_DIR = os.path.join(TRAINING_FOLDER, "final_models", "multiomics")

In [198]:
# MLP trained with supervised learning

# ckpt_path = os.path.join(MODEL_DIR, "No_SSL_new_run0", "default", "version_2", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "No_SSL_new_run1", "default", "version_2", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "No_SSL_new_run2", "default", "version_2", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "No_SSL_new_run3", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "No_SSL_new_run4", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")

# ckpt_path = os.path.join(MODEL_DIR, "No_SSL_rev_new_run0", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "No_SSL_rev_new_run1", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "No_SSL_rev_new_run2", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "No_SSL_rev_new_run3", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "No_SSL_rev_new_run4", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")

# ckpt_path = os.path.join(MODEL_DIR, "No_SSL_tfidf_run0", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "No_SSL_tfidf_run1", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "No_SSL_tfidf_run2", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "No_SSL_tfidf_run3", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "No_SSL_tfidf_run4", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")

### new ###
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/No_SSL_tfidf_run0/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/No_SSL_tfidf_run1/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/No_SSL_tfidf_run2/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/No_SSL_tfidf_run3/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/No_SSL_tfidf_run4/default/version_0/checkpoints/best_checkpoint_val.ckpt"

### new + big ###
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/No_SSL_new_big_test_tfidf_run0/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/No_SSL_new_big_tfidf_run1/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/No_SSL_new2_big_tfidf_run2/default/version_1/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/No_SSL_new2_big_tfidf_run3/default/version_1/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/No_SSL_new2_big_tfidf_run4/default/version_1/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/No_SSL_new_big_test_tfidf_run5/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/No_SSL_new_big_tfidf_run1/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/No_SSL_new_big_tfidf_run2/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/No_SSL_new_big_tfidf_run3/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/No_SSL_new_big_tfidf_run4/default/version_0/checkpoints/best_checkpoint_val.ckpt"

# MLP trained with self-supervised learning on auxiliary scTab data

# ckpt_path1 = os.path.join(MODEL_DIR, "SSL_atac_multiomics_20M_random_MAE_tfidftfidf_run0", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path2 = os.path.join(MODEL_DIR, "SSL_atac_multiomics_20M_random_MAE_tfidftfidf_run1", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "SSL_atac_multiomics_20M_random_MAE_tfidftfidf_run2", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "SSL_atac_multiomics_20M_random_MAE_tfidftfidf_run3", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "SSL_atac_multiomics_20M_random_MAE_tfidftfidf_run4", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")


### new ###
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_20M_random_MAE_tfidfnew_tfidf_run0/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_20M_random_MAE_tfidfnew_tfidf_run1/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_20M_random_MAE_tfidfnew_tfidf_run2/default/version_0/checkpoints/best_checkpoint_val.ckpt"

### new + big ###
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_20M_random_MAE_tfidfnew_big_tfidf_run0/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_20M_random_MAE_tfidfnew_big_tfidf_run1/default/version_0/checkpoints/best_checkpoint_val.ckpt"

# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_20M_random_MAE_tfidfnew_big_tfidf_run2/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_20M_random_MAE_tfidfnew_big_tfidf_run5/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_20M_random_MAE_tfidfnew2_big_tfidf_run2/default/version_1/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_20M_random_MAE_tfidfnew2_big_tfidf_run3/default/version_1/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_20M_random_MAE_tfidfnew2_big_tfidf_run4/default/version_1/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_20M_random_MAE_tfidftfidf_big_run3/default/version_0/checkpoints/best_checkpoint_val.ckpt"
ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_20M_random_MAE_tfidftfidf_big_run4/default/version_0/checkpoints/best_checkpoint_val.ckpt"


# MLP trained with self-supervised learning on not-auxiliary NeurIPS data

# ckpt_path = os.path.join(MODEL_DIR, "SSL_atac_multiomics_NeurIPS_random_MAE_tfidftfidf_run0", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "SSL_atac_multiomics_NeurIPS_random_MAE_tfidftfidf_run1", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "SSL_atac_multiomics_NeurIPS_random_MAE_tfidftfidf_run2", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "SSL_atac_multiomics_NeurIPS_random_MAE_tfidftfidf_run3", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")
# ckpt_path = os.path.join(MODEL_DIR, "SSL_atac_multiomics_NeurIPS_random_MAE_tfidftfidf_run4", "default", "version_0", "checkpoints", "best_checkpoint_val.ckpt")

### new ###
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_NeurIPS_random_MAE_tfidfnew_tfidf_run0/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_NeurIPS_random_MAE_tfidfnew_tfidf_run1/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_NeurIPS_random_MAE_tfidfnew_tfidf_run2/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_NeurIPS_random_MAE_tfidfnew_tfidf_run3/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_NeurIPS_random_MAE_tfidfnew_tfidf_run4/default/version_0/checkpoints/best_checkpoint_val.ckpt"

### new + big ###
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_NeurIPS_random_MAE_tfidfnew_big_tfidf_run0/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_NeurIPS_random_MAE_tfidfnew2_big_tfidf_run4/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_NeurIPS_random_MAE_tfidfnew_big_tfidf_run0/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_NeurIPS_random_MAE_tfidfnew_big_tfidf_run1/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_NeurIPS_random_MAE_tfidfnew_big_tfidf_run2/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_NeurIPS_random_MAE_tfidfnew_big_tfidf_run3/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_NeurIPS_random_MAE_tfidfnew_big_tfidf_run4/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_NeurIPS_random_MAE_tfidfnew_big_tfidf_run5/default/version_0/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_NeurIPS_random_MAE_tfidfnew2_big_tfidf_run2/default/version_1/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_NeurIPS_random_MAE_tfidfnew2_big_tfidf_run3/default/version_1/checkpoints/best_checkpoint_val.ckpt"
# ckpt_path = "/lustre/groups/ml01/workspace/till.richter/trained_models/final_models/multiomics/SSL_atac_multiomics_NeurIPS_random_MAE_tfidfnew2_big_tfidf_run4/default/version_0/checkpoints/best_checkpoint_val.ckpt"

In [199]:
model = MultiomicsMultiAutoencoder(
            mode='fine_tuning',
            model='MAE',
            dropout=0.11642113240634665,
            learning_rate=0.00011197711341004587,
            weight_decay=0.0010851761758488817,
            batch_size=int(4096 / 4),  # Reduce batch size for fine-tuning, predict 116k instead of 2k features
        )

In [200]:
model.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict'], strict=False)

<All keys matched successfully>

In [201]:
# Initialize the trainer
trainer = pl.Trainer(accelerator="cpu", devices=1)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [202]:
# Perform inference
predictions = trainer.predict(model, dataloaders=test_dataloader)

Predicting: |                                                                                                 …

In [203]:
all_preds = torch.cat([predictions[i][1] for i in range(len(predictions))], dim=0)
all_preds.shape

torch.Size([6896, 116490])

### Evaluation Metric from NeurIPS Challenge

Essentially a Pearson Correlation

In [204]:
def correlation_score(y_true, y_pred):
    """Scores the predictions according to the competition rules. 
    
    It is assumed that the predictions are not constant.
    
    Returns the average of each sample's Pearson correlation coefficient"""
    if type(y_true) == pd.DataFrame: y_true = y_true.values
    if type(y_pred) == pd.DataFrame: y_pred = y_pred.values
    if y_true.shape != y_pred.shape: raise ValueError("Shapes are different.")
    corrsum = 0
    for i in range(len(y_true)):
        corrsum += np.corrcoef(y_true[i], y_pred[i])[1, 0]
    return corrsum / len(y_true)

### Evaluation

In [205]:
# Calculate Test Pearson Correlation
test_mse = mean_squared_error(test_atac.toarray(), np.array(all_preds))
test_corr = correlation_score(test_atac.toarray(), np.array(all_preds))

print("MLP Test MSE:", test_mse)
print("MLP Test Correlation:", test_corr)

MLP Test MSE: 0.43552955382885333
MLP Test Correlation: 0.06276293699490376


In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Define your results in a dictionary
data = {
    'Model': ['Linear'] * 1 + ['Supervised'] * 3 + ['SSL - scTab'] * 3 + ['SSL - NeurIPS'] * 3,
    'Pearson Correlation': [
        0.16219452299615636,  # Linear
        0.1923985705574413, 0.19130795020056338, 0.19239857055063997,  # Supervised
        0.20528594586381807, 0.20170950377163432, 0.2028211667987072,  # SSL - scTab
        0.19100245776051722, 0.19069321282681703, 0.19116319342038793   # SSL - NeurIPS
    ]
}

# Create a DataFrame
mean_results_df = pd.DataFrame(data)

# Rename models
def rename_model(model_name):
    if model_name == "Supervised":
        return "Supervised"
    elif model_name == "SSL - scTab":
        return "Self-Supervised:\nRandom Mask\nscTab Dataset"
    elif model_name == "SSL - NeurIPS":
        return "Self-Supervised\nRandom Mask\nNeurIPS Dataset"
    elif model_name == "Linear":
        return "Linear"
    else:
        return model_name

mean_results_df['Model Type'] = mean_results_df['Model'].apply(rename_model)

# Define font and tick properties
font = {'family': 'sans-serif', 'size': 5}  # Adjust the size as needed
fontdict = {'family': 'sans-serif', 'fontsize': 5}  # Adjust the size as needed
tick_font = {'fontsize': 5, 'fontname': 'sans-serif'}  # For tick labels

# Set the colorblind friendly palette
# Setting the style for the plots
sns.set_theme(style="whitegrid")
sns.set_palette("colorblind")

# Get the list of colors in the palette
palette_colors = sns.color_palette("colorblind")

# Access the colors
color_supervised = palette_colors[0]  # First color
color_ssl = palette_colors[1]  # Second color
color_zeroshot = palette_colors[2]  # Third color
color_baseline = palette_colors[3]  # Fourth color
color_else1 = palette_colors[5]
color_else2 = palette_colors[6]
color_else3 = palette_colors[7]

# Define the order of models
model_order = ["Linear", "Supervised", "Self-Supervised\nRandom Mask\nNeurIPS Dataset", "Self-Supervised:\nRandom Mask\nscTab Dataset"]

# Define color mapping for each model type
model_colors = {
    "Supervised": color_supervised,
    "Self-Supervised:\nRandom Mask\nscTab Dataset": color_ssl,
    "Self-Supervised\nRandom Mask\nNeurIPS Dataset": color_ssl,
    "Linear": palette_colors[3]
}

# Box plot for Pearson Correlation
plt.figure(figsize=(3.5, 2.5))
ax1 = sns.boxplot(x='Model Type', y='Pearson Correlation', data=mean_results_df, linewidth=0.5, order=model_order, palette=model_colors)
ax1.set_xlabel('Model Type', fontdict=font)
ax1.set_ylabel('Pearson Correlation', fontdict=font)
ax1.set_title('Cross-Modality Prediction', fontdict=font)

# Set font for all tick labels to match the fontdict
ax1.set_xticklabels(ax1.get_xticklabels(), **tick_font)
ax1.set_yticklabels(ax1.get_yticklabels(), **tick_font)
plt.savefig(RESULTS_FOLDER + "/multiomics/atac_pearson_corr.svg", bbox_inches='tight')  # Save as SVG

plt.show()
