# PYG 1028

### Train best 3 models using cell embedding (normalized_counts) engineered from multiome data

### Increase dimension of cell embedding

### Use both gene expression and peak data

In [1]:
import torch
torch.device('cuda' if torch.cuda.is_available() else 'cpu')

device(type='cuda')

In [2]:
import pandas as pd
import numpy as np
import tqdm
import torch.nn.functional as F
import copy

from transformers import RobertaTokenizerFast, RobertaForMaskedLM, DataCollatorWithPadding

from torch_geometric.data import Data, HeteroData
import torch_geometric.transforms as T

from sklearn.metrics import mean_squared_error
import copy

import os

## Data

Number of (cell_type, sm_name) pairs: 614

Number of genes expressions to predict: 18211

Control molecules: `Belinostat` and `Dabrafenib`.

In [3]:
df = pd.read_parquet('../data/de_train.parquet')

print(df.shape)
df.head(10)

(614, 18216)


Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,NK cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.10472,-0.077524,-1.625596,-0.144545,0.143555,...,-0.227781,-0.010752,-0.023881,0.674536,-0.453068,0.005164,-0.094959,0.034127,0.221377,0.368755
1,T cells CD4+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.915953,-0.88438,0.371834,-0.081677,-0.498266,...,-0.494985,-0.303419,0.304955,-0.333905,-0.315516,-0.369626,-0.095079,0.70478,1.096702,-0.869887
2,T cells CD8+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,-0.387721,-0.305378,0.567777,0.303895,-0.022653,...,-0.119422,-0.033608,-0.153123,0.183597,-0.555678,-1.494789,-0.21355,0.415768,0.078439,-0.259365
3,T regulatory cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.232893,0.129029,0.336897,0.486946,0.767661,...,0.451679,0.704643,0.015468,-0.103868,0.865027,0.189114,0.2247,-0.048233,0.216139,-0.085024
4,NK cells,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False,4.290652,-0.063864,-0.017443,-0.541154,0.570982,...,0.758474,0.510762,0.607401,-0.123059,0.214366,0.487838,-0.819775,0.112365,-0.122193,0.676629
5,T cells CD4+,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False,0.618061,0.180148,0.590015,0.035658,0.034297,...,-1.002997,0.177514,0.591768,-0.4124,-0.011633,-0.044739,0.213627,0.186406,-1.459477,1.164084
6,T cells CD8+,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False,-0.148131,0.091336,-0.097212,1.225601,0.045787,...,-0.228645,0.091248,-0.581765,0.405682,-0.034414,0.296333,0.170313,0.332022,-0.532363,0.134475
7,T regulatory cells,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False,0.561473,0.149415,3.22492,3.517419,1.361175,...,-0.156127,0.766118,0.702836,0.949482,0.757482,1.163063,1.532419,-0.399292,-2.412165,0.478977
8,B cells,Idelalisib,LSM-1205,CC[C@H](Nc1ncnc2[nH]cnc12)c1nc2cccc(F)c2c(=O)n...,False,0.394173,-0.153824,0.178232,0.566241,0.391377,...,-1.052302,-1.176587,-1.220291,-0.278944,-0.095066,1.10179,0.061803,1.406335,-0.264996,-0.119743
9,Myeloid cells,Idelalisib,LSM-1205,CC[C@H](Nc1ncnc2[nH]cnc12)c1nc2cccc(F)c2c(=O)n...,False,0.025146,0.316388,1.366885,1.301593,2.317619,...,-0.902546,-1.445523,0.794772,0.428973,0.605834,0.271988,0.492231,0.354721,1.471559,-0.259483


In [4]:
# node index for cell types
c2index = dict()
cur = 0
for x in df['cell_type'].unique():
    assert x not in c2index
    c2index[x] = cur
    cur += 1
print(c2index)
del cur

{'NK cells': 0, 'T cells CD4+': 1, 'T cells CD8+': 2, 'T regulatory cells': 3, 'B cells': 4, 'Myeloid cells': 5}


In [5]:
df['sm_name'].nunique()

146

In [6]:
# Assign the list of unique small molecules
sm_list = df['sm_name'].unique().tolist()
len(sm_list)

146

In [7]:
# small molecule embedding
tokenizer = RobertaTokenizerFast.from_pretrained("entropy/roberta_zinc_480m", max_len=128)
model = RobertaForMaskedLM.from_pretrained('entropy/roberta_zinc_480m')
collator = DataCollatorWithPadding(tokenizer, padding=True, return_tensors='pt')

# Extract SMILES embeddings
sm2emb = dict()
sm2index = dict()
cur = 0

