In [1]:
import numpy as np
import torch
import scanpy as sc
from data import PertDataloader, Network


name2path = {
    'GNN_Disentangle-L2': 'GNN_Disentangle_GAT_string_20.0_64_2_l2_Norman2019_gene_emb_pert_emb_constant_sim_gnn',
    'GNN_Disentangle_Sim': 'GNN_Disentangle_GAT_string_20.0_64_2_l3_Norman2019_gamma2_gene_emb_pert_emb_constant_sim_gnn',
    'GNN_Disentangle_Sim_No_Gene': 'GNN_Disentangle_sim_gnn', 
    'No-Perturb': 'No_Perturb_GAT_string_20.0_64_2_l3_Norman2019_gamma2_gene_emb_pert_emb_constant_sim_gnn'
}


name = 'GNN_Disentangle_Sim_No_Gene'

model_name = name2path[name]
args = np.load('./saved_args/'+model_name+'.npy', allow_pickle = True).item()
args['device'] = 'cuda:3'

In [2]:
#import wandb
#wandb.init(project='pert_gnn_simulation', entity='kexinhuang', name=name)

In [4]:
if args['network_name'] == 'string':
    args['network_path'] = '/dfs/project/perturb-gnn/graphs/STRING_full_9606.csv'

if args['dataset'] == 'Norman2019':
    data_path = '/dfs/project/perturb-gnn/datasets/Norman2019/Norman2019_hvg+perts_more_de.h5ad'

adata = sc.read_h5ad(data_path)
if 'gene_symbols' not in adata.var.columns.values:
    adata.var['gene_symbols'] = adata.var['gene_name']
gene_list = [f for f in adata.var.gene_symbols.values]
# Set up message passing network
network = Network(fname=args['network_path'], gene_list=args['gene_list'],
                  percentile=args['top_edge_percent'])

# Pertrubation dataloader
pertdl = PertDataloader(adata, network.G, network.weights, args)

There are 101013 edges in the PPI.
Creating pyg object for each cell in the data...
Local copy of pyg dataset is detected. Loading...
Loading splits...
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:9
combo_seen1:52
combo_seen2:18
unseen_single:37
Creating dataloaders....
Dataloaders created...


In [5]:
loader = pertdl.loaders['train_loader']

In [9]:
for data in loader:
    break

In [8]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [11]:
num_graphs = len(data.batch.unique())

In [10]:
emb = nn.Embedding(5045, 64, max_norm=True)

In [13]:
emb_o = emb(torch.LongTensor(list(range(5045))).repeat(num_graphs, ))

In [14]:
emb_o.shape

torch.Size([161440, 64])

In [15]:
base_emb = emb_o.reshape(num_graphs, 5045, -1)

In [17]:
base_emb.shape

torch.Size([32, 5045, 64])

In [24]:
class MLP(nn.Module):
    def __init__(self, num_features, expansion_factor, dropout):
        super().__init__()
        num_hidden = num_features * expansion_factor
        self.fc1 = nn.Linear(num_features, num_hidden)
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(num_hidden, num_features)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        x = self.dropout1(F.gelu(self.fc1(x)))
        x = self.dropout2(self.fc2(x))
        return x


class TokenMixer(nn.Module):
    def __init__(self, num_features, num_patches, expansion_factor, dropout):
        super().__init__()
        self.norm = nn.LayerNorm(num_features)
        self.mlp = MLP(num_patches, expansion_factor, dropout)

    def forward(self, x):
        # x.shape == (batch_size, num_patches, num_features)
        residual = x
        x = self.norm(x)
        x = x.transpose(1, 2)
        # x.shape == (batch_size, num_features, num_patches)
        x = self.mlp(x)
        x = x.transpose(1, 2)
        # x.shape == (batch_size, num_patches, num_features)
        out = x + residual
        return out


class ChannelMixer(nn.Module):
    def __init__(self, num_features, num_patches, expansion_factor, dropout):
        super().__init__()
        self.norm = nn.LayerNorm(num_features)
        self.mlp = MLP(num_features, expansion_factor, dropout)

    def forward(self, x):
        # x.shape == (batch_size, num_patches, num_features)
        residual = x
        x = self.norm(x)
        x = self.mlp(x)
        # x.shape == (batch_size, num_patches, num_features)
        out = x + residual
        return out


class MixerLayer(nn.Module):
    def __init__(self, num_features, num_patches, expansion_factor, dropout):
        super().__init__()
        self.token_mixer = TokenMixer(
            num_features, num_patches, expansion_factor, dropout
        )
        self.channel_mixer = ChannelMixer(
            num_features, num_patches, expansion_factor, dropout
        )

    def forward(self, x):
        # x.shape == (batch_size, num_patches, num_features)
        x = self.token_mixer(x)
        x = self.channel_mixer(x)
        # x.shape == (batch_size, num_patches, num_features)
        return x
    
mixer = MixerLayer(64, 5045, 2, 0.1)

In [27]:
mixer = nn.Sequential(
                    *[
                        MixerLayer(64, 5045, 2, 0.3)
                        for _ in range(2)
                    ]
                )

In [28]:
mixer(base_emb)

tensor([[[ 0.1904, -0.0113, -0.5770,  ..., -1.1207,  1.1071,  0.9854],
         [ 1.0404, -1.1618, -1.1913,  ..., -0.0767,  0.3916, -0.4452],
         [-0.1410,  0.1107,  0.9528,  ...,  0.3103, -0.3393, -0.4750],
         ...,
         [-0.1382,  0.8997,  0.1446,  ...,  0.2115,  0.6487, -0.1571],
         [ 1.0052, -0.3672,  0.1188,  ..., -0.0523,  0.0688,  0.3462],
         [-0.0667,  0.4412,  0.1082,  ..., -0.3944,  0.1929, -1.0405]],

        [[-0.5534,  0.3848, -0.9856,  ..., -0.1377,  0.5450,  0.7492],
         [ 0.1708, -0.1524, -0.2508,  ...,  0.0147, -0.6033, -0.2640],
         [-0.4169,  0.2331,  0.4029,  ...,  0.1368, -1.0101, -0.5444],
         ...,
         [ 0.6307,  1.0567,  0.2277,  ..., -0.0283, -0.0182, -0.1679],
         [ 0.3622,  0.3869,  0.0556,  ...,  0.6102, -0.8638, -1.0897],
         [ 0.8843,  0.0194,  0.4571,  ...,  0.0995, -0.0102, -0.0671]],

        [[ 0.9770, -0.2882,  0.1499,  ..., -1.2680,  0.6785,  0.2482],
         [ 0.3956, -0.2058,  0.1097,  ..., -0