In [1]:
# library

import os
import re
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, SGD


### load dataset & EDA

In [2]:
# load dataset & EDA

de_train = pd.read_parquet("/home/aiuser/taeuk/open-problems-single-cell-perturbations/de_train.parquet")
id_map = pd.read_csv("/home/aiuser/taeuk/open-problems-single-cell-perturbations/id_map.csv")
submission = pd.read_csv("/home/aiuser/taeuk/open-problems-single-cell-perturbations/sample_submission.csv")

display(de_train.head())
display(de_train.iloc[:, :2].groupby("cell_type").count())
display(id_map.iloc[:, 1:].groupby("cell_type").count())

Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,NK cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.10472,-0.077524,-1.625596,-0.144545,0.143555,...,-0.227781,-0.010752,-0.023881,0.674536,-0.453068,0.005164,-0.094959,0.034127,0.221377,0.368755
1,T cells CD4+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.915953,-0.88438,0.371834,-0.081677,-0.498266,...,-0.494985,-0.303419,0.304955,-0.333905,-0.315516,-0.369626,-0.095079,0.70478,1.096702,-0.869887
2,T cells CD8+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,-0.387721,-0.305378,0.567777,0.303895,-0.022653,...,-0.119422,-0.033608,-0.153123,0.183597,-0.555678,-1.494789,-0.21355,0.415768,0.078439,-0.259365
3,T regulatory cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.232893,0.129029,0.336897,0.486946,0.767661,...,0.451679,0.704643,0.015468,-0.103868,0.865027,0.189114,0.2247,-0.048233,0.216139,-0.085024
4,NK cells,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False,4.290652,-0.063864,-0.017443,-0.541154,0.570982,...,0.758474,0.510762,0.607401,-0.123059,0.214366,0.487838,-0.819775,0.112365,-0.122193,0.676629


Unnamed: 0_level_0,sm_name
cell_type,Unnamed: 1_level_1
B cells,17
Myeloid cells,17
NK cells,146
T cells CD4+,146
T cells CD8+,142
T regulatory cells,146


Unnamed: 0_level_0,sm_name
cell_type,Unnamed: 1_level_1
B cells,128
Myeloid cells,127


In [3]:
# gene correlations

corr_gene = de_train.iloc[:, 5:].corr()

In [4]:
corr_gene

Unnamed: 0,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,A4GALT,AAAS,AACS,AAGAB,AAK1,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
A1BG,1.000000,0.714458,0.457526,0.522942,0.778408,0.837132,0.439837,0.705893,0.192259,-0.070669,...,0.194641,0.652388,0.439657,0.755152,0.805117,0.792222,0.515260,0.282706,0.064979,0.149878
A1BG-AS1,0.714458,1.000000,0.301795,0.369418,0.577066,0.712535,0.412077,0.638482,0.245015,-0.019140,...,0.177971,0.591506,0.302379,0.668320,0.630534,0.697719,0.380695,0.323991,0.085938,0.072710
A2M,0.457526,0.301795,1.000000,0.614617,0.407135,0.350203,0.187354,0.505190,0.011625,-0.193726,...,-0.045483,0.456399,0.617574,0.486742,0.505622,0.536887,0.599092,0.234177,0.054594,0.214293
A2M-AS1,0.522942,0.369418,0.614617,1.000000,0.625173,0.347183,0.252877,0.449189,-0.119271,-0.009788,...,-0.148671,0.404312,0.329363,0.475214,0.583882,0.547979,0.469065,0.100440,0.235550,0.208293
A2MP1,0.778408,0.577066,0.407135,0.625173,1.000000,0.741961,0.355945,0.567186,0.146833,0.037726,...,0.074257,0.531156,0.267254,0.655978,0.768620,0.667382,0.398134,0.126451,0.130264,0.141848
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ZXDB,0.792222,0.697719,0.536887,0.547979,0.667382,0.750323,0.417411,0.730304,0.185898,-0.180804,...,0.185169,0.669619,0.532330,0.756865,0.832125,1.000000,0.624481,0.274782,-0.003387,0.134351
ZXDC,0.515260,0.380695,0.599092,0.469065,0.398134,0.432041,0.219351,0.557733,0.100784,-0.185686,...,0.103798,0.524886,0.634150,0.554500,0.585897,0.624481,1.000000,0.243037,-0.115429,0.211058
ZYG11B,0.282706,0.323991,0.234177,0.100440,0.126451,0.225936,0.164784,0.329503,0.217202,0.193395,...,0.163312,0.276594,0.194563,0.239123,0.176127,0.274782,0.243037,1.000000,0.044252,0.063161
ZYX,0.064979,0.085938,0.054594,0.235550,0.130264,0.001245,0.083897,-0.030927,-0.060895,0.230573,...,-0.164996,0.051608,-0.192064,0.076097,-0.009516,-0.003387,-0.115429,0.044252,1.000000,-0.011278


