# PYG 1109

### Train best 3 models using cell embedding (normalized_counts) engineered from multiome data

### Increase dimension of cell embedding

### Put more penalty on B cells and M cells

In [1]:
import torch
torch.device('cuda' if torch.cuda.is_available() else 'cpu')

device(type='cuda')

In [2]:
import pandas as pd
import numpy as np
import tqdm
import torch.nn.functional as F
import copy

from transformers import RobertaTokenizerFast, RobertaForMaskedLM, DataCollatorWithPadding

from torch_geometric.data import Data, HeteroData
import torch_geometric.transforms as T

from sklearn.metrics import mean_squared_error
import copy

import os

## Data

Number of (cell_type, sm_name) pairs: 614

Number of genes expressions to predict: 18211

Control molecules: `Belinostat` and `Dabrafenib`.

In [3]:
df = pd.read_parquet('../data/de_train.parquet')

print(df.shape)
df.head(10)

(614, 18216)


Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,NK cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.10472,-0.077524,-1.625596,-0.144545,0.143555,...,-0.227781,-0.010752,-0.023881,0.674536,-0.453068,0.005164,-0.094959,0.034127,0.221377,0.368755
1,T cells CD4+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.915953,-0.88438,0.371834,-0.081677,-0.498266,...,-0.494985,-0.303419,0.304955,-0.333905,-0.315516,-0.369626,-0.095079,0.70478,1.096702,-0.869887
2,T cells CD8+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,-0.387721,-0.305378,0.567777,0.303895,-0.022653,...,-0.119422,-0.033608,-0.153123,0.183597,-0.555678,-1.494789,-0.21355,0.415768,0.078439,-0.259365
3,T regulatory cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.232893,0.129029,0.336897,0.486946,0.767661,...,0.451679,0.704643,0.015468,-0.103868,0.865027,0.189114,0.2247,-0.048233,0.216139,-0.085024
4,NK cells,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False,4.290652,-0.063864,-0.017443,-0.541154,0.570982,...,0.758474,0.510762,0.607401,-0.123059,0.214366,0.487838,-0.819775,0.112365,-0.122193,0.676629
5,T cells CD4+,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False,0.618061,0.180148,0.590015,0.035658,0.034297,...,-1.002997,0.177514,0.591768,-0.4124,-0.011633,-0.044739,0.213627,0.186406,-1.459477,1.164084
6,T cells CD8+,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False,-0.148131,0.091336,-0.097212,1.225601,0.045787,...,-0.228645,0.091248,-0.581765,0.405682,-0.034414,0.296333,0.170313,0.332022,-0.532363,0.134475
7,T regulatory cells,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False,0.561473,0.149415,3.22492,3.517419,1.361175,...,-0.156127,0.766118,0.702836,0.949482,0.757482,1.163063,1.532419,-0.399292,-2.412165,0.478977
8,B cells,Idelalisib,LSM-1205,CC[C@H](Nc1ncnc2[nH]cnc12)c1nc2cccc(F)c2c(=O)n...,False,0.394173,-0.153824,0.178232,0.566241,0.391377,...,-1.052302,-1.176587,-1.220291,-0.278944,-0.095066,1.10179,0.061803,1.406335,-0.264996,-0.119743
9,Myeloid cells,Idelalisib,LSM-1205,CC[C@H](Nc1ncnc2[nH]cnc12)c1nc2cccc(F)c2c(=O)n...,False,0.025146,0.316388,1.366885,1.301593,2.317619,...,-0.902546,-1.445523,0.794772,0.428973,0.605834,0.271988,0.492231,0.354721,1.471559,-0.259483


In [4]:
# node index for cell types
c2index = dict()
cur = 0
for x in df['cell_type'].unique():
    assert x not in c2index
    c2index[x] = cur
    cur += 1
print(c2index)
del cur

{'NK cells': 0, 'T cells CD4+': 1, 'T cells CD8+': 2, 'T regulatory cells': 3, 'B cells': 4, 'Myeloid cells': 5}


In [5]:
df['sm_name'].nunique()

146

In [6]:
# Assign the list of unique small molecules
sm_list = df['sm_name'].unique().tolist()
len(sm_list)

146

In [7]:
# small molecule embedding
tokenizer = RobertaTokenizerFast.from_pretrained("entropy/roberta_zinc_480m", max_len=128)
model = RobertaForMaskedLM.from_pretrained('entropy/roberta_zinc_480m')
collator = DataCollatorWithPadding(tokenizer, padding=True, return_tensors='pt')

# Extract SMILES embeddings
sm2emb = dict()
sm2index = dict()
cur = 0

