In [1]:
import math
from tqdm import tqdm
import scvi
import torch
import numpy as np
import pytorch_lightning as pl
import os

import scanpy as sc

from pytorch_lightning.callbacks import TQDMProgressBar
from torchcfm.conditional_flow_matching import *
import scanpy as sc
from datamodules import  SCFMDataset, cfm_collate, StratifiedBatchSampler, ot_collate
from torch.utils.data import RandomSampler
from sc_etl_utils import *
from arch import *
import json

import scvi
import torch
import numpy as np
import pytorch_lightning as pl
import os

import scanpy as sc

from torchcfm.conditional_flow_matching import *
import scanpy as sc
import hashlib
from llm import MAE

In [2]:
# load some data
adata = sc.read_h5ad('/orcd/archive/abugoot/001/Projects/dlesman/datasets/satija_IFNB_HVG_and_perturbed_genes_raw.h5ad')

In [3]:
gene_map = {k: i for i, k in enumerate(adata.var.index)}
gene_map = gene_map | {'NT': max(gene_map.values()) + 1}
gene_unmap = {gene_map[k]: k for k in gene_map}
perts = adata.obs.gene.unique().map(gene_map)
adata.obs['pert_type'] = adata.obs.gene.map(gene_map)
pert_ids = np.array(adata.obs['pert_type'])
pert_mat = np.arange(pert_ids.max() + 1)[:, None]

In [4]:
cell_col = 'cell_type'
pert_col = 'pert_type'
control_pert, holdout_cells, holdout_perts = gene_map['NT'], ['HT29'], [gene_map['USP18']]

In [5]:
from sc_etl_utils import *
control_idx, pert_idx, eval_idx, eval_cell_idx, eval_pert_idx = get_train_eval_idxs(
    adata, control_pert, holdout_cells, holdout_perts, cell_col=cell_col, pert_col=pert_col
)

_, _, cell_types = get_identity_features(
    adata, cell_col=cell_col, pert_col=pert_col, cell_type_features=False
)

adata.obsm["standard"] = adata.X
X = adata.obsm["standard"]
X = X.toarray()
X = X / X.sum(axis=1)[:, None]
X = np.log(X * 10_000. + 1)

control_train, pert_train, pert_ids_train, control_cell_types, pert_cell_types, control_eval, pert_eval, pert_ids_eval = get_train_eval(
    X, pert_ids, cell_types, control_idx, pert_idx, eval_idx, eval_cell_idx, eval_pert_idx
)


Controls: 14582, Perturbations: 313080,  Eval: 880


In [6]:
train = np.vstack([control_train, pert_train])

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

# Create a custom Dataset class
class NumpyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        return torch.from_numpy(self.data[index]).float()

    def __len__(self):
        return len(self.data)

# Create an instance of the dataset
dataset = NumpyDataset(train)

In [7]:
batch_size = 512
dset = SCFMDataset(
    control_train, pert_train, 
    pert_ids_train, pert_mat, 
    control_cell_types, pert_cell_types,
    batch_size=batch_size, size=X.shape[0]
)
ns = np.array([[t.shape[0] for t in ts] for ts in dset.target])
dl = torch.utils.data.DataLoader(
    dset, collate_fn=ot_collate, 
    batch_sampler=StratifiedBatchSampler(
        ns=ns, batch_size=512
    )
)

Strata probs [0.         0.0001086  0.00030344 0.00039606 0.00040565 0.00041523
 0.00042162 0.00043439 0.00045356 0.00045995 0.00049189 0.00052063
 0.00054938 0.00056535 0.00061007 0.00062284 0.00062604 0.00062604
 0.00063562 0.00068673 0.00075061 0.00076977 0.00077297 0.00077935
 0.00079852 0.0008081  0.00084323 0.0008624  0.00087837 0.00088156
 0.00088795 0.00088795 0.00089753 0.00089753 0.0009167  0.00092628
 0.00093906 0.00094864 0.00099016 0.00099016 0.00100294 0.00104127
 0.00104446 0.00105085 0.00109237 0.00110834 0.00110834 0.00111154
 0.00111473 0.00111473 0.0011307  0.0011339  0.00113709 0.00116584
 0.00117222 0.00117222 0.00121375 0.00121694 0.00122333 0.00122652
 0.00124569 0.00124888 0.00124888 0.00126485 0.00128082 0.00129999
 0.00130318 0.00130318 0.00131596 0.00131596 0.00132554 0.00133193
 0.00133512 0.0013447  0.00135109 0.00137026 0.00137984 0.00137984
 0.00139262 0.0014022  0.00140539 0.00140539 0.00141178 0.00141497
 0.00141817 0.00143733 0.00146288 0.00146608 0.00

In [8]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np


class GeneEmbedding(torch.nn.Module):
    def __init__(self, input_dim, emb_dim=128):
        super().__init__()
        self.pos = torch.nn.Parameter(torch.zeros(1, input_dim, emb_dim))
        nn.init.normal_(self.pos)
        
        
class PertEmbedder(torch.nn.Module):
    def __init__(self, gene_embedding):
        super().__init__()
        _, input_dim, emb_dim = gene_embedding.pos.shape
        self.gene_embedding = gene_embedding
        self.pert_token = torch.nn.Parameter(torch.zeros(emb_dim))
        nn.init.normal_(self.pert_token)
        
    def forward(self, pert_index, pert_expression):
        pert_pos = self.gene_embedding.pos[:, pert_index][0]

        pert_embed_and_expr = torch.cat(
            (
                pert_pos + self.pert_token, 
                pert_expression.unsqueeze(-1)
            ), dim=-1
        )
        return pert_embed_and_expr.unsqueeze(1)
    
class CellEncoder(torch.nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_dim, dropout_rate=0.1):
        super().__init__()
        self.input_dim = input_dim  
        self.latent_dim = latent_dim  # Dimension of the latent space
        self.dropout_rate = dropout_rate  
        

        # Encoder network definitions
        self.encoder_fc1 = nn.Linear(input_dim, hidden_dim, bias=False) 
        self.encoder_bn1 = nn.BatchNorm1d(hidden_dim)  
        self.encoder_fc2 = nn.Linear(hidden_dim, hidden_dim, bias=False) 
        self.encoder_bn2 = nn.BatchNorm1d(hidden_dim) 
        self.fc_mu = nn.Linear(hidden_dim, latent_dim, bias=True) 
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim, bias=True) 

    # Encoder function
    def forward(self, x):
        h = F.leaky_relu(self.encoder_bn1(self.encoder_fc1(x)))  
        h = F.dropout(h, p=self.dropout_rate, training=self.training)  
        h = F.leaky_relu(self.encoder_bn2(self.encoder_fc2(h)))  
        return nn.ELU()(self.fc_mu(h))# , self.fc_logvar(h)  


class ExprPred(torch.nn.Module):
    def __init__(self,
                 gene_embedding,
                 ff_dim=128,
                 ) -> None:
        super().__init__()
        
        _, _, emb_dim = gene_embedding.pos.shape
        self.gene_embedding = gene_embedding

        self.pred_expr = torch.nn.Sequential(
            torch.nn.Linear(emb_dim + emb_dim, ff_dim),
            torch.nn.SELU(),
            torch.nn.Linear(ff_dim, ff_dim),
            torch.nn.SELU(),
            torch.nn.Linear(ff_dim, ff_dim),
            torch.nn.SELU(),
            torch.nn.Linear(ff_dim, ff_dim),
            torch.nn.SELU(),
            torch.nn.Linear(ff_dim, ff_dim),
            torch.nn.SELU(),
            torch.nn.Linear(ff_dim, 1),
            torch.nn.ELU()
        ) 
        
        self.pred_bin = torch.nn.Sequential(
            torch.nn.Linear(emb_dim + emb_dim, ff_dim),
            torch.nn.SELU(),
            torch.nn.Linear(ff_dim, ff_dim),
            torch.nn.SELU(),
            torch.nn.Linear(ff_dim, ff_dim),
            torch.nn.SELU(),
            torch.nn.Linear(ff_dim, ff_dim),
            torch.nn.SELU(),
            torch.nn.Linear(ff_dim, ff_dim),
            torch.nn.SELU(),
            torch.nn.Linear(ff_dim, 1),
            torch.nn.Sigmoid()
        ) 

    def forward(self, cell_embedding, pred_idx):
        embed_and_cell_embed = torch.cat(
            (
            torch.tile(cell_embedding.unsqueeze(1), (1, idx.shape[0], 1)),
            torch.tile(self.gene_embedding.pos[:, pred_idx], (cell_embedding.shape[0], 1, 1))
            ), dim=-1
        )
        pred_expr = self.pred_expr(embed_and_cell_embed) + 1
        pred_bin = self.pred_bin(embed_and_cell_embed)
        
        return pred_expr.squeeze(), pred_bin.squeeze()

class CMLP(pl.LightningModule):
    def __init__(self, feat_dim, cond_dim, out_dim=None, w1=128, w2=128, n_combo_layer=4, n_cond_layer=3, time_varying=False):
        super().__init__()
        self.time_varying = time_varying
        if out_dim is None:
            out_dim = feat_dim
        self.combo_net = torch.nn.Sequential(
            torch.nn.Linear(feat_dim + (1 if time_varying else 0) + cond_dim, w1), torch.nn.SELU(),
            *([torch.nn.Linear(w1, w1), torch.nn.SELU()] * n_combo_layer),
            torch.nn.Linear(w1, out_dim)
        )
        self.cond = None
        
    def forward(self, x, cond=None):
        if cond is None:
            cond = self.cond
        # cond = self.cond_net(cond)
        return self.combo_net(torch.cat([x, cond], dim=-1))

    
class BernoulliSampleLayer(nn.Module):
    def __init__(self):
        super(BernoulliSampleLayer, self).__init__()

    def forward(self, probs):
        sample = torch.bernoulli(probs)
        return sample + probs - probs.detach()    

    
