# Imports

In [1]:
from GENIE3 import *
import sys, os
sys.path.append(os.getcwd())
sys.path.append('/scratch/ab9738/dfdl_imputation/')
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import re
from scipy import stats
import SERGIO.SERGIO.sergio as sergio
from sklearn.metrics import roc_auc_score
from copy import deepcopy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

# Load Data

In [2]:
ds1_clean = np.load('../SERGIO/imputation_data/DS1/DS6_clean_iter_0.npy')

In [3]:
ds1_expr = np.load('../SERGIO/imputation_data/DS1/DS6_expr_iter_0.npy')

# Load Simulation

In [4]:
def parse_dataset_name(folder_name):
    pattern1 = r'De-noised_(\d+)G_(\d+)T_(\d+)cPerT_dynamics_(\d+)_DS(\d+)'
    pattern2 = r'De-noised_(\d+)G_(\d+)T_(\d+)cPerT_(\d+)_DS(\d+)'
    match_p1 = re.match(pattern1, folder_name)
    match_p2 = re.match(pattern2, folder_name)
    if match_p1:
        return {
            'number_genes': int(match_p1.group(1)),
            'number_bins': int(match_p1.group(2)),
            'number_sc': int(match_p1.group(3)),
            'dynamics': int(match_p1.group(4)),
            'dataset_id': int(match_p1.group(5)),
            "pattern": "De-noised_{number_genes}G_{number_bins}T_{number_sc}cPerT_dynamics_{dynamics}_DS{dataset_id}"
        }
    if match_p2:
        return {
            'number_genes': int(match_p2.group(1)),
            'number_bins': int(match_p2.group(2)),
            'number_sc': int(match_p2.group(3)),
            'dynamics': int(match_p2.group(4)),
            'dataset_id': int(match_p2.group(5)),
            "pattern": "De-noised_{number_genes}G_{number_bins}T_{number_sc}cPerT_{dynamics}_DS{dataset_id}"
        }
    return

def get_datasets():
    datasets = []
    for folder_name in os.listdir('../SERGIO/data_sets'):
        dataset_info = parse_dataset_name(folder_name)
        if dataset_info:
            datasets.append(dataset_info)
    return sorted(datasets, key=lambda x: x['dataset_id'])

In [5]:
data_info = get_datasets()[0]

In [6]:
sim = sergio.sergio(
        number_genes=data_info["number_genes"],
        number_bins=data_info["number_bins"], 
        number_sc=data_info["number_sc"],
        noise_params=1,
        decays=0.8, 
        sampling_state=15,
        noise_type='dpd'
    )

# Ground Truth

In [7]:
target_file = '../SERGIO/data_sets/De-noised_100G_9T_300cPerT_4_DS1/Interaction_cID_4.txt'

In [8]:
gt = np.zeros((100,100))
f = open(target_file,'r')
Lines = f.readlines()
f.close()
for j in range(len(Lines)):
    line = Lines[j]
    line_list = line.split(',')
    target_index = int(float(line_list[0]))
    num_regs = int(float(line_list[1]))
    for i in range(num_regs):
        try:
            reg_index = int(float(line_list[i+2]))
            gt[reg_index,target_index] = 1 
        except:
            continue

# Metrics Evaluation

Clean Dataset

In [15]:
VIM_clean = GENIE3(np.transpose(ds1_clean), nthreads=80, ntrees=100, regulators='all', gene_names=[str(s) for s in range(np.transpose(ds1_clean).shape[1])])

Tree method: RF
K: sqrt
Number of trees: 100


running jobs on 80 threads
Elapsed time: 59.96 seconds


In [17]:
roc_auc_score(gt.flatten(),VIM_clean.flatten())

0.6898146600908079

Add outlier noise

In [9]:
expr_O = sim.outlier_effect(ds1_expr, outlier_prob = 0.01, mean = 5, scale = 1)

In [10]:
ds1_O = np.concatenate(expr_O, axis=1)