for i in range(df.shape[0]):
    sm_name = df['sm_name'].iloc[i]
    smiles = df['SMILES'].iloc[i]



    if sm_name not in sm2emb:

        if True: # if sm_name not in ['Belinostat', 'Dabrafenib']:
            sm2index[sm_name] = cur
            cur += 1
        
        inputs = collator(tokenizer([smiles]))
        outputs = model(**inputs, output_hidden_states=True)
        full_embeddings = outputs[1][-1].detach()
        mask = inputs['attention_mask']
        node_embedding = ((full_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(-1).unsqueeze(-1))
        node_embedding = torch.squeeze(node_embedding)
        sm2emb[sm_name] = node_embedding

del cur

print(sm2index)
print(len(sm2emb))

You're using a RobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'Clotrimazole': 0, 'Mometasone Furoate': 1, 'Idelalisib': 2, 'Vandetanib': 3, 'Bosutinib': 4, 'Ceritinib': 5, 'Lamivudine': 6, 'Crizotinib': 7, 'Cabozantinib': 8, 'Flutamide': 9, 'Dasatinib': 10, 'Selumetinib': 11, 'Trametinib': 12, 'ABT-199 (GDC-0199)': 13, 'Oxybenzone': 14, 'Vorinostat': 15, 'Raloxifene': 16, 'Linagliptin': 17, 'Lapatinib': 18, 'Canertinib': 19, 'Disulfiram': 20, 'Vardenafil': 21, 'Palbociclib': 22, 'Ricolinostat': 23, 'Dabrafenib': 24, 'Proscillaridin A;Proscillaridin-A': 25, 'IN1451': 26, 'Ixabepilone': 27, 'CEP-18770 (Delanzomib)': 28, 'RG7112': 29, 'MK-5108': 30, 'Resminostat': 31, 'IMD-0354': 32, 'Alvocidib': 33, 'LY2090314': 34, 'Methotrexate': 35, 'LDN 193189': 36, 'Tacalcitol': 37, 'Colchicine': 38, 'R428': 39, 'TL_HRAS26': 40, 'BMS-387032': 41, 'CGP 60474': 42, 'TIE2 Kinase Inhibitor': 43, 'PD-0325901': 44, 'Isoniazid': 45, 'GSK-1070916': 46, 'Masitinib': 47, 'Saracatinib': 48, 'CC-401': 49, 'Decitabine': 50, 'Ketoconazole': 51, 'HYDROXYUREA': 52, 'BAY 61-3

## Cell embedding

In [8]:
# cell embedding
c2emb_df = pd.read_csv('./cell_emb_normcnt.csv')
c2emb_df.head()

Unnamed: 0,cell_type,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,A4GALT,AAAS,AACS,AAGAB,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,B cells,0.286709,0.369192,0.174579,0.04098,0.0,0.004556,0.160233,0.189103,0.508853,...,0.289054,0.218039,0.146898,0.010962,0.020037,0.055597,1.41121,0.731649,0.443267,1.138623
1,T cells CD4+,0.345561,0.276483,0.845896,0.397363,0.0,0.004368,0.207783,0.155854,0.671557,...,0.301208,0.231177,0.164825,0.013874,0.028687,0.059835,1.596086,0.810385,0.741314,1.573785
2,NK cells,0.198593,0.193077,2.525363,1.3214,0.0,0.006829,0.171567,0.138596,0.623483,...,0.264486,0.229794,0.147945,0.006202,0.010658,0.04588,1.375194,0.478963,0.794888,1.440577
3,Myeloid cells,0.426668,0.309088,0.842713,0.05328,0.0,0.057028,0.291705,0.120023,0.786532,...,0.578372,0.530244,0.217056,0.023516,0.014606,0.034963,1.783012,1.38516,1.725515,2.837625
4,T regulatory cells,0.244087,0.15209,0.146131,0.053442,0.0,0.0,0.135021,0.064394,0.65629,...,0.242616,0.191995,0.124759,0.048676,0.023823,0.053844,1.144881,0.64771,0.464262,1.315307


## Heterogenous Graph

#### Cell nodes (c_nodes)

There are 6 nodes for each unique cell types.

#### Small molecule nodes (sm_nodes)

There are 146 unique nodes for each small molecules.

SMILES embedding are used as node embeddings (node_emb_dim = 768).

In [9]:
# initialize graph
data = HeteroData()

# number of unique cell nodes and sm nodes
data['cell'].node_id = torch.arange(6)
data['sm'].node_id = torch.arange(146) # 144 or 146

## initialize cell node emb
cell_node_x = []

for cell_type, cell_index in c2index.items():
    #cell_node_x.append([cell_index])  # turn of cell embedding
    cell_node_x.append(c2emb_df[c2emb_df['cell_type']==cell_type].values[0][1:].tolist())
    

cell_node_x = torch.from_numpy(np.array(cell_node_x)).float()
data['cell'].x = cell_node_x

#print(data['cell'].x
assert data['cell'].x.shape == (6, 18211)

## initialize sm node emb
data['sm'].x = torch.from_numpy(np.array([sm2emb[y].tolist() for y in sm_list])).float()
assert data['sm'].x.shape == (146, 768)

## Edges

In [10]:
def add_edges(data):
    
    edge_index_c2sm = []

    for cell_index in range(6):
        for sm_index in range(146):
            edge_index_c2sm.append([cell_index, sm_index])

    print("Adding {} edges".format(len(edge_index_c2sm)))
    data['cell', 'gene_exp', 'sm'].edge_index = torch.tensor(edge_index_c2sm).t().contiguous().long()    
        
    # convert into bidirecitonal
    data = T.ToUndirected()(data)

    # data split (all for training)
    transform = T.RandomLinkSplit(
        num_val=0.0,
        num_test=0.0,
        is_undirected=True,
        disjoint_train_ratio=0.0,  # was 0.3
        neg_sampling_ratio=0.0,  # was 2.0
        add_negative_train_samples=False,
        edge_types=('cell', 'gene_exp', 'sm'),
        rev_edge_types=('sm', 'rev_gene_exp', 'cell'), 
    )
    
    train_data, _, _ = transform(data)
    
    return train_data

In [11]:
data = add_edges(data)

print(data)

Adding 876 edges
HeteroData(
  cell={
    node_id=[6],
    x=[6, 18211],
  },
  sm={
    node_id=[146],
    x=[146, 768],
  },
  (cell, gene_exp, sm)={
    edge_index=[2, 876],
    edge_label=[876],
    edge_label_index=[2, 876],
  },
  (sm, rev_gene_exp, cell)={ edge_index=[2, 876] }
)


## Train/Validation/Test Splits

In [12]:
df.head()

Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,NK cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.10472,-0.077524,-1.625596,-0.144545,0.143555,...,-0.227781,-0.010752,-0.023881,0.674536,-0.453068,0.005164,-0.094959,0.034127,0.221377,0.368755
1,T cells CD4+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.915953,-0.88438,0.371834,-0.081677,-0.498266,...,-0.494985,-0.303419,0.304955,-0.333905,-0.315516,-0.369626,-0.095079,0.70478,1.096702,-0.869887
2,T cells CD8+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,-0.387721,-0.305378,0.567777,0.303895,-0.022653,...,-0.119422,-0.033608,-0.153123,0.183597,-0.555678,-1.494789,-0.21355,0.415768,0.078439,-0.259365
3,T regulatory cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.232893,0.129029,0.336897,0.486946,0.767661,...,0.451679,0.704643,0.015468,-0.103868,0.865027,0.189114,0.2247,-0.048233,0.216139,-0.085024
4,NK cells,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False,4.290652,-0.063864,-0.017443,-0.541154,0.570982,...,0.758474,0.510762,0.607401,-0.123059,0.214366,0.487838,-0.819775,0.112365,-0.122193,0.676629


In [13]:
from sklearn.model_selection import train_test_split

train_df, valid_df = train_test_split(df, test_size=0.2, random_state=14, shuffle=True, stratify=df['cell_type'])


# all edge indeces and gene expressions for training
train_edge_indeces = set()
train_edge2feat = dict()

for i in range(train_df.shape[0]):
    cell_idx = c2index[train_df.iloc[i]['cell_type']]
    sm_idx = sm2index[train_df.iloc[i]['sm_name']]

    train_edge_indeces.add((cell_idx, sm_idx))
    
    train_edge2feat[(cell_idx, sm_idx)] = train_df.iloc[i, 5:].values.tolist()
    assert len(train_edge2feat[(cell_idx, sm_idx)]) == 18211


# all edge indeces and gene expressions for validation
valid_edge_indeces = set()
valid_edge2feat = dict()

for i in range(valid_df.shape[0]):
    cell_idx = c2index[valid_df.iloc[i]['cell_type']]
    sm_idx = sm2index[valid_df.iloc[i]['sm_name']]

    valid_edge_indeces.add((cell_idx, sm_idx))
    
    valid_edge2feat[(cell_idx, sm_idx)] = valid_df.iloc[i, 5:].values.tolist()
    assert len(valid_edge2feat[(cell_idx, sm_idx)]) == 18211

In [14]:
print(train_df.shape)
train_df.head()

(491, 18216)


Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
373,NK cells,MLN 2238,LSM-4944,CC(C)C[C@H](NC(=O)CNC(=O)c1cc(Cl)ccc1Cl)B(O)O,False,9.497701,5.478697,0.129686,-2.328127,5.732667,...,1.047863,3.069138,3.280055,5.284819,5.440957,5.702465,2.486294,1.484651,-1.385248,0.063061
391,T cells CD8+,I-BET151,LSM-6335,COc1cc2c(cc1-c1c(C)noc1C)ncc1[nH]c(=O)n([C@H](...,False,-0.01225,0.174611,-0.099221,-0.134166,0.143319,...,-0.301498,-0.534493,-0.06671,0.647332,0.153865,0.478001,0.884914,0.558577,-1.453534,1.611454
568,T cells CD8+,AMD-070 (hydrochloride),LSM-45591,NCCCCN(Cc1nc2ccccc2[nH]1)[C@H]1CCCc2cccnc21,False,0.909501,-0.011053,0.622651,-0.315069,0.142047,...,0.042186,0.001675,0.263338,0.077565,-0.028096,-0.006699,-0.175771,-0.451758,0.163957,-0.320038
15,T cells CD4+,Vandetanib,LSM-1199,COc1cc2c(Nc3ccc(Br)cc3F)ncnc2cc1OCC1CCN(C)CC1,False,0.106623,0.006949,0.424472,-0.124439,1.123565,...,-0.053448,0.072833,-0.42144,0.067708,1.21589,0.226569,0.084715,-0.668755,0.798338,-0.060044
190,T cells CD4+,PD-0325901,LSM-1101,O=C(NOC[C@H](O)CO)c1ccc(F)c(F)c1Nc1ccc(I)cc1F,False,-0.237122,-0.694187,0.39587,0.210085,-0.081793,...,-1.153235,-0.131176,0.49708,0.604248,0.313291,-0.730704,1.744382,1.603138,-0.143792,-0.194503


In [15]:
print(valid_df.shape)
valid_df.head()

(123, 18216)


Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
342,T regulatory cells,Tivantinib,LSM-1131,O=C1NC(=O)[C@@H](c2cn3c4c(cccc24)CCC3)[C@@H]1c...,False,0.193286,0.665367,-0.142059,0.047231,-0.062032,...,0.086186,-0.119378,0.071182,-0.69475,0.225487,0.352343,-0.127136,0.499296,0.90434,0.266353
244,T cells CD4+,Prednisolone,LSM-24954,C[C@]12C=CC(=O)C=C1CC[C@@H]1[C@@H]2[C@@H](O)C[...,False,0.134272,-0.051087,0.516691,-0.39171,-0.411336,...,-0.842898,-0.443645,1.415605,-0.82326,-0.084823,0.622433,0.845683,1.280901,-2.258758,0.635675
343,NK cells,CEP-37440,LSM-45849,CNC(=O)c1ccccc1Nc1nc(Nc2ccc3c(c2OC)CCC[C@H](N2...,False,0.791306,0.738089,-0.599179,-0.372779,-0.930951,...,-0.02937,-0.345004,0.096796,-0.287745,-0.267073,2.065996,0.24381,0.488007,-0.436535,0.361004
313,B cells,Foretinib,LSM-1158,COc1cc2c(Oc3ccc(NC(=O)C4(C(=O)Nc5ccc(F)cc5)CC4...,False,-0.176871,0.320933,0.532543,0.570455,0.725136,...,-0.432899,0.01753,-0.110185,0.207273,-0.184022,0.858617,0.271281,-0.31117,-0.230995,-1.423259
243,NK cells,Prednisolone,LSM-24954,C[C@]12C=CC(=O)C=C1CC[C@@H]1[C@@H]2[C@@H](O)C[...,False,0.05465,0.401233,0.258088,-0.396746,0.355001,...,0.093196,0.051984,-0.114204,-0.170931,-0.600998,0.242447,0.009419,0.249872,-1.248088,-0.01852


In [16]:
train_df[train_df['cell_type'].isin(['B cells', 'Myeloid cells'])].shape

(27, 18216)

In [17]:
valid_df[valid_df['cell_type'].isin(['B cells', 'Myeloid cells'])].shape

(7, 18216)

## Training routines

In [18]:
data

HeteroData(
  cell={
    node_id=[6],
    x=[6, 18211],
  },
  sm={
    node_id=[146],
    x=[146, 768],
  },
  (cell, gene_exp, sm)={
    edge_index=[2, 876],
    edge_label=[876],
    edge_label_index=[2, 876],
  },
  (sm, rev_gene_exp, cell)={ edge_index=[2, 876] }
)

In [19]:
print(data['cell', 'gene_exp', 'sm'].edge_label_index.shape)
data['cell', 'gene_exp', 'sm'].edge_label_index

torch.Size([2, 876])


tensor([[  4,   4,   0,  ...,   1,   2,   1],
        [132, 129, 118,  ...,  85, 133,  22]])

In [20]:
train_mask = torch.zeros_like(data['cell', 'gene_exp', 'sm'].edge_label_index, dtype=torch.bool)
valid_mask = torch.zeros_like(data['cell', 'gene_exp', 'sm'].edge_label_index, dtype=torch.bool)

for i in range(data['cell', 'gene_exp', 'sm'].edge_label_index.shape[1]):
    if (data['cell', 'gene_exp', 'sm'].edge_label_index[0, i].item(), data['cell', 'gene_exp', 'sm'].edge_label_index[1, i].item()) in train_edge_indeces:
        train_mask[:, i] = True
    if (data['cell', 'gene_exp', 'sm'].edge_label_index[0, i].item(), data['cell', 'gene_exp', 'sm'].edge_label_index[1, i].item()) in valid_edge_indeces:
        valid_mask[:, i] = True

print(train_mask.sum()/2)
print(train_mask)

print(valid_mask.sum()/2)
print(valid_mask)

tensor(491.)
tensor([[ True, False,  True,  ...,  True,  True,  True],
        [ True, False,  True,  ...,  True,  True,  True]])
tensor(123.)
tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])


In [21]:
BM_mask = torch.zeros_like(data['cell', 'gene_exp', 'sm'].edge_label_index, dtype=torch.bool)

for i in range(data['cell', 'gene_exp', 'sm'].edge_label_index.shape[1]):
    if data['cell', 'gene_exp', 'sm'].edge_label_index[0, i].item() in [c2index['B cells'], c2index['Myeloid cells']]:
        BM_mask[:, i] = True
        
print(BM_mask.sum()/2)
print(BM_mask)

tensor(292.)
tensor([[ True,  True, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False]])


In [22]:
def train(data,
          train_edge_indeces,
          valid_edge_indeces,
          train_edge2feat,
          valid_edge2feat,
          train_mask,
          valid_mask,
          BM_mask,
          alpha=1.0,
          cell_emb_dim=1,
          sm_emb_dim=4,
          lr=0.01,
          hidden_channels=32,
          p_dropout=0.1,
          num_epochs=100,
          print_every=1):
    
    # clone data
    data = data.clone()
    
    # show device information
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # create model (data, cell_emb_dim, sm_emb_dim, hidden_channels, p_dropout)
    model = Model(data=data,
                  cell_emb_dim=cell_emb_dim,
                  sm_emb_dim=sm_emb_dim,
                  hidden_channels=hidden_channels,
                  p_dropout=p_dropout).to(device)
    
    model = model.to(device)
    
    # print hypoparameters and model settings
    if print_every > 0:
        print("Device: {}, lr = {}, hidden_channels = {}, num_epochs = {}".format(device, lr, hidden_channels, num_epochs))
        print(model)

    # prepare data
    data = data.to(device)
    
    train_BM_mask = np.logical_and(train_mask, BM_mask).to(device).bool()
    valid_BM_mask = np.logical_and(valid_mask, BM_mask).to(device).bool()
    
    train_mask = train_mask.to(device)
    valid_mask = valid_mask.to(device)
    BM_mask = BM_mask.to(device)
    
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # model with best performance
    best_epoch = 0
    best_valid_loss = float('inf')
    best_model = copy.deepcopy(model)
    
    # ground truths for training and validation
    with torch.no_grad():
        target = []
        traget_BM = []
        
        valid_target = []
        valid_target_BM = []
        
        for i in range(data['cell', 'gene_exp', 'sm'].edge_label_index.shape[1]):
            cell_idx = data['cell', 'gene_exp', 'sm'].edge_label_index[0, i].item()
            sm_idx = data['cell', 'gene_exp', 'sm'].edge_label_index[1, i].item()

            # training
            if (cell_idx, sm_idx) in train_edge_indeces:
                target.append([train_edge2feat[(cell_idx, sm_idx)]])
                
                # training + kaggle targets
                if cell_idx in [c2index['B cells'], c2index['Myeloid cells']]:
                    traget_BM.append([train_edge2feat[(cell_idx, sm_idx)]])
            
            # validation
            elif (cell_idx, sm_idx) in valid_edge_indeces:
                valid_target.append([valid_edge2feat[(cell_idx, sm_idx)]])
                
                # validation + kaggle targets
                if cell_idx in [c2index['B cells'], c2index['Myeloid cells']]:
                    valid_target_BM.append([valid_edge2feat[(cell_idx, sm_idx)]])    
                    

        target = torch.from_numpy(np.array(target)).float().squeeze().to(device)  # (491, 18211)
        target_BM = torch.from_numpy(np.array(traget_BM)).float().squeeze().to(device)  # (27, 18211)

        valid_target = torch.from_numpy(np.array(valid_target)).float().squeeze()  # (123, 18211)
        valid_target_BM = torch.from_numpy(np.array(valid_target_BM)).float().squeeze()  # (7, 18211)
        
        
        N_train = target.shape[0]
        N_train_BM = target_BM.shape[0]
        
        N_valid = valid_target.shape[0]
        N_valid_BM = valid_target_BM.shape[0]
    
    
    for epoch in tqdm.tqdm(range(1, num_epochs+1)):
                
        optimizer.zero_grad()
        
        """
        Train
        """
        model.train()
        pred = model(data)  # (876, 18211)
               
        # 1107 edited
        total_train_loss = (alpha-1.0)*(N_train_BM/N_train)*F.mse_loss(pred[train_BM_mask[0]], target_BM) + F.mse_loss(pred[train_mask[0]], target)

        total_train_loss.backward()
        optimizer.step()
        
        """
        Validation
        """
        model.eval()
        with torch.no_grad():
            
            total_valid_loss = (alpha-1.0)*(N_valid_BM/N_valid)*mean_squared_error(valid_target_BM.numpy(), pred[valid_BM_mask[0]].detach().cpu().numpy()) + \
                               mean_squared_error(valid_target.numpy(), pred[valid_mask[0]].detach().cpu().numpy())
               
            # check if model is better
            if total_valid_loss < best_valid_loss:
                best_epoch = epoch
                best_valid_loss = total_valid_loss
                best_model = copy.deepcopy(model)
                best_status = '<<<'
            else:
                best_status = ''

        if print_every > 0:
            if (epoch-1) % print_every == 0:
                print("Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f} {}".format(epoch, total_train_loss, total_valid_loss, best_status))
            elif epoch == num_epochs:
                print("Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f} {}".format(epoch, total_train_loss, total_valid_loss, best_status))

    if print_every > 0:
        print("\nBest epoch: {}, val_loss: {:.4f}".format(best_epoch, best_valid_loss))
    else:
        print("Cell_emb: {}, sm_emb: {}, lr = {}, hidden_channels = {}  >>>  ".format(cell_emb_dim, sm_emb_dim, lr, hidden_channels) +
              "Best epoch: {}/{}, val_loss: {:.4f}".format(best_epoch, num_epochs, best_valid_loss))
        
    print("---------------------------------")
    return best_model, best_valid_loss

In [23]:
def final_submission(data, model_name='test_model', model=None):

    assert model is not None
    
    data = data.clone()
    
    ## load test indeces
    id_map = pd.read_csv('../data/id_map.csv')
    assert id_map.shape[0] == 255
    #print("ID map loaded")
    #print(id_map.shape)

    ## load template
    res_df = pd.read_csv('../data/sample_submission.csv')
    assert res_df.shape[0] == 255
    #print("Submission template loaded")
    #print(res_df.shape)

    ## device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    ## load model
    model = copy.deepcopy(model).to(device)
    model.eval()
    
    data = data.to(device)
    
    ## Inference
    with torch.no_grad():
        pred = model(data).detach().cpu().numpy()  # (876, 18211)
        
    for i in tqdm.tqdm(range(id_map.shape[0])):
        cell_idx = c2index[id_map['cell_type'].iloc[i]]
        sm_idx = sm2index[id_map['sm_name'].iloc[i]]

        all_edge_indeces = data['cell', 'gene_exp', 'sm'].edge_label_index
        #print(all_edge_indeces)
        
        for j in range(all_edge_indeces.shape[1]):
            if all_edge_indeces[0, j].item() == cell_idx and all_edge_indeces[1, j].item() == sm_idx:
                #print("{} + {} | {:.4f}".format(id_map['cell_type'].iloc[i],
                #                                     id_map['sm_name'].iloc[i],
                #                                     pred[j, :].sum()))
                
                #print("{}".format(pred[j, :3]))
                res_df.iloc[i, 1:] = pred[j, :]
                break
                
    res_df.to_csv('../submissions/' + model_name + '.csv', index=False)
    print("File saved.")
    print(res_df.shape)

    # sanity check
    tmp = pd.read_csv('../submissions/' + model_name + '.csv')
    assert tmp.shape == res_df.shape

    del tmp


### Experiment 1

### Remove linear for sm embedding

In [28]:
from torch_geometric.nn import SAGEConv, to_hetero
from torch.nn import Linear, Dropout, Embedding
import torch.nn.functional as F

class GNN(torch.nn.Module):
    """
    Compute node embedding
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()       
        self.conv1 = SAGEConv(-1, hidden_channels, normalize=True)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels, normalize=True)
        self.dropout1 = Dropout(p=p_dropout)
        
    def forward(self, x, edge_index):
        x = self.dropout1(F.elu(self.conv1(x, edge_index)))        
        x = self.conv2(x, edge_index)
        return x

class MLP(torch.nn.Module):
    """
    Model that predicts the gene expression given node features at the two ends
    
    1020: Change inner layer from (1024, 1024) units to (256, 256) units 
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()

        self.dropout1 = Dropout(p=p_dropout)
        self.dropout2 = Dropout(p=p_dropout)
        
        self.layer1 = Linear(2*hidden_channels, 256)
        self.layer2 = Linear(256, 1024)
        self.layer3 = Linear(1024, 18211)
    
    def forward(self, x_cell, x_sm, edge_label_index):
        
        # Retrieve embeddings from the two nodes
        feat_cell = x_cell[edge_label_index[0]]  # dim = hidden_channels
        feat_sm = x_sm[edge_label_index[1]]  #  dim = hidden_channels

        # calculate the gene expression
        y = torch.cat((feat_cell, feat_sm), dim=1)  # dim = 2*hidden_channels
        
        y = F.relu(self.layer1(y))  # dim = 256
        y = self.dropout1(y)
        
        y = F.relu(self.layer2(y))  # dim = 1024
        y = self.dropout2(y)
        
        y = self.layer3(y)  # dim = 18211
        
        return y


class Model(torch.nn.Module):
    def __init__(self, data, cell_emb_dim, sm_emb_dim, hidden_channels, p_dropout):
        super().__init__()

        # cell emb for each cell type
        self.cell_emb = Embedding(data["cell"].num_nodes, cell_emb_dim)
        
        """
        1027 - change from hidden_channels to 64
        """
        # cell emb leveraging control gene expression
        #self.cell_lin = Linear(18211, hidden_channels)
        self.cell_lin = Linear(18211, 64)

        # sm emb for each sm type
        if sm_emb_dim > 0:
            self.sm_emb = Embedding(data["sm"].num_nodes, sm_emb_dim)
        else:
            self.sm_emb = None
        # sm emb leveraging SMILES embedding
        self.sm_lin = Linear(768, hidden_channels)
        
        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels=hidden_channels, p_dropout=p_dropout)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        
        self.MLP = MLP(hidden_channels=hidden_channels, p_dropout=p_dropout)
        
                
    def forward(self, data):
        # calculate node embedding
        #x_cell = self.cell_emb(data["cell"].node_id)
        x_cell = torch.cat((self.cell_emb(data["cell"].node_id), self.cell_lin(data["cell"].x)), dim=1)
        
        
        if self.sm_emb:
            #x_sm = torch.cat((self.sm_emb(data["sm"].node_id), self.sm_lin(data["sm"].x)), dim=1)
            x_sm = torch.cat((self.sm_emb(data["sm"].node_id), data["sm"].x), dim=1)
        else:
            #x_sm = self.sm_lin(data["sm"].x)
            x_sm = data["sm"].x
            
        
        x_dict = {
          "cell": x_cell,
          "sm": x_sm,
        } 

        # compute final node embedding
        x_dict = self.gnn(x_dict, data.edge_index_dict)

        # Use node embedding to predict gene expression
        pred = self.MLP(
            x_dict["cell"],
            x_dict["sm"],
            data["cell", "gene_exp", "sm"].edge_label_index,
        )
        return pred

In [25]:
cur_model = None
cur_valid_loss = float('inf')

cur_sub_index = 0

for _ in range(20):

    cur_sub_index += 1

    tmp_model, tmp_valid_loss = train(data,
                                      train_edge_indeces,
                                      valid_edge_indeces,
                                      train_edge2feat,
                                      valid_edge2feat,
                                      train_mask,
                                      valid_mask,
                                      BM_mask,
                                      alpha=1.0,
                                      cell_emb_dim=2,
                                      sm_emb_dim=32,
                                      lr=5e-4,
                                      hidden_channels=16,
                                      p_dropout=0.2,
                                      num_epochs=2000,
                                      print_every=-1)
    continue
    
    #if tmp_valid_loss >= 5.0:
    #    continue
    #
    #"""
    #Save all models
    #"""
    #cur_valid_loss = tmp_valid_loss
    #cur_model = tmp_model
    #cur_model_name = 'model_1108_exp1_' + str(cur_sub_index)
    #print(cur_model_name)
    #print(cur_model)
    #final_submission(data=data, model_name=cur_model_name, model=cur_model)
    #torch.save(cur_model.state_dict(), os.path.join('../pyg_models/', cur_model_name+'.pth'))

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,
100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.38it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1565/2000, val_loss: 4.8852
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.83it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1645/2000, val_loss: 5.3835
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.75it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1951/2000, val_loss: 4.3765
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.18it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1440/2000, val_loss: 4.8766
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.78it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1563/2000, val_loss: 4.2816
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.48it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1963/2000, val_loss: 4.8542
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.59it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1744/2000, val_loss: 3.8263
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.63it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1499/2000, val_loss: 4.3608
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.74it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1791/2000, val_loss: 5.9578
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.95it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1447/2000, val_loss: 4.5252
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.69it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1299/2000, val_loss: 4.5934
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.47it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1422/2000, val_loss: 4.5658
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.77it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1623/2000, val_loss: 4.7644
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 37.02it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1828/2000, val_loss: 4.9671
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.33it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1910/2000, val_loss: 4.3054
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.69it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1886/2000, val_loss: 5.1543
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.97it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1872/2000, val_loss: 4.8996
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:51<00:00, 38.57it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1832/2000, val_loss: 5.4667
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.32it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1643/2000, val_loss: 4.9693
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:51<00:00, 38.58it/s]

Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1777/2000, val_loss: 4.7588
---------------------------------





In [34]:
cur_model = None
cur_valid_loss = float('inf')

while not cur_model:
    
    cur_sub_index = 0
    
    for _ in range(20):

        cur_sub_index += 1

        tmp_model, tmp_valid_loss = train(data,
                                          train_edge_indeces,
                                          valid_edge_indeces,
                                          train_edge2feat,
                                          valid_edge2feat,
                                          train_mask,
                                          valid_mask,
                                          BM_mask,
                                          alpha=1.0,
                                          cell_emb_dim=2,
                                          sm_emb_dim=32,
                                          lr=5e-4,
                                          hidden_channels=16,
                                          p_dropout=0.2,
                                          num_epochs=2000,
                                          print_every=-1)

        if tmp_valid_loss >= 4.5:
            continue

        """
        Save all models
        """
        cur_valid_loss = tmp_valid_loss
        cur_model = tmp_model
        cur_model_name = 'model_1109_exp1_' + str(cur_sub_index)
        print(cur_model_name)
        print(cur_model)
        final_submission(data=data, model_name=cur_model_name, model=cur_model)
        torch.save(cur_model.state_dict(), os.path.join('../pyg_models/', cur_model_name+'.pth'))

100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.62it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1883/2000, val_loss: 4.4932
---------------------------------
model_1109_exp1_1
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=18211, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_f

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:06<00:00,  3.85it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.49it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1690/2000, val_loss: 4.9816
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.00it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1524/2000, val_loss: 4.3397
---------------------------------
model_1109_exp1_3
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=18211, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_f

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:05<00:00,  3.90it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.78it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1841/2000, val_loss: 4.7325
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.76it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1426/2000, val_loss: 5.0451
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.56it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1874/2000, val_loss: 4.7461
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.58it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1745/2000, val_loss: 4.5548
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.80it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1704/2000, val_loss: 4.8080
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.05it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1515/2000, val_loss: 4.1930
---------------------------------
model_1109_exp1_9
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=18211, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_f

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:04<00:00,  3.95it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.21it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1889/2000, val_loss: 4.6687
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.68it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1273/2000, val_loss: 4.1099
---------------------------------
model_1109_exp1_11
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=18211, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:04<00:00,  3.96it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.83it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1355/2000, val_loss: 5.1304
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.41it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1849/2000, val_loss: 4.3245
---------------------------------
model_1109_exp1_13
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=18211, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:04<00:00,  3.95it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.53it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1814/2000, val_loss: 4.8510
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.65it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1525/2000, val_loss: 5.6233
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.05it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1278/2000, val_loss: 4.4652
---------------------------------
model_1109_exp1_16
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=18211, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:05<00:00,  3.87it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.35it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1974/2000, val_loss: 5.0168
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.90it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1680/2000, val_loss: 4.6003
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.96it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1588/2000, val_loss: 4.5086
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.13it/s]

Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1900/2000, val_loss: 5.0885
---------------------------------





### Experiment 2

### Remove linear transform of cell embedding

In [39]:
from torch_geometric.nn import SAGEConv, to_hetero
from torch.nn import Linear, Dropout, Embedding
import torch.nn.functional as F

class GNN(torch.nn.Module):
    """
    Compute node embedding
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()       
        self.conv1 = SAGEConv(-1, hidden_channels, normalize=True)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels, normalize=True)
        self.dropout1 = Dropout(p=p_dropout)
        
    def forward(self, x, edge_index):
        x = self.dropout1(F.elu(self.conv1(x, edge_index)))        
        x = self.conv2(x, edge_index)
        return x

class MLP(torch.nn.Module):
    """
    Model that predicts the gene expression given node features at the two ends
    
    1020: Change inner layer from (1024, 1024) units to (256, 256) units 
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()

        self.dropout1 = Dropout(p=p_dropout)
        self.dropout2 = Dropout(p=p_dropout)
        
        self.layer1 = Linear(2*hidden_channels, 256)
        self.layer2 = Linear(256, 1024)
        self.layer3 = Linear(1024, 18211)
    
    def forward(self, x_cell, x_sm, edge_label_index):
        
        # Retrieve embeddings from the two nodes
        feat_cell = x_cell[edge_label_index[0]]  # dim = hidden_channels
        feat_sm = x_sm[edge_label_index[1]]  #  dim = hidden_channels

        # calculate the gene expression
        y = torch.cat((feat_cell, feat_sm), dim=1)  # dim = 2*hidden_channels
        
        y = F.relu(self.layer1(y))  # dim = 256
        y = self.dropout1(y)
        
        y = F.relu(self.layer2(y))  # dim = 1024
        y = self.dropout2(y)
        
        y = self.layer3(y)  # dim = 18211
        
        return y


class Model(torch.nn.Module):
    def __init__(self, data, cell_emb_dim, sm_emb_dim, hidden_channels, p_dropout):
        super().__init__()

        # cell emb for each cell type
        self.cell_emb = Embedding(data["cell"].num_nodes, cell_emb_dim)
        
        """
        1027 - change from hidden_channels to 64
        """
        # cell emb leveraging control gene expression
        #self.cell_lin = Linear(18211, hidden_channels)
        self.cell_lin = Linear(18211, 64)

        # sm emb for each sm type
        if sm_emb_dim > 0:
            self.sm_emb = Embedding(data["sm"].num_nodes, sm_emb_dim)
        else:
            self.sm_emb = None
        # sm emb leveraging SMILES embedding
        self.sm_lin = Linear(768, hidden_channels)
        
        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels=hidden_channels, p_dropout=p_dropout)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        
        self.MLP = MLP(hidden_channels=hidden_channels, p_dropout=p_dropout)
        
                
    def forward(self, data):
        # calculate node embedding
        #x_cell = self.cell_emb(data["cell"].node_id)
        #x_cell = torch.cat((self.cell_emb(data["cell"].node_id), self.cell_lin(data["cell"].x)), dim=1)
        x_cell = torch.cat((self.cell_emb(data["cell"].node_id), data["cell"].x), dim=1)
        
        if self.sm_emb:
            #x_sm = torch.cat((self.sm_emb(data["sm"].node_id), self.sm_lin(data["sm"].x)), dim=1)
            x_sm = torch.cat((self.sm_emb(data["sm"].node_id), data["sm"].x), dim=1)
        else:
            #x_sm = self.sm_lin(data["sm"].x)
            x_sm = data["sm"].x
            
        
        x_dict = {
          "cell": x_cell,
          "sm": x_sm,
        } 

        # compute final node embedding
        x_dict = self.gnn(x_dict, data.edge_index_dict)

        # Use node embedding to predict gene expression
        pred = self.MLP(
            x_dict["cell"],
            x_dict["sm"],
            data["cell", "gene_exp", "sm"].edge_label_index,
        )
        return pred

In [27]:
cur_model = None
cur_valid_loss = float('inf')

cur_sub_index = 0

for _ in range(20):

    cur_sub_index += 1

    tmp_model, tmp_valid_loss = train(data,
                                      train_edge_indeces,
                                      valid_edge_indeces,
                                      train_edge2feat,
                                      valid_edge2feat,
                                      train_mask,
                                      valid_mask,
                                      BM_mask,
                                      alpha=1.0,
                                      cell_emb_dim=2,
                                      sm_emb_dim=32,
                                      lr=5e-4,
                                      hidden_channels=16,
                                      p_dropout=0.2,
                                      num_epochs=2000,
                                      print_every=-1)
    continue
    
    #if tmp_valid_loss >= 5.0:
    #    continue
    #
    #"""
    #Save all models
    #"""
    #cur_valid_loss = tmp_valid_loss
    #cur_model = tmp_model
    #cur_model_name = 'model_1108_exp1_' + str(cur_sub_index)
    #print(cur_model_name)
    #print(cur_model)
    #final_submission(data=data, model_name=cur_model_name, model=cur_model)
    #torch.save(cur_model.state_dict(), os.path.join('../pyg_models/', cur_model_name+'.pth'))

100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.04it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1616/2000, val_loss: 4.3214
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.09it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1323/2000, val_loss: 5.2259
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.94it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1540/2000, val_loss: 5.8120
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.98it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1975/2000, val_loss: 4.4732
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.94it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1651/2000, val_loss: 5.8281
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 37.00it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1849/2000, val_loss: 4.6403
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1781/2000, val_loss: 4.4507
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.92it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1652/2000, val_loss: 4.2838
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.20it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1786/2000, val_loss: 4.7674
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 37.01it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 707/2000, val_loss: 5.4111
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.12it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1795/2000, val_loss: 4.6785
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.96it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1820/2000, val_loss: 4.9857
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.04it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1968/2000, val_loss: 5.7131
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.98it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1939/2000, val_loss: 4.7721
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.95it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1422/2000, val_loss: 5.0398
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.13it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1964/2000, val_loss: 5.4091
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:55<00:00, 36.02it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1668/2000, val_loss: 4.0776
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:55<00:00, 36.31it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1654/2000, val_loss: 4.4586
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.38it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1540/2000, val_loss: 5.3650
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:55<00:00, 36.06it/s]

Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1845/2000, val_loss: 4.8364
---------------------------------





In [41]:
cur_model = None
cur_valid_loss = float('inf')

while not cur_model:
    
    cur_sub_index = 0
    
    for _ in range(10):

        cur_sub_index += 1

        tmp_model, tmp_valid_loss = train(data,
                                          train_edge_indeces,
                                          valid_edge_indeces,
                                          train_edge2feat,
                                          valid_edge2feat,
                                          train_mask,
                                          valid_mask,
                                          BM_mask,
                                          alpha=1.0,
                                          cell_emb_dim=2,
                                          sm_emb_dim=32,
                                          lr=5e-4,
                                          hidden_channels=16,
                                          p_dropout=0.2,
                                          num_epochs=2000,
                                          print_every=-1)

        if tmp_valid_loss >= 4.2:
            continue

        """
        Save all models
        """
        cur_valid_loss = tmp_valid_loss
        cur_model = tmp_model
        cur_model_name = 'model_1109_exp2_' + str(cur_sub_index)
        print(cur_model_name)
        print(cur_model)
        final_submission(data=data, model_name=cur_model_name, model=cur_model)
        torch.save(cur_model.state_dict(), os.path.join('../pyg_models/', cur_model_name+'.pth'))

100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:55<00:00, 36.01it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1803/2000, val_loss: 4.5691
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:55<00:00, 36.06it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1710/2000, val_loss: 4.5364
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.47it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1998/2000, val_loss: 5.1655
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.59it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1659/2000, val_loss: 4.7066
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.74it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1820/2000, val_loss: 4.9576
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:55<00:00, 36.27it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1103/2000, val_loss: 5.5636
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:55<00:00, 36.24it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1572/2000, val_loss: 4.8629
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:55<00:00, 36.12it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1933/2000, val_loss: 5.3124
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.41it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 907/2000, val_loss: 4.8154
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:55<00:00, 36.17it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1876/2000, val_loss: 4.8574
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.54it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1867/2000, val_loss: 5.5034
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:55<00:00, 36.17it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1969/2000, val_loss: 4.8766
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:55<00:00, 36.14it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1988/2000, val_loss: 4.1385
---------------------------------
model_1109_exp2_3
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=18211, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_f

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:05<00:00,  3.87it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.58it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1709/2000, val_loss: 5.2028
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:55<00:00, 36.16it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1934/2000, val_loss: 5.1600
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.37it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1539/2000, val_loss: 4.2549
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:55<00:00, 36.12it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1604/2000, val_loss: 4.9117
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:55<00:00, 36.21it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1943/2000, val_loss: 4.2646
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:55<00:00, 36.28it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1423/2000, val_loss: 4.5942
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:55<00:00, 36.10it/s]

Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1814/2000, val_loss: 5.0985
---------------------------------





### Experiment 3 - Train original 0.594 model with more penalty on B and M cells

In [42]:
from torch_geometric.nn import SAGEConv, to_hetero
from torch.nn import Linear, Dropout, Embedding
import torch.nn.functional as F

class GNN(torch.nn.Module):
    """
    Compute node embedding
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()       
        self.conv1 = SAGEConv(-1, hidden_channels, normalize=True)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels, normalize=True)
        self.dropout1 = Dropout(p=p_dropout)
        
    def forward(self, x, edge_index):
        x = self.dropout1(F.elu(self.conv1(x, edge_index)))        
        x = self.conv2(x, edge_index)
        return x

class MLP(torch.nn.Module):
    """
    Model that predicts the gene expression given node features at the two ends
    
    1020: Change inner layer from (1024, 1024) units to (256, 256) units 
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()

        self.dropout1 = Dropout(p=p_dropout)
        self.dropout2 = Dropout(p=p_dropout)
        
        self.layer1 = Linear(2*hidden_channels, 256)
        self.layer2 = Linear(256, 1024)
        self.layer3 = Linear(1024, 18211)
    
    def forward(self, x_cell, x_sm, edge_label_index):
        
        # Retrieve embeddings from the two nodes
        feat_cell = x_cell[edge_label_index[0]]  # dim = hidden_channels
        feat_sm = x_sm[edge_label_index[1]]  #  dim = hidden_channels

        # calculate the gene expression
        y = torch.cat((feat_cell, feat_sm), dim=1)  # dim = 2*hidden_channels
        
        y = F.relu(self.layer1(y))  # dim = 256
        y = self.dropout1(y)
        
        y = F.relu(self.layer2(y))  # dim = 1024
        y = self.dropout2(y)
        
        y = self.layer3(y)  # dim = 18211
        
        return y


class Model(torch.nn.Module):
    def __init__(self, data, cell_emb_dim, sm_emb_dim, hidden_channels, p_dropout):
        super().__init__()

        # cell emb for each cell type
        self.cell_emb = Embedding(data["cell"].num_nodes, cell_emb_dim)
        
        """
        1027 - change from hidden_channels to 64
        """
        # cell emb leveraging control gene expression
        #self.cell_lin = Linear(18211, hidden_channels)
        self.cell_lin = Linear(18211, 64)

        # sm emb for each sm type
        if sm_emb_dim > 0:
            self.sm_emb = Embedding(data["sm"].num_nodes, sm_emb_dim)
        else:
            self.sm_emb = None
        # sm emb leveraging SMILES embedding
        self.sm_lin = Linear(768, hidden_channels)
        
        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels=hidden_channels, p_dropout=p_dropout)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        
        self.MLP = MLP(hidden_channels=hidden_channels, p_dropout=p_dropout)
        
                
    def forward(self, data):
        # calculate node embedding
        #x_cell = self.cell_emb(data["cell"].node_id)
        x_cell = torch.cat((self.cell_emb(data["cell"].node_id), self.cell_lin(data["cell"].x)), dim=1)
        
        
        if self.sm_emb:
            x_sm = torch.cat((self.sm_emb(data["sm"].node_id), self.sm_lin(data["sm"].x)), dim=1)
        else:
            x_sm = self.sm_lin(data["sm"].x)
            
        
        x_dict = {
          "cell": x_cell,
          "sm": x_sm,
        } 

        # compute final node embedding
        x_dict = self.gnn(x_dict, data.edge_index_dict)

        # Use node embedding to predict gene expression
        pred = self.MLP(
            x_dict["cell"],
            x_dict["sm"],
            data["cell", "gene_exp", "sm"].edge_label_index,
        )
        return pred