class MAE(torch.nn.Module):
    def __init__(self,
                 input_dim,
                 ff_dim=128,
                 emb_dim=128,
                 encoder_layer=6,
                 ) -> None:
        super().__init__()
        
        self.gene_embedding = GeneEmbedding(input_dim=input_dim, emb_dim=emb_dim)
        self.pert_embedding = PertEmbedder(self.gene_embedding)
        self.encoder = CellEncoder(
            input_dim, emb_dim, hidden_dim=ff_dim
        )
        self.recon = ExprPred(self.gene_embedding)
        self.sparse_sampler = BernoulliSampleLayer()
        
        self.flow = CMLP(feat_dim=emb_dim, cond_dim=emb_dim, time_varying=True, w1=ff_dim)
        self.cfm_sampler = ExactOptimalTransportConditionalFlowMatcher(sigma=0.1)

    def forward(self, expr):
        cell_embedding = self.encoder(expr)
        # cell_embedding = features.mean(axis=1)
        return cell_embedding #, pred_expr.squeeze(), pred_bin.squeeze()
    
    def sparsify(self, pred_expr, pred_bin):
        sparsity = self.sparse_sampler(pred_bin)
        pred_expr *= sparsity
        return pred_expr
    
    def ae_loss(self, batch_emb, batch, gene_ids, lambd=0.5, return_recon=False):
        batch_bin = (batch > 0).float()
        batch_recon, batch_bin_recon = self.recon(batch_emb, gene_ids)
        recon_loss = torch.mean(batch_bin * (batch_recon - batch[:, gene_ids])**2) # / minibatch_size
        bin_loss = F.binary_cross_entropy(batch_bin_recon, batch_bin[:, gene_ids])
        loss = lambd * recon_loss + (1 - lambd) * bin_loss
        if return_recon:
            return loss, batch_recon, batch_bin_recon
        return loss
    
    def flow_loss(self, source_emb, target_emb, cond):
        t, xt, ut = self.cfm_sampler.sample_location_and_conditional_flow(
            source_emb, target_emb
        )

        inp = torch.cat([xt, t[:, None]], dim=-1)
        vt = model.flow(inp, cond)
        return torch.nn.functional.mse_loss(vt, ut) 


In [9]:
model = MAE(
    X.shape[1], 
    emb_dim=256, 
    encoder_layer=4,
    ff_dim=256
)
# model = torch.load(f"llm/v9")
device = 'cuda'
# device = 'cpu'
model = model.to(device)

In [10]:
base_learning_rate = 2e-4
weight_decay=0.0
total_epoch = 1000
warmup_epoch = 5
save_dir = "ae/v"
optim = torch.optim.Adam(model.parameters(), lr=base_learning_rate)
lr_func = lambda epoch: min((epoch + 1) / (warmup_epoch + 1e-5), 0.5 * (math.cos(epoch / total_epoch * math.pi) + 1))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_func, verbose=True)

Adjusting learning rate of group 0 to 4.0000e-05.


In [11]:
use_sparsity_loss = False
use_mask_task = True
use_active_weights = False
lr_step = 32
minibatch_size = 128

step_count = 0
optim.zero_grad()
pert_task = 0
idx = torch.arange(train.shape[1]).to(device)

# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=minibatch_size, shuffle=True)

for e in range(100):
    model.train()
    losses = {'loss': [], 'control': [], 'pert': [], 'flow': []}
    for batch in (pbar := tqdm(iter(dataloader))):
        step_count += 1
        batch = batch.to(device)
        idx = torch.arange(batch.shape[1]).to(device)
        batch_emb = model(batch)
        loss, batch_recon, batch_bin_recon = model.ae_loss(batch_emb, batch, idx, return_recon=True)
        loss.backward()
        optim.step()
        optim.zero_grad()
        with torch.no_grad():
            batch_recon = model.sparsify(batch_recon, batch_bin_recon)
            recon_loss = torch.mean((batch_recon - batch)**2)
        losses['loss'].append(recon_loss.item())
        if step_count % lr_step == 0:
            lr_scheduler.step()
        pbar.set_description(
            f"loss: {np.array(losses['loss'])[-lr_step:].mean():.3f}"
        )
    
    avg_loss = sum(losses['loss']) / len(losses['loss'])
    torch.save(model, f"{save_dir}{e}")
    # writer.add_scalar('mae_loss', avg_loss, global_step=e)
    print(f'In epoch {e}, average traning loss is {avg_loss}.')

loss: 1.671:   1%|██▌                                                                                                                                                                                    | 35/2560 [00:02<02:44, 15.31it/s]

Adjusting learning rate of group 0 to 8.0000e-05.


loss: 2.296:   3%|████▊                                                                                                                                                                                  | 67/2560 [00:04<02:42, 15.33it/s]

Adjusting learning rate of group 0 to 1.2000e-04.


loss: 2.204:   4%|███████                                                                                                                                                                                | 99/2560 [00:06<02:40, 15.32it/s]

Adjusting learning rate of group 0 to 1.6000e-04.


loss: 2.178:   5%|█████████▎                                                                                                                                                                            | 131/2560 [00:08<02:38, 15.28it/s]

Adjusting learning rate of group 0 to 1.9999e-04.


loss: 2.104:   6%|███████████▌                                                                                                                                                                          | 163/2560 [00:10<02:37, 15.26it/s]

Adjusting learning rate of group 0 to 1.9999e-04.


loss: 2.034:   8%|█████████████▊                                                                                                                                                                        | 195/2560 [00:13<02:35, 15.20it/s]

Adjusting learning rate of group 0 to 1.9998e-04.


loss: 1.965:   9%|████████████████▏                                                                                                                                                                     | 227/2560 [00:15<02:33, 15.22it/s]

Adjusting learning rate of group 0 to 1.9998e-04.


loss: 1.892:  10%|██████████████████▍                                                                                                                                                                   | 259/2560 [00:17<02:31, 15.20it/s]

Adjusting learning rate of group 0 to 1.9997e-04.


loss: 1.828:  11%|████████████████████▋                                                                                                                                                                 | 291/2560 [00:19<02:29, 15.18it/s]

Adjusting learning rate of group 0 to 1.9996e-04.


loss: 1.783:  13%|██████████████████████▉                                                                                                                                                               | 323/2560 [00:21<02:27, 15.16it/s]

Adjusting learning rate of group 0 to 1.9995e-04.


loss: 1.710:  14%|█████████████████████████▏                                                                                                                                                            | 355/2560 [00:23<02:25, 15.12it/s]

Adjusting learning rate of group 0 to 1.9994e-04.


loss: 1.633:  15%|███████████████████████████▌                                                                                                                                                          | 387/2560 [00:25<02:23, 15.11it/s]

Adjusting learning rate of group 0 to 1.9993e-04.


loss: 1.522:  16%|█████████████████████████████▊                                                                                                                                                        | 419/2560 [00:27<02:21, 15.10it/s]

Adjusting learning rate of group 0 to 1.9992e-04.


loss: 1.414:  18%|████████████████████████████████                                                                                                                                                      | 451/2560 [00:29<02:19, 15.11it/s]

Adjusting learning rate of group 0 to 1.9990e-04.


loss: 1.311:  19%|██████████████████████████████████▎                                                                                                                                                   | 483/2560 [00:32<02:17, 15.07it/s]

Adjusting learning rate of group 0 to 1.9989e-04.


loss: 1.243:  20%|████████████████████████████████████▌                                                                                                                                                 | 515/2560 [00:34<02:15, 15.04it/s]

Adjusting learning rate of group 0 to 1.9987e-04.


loss: 1.201:  21%|██████████████████████████████████████▉                                                                                                                                               | 547/2560 [00:36<02:13, 15.05it/s]

Adjusting learning rate of group 0 to 1.9986e-04.


loss: 1.156:  23%|█████████████████████████████████████████▏                                                                                                                                            | 579/2560 [00:38<02:11, 15.05it/s]

Adjusting learning rate of group 0 to 1.9984e-04.


loss: 1.150:  24%|███████████████████████████████████████████▍                                                                                                                                          | 611/2560 [00:40<02:09, 15.03it/s]

Adjusting learning rate of group 0 to 1.9982e-04.


loss: 1.116:  25%|█████████████████████████████████████████████▋                                                                                                                                        | 643/2560 [00:42<02:07, 15.02it/s]

Adjusting learning rate of group 0 to 1.9980e-04.


loss: 1.115:  26%|███████████████████████████████████████████████▊                                                                                                                                      | 673/2560 [00:44<02:05, 15.02it/s]

Adjusting learning rate of group 0 to 1.9978e-04.


loss: 1.109:  28%|██████████████████████████████████████████████████▎                                                                                                                                   | 707/2560 [00:46<02:03, 15.00it/s]

Adjusting learning rate of group 0 to 1.9976e-04.


loss: 1.100:  29%|████████████████████████████████████████████████████▌                                                                                                                                 | 739/2560 [00:49<02:01, 14.99it/s]

Adjusting learning rate of group 0 to 1.9974e-04.


loss: 1.093:  30%|██████████████████████████████████████████████████████▋                                                                                                                               | 769/2560 [00:51<01:59, 14.98it/s]

Adjusting learning rate of group 0 to 1.9972e-04.


loss: 1.085:  31%|████████████████████████████████████████████████████████▉                                                                                                                             | 801/2560 [00:53<01:57, 14.95it/s]

Adjusting learning rate of group 0 to 1.9969e-04.


loss: 1.077:  33%|███████████████████████████████████████████████████████████▏                                                                                                                          | 833/2560 [00:55<01:55, 14.90it/s]

Adjusting learning rate of group 0 to 1.9967e-04.


loss: 1.078:  34%|█████████████████████████████████████████████████████████████▍                                                                                                                        | 865/2560 [00:57<01:53, 14.92it/s]

Adjusting learning rate of group 0 to 1.9964e-04.


loss: 1.080:  35%|███████████████████████████████████████████████████████████████▊                                                                                                                      | 897/2560 [00:59<01:51, 14.90it/s]

Adjusting learning rate of group 0 to 1.9961e-04.


loss: 1.075:  36%|██████████████████████████████████████████████████████████████████                                                                                                                    | 929/2560 [01:01<01:49, 14.89it/s]

Adjusting learning rate of group 0 to 1.9959e-04.