In [258]:
c17 = de_train.loc[(de_train.cell_type=="B cells")|(de_train.cell_type=="Myeloid cells"), "sm_name"].values
c17 = list(set(c17))
de_split = de_train.apply(lambda x : x['sm_name'] in c17, axis=1)
de_split = de_train.loc[de_split, :]
de_split = de_split.drop(["sm_lincs_id","SMILES","control"], axis=1)

gene_names = list(de_split.columns[2:])
de_melt = de_split.melt(id_vars=["sm_name", "cell_type"], value_vars=(gene_names))

# gene vs cell type
tri_idx = [(i,j) for i in range(6) for j in range(i+1, 6)]
corrs_by_cell = pd.DataFrame(np.zeros((18211, 15)))

for g in tqdm(range(18211)):
    
    genes_by_cell = de_melt.loc[de_melt.variable==gene_names[g], :].pivot(index="sm_name", columns="cell_type", values="value")
    if g == 0:
        corrs_by_cell.columns = [genes_by_cell.columns[idx[0]]+"_"+genes_by_cell.columns[idx[1]] for idx in tri_idx]
    corr = genes_by_cell.corr()
    row = [corr.iloc[tri_idx[i]] for i in range(15)]
    corrs_by_cell.iloc[g, :] = row
corrs_by_cell

  0%|          | 0/18211 [00:00<?, ?it/s]

Unnamed: 0,B cells_Myeloid cells,B cells_NK cells,B cells_T cells CD4+,B cells_T cells CD8+,B cells_T regulatory cells,Myeloid cells_NK cells,Myeloid cells_T cells CD4+,Myeloid cells_T cells CD8+,Myeloid cells_T regulatory cells,NK cells_T cells CD4+,NK cells_T cells CD8+,NK cells_T regulatory cells,T cells CD4+_T cells CD8+,T cells CD4+_T regulatory cells,T cells CD8+_T regulatory cells
0,0.905482,0.895722,0.375210,-0.323287,0.176634,0.880579,0.290179,-0.391285,0.316655,0.484152,-0.153985,0.076878,-0.460513,0.218283,-0.623057
1,0.409351,0.639225,0.123594,-0.296232,0.090733,0.633108,0.643295,-0.510739,0.847070,0.619108,-0.405808,0.327202,-0.100659,0.634006,-0.333909
2,0.127967,0.226809,0.429107,-0.596813,-0.057069,0.787378,0.393811,0.079111,-0.667496,0.652628,0.056746,-0.932599,0.138729,-0.544332,-0.158529
3,0.824447,-0.045439,-0.003867,-0.846540,-0.074458,-0.000292,-0.039849,-0.643197,0.033730,0.953095,-0.125184,-0.917929,-0.104112,-0.925651,0.206291
4,0.846774,0.908087,0.450314,-0.877904,-0.033933,0.775706,0.174370,-0.917120,0.339516,0.430031,-0.653138,-0.070585,0.519603,-0.281798,-0.372194
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18206,0.907377,0.855056,0.612551,0.333848,-0.288227,0.809065,0.524023,0.216929,-0.116145,0.551853,0.341138,-0.040655,0.588474,-0.403063,-0.584464
18207,0.338583,0.572691,0.771528,0.086038,-0.110477,0.150111,0.350980,-0.123649,-0.860691,0.378338,0.327283,0.134020,0.252471,-0.117507,0.231576
18208,0.007184,0.587867,0.770966,0.640051,0.777028,0.043483,0.383415,0.489519,0.263016,0.643875,0.725290,0.593436,0.719267,0.885651,0.726981
18209,-0.007667,-0.372145,-0.449387,-0.086326,-0.009817,-0.121042,0.172314,0.374107,0.823531,0.167669,-0.009806,-0.228858,0.190652,0.005990,0.220059


In [49]:
corrs_by_cell = pd.read_csv("/home/aiuser/taeuk/corrs_by_cell.csv")
(corrs_by_cell > 0.5).sum(axis=0) / corrs_by_cell.shape[0]

B cells_Myeloid cells               0.642743
B cells_NK cells                    0.757180
B cells_T cells CD4+                0.530174
B cells_T cells CD8+                0.130416
B cells_T regulatory cells          0.123003
Myeloid cells_NK cells              0.625117
Myeloid cells_T cells CD4+          0.237384
Myeloid cells_T cells CD8+          0.105596
Myeloid cells_T regulatory cells    0.136401
NK cells_T cells CD4+               0.696008
NK cells_T cells CD8+               0.119598
NK cells_T regulatory cells         0.110263
T cells CD4+_T cells CD8+           0.356158
T cells CD4+_T regulatory cells     0.182252
T cells CD8+_T regulatory cells     0.094009
dtype: float64

In [42]:
cor_tmp