In [43]:
for _ in range(20):

    tmp_model, tmp_valid_loss = train(data,
                                      train_edge_indeces,
                                      valid_edge_indeces,
                                      train_edge2feat,
                                      valid_edge2feat,
                                      train_mask,
                                      valid_mask,
                                      BM_mask,
                                      alpha=2.0,
                                      cell_emb_dim=2,
                                      sm_emb_dim=32,
                                      lr=5e-4,
                                      hidden_channels=16,
                                      p_dropout=0.2,
                                      num_epochs=2000,
                                      print_every=-1)

100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.58it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1544/2000, val_loss: 6.7359
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.35it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1821/2000, val_loss: 6.7074
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.78it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1476/2000, val_loss: 6.5126
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.19it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1169/2000, val_loss: 6.1828
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.40it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1962/2000, val_loss: 7.2022
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.11it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1579/2000, val_loss: 5.0741
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.33it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1461/2000, val_loss: 6.5177
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.28it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 932/2000, val_loss: 6.7868
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.69it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1030/2000, val_loss: 6.1327
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.48it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1995/2000, val_loss: 6.6351
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.35it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1063/2000, val_loss: 6.5582
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.42it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1064/2000, val_loss: 6.7036
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 37.00it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1250/2000, val_loss: 6.2669
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.51it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 885/2000, val_loss: 6.2008
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.12it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1475/2000, val_loss: 6.4772
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.71it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 887/2000, val_loss: 6.5740
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.17it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1294/2000, val_loss: 6.1899
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.46it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 936/2000, val_loss: 6.3893
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.39it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1819/2000, val_loss: 5.7901
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.17it/s]

Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1561/2000, val_loss: 6.8156
---------------------------------