loss: 1.064:  38%|████████████████████████████████████████████████████████████████████▎                                                                                                                 | 961/2560 [01:04<01:47, 14.89it/s]

Adjusting learning rate of group 0 to 1.9956e-04.


loss: 1.058:  39%|██████████████████████████████████████████████████████████████████████▌                                                                                                               | 993/2560 [01:06<01:45, 14.88it/s]

Adjusting learning rate of group 0 to 1.9953e-04.


loss: 1.054:  40%|████████████████████████████████████████████████████████████████████████▍                                                                                                            | 1025/2560 [01:08<01:43, 14.89it/s]

Adjusting learning rate of group 0 to 1.9950e-04.


loss: 1.070:  41%|██████████████████████████████████████████████████████████████████████████▋                                                                                                          | 1057/2560 [01:10<01:41, 14.86it/s]

Adjusting learning rate of group 0 to 1.9946e-04.


loss: 1.056:  43%|████████████████████████████████████████████████████████████████████████████▉                                                                                                        | 1089/2560 [01:12<01:38, 14.87it/s]

Adjusting learning rate of group 0 to 1.9943e-04.


loss: 1.049:  44%|███████████████████████████████████████████████████████████████████████████████▎                                                                                                     | 1121/2560 [01:14<01:36, 14.88it/s]

Adjusting learning rate of group 0 to 1.9940e-04.


loss: 1.057:  45%|█████████████████████████████████████████████████████████████████████████████████▌                                                                                                   | 1153/2560 [01:16<01:34, 14.85it/s]

Adjusting learning rate of group 0 to 1.9936e-04.


loss: 1.055:  46%|███████████████████████████████████████████████████████████████████████████████████▊                                                                                                 | 1185/2560 [01:19<01:32, 14.87it/s]

Adjusting learning rate of group 0 to 1.9933e-04.


loss: 1.055:  48%|██████████████████████████████████████████████████████████████████████████████████████                                                                                               | 1217/2560 [01:21<01:30, 14.86it/s]

Adjusting learning rate of group 0 to 1.9929e-04.


loss: 1.047:  49%|████████████████████████████████████████████████████████████████████████████████████████▎                                                                                            | 1249/2560 [01:23<01:28, 14.86it/s]

Adjusting learning rate of group 0 to 1.9925e-04.


loss: 1.055:  50%|██████████████████████████████████████████████████████████████████████████████████████████▌                                                                                          | 1281/2560 [01:25<01:26, 14.86it/s]

Adjusting learning rate of group 0 to 1.9921e-04.


loss: 1.056:  51%|████████████████████████████████████████████████████████████████████████████████████████████▊                                                                                        | 1313/2560 [01:27<01:23, 14.86it/s]

Adjusting learning rate of group 0 to 1.9917e-04.


loss: 1.063:  53%|███████████████████████████████████████████████████████████████████████████████████████████████                                                                                      | 1345/2560 [01:29<01:21, 14.85it/s]

Adjusting learning rate of group 0 to 1.9913e-04.


loss: 1.053:  54%|█████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                                   | 1377/2560 [01:32<01:19, 14.84it/s]

Adjusting learning rate of group 0 to 1.9909e-04.


loss: 1.050:  55%|███████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                                 | 1409/2560 [01:34<01:17, 14.85it/s]

Adjusting learning rate of group 0 to 1.9905e-04.


loss: 1.048:  56%|█████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                                               | 1441/2560 [01:36<01:15, 14.85it/s]

Adjusting learning rate of group 0 to 1.9900e-04.


loss: 1.050:  58%|████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                            | 1473/2560 [01:38<01:13, 14.83it/s]

Adjusting learning rate of group 0 to 1.9896e-04.


loss: 1.039:  59%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                          | 1505/2560 [01:40<01:11, 14.83it/s]

Adjusting learning rate of group 0 to 1.9891e-04.


loss: 1.049:  60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                        | 1537/2560 [01:42<01:08, 14.85it/s]

Adjusting learning rate of group 0 to 1.9887e-04.


loss: 1.036:  61%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                                      | 1569/2560 [01:44<01:06, 14.82it/s]

Adjusting learning rate of group 0 to 1.9882e-04.


loss: 1.043:  63%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                   | 1601/2560 [01:47<01:04, 14.81it/s]

Adjusting learning rate of group 0 to 1.9877e-04.


loss: 1.040:  64%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                 | 1633/2560 [01:49<01:02, 14.79it/s]

Adjusting learning rate of group 0 to 1.9872e-04.


loss: 1.039:  65%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                               | 1665/2560 [01:51<01:00, 14.81it/s]

Adjusting learning rate of group 0 to 1.9867e-04.


loss: 1.037:  66%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                             | 1697/2560 [01:53<00:58, 14.82it/s]

Adjusting learning rate of group 0 to 1.9862e-04.


loss: 1.036:  68%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                          | 1729/2560 [01:55<00:56, 14.80it/s]

Adjusting learning rate of group 0 to 1.9856e-04.


loss: 1.033:  69%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                        | 1761/2560 [01:57<00:53, 14.81it/s]

Adjusting learning rate of group 0 to 1.9851e-04.


loss: 1.041:  70%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                      | 1793/2560 [02:00<00:51, 14.78it/s]

Adjusting learning rate of group 0 to 1.9846e-04.


loss: 1.036:  71%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                    | 1825/2560 [02:02<00:49, 14.80it/s]

Adjusting learning rate of group 0 to 1.9840e-04.


loss: 1.038:  73%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                 | 1857/2560 [02:04<00:47, 14.81it/s]

Adjusting learning rate of group 0 to 1.9834e-04.


loss: 1.036:  74%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                               | 1889/2560 [02:06<00:45, 14.81it/s]

Adjusting learning rate of group 0 to 1.9829e-04.


loss: 1.033:  75%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                             | 1921/2560 [02:08<00:43, 14.79it/s]

Adjusting learning rate of group 0 to 1.9823e-04.


loss: 1.029:  76%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                           | 1953/2560 [02:10<00:40, 14.81it/s]

Adjusting learning rate of group 0 to 1.9817e-04.


loss: 1.025:  78%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                        | 1985/2560 [02:13<00:38, 14.82it/s]

Adjusting learning rate of group 0 to 1.9811e-04.


loss: 1.027:  79%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                      | 2017/2560 [02:15<00:36, 14.82it/s]

Adjusting learning rate of group 0 to 1.9805e-04.


loss: 1.030:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                    | 2049/2560 [02:17<00:34, 14.80it/s]

Adjusting learning rate of group 0 to 1.9799e-04.


loss: 1.027:  81%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                 | 2081/2560 [02:19<00:32, 14.82it/s]

Adjusting learning rate of group 0 to 1.9792e-04.


loss: 1.022:  83%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                               | 2113/2560 [02:21<00:30, 14.79it/s]

Adjusting learning rate of group 0 to 1.9786e-04.


loss: 1.032:  84%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                             | 2145/2560 [02:23<00:28, 14.79it/s]

Adjusting learning rate of group 0 to 1.9779e-04.


loss: 1.022:  85%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                           | 2177/2560 [02:26<00:25, 14.80it/s]

Adjusting learning rate of group 0 to 1.9773e-04.


loss: 1.033:  86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                        | 2209/2560 [02:28<00:23, 14.80it/s]

Adjusting learning rate of group 0 to 1.9766e-04.


loss: 1.020:  88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                      | 2241/2560 [02:30<00:21, 14.80it/s]

Adjusting learning rate of group 0 to 1.9759e-04.


loss: 1.025:  89%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                    | 2273/2560 [02:32<00:19, 14.78it/s]

Adjusting learning rate of group 0 to 1.9752e-04.


loss: 1.017:  90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                  | 2305/2560 [02:34<00:17, 14.81it/s]

Adjusting learning rate of group 0 to 1.9745e-04.


loss: 1.023:  91%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏               | 2337/2560 [02:36<00:15, 14.79it/s]

Adjusting learning rate of group 0 to 1.9738e-04.


loss: 1.014:  93%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍             | 2369/2560 [02:39<00:12, 14.79it/s]

Adjusting learning rate of group 0 to 1.9731e-04.


loss: 1.012:  94%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊           | 2401/2560 [02:41<00:10, 14.77it/s]

Adjusting learning rate of group 0 to 1.9724e-04.


loss: 1.022:  95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████         | 2433/2560 [02:43<00:08, 14.79it/s]

Adjusting learning rate of group 0 to 1.9716e-04.


loss: 1.022:  96%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎      | 2465/2560 [02:45<00:06, 14.78it/s]

Adjusting learning rate of group 0 to 1.9709e-04.


loss: 1.014:  98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌    | 2497/2560 [02:47<00:04, 14.80it/s]

Adjusting learning rate of group 0 to 1.9701e-04.


loss: 1.009:  99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊  | 2529/2560 [02:49<00:02, 14.78it/s]

Adjusting learning rate of group 0 to 1.9694e-04.


loss: 1.025: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2560/2560 [02:51<00:00, 14.89it/s]


Adjusting learning rate of group 0 to 1.9686e-04.
In epoch 0, average traning loss is 1.2022667035693302.


loss: 1.014:   1%|██▍                                                                                                                                                                                    | 34/2560 [00:02<02:50, 14.79it/s]

Adjusting learning rate of group 0 to 1.9678e-04.


loss: 1.007:   3%|████▋                                                                                                                                                                                  | 66/2560 [00:04<02:48, 14.81it/s]

Adjusting learning rate of group 0 to 1.9670e-04.


loss: 1.014:   4%|███████                                                                                                                                                                                | 98/2560 [00:06<02:46, 14.81it/s]

Adjusting learning rate of group 0 to 1.9662e-04.


loss: 1.013:   5%|█████████▏                                                                                                                                                                            | 130/2560 [00:08<02:44, 14.82it/s]

Adjusting learning rate of group 0 to 1.9654e-04.


loss: 1.014:   6%|███████████▌                                                                                                                                                                          | 162/2560 [00:10<02:42, 14.80it/s]

Adjusting learning rate of group 0 to 1.9646e-04.