In [27]:
VIM_O = GENIE3(np.transpose(ds1_O), nthreads=80, ntrees=100, regulators='all', gene_names=[str(s) for s in range(np.transpose(ds1_O).shape[1])])

Tree method: RF
K: sqrt
Number of trees: 100


running jobs on 80 threads
Elapsed time: 60.18 seconds


In [28]:
roc_auc_score(gt.flatten(),VIM_O.flatten())

0.7012030543049436

Add library noise on top

In [11]:
libFactor, expr_O_L = sim.lib_size_effect(expr_O, mean = 4.5, scale = 0.7)

In [12]:
ds1_O_L = np.concatenate(expr_O_L, axis=1)

In [31]:
VIM_O_L = GENIE3(np.transpose(ds1_O_L), nthreads=80, ntrees=100, regulators='all', gene_names=[str(s) for s in range(np.transpose(ds1_O_L).shape[1])])

Tree method: RF
K: sqrt
Number of trees: 100


running jobs on 80 threads
Elapsed time: 61.82 seconds


In [32]:
roc_auc_score(gt.flatten(),VIM_O_L.flatten())

0.5561637535230657

Add dropouts on top

In [13]:
binary_ind = sim.dropout_indicator(expr_O_L, shape = 8, percentile = 45)
expr_O_L_D = np.multiply(binary_ind, expr_O_L)

In [14]:
ds1_O_L_D = np.concatenate(expr_O_L_D, axis=1)

In [35]:
VIM_O_L_D = GENIE3(np.transpose(ds1_O_L_D), nthreads=80, ntrees=100, regulators='all', gene_names=[str(s) for s in range(np.transpose(ds1_O_L_D).shape[1])])

Tree method: RF
K: sqrt
Number of trees: 100


running jobs on 80 threads
Elapsed time: 42.22 seconds


In [37]:
roc_auc_score(gt.flatten(),VIM_O_L_D.flatten())

0.5184400955504734

Convert to UMI Counts

In [15]:
expr_O_L_D_C = sim.convert_to_UMIcounts(expr_O_L_D)

In [16]:
ds1_O_L_D_C = np.concatenate(expr_O_L_D_C, axis=1)

In [40]:
VIM_O_L_D_C = GENIE3(np.transpose(ds1_O_L_D_C), nthreads=80, ntrees=100, regulators='all', gene_names=[str(s) for s in range(np.transpose(ds1_O_L_D_C).shape[1])])

Tree method: RF
K: sqrt
Number of trees: 100


running jobs on 80 threads
Elapsed time: 12.80 seconds


In [41]:
roc_auc_score(gt.flatten(),VIM_O_L_D_C.flatten())

0.43300764371959344

# Denoising

In [14]:
ds1_noisy = ds1_O_L_D_C

Normal Imputation Followed by Normalization

In [158]:
def zero_impute(ds1):
    ds1[ds1 == 0] = np.nan
    for i in range(9):
        ds1_cell_type = ds1[:,i*300:(i+1)*300]
        mean_array = np.nanmean(ds1_cell_type, axis=1)
        var_array = np.nanvar(ds1_cell_type, axis=1)
        for j in range(100):
            np.nan_to_num(ds1_cell_type[j,:],copy=False,nan=np.random.normal(loc=mean_array[j],scale=np.sqrt(var_array[j])))
    ds1[ds1<0] = 0.0
    np.nan_to_num(ds1,copy=False)
    return(ds1)

In [161]:
ds1_imputed = zero_impute(ds1_noisy.astype(float32))

  mean_array = np.nanmean(ds1_cell_type, axis=1)
  var_array = np.nanvar(ds1_cell_type, axis=1)


In [162]:
lib_depth_matrix = np.tile(np.sum(ds1_imputed, axis=0),(100,1))

In [163]:
ds1_normalized = ds1_imputed/lib_depth_matrix