In [44]:
cur_model = None
cur_valid_loss = float('inf')

while not cur_model:
    
    cur_sub_index = 0
    
    for _ in range(10):

        cur_sub_index += 1

        tmp_model, tmp_valid_loss = train(data,
                                          train_edge_indeces,
                                          valid_edge_indeces,
                                          train_edge2feat,
                                          valid_edge2feat,
                                          train_mask,
                                          valid_mask,
                                          BM_mask,
                                          alpha=2.0,
                                          cell_emb_dim=2,
                                          sm_emb_dim=32,
                                          lr=5e-4,
                                          hidden_channels=16,
                                          p_dropout=0.2,
                                          num_epochs=2000,
                                          print_every=-1)

        if tmp_valid_loss >= 6.0:
            continue

        """
        Save all models
        """
        cur_valid_loss = tmp_valid_loss
        cur_model = tmp_model
        cur_model_name = 'model_1109_exp3_' + str(cur_sub_index)
        print(cur_model_name)
        print(cur_model)
        final_submission(data=data, model_name=cur_model_name, model=cur_model)
        torch.save(cur_model.state_dict(), os.path.join('../pyg_models/', cur_model_name+'.pth'))

100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.73it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1151/2000, val_loss: 6.3532
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.68it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1081/2000, val_loss: 6.3344
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.80it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 853/2000, val_loss: 6.2779
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.42it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1850/2000, val_loss: 6.2446
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.66it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1033/2000, val_loss: 6.0120
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.15it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 916/2000, val_loss: 6.4251
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.91it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 928/2000, val_loss: 6.8127
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:54<00:00, 36.78it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1062/2000, val_loss: 7.0809
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.12it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1095/2000, val_loss: 7.1483
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.53it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1751/2000, val_loss: 6.3693
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.97it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1626/2000, val_loss: 6.3620
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.47it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1136/2000, val_loss: 6.5522
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.11it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1167/2000, val_loss: 6.0058
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.77it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1174/2000, val_loss: 6.2694
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.84it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1604/2000, val_loss: 5.9093
---------------------------------
model_1109_exp3_5
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=18211, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_f

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:05<00:00,  3.88it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.77it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1567/2000, val_loss: 6.2419
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.38it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1028/2000, val_loss: 6.2289
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.84it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1250/2000, val_loss: 6.0420
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.86it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1865/2000, val_loss: 6.1508
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.38it/s]

Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1205/2000, val_loss: 6.0962
---------------------------------