loss: 1.011:   8%|█████████████▊                                                                                                                                                                        | 194/2560 [00:13<02:39, 14.79it/s]

Adjusting learning rate of group 0 to 1.9637e-04.


loss: 1.007:   9%|████████████████                                                                                                                                                                      | 226/2560 [00:15<02:37, 14.81it/s]

Adjusting learning rate of group 0 to 1.9629e-04.


loss: 1.008:  10%|██████████████████▎                                                                                                                                                                   | 258/2560 [00:17<02:35, 14.80it/s]

Adjusting learning rate of group 0 to 1.9620e-04.


loss: 1.012:  11%|████████████████████▌                                                                                                                                                                 | 290/2560 [00:19<02:33, 14.82it/s]

Adjusting learning rate of group 0 to 1.9612e-04.


loss: 1.001:  13%|██████████████████████▉                                                                                                                                                               | 322/2560 [00:21<02:31, 14.81it/s]

Adjusting learning rate of group 0 to 1.9603e-04.


loss: 1.003:  14%|█████████████████████████▏                                                                                                                                                            | 354/2560 [00:23<02:29, 14.79it/s]

Adjusting learning rate of group 0 to 1.9594e-04.


loss: 1.004:  15%|███████████████████████████▍                                                                                                                                                          | 386/2560 [00:26<02:27, 14.79it/s]

Adjusting learning rate of group 0 to 1.9585e-04.


loss: 1.006:  16%|█████████████████████████████▋                                                                                                                                                        | 418/2560 [00:28<02:24, 14.79it/s]

Adjusting learning rate of group 0 to 1.9576e-04.


loss: 1.001:  18%|███████████████████████████████▉                                                                                                                                                      | 450/2560 [00:30<02:22, 14.81it/s]

Adjusting learning rate of group 0 to 1.9567e-04.


loss: 1.000:  19%|██████████████████████████████████▎                                                                                                                                                   | 482/2560 [00:32<02:20, 14.81it/s]

Adjusting learning rate of group 0 to 1.9558e-04.


loss: 1.000:  20%|████████████████████████████████████▌                                                                                                                                                 | 514/2560 [00:34<02:18, 14.79it/s]

Adjusting learning rate of group 0 to 1.9549e-04.


loss: 1.007:  21%|██████████████████████████████████████▊                                                                                                                                               | 546/2560 [00:36<02:16, 14.78it/s]

Adjusting learning rate of group 0 to 1.9539e-04.


loss: 1.007:  23%|█████████████████████████████████████████                                                                                                                                             | 578/2560 [00:39<02:14, 14.77it/s]

Adjusting learning rate of group 0 to 1.9530e-04.


loss: 0.999:  24%|███████████████████████████████████████████▎                                                                                                                                          | 610/2560 [00:41<02:11, 14.81it/s]

Adjusting learning rate of group 0 to 1.9520e-04.


loss: 1.003:  25%|█████████████████████████████████████████████▋                                                                                                                                        | 642/2560 [00:43<02:09, 14.79it/s]

Adjusting learning rate of group 0 to 1.9511e-04.


loss: 1.005:  26%|███████████████████████████████████████████████▉                                                                                                                                      | 674/2560 [00:45<02:07, 14.76it/s]

Adjusting learning rate of group 0 to 1.9501e-04.


loss: 0.992:  28%|██████████████████████████████████████████████████▏                                                                                                                                   | 706/2560 [00:47<02:05, 14.80it/s]

Adjusting learning rate of group 0 to 1.9491e-04.


loss: 1.001:  29%|████████████████████████████████████████████████████▍                                                                                                                                 | 738/2560 [00:49<02:03, 14.78it/s]

Adjusting learning rate of group 0 to 1.9481e-04.


loss: 1.000:  30%|██████████████████████████████████████████████████████▋                                                                                                                               | 770/2560 [00:52<02:01, 14.78it/s]

Adjusting learning rate of group 0 to 1.9471e-04.


loss: 0.997:  31%|█████████████████████████████████████████████████████████                                                                                                                             | 802/2560 [00:54<01:58, 14.80it/s]

Adjusting learning rate of group 0 to 1.9461e-04.


loss: 0.990:  33%|███████████████████████████████████████████████████████████▎                                                                                                                          | 834/2560 [00:56<01:56, 14.79it/s]

Adjusting learning rate of group 0 to 1.9451e-04.


loss: 0.999:  34%|█████████████████████████████████████████████████████████████▌                                                                                                                        | 866/2560 [00:58<01:54, 14.80it/s]

Adjusting learning rate of group 0 to 1.9440e-04.


loss: 1.002:  35%|███████████████████████████████████████████████████████████████▊                                                                                                                      | 898/2560 [01:00<01:52, 14.78it/s]

Adjusting learning rate of group 0 to 1.9430e-04.


loss: 1.006:  36%|██████████████████████████████████████████████████████████████████                                                                                                                    | 930/2560 [01:02<01:50, 14.78it/s]

Adjusting learning rate of group 0 to 1.9419e-04.


loss: 0.994:  38%|████████████████████████████████████████████████████████████████████▍                                                                                                                 | 962/2560 [01:05<01:47, 14.80it/s]

Adjusting learning rate of group 0 to 1.9409e-04.


loss: 0.993:  39%|██████████████████████████████████████████████████████████████████████▋                                                                                                               | 994/2560 [01:07<01:46, 14.77it/s]

Adjusting learning rate of group 0 to 1.9398e-04.


loss: 0.985:  40%|████████████████████████████████████████████████████████████████████████▌                                                                                                            | 1026/2560 [01:09<01:43, 14.76it/s]

Adjusting learning rate of group 0 to 1.9387e-04.


loss: 0.992:  41%|██████████████████████████████████████████████████████████████████████████▊                                                                                                          | 1058/2560 [01:11<01:41, 14.80it/s]

Adjusting learning rate of group 0 to 1.9376e-04.


loss: 0.982:  43%|█████████████████████████████████████████████████████████████████████████████                                                                                                        | 1090/2560 [01:13<01:39, 14.78it/s]

Adjusting learning rate of group 0 to 1.9365e-04.


loss: 0.988:  44%|███████████████████████████████████████████████████████████████████████████████▎                                                                                                     | 1122/2560 [01:15<01:37, 14.80it/s]

Adjusting learning rate of group 0 to 1.9354e-04.


loss: 0.992:  45%|█████████████████████████████████████████████████████████████████████████████████▌                                                                                                   | 1154/2560 [01:18<01:35, 14.80it/s]

Adjusting learning rate of group 0 to 1.9343e-04.


loss: 0.990:  46%|███████████████████████████████████████████████████████████████████████████████████▊                                                                                                 | 1186/2560 [01:20<01:33, 14.76it/s]

Adjusting learning rate of group 0 to 1.9332e-04.


loss: 0.993:  48%|██████████████████████████████████████████████████████████████████████████████████████                                                                                               | 1218/2560 [01:22<01:30, 14.77it/s]

Adjusting learning rate of group 0 to 1.9321e-04.


loss: 0.993:  49%|████████████████████████████████████████████████████████████████████████████████████████▍                                                                                            | 1250/2560 [01:24<01:28, 14.77it/s]

Adjusting learning rate of group 0 to 1.9309e-04.


loss: 0.982:  50%|██████████████████████████████████████████████████████████████████████████████████████████▋                                                                                          | 1282/2560 [01:26<01:26, 14.78it/s]

Adjusting learning rate of group 0 to 1.9298e-04.


loss: 0.989:  51%|████████████████████████████████████████████████████████████████████████████████████████████▉                                                                                        | 1314/2560 [01:28<01:24, 14.77it/s]

Adjusting learning rate of group 0 to 1.9286e-04.


loss: 0.979:  53%|███████████████████████████████████████████████████████████████████████████████████████████████▏                                                                                     | 1346/2560 [01:30<01:22, 14.77it/s]

Adjusting learning rate of group 0 to 1.9274e-04.


loss: 0.990:  54%|█████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                                   | 1378/2560 [01:33<01:19, 14.80it/s]

Adjusting learning rate of group 0 to 1.9263e-04.


loss: 0.980:  55%|███████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                                 | 1410/2560 [01:35<01:17, 14.77it/s]

Adjusting learning rate of group 0 to 1.9251e-04.


loss: 0.994:  56%|█████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                                               | 1442/2560 [01:37<01:15, 14.79it/s]

Adjusting learning rate of group 0 to 1.9239e-04.


loss: 0.976:  58%|████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                            | 1474/2560 [01:39<01:13, 14.77it/s]

Adjusting learning rate of group 0 to 1.9227e-04.


loss: 0.985:  59%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                                          | 1506/2560 [01:41<01:11, 14.79it/s]

Adjusting learning rate of group 0 to 1.9215e-04.


loss: 0.972:  60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                        | 1538/2560 [01:43<01:09, 14.79it/s]

Adjusting learning rate of group 0 to 1.9202e-04.


loss: 0.971:  61%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                      | 1570/2560 [01:46<01:06, 14.79it/s]

Adjusting learning rate of group 0 to 1.9190e-04.


loss: 0.983:  63%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                   | 1602/2560 [01:48<01:04, 14.77it/s]

Adjusting learning rate of group 0 to 1.9178e-04.


loss: 0.993:  64%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                                 | 1634/2560 [01:50<01:02, 14.78it/s]

Adjusting learning rate of group 0 to 1.9165e-04.


loss: 0.984:  65%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                               | 1666/2560 [01:52<01:00, 14.77it/s]

Adjusting learning rate of group 0 to 1.9152e-04.


loss: 0.981:  66%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                             | 1698/2560 [01:54<00:58, 14.78it/s]

Adjusting learning rate of group 0 to 1.9140e-04.


loss: 0.976:  68%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                          | 1730/2560 [01:56<00:56, 14.80it/s]

Adjusting learning rate of group 0 to 1.9127e-04.


loss: 0.983:  69%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                        | 1762/2560 [01:59<00:53, 14.80it/s]

Adjusting learning rate of group 0 to 1.9114e-04.


loss: 0.974:  70%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                      | 1794/2560 [02:01<00:51, 14.78it/s]

