Reviewers asked to compare our method to PCA, ICA, SVD, and VAEs with sparsity priors.
I can start with PCA, ICA, and SVD.
I guess what I should do is to run these on the data and compare the redundancy, recovery, and coverage for the simulation,
and the GO annotation for the real data.

In [1]:
# go 2 steps back

import os
import sys

os.chdir('../..')

# Simulation

In [2]:
import torch
import pandas as pd
import numpy as np
import random
import gc

from src.functions.sae_analysis_sim3 import *

In [3]:
device = 'cuda:2' if torch.cuda.is_available() else 'cpu'

In [4]:
complexity = 'high'
n_samples = 100000
data_dir = '/home/vschuste/data/simulation/'

for seed in range(10):
    temp_y = torch.load(data_dir+'large_{}-complexity_rs{}_y.pt'.format(complexity, seed), weights_only=False)
    temp_x0 = torch.load(data_dir+'large_{}-complexity_rs{}_x0.pt'.format(complexity, seed), weights_only=False)
    temp_x1 = torch.load(data_dir+'large_{}-complexity_rs{}_x1.pt'.format(complexity, seed), weights_only=False)
    temp_x2 = torch.load(data_dir+'large_{}-complexity_rs{}_x2.pt'.format(complexity, seed), weights_only=False)
    temp_ct = torch.load(data_dir+'large_{}-complexity_rs{}_ct.pt'.format(complexity, seed), weights_only=False)
    temp_cov = torch.load(data_dir+'large_{}-complexity_rs{}_co.pt'.format(complexity, seed), weights_only=False)
    if seed == 0:
        rna_counts = temp_y
        x0 = temp_x0
        x1 = temp_x1
        x2 = temp_x2
        ct = temp_ct
        co = temp_cov
    else:
        rna_counts = torch.cat((rna_counts, temp_y), dim=0)
        x0 = torch.cat((x0, temp_x0), dim=0)
        x1 = torch.cat((x1, temp_x1), dim=0)
        x2 = torch.cat((x2, temp_x2), dim=0)
        ct = torch.cat((ct, temp_ct), dim=0)
        co = torch.cat((co, temp_cov), dim=0)
# limit to the training data
n_samples_train = int(n_samples*0.9)
rna_counts = rna_counts[:n_samples_train]
x0 = x0[:n_samples_train]
x1 = x1[:n_samples_train]
x2 = x2[:n_samples_train]
ct = ct[:n_samples_train]
co = co[:n_samples_train]
# also make this faster by taking every 10th sample
rna_counts = rna_counts[::3]
x0 = x0[::3]
x1 = x1[::3]
x2 = x2[::3]
ct = ct[::3]
co = co[::3]

print("Data loaded.")
print(f"Running on a subset with {rna_counts.shape[0]} samples.")

Data loaded.
Running on a subset with 30000 samples.


### analysis function

In [None]:
def pearsonr(a,b):
    #cov = torch.mean((a - a.mean(dim=0).unsqueeze(0)).unsqueeze(1) * (b - b.mean(dim=0).unsqueeze(0)).unsqueeze(-1), dim=0)
    cov = torch.mean((a - a.mean(dim=0)) * (b - b.mean()).unsqueeze(-1), dim=0)
    #std_a = a.std(dim=0)
    std_a = a.std(dim=0)
    #std_b = b.std(dim=0)
    std_b = b.std()
    return cov / (std_a * std_b)

def get_correlations_with_data(activations, unique_activs, comparison_data, device='cpu'):
    correlations_p = torch.zeros((len(unique_activs), comparison_data.shape[1]))
    
    # Move data to GPU once, not in every loop iteration
    with torch.no_grad():        
        # Process in smaller batches to avoid memory issues
        batch_size = 5000  # Adjust based on your GPU memory
        for start_idx in range(0, comparison_data.shape[1], batch_size):
            end_idx = min(start_idx + batch_size, comparison_data.shape[1])
            
            # Calculate correlations for the batch
            for j in tqdm.tqdm(range(0, comparison_data.shape[1])):
                correlations_p[start_idx:end_idx, j] = pearsonr(activations[:,start_idx:end_idx].to(device), comparison_data[:, j].to(device)).cpu()
            
            gc.collect()
            torch.cuda.empty_cache()
        
        gc.collect()
        torch.cuda.empty_cache()
        
    return correlations_p.numpy()