### Experiment 4

In [45]:
from torch_geometric.nn import SAGEConv, to_hetero
from torch.nn import Linear, Dropout, Embedding
import torch.nn.functional as F

class GNN(torch.nn.Module):
    """
    Compute node embedding
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()       
        self.conv1 = SAGEConv(-1, hidden_channels, normalize=True)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels, normalize=True)
        self.dropout1 = Dropout(p=p_dropout)
        
    def forward(self, x, edge_index):
        x = self.dropout1(F.elu(self.conv1(x, edge_index)))        
        x = self.conv2(x, edge_index)
        return x

class MLP(torch.nn.Module):
    """
    Model that predicts the gene expression given node features at the two ends
    
    1020: Change inner layer from (1024, 1024) units to (256, 256) units 
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()

        self.dropout1 = Dropout(p=p_dropout)
        self.dropout2 = Dropout(p=p_dropout)
        
        self.layer1 = Linear(2*hidden_channels, 256)
        self.layer2 = Linear(256, 1024)
        self.layer3 = Linear(1024, 18211)
    
    def forward(self, x_cell, x_sm, edge_label_index):
        
        # Retrieve embeddings from the two nodes
        feat_cell = x_cell[edge_label_index[0]]  # dim = hidden_channels
        feat_sm = x_sm[edge_label_index[1]]  #  dim = hidden_channels

        # calculate the gene expression
        y = torch.cat((feat_cell, feat_sm), dim=1)  # dim = 2*hidden_channels
        
        y = F.relu(self.layer1(y))  # dim = 256
        y = self.dropout1(y)
        
        y = F.relu(self.layer2(y))  # dim = 1024
        y = self.dropout2(y)
        
        y = self.layer3(y)  # dim = 18211
        
        return y