Unnamed: 0,B cells_Myeloid cells,B cells_NK cells,B cells_T cells CD4+,B cells_T cells CD8+,B cells_T regulatory cells,Myeloid cells_NK cells,Myeloid cells_T cells CD4+,Myeloid cells_T cells CD8+,Myeloid cells_T regulatory cells,NK cells_T cells CD4+,NK cells_T cells CD8+,NK cells_T regulatory cells,T cells CD4+_T cells CD8+,T cells CD4+_T regulatory cells,T cells CD8+_T regulatory cells
B cells_Myeloid cells,1.0,0.524587,0.257967,-0.305271,-0.003383,0.764165,0.504755,-0.264118,0.196902,0.323811,-0.286464,0.098533,0.028638,-0.105866,-0.286765
B cells_NK cells,0.524587,1.0,0.48592,-0.295177,0.032539,0.529955,0.273522,-0.235804,0.099691,0.528724,-0.277077,0.135274,0.082636,-0.110616,-0.32521
B cells_T cells CD4+,0.257967,0.48592,1.0,0.213072,0.111415,0.238064,0.470442,0.139485,0.030419,0.59259,0.142526,0.086743,0.115109,0.013477,-0.03604
B cells_T cells CD8+,-0.305271,-0.295177,0.213072,1.0,0.144632,-0.270287,0.124025,0.730407,-0.035704,-0.056312,0.83554,-0.008861,0.299639,0.202169,0.453212
B cells_T regulatory cells,-0.003383,0.032539,0.111415,0.144632,1.0,-0.030283,0.064378,0.161099,0.435813,-0.022013,0.205524,0.603689,0.073933,0.684534,0.449856
Myeloid cells_NK cells,0.764165,0.529955,0.238064,-0.270287,-0.030283,1.0,0.5481,-0.287354,0.234885,0.356107,-0.303146,0.120137,0.021469,-0.102429,-0.29045
Myeloid cells_T cells CD4+,0.504755,0.273522,0.470442,0.124025,0.064378,0.5481,1.0,0.266511,0.124392,0.410198,0.105782,0.085542,0.032043,0.050292,0.002264
Myeloid cells_T cells CD8+,-0.264118,-0.235804,0.139485,0.730407,0.161099,-0.287354,0.266511,1.0,-0.117236,-0.053406,0.723741,-0.011434,0.183286,0.201104,0.428463
Myeloid cells_T regulatory cells,0.196902,0.099691,0.030419,-0.035704,0.435813,0.234885,0.124392,-0.117236,1.0,0.030248,-0.020386,0.422313,0.07922,0.386657,0.153169
NK cells_T cells CD4+,0.323811,0.528724,0.59259,-0.056312,-0.022013,0.356107,0.410198,-0.053406,0.030248,1.0,-0.038046,0.134511,0.016859,-0.0502,-0.203137


In [33]:
cor_tmp = corrs_by_cell.corr()
nums = 5
df = pd.DataFrame(np.zeros((cor_tmp.shape[0], nums)))
df.index = cor_tmp.index

for idx in cor_tmp.index:
    row = [(pair, cor) for pair, cor in cor_tmp.loc[idx, :].sort_values(ascending=False)[1:nums+1].items()]
    df.loc[idx, :] = pd.Series(row).T
df

  df.loc[idx, :] = pd.Series(row).T
  df.loc[idx, :] = pd.Series(row).T
  df.loc[idx, :] = pd.Series(row).T
  df.loc[idx, :] = pd.Series(row).T
  df.loc[idx, :] = pd.Series(row).T