def get_number_of_redundant_features(activations, threshold=0.95, device='cpu'):
    # compute correlations between all active features
    redundant_set = set()
    # Move data to GPU once, not in every loop iteration
    with torch.no_grad():        
        # Process in smaller batches to avoid memory issues
        batch_size = 5000  # Adjust based on your GPU memory
        for j in tqdm.tqdm(range(0, activations.shape[1])):
            corr_temp = torch.zeros(activations.shape[1])
            for start_idx in range(0, activations.shape[1], batch_size):
                end_idx = min(start_idx + batch_size, activations.shape[1])
                corr_temp[start_idx:end_idx] = pearsonr(activations[:,start_idx:end_idx].to(device), activations[:, j].to(device)).cpu()
                gc.collect()
                torch.cuda.empty_cache()
        
        redundant_set.update([j for j in np.where(corr_temp.numpy() > threshold)[0]])
    n_redundant = len(redundant_set)
    return n_redundant

def analyze_dimreduction_methods(latent, comparison_data, redundant=False, device='cpu'):
    if redundant:
        n_redundant = get_number_of_redundant_features(latent, threshold=0.95, device=device)
    else:
        n_redundant = None
    corrs = get_correlations_with_data(latent, np.arange(latent.shape[1]), comparison_data, device=device)
    n_per_attribute = get_n_features_per_attribute(corrs)
    highest_corrs = get_highest_corr_per_attribute(corrs)
    return n_redundant, n_per_attribute, highest_corrs

## PCA

In [5]:
# run PCA on the data
print("Running PCA...")
from sklearn.decomposition import PCA
pca = PCA(n_components=rna_counts.shape[1])
pca.fit(rna_counts)
embed = torch.tensor(pca.transform(rna_counts))

Running PCA...


In [58]:
def torch_corrcoef(x, epsilon=1e-6):
    # calculate the covariance
    mean_x = torch.mean(x, dim=1)
    xm = x - mean_x.unsqueeze(1)
    cov = xm @ xm.T# / (x.shape[1] - 1)
    # calculate the standard deviation
    std_x = torch.sqrt(torch.diag(cov)) + epsilon
    # calculate the correlation
    correlation = cov / (std_x.unsqueeze(0) * std_x.unsqueeze(1))
    return correlation

def get_number_of_redundant_features(activations, threshold=0.95, device='cpu'):
    # compute correlations between all active features
    redundant_set = set()
    # Move data to GPU once, not in every loop iteration
    with torch.no_grad():        
        # Process in smaller batches to avoid memory issues
        batch_size = 5000  # Adjust based on your GPU memory
        for j in tqdm.tqdm(range(0, activations.shape[1])):
            corr_temp = torch.zeros(activations.shape[1])
            for start_idx in range(0, activations.shape[1], batch_size):
                end_idx = min(start_idx + batch_size, activations.shape[1])
                corr_temp[start_idx:end_idx] = pearsonr(activations[:,start_idx:end_idx].to(device), activations[:, j].to(device)).cpu()
                gc.collect()
                torch.cuda.empty_cache()
        
        redundant_set.update([j for j in np.where(corr_temp.numpy() > threshold)[0]])
    n_redundant = len(redundant_set)
    return n_redundant

n_redundant = get_number_of_redundant_features(x0, threshold=0.95, device=device)

100%|██████████| 100/100 [00:07<00:00, 12.55it/s]


In [38]:
gc.collect()
torch.cuda.empty_cache()

In [10]:
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

In [None]:
print("Running y")
y_metrics = analyze_dimreduction_methods(embed, rna_counts, redundant=False, device=device)
gc.collect()
torch.cuda.empty_cache()
print("Running x0")
x0_metrics = analyze_dimreduction_methods(embed, x0, redundant=False, device=device)
gc.collect()
torch.cuda.empty_cache()
print("Running x1")
x1_metrics = analyze_dimreduction_methods(embed, x1, redundant=False, device=device)
gc.collect()
torch.cuda.empty_cache()
print("Running x2")
x2_metrics = analyze_dimreduction_methods(embed, x2, redundant=False, device=device)
gc.collect()
torch.cuda.empty_cache()
print("Running ct")
ct_metrics = analyze_dimreduction_methods(embed, ct.float().unsqueeze(1), redundant=False, device=device)
gc.collect()
torch.cuda.empty_cache()
print("Running co")
co_metrics = analyze_dimreduction_methods(embed, co.float().unsqueeze(1), redundant=False, device=device)
gc.collect()
torch.cuda.empty_cache()

Running y
Running x0


100%|██████████| 100/100 [00:02<00:00, 46.31it/s]


Running x1


 49%|████▉     | 49/100 [00:00<00:00, 51.79it/s]


KeyboardInterrupt: 

In [56]:
ct_metrics[1][0]

0