class Model(torch.nn.Module):
    def __init__(self, data, cell_emb_dim, sm_emb_dim, hidden_channels, p_dropout):
        super().__init__()

        # cell emb for each cell type
        self.cell_emb = Embedding(data["cell"].num_nodes, cell_emb_dim)
        
        """
        1027 - change from hidden_channels to 64
        1109 - change from 64 to 16
        """
        # cell emb leveraging control gene expression
        #self.cell_lin = Linear(18211, hidden_channels)
        self.cell_lin = Linear(18211, 16)

        # sm emb for each sm type
        if sm_emb_dim > 0:
            self.sm_emb = Embedding(data["sm"].num_nodes, sm_emb_dim)
        else:
            self.sm_emb = None
        # sm emb leveraging SMILES embedding
        self.sm_lin = Linear(768, hidden_channels)
        
        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels=hidden_channels, p_dropout=p_dropout)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        
        self.MLP = MLP(hidden_channels=hidden_channels, p_dropout=p_dropout)
        
                
    def forward(self, data):
        # calculate node embedding
        #x_cell = self.cell_emb(data["cell"].node_id)
        x_cell = torch.cat((self.cell_emb(data["cell"].node_id), self.cell_lin(data["cell"].x)), dim=1)
        
        
        if self.sm_emb:
            x_sm = torch.cat((self.sm_emb(data["sm"].node_id), self.sm_lin(data["sm"].x)), dim=1)
        else:
            x_sm = self.sm_lin(data["sm"].x)
            
        
        x_dict = {
          "cell": x_cell,
          "sm": x_sm,
        } 

        # compute final node embedding
        x_dict = self.gnn(x_dict, data.edge_index_dict)

        # Use node embedding to predict gene expression
        pred = self.MLP(
            x_dict["cell"],
            x_dict["sm"],
            data["cell", "gene_exp", "sm"].edge_label_index,
        )
        return pred

In [46]:
cur_model = None
cur_valid_loss = float('inf')

while not cur_model:
    
    cur_sub_index = 0
    
    for _ in range(10):

        cur_sub_index += 1

        tmp_model, tmp_valid_loss = train(data,
                                          train_edge_indeces,
                                          valid_edge_indeces,
                                          train_edge2feat,
                                          valid_edge2feat,
                                          train_mask,
                                          valid_mask,
                                          BM_mask,
                                          alpha=1.0,
                                          cell_emb_dim=2,
                                          sm_emb_dim=32,
                                          lr=5e-4,
                                          hidden_channels=16,
                                          p_dropout=0.2,
                                          num_epochs=2000,
                                          print_every=-1)

        if tmp_valid_loss >= 4.5:
            continue

        """
        Save all models
        """
        cur_valid_loss = tmp_valid_loss
        cur_model = tmp_model
        cur_model_name = 'model_1109_exp4_' + str(cur_sub_index)
        print(cur_model_name)
        print(cur_model)
        final_submission(data=data, model_name=cur_model_name, model=cur_model)
        torch.save(cur_model.state_dict(), os.path.join('../pyg_models/', cur_model_name+'.pth'))

100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.47it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1508/2000, val_loss: 5.1483
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.96it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1847/2000, val_loss: 4.9446
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.13it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1255/2000, val_loss: 5.6857
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.12it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1230/2000, val_loss: 4.8977
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.39it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1510/2000, val_loss: 4.8706
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.14it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1954/2000, val_loss: 5.6161
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.12it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1120/2000, val_loss: 4.8119
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.30it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1912/2000, val_loss: 4.6099
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.92it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1936/2000, val_loss: 5.1243
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:53<00:00, 37.66it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1330/2000, val_loss: 4.7525
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.94it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1582/2000, val_loss: 4.4598
---------------------------------
model_1109_exp4_1
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=18211, out_features=16, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_f

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:06<00:00,  3.82it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 37.87it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1292/2000, val_loss: 4.4957
---------------------------------
model_1109_exp4_2
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=18211, out_features=16, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_f

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:04<00:00,  3.93it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.05it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1555/2000, val_loss: 4.3574
---------------------------------
model_1109_exp4_3
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=18211, out_features=16, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_f

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:03<00:00,  4.00it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.03it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1591/2000, val_loss: 4.7448
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:51<00:00, 38.66it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1873/2000, val_loss: 4.5506
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.27it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1745/2000, val_loss: 5.5584
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.34it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1239/2000, val_loss: 5.3972
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:51<00:00, 38.55it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1438/2000, val_loss: 4.7318
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.21it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1361/2000, val_loss: 4.7301
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:52<00:00, 38.46it/s]

Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1716/2000, val_loss: 6.0380
---------------------------------





### Experiment 5

In [49]:
from torch_geometric.nn import SAGEConv, to_hetero
from torch.nn import Linear, Dropout, Embedding
import torch.nn.functional as F