Unnamed: 0,0,1,2,3,4
B cells_Myeloid cells,"(Myeloid cells_NK cells, 0.7641649292983431)","(B cells_NK cells, 0.5245870587714191)","(Myeloid cells_T cells CD4+, 0.504754789375354)","(NK cells_T cells CD4+, 0.3238112701500538)","(B cells_T cells CD4+, 0.2579668162256001)"
B cells_NK cells,"(Myeloid cells_NK cells, 0.5299546145149721)","(NK cells_T cells CD4+, 0.5287239993031712)","(B cells_Myeloid cells, 0.5245870587714191)","(B cells_T cells CD4+, 0.4859196557271453)","(Myeloid cells_T cells CD4+, 0.273521711662071)"
B cells_T cells CD4+,"(NK cells_T cells CD4+, 0.5925895471024951)","(B cells_NK cells, 0.4859196557271453)","(Myeloid cells_T cells CD4+, 0.4704419810789211)","(B cells_Myeloid cells, 0.2579668162256001)","(Myeloid cells_NK cells, 0.23806435729119393)"
B cells_T cells CD8+,"(NK cells_T cells CD8+, 0.8355401075410462)","(Myeloid cells_T cells CD8+, 0.7304070000663468)","(T cells CD8+_T regulatory cells, 0.4532115639...","(T cells CD4+_T cells CD8+, 0.2996386092776588)","(B cells_T cells CD4+, 0.2130718482138083)"
B cells_T regulatory cells,"(T cells CD4+_T regulatory cells, 0.6845340741...","(NK cells_T regulatory cells, 0.6036886689426403)","(T cells CD8+_T regulatory cells, 0.4498560864...","(Myeloid cells_T regulatory cells, 0.435813165...","(NK cells_T cells CD8+, 0.20552436096891816)"
Myeloid cells_NK cells,"(B cells_Myeloid cells, 0.7641649292983431)","(Myeloid cells_T cells CD4+, 0.5481003918594523)","(B cells_NK cells, 0.5299546145149721)","(NK cells_T cells CD4+, 0.3561068037727002)","(B cells_T cells CD4+, 0.23806435729119393)"
Myeloid cells_T cells CD4+,"(Myeloid cells_NK cells, 0.5481003918594523)","(B cells_Myeloid cells, 0.504754789375354)","(B cells_T cells CD4+, 0.4704419810789211)","(NK cells_T cells CD4+, 0.4101983987719259)","(B cells_NK cells, 0.273521711662071)"
Myeloid cells_T cells CD8+,"(B cells_T cells CD8+, 0.7304070000663468)","(NK cells_T cells CD8+, 0.7237410701473934)","(T cells CD8+_T regulatory cells, 0.4284625970...","(Myeloid cells_T cells CD4+, 0.2665112903587434)","(T cells CD4+_T regulatory cells, 0.2011043963..."
Myeloid cells_T regulatory cells,"(B cells_T regulatory cells, 0.435813165072921)","(NK cells_T regulatory cells, 0.42231298432111...","(T cells CD4+_T regulatory cells, 0.3866571522...","(Myeloid cells_NK cells, 0.23488518268381509)","(B cells_Myeloid cells, 0.1969015850970744)"
NK cells_T cells CD4+,"(B cells_T cells CD4+, 0.5925895471024951)","(B cells_NK cells, 0.5287239993031712)","(Myeloid cells_T cells CD4+, 0.4101983987719259)","(Myeloid cells_NK cells, 0.3561068037727002)","(B cells_Myeloid cells, 0.3238112701500538)"


In [28]:
[(pair, cor) for pair, cor in cor_tmp.loc["B cells_Myeloid cells", :].sort_values(ascending=False)[1:nums+1].items()]

[('Myeloid cells_NK cells', 0.7641649292983431),
 ('B cells_NK cells', 0.5245870587714191),
 ('Myeloid cells_T cells CD4+', 0.504754789375354),
 ('NK cells_T cells CD4+', 0.3238112701500538),
 ('B cells_T cells CD4+', 0.2579668162256001)]

In [259]:
# gene vs cell type
tri_idx = [(i,j) for i in range(17) for j in range(i+1, 17)]
corrs_by_name = pd.DataFrame(np.zeros((18211, 136)))

for g in tqdm(range(18211)):
    
    genes_by_name = de_melt.loc[de_melt.variable==gene_names[g], :].pivot(index="cell_type", columns="sm_name", values="value")
    if g == 0:
        corrs_by_name.columns = [genes_by_name.columns[idx[0]]+"_"+genes_by_name.columns[idx[1]] for idx in tri_idx]
    corr = genes_by_name.corr()
    row = [corr.iloc[tri_idx[i]] for i in range(136)]
    corrs_by_name.iloc[g, :] = row
corrs_by_name

  0%|          | 0/18211 [00:00<?, ?it/s]