In [None]:
# save metrics
df_y = pd.DataFrame({'n_redundant': y_metrics[0], 'n_per_attribute (mean)': y_metrics[1].mean(), 'n_per_attribute (max)': y_metrics[1].max(), 'highest_corrs (mean)': y_metrics[2].mean(), 'highest_corrs (max)': y_metrics[2].max(), 'variable': 'y'})
df_x0 = pd.DataFrame({'n_redundant': y_metrics[0], 'n_per_attribute (mean)': x0_metrics[1].mean(), 'n_per_attribute (max)': x0_metrics[1].max(), 'highest_corrs (mean)': x0_metrics[2].mean(), 'highest_corrs (max)': x0_metrics[2].max(), 'variable': 'x0'})
df_x1 = pd.DataFrame({'n_redundant': y_metrics[0], 'n_per_attribute (mean)': x1_metrics[1].mean(), 'n_per_attribute (max)': x1_metrics[1].max(), 'highest_corrs (mean)': x1_metrics[2].mean(), 'highest_corrs (max)': x1_metrics[2].max(), 'variable': 'x1'})
df_x2 = pd.DataFrame({'n_redundant': y_metrics[0], 'n_per_attribute (mean)': x2_metrics[1].mean(), 'n_per_attribute (max)': x2_metrics[1].max(), 'highest_corrs (mean)': x2_metrics[2].mean(), 'highest_corrs (max)': x2_metrics[2].max(), 'variable': 'x2'})
df_ct = pd.DataFrame({'n_redundant': y_metrics[0], 'n_per_attribute (mean)': ct_metrics[1][0], 'n_per_attribute (max)': ct_metrics[1][0], 'highest_corrs (mean)': ct_metrics[2][0], 'highest_corrs (max)': ct_metrics[2][0], 'variable': 'ct'})
df_co = pd.DataFrame({'n_redundant': y_metrics[0], 'n_per_attribute (mean)': co_metrics[1][0], 'n_per_attribute (max)': co_metrics[1][0], 'highest_corrs (mean)': co_metrics[2][0], 'highest_corrs (max)': co_metrics[2][0], 'variable': 'co'})
df_pca = pd.concat([df_y, df_x0, df_x1, df_x2, df_ct, df_co], axis=0)
df_pca.to_csv('03_results/reports/files/sim2L_pca_metrics.csv', index=False)

## ICA

In [None]:
# perform ICA
print("Running ICA...")
from sklearn.decomposition import FastICA
ica = FastICA(n_components=rna_counts.shape[1], random_state=0)
ica.fit(rna_counts)
#embed = torch.tensor(ica.transform(rna_counts))

Running ICA...




In [None]:
embed = torch.tensor(ica.transform(rna_counts))

## SVD

In [None]:
# perform SVD
print("Running SVD...")
from sklearn.decomposition import TruncatedSVD
svd = TruncatedSVD(n_components=rna_counts.shape[1])
svd.fit(rna_counts)
embed = torch.tensor(svd.transform(rna_counts))

Running SVD...


In [None]:
svd.explained_variance_ratio_

## Sparse VAE

In [5]:
from src.models.sparse_vae import *

# Create a VAE with Laplace prior
input_dim = rna_counts.shape[1]
scaling_factor = 1.0
latent_dim = int(scaling_factor * 150)
svae = PriorVAE(
    input_dim=input_dim,
    hidden_dim=int(abs(input_dim - latent_dim) / 2),
    latent_dim=latent_dim,
    prior_type='laplace'  # Options: 'gaussian', 'laplace', 'cauchy'
)
# write a dataloader for the rna_counts data
from torch.utils.data import DataLoader, TensorDataset
# Assuming rna_counts is a PyTorch tensor of shape (n_samples, n_features)
batch_size = 128
data_loader = DataLoader(
    TensorDataset(rna_counts),
    batch_size=batch_size,
    shuffle=True
)

svae.to(device)
# Set up optimizer
optimizer = torch.optim.Adam(svae.parameters(), lr=1e-4)

In [6]:
# Train
svae = train_vae(svae, optimizer, data_loader, epochs=100, device=device)
# Get the latent representation in batches
svae.eval()
batch_size = 5000
latent_representations = []
for start_idx in range(0, rna_counts.shape[0], batch_size):
    end_idx = min(start_idx + batch_size, rna_counts.shape[0])
    with torch.no_grad():
        latent_batch = svae.encode(rna_counts[start_idx:end_idx].to(device))
        latent_representations.append(latent_batch.cpu())
embed = torch.cat(latent_representations, dim=0)

Training VAE...


  1%|          | 1/100 [00:03<05:59,  3.63s/it]

Epoch 1/100, Loss: 969.3763, Recon: 203.6758, KL: 7657.0054


  2%|▏         | 2/100 [00:07<05:41,  3.49s/it]