class GNN(torch.nn.Module):
    """
    Compute node embedding
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()       
        self.conv1 = SAGEConv(-1, hidden_channels, normalize=True)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels, normalize=True)
        self.dropout1 = Dropout(p=p_dropout)
        
    def forward(self, x, edge_index):
        x = self.dropout1(F.elu(self.conv1(x, edge_index)))        
        x = self.conv2(x, edge_index)
        return x

class MLP(torch.nn.Module):
    """
    Model that predicts the gene expression given node features at the two ends
    
    1020: Change inner layer from (1024, 1024) units to (256, 256) units 
    1020: Change inner layer from (256, 1024) units to (512, 2048) units 
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()

        self.dropout1 = Dropout(p=p_dropout)
        self.dropout2 = Dropout(p=p_dropout)
        
        self.layer1 = Linear(2*hidden_channels, 512)
        self.layer2 = Linear(512, 2048)
        self.layer3 = Linear(2048, 18211)
    
    def forward(self, x_cell, x_sm, edge_label_index):
        
        # Retrieve embeddings from the two nodes
        feat_cell = x_cell[edge_label_index[0]]  # dim = hidden_channels
        feat_sm = x_sm[edge_label_index[1]]  #  dim = hidden_channels

        # calculate the gene expression
        y = torch.cat((feat_cell, feat_sm), dim=1)  # dim = 2*hidden_channels
        
        y = F.relu(self.layer1(y))  # dim = 256
        y = self.dropout1(y)
        
        y = F.relu(self.layer2(y))  # dim = 1024
        y = self.dropout2(y)
        
        y = self.layer3(y)  # dim = 18211
        
        return y


class Model(torch.nn.Module):
    def __init__(self, data, cell_emb_dim, sm_emb_dim, hidden_channels, p_dropout):
        super().__init__()

        # cell emb for each cell type
        self.cell_emb = Embedding(data["cell"].num_nodes, cell_emb_dim)
        
        """
        1027 - change from hidden_channels to 64
        """
        # cell emb leveraging control gene expression
        #self.cell_lin = Linear(18211, hidden_channels)
        self.cell_lin = Linear(18211, 64)

        # sm emb for each sm type
        if sm_emb_dim > 0:
            self.sm_emb = Embedding(data["sm"].num_nodes, sm_emb_dim)
        else:
            self.sm_emb = None
        # sm emb leveraging SMILES embedding
        self.sm_lin = Linear(768, hidden_channels)
        
        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels=hidden_channels, p_dropout=p_dropout)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        
        self.MLP = MLP(hidden_channels=hidden_channels, p_dropout=p_dropout)
        
                
    def forward(self, data):
        # calculate node embedding
        #x_cell = self.cell_emb(data["cell"].node_id)
        x_cell = torch.cat((self.cell_emb(data["cell"].node_id), self.cell_lin(data["cell"].x)), dim=1)
        
        
        if self.sm_emb:
            x_sm = torch.cat((self.sm_emb(data["sm"].node_id), self.sm_lin(data["sm"].x)), dim=1)
        else:
            x_sm = self.sm_lin(data["sm"].x)
            
        
        x_dict = {
          "cell": x_cell,
          "sm": x_sm,
        } 

        # compute final node embedding
        x_dict = self.gnn(x_dict, data.edge_index_dict)

        # Use node embedding to predict gene expression
        pred = self.MLP(
            x_dict["cell"],
            x_dict["sm"],
            data["cell", "gene_exp", "sm"].edge_label_index,
        )
        return pred

In [50]:
cur_model = None
cur_valid_loss = float('inf')

while not cur_model:
    
    cur_sub_index = 0
    
    for _ in range(10):

        cur_sub_index += 1

        tmp_model, tmp_valid_loss = train(data,
                                          train_edge_indeces,
                                          valid_edge_indeces,
                                          train_edge2feat,
                                          valid_edge2feat,
                                          train_mask,
                                          valid_mask,
                                          BM_mask,
                                          alpha=1.0,
                                          cell_emb_dim=2,
                                          sm_emb_dim=32,
                                          lr=5e-4,
                                          hidden_channels=16,
                                          p_dropout=0.2,
                                          num_epochs=2000,
                                          print_every=-1)

        if tmp_valid_loss >= 4.5:
            continue

        """
        Save all models
        """
        cur_valid_loss = tmp_valid_loss
        cur_model = tmp_model
        cur_model_name = 'model_1109_exp5_' + str(cur_sub_index)
        print(cur_model_name)
        print(cur_model)
        final_submission(data=data, model_name=cur_model_name, model=cur_model)
        torch.save(cur_model.state_dict(), os.path.join('../pyg_models/', cur_model_name+'.pth'))