In [164]:
VIM_normalized = GENIE3(np.transpose(ds1_normalized), nthreads=80, ntrees=100, regulators='all',\
                        gene_names=[str(s) for s in range(np.transpose(ds1_normalized).shape[1])])

Tree method: RF
K: sqrt
Number of trees: 100


running jobs on 80 threads
Elapsed time: 61.83 seconds


In [165]:
roc_auc_score(gt.flatten(),VIM_normalized.flatten())

0.54083016237533

In [166]:
# percentage recovery
(0.541-0.433)/(0.690-0.433)

0.4202334630350197

Dataset Substitution

In [179]:
ds1_noisy = ds1_O_L_D_C

In [180]:
def substitute_dataset(ds1):
    ds1[ds1 == 0] = np.nan
    for i in range(9):
        ds1_cell_type = ds1[:,i*300:(i+1)*300]
        mean_array = np.nanmean(ds1_cell_type, axis=1)
        var_array = np.nanvar(ds1_cell_type, axis=1)
        for j in range(100):
            ds1_cell_type[j,:] = np.random.normal(loc=mean_array[j],scale=np.sqrt(var_array[j]),size=300)
    ds1[ds1<0] = 0.0
    np.nan_to_num(ds1,copy=False)
    return(ds1)

In [181]:
ds1_substitute = substitute_dataset(ds1_noisy.astype(float32))

  mean_array = np.nanmean(ds1_cell_type, axis=1)
  var_array = np.nanvar(ds1_cell_type, axis=1)


In [182]:
VIM_substitute = GENIE3(np.transpose(ds1_substitute), nthreads=80, ntrees=100, regulators='all',\
                        gene_names=[str(s) for s in range(np.transpose(ds1_substitute).shape[1])])

Tree method: RF
K: sqrt
Number of trees: 100


running jobs on 80 threads
Elapsed time: 72.90 seconds


In [183]:
roc_auc_score(gt.flatten(),VIM_substitute.flatten())

0.5452106200436374

Model based dataset substitution

In [125]:
ds1_noisy = ds1_O_L_D_C

In [126]:
device = 'cpu'

In [127]:
class CNNMultiCTNet(nn.Module):
    def __init__(self):
        super(CNNMultiCTNet, self).__init__()
        
        self.num_ct = 9
        
        # Separate 1D convolution layers for each ct
        self.conv_layers = nn.ModuleList([nn.Conv1d(in_channels=2, out_channels=2, kernel_size=1) for _ in range(self.num_ct)])
        
        # Separate fully connected layers for mu_{ct,g} and sigma_{ct,g} for each ct
        self.fc_mu_layers = nn.ModuleList([nn.Linear(100, 100) for _ in range(self.num_ct)])
        self.fc_sigma_layers = nn.ModuleList([nn.Linear(100, 100) for _ in range(self.num_ct)])

    def forward(self, x):
        # Input shape: (batch_size, 2, 9, 100)
        batch_size = x.size(0)
        
        # Prepare lists to store results for each ct
        mu_g_hat_list = []
        sigma_g_hat_list = []
        
        for ct in range(self.num_ct):
            # Extract the mu_{ct,g} and sigma_{ct,g} for this ct (shape: (batch_size, 2, 100))
            x_ct = x[:, :, ct, :]  # (batch_size, 2, 100)

            # Pass through the convolution layer specific to this ct
            x_ct = self.conv_layers[ct](x_ct)  # (batch_size, 2, 100)
            
            # Separate the mu_{ct,g} and sigma_{ct,g}
            mu_ct_g = x_ct[:, 0, :]  # (batch_size, 100)
            sigma_ct_g = x_ct[:, 1, :]  # (batch_size, 100)

            # Pass through fully connected layers specific to this ct
            mu_ct_g_hat = self.fc_mu_layers[ct](mu_ct_g)  # (batch_size, 100)
            sigma_ct_g_hat = self.fc_sigma_layers[ct](sigma_ct_g)  # (batch_size, 100)
            
            # Apply activation (e.g., ReLU) if necessary
            mu_ct_g_hat = F.relu(mu_ct_g_hat)
            sigma_ct_g_hat = F.relu(sigma_ct_g_hat)
            
            # Sum the mu_ct_g_hat to normalize it
            sum_mu_ct_g_hat = torch.sum(mu_ct_g_hat, dim=1, keepdim=True)  # (batch_size, 1)
            
            # Scale the mu_ct_g_hat and sigma_ct_g_hat
            mu_ct_g_hat = mu_ct_g_hat / sum_mu_ct_g_hat  # Normalize mu_ct_g_hat so the sum is 1
            sigma_ct_g_hat = sigma_ct_g_hat / sum_mu_ct_g_hat  # Scale sigma_ct_g_hat with the same factor
            
            # Collect the results for each ct
            mu_g_hat_list.append(mu_ct_g_hat.unsqueeze(1))  # (batch_size, 1, 100)
            sigma_g_hat_list.append(sigma_ct_g_hat.unsqueeze(1))  # (batch_size, 1, 100)
        
        # Stack the results along the ct dimension
        mu_g_hat = torch.cat(mu_g_hat_list, dim=1)  # (batch_size, 9, 100)
        sigma_g_hat = torch.cat(sigma_g_hat_list, dim=1)  # (batch_size, 9, 100)
        
        # Combine mu and sigma along the 2nd dimension (channels)
        output = torch.stack([mu_g_hat, sigma_g_hat], dim=1)  # (batch_size, 2, 9, 100)
        
        return output