Epoch 2/100, Loss: 945.6853, Recon: 156.3649, KL: 7893.2042


  3%|▎         | 3/100 [00:10<05:34,  3.44s/it]

Epoch 3/100, Loss: 941.7326, Recon: 152.6639, KL: 7890.6872


  4%|▍         | 4/100 [00:13<05:28,  3.42s/it]

Epoch 4/100, Loss: 939.3570, Recon: 151.2392, KL: 7881.1775


  5%|▌         | 5/100 [00:17<05:24,  3.41s/it]

Epoch 5/100, Loss: 938.2517, Recon: 149.4475, KL: 7888.0420


  6%|▌         | 6/100 [00:20<05:20,  3.41s/it]

Epoch 6/100, Loss: 937.1631, Recon: 147.2787, KL: 7898.8445


  7%|▋         | 7/100 [00:24<05:16,  3.41s/it]

Epoch 7/100, Loss: 931.3068, Recon: 145.0820, KL: 7862.2480


  8%|▊         | 8/100 [00:27<05:13,  3.41s/it]

Epoch 8/100, Loss: 936.6593, Recon: 147.6109, KL: 7890.4841


  9%|▉         | 9/100 [00:30<05:10,  3.41s/it]

Epoch 9/100, Loss: 933.7231, Recon: 143.8287, KL: 7898.9435


 10%|█         | 10/100 [00:34<05:07,  3.41s/it]

Epoch 10/100, Loss: 931.7667, Recon: 143.7105, KL: 7880.5620


 11%|█         | 11/100 [00:37<05:04,  3.42s/it]

Epoch 11/100, Loss: 931.7539, Recon: 143.4995, KL: 7882.5438


 12%|█▏        | 12/100 [00:41<05:01,  3.42s/it]

Epoch 12/100, Loss: 932.3440, Recon: 143.2713, KL: 7890.7275


 13%|█▎        | 13/100 [00:44<04:58,  3.43s/it]

Epoch 13/100, Loss: 929.1254, Recon: 142.6935, KL: 7864.3188


 14%|█▍        | 14/100 [00:47<04:55,  3.43s/it]

Epoch 14/100, Loss: 931.4654, Recon: 143.5247, KL: 7879.4076


 15%|█▌        | 15/100 [00:51<04:52,  3.44s/it]

Epoch 15/100, Loss: 929.9656, Recon: 142.8737, KL: 7870.9194


 16%|█▌        | 16/100 [00:54<04:49,  3.44s/it]

Epoch 16/100, Loss: 930.8343, Recon: 142.6945, KL: 7881.3972


 17%|█▋        | 17/100 [00:58<04:46,  3.45s/it]

Epoch 17/100, Loss: 931.1044, Recon: 142.4506, KL: 7886.5376


 18%|█▊        | 18/100 [01:01<04:43,  3.46s/it]

Epoch 18/100, Loss: 933.0772, Recon: 143.0218, KL: 7900.5540


 19%|█▉        | 19/100 [01:05<04:40,  3.46s/it]

Epoch 19/100, Loss: 929.2501, Recon: 142.2766, KL: 7869.7347


 20%|██        | 20/100 [01:08<04:37,  3.47s/it]

Epoch 20/100, Loss: 930.9763, Recon: 143.0353, KL: 7879.4106


 21%|██        | 21/100 [01:12<04:34,  3.48s/it]

Epoch 21/100, Loss: 930.6695, Recon: 143.2789, KL: 7873.9056


 22%|██▏       | 22/100 [01:15<04:31,  3.48s/it]

Epoch 22/100, Loss: 930.6161, Recon: 144.0991, KL: 7865.1697


 23%|██▎       | 23/100 [01:19<04:28,  3.49s/it]

Epoch 23/100, Loss: 929.1857, Recon: 142.2489, KL: 7869.3674


 24%|██▍       | 24/100 [01:22<04:25,  3.50s/it]

Epoch 24/100, Loss: 932.6307, Recon: 142.5204, KL: 7901.1035


 25%|██▌       | 25/100 [01:26<04:22,  3.50s/it]

Epoch 25/100, Loss: 930.6281, Recon: 142.0378, KL: 7885.9029


 26%|██▌       | 26/100 [01:29<04:19,  3.50s/it]

Epoch 26/100, Loss: 929.5031, Recon: 142.4815, KL: 7870.2167


 27%|██▋       | 27/100 [01:33<04:15,  3.51s/it]

Epoch 27/100, Loss: 930.6729, Recon: 142.1355, KL: 7885.3740


 27%|██▋       | 27/100 [01:34<04:15,  3.50s/it]


KeyboardInterrupt: 

# Single-cell