100%|███████████████████████████████████████████████████████████████████| 2000/2000 [01:26<00:00, 23.11it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1882/2000, val_loss: 5.2758
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [01:24<00:00, 23.69it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1491/2000, val_loss: 5.8024
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [01:24<00:00, 23.74it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1803/2000, val_loss: 4.1337
---------------------------------
model_1109_exp5_3
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=18211, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=512, bias=True)
    (layer2): Linear(in_features=512, out_f

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:03<00:00,  4.00it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [01:24<00:00, 23.61it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1561/2000, val_loss: 4.6954
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [01:24<00:00, 23.73it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 676/2000, val_loss: 5.5468
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [01:26<00:00, 23.17it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1315/2000, val_loss: 4.4636
---------------------------------
model_1109_exp5_6
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=18211, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=512, bias=True)
    (layer2): Linear(in_features=512, out_f

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:04<00:00,  3.98it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [01:24<00:00, 23.58it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1927/2000, val_loss: 4.3343
---------------------------------
model_1109_exp5_7
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=18211, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=512, bias=True)
    (layer2): Linear(in_features=512, out_f

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:04<00:00,  3.95it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [01:25<00:00, 23.34it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 957/2000, val_loss: 5.6597
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [01:25<00:00, 23.34it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1765/2000, val_loss: 4.8355
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [01:25<00:00, 23.29it/s]

Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1834/2000, val_loss: 5.1686
---------------------------------





### Experiment 6 -- Validation error only considers B and G cells

### Model backbone is from 0.594

In [59]:
def train(data,
          train_edge_indeces,
          valid_edge_indeces,
          train_edge2feat,
          valid_edge2feat,
          train_mask,
          valid_mask,
          BM_mask,
          alpha=1.0,
          cell_emb_dim=1,
          sm_emb_dim=4,
          lr=0.01,
          hidden_channels=32,
          p_dropout=0.1,
          num_epochs=100,
          print_every=1):
    
    # clone data
    data = data.clone()
    
    # show device information
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # create model (data, cell_emb_dim, sm_emb_dim, hidden_channels, p_dropout)
    model = Model(data=data,
                  cell_emb_dim=cell_emb_dim,
                  sm_emb_dim=sm_emb_dim,
                  hidden_channels=hidden_channels,
                  p_dropout=p_dropout).to(device)
    
    model = model.to(device)
    
    # print hypoparameters and model settings
    if print_every > 0:
        print("Device: {}, lr = {}, hidden_channels = {}, num_epochs = {}".format(device, lr, hidden_channels, num_epochs))
        print(model)

    # prepare data
    data = data.to(device)
    
    train_BM_mask = np.logical_and(train_mask, BM_mask).to(device).bool()
    valid_BM_mask = np.logical_and(valid_mask, BM_mask).to(device).bool()
    
    train_mask = train_mask.to(device)
    valid_mask = valid_mask.to(device)
    BM_mask = BM_mask.to(device)
    
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # model with best performance
    best_epoch = 0
    best_valid_loss = float('inf')
    best_model = copy.deepcopy(model)
    
    # ground truths for training and validation
    with torch.no_grad():
        target = []
        traget_BM = []
        
        valid_target = []
        valid_target_BM = []
        
        for i in range(data['cell', 'gene_exp', 'sm'].edge_label_index.shape[1]):
            cell_idx = data['cell', 'gene_exp', 'sm'].edge_label_index[0, i].item()
            sm_idx = data['cell', 'gene_exp', 'sm'].edge_label_index[1, i].item()

            # training
            if (cell_idx, sm_idx) in train_edge_indeces:
                target.append([train_edge2feat[(cell_idx, sm_idx)]])
                
                # training + kaggle targets
                if cell_idx in [c2index['B cells'], c2index['Myeloid cells']]:
                    traget_BM.append([train_edge2feat[(cell_idx, sm_idx)]])
            
            # validation
            elif (cell_idx, sm_idx) in valid_edge_indeces:
                valid_target.append([valid_edge2feat[(cell_idx, sm_idx)]])
                
                # validation + kaggle targets
                if cell_idx in [c2index['B cells'], c2index['Myeloid cells']]:
                    valid_target_BM.append([valid_edge2feat[(cell_idx, sm_idx)]])    
                    

        target = torch.from_numpy(np.array(target)).float().squeeze().to(device)  # (491, 18211)
        target_BM = torch.from_numpy(np.array(traget_BM)).float().squeeze().to(device)  # (27, 18211)

        valid_target = torch.from_numpy(np.array(valid_target)).float().squeeze()  # (123, 18211)
        valid_target_BM = torch.from_numpy(np.array(valid_target_BM)).float().squeeze()  # (7, 18211)
        
        
        N_train = target.shape[0]
        N_train_BM = target_BM.shape[0]
        
        N_valid = valid_target.shape[0]
        N_valid_BM = valid_target_BM.shape[0]
    
    
    for epoch in tqdm.tqdm(range(1, num_epochs+1)):
                
        optimizer.zero_grad()
        
        """
        Train
        """
        model.train()
        pred = model(data)  # (876, 18211)
               
        # 1107 edited
        total_train_loss = (alpha-1.0)*(N_train_BM/N_train)*F.mse_loss(pred[train_BM_mask[0]], target_BM) + F.mse_loss(pred[train_mask[0]], target)

        total_train_loss.backward()
        optimizer.step()
        
        """
        Validation
        """
        model.eval()
        with torch.no_grad():
            
            total_valid_loss = mean_squared_error(valid_target_BM.numpy(), pred[valid_BM_mask[0]].detach().cpu().numpy())
               
            # check if model is better
            if total_valid_loss < best_valid_loss:
                best_epoch = epoch
                best_valid_loss = total_valid_loss
                best_model = copy.deepcopy(model)
                best_status = '<<<'
            else:
                best_status = ''

        if print_every > 0:
            if (epoch-1) % print_every == 0:
                print("Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f} {}".format(epoch, total_train_loss, total_valid_loss, best_status))
            elif epoch == num_epochs:
                print("Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f} {}".format(epoch, total_train_loss, total_valid_loss, best_status))

    if print_every > 0:
        print("\nBest epoch: {}, val_loss: {:.4f}".format(best_epoch, best_valid_loss))
    else:
        print("Cell_emb: {}, sm_emb: {}, lr = {}, hidden_channels = {}  >>>  ".format(cell_emb_dim, sm_emb_dim, lr, hidden_channels) +
              "Best epoch: {}/{}, val_loss: {:.4f}".format(best_epoch, num_epochs, best_valid_loss))
        
    print("---------------------------------")
    return best_model, best_valid_loss

In [60]:
from torch_geometric.nn import SAGEConv, to_hetero
from torch.nn import Linear, Dropout, Embedding
import torch.nn.functional as F

class GNN(torch.nn.Module):
    """
    Compute node embedding
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()       
        self.conv1 = SAGEConv(-1, hidden_channels, normalize=True)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels, normalize=True)
        self.dropout1 = Dropout(p=p_dropout)
        
    def forward(self, x, edge_index):
        x = self.dropout1(F.elu(self.conv1(x, edge_index)))        
        x = self.conv2(x, edge_index)
        return x

class MLP(torch.nn.Module):
    """
    Model that predicts the gene expression given node features at the two ends
    
    1020: Change inner layer from (1024, 1024) units to (256, 256) units 
    """
    def __init__(self, hidden_channels, p_dropout=0.2):
        super().__init__()

        self.dropout1 = Dropout(p=p_dropout)
        self.dropout2 = Dropout(p=p_dropout)
        
        self.layer1 = Linear(2*hidden_channels, 256)
        self.layer2 = Linear(256, 1024)
        self.layer3 = Linear(1024, 18211)
    
    def forward(self, x_cell, x_sm, edge_label_index):
        
        # Retrieve embeddings from the two nodes
        feat_cell = x_cell[edge_label_index[0]]  # dim = hidden_channels
        feat_sm = x_sm[edge_label_index[1]]  #  dim = hidden_channels

        # calculate the gene expression
        y = torch.cat((feat_cell, feat_sm), dim=1)  # dim = 2*hidden_channels
        
        y = F.relu(self.layer1(y))  # dim = 256
        y = self.dropout1(y)
        
        y = F.relu(self.layer2(y))  # dim = 1024
        y = self.dropout2(y)
        
        y = self.layer3(y)  # dim = 18211
        
        return y


class Model(torch.nn.Module):
    def __init__(self, data, cell_emb_dim, sm_emb_dim, hidden_channels, p_dropout):
        super().__init__()

        # cell emb for each cell type
        self.cell_emb = Embedding(data["cell"].num_nodes, cell_emb_dim)
        
        """
        1027 - change from hidden_channels to 64
        """
        # cell emb leveraging control gene expression
        #self.cell_lin = Linear(18211, hidden_channels)
        self.cell_lin = Linear(18211, 64)

        # sm emb for each sm type
        if sm_emb_dim > 0:
            self.sm_emb = Embedding(data["sm"].num_nodes, sm_emb_dim)
        else:
            self.sm_emb = None
        # sm emb leveraging SMILES embedding
        self.sm_lin = Linear(768, hidden_channels)
        
        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels=hidden_channels, p_dropout=p_dropout)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        
        self.MLP = MLP(hidden_channels=hidden_channels, p_dropout=p_dropout)
        
                
    def forward(self, data):
        # calculate node embedding
        #x_cell = self.cell_emb(data["cell"].node_id)
        x_cell = torch.cat((self.cell_emb(data["cell"].node_id), self.cell_lin(data["cell"].x)), dim=1)
        
        
        if self.sm_emb:
            x_sm = torch.cat((self.sm_emb(data["sm"].node_id), self.sm_lin(data["sm"].x)), dim=1)
        else:
            x_sm = self.sm_lin(data["sm"].x)
            
        
        x_dict = {
          "cell": x_cell,
          "sm": x_sm,
        } 

        # compute final node embedding
        x_dict = self.gnn(x_dict, data.edge_index_dict)

        # Use node embedding to predict gene expression
        pred = self.MLP(
            x_dict["cell"],
            x_dict["sm"],
            data["cell", "gene_exp", "sm"].edge_label_index,
        )
        return pred

In [61]:
for _ in range(20):

    tmp_model, tmp_valid_loss = train(data,
                                      train_edge_indeces,
                                      valid_edge_indeces,
                                      train_edge2feat,
                                      valid_edge2feat,
                                      train_mask,
                                      valid_mask,
                                      BM_mask,
                                      alpha=1.0,
                                      cell_emb_dim=2,
                                      sm_emb_dim=32,
                                      lr=5e-4,
                                      hidden_channels=16,
                                      p_dropout=0.2,
                                      num_epochs=2000,
                                      print_every=-1)

100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:45<00:00, 44.40it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 749/2000, val_loss: 9.1092
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:44<00:00, 44.52it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 986/2000, val_loss: 8.2229
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:44<00:00, 45.00it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 760/2000, val_loss: 8.2096
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:45<00:00, 43.94it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 819/2000, val_loss: 8.4633
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:44<00:00, 44.68it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 711/2000, val_loss: 8.3839
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:45<00:00, 44.17it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1542/2000, val_loss: 7.1488
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:45<00:00, 44.13it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 789/2000, val_loss: 7.5599
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:44<00:00, 44.49it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 668/2000, val_loss: 7.8817
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:45<00:00, 44.33it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 747/2000, val_loss: 7.9478
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:45<00:00, 44.12it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 957/2000, val_loss: 7.8973
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:44<00:00, 44.63it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1005/2000, val_loss: 7.8590
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:45<00:00, 44.22it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 731/2000, val_loss: 9.0770
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:45<00:00, 44.00it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 535/2000, val_loss: 8.1703
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:45<00:00, 43.60it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1035/2000, val_loss: 7.9565
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:45<00:00, 44.16it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 679/2000, val_loss: 7.6298
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:44<00:00, 44.61it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1482/2000, val_loss: 8.1742
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:45<00:00, 44.30it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1312/2000, val_loss: 7.8740
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:44<00:00, 44.47it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1117/2000, val_loss: 7.0806
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:44<00:00, 44.59it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 951/2000, val_loss: 8.9297
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:45<00:00, 44.22it/s]

Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1061/2000, val_loss: 9.0092
---------------------------------





In [62]:
cur_model = None
cur_valid_loss = float('inf')

while not cur_model:
    
    cur_sub_index = 0
    
    for _ in range(10):

        cur_sub_index += 1

        tmp_model, tmp_valid_loss = train(data,
                                          train_edge_indeces,
                                          valid_edge_indeces,
                                          train_edge2feat,
                                          valid_edge2feat,
                                          train_mask,
                                          valid_mask,
                                          BM_mask,
                                          alpha=1.0,
                                          cell_emb_dim=2,
                                          sm_emb_dim=32,
                                          lr=5e-4,
                                          hidden_channels=16,
                                          p_dropout=0.2,
                                          num_epochs=2000,
                                          print_every=-1)

        if tmp_valid_loss >= 7.5:
            continue

        """
        Save all models
        """
        cur_valid_loss = tmp_valid_loss
        cur_model = tmp_model
        cur_model_name = 'model_1109_exp6_' + str(cur_sub_index)
        print(cur_model_name)
        print(cur_model)
        final_submission(data=data, model_name=cur_model_name, model=cur_model)
        torch.save(cur_model.state_dict(), os.path.join('../pyg_models/', cur_model_name+'.pth'))

100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:45<00:00, 44.37it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1352/2000, val_loss: 7.7714
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:44<00:00, 44.60it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1064/2000, val_loss: 7.8516
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:45<00:00, 44.34it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 751/2000, val_loss: 8.1462
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:45<00:00, 44.25it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 859/2000, val_loss: 7.7518
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:46<00:00, 42.77it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1246/2000, val_loss: 7.4393
---------------------------------
model_1109_exp6_5
Model(
  (cell_emb): Embedding(6, 2)
  (cell_lin): Linear(in_features=18211, out_features=64, bias=True)
  (sm_emb): Embedding(146, 32)
  (sm_lin): Linear(in_features=768, out_features=16, bias=True)
  (gnn): GraphModule(
    (conv1): Module(
      (cell__gene_exp__sm): SAGEConv(-1, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(-1, 16, aggr=mean)
    )
    (dropout1): Module(
      (cell): Dropout(p=0.2, inplace=False)
      (sm): Dropout(p=0.2, inplace=False)
    )
    (conv2): Module(
      (cell__gene_exp__sm): SAGEConv(16, 16, aggr=mean)
      (sm__rev_gene_exp__cell): SAGEConv(16, 16, aggr=mean)
    )
  )
  (MLP): MLP(
    (dropout1): Dropout(p=0.2, inplace=False)
    (dropout2): Dropout(p=0.2, inplace=False)
    (layer1): Linear(in_features=32, out_features=256, bias=True)
    (layer2): Linear(in_features=256, out_f

100%|█████████████████████████████████████████████████████████████████████| 255/255 [01:05<00:00,  3.89it/s]


File saved.
(255, 18212)


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:44<00:00, 44.48it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 518/2000, val_loss: 8.6442
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:45<00:00, 44.11it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1224/2000, val_loss: 7.8396
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:46<00:00, 43.15it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1055/2000, val_loss: 8.0958
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:45<00:00, 43.68it/s]


Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 1976/2000, val_loss: 8.5091
---------------------------------


100%|███████████████████████████████████████████████████████████████████| 2000/2000 [00:45<00:00, 43.91it/s]

Cell_emb: 2, sm_emb: 32, lr = 0.0005, hidden_channels = 16  >>>  Best epoch: 660/2000, val_loss: 8.8226
---------------------------------