for i in range(df.shape[0]):
    sm_name = df['sm_name'].iloc[i]
    smiles = df['SMILES'].iloc[i]



    if sm_name not in sm2emb:

        if True: # if sm_name not in ['Belinostat', 'Dabrafenib']:
            sm2index[sm_name] = cur
            cur += 1
        
        inputs = collator(tokenizer([smiles]))
        outputs = model(**inputs, output_hidden_states=True)
        full_embeddings = outputs[1][-1].detach()
        mask = inputs['attention_mask']
        node_embedding = ((full_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(-1).unsqueeze(-1))
        node_embedding = torch.squeeze(node_embedding)
        sm2emb[sm_name] = node_embedding

del cur

print(sm2index)
print(len(sm2emb))

You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'Clotrimazole': 0, 'Mometasone Furoate': 1, 'Idelalisib': 2, 'Vandetanib': 3, 'Bosutinib': 4, 'Ceritinib': 5, 'Lamivudine': 6, 'Crizotinib': 7, 'Cabozantinib': 8, 'Flutamide': 9, 'Dasatinib': 10, 'Selumetinib': 11, 'Trametinib': 12, 'ABT-199 (GDC-0199)': 13, 'Oxybenzone': 14, 'Vorinostat': 15, 'Raloxifene': 16, 'Linagliptin': 17, 'Lapatinib': 18, 'Canertinib': 19, 'Disulfiram': 20, 'Vardenafil': 21, 'Palbociclib': 22, 'Ricolinostat': 23, 'Dabrafenib': 24, 'Proscillaridin A;Proscillaridin-A': 25, 'IN1451': 26, 'Ixabepilone': 27, 'CEP-18770 (Delanzomib)': 28, 'RG7112': 29, 'MK-5108': 30, 'Resminostat': 31, 'IMD-0354': 32, 'Alvocidib': 33, 'LY2090314': 34, 'Methotrexate': 35, 'LDN 193189': 36, 'Tacalcitol': 37, 'Colchicine': 38, 'R428': 39, 'TL_HRAS26': 40, 'BMS-387032': 41, 'CGP 60474': 42, 'TIE2 Kinase Inhibitor': 43, 'PD-0325901': 44, 'Isoniazid': 45, 'GSK-1070916': 46, 'Masitinib': 47, 'Saracatinib': 48, 'CC-401': 49, 'Decitabine': 50, 'Ketoconazole': 51, 'HYDROXYUREA': 52, 'BAY 61-3

## Cell embedding

In [8]:
# cell embedding
c2emb_df = pd.read_csv('./cell_emb_normcnt_peak.csv')
c2emb_df.head()

Unnamed: 0,cell_type,AAMP,AASS,ABCC11,ABCC2,ABR,ABRAXAS2,AC002429.2,AC003102.1,AC007569.1,...,MYO3B-AS1,AF127936.1,RHOJ,AL021918.3,SYCE3,AC091906.1,OR7A17,CDC45,GABRA6,FST
0,B cells,0.378421,0.292951,0.026902,0.091517,1.973495,0.502242,0.048453,0.132143,0.449569,...,0.004518,0.001416,0.002811,0.009294,0.001431,0.001694,0.0,0.004596,0.007821,0.001665
1,T cells CD4+,0.406463,0.16178,0.029153,0.088855,1.264111,0.458575,0.010122,0.176167,0.034765,...,0.005329,0.0,0.004449,0.003967,0.004804,0.00737,0.002894,0.00177,0.003694,0.008668
2,NK cells,0.326737,0.150077,0.017284,0.054711,1.255197,0.392022,0.021979,0.116615,0.036065,...,0.00535,0.0,0.006698,0.008305,0.009473,0.001378,0.002723,0.001313,0.00124,0.00126
3,Myeloid cells,0.675277,0.083278,0.027169,0.082512,3.729312,0.833589,0.006327,0.116312,0.059618,...,0.001634,0.009832,0.003661,0.003481,0.006446,0.000629,0.002726,0.002356,0.002234,0.00168
4,T regulatory cells,0.353287,0.098467,0.0,0.038467,0.942583,0.411683,0.0,0.025606,0.060882,...,0.012295,0.0,0.0,0.0,0.0,0.0,0.013335,0.0,0.0,0.0


## Heterogenous Graph

#### Cell nodes (c_nodes)

There are 6 nodes for each unique cell types.

#### Small molecule nodes (sm_nodes)

There are 146 unique nodes for each small molecules.

SMILES embedding are used as node embeddings (node_emb_dim = 768).

In [9]:
# initialize graph
data = HeteroData()

# number of unique cell nodes and sm nodes
data['cell'].node_id = torch.arange(6)
data['sm'].node_id = torch.arange(146) # 144 or 146

## initialize cell node emb
cell_node_x = []

for cell_type, cell_index in c2index.items():
    #cell_node_x.append([cell_index])  # turn of cell embedding
    cell_node_x.append(c2emb_df[c2emb_df['cell_type']==cell_type].values[0][1:].tolist())
    

cell_node_x = torch.from_numpy(np.array(cell_node_x)).float()
data['cell'].x = cell_node_x

#print(data['cell'].x
assert data['cell'].x.shape == (6, 158205)

## initialize sm node emb
data['sm'].x = torch.from_numpy(np.array([sm2emb[y].tolist() for y in sm_list])).float()
assert data['sm'].x.shape == (146, 768)

## Edges

In [10]:
def add_edges(data):
    
    edge_index_c2sm = []

    for cell_index in range(6):
        for sm_index in range(146):
            edge_index_c2sm.append([cell_index, sm_index])

    print("Adding {} edges".format(len(edge_index_c2sm)))
    data['cell', 'gene_exp', 'sm'].edge_index = torch.tensor(edge_index_c2sm).t().contiguous().long()    
        
    # convert into bidirecitonal
    data = T.ToUndirected()(data)

    # data split (all for training)
    transform = T.RandomLinkSplit(
        num_val=0.0,
        num_test=0.0,
        is_undirected=True,
        disjoint_train_ratio=0.0,  # was 0.3
        neg_sampling_ratio=0.0,  # was 2.0
        add_negative_train_samples=False,
        edge_types=('cell', 'gene_exp', 'sm'),
        rev_edge_types=('sm', 'rev_gene_exp', 'cell'), 
    )
    
    train_data, _, _ = transform(data)
    
    return train_data

In [11]:
data = add_edges(data)

print(data)

Adding 876 edges
HeteroData(
  cell={
    node_id=[6],
    x=[6, 158205],
  },
  sm={
    node_id=[146],
    x=[146, 768],
  },
  (cell, gene_exp, sm)={
    edge_index=[2, 876],
    edge_label=[876],
    edge_label_index=[2, 876],
  },
  (sm, rev_gene_exp, cell)={ edge_index=[2, 876] }
)


## Train/Validation/Test Splits

In [12]:
df.head()

Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,NK cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.10472,-0.077524,-1.625596,-0.144545,0.143555,...,-0.227781,-0.010752,-0.023881,0.674536,-0.453068,0.005164,-0.094959,0.034127,0.221377,0.368755
1,T cells CD4+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.915953,-0.88438,0.371834,-0.081677,-0.498266,...,-0.494985,-0.303419,0.304955,-0.333905,-0.315516,-0.369626,-0.095079,0.70478,1.096702,-0.869887
2,T cells CD8+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,-0.387721,-0.305378,0.567777,0.303895,-0.022653,...,-0.119422,-0.033608,-0.153123,0.183597,-0.555678,-1.494789,-0.21355,0.415768,0.078439,-0.259365
3,T regulatory cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.232893,0.129029,0.336897,0.486946,0.767661,...,0.451679,0.704643,0.015468,-0.103868,0.865027,0.189114,0.2247,-0.048233,0.216139,-0.085024
4,NK cells,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False,4.290652,-0.063864,-0.017443,-0.541154,0.570982,...,0.758474,0.510762,0.607401,-0.123059,0.214366,0.487838,-0.819775,0.112365,-0.122193,0.676629


In [13]:
from sklearn.model_selection import train_test_split

train_df, valid_df = train_test_split(df, test_size=0.2, random_state=14, shuffle=True, stratify=df['cell_type'])


# all edge indeces and gene expressions for training
train_edge_indeces = set()
train_edge2feat = dict()

for i in range(train_df.shape[0]):
    cell_idx = c2index[train_df.iloc[i]['cell_type']]
    sm_idx = sm2index[train_df.iloc[i]['sm_name']]

    train_edge_indeces.add((cell_idx, sm_idx))
    
    train_edge2feat[(cell_idx, sm_idx)] = train_df.iloc[i, 5:].values.tolist()
    assert len(train_edge2feat[(cell_idx, sm_idx)]) == 18211


# all edge indeces and gene expressions for validation
valid_edge_indeces = set()
valid_edge2feat = dict()

for i in range(valid_df.shape[0]):
    cell_idx = c2index[valid_df.iloc[i]['cell_type']]
    sm_idx = sm2index[valid_df.iloc[i]['sm_name']]

    valid_edge_indeces.add((cell_idx, sm_idx))
    
    valid_edge2feat[(cell_idx, sm_idx)] = valid_df.iloc[i, 5:].values.tolist()
    assert len(valid_edge2feat[(cell_idx, sm_idx)]) == 18211

In [14]:
print(train_df.shape)
train_df.head()

(491, 18216)


Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
373,NK cells,MLN 2238,LSM-4944,CC(C)C[C@H](NC(=O)CNC(=O)c1cc(Cl)ccc1Cl)B(O)O,False,9.497701,5.478697,0.129686,-2.328127,5.732667,...,1.047863,3.069138,3.280055,5.284819,5.440957,5.702465,2.486294,1.484651,-1.385248,0.063061
391,T cells CD8+,I-BET151,LSM-6335,COc1cc2c(cc1-c1c(C)noc1C)ncc1[nH]c(=O)n([C@H](...,False,-0.01225,0.174611,-0.099221,-0.134166,0.143319,...,-0.301498,-0.534493,-0.06671,0.647332,0.153865,0.478001,0.884914,0.558577,-1.453534,1.611454
568,T cells CD8+,AMD-070 (hydrochloride),LSM-45591,NCCCCN(Cc1nc2ccccc2[nH]1)[C@H]1CCCc2cccnc21,False,0.909501,-0.011053,0.622651,-0.315069,0.142047,...,0.042186,0.001675,0.263338,0.077565,-0.028096,-0.006699,-0.175771,-0.451758,0.163957,-0.320038
15,T cells CD4+,Vandetanib,LSM-1199,COc1cc2c(Nc3ccc(Br)cc3F)ncnc2cc1OCC1CCN(C)CC1,False,0.106623,0.006949,0.424472,-0.124439,1.123565,...,-0.053448,0.072833,-0.42144,0.067708,1.21589,0.226569,0.084715,-0.668755,0.798338,-0.060044
190,T cells CD4+,PD-0325901,LSM-1101,O=C(NOC[C@H](O)CO)c1ccc(F)c(F)c1Nc1ccc(I)cc1F,False,-0.237122,-0.694187,0.39587,0.210085,-0.081793,...,-1.153235,-0.131176,0.49708,0.604248,0.313291,-0.730704,1.744382,1.603138,-0.143792,-0.194503


In [15]:
print(valid_df.shape)
valid_df.head()

(123, 18216)


Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
342,T regulatory cells,Tivantinib,LSM-1131,O=C1NC(=O)[C@@H](c2cn3c4c(cccc24)CCC3)[C@@H]1c...,False,0.193286,0.665367,-0.142059,0.047231,-0.062032,...,0.086186,-0.119378,0.071182,-0.69475,0.225487,0.352343,-0.127136,0.499296,0.90434,0.266353
244,T cells CD4+,Prednisolone,LSM-24954,C[C@]12C=CC(=O)C=C1CC[C@@H]1[C@@H]2[C@@H](O)C[...,False,0.134272,-0.051087,0.516691,-0.39171,-0.411336,...,-0.842898,-0.443645,1.415605,-0.82326,-0.084823,0.622433,0.845683,1.280901,-2.258758,0.635675
343,NK cells,CEP-37440,LSM-45849,CNC(=O)c1ccccc1Nc1nc(Nc2ccc3c(c2OC)CCC[C@H](N2...,False,0.791306,0.738089,-0.599179,-0.372779,-0.930951,...,-0.02937,-0.345004,0.096796,-0.287745,-0.267073,2.065996,0.24381,0.488007,-0.436535,0.361004
313,B cells,Foretinib,LSM-1158,COc1cc2c(Oc3ccc(NC(=O)C4(C(=O)Nc5ccc(F)cc5)CC4...,False,-0.176871,0.320933,0.532543,0.570455,0.725136,...,-0.432899,0.01753,-0.110185,0.207273,-0.184022,0.858617,0.271281,-0.31117,-0.230995,-1.423259
243,NK cells,Prednisolone,LSM-24954,C[C@]12C=CC(=O)C=C1CC[C@@H]1[C@@H]2[C@@H](O)C[...,False,0.05465,0.401233,0.258088,-0.396746,0.355001,...,0.093196,0.051984,-0.114204,-0.170931,-0.600998,0.242447,0.009419,0.249872,-1.248088,-0.01852


## Training routines

In [16]:
data

HeteroData(
  cell={
    node_id=[6],
    x=[6, 158205],
  },
  sm={
    node_id=[146],
    x=[146, 768],
  },
  (cell, gene_exp, sm)={
    edge_index=[2, 876],
    edge_label=[876],
    edge_label_index=[2, 876],
  },
  (sm, rev_gene_exp, cell)={ edge_index=[2, 876] }
)

In [17]:
print(data['cell', 'gene_exp', 'sm'].edge_label_index.shape)
data['cell', 'gene_exp', 'sm'].edge_label_index

torch.Size([2, 876])


tensor([[  4,   1,   1,  ...,   5,   3,   0],
        [ 60,   2, 118,  ...,  95,  48,   2]])

In [18]:
train_mask = torch.zeros_like(data['cell', 'gene_exp', 'sm'].edge_label_index, dtype=torch.bool)
valid_mask = torch.zeros_like(data['cell', 'gene_exp', 'sm'].edge_label_index, dtype=torch.bool)

for i in range(data['cell', 'gene_exp', 'sm'].edge_label_index.shape[1]):
    if (data['cell', 'gene_exp', 'sm'].edge_label_index[0, i].item(), data['cell', 'gene_exp', 'sm'].edge_label_index[1, i].item()) in train_edge_indeces:
        train_mask[:, i] = True
    if (data['cell', 'gene_exp', 'sm'].edge_label_index[0, i].item(), data['cell', 'gene_exp', 'sm'].edge_label_index[1, i].item()) in valid_edge_indeces:
        valid_mask[:, i] = True

print(train_mask.sum()/2)
print(train_mask)

print(valid_mask.sum()/2)
print(valid_mask)

tensor(491.)
tensor([[False,  True, False,  ..., False,  True,  True],
        [False,  True, False,  ..., False,  True,  True]])
tensor(123.)
tensor([[False, False,  True,  ..., False, False, False],
        [False, False,  True,  ..., False, False, False]])


In [19]:
def train(data,
          train_edge_indeces,
          valid_edge_indeces,
          train_edge2feat,
          valid_edge2feat,
          train_mask,
          valid_mask,
          cell_emb_dim=1,
          sm_emb_dim=4,
          lr=0.01,
          hidden_channels=32,
          p_dropout=0.1,
          num_epochs=100,
          print_every=1):
    
    # clone data
    data = data.clone()
    
    # show device information
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # create model (data, cell_emb_dim, sm_emb_dim, hidden_channels, p_dropout)
    model = Model(data=data,
                  cell_emb_dim=cell_emb_dim,
                  sm_emb_dim=sm_emb_dim,
                  hidden_channels=hidden_channels,
                  p_dropout=p_dropout).to(device)
    
    #print("Number of trainable parameters: {}".format(sum(param.numel() for param in model.parameters() if param.requires_grad)))
    
    model = model.to(device)
    
    # print hypoparameters and model settings
    if print_every > 0:
        print("Device: {}, lr = {}, hidden_channels = {}, num_epochs = {}".format(device, lr, hidden_channels, num_epochs))
        print(model)

    # prepare data
    data = data.to(device)
    
    train_mask = train_mask.to(device)
    valid_mask = valid_mask.to(device)

    #n_train = len(train_edge_indeces)
    #n_valid = len(valid_edge_indeces)

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # model with best performance
    best_epoch = 0
    best_valid_loss = float('inf')
    best_model = copy.deepcopy(model)
    
    for epoch in tqdm.tqdm(range(1, num_epochs+1)):
                
        optimizer.zero_grad()
        
        """
        Train
        """
        model.train()
        pred = model(data)  # (876, 18211)

        # ground truth
        target = []
        for i in range(data['cell', 'gene_exp', 'sm'].edge_label_index.shape[1]):
            cell_idx = data['cell', 'gene_exp', 'sm'].edge_label_index[0, i].item()
            sm_idx = data['cell', 'gene_exp', 'sm'].edge_label_index[1, i].item()
            if (cell_idx, sm_idx) in train_edge_indeces:
                target.append([train_edge2feat[(cell_idx, sm_idx)]])    
        target = torch.from_numpy(np.array(target)).float().squeeze().to(device)  # (491, 18211)
        
        total_train_loss = F.mse_loss(pred[train_mask[0]], target)

        total_train_loss.backward()
        optimizer.step()
        

        """
        Validation
        """
        model.eval()
        with torch.no_grad():
            
            # ground truth
            valid_target = []
            for i in range(data['cell', 'gene_exp', 'sm'].edge_label_index.shape[1]):
                cell_idx = data['cell', 'gene_exp', 'sm'].edge_label_index[0, i].item()
                sm_idx = data['cell', 'gene_exp', 'sm'].edge_label_index[1, i].item()
                if (cell_idx, sm_idx) in valid_edge_indeces:
                    valid_target.append([valid_edge2feat[(cell_idx, sm_idx)]])    
            valid_target = torch.from_numpy(np.array(valid_target)).float().squeeze()  # (123, 18211)
          
            total_valid_loss = mean_squared_error(valid_target.numpy(), pred[valid_mask[0]].detach().cpu().numpy())
               
            # check if model is better
            if total_valid_loss < best_valid_loss:
                best_epoch = epoch
                best_valid_loss = total_valid_loss
                best_model = copy.deepcopy(model)
                best_status = '<<<'
            else:
                best_status = ''

        if print_every > 0:
            if (epoch-1) % print_every == 0:
                print("Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f} {}".format(epoch, total_train_loss, total_valid_loss, best_status))
            elif epoch == num_epochs:
                print("Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f} {}".format(epoch, total_train_loss, total_valid_loss, best_status))

    if print_every > 0:
        print("\nBest epoch: {}, val_loss: {:.4f}".format(best_epoch, best_valid_loss))
    else:
        print("Cell_emb: {}, sm_emb: {}, lr = {}, hidden_channels = {}  >>>  ".format(cell_emb_dim, sm_emb_dim, lr, hidden_channels) +
              "Best epoch: {}/{}, val_loss: {:.4f}".format(best_epoch, num_epochs, best_valid_loss))
        
    print("---------------------------------")
    return best_model, best_valid_loss
        

In [20]:
def final_submission(data, model_name='test_model', model=None):

    assert model is not None
    
    data = data.clone()
    
    ## load test indeces
    id_map = pd.read_csv('../data/id_map.csv')
    assert id_map.shape[0] == 255
    #print("ID map loaded")
    #print(id_map.shape)

    ## load template
    res_df = pd.read_csv('../data/sample_submission.csv')
    assert res_df.shape[0] == 255
    #print("Submission template loaded")
    #print(res_df.shape)

    ## device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    ## load model
    model = copy.deepcopy(model).to(device)
    model.eval()
    
    data = data.to(device)
    
    ## Inference
    with torch.no_grad():
        pred = model(data).detach().cpu().numpy()  # (876, 18211)
        
    for i in tqdm.tqdm(range(id_map.shape[0])):
        cell_idx = c2index[id_map['cell_type'].iloc[i]]
        sm_idx = sm2index[id_map['sm_name'].iloc[i]]

        all_edge_indeces = data['cell', 'gene_exp', 'sm'].edge_label_index
        #print(all_edge_indeces)
        
        for j in range(all_edge_indeces.shape[1]):
            if all_edge_indeces[0, j].item() == cell_idx and all_edge_indeces[1, j].item() == sm_idx:
                #print("{} + {} | {:.4f}".format(id_map['cell_type'].iloc[i],
                #                                     id_map['sm_name'].iloc[i],
                #                                     pred[j, :].sum()))
                
                #print("{}".format(pred[j, :3]))
                res_df.iloc[i, 1:] = pred[j, :]
                break
                
    res_df.to_csv('../submissions/' + model_name + '.csv', index=False)
    print("File saved.")
    print(res_df.shape)

    # sanity check
    tmp = pd.read_csv('../submissions/' + model_name + '.csv')
    assert tmp.shape == res_df.shape

    del tmp


### Experiment 1 -- model_1022_exp3_29

In [21]:
from torch_geometric.nn import SAGEConv, to_hetero
from torch.nn import Linear, Dropout, Embedding
import torch.nn.functional as F

class GNN(torch.nn.Module):
    """
    Compute node embedding
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()       
        self.conv1 = SAGEConv(-1, hidden_channels, normalize=True)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels, normalize=True)
        self.dropout1 = Dropout(p=p_dropout)
        
    def forward(self, x, edge_index):
        x = self.dropout1(F.elu(self.conv1(x, edge_index)))        
        x = self.conv2(x, edge_index)
        return x

class MLP(torch.nn.Module):
    """
    Model that predicts the gene expression given node features at the two ends
    
    1020: Change inner layer from (1024, 1024) units to (256, 256) units 
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()

        self.dropout1 = Dropout(p=p_dropout)
        self.dropout2 = Dropout(p=p_dropout)
        
        self.layer1 = Linear(2*hidden_channels, 256)
        self.layer2 = Linear(256, 1024)
        self.layer3 = Linear(1024, 18211)
    
    def forward(self, x_cell, x_sm, edge_label_index):
        
        # Retrieve embeddings from the two nodes
        feat_cell = x_cell[edge_label_index[0]]  # dim = hidden_channels
        feat_sm = x_sm[edge_label_index[1]]  #  dim = hidden_channels

        # calculate the gene expression
        y = torch.cat((feat_cell, feat_sm), dim=1)  # dim = 2*hidden_channels
        
        y = F.relu(self.layer1(y))  # dim = 256
        y = self.dropout1(y)
        
        y = F.relu(self.layer2(y))  # dim = 1024
        y = self.dropout2(y)
        
        y = self.layer3(y)  # dim = 18211
        
        return y


class Model(torch.nn.Module):
    def __init__(self, data, cell_emb_dim, sm_emb_dim, hidden_channels, p_dropout):
        super().__init__()

        # cell emb for each cell type
        self.cell_emb = Embedding(data["cell"].num_nodes, cell_emb_dim)
        
        """
        1027 - change from hidden_channels to 64
        """
        # cell emb leveraging control gene expression
        self.cell_lin = Linear(158205, 64)

        # sm emb for each sm type
        if sm_emb_dim > 0:
            self.sm_emb = Embedding(data["sm"].num_nodes, sm_emb_dim)
        else:
            self.sm_emb = None
        # sm emb leveraging SMILES embedding
        self.sm_lin = Linear(768, hidden_channels)
        
        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels=hidden_channels, p_dropout=p_dropout)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        
        self.MLP = MLP(hidden_channels=hidden_channels, p_dropout=p_dropout)
        
                
    def forward(self, data):
        # calculate node embedding
        #x_cell = self.cell_emb(data["cell"].node_id)
        x_cell = torch.cat((self.cell_emb(data["cell"].node_id), self.cell_lin(data["cell"].x)), dim=1)

        if self.sm_emb:
            x_sm = torch.cat((self.sm_emb(data["sm"].node_id), self.sm_lin(data["sm"].x)), dim=1)
        else:
            x_sm = self.sm_lin(data["sm"].x)
            
        
        x_dict = {
          "cell": x_cell,
          "sm": x_sm,
        } 

        # compute final node embedding
        x_dict = self.gnn(x_dict, data.edge_index_dict)

        # Use node embedding to predict gene expression
        pred = self.MLP(
            x_dict["cell"],
            x_dict["sm"],
            data["cell", "gene_exp", "sm"].edge_label_index,
        )
        return pred

In [22]:
cur_model = None
cur_valid_loss = float('inf')

cur_sub_index = 0

for _ in range(20):

    cur_sub_index += 1

    tmp_model, tmp_valid_loss = train(data,
                                      train_edge_indeces,
                                      valid_edge_indeces,
                                      train_edge2feat,
                                      valid_edge2feat,
                                      train_mask,
                                      valid_mask,
                                      cell_emb_dim=2,
                                      sm_emb_dim=32,
                                      lr=5e-4,
                                      hidden_channels=8,
                                      p_dropout=0.2,
                                      num_epochs=2000,
                                      print_every=-1)
    
    if tmp_valid_loss >= 5.0:
        continue
    
    """
    Save all models
    """
    cur_valid_loss = tmp_valid_loss
    cur_model = tmp_model
    cur_model_name = 'model_1028_exp1_' + str(cur_sub_index)
    print(cur_model_name)
    print(cur_model)
    final_submission(data=data, model_name=cur_model_name, model=cur_model)
    torch.save(cur_model.state_dict(), os.path.join('../pyg_models/', cur_model_name+'.pth'))

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,
100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:25<00:00,  2.31it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1719/2000, val_loss: 4.9447
---------------------------------
model_1028_exp1_1
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=8, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 8, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(8, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(8, 8, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=16, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_features

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:01<00:00,  4.12it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:14<00:00,  2.19it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1350/2000, val_loss: 5.6959
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:16<00:00,  2.18it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1673/2000, val_loss: 5.2964
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:16<00:00,  2.18it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1911/2000, val_loss: 5.0469
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:14<00:00,  2.19it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1570/2000, val_loss: 5.4692
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:21<00:00,  2.17it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1664/2000, val_loss: 4.7027
---------------------------------
model_1028_exp1_6
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=8, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 8, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(8, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(8, 8, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=16, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_features

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:03<00:00,  4.02it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:17<00:00,  2.18it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1776/2000, val_loss: 5.2636
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:10<00:00,  2.20it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1529/2000, val_loss: 5.6438
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:13<00:00,  2.19it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1712/2000, val_loss: 5.3616
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:11<00:00,  2.19it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1945/2000, val_loss: 4.9682
---------------------------------
model_1028_exp1_10
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=8, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 8, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(8, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(8, 8, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=16, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_feature

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:03<00:00,  4.00it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:13<00:00,  2.19it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1561/2000, val_loss: 5.3714
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:07<00:00,  2.20it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1832/2000, val_loss: 4.7360
---------------------------------
model_1028_exp1_12
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=8, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 8, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(8, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(8, 8, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=16, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_feature

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:03<00:00,  4.02it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:02<00:00,  2.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1927/2000, val_loss: 4.8150
---------------------------------
model_1028_exp1_13
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=8, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 8, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(8, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(8, 8, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=16, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_feature

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:02<00:00,  4.06it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:06<00:00,  2.21it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1936/2000, val_loss: 4.4772
---------------------------------
model_1028_exp1_14
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=8, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 8, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(8, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(8, 8, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=16, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_feature

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:02<00:00,  4.06it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:06<00:00,  2.21it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1842/2000, val_loss: 4.8585
---------------------------------
model_1028_exp1_15
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=8, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 8, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(8, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(8, 8, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=16, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_feature

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:01<00:00,  4.12it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:16<00:00,  2.18it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1237/2000, val_loss: 5.6481
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:27<00:00,  2.16it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1873/2000, val_loss: 5.2409
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:21<00:00,  2.17it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1933/2000, val_loss: 5.2653
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:19<00:00,  2.18it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1711/2000, val_loss: 4.7735
---------------------------------
model_1028_exp1_19
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=8, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 8, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(8, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(8, 8, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=16, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_feature

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:02<00:00,  4.10it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:23<00:00,  2.17it/s]

Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1972/2000, val_loss: 6.0093
---------------------------------





### Experiment 2 -- model_1023_exp5_4

In [23]:
from torch_geometric.nn import SAGEConv, to_hetero
from torch.nn import Linear, Dropout, Embedding
import torch.nn.functional as F

class GNN(torch.nn.Module):
    """
    Compute node embedding
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()       
        self.conv1 = SAGEConv(-1, hidden_channels, normalize=True)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels, normalize=True)
        self.dropout1 = Dropout(p=p_dropout)
        
    def forward(self, x, edge_index):
        x = self.dropout1(F.elu(self.conv1(x, edge_index)))        
        x = self.conv2(x, edge_index)
        return x

class MLP(torch.nn.Module):
    """
    Model that predicts the gene expression given node features at the two ends
    
    1020: Change inner layer from (1024, 1024) units to (256, 256) units 
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()

        self.dropout1 = Dropout(p=p_dropout)
        self.dropout2 = Dropout(p=p_dropout)
        
        self.layer1 = Linear(2*hidden_channels, 256)
        self.layer2 = Linear(256, 1024)
        self.layer3 = Linear(1024, 18211)
    
    def forward(self, x_cell, x_sm, edge_label_index):
        
        # Retrieve embeddings from the two nodes
        feat_cell = x_cell[edge_label_index[0]]  # dim = hidden_channels
        feat_sm = x_sm[edge_label_index[1]]  #  dim = hidden_channels

        # calculate the gene expression
        y = torch.cat((feat_cell, feat_sm), dim=1)  # dim = 2*hidden_channels
        
        y = F.relu(self.layer1(y))  # dim = 256
        y = self.dropout1(y)
        
        y = F.relu(self.layer2(y))  # dim = 1024
        y = self.dropout2(y)
        
        y = self.layer3(y)  # dim = 18211
        
        return y


class Model(torch.nn.Module):
    def __init__(self, data, cell_emb_dim, sm_emb_dim, hidden_channels, p_dropout):
        super().__init__()

        # cell emb for each cell type
        self.cell_emb = Embedding(data["cell"].num_nodes, cell_emb_dim)
        
        """
        1027 - change from hidden_channels to 64
        """
        # cell emb leveraging control gene expression
        self.cell_lin = Linear(158205, 64)

        # sm emb for each sm type
        if sm_emb_dim > 0:
            self.sm_emb = Embedding(data["sm"].num_nodes, sm_emb_dim)
        else:
            self.sm_emb = None
        # sm emb leveraging SMILES embedding
        self.sm_lin = Linear(768, hidden_channels)
        
        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels=hidden_channels, p_dropout=p_dropout)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        
        self.MLP = MLP(hidden_channels=hidden_channels, p_dropout=p_dropout)
        
                
    def forward(self, data):
        # calculate node embedding
        #x_cell = self.cell_emb(data["cell"].node_id)
        x_cell = torch.cat((self.cell_emb(data["cell"].node_id), self.cell_lin(data["cell"].x)), dim=1)
        
        
        if self.sm_emb:
            x_sm = torch.cat((self.sm_emb(data["sm"].node_id), self.sm_lin(data["sm"].x)), dim=1)
        else:
            x_sm = self.sm_lin(data["sm"].x)
            
        
        x_dict = {
          "cell": x_cell,
          "sm": x_sm,
        } 

        # compute final node embedding
        x_dict = self.gnn(x_dict, data.edge_index_dict)

        # Use node embedding to predict gene expression
        pred = self.MLP(
            x_dict["cell"],
            x_dict["sm"],
            data["cell", "gene_exp", "sm"].edge_label_index,
        )
        return pred

In [24]:
cur_model = None
cur_valid_loss = float('inf')

cur_sub_index = 0

for _ in range(10):

    cur_sub_index += 1

    tmp_model, tmp_valid_loss = train(data,
                                      train_edge_indeces,
                                      valid_edge_indeces,
                                      train_edge2feat,
                                      valid_edge2feat,
                                      train_mask,
                                      valid_mask,
                                      cell_emb_dim=2,
                                      sm_emb_dim=32,
                                      lr=5e-4,
                                      hidden_channels=16,
                                      p_dropout=0.2,
                                      num_epochs=2000,
                                      print_every=-1)

    if tmp_valid_loss >= 5.0:
        continue
    
    """
    Save all models
    """
    cur_valid_loss = tmp_valid_loss
    cur_model = tmp_model
    cur_model_name = 'model_1028_exp2_' + str(cur_sub_index)
    print(cur_model_name)
    print(cur_model)
    final_submission(data=data, model_name=cur_model_name, model=cur_model)
    torch.save(cur_model.state_dict(), os.path.join('../pyg_models/', cur_model_name+'.pth'))

100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:05<00:00,  2.21it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 888/2000, val_loss: 5.3643
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:04<00:00,  2.21it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 939/2000, val_loss: 4.6904
---------------------------------
model_1028_exp2_2
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_f

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:02<00:00,  4.06it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:12<00:00,  2.19it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1955/2000, val_loss: 4.7159
---------------------------------
model_1028_exp2_3
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:03<00:00,  4.01it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:07<00:00,  2.21it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1161/2000, val_loss: 5.3731
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:14<00:00,  2.19it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1113/2000, val_loss: 4.5782
---------------------------------
model_1028_exp2_5
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:03<00:00,  4.03it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:30<00:00,  2.15it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1276/2000, val_loss: 5.2302
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:21<00:00,  2.17it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1739/2000, val_loss: 4.8652
---------------------------------
model_1028_exp2_7
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:01<00:00,  4.11it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:12<00:00,  2.19it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1078/2000, val_loss: 4.7972
---------------------------------
model_1028_exp2_8
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:03<00:00,  4.00it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:01<00:00,  2.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1841/2000, val_loss: 4.4979
---------------------------------
model_1028_exp2_9
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:03<00:00,  4.03it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:56<00:00,  2.23it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1991/2000, val_loss: 4.6483
---------------------------------
model_1028_exp2_10
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:03<00:00,  4.04it/s]


File saved.
(255, 18212)


### Experiment 3 -- model_1020_exp3

In [25]:
from torch_geometric.nn import SAGEConv, to_hetero
from torch.nn import Linear, Dropout, Embedding
import torch.nn.functional as F

class GNN(torch.nn.Module):
    """
    Compute node embedding
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()       
        self.conv1 = SAGEConv(-1, hidden_channels, normalize=True)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels, normalize=True)
        self.conv3 = SAGEConv(hidden_channels, hidden_channels, normalize=True)
        self.dropout1 = Dropout(p=p_dropout)
        self.dropout2 = Dropout(p=p_dropout)
        
    def forward(self, x, edge_index):
        x = self.dropout1(F.elu(self.conv1(x, edge_index)))        
        x = self.dropout2(F.elu(self.conv2(x, edge_index)))
        x = self.conv3(x, edge_index)
        return x

class MLP(torch.nn.Module):
    """
    Model that predicts the gene expression given node features at the two ends
    
    1020: Change inner layer from (1024, 1024) units to (256, 256) units 
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()

        self.dropout1 = Dropout(p=p_dropout)
        self.dropout2 = Dropout(p=p_dropout)
        
        self.layer1 = Linear(2*hidden_channels, 256)
        self.layer2 = Linear(256, 256)
        self.layer3 = Linear(256, 18211)
    
    def forward(self, x_cell, x_sm, edge_label_index):
        
        # Retrieve embeddings from the two nodes
        feat_cell = x_cell[edge_label_index[0]]  # dim = hidden_channels
        feat_sm = x_sm[edge_label_index[1]]  #  dim = hidden_channels

        # calculate the gene expression
        y = torch.cat((feat_cell, feat_sm), dim=1)  # dim = 2*hidden_channels
        
        y = F.relu(self.layer1(y))  # dim = 256
        y = self.dropout1(y)
        
        y = F.relu(self.layer2(y))  # dim = 256
        y = self.dropout2(y)
        
        y = self.layer3(y)  # dim = 18211
        
        return y


class Model(torch.nn.Module):
    def __init__(self, data, cell_emb_dim, sm_emb_dim, hidden_channels, p_dropout):
        super().__init__()

        # cell emb for each cell type
        self.cell_emb = Embedding(data["cell"].num_nodes, cell_emb_dim)

        
        """
        1027 - change from hidden_channels to 64
        """
        # cell emb leveraging control gene expression
        self.cell_lin = Linear(158205, 64)
        

        # sm emb for each sm type
        if sm_emb_dim > 0:
            self.sm_emb = Embedding(data["sm"].num_nodes, sm_emb_dim)
        else:
            self.sm_emb = None
        # sm emb leveraging SMILES embedding
        self.sm_lin = Linear(768, hidden_channels)
        
        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels=hidden_channels, p_dropout=p_dropout)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        
        self.MLP = MLP(hidden_channels=hidden_channels, p_dropout=p_dropout)
        
                
    def forward(self, data):
        # calculate node embedding
        #x_cell = self.cell_emb(data["cell"].node_id)
        x_cell = torch.cat((self.cell_emb(data["cell"].node_id), self.cell_lin(data["cell"].x)), dim=1)

        if self.sm_emb:
            x_sm = torch.cat((self.sm_emb(data["sm"].node_id), self.sm_lin(data["sm"].x)), dim=1)
        else:
            x_sm = self.sm_lin(data["sm"].x)
            
        
        x_dict = {
          "cell": x_cell,
          "sm": x_sm,
        } 

        # compute final node embedding
        x_dict = self.gnn(x_dict, data.edge_index_dict)

        # Use node embedding to predict gene expression
        pred = self.MLP(
            x_dict["cell"],
            x_dict["sm"],
            data["cell", "gene_exp", "sm"].edge_label_index,
        )
        return pred

In [26]:
cur_model = None
cur_valid_loss = float('inf')

cur_sub_index = 0

for _ in range(20):

    cur_sub_index += 1

    tmp_model, tmp_valid_loss = train(data,
                                      train_edge_indeces,
                                      valid_edge_indeces,
                                      train_edge2feat,
                                      valid_edge2feat,
                                      train_mask,
                                      valid_mask,
                                      cell_emb_dim=2,
                                      sm_emb_dim=32,
                                      lr=5e-4,
                                      hidden_channels=8,
                                      p_dropout=0.2,
                                      num_epochs=2000,
                                      print_every=-1)
    
    if tmp_valid_loss >= 5.0:
        continue
    
    """
    Save all models
    """
    cur_valid_loss = tmp_valid_loss
    cur_model = tmp_model
    cur_model_name = 'model_1028_exp3_' + str(cur_sub_index)
    print(cur_model_name)
    print(cur_model)
    final_submission(data=data, model_name=cur_model_name, model=cur_model)
    torch.save(cur_model.state_dict(), os.path.join('../pyg_models/', cur_model_name+'.pth'))

100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:05<00:00,  2.21it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 629/2000, val_loss: 5.9420
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:15<00:00,  2.34it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 944/2000, val_loss: 6.2729
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [13:50<00:00,  2.41it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 498/2000, val_loss: 6.4474
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:22<00:00,  2.32it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 763/2000, val_loss: 5.9063
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:30<00:00,  2.30it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1636/2000, val_loss: 6.4632
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:28<00:00,  2.30it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 869/2000, val_loss: 6.3737
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:25<00:00,  2.31it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 604/2000, val_loss: 6.3499
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:26<00:00,  2.31it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 668/2000, val_loss: 6.0288
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:44<00:00,  2.26it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 588/2000, val_loss: 6.1188
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:43<00:00,  2.26it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 721/2000, val_loss: 6.0472
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:49<00:00,  2.25it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 744/2000, val_loss: 6.6446
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:48<00:00,  2.25it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 912/2000, val_loss: 6.5027
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:50<00:00,  2.24it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 855/2000, val_loss: 6.4116
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:45<00:00,  2.26it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1831/2000, val_loss: 6.5011
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:48<00:00,  2.25it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1348/2000, val_loss: 5.8858
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:47<00:00,  2.25it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 849/2000, val_loss: 6.1120
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:50<00:00,  2.25it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 801/2000, val_loss: 6.3791
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:48<00:00,  2.25it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 780/2000, val_loss: 6.3831
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:52<00:00,  2.24it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 677/2000, val_loss: 6.1685
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:46<00:00,  2.26it/s]

Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 674/2000, val_loss: 6.4222
---------------------------------





### Experiment 4 -- model_1022_exp3_29

Change dropout to 0.1 from experiment 1

In [27]:
from torch_geometric.nn import SAGEConv, to_hetero
from torch.nn import Linear, Dropout, Embedding
import torch.nn.functional as F

class GNN(torch.nn.Module):
    """
    Compute node embedding
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()       
        self.conv1 = SAGEConv(-1, hidden_channels, normalize=True)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels, normalize=True)
        self.dropout1 = Dropout(p=p_dropout)
        
    def forward(self, x, edge_index):
        x = self.dropout1(F.elu(self.conv1(x, edge_index)))        
        x = self.conv2(x, edge_index)
        return x

class MLP(torch.nn.Module):
    """
    Model that predicts the gene expression given node features at the two ends
    
    1020: Change inner layer from (1024, 1024) units to (256, 256) units 
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()

        self.dropout1 = Dropout(p=p_dropout)
        self.dropout2 = Dropout(p=p_dropout)
        
        self.layer1 = Linear(2*hidden_channels, 256)
        self.layer2 = Linear(256, 1024)
        self.layer3 = Linear(1024, 18211)
    
    def forward(self, x_cell, x_sm, edge_label_index):
        
        # Retrieve embeddings from the two nodes
        feat_cell = x_cell[edge_label_index[0]]  # dim = hidden_channels
        feat_sm = x_sm[edge_label_index[1]]  #  dim = hidden_channels

        # calculate the gene expression
        y = torch.cat((feat_cell, feat_sm), dim=1)  # dim = 2*hidden_channels
        
        y = F.relu(self.layer1(y))  # dim = 256
        y = self.dropout1(y)
        
        y = F.relu(self.layer2(y))  # dim = 1024
        y = self.dropout2(y)
        
        y = self.layer3(y)  # dim = 18211
        
        return y


class Model(torch.nn.Module):
    def __init__(self, data, cell_emb_dim, sm_emb_dim, hidden_channels, p_dropout):
        super().__init__()

        # cell emb for each cell type
        self.cell_emb = Embedding(data["cell"].num_nodes, cell_emb_dim)
        
        """
        1027 - change from hidden_channels to 64
        """
        # cell emb leveraging control gene expression
        self.cell_lin = Linear(158205, 64)

        # sm emb for each sm type
        if sm_emb_dim > 0:
            self.sm_emb = Embedding(data["sm"].num_nodes, sm_emb_dim)
        else:
            self.sm_emb = None
        # sm emb leveraging SMILES embedding
        self.sm_lin = Linear(768, hidden_channels)
        
        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels=hidden_channels, p_dropout=p_dropout)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        
        self.MLP = MLP(hidden_channels=hidden_channels, p_dropout=p_dropout)
        
                
    def forward(self, data):
        # calculate node embedding
        #x_cell = self.cell_emb(data["cell"].node_id)
        x_cell = torch.cat((self.cell_emb(data["cell"].node_id), self.cell_lin(data["cell"].x)), dim=1)

        if self.sm_emb:
            x_sm = torch.cat((self.sm_emb(data["sm"].node_id), self.sm_lin(data["sm"].x)), dim=1)
        else:
            x_sm = self.sm_lin(data["sm"].x)
            
        
        x_dict = {
          "cell": x_cell,
          "sm": x_sm,
        } 

        # compute final node embedding
        x_dict = self.gnn(x_dict, data.edge_index_dict)

        # Use node embedding to predict gene expression
        pred = self.MLP(
            x_dict["cell"],
            x_dict["sm"],
            data["cell", "gene_exp", "sm"].edge_label_index,
        )
        return pred

In [28]:
cur_model = None
cur_valid_loss = float('inf')

cur_sub_index = 0

for _ in range(20):

    cur_sub_index += 1

    tmp_model, tmp_valid_loss = train(data,
                                      train_edge_indeces,
                                      valid_edge_indeces,
                                      train_edge2feat,
                                      valid_edge2feat,
                                      train_mask,
                                      valid_mask,
                                      cell_emb_dim=2,
                                      sm_emb_dim=32,
                                      lr=5e-4,
                                      hidden_channels=8,
                                      p_dropout=0.1,
                                      num_epochs=2000,
                                      print_every=-1)
    
    if tmp_valid_loss >= 5.0:
        continue
    
    """
    Save all models
    """
    cur_valid_loss = tmp_valid_loss
    cur_model = tmp_model
    cur_model_name = 'model_1028_exp4_' + str(cur_sub_index)
    print(cur_model_name)
    print(cur_model)
    final_submission(data=data, model_name=cur_model_name, model=cur_model)
    torch.save(cur_model.state_dict(), os.path.join('../pyg_models/', cur_model_name+'.pth'))

100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:16<00:00,  2.33it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1193/2000, val_loss: 5.2763
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:13<00:00,  2.34it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1968/2000, val_loss: 5.2195
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:17<00:00,  2.33it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1486/2000, val_loss: 5.0494
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:15<00:00,  2.34it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1515/2000, val_loss: 5.1751
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:12<00:00,  2.35it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1685/2000, val_loss: 5.2330
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:19<00:00,  2.33it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1429/2000, val_loss: 4.8843
---------------------------------
model_1028_exp4_6
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=8, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 8, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.1, inplace=False)
      (sm): Dropout(p=0.1, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(8, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(8, 8, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
    (layer1): Linear(in_features=16, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_features

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:02<00:00,  4.09it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:22<00:00,  2.32it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1944/2000, val_loss: 5.1289
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:22<00:00,  2.32it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1076/2000, val_loss: 5.1560
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:29<00:00,  2.30it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1586/2000, val_loss: 5.2080
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:30<00:00,  2.30it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1761/2000, val_loss: 4.9489
---------------------------------
model_1028_exp4_10
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=8, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 8, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.1, inplace=False)
      (sm): Dropout(p=0.1, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(8, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(8, 8, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
    (layer1): Linear(in_features=16, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_feature

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:01<00:00,  4.18it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:25<00:00,  2.31it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1777/2000, val_loss: 5.3123
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:21<00:00,  2.32it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1131/2000, val_loss: 4.7959
---------------------------------
model_1028_exp4_12
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=8, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 8, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.1, inplace=False)
      (sm): Dropout(p=0.1, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(8, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(8, 8, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
    (layer1): Linear(in_features=16, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_feature

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:03<00:00,  4.03it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:32<00:00,  2.29it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1191/2000, val_loss: 5.2924
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:28<00:00,  2.30it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1998/2000, val_loss: 4.6771
---------------------------------
model_1028_exp4_14
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=8, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 8, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.1, inplace=False)
      (sm): Dropout(p=0.1, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(8, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(8, 8, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
    (layer1): Linear(in_features=16, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_feature

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:01<00:00,  4.13it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:22<00:00,  2.32it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 991/2000, val_loss: 5.4500
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:22<00:00,  2.32it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1976/2000, val_loss: 5.3703
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:25<00:00,  2.31it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1528/2000, val_loss: 4.8200
---------------------------------
model_1028_exp4_17
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=8, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 8, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.1, inplace=False)
      (sm): Dropout(p=0.1, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(8, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(8, 8, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
    (layer1): Linear(in_features=16, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_feature

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:01<00:00,  4.17it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:25<00:00,  2.31it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1357/2000, val_loss: 4.8715
---------------------------------
model_1028_exp4_18
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=8, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 8, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.1, inplace=False)
      (sm): Dropout(p=0.1, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(8, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(8, 8, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
    (layer1): Linear(in_features=16, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_feature

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:01<00:00,  4.15it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:26<00:00,  2.31it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1969/2000, val_loss: 5.1329
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:23<00:00,  2.32it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1557/2000, val_loss: 4.6986
---------------------------------
model_1028_exp4_20
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=8, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 8, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.1, inplace=False)
      (sm): Dropout(p=0.1, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(8, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(8, 8, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
    (layer1): Linear(in_features=16, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_feature

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:01<00:00,  4.12it/s]


File saved.
(255, 18212)


### Experiment 5

In [29]:
from torch_geometric.nn import SAGEConv, to_hetero
from torch.nn import Linear, Dropout, Embedding
import torch.nn.functional as F

class GNN(torch.nn.Module):
    """
    Compute node embedding
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()       
        self.conv1 = SAGEConv(-1, hidden_channels, normalize=True)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels, normalize=True)
        self.dropout1 = Dropout(p=p_dropout)
        
    def forward(self, x, edge_index):
        x = self.dropout1(F.elu(self.conv1(x, edge_index)))        
        x = self.conv2(x, edge_index)
        return x

class MLP(torch.nn.Module):
    """
    Model that predicts the gene expression given node features at the two ends
    
    1020: Change inner layer from (1024, 1024) units to (256, 256) units 
    1029: Use a single 2048 dense layer
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()

        self.dropout1 = Dropout(p=p_dropout)
        
        self.layer1 = Linear(2*hidden_channels, 2048)
        self.layer2 = Linear(2048, 18211)
    
    def forward(self, x_cell, x_sm, edge_label_index):
        
        # Retrieve embeddings from the two nodes
        feat_cell = x_cell[edge_label_index[0]]  # dim = hidden_channels
        feat_sm = x_sm[edge_label_index[1]]  #  dim = hidden_channels

        # calculate the gene expression
        y = torch.cat((feat_cell, feat_sm), dim=1)  # dim = 2*hidden_channels
        
        y = F.relu(self.layer1(y))  # dim = 2048
        y = self.dropout1(y)
        
        y = self.layer2(y)  # dim = 18211
        
        return y


class Model(torch.nn.Module):
    def __init__(self, data, cell_emb_dim, sm_emb_dim, hidden_channels, p_dropout):
        super().__init__()

        # cell emb for each cell type
        self.cell_emb = Embedding(data["cell"].num_nodes, cell_emb_dim)
        
        """
        1027 - change from hidden_channels to 64
        """
        # cell emb leveraging control gene expression
        self.cell_lin = Linear(158205, 64)

        # sm emb for each sm type
        if sm_emb_dim > 0:
            self.sm_emb = Embedding(data["sm"].num_nodes, sm_emb_dim)
        else:
            self.sm_emb = None
        # sm emb leveraging SMILES embedding
        self.sm_lin = Linear(768, hidden_channels)
        
        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels=hidden_channels, p_dropout=p_dropout)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        
        self.MLP = MLP(hidden_channels=hidden_channels, p_dropout=p_dropout)
        
                
    def forward(self, data):
        # calculate node embedding
        #x_cell = self.cell_emb(data["cell"].node_id)
        x_cell = torch.cat((self.cell_emb(data["cell"].node_id), self.cell_lin(data["cell"].x)), dim=1)

        if self.sm_emb:
            x_sm = torch.cat((self.sm_emb(data["sm"].node_id), self.sm_lin(data["sm"].x)), dim=1)
        else:
            x_sm = self.sm_lin(data["sm"].x)
            
        
        x_dict = {
          "cell": x_cell,
          "sm": x_sm,
        } 

        # compute final node embedding
        x_dict = self.gnn(x_dict, data.edge_index_dict)

        # Use node embedding to predict gene expression
        pred = self.MLP(
            x_dict["cell"],
            x_dict["sm"],
            data["cell", "gene_exp", "sm"].edge_label_index,
        )
        return pred

In [30]:
cur_model = None
cur_valid_loss = float('inf')

cur_sub_index = 0

for _ in range(20):

    cur_sub_index += 1

    tmp_model, tmp_valid_loss = train(data,
                                      train_edge_indeces,
                                      valid_edge_indeces,
                                      train_edge2feat,
                                      valid_edge2feat,
                                      train_mask,
                                      valid_mask,
                                      cell_emb_dim=2,
                                      sm_emb_dim=32,
                                      lr=5e-4,
                                      hidden_channels=8,
                                      p_dropout=0.1,
                                      num_epochs=2000,
                                      print_every=-1)
    
    if tmp_valid_loss >= 5.0:
        continue
    
    """
    Save all models
    """
    cur_valid_loss = tmp_valid_loss
    cur_model = tmp_model
    cur_model_name = 'model_1028_exp5_' + str(cur_sub_index)
    print(cur_model_name)
    print(cur_model)
    final_submission(data=data, model_name=cur_model_name, model=cur_model)
    torch.save(cur_model.state_dict(), os.path.join('../pyg_models/', cur_model_name+'.pth'))

100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:02<00:00,  2.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1614/2000, val_loss: 6.1035
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:59<00:00,  2.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1934/2000, val_loss: 5.3862
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:02<00:00,  2.21it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1796/2000, val_loss: 5.1239
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:59<00:00,  2.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1916/2000, val_loss: 5.4370
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:02<00:00,  2.21it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1940/2000, val_loss: 5.4158
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:02<00:00,  2.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1287/2000, val_loss: 5.7643
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:00<00:00,  2.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1996/2000, val_loss: 5.1126
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:01<00:00,  2.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1815/2000, val_loss: 5.4179
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:01<00:00,  2.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1999/2000, val_loss: 5.2043
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:58<00:00,  2.23it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1307/2000, val_loss: 5.1913
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:59<00:00,  2.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1789/2000, val_loss: 4.7409
---------------------------------
model_1028_exp5_11
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=158205, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=8, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 8, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.1, inplace=False)
      (sm): Dropout(p=0.1, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(8, 8, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(8, 8, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.1, inplace=False)
    (layer1): Linear(in_features=16, out_features=2048, bias=True)
    (layer2): Linear(in_features=2048, out_features=18211, bias=True)
  )
)


100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:01<00:00,  4.12it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:00<00:00,  2.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1654/2000, val_loss: 5.5162
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:00<00:00,  2.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1551/2000, val_loss: 5.6127
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:03<00:00,  2.21it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1634/2000, val_loss: 5.6513
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:02<00:00,  2.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1093/2000, val_loss: 5.9193
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:01<00:00,  2.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1673/2000, val_loss: 6.0254
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:02<00:00,  2.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1977/2000, val_loss: 5.3910
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:02<00:00,  2.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1564/2000, val_loss: 5.1927
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:01<00:00,  2.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1893/2000, val_loss: 5.6357
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:13<00:00,  2.19it/s]

Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1476/2000, val_loss: 5.3265
---------------------------------





### Experiment 6

dropout from 0.1 to 0.2

In [31]:
cur_model = None
cur_valid_loss = float('inf')

cur_sub_index = 0

for _ in range(20):

    cur_sub_index += 1

    tmp_model, tmp_valid_loss = train(data,
                                      train_edge_indeces,
                                      valid_edge_indeces,
                                      train_edge2feat,
                                      valid_edge2feat,
                                      train_mask,
                                      valid_mask,
                                      cell_emb_dim=2,
                                      sm_emb_dim=32,
                                      lr=5e-4,
                                      hidden_channels=8,
                                      p_dropout=0.2,
                                      num_epochs=2000,
                                      print_every=-1)
    
    if tmp_valid_loss >= 5.0:
        continue
    
    """
    Save all models
    """
    cur_valid_loss = tmp_valid_loss
    cur_model = tmp_model
    cur_model_name = 'model_1028_exp6_' + str(cur_sub_index)
    print(cur_model_name)
    print(cur_model)
    final_submission(data=data, model_name=cur_model_name, model=cur_model)
    torch.save(cur_model.state_dict(), os.path.join('../pyg_models/', cur_model_name+'.pth'))

100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:11<00:00,  2.19it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1340/2000, val_loss: 5.4348
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:51<00:00,  2.24it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1637/2000, val_loss: 5.8212
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [14:31<00:00,  2.30it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1717/2000, val_loss: 5.2926
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:36<00:00,  2.14it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1251/2000, val_loss: 5.6421
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:52<00:00,  2.10it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1757/2000, val_loss: 5.6326
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:55<00:00,  2.09it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1823/2000, val_loss: 6.0930
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:54<00:00,  2.09it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1225/2000, val_loss: 5.5031
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:53<00:00,  2.10it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1604/2000, val_loss: 5.3857
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:52<00:00,  2.10it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1599/2000, val_loss: 5.6482
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:46<00:00,  2.11it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 716/2000, val_loss: 6.3894
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:43<00:00,  2.12it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 589/2000, val_loss: 6.4321
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:47<00:00,  2.11it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1373/2000, val_loss: 5.5676
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:43<00:00,  2.12it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1581/2000, val_loss: 5.5209
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:41<00:00,  2.12it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1849/2000, val_loss: 5.6324
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:42<00:00,  2.12it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1724/2000, val_loss: 6.4558
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:37<00:00,  2.13it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1967/2000, val_loss: 5.3454
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:43<00:00,  2.12it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1249/2000, val_loss: 5.5275
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:39<00:00,  2.13it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1803/2000, val_loss: 5.1799
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:39<00:00,  2.13it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1908/2000, val_loss: 5.7750
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [15:38<00:00,  2.13it/s]

Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 8  >>>  Best epoch: 1512/2000, val_loss: 5.3052
---------------------------------