Unnamed: 0,Alvocidib_Belinostat,Alvocidib_CHIR-99021,Alvocidib_Crizotinib,Alvocidib_Dabrafenib,Alvocidib_Dactolisib,Alvocidib_Foretinib,Alvocidib_Idelalisib,Alvocidib_LDN 193189,Alvocidib_Linagliptin,Alvocidib_MLN 2238,...,Oprozomib (ONX 0912)_Palbociclib,Oprozomib (ONX 0912)_Penfluridol,Oprozomib (ONX 0912)_Porcn Inhibitor III,Oprozomib (ONX 0912)_R428,Palbociclib_Penfluridol,Palbociclib_Porcn Inhibitor III,Palbociclib_R428,Penfluridol_Porcn Inhibitor III,Penfluridol_R428,Porcn Inhibitor III_R428
0,-0.678366,-0.357550,-0.279944,-0.469193,0.593291,-0.438795,0.837843,0.953592,-0.591115,0.670386,...,0.038837,0.363404,0.578687,-0.315594,-0.607147,0.158275,-0.349787,0.203467,0.222975,0.018866
1,-0.380463,0.595359,-0.110443,-0.280384,0.200448,-0.519325,-0.551719,0.370149,-0.783819,0.587951,...,0.293632,-0.415901,-0.234204,0.479206,-0.256939,-0.144587,0.873800,0.684758,0.087478,-0.090474
2,0.325494,-0.421461,-0.547177,0.629261,-0.843277,0.000112,-0.473704,-0.054034,-0.112018,0.807183,...,0.712041,-0.925474,0.197049,-0.727386,-0.423427,-0.263640,-0.565783,-0.125504,-0.068351,0.044498
3,0.251181,-0.195370,0.584273,0.272200,-0.306319,-0.119838,0.239116,-0.267157,-0.243321,0.951603,...,-0.248629,0.395340,-0.446829,-0.643552,-0.505334,0.619055,0.049557,-0.985010,-0.397491,0.327272
4,-0.317867,0.546507,0.507989,-0.071687,0.418325,0.450859,0.295435,-0.070157,0.756840,0.882432,...,0.684890,-0.750850,-0.795588,0.229752,-0.662148,0.135100,0.327093,0.200748,0.144519,-0.462601
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18206,-0.489792,-0.298316,-0.365792,-0.408647,0.224385,-0.608529,0.254372,-0.563167,0.573913,0.219601,...,0.645735,0.468762,-0.777653,-0.586279,0.305629,-0.578100,0.158129,0.104055,0.001302,0.492585
18207,-0.072636,-0.280887,0.673323,0.016196,-0.556377,-0.157123,-0.586318,0.116283,-0.402915,0.401807,...,0.536880,-0.180770,0.033063,0.763301,-0.079653,-0.401693,0.322178,-0.515578,-0.677379,0.569663
18208,0.362539,0.342146,0.730185,0.121317,-0.263068,-0.408662,-0.450978,-0.752649,-0.223407,0.574724,...,-0.106701,-0.603292,0.824076,0.697401,0.174108,0.171609,0.567341,0.005721,0.180799,0.715125
18209,0.236713,-0.286155,0.425419,-0.678034,0.576773,0.739436,0.333119,0.204466,-0.519603,0.481797,...,-0.085639,0.303560,-0.345273,-0.284089,-0.068681,-0.460870,0.254299,-0.733501,0.299807,0.022749


In [9]:
corrs_by_name = pd.read_csv("/home/aiuser/taeuk/corrs_by_name.csv")
corrs_by_name

Unnamed: 0,Alvocidib_Belinostat,Alvocidib_CHIR-99021,Alvocidib_Crizotinib,Alvocidib_Dabrafenib,Alvocidib_Dactolisib,Alvocidib_Foretinib,Alvocidib_Idelalisib,Alvocidib_LDN 193189,Alvocidib_Linagliptin,Alvocidib_MLN 2238,...,Oprozomib (ONX 0912)_Palbociclib,Oprozomib (ONX 0912)_Penfluridol,Oprozomib (ONX 0912)_Porcn Inhibitor III,Oprozomib (ONX 0912)_R428,Palbociclib_Penfluridol,Palbociclib_Porcn Inhibitor III,Palbociclib_R428,Penfluridol_Porcn Inhibitor III,Penfluridol_R428,Porcn Inhibitor III_R428
0,-0.678366,-0.357550,-0.279944,-0.469193,0.593291,-0.438795,0.837843,0.953592,-0.591115,0.670386,...,0.038837,0.363404,0.578687,-0.315594,-0.607147,0.158275,-0.349787,0.203467,0.222975,0.018866
1,-0.380463,0.595359,-0.110443,-0.280384,0.200448,-0.519325,-0.551719,0.370149,-0.783819,0.587951,...,0.293632,-0.415901,-0.234204,0.479206,-0.256939,-0.144587,0.873800,0.684758,0.087478,-0.090474
2,0.325494,-0.421461,-0.547177,0.629261,-0.843277,0.000112,-0.473704,-0.054034,-0.112018,0.807183,...,0.712041,-0.925474,0.197049,-0.727386,-0.423427,-0.263640,-0.565783,-0.125504,-0.068351,0.044498
3,0.251181,-0.195370,0.584273,0.272200,-0.306319,-0.119838,0.239116,-0.267157,-0.243321,0.951603,...,-0.248629,0.395340,-0.446829,-0.643552,-0.505334,0.619055,0.049557,-0.985010,-0.397491,0.327272
4,-0.317867,0.546507,0.507989,-0.071687,0.418325,0.450859,0.295435,-0.070157,0.756840,0.882432,...,0.684890,-0.750850,-0.795588,0.229752,-0.662148,0.135100,0.327093,0.200748,0.144519,-0.462601
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18206,-0.489792,-0.298316,-0.365792,-0.408647,0.224385,-0.608529,0.254372,-0.563167,0.573913,0.219601,...,0.645735,0.468762,-0.777653,-0.586279,0.305629,-0.578100,0.158129,0.104055,0.001302,0.492585
18207,-0.072636,-0.280887,0.673323,0.016196,-0.556377,-0.157123,-0.586318,0.116283,-0.402915,0.401807,...,0.536880,-0.180770,0.033063,0.763301,-0.079653,-0.401693,0.322178,-0.515578,-0.677379,0.569663
18208,0.362539,0.342146,0.730185,0.121317,-0.263068,-0.408662,-0.450978,-0.752649,-0.223407,0.574724,...,-0.106701,-0.603292,0.824076,0.697401,0.174108,0.171609,0.567341,0.005721,0.180799,0.715125
18209,0.236713,-0.286155,0.425419,-0.678034,0.576773,0.739436,0.333119,0.204466,-0.519603,0.481797,...,-0.085639,0.303560,-0.345273,-0.284089,-0.068681,-0.460870,0.254299,-0.733501,0.299807,0.022749