Adjusting learning rate of group 0 to 1.9101e-04.


loss: 0.980:  71%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                    | 1826/2560 [02:03<00:49, 14.78it/s]

Adjusting learning rate of group 0 to 1.9088e-04.


loss: 0.986:  73%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                 | 1858/2560 [02:05<00:47, 14.76it/s]

Adjusting learning rate of group 0 to 1.9075e-04.


loss: 0.979:  74%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                               | 1890/2560 [02:07<00:45, 14.77it/s]

Adjusting learning rate of group 0 to 1.9062e-04.


loss: 0.976:  75%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                             | 1922/2560 [02:09<00:43, 14.77it/s]

Adjusting learning rate of group 0 to 1.9048e-04.


loss: 0.982:  76%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                          | 1954/2560 [02:12<00:40, 14.79it/s]

Adjusting learning rate of group 0 to 1.9035e-04.


loss: 0.977:  78%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                        | 1986/2560 [02:14<00:38, 14.78it/s]

Adjusting learning rate of group 0 to 1.9021e-04.


loss: 0.967:  79%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                      | 2018/2560 [02:16<00:36, 14.78it/s]

Adjusting learning rate of group 0 to 1.9008e-04.


loss: 0.971:  80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                    | 2050/2560 [02:18<00:34, 14.78it/s]

Adjusting learning rate of group 0 to 1.8994e-04.


loss: 0.978:  81%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                 | 2082/2560 [02:20<00:32, 14.80it/s]

Adjusting learning rate of group 0 to 1.8980e-04.


loss: 0.981:  83%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                               | 2114/2560 [02:22<00:30, 14.78it/s]

Adjusting learning rate of group 0 to 1.8966e-04.


loss: 0.972:  84%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                             | 2146/2560 [02:25<00:28, 14.78it/s]

Adjusting learning rate of group 0 to 1.8952e-04.


loss: 0.977:  85%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                           | 2178/2560 [02:27<00:25, 14.77it/s]

Adjusting learning rate of group 0 to 1.8938e-04.


loss: 0.973:  86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                        | 2210/2560 [02:29<00:23, 14.76it/s]

Adjusting learning rate of group 0 to 1.8924e-04.


loss: 0.974:  88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                      | 2242/2560 [02:31<00:21, 14.76it/s]

Adjusting learning rate of group 0 to 1.8910e-04.


loss: 0.965:  89%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                    | 2274/2560 [02:33<00:19, 14.77it/s]

Adjusting learning rate of group 0 to 1.8896e-04.


loss: 0.957:  90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                  | 2306/2560 [02:35<00:17, 14.80it/s]

Adjusting learning rate of group 0 to 1.8881e-04.


loss: 0.966:  91%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎               | 2338/2560 [02:38<00:15, 14.77it/s]

Adjusting learning rate of group 0 to 1.8867e-04.


loss: 0.965:  93%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌             | 2370/2560 [02:40<00:12, 14.78it/s]

Adjusting learning rate of group 0 to 1.8852e-04.


loss: 0.964:  94%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊           | 2402/2560 [02:42<00:10, 14.77it/s]

Adjusting learning rate of group 0 to 1.8838e-04.


loss: 0.966:  95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████         | 2434/2560 [02:44<00:08, 14.80it/s]

Adjusting learning rate of group 0 to 1.8823e-04.


loss: 0.963:  96%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎      | 2466/2560 [02:46<00:06, 14.77it/s]

Adjusting learning rate of group 0 to 1.8808e-04.


loss: 0.967:  98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌    | 2498/2560 [02:48<00:04, 14.78it/s]

Adjusting learning rate of group 0 to 1.8793e-04.


loss: 0.974:  99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉  | 2530/2560 [02:51<00:02, 14.76it/s]

Adjusting learning rate of group 0 to 1.8778e-04.


loss: 0.965: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2560/2560 [02:53<00:00, 14.78it/s]


Adjusting learning rate of group 0 to 1.8763e-04.
In epoch 1, average traning loss is 0.9880959489615634.


loss: 0.966:   1%|██▍                                                                                                                                                                                    | 34/2560 [00:02<02:50, 14.81it/s]

Adjusting learning rate of group 0 to 1.8748e-04.


loss: 0.971:   3%|████▋                                                                                                                                                                                  | 66/2560 [00:04<02:48, 14.79it/s]

Adjusting learning rate of group 0 to 1.8733e-04.


loss: 0.960:   4%|███████                                                                                                                                                                                | 98/2560 [00:06<02:46, 14.80it/s]

Adjusting learning rate of group 0 to 1.8717e-04.


loss: 0.969:   5%|█████████▏                                                                                                                                                                            | 130/2560 [00:08<02:44, 14.81it/s]

Adjusting learning rate of group 0 to 1.8702e-04.


loss: 0.958:   6%|███████████▌                                                                                                                                                                          | 162/2560 [00:10<02:42, 14.80it/s]

Adjusting learning rate of group 0 to 1.8686e-04.


loss: 0.964:   8%|█████████████▊                                                                                                                                                                        | 194/2560 [00:13<02:40, 14.79it/s]

Adjusting learning rate of group 0 to 1.8671e-04.


loss: 0.965:   9%|████████████████                                                                                                                                                                      | 226/2560 [00:15<02:37, 14.79it/s]

Adjusting learning rate of group 0 to 1.8655e-04.


loss: 0.959:  10%|██████████████████▎                                                                                                                                                                   | 258/2560 [00:17<02:35, 14.80it/s]

Adjusting learning rate of group 0 to 1.8639e-04.


loss: 0.969:  11%|████████████████████▌                                                                                                                                                                 | 290/2560 [00:19<02:33, 14.80it/s]

Adjusting learning rate of group 0 to 1.8623e-04.


loss: 0.958:  13%|██████████████████████▉                                                                                                                                                               | 322/2560 [00:21<02:31, 14.81it/s]

Adjusting learning rate of group 0 to 1.8607e-04.


loss: 0.954:  14%|█████████████████████████▏                                                                                                                                                            | 354/2560 [00:23<02:29, 14.80it/s]

Adjusting learning rate of group 0 to 1.8591e-04.


loss: 0.963:  15%|███████████████████████████▍                                                                                                                                                          | 386/2560 [00:26<02:26, 14.81it/s]

Adjusting learning rate of group 0 to 1.8575e-04.


loss: 0.965:  16%|█████████████████████████████▋                                                                                                                                                        | 418/2560 [00:28<02:24, 14.79it/s]

Adjusting learning rate of group 0 to 1.8559e-04.


loss: 0.961:  18%|███████████████████████████████▉                                                                                                                                                      | 450/2560 [00:30<02:22, 14.79it/s]

Adjusting learning rate of group 0 to 1.8543e-04.


loss: 0.962:  19%|██████████████████████████████████▎                                                                                                                                                   | 482/2560 [00:32<02:20, 14.80it/s]

Adjusting learning rate of group 0 to 1.8526e-04.


loss: 0.967:  20%|████████████████████████████████████▌                                                                                                                                                 | 514/2560 [00:34<02:18, 14.81it/s]

Adjusting learning rate of group 0 to 1.8510e-04.


loss: 0.963:  21%|██████████████████████████████████████▊                                                                                                                                               | 546/2560 [00:36<02:16, 14.79it/s]

Adjusting learning rate of group 0 to 1.8493e-04.


loss: 0.957:  23%|█████████████████████████████████████████                                                                                                                                             | 578/2560 [00:39<02:14, 14.77it/s]

Adjusting learning rate of group 0 to 1.8477e-04.


loss: 0.955:  24%|███████████████████████████████████████████▎                                                                                                                                          | 610/2560 [00:41<02:11, 14.77it/s]

Adjusting learning rate of group 0 to 1.8460e-04.


loss: 0.954:  25%|█████████████████████████████████████████████▋                                                                                                                                        | 642/2560 [00:43<02:09, 14.80it/s]

Adjusting learning rate of group 0 to 1.8443e-04.


loss: 0.957:  26%|███████████████████████████████████████████████▉                                                                                                                                      | 674/2560 [00:45<02:07, 14.80it/s]

Adjusting learning rate of group 0 to 1.8426e-04.


loss: 0.956:  28%|██████████████████████████████████████████████████▏                                                                                                                                   | 706/2560 [00:47<02:05, 14.81it/s]

Adjusting learning rate of group 0 to 1.8409e-04.


loss: 0.949:  29%|████████████████████████████████████████████████████▍                                                                                                                                 | 738/2560 [00:49<02:03, 14.81it/s]

Adjusting learning rate of group 0 to 1.8392e-04.


loss: 0.961:  30%|██████████████████████████████████████████████████████▋                                                                                                                               | 770/2560 [00:52<02:00, 14.80it/s]

Adjusting learning rate of group 0 to 1.8375e-04.


loss: 0.966:  31%|█████████████████████████████████████████████████████████                                                                                                                             | 802/2560 [00:54<01:58, 14.78it/s]

Adjusting learning rate of group 0 to 1.8358e-04.


loss: 0.959:  33%|███████████████████████████████████████████████████████████▎                                                                                                                          | 834/2560 [00:56<01:56, 14.77it/s]

Adjusting learning rate of group 0 to 1.8341e-04.


loss: 0.954:  34%|█████████████████████████████████████████████████████████████▌                                                                                                                        | 866/2560 [00:58<01:54, 14.78it/s]

Adjusting learning rate of group 0 to 1.8323e-04.


loss: 0.950:  35%|███████████████████████████████████████████████████████████████▊                                                                                                                      | 898/2560 [01:00<01:52, 14.76it/s]

Adjusting learning rate of group 0 to 1.8306e-04.


loss: 0.952:  36%|██████████████████████████████████████████████████████████████████                                                                                                                    | 930/2560 [01:02<01:50, 14.79it/s]

Adjusting learning rate of group 0 to 1.8288e-04.


loss: 0.960:  38%|████████████████████████████████████████████████████████████████████▍                                                                                                                 | 962/2560 [01:05<01:47, 14.82it/s]

Adjusting learning rate of group 0 to 1.8271e-04.