In [128]:
model = CNNMultiCTNet().to(device)
model.load_state_dict(torch.load('./model_ds1.pth'))

<All keys matched successfully>

In [129]:
x_means, x_stds = np.zeros((9,100)),np.zeros((9,100))

In [130]:
def find_x(ds1, x_means, x_stds):
    ds1[ds1 == 0] = np.nan
    for i in range(9):
        ds1_cell_type = ds1[:,i*300:(i+1)*300]
        x_means[i,:] = np.nanmean(ds1_cell_type, axis=1)
        x_stds[i,:] = np.nanstd(ds1_cell_type, axis=1)
    x = np.stack([x_means, x_stds], axis=0)
    x = np.expand_dims(x, axis=0)
    x = np.nan_to_num(x)
    x = torch.tensor(x,dtype=torch.float32).to(device)
    return(x)

In [131]:
x = find_x(ds1_noisy.astype(float32),x_means,x_stds)

In [132]:
y = model(x)

In [133]:
y = y.detach().cpu().numpy()

In [134]:
y.shape

(1, 2, 9, 100)

In [135]:
y_means = y[0,0,:,:]
y_stds = y[0,1,:,:]

In [136]:
def simulate_dataset(ds1,y_means,y_stds):
    ds = np.zeros_like(ds1)
    for i in range(9):
        ds_cell_type = ds[:,i*300:(i+1)*300]
        mean_array = y_means[i,:]
        std_array = y_stds[i,:]
        for j in range(100):
            ds_cell_type[j,:] = np.random.normal(loc=mean_array[j],scale=std_array[j],size=300)
    ds[ds<0] = 0.0
    np.nan_to_num(ds,copy=False)
    return(ds)

In [137]:
ds1_simulated = simulate_dataset(ds1_noisy.astype(float32), y_means, y_stds)

In [138]:
VIM_simulated = GENIE3(np.transpose(ds1_simulated), nthreads=80, ntrees=100, regulators='all',\
                        gene_names=[str(s) for s in range(np.transpose(ds1_simulated).shape[1])])

Tree method: RF
K: sqrt
Number of trees: 100


running jobs on 80 threads
Elapsed time: 36.27 seconds


In [139]:
roc_auc_score(gt.flatten(),VIM_simulated.flatten())

0.5381243047366234

In [94]:
# percentage recovery
(0.555-0.433)/(0.690-0.433)

0.474708171206226