In [7]:
mean_cell_type = de_train.iloc[:, [0]+list(range(5,de_train.shape[1]))].groupby("cell_type").mean().sort_index().reset_index()
mean_sm_name = de_train.iloc[:, [1]+list(range(5,de_train.shape[1]))].groupby("sm_name").mean().sort_index().reset_index()

for g in 5+np.random.choice(18211, 10, replace=False):

    gene = de_train.columns[g]
    df = de_train[["cell_type", "sm_name", gene]].merge(mean_cell_type[["cell_type", gene]], how="left", on="cell_type", suffixes=(None,"_by_cell"))
    df = df.merge(mean_sm_name[["sm_name", gene]], how="left", on="sm_name", suffixes=(None,"_by_name"))
    display(df.iloc[:,2:].corr().iloc[[0], 1:])

Unnamed: 0,PSRC1_by_cell,PSRC1_by_name
PSRC1,0.243834,0.688808


Unnamed: 0,TRBV19_by_cell,TRBV19_by_name
TRBV19,0.41422,0.636908


Unnamed: 0,CD1A_by_cell,CD1A_by_name
CD1A,0.167433,0.624594


Unnamed: 0,ZNF597_by_cell,ZNF597_by_name
ZNF597,0.217853,0.552565


Unnamed: 0,TRBV30_by_cell,TRBV30_by_name
TRBV30,0.335738,0.67398


Unnamed: 0,RBM48_by_cell,RBM48_by_name
RBM48,0.170014,0.593458


Unnamed: 0,PPP6C_by_cell,PPP6C_by_name
PPP6C,0.192989,0.575495


Unnamed: 0,KLRC4-KLRK1_by_cell,KLRC4-KLRK1_by_name
KLRC4-KLRK1,0.368933,0.700871


Unnamed: 0,AC135803.1_by_cell,AC135803.1_by_name
AC135803.1,0.311674,0.718576


Unnamed: 0,HIVEP1_by_cell,HIVEP1_by_name
HIVEP1,0.126706,0.637452


### Dataset & Modelling

In [3]:
# feature select
# feature는 cell_type, sm_name, mean value, SMILES 이용, 
# 최대한 id_map에 맞게 최소한으로 설정할건데 일단은 de_train을 기준으로 해보자
de_train = pd.read_parquet("/home/aiuser/taeuk/open-problems-single-cell-perturbations/de_train.parquet")
id_map = pd.read_csv("/home/aiuser/taeuk/open-problems-single-cell-perturbations/id_map.csv")
submission = pd.read_csv("/home/aiuser/taeuk/open-problems-single-cell-perturbations/sample_submission.csv")

sm_name_id_map = sorted(id_map.sm_name.unique())
cell_type_id_map = sorted(id_map.cell_type.unique())

sm_name_de_train = sorted(de_train.sm_name.unique())
cell_type_de_train = sorted(de_train.cell_type.unique())

# cell type, compound dictionary 

cell_type_dict = {cell_type_de_train[i]:i for i in range(len(cell_type_de_train))}
sm_name_dict = {sm_name_de_train[i]:i for i in range(len(sm_name_de_train))}
print(len(cell_type_dict), len(sm_name_dict))

# mean value

mean_cell_type = de_train.iloc[:, [0]+list(range(5,de_train.shape[1]))].groupby("cell_type").mean().sort_index().reset_index()
mean_sm_name = de_train.iloc[:, [1]+list(range(5,de_train.shape[1]))].groupby("sm_name").mean().sort_index().reset_index()

# SMILES preprocessing

def smile_preprocessing(smile):
    smile = re.sub("[()]", " ", smile)
    smile = re.sub("[1-9]", "1", smile)
    return list(set(smile.split()))

# SMILES decomposition
smiles = de_train[["sm_name", "SMILES"]].drop_duplicates().reset_index(drop=True)
compounds = []
for smile in smiles.SMILES:
    compounds += smile_preprocessing(smile)
compounds = list(set(compounds))
compounds_dict = {compounds[i]:i for i in range(len(compounds))}

smiles = smiles.join(pd.DataFrame(np.zeros((smiles.shape[0], len(compounds)), dtype=np.int32)))
for i in range(smiles.shape[0]):
    coms = list(set(smile_preprocessing(smiles["SMILES"][i])))
    for com in coms:
        smiles.iloc[i, 2 + compounds_dict[com]] = 1
        
compound = smiles.set_index("sm_name").drop("SMILES", axis=1)
compound.head()