loss: 0.957:  39%|██████████████████████████████████████████████████████████████████████▋                                                                                                               | 994/2560 [01:07<01:45, 14.79it/s]

Adjusting learning rate of group 0 to 1.8253e-04.


loss: 0.960:  40%|████████████████████████████████████████████████████████████████████████▌                                                                                                            | 1026/2560 [01:09<01:43, 14.76it/s]

Adjusting learning rate of group 0 to 1.8235e-04.


loss: 0.957:  41%|██████████████████████████████████████████████████████████████████████████▊                                                                                                          | 1058/2560 [01:11<01:41, 14.81it/s]

Adjusting learning rate of group 0 to 1.8217e-04.


loss: 0.952:  43%|█████████████████████████████████████████████████████████████████████████████                                                                                                        | 1090/2560 [01:13<01:39, 14.80it/s]

Adjusting learning rate of group 0 to 1.8200e-04.


loss: 0.951:  44%|███████████████████████████████████████████████████████████████████████████████▎                                                                                                     | 1122/2560 [01:15<01:37, 14.71it/s]

Adjusting learning rate of group 0 to 1.8181e-04.


loss: 0.956:  45%|█████████████████████████████████████████████████████████████████████████████████▌                                                                                                   | 1154/2560 [01:18<01:35, 14.78it/s]

Adjusting learning rate of group 0 to 1.8163e-04.


loss: 0.953:  46%|███████████████████████████████████████████████████████████████████████████████████▊                                                                                                 | 1186/2560 [01:20<01:32, 14.81it/s]

Adjusting learning rate of group 0 to 1.8145e-04.


loss: 0.945:  48%|██████████████████████████████████████████████████████████████████████████████████████                                                                                               | 1218/2560 [01:22<01:30, 14.80it/s]

Adjusting learning rate of group 0 to 1.8127e-04.


loss: 0.952:  49%|████████████████████████████████████████████████████████████████████████████████████████▍                                                                                            | 1250/2560 [01:24<01:28, 14.79it/s]

Adjusting learning rate of group 0 to 1.8109e-04.


loss: 0.948:  50%|██████████████████████████████████████████████████████████████████████████████████████████▋                                                                                          | 1282/2560 [01:26<01:26, 14.78it/s]

Adjusting learning rate of group 0 to 1.8090e-04.


loss: 0.953:  51%|████████████████████████████████████████████████████████████████████████████████████████████▉                                                                                        | 1314/2560 [01:28<01:24, 14.80it/s]

Adjusting learning rate of group 0 to 1.8072e-04.


loss: 0.951:  53%|███████████████████████████████████████████████████████████████████████████████████████████████▏                                                                                     | 1346/2560 [01:30<01:22, 14.78it/s]

Adjusting learning rate of group 0 to 1.8053e-04.


loss: 0.950:  54%|█████████████████████████████████████████████████████████████████████████████████████████████████▎                                                                                   | 1377/2560 [01:33<01:20, 14.78it/s]

Adjusting learning rate of group 0 to 1.8034e-04.





KeyboardInterrupt: 

In [None]:
del pert_index
del ctrl 
del pert
del idx
del ctrl_emb
del ctrl_recon
del ctrl_bin_recon
del pert_emb
del pert_recon
del pert_bin_recon

In [12]:
optim.zero_grad()
import gc
torch.cuda.empty_cache()
gc.collect()

498

In [13]:
use_sparsity_loss = False
use_mask_task = True
use_active_weights = False
lr_step = 32
minibatch_size = 172

step_count = 0
optim.zero_grad()
pert_task = 0
for e in range(20):
    model.train()
    losses = {'loss': [], 'control': [], 'pert': [], 'flow': []}
    for (bcontrol, bpert, bpert_index) in (pbar := tqdm(iter(dl))):
        bcontrol, bpert, bpert_index = bcontrol.squeeze(), bpert.squeeze(), bpert_index.reshape(-1, 1)# , # bpert_expr.squeeze()
        curr_batch_size = bcontrol.shape[0]
        for i in range(curr_batch_size // minibatch_size):
            ctrl = bcontrol[(i * minibatch_size):((i + 1) * minibatch_size)]
            if ctrl.shape[0] == 0:
                continue
            pert = bpert[(i * minibatch_size):((i + 1) * minibatch_size)]
            pert_index = bpert_index[(i * minibatch_size):((i + 1) * minibatch_size)]
            pert_index = pert_index.squeeze()
            
            ctrl = ctrl.float().to(device)
            pert = pert.float().to(device)
            
            idx = torch.arange(ctrl.shape[1]).to(device)
            
            step_count += 1
            
            ctrl_emb = model(ctrl)
            ctrl_loss, ctrl_recon, ctrl_bin_recon = model.ae_loss(ctrl_emb, ctrl, idx, return_recon=True)
            
            pert_emb = model(pert)
            pert_loss, pert_recon, pert_bin_recon = model.ae_loss(pert_emb, pert, idx, return_recon=True)
            
            cond = model.gene_embedding.pos[:, pert_index][0]
            flow_loss = model.flow_loss(ctrl_emb, pert_emb, cond)
            
            loss = ctrl_loss + pert_loss + flow_loss
            loss.backward()
            optim.step()
            optim.zero_grad()
            losses['loss'].append(loss.item())
            losses['control'].append(ctrl_loss.item())
            losses['pert'].append(pert_loss.item())
            losses['flow'].append(flow_loss.item())
            if step_count % lr_step == 0:
                lr_scheduler.step()
            pbar.set_description(
                f"loss: {np.array(losses['loss'])[-lr_step:].mean():.3f}, tv: {np.array(losses['control'])[-lr_step:].mean():.3f}, ptv: {np.array(losses['pert'])[-lr_step:].mean():.3f}, flow: {np.array(losses['flow'])[-lr_step:].mean():.3f}"
            )
    
    avg_loss = sum(losses['control']) / len(losses['control'])
    torch.save(model, f"{save_dir}{e}")
    # writer.add_scalar('mae_loss', avg_loss, global_step=e)
    print(f'In epoch {e}, average traning loss is {avg_loss}.')

loss: 0.414, tv: 0.166, ptv: 0.170, flow: 0.078:   3%|███▉                                                                                                                                                | 16/611 [00:07<04:17,  2.31it/s]

Adjusting learning rate of group 0 to 1.8016e-04.


loss: 0.332, tv: 0.154, ptv: 0.159, flow: 0.019:   5%|███████▉                                                                                                                                            | 33/611 [00:14<04:06,  2.35it/s]

Adjusting learning rate of group 0 to 1.7997e-04.


loss: 0.336, tv: 0.159, ptv: 0.164, flow: 0.013:   8%|███████████▊                                                                                                                                        | 49/611 [00:21<04:05,  2.29it/s]

Adjusting learning rate of group 0 to 1.7978e-04.


loss: 0.299, tv: 0.143, ptv: 0.145, flow: 0.010:  11%|███████████████▋                                                                                                                                    | 65/611 [00:28<04:00,  2.27it/s]

Adjusting learning rate of group 0 to 1.7959e-04.


loss: 0.283, tv: 0.137, ptv: 0.138, flow: 0.008:  13%|███████████████████▌                                                                                                                                | 81/611 [00:35<03:52,  2.28it/s]

Adjusting learning rate of group 0 to 1.7940e-04.


loss: 0.311, tv: 0.150, ptv: 0.153, flow: 0.008:  16%|███████████████████████▍                                                                                                                            | 97/611 [00:42<03:45,  2.28it/s]

Adjusting learning rate of group 0 to 1.7921e-04.


loss: 0.290, tv: 0.140, ptv: 0.143, flow: 0.007:  19%|███████████████████████████▍                                                                                                                       | 114/611 [00:49<03:35,  2.31it/s]

Adjusting learning rate of group 0 to 1.7902e-04.


loss: 0.292, tv: 0.142, ptv: 0.144, flow: 0.006:  19%|████████████████████████████▏                                                                                                                      | 117/611 [00:50<03:34,  2.30it/s]


KeyboardInterrupt: 

In [19]:
import math
import torch
import numpy as np
from torchdyn.core import NeuralODE
from datamodules import torch_wrapper

def compute_conditional_flow(
    model, control, pert_ids, pert_mat, batch_size=100, num_steps=400, n_batches=1e8, true_bin=None
):
    node = NeuralODE(
        torch_wrapper(model.flow).to(device), solver="dopri5", sensitivity="adjoint"
    )
    n_samples = min(control.shape[0], pert_ids.shape[0])
    n_batches = min(math.ceil(n_samples / batch_size), n_batches)
    preds = np.zeros((min(n_batches * batch_size, n_samples), control.shape[1]))
    with torch.no_grad():
        for i in range(n_batches):
            control_batch = control[batch_size*i:min(batch_size*(i+1), n_samples)]
            pert_batch = pert_mat[pert_ids][batch_size*i:min(batch_size*(i+1), n_samples)]# [:control_batch.shape[0]]
            model.flow.cond = pert_batch.to(device)
            inp = control_batch.float()
            inp = inp.to(device)
            
            idx = torch.arange(control_eval.shape[1]).to(device)
            cell_embedding = model(inp)
            
            outp = node.trajectory(
                cell_embedding,
                t_span=torch.linspace(0, 1, num_steps)
            )
            outp = outp[-1, :, :]
            outp, outb = model.recon(cell_embedding, idx)
            if true_bin:
                outb = true_bin.to(device)
            outp = model.sparsify(outp, outb)
            outp = outp.cpu()
            preds[batch_size*i:batch_size*(i+1), :] = outp.squeeze()
            
    return preds

In [17]:
for cell_type, pert_type in zip(holdout_cells, holdout_perts):
    break

cell_type_names = adata.obs[cell_col]
pert_type_names = adata.obs[pert_col]
control_eval =  torch.tensor(X[(cell_type_names == cell_type) & (pert_type_names == gene_map['NT'])]).float()
true_pert = torch.tensor(X[(pert_type_names == pert_type) & (cell_type_names == cell_type)]).float()

In [20]:
pred = compute_conditional_flow(
    model,
    control_eval, 
    torch.tensor(np.repeat(pert_type, control_eval.shape[0])), 
    model.gene_embedding.pos[0]
)  

In [21]:
true_pert.mean(axis=0)

tensor([0.0490, 0.0000, 0.0000,  ..., 0.0000, 0.0484, 0.0105])

In [22]:
pred.mean(axis=0)

array([0.16726097, 0.00103138, 0.00203618, ..., 0.        , 0.04829858,
       0.10163789])

In [23]:
from os import listdir
from scipy.sparse import issparse
import anndata
import scanpy as sc
import numpy as np
import pandas as pd

from scipy.stats import pearsonr

import logging

logger = logging.getLogger(__name__)

def r2_mse_filename(pert, cell):
    return f'r2_and_mse_{pert}_{cell}.json'

def c_r_filename(pert, cell):
    return f'c_r_results_{pert}_{cell}.json'

def DEGs_overlap_filename(pert, cell):   
    return f'DEGs_overlaps_{pert}_{cell}.json'


def get_DEG_with_direction(gene, score):
    if score > 0:
        return(f'{gene}+')
    else:
        return(f'{gene}-')
        
def to_dense(X):
    if issparse(X):
        return X.toarray()
    else:
        return np.asarray(X)

def get_DEGs(control_adata, target_adata):
    temp_concat = anndata.concat([control_adata, target_adata], label = 'batch')
    sc.tl.rank_genes_groups(
        temp_concat, 'batch', method='wilcoxon', 
        groups = ['1'], ref = '0', rankby_abs = True, tie_correct=True
    )

    rankings = temp_concat.uns['rank_genes_groups']
    result_df = pd.DataFrame({'scores': rankings['scores']['1'],
                     'pvals_adj': rankings['pvals_adj']['1'],
                     'lfc': rankings['logfoldchanges']['1']},
                    index = rankings['names']['1'])
    return result_df

def get_eval(ctrl_adata, true_adata, pred_adata, DEGs, DEG_vals, pval_threshold, lfc_threshold):
        
    results_dict =  {}
    
    logger.debug(f"Computing R, R2, and MSE metrics")

    ctrl_mean = to_dense(ctrl_adata.X).mean(axis = 0)

    true_mean = to_dense(true_adata.X).mean(axis = 0)
    true_var = to_dense(true_adata.X).var(axis = 0)
    
    pred_mean = to_dense(pred_adata.X).mean(axis = 0)
    pred_var = to_dense(pred_adata.X).var(axis = 0)
    
    true_corr_mtx = np.corrcoef(to_dense(true_adata.X), rowvar=False).flatten()
    true_cov_mtx = np.cov(to_dense(true_adata.X), rowvar=False).flatten()
        
    pred_corr_mtx = np.corrcoef(to_dense(pred_adata.X), rowvar=False).flatten()
    pred_cov_mtx = np.cov(to_dense(pred_adata.X), rowvar=False).flatten()

    true_sub_diff = true_mean - ctrl_mean
    pred_sub_diff = pred_mean - ctrl_mean

    true_diff = true_mean/ctrl_mean
    pred_diff = pred_mean/ctrl_mean

    true_diff_mask = (np.isnan(true_diff) | np.isinf(true_diff))
    pred_diff_mask = (np.isnan(pred_diff) | np.isinf(pred_diff))
    
    common_mask = true_diff_mask | pred_diff_mask
    true_fold_diff = np.ma.array(true_diff, mask=common_mask).compressed()
    pred_fold_diff = np.ma.array(pred_diff, mask=common_mask).compressed()

    results_dict['all_genes_mean_sub_diff_R'] = pearsonr(true_sub_diff, pred_sub_diff)[0]
    results_dict['all_genes_mean_sub_diff_R2'] = pearsonr(true_sub_diff, pred_sub_diff)[0]**2
    results_dict['all_genes_mean_sub_diff_MSE'] = (np.square(true_sub_diff - pred_sub_diff)).mean(axis=0)

    results_dict['all_genes_mean_fold_diff_R'] = pearsonr(true_fold_diff, pred_fold_diff)[0]
    results_dict['all_genes_mean_fold_diff_R2'] = pearsonr(true_fold_diff, pred_fold_diff)[0]**2
    results_dict['all_genes_mean_fold_diff_MSE'] = (np.square(true_fold_diff - pred_fold_diff)).mean(axis=0)
    
    results_dict['all_genes_mean_R'] = pearsonr(true_mean, pred_mean)[0]
    results_dict['all_genes_mean_R2'] = pearsonr(true_mean, pred_mean)[0]**2
    results_dict['all_genes_mean_MSE'] = (np.square(true_mean - pred_mean)).mean(axis=0)

    results_dict['all_genes_var_R'] = pearsonr(true_var, pred_var)[0]
    results_dict['all_genes_var_R2'] = pearsonr(true_var, pred_var)[0]**2
    results_dict['all_genes_var_MSE'] = (np.square(true_var - pred_var)).mean(axis=0)
   
    corr_nas = np.logical_or(np.isnan(true_corr_mtx), np.isnan(pred_corr_mtx))
    cov_nas = np.logical_or(np.isnan(true_cov_mtx), np.isnan(pred_cov_mtx))

    results_dict['all_genes_corr_mtx_R'] = pearsonr(true_corr_mtx[~corr_nas], pred_corr_mtx[~corr_nas])[0]
    results_dict['all_genes_corr_mtx_R2'] = pearsonr(true_corr_mtx[~corr_nas], pred_corr_mtx[~corr_nas])[0]**2
    results_dict['all_genes_corr_mtx_MSE'] = (np.square(true_corr_mtx[~corr_nas] - pred_corr_mtx[~corr_nas])).mean(axis=0)

    results_dict['all_genes_cov_mtx_R'] = pearsonr(true_cov_mtx[~cov_nas], pred_cov_mtx[~cov_nas])[0]
    results_dict['all_genes_cov_mtx_R2'] = pearsonr(true_cov_mtx[~cov_nas], pred_cov_mtx[~cov_nas])[0]**2
    results_dict['all_genes_cov_mtx_MSE'] = (np.square(true_cov_mtx[~cov_nas] - pred_cov_mtx[~cov_nas])).mean(axis=0)

    if lfc_threshold:   
        significant_DEGs = DEGs[(DEGs['pvals_adj'] < pval_threshold) & (abs(DEGs) > lfc_threshold)]
    else:
        significant_DEGs = DEGs[DEGs['pvals_adj'] < pval_threshold]
    num_DEGs = len(significant_DEGs)
    DEG_vals.insert(0, num_DEGs)


    logger.debug(f"Significant DEGs {significant_DEGs}")
    
    for val in DEG_vals:

        logger.debug(f"Computing R, R2, and MSE metrics for top {val} DEGs")

        #If val == 1 we can't
        if ((val > num_DEGs) or (val == 0) or (val == 1)):
            results_dict[f'Top_{val}_DEGs_sub_diff_mean_R'] = None
            results_dict[f'Top_{val}_DEGs_sub_diff_mean_R2'] = None
            results_dict[f'Top_{val}_DEGs_sub_diff_mean_MSE'] = None
            
            results_dict[f'Top_{val}_DEGs_fold_diff_mean_R'] = None
            results_dict[f'Top_{val}_DEGs_fold_diff_mean_R2'] = None
            results_dict[f'Top_{val}_DEGs_fold_diff_mean_MSE'] = None
            
            results_dict[f'Top_{val}_DEGs_mean_R'] = None
            results_dict[f'Top_{val}_DEGs_mean_R2'] = None
            results_dict[f'Top_{val}_DEGs_mean_MSE'] = None

            results_dict[f'Top_{val}_DEGs_var_R'] = None
            results_dict[f'Top_{val}_DEGs_var_R2'] = None
            results_dict[f'Top_{val}_DEGs_var_MSE'] = None
            
            results_dict[f'Top_{val}_DEGs_corr_mtx_R'] = None
            results_dict[f'Top_{val}_DEGs_corr_mtx_R2'] = None
            results_dict[f'Top_{val}_DEGs_corr_mtx_MSE'] = None
            
            results_dict[f'Top_{val}_DEGs_cov_mtx_R'] = None
            results_dict[f'Top_{val}_DEGs_cov_mtx_R2'] = None
            results_dict[f'Top_{val}_DEGs_cov_mtx_MSE'] = None
        
        else:
            top_DEGs = significant_DEGs.iloc[0:val].index

            logger.debug(f"Top DEGs: {top_DEGs}")


            #Reshape --> If there is a single gene, the shape is (1,) and we need to reshape it to (1,1)

            ctrl_mean = to_dense(ctrl_adata[:,top_DEGs].X).mean(axis = 0)
            
            true_mean = to_dense(true_adata[:,top_DEGs].X).mean(axis = 0)

            logger.debug(f"Shape ctrl_adata with top DEGs: {ctrl_adata[:,top_DEGs].X.shape}, shape true_adata with top DEGs: {true_adata[:,top_DEGs].X.shape}")


            true_var = to_dense(true_adata[:,top_DEGs].X).var(axis = 0)
            true_corr_mtx = np.corrcoef(to_dense(true_adata[:,top_DEGs].X), rowvar=False).flatten()
            true_cov_mtx = np.cov(to_dense(true_adata[:,top_DEGs].X), rowvar=False).flatten()

            pred_mean = to_dense(pred_adata[:,top_DEGs].X).mean(axis = 0)
            logger.debug(f"Shape of true_mean shape: {true_mean.shape}, ctrl_mean shape: {ctrl_mean.shape}, pred_mean shape: {pred_mean.shape}")

            pred_var = to_dense(pred_adata[:,top_DEGs].X).var(axis = 0)
            pred_corr_mtx = np.corrcoef(to_dense(pred_adata[:,top_DEGs].X), rowvar=False).flatten()
            pred_cov_mtx = np.cov(to_dense(pred_adata[:,top_DEGs].X), rowvar=False).flatten()

            logger.debug(f"Shape of true_var shape: {true_var.shape}, pred_var shape: {pred_var.shape}")

            true_sub_diff = true_mean - ctrl_mean
            pred_sub_diff = pred_mean - ctrl_mean
        
            true_diff = true_mean/ctrl_mean
            pred_diff = pred_mean/ctrl_mean
        
            true_diff_mask = (np.isnan(true_diff) | np.isinf(true_diff))
            pred_diff_mask = (np.isnan(pred_diff) | np.isinf(pred_diff))
            
            common_mask = true_diff_mask | pred_diff_mask
            true_fold_diff = np.ma.array(true_diff, mask=common_mask).compressed()
            pred_fold_diff = np.ma.array(pred_diff, mask=common_mask).compressed()

            results_dict[f'Top_{val}_DEGs_sub_diff_R'] = pearsonr(true_sub_diff, pred_sub_diff)[0]
            results_dict[f'Top_{val}_DEGs_sub_diff_R2'] = pearsonr(true_sub_diff, pred_sub_diff)[0]**2
            results_dict[f'Top_{val}_DEGs_sub_diff_MSE'] = (np.square(true_sub_diff - pred_sub_diff)).mean(axis=0)
        
            results_dict[f'Top_{val}_DEGs_fold_diff_R'] = pearsonr(true_fold_diff, pred_fold_diff)[0]
            results_dict[f'Top_{val}_DEGs_fold_diff_R2'] = pearsonr(true_fold_diff, pred_fold_diff)[0]**2
            results_dict[f'Top_{val}_DEGs_fold_diff_MSE'] = (np.square(true_fold_diff - pred_fold_diff)).mean(axis=0)
    
            results_dict[f'Top_{val}_DEGs_mean_R'] = pearsonr(true_mean, pred_mean)[0]
            results_dict[f'Top_{val}_DEGs_mean_R2'] = pearsonr(true_mean, pred_mean)[0]**2
            results_dict[f'Top_{val}_DEGs_mean_MSE'] = (np.square(true_mean - pred_mean)).mean(axis=0)

            results_dict[f'Top_{val}_DEGs_var_R'] = pearsonr(true_var, pred_var)[0]
            results_dict[f'Top_{val}_DEGs_var_R2'] = pearsonr(true_var, pred_var)[0]**2
            results_dict[f'Top_{val}_DEGs_var_MSE'] = (np.square(true_var - pred_var)).mean(axis=0)
            
            corr_nas = np.logical_or(np.isnan(true_corr_mtx), np.isnan(pred_corr_mtx))
            cov_nas = np.logical_or(np.isnan(true_cov_mtx), np.isnan(pred_cov_mtx))

            results_dict[f'Top_{val}_DEGs_corr_mtx_R'] = pearsonr(true_corr_mtx[~corr_nas], pred_corr_mtx[~corr_nas])[0]
            results_dict[f'Top_{val}_DEGs_corr_mtx_R2'] = pearsonr(true_corr_mtx[~corr_nas], pred_corr_mtx[~corr_nas])[0]**2
            results_dict[f'Top_{val}_DEGs_corr_mtx_MSE'] = (np.square(true_corr_mtx[~corr_nas] - pred_corr_mtx[~corr_nas])).mean(axis=0)

            results_dict[f'Top_{val}_DEGs_cov_mtx_R'] = pearsonr(true_cov_mtx[~cov_nas], pred_cov_mtx[~cov_nas])[0]
            results_dict[f'Top_{val}_DEGs_cov_mtx_R2'] = pearsonr(true_cov_mtx[~cov_nas], pred_cov_mtx[~cov_nas])[0]**2
            results_dict[f'Top_{val}_DEGs_cov_mtx_MSE'] = (np.square(true_cov_mtx[~cov_nas] - pred_cov_mtx[~cov_nas])).mean(axis=0)

    return results_dict

def get_DEG_Coverage_Recall(true_DEGs, pred_DEGs, p_cutoff):
    sig_true_DEGs = true_DEGs[true_DEGs['pvals_adj'] < p_cutoff]
    true_DEGs_with_direction = [get_DEG_with_direction(gene,score) for gene, score in zip(sig_true_DEGs.index, sig_true_DEGs['scores'])]
    sig_pred_DEGs = pred_DEGs[pred_DEGs['pvals_adj'] < p_cutoff]
    pred_DEGs_with_direction = [get_DEG_with_direction(gene,score) for gene, score in zip(sig_pred_DEGs.index, sig_pred_DEGs['scores'])]
    num_true_DEGs = len(true_DEGs_with_direction)
    num_pred_DEGs = len(pred_DEGs_with_direction)
    num_overlapping_DEGs = len(set(true_DEGs_with_direction).intersection(set(pred_DEGs_with_direction)))
    if num_true_DEGs > 0: 
        COVERAGE = num_overlapping_DEGs/num_true_DEGs
    else:
        COVERAGE = None
    if num_pred_DEGs > 0:
        RECALL = num_overlapping_DEGs/num_pred_DEGs
    else:
        RECALL = None
    return COVERAGE, RECALL

def get_DEGs_overlaps(true_DEGs, pred_DEGs, DEG_vals, pval_threshold, lfc_threshold):
    if lfc_threshold:
        significant_true_DEGs = true_DEGs[(true_DEGs['pvals_adj'] < pval_threshold) & (abs(true_DEGs['lfc']) > lfc_threshold)]
        significant_pred_DEGs = pred_DEGs[(pred_DEGs['pvals_adj'] < pval_threshold) & (abs(pred_DEGs['lfc']) > lfc_threshold)]
    else:
        significant_true_DEGs = true_DEGs[true_DEGs['pvals_adj'] < pval_threshold]
        significant_pred_DEGs = pred_DEGs[pred_DEGs['pvals_adj'] < pval_threshold]

    true_DEGs_for_comparison = [get_DEG_with_direction(gene,score) for gene, score in zip(significant_true_DEGs.index, significant_true_DEGs['scores'])]   
    pred_DEGs_for_comparison = [get_DEG_with_direction(gene,score) for gene, score in zip(significant_pred_DEGs.index, significant_pred_DEGs['scores'])]
    
    logger.debug(f"Computing DEG overlaps, # of significant DEGs in true data: {len(true_DEGs_for_comparison)}, # of significant DEGs in pred data: {len(pred_DEGs_for_comparison)}")
    num_DEGs = len(significant_true_DEGs)
    DEG_vals.insert(0, num_DEGs)
    
    results = {}
    for val in DEG_vals:
        if val > num_DEGs:
            results[f'Overlap_in_top_{val}_DEGs'] = None
        else:
            results[f'Overlap_in_top_{val}_DEGs'] = len(set(true_DEGs_for_comparison[0:val]).intersection(set(pred_DEGs_for_comparison[0:val])))

    intersection = len(set(true_DEGs_for_comparison).intersection(set(pred_DEGs_for_comparison)))
    union = len(set(true_DEGs_for_comparison).union(set(pred_DEGs_for_comparison)))
    if union > 0:
        results['Jaccard'] = intersection/union
    else:
        results['Jaccard'] = None
    
    return results

In [24]:
import copy
adata_ = copy.deepcopy(adata)
adata_.X = X

In [25]:
adata_ctrl = adata_[(cell_type_names == cell_type) & (pert_type_names == gene_map['NT'])].copy()
adata_pert = adata_[(pert_type_names == pert_type) & (cell_type_names == cell_type)].copy()
adata_pert_pred = adata_ctrl.copy()
adata_pert_pred.X = pred

In [28]:
degs = get_DEGs(adata_ctrl, adata_pert)
degs_pred = get_DEGs(adata_ctrl, adata_pert_pred)

pdefault=0.05

significant_degs = degs[degs['pvals_adj'] < pdefault]
significant_degs_pred = degs_pred[degs_pred['pvals_adj'] < pdefault]

  scores[group_index, :] = (
  utils.warn_names_duplicates("obs")
  scores[group_index, :] = (


In [31]:
res = get_eval(adata_ctrl, adata_pert, adata_pert_pred, degs, [10, 25, 50, 100], pdefault, None)
print(res)

  c /= stddev[:, None]
  c /= stddev[None, :]
  true_diff = true_mean/ctrl_mean
  true_diff = true_mean/ctrl_mean
  pred_diff = pred_mean/ctrl_mean
  pred_diff = pred_mean/ctrl_mean


{'all_genes_mean_sub_diff_R': 0.05029209858854033, 'all_genes_mean_sub_diff_R2': 0.0025292951804394602, 'all_genes_mean_sub_diff_MSE': 0.36348496993582685, 'all_genes_mean_fold_diff_R': 0.16582671303629845, 'all_genes_mean_fold_diff_R2': 0.027498498756422875, 'all_genes_mean_fold_diff_MSE': 322.7634392561425, 'all_genes_mean_R': 0.6568438458438476, 'all_genes_mean_R2': 0.4314438378229362, 'all_genes_mean_MSE': 0.36348496993582685, 'all_genes_var_R': 0.6011030204631375, 'all_genes_var_R2': 0.36132484120990704, 'all_genes_var_MSE': 1.0966558748824586, 'all_genes_corr_mtx_R': 0.4622095705945211, 'all_genes_corr_mtx_R2': 0.21363768714917158, 'all_genes_corr_mtx_MSE': 0.0017465257027203752, 'all_genes_cov_mtx_R': 0.40612418943794526, 'all_genes_cov_mtx_R2': 0.16493685724662804, 'all_genes_cov_mtx_MSE': 0.0019094857114673165, 'Top_169_DEGs_sub_diff_R': 0.40730274960602525, 'Top_169_DEGs_sub_diff_R2': 0.1658955298366285, 'Top_169_DEGs_sub_diff_MSE': 1.0310371538274536, 'Top_169_DEGs_fold_diff

In [30]:
len(significant_degs.index[:10].intersection(significant_degs_pred.index[:10])), \
len(significant_degs.index[:25].intersection(significant_degs_pred.index[:25])), \
len(significant_degs.index[:50].intersection(significant_degs_pred.index[:50])), \
len(significant_degs.index[:100].intersection(significant_degs_pred.index[:100])), \
len(significant_degs.index.intersection(significant_degs_pred.index))

(0, 2, 6, 18, 159)