6 146


Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,419,420,421,422,423,424,425,426,427,428
sm_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Clotrimazole,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
Mometasone Furoate,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
Idelalisib,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
Vandetanib,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
Bosutinib,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [4]:
# build dataloader

class DEset(Dataset):
    def __init__(self, dataset, cell_type_dict, sm_name_dict, 
                 mean_cell_type, mean_sm_name, compound, gene):
        super(DEset, self).__init__()
        if dataset is not None:
            self.x = dataset.iloc[:, :2]
            self.y = dataset.iloc[:, 5:]
        else:
            self.x = id_map.iloc[:, 1:]
            self.y = None
        self.cell_type_dict = cell_type_dict
        self.sm_name_dict = sm_name_dict
        self.mean_cell_type = mean_cell_type
        self.mean_sm_name = mean_sm_name
        self.compound = compound
        self.gene = gene
        
        
    def __getitem__(self, idx):
        cell, name = self.x.iloc[idx]
        x_cell = self.cell_type_dict[cell]
        x_name = self.sm_name_dict[name]
        x_cell_mean = self.mean_cell_type.iloc[x_cell, self.gene + 1]
        x_name_mean = self.mean_sm_name.iloc[x_name, self.gene + 1]
        #x_compound = np.where(self.compound.loc[name].values==1)[0]
        if self.y is None:
            return torch.tensor([x_cell, x_name], dtype=torch.int64), \
                torch.tensor([x_cell_mean, x_name_mean], dtype=torch.float32)
        else:
            y = self.y.iloc[idx, self.gene]
            return torch.tensor([x_cell, x_name], dtype=torch.int64), \
                    torch.tensor([x_cell_mean, x_name_mean], dtype=torch.float32), \
                    torch.tensor([y], dtype=torch.float32)
                    #torch.tensor(x_compound, dtype=torch.int64), \
                
                
    def __len__(self):
        return self.x.shape[0]
    
batch_size = 614
gene = 0
dataloader = DataLoader(DEset(de_train, cell_type_dict, sm_name_dict, 
                              mean_cell_type, mean_sm_name, compound, gene),
                        batch_size=batch_size, shuffle=False)
x,m,y = next(iter(dataloader))
    
               

In [11]:
# model define
class Basic_Encoder(nn.Module):
    def __init__(self, dim_model):
        super(Basic_Encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(dim_model, dim_model//2),
            nn.BatchNorm1d(dim_model//2),
            nn.ReLU(),
            nn.Dropout(0.2), 
            
            nn.Linear(dim_model//2, dim_model//4),
            nn.BatchNorm1d(dim_model//4),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(dim_model//4, dim_model//8),
            nn.BatchNorm1d(dim_model//8),
            nn.ReLU(),
            
            nn.AdaptiveAvgPool1d(1)
        )
    def forward(self, x):
        return self.encoder(x)
    
class Model(nn.Module):
    def __init__(self, dim_model):
        super(Model, self).__init__()
        self.cell_emb = nn.Embedding(6, dim_model)
        self.name_emb = nn.Embedding(146, dim_model)
        self.cell_enc = Basic_Encoder(dim_model)
        self.name_enc = Basic_Encoder(dim_model)
        self.linear = nn.Linear(4, 1)
        
    def forward(self, xs, x_means):
        cell = self.cell_emb(xs[:,0])
        name = self.name_emb(xs[:,1])
        cell_h = self.cell_enc(cell)
        name_h = self.name_enc(name)
        x = torch.cat([cell_h, name_h, x_means], dim=1)
        x = self.linear(x)
        return x
    

        

In [12]:
# training

def MRRMSE(pred, y):
      pred = pred.detach().cpu().numpy()
      y = y.detach().cpu().numpy()
      return np.sqrt(np.square(y - pred).mean(axis=1)).mean()

def train_model(device, train_rate, batch_size, gene, dim_model, learning_rate, num_epochs, verbose):
      
      train_idx = np.random.choice(de_train.shape[0], int(de_train.shape[0]*train_rate), replace=False)
      train_loader = DataLoader(DEset(de_train.iloc[train_idx, :], cell_type_dict,
                                    sm_name_dict,mean_cell_type, mean_sm_name, compound, gene),
                              batch_size=batch_size, shuffle=True)
      if train_rate < 1.:
            valid_idx = list(set(np.arange(de_train.shape[0])) - set(train_idx))
            valid_loader = DataLoader(DEset(de_train.iloc[valid_idx, :], cell_type_dict,
                                          sm_name_dict,mean_cell_type, mean_sm_name, compound, gene),
                                    batch_size=batch_size, shuffle=True)

      model = Model(dim_model).to(device)
      optimizer = Adam(model.parameters(), lr=learning_rate)
      criterion = nn.SmoothL1Loss()
      best_mrrmse = 10
      for epoch in tqdm(range(1, num_epochs+1)):
            train_loss = 0.
            valid_loss = 0.
            train_mrrmse = 0.
            valid_mrrmse = 0.
            
            model.train()
            for x, m, y in train_loader:
                  x, m, y = x.to(device), m.to(device), y.to(device)
                  optimizer.zero_grad()
                  pred = model(x,m)
                  loss = criterion(pred, y)
                  loss.backward()
                  optimizer.step()
                  
                  train_loss += loss.item()
                  train_mrrmse += MRRMSE(pred, y)
            if train_rate < 1.:
                  model.eval()
                  for x, m, y in valid_loader:
                        x, m, y = x.to(device), m.to(device), y.to(device)
                        with torch.no_grad():
                              pred = model(x,m)
                              loss = criterion(pred, y)
                        
                        valid_loss += loss.item()
                        valid_mrrmse += MRRMSE(pred, y)
                  
            # print loss per epoch
            if verbose:
                  print("[Epoch : %2d] [Loss : %.4f / %.4f] [MRRMSE : %.3f / %.3f]"%(
                        epoch, train_loss/len(train_loader), valid_loss/len(valid_loader),
                        train_mrrmse/len(train_loader), valid_mrrmse/len(valid_loader)))
            if train_rate < 1. and best_mrrmse > valid_mrrmse/len(valid_loader):
                  best_mrrmse = valid_mrrmse/len(valid_loader)
            
      return model, best_mrrmse

def infer_model(device, model, gene):
      test_loader = DataLoader(DEset(None, cell_type_dict, sm_name_dict,
                                      mean_cell_type, mean_sm_name, compound, gene),
                                batch_size=255, shuffle=False)
      model.eval()
      for x, m, in test_loader:
            x, m = x.to(device), m.to(device)
            with torch.no_grad():
                  pred = model(x,m)
      return pred

In [13]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
      
train_rate = 0.7
batch_size = 128
gene = 0
dim_model = 32
learning_rate = 0.025
num_epochs = 100

model, _ = train_model(device, train_rate, batch_size, gene, dim_model, learning_rate, num_epochs, True)

  0%|          | 0/100 [00:00<?, ?it/s]

[Epoch :  1] [Loss : 0.3611 / 0.4131] [MRRMSE : 0.651 / 0.704]
[Epoch :  2] [Loss : 0.4145 / 0.3383] [MRRMSE : 0.705 / 0.620]
[Epoch :  3] [Loss : 0.3591 / 0.3294] [MRRMSE : 0.643 / 0.622]
[Epoch :  4] [Loss : 0.3264 / 0.3472] [MRRMSE : 0.607 / 0.643]
[Epoch :  5] [Loss : 0.3611 / 0.3555] [MRRMSE : 0.648 / 0.653]
[Epoch :  6] [Loss : 0.3379 / 0.3757] [MRRMSE : 0.629 / 0.665]
[Epoch :  7] [Loss : 0.3317 / 0.3901] [MRRMSE : 0.621 / 0.690]
[Epoch :  8] [Loss : 0.3340 / 0.3320] [MRRMSE : 0.625 / 0.635]
[Epoch :  9] [Loss : 0.3086 / 0.3465] [MRRMSE : 0.589 / 0.654]
[Epoch : 10] [Loss : 0.3145 / 0.3856] [MRRMSE : 0.599 / 0.702]
[Epoch : 11] [Loss : 0.3645 / 0.3971] [MRRMSE : 0.654 / 0.713]
[Epoch : 12] [Loss : 0.3363 / 0.4113] [MRRMSE : 0.620 / 0.725]
[Epoch : 13] [Loss : 0.3141 / 0.4255] [MRRMSE : 0.597 / 0.748]
[Epoch : 14] [Loss : 0.3256 / 0.4920] [MRRMSE : 0.612 / 0.812]
[Epoch : 15] [Loss : 0.3084 / 0.3808] [MRRMSE : 0.592 / 0.697]
[Epoch : 16] [Loss : 0.4022 / 0.4007] [MRRMSE : 0.696 /

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
      
train_rate = 1
batch_size = 128
gene = 0
dim_model = 32
learning_rate = 0.025
num_epochs = 100

model, _ = train_model(device, train_rate, batch_size, gene, dim_model, learning_rate, num_epochs, False)
preds = infer_model(device, model, gene)
for gene in range(1, 18211):
    model, _ = train_model(device, train_rate, batch_size, gene, dim_model, learning_rate, num_epochs, False)
    pred = infer_model(device, model, gene)
    preds = torch.cat([preds, pred], dim=1)
data = preds.detach().cpu().numpy()
df = pd.DataFrame(data)
df = df.reset_index(drop=False)
df.columns = submission.columns[:4]
df.to_csv("/home/aiuser/taeuk/scp_B%d_D%d_lr%.3f_E%d.csv"%(batch_size, dim_model, learning_rate, num_epochs),
          header=True, index=False)


In [1]:
# [1] 1567006