In [16]:
import jax
import flax
import pandas as pd
from pathlib import Path
from jax import numpy as jnp
from jax import random
from flax import linen as nn
import optax
from flax.training import train_state
from sklearn.preprocessing import StandardScaler



# Load Data

In [2]:
de_train_path = "../data/de_train.parquet";
de_train = pd.read_parquet(de_train_path)
supp_compound_path = "../data/compounds.tsv";
supp_compound = pd.read_csv(supp_compound_path, sep='\t');
finger_features_path = "../data/finger_features.csv";
finger_features = pd.read_csv(finger_features_path);

In [3]:
de_train.head()

Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,...,ZUP1,ZW10,ZWILCH,ZWINT,ZXDA,ZXDB,ZXDC,ZYG11B,ZYX,ZZEF1
0,NK cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.10472,-0.077524,-1.625596,-0.144545,0.143555,...,-0.227781,-0.010752,-0.023881,0.674536,-0.453068,0.005164,-0.094959,0.034127,0.221377,0.368755
1,T cells CD4+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.915953,-0.88438,0.371834,-0.081677,-0.498266,...,-0.494985,-0.303419,0.304955,-0.333905,-0.315516,-0.369626,-0.095079,0.70478,1.096702,-0.869887
2,T cells CD8+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,-0.387721,-0.305378,0.567777,0.303895,-0.022653,...,-0.119422,-0.033608,-0.153123,0.183597,-0.555678,-1.494789,-0.21355,0.415768,0.078439,-0.259365
3,T regulatory cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.232893,0.129029,0.336897,0.486946,0.767661,...,0.451679,0.704643,0.015468,-0.103868,0.865027,0.189114,0.2247,-0.048233,0.216139,-0.085024
4,NK cells,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False,4.290652,-0.063864,-0.017443,-0.541154,0.570982,...,0.758474,0.510762,0.607401,-0.123059,0.214366,0.487838,-0.819775,0.112365,-0.122193,0.676629


In [9]:
finger_features.head()

Unnamed: 0,sm_name,V1,V2,V3,V4,V5,V6,V7,V8,V9,...,V1015,V1016,V1017,V1018,V1019,V1020,V1021,V1022,V1023,V1024
0,Clotrimazole,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,1,0,0,0,0
1,Mometasone Furoate,0,0,0,0,1,0,0,0,0,...,0,0,0,0,1,0,0,1,0,0
2,Idelalisib,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,Vandetanib,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,Bosutinib,0,0,0,1,1,0,0,0,1,...,0,0,0,0,0,0,0,0,0,0


In [4]:
supp_compound.head()

Unnamed: 0.1,Unnamed: 0,pert_id,cmap_name,target,moa,canonical_smiles,inchi_key,compound_aliases,synth_id,SMILES,sm_name
0,0,BRD-K20001883,RVX-208,,,COc1cc(OC)c2c(c1)[nH]c(nc2=O)-c1cc(C)c(OCCO)c(...,NETXMUIMUZJUTB-UHFFFAOYSA-N,apabetalone,RVX-208,,RVX-208
1,1,BRD-K20001883,RVX-208,,,COc1cc(OC)c2c(c1)[nH]c(nc2=O)-c1cc(C)c(OCCO)c(...,NETXMUIMUZJUTB-UHFFFAOYSA-N,apabetalone,RVX-208,,RVX-208
2,2,BRD-K20001883,RVX-208,,,COc1cc(OC)c2c(c1)[nH]c(nc2=O)-c1cc(C)c(OCCO)c(...,NETXMUIMUZJUTB-UHFFFAOYSA-N,apabetalone,RVX-208,,RVX-208
3,3,BRD-K20001883,RVX-208,,,COc1cc(OC)c2c(c1)[nH]c(nc2=O)-c1cc(C)c(OCCO)c(...,NETXMUIMUZJUTB-UHFFFAOYSA-N,apabetalone,RVX-208,,RVX-208
4,4,BRD-A04553218,BRD-A04553218,,,CN(C)CCC(c1ccc(Cl)cc1)c1ccccn1,SOYKEARSMXGVTM-UHFFFAOYSA-N,chlorphenamine,BRD-A04553218,CN(C)CCC(c1ccc(Cl)cc1)c1ccccn1,Chlorpheniramine


In [7]:
supp_compound["moa"].unique()

array([nan, 'KIT inhibitor', 'FLT3 inhibitor', 'PDGFR inhibitor',
       'VEGFR inhibitor', 'Proteasome inhibitor',
       'RET tyrosine kinase inhibitor',
       'Dihydrofolate reductase inhibitor', 'TRPV agonist',
       'ALK inhibitor', 'BCL inhibitor', 'BTK inhibitor', 'CDK inhibitor',
       'IKK inhibitor', 'JAK inhibitor', 'MEK inhibitor', 'PKC inhibitor',
       'PLK inhibitor', 'RAF inhibitor', 'Src inhibitor',
       'EGFR inhibitor', 'FGFR inhibitor', 'HDAC inhibitor',
       'MCL1 inhibitor', 'MTOR inhibitor', 'PI3K inhibitor',
       'SERT inhibitor', 'C-Met inhibitor', 'ErbB2 inhibitor',
       'HMGCR inhibitor', 'IGF-1 inhibitor', 'Ephrin inhibitor',
       'Imidazoline ligand', 'Peptidase inhibitor',
       'AXL kinase inhibitor', 'Abl kinase inhibitor',
       'Adrenergic inhibitor', 'Cell cycle inhibitor',
       'MAP kinase inhibitor', 'Macrophage inhibitor',
       'Microtubule inhibitor', 'Angiogenesis inhibitor',
       'Aurora kinase inhibitor', 'Topoisomerase in

In [8]:
moa = pd.get_dummies(supp_compound['moa'])
sm_moa = supp_compound[['sm_name']].join(moa).drop_duplicates('sm_name')
sm_moa.head()

Unnamed: 0,sm_name,ALK inhibitor,AXL kinase inhibitor,Abl kinase inhibitor,Adenosine receptor antagonist,Adenylyl cyclase activator,Adrenergic inhibitor,Aldehyde dehydrogenase inhibitor,Androgen receptor antagonist,Angiogenesis inhibitor,...,Serotonin reuptake inhibitor,Src inhibitor,Sterol demethylase inhibitor,TRPV agonist,Topoisomerase inhibitor,Tricyclic antidepressant,Tumor necrosis factor production inhibitor,Tyrosine kinase inhibitor,VEGFR inhibitor,Vitamin D receptor agonist
0,RVX-208,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,Chlorpheniramine,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
8,RG7112,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
12,Sunitinib,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
140,MLN 2238,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


# Preprocessing

In [10]:
def add_features(df):
    one_hot = df.join(pd.get_dummies(df['cell_type']), how='left')
    add_chem  = one_hot.merge(finger_features, on='sm_name', how='left')
    one_hot_moa = add_chem.join(pd.get_dummies(supp_compound['moa']), how='left')
    return one_hot_moa



In [12]:
de_train_feats = add_features(de_train)
de_train_feats.head()

Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,...,Serotonin reuptake inhibitor,Src inhibitor,Sterol demethylase inhibitor,TRPV agonist,Topoisomerase inhibitor,Tricyclic antidepressant,Tumor necrosis factor production inhibitor,Tyrosine kinase inhibitor,VEGFR inhibitor,Vitamin D receptor agonist
0,NK cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.10472,-0.077524,-1.625596,-0.144545,0.143555,...,0,0,0,0,0,0,0,0,0,0
1,T cells CD4+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.915953,-0.88438,0.371834,-0.081677,-0.498266,...,0,0,0,0,0,0,0,0,0,0
2,T cells CD8+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,-0.387721,-0.305378,0.567777,0.303895,-0.022653,...,0,0,0,0,0,0,0,0,0,0
3,T regulatory cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.232893,0.129029,0.336897,0.486946,0.767661,...,0,0,0,0,0,0,0,0,0,0
4,NK cells,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False,4.290652,-0.063864,-0.017443,-0.541154,0.570982,...,0,0,0,0,0,0,0,0,0,0


In [13]:
de_train_feats['cell_type'].value_counts()

NK cells              146
T cells CD4+          146
T regulatory cells    146
T cells CD8+          142
B cells                17
Myeloid cells          17
Name: cell_type, dtype: int64

# Extracting and Scaling Training Data

In [14]:
train_cells = ['NK cells', 'T cells CD4+', 'T regulatory cells', 'T cells CD8+']

train = de_train_feats[de_train_feats['cell_type'].isin(train_cells)]
X_pre = train[train.columns[5:]].values

print(f"Training data shape: {X_pre.shape}")

Training data shape: (580, 19319)


In [15]:
train.head()

Unnamed: 0,cell_type,sm_name,sm_lincs_id,SMILES,control,A1BG,A1BG-AS1,A2M,A2M-AS1,A2MP1,...,Serotonin reuptake inhibitor,Src inhibitor,Sterol demethylase inhibitor,TRPV agonist,Topoisomerase inhibitor,Tricyclic antidepressant,Tumor necrosis factor production inhibitor,Tyrosine kinase inhibitor,VEGFR inhibitor,Vitamin D receptor agonist
0,NK cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.10472,-0.077524,-1.625596,-0.144545,0.143555,...,0,0,0,0,0,0,0,0,0,0
1,T cells CD4+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.915953,-0.88438,0.371834,-0.081677,-0.498266,...,0,0,0,0,0,0,0,0,0,0
2,T cells CD8+,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,-0.387721,-0.305378,0.567777,0.303895,-0.022653,...,0,0,0,0,0,0,0,0,0,0
3,T regulatory cells,Clotrimazole,LSM-5341,Clc1ccccc1C(c1ccccc1)(c1ccccc1)n1ccnc1,False,0.232893,0.129029,0.336897,0.486946,0.767661,...,0,0,0,0,0,0,0,0,0,0
4,NK cells,Mometasone Furoate,LSM-3349,C[C@@H]1C[C@H]2[C@@H]3CCC4=CC(=O)C=C[C@]4(C)[C...,False,4.290652,-0.063864,-0.017443,-0.541154,0.570982,...,0,0,0,0,0,0,0,0,0,0


In [17]:
ss = StandardScaler()
ss.fit(X_pre)
X = ss.transform(X_pre)

In [18]:
X

array([[-0.13015146, -0.26051891, -1.13082615, ...,  0.        ,
        -0.24164884,  0.        ],
       [ 0.40567429, -0.94840241,  0.10550155, ...,  0.        ,
        -0.24164884,  0.        ],
       [-0.45541285, -0.45477539,  0.22678218, ...,  0.        ,
        -0.24164884,  0.        ],
       ...,
       [ 0.02404226, -0.28742064,  0.04258263, ...,  0.        ,
        -0.24164884,  0.        ],
       [-0.13251774, -0.54373542, -0.49991861, ...,  0.        ,
        -0.24164884,  0.        ],
       [-0.69940154, -0.12118302, -0.57650432, ...,  0.        ,
        -0.24164884,  0.        ]])

# Creating AutoEncoder

In [19]:
class Encoder(nn.Module):
    c_hid : int
    latent_dim : int
    training: bool

    @nn.compact
    def __call__(self, x):
        x = nn.Dropout(rate=0.10, deterministic=not self.training)(x)
        x = nn.Dense(features=2*self.c_hid)(x)
        x = nn.gelu(x)
        x = nn.Dropout(rate=0.25, deterministic=not self.training)(x)
        x = nn.Dense(features=self.c_hid)(x)
        x = nn.gelu(x)
        x = nn.Dropout(rate=0.25, deterministic=not self.training)(x)
        x = nn.Dense(features=self.c_hid)(x)
        x = nn.gelu(x)
        x = nn.Dropout(rate=0.25, deterministic=not self.training)(x)
        x = nn.Dense(features=self.latent_dim)(x)
        return x
    
    
class Decoder(nn.Module):
    c_out : int
    c_hid : int
    latent_dim : int
    training: bool

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.c_hid)(x)
        x = nn.gelu(x)
        x = nn.Dropout(rate=0.25, deterministic=not self.training)(x)
        x = nn.Dense(features=2*self.c_hid)(x)
        x = nn.gelu(x)
        x = nn.Dropout(rate=0.25, deterministic=not self.training)(x)
        x = nn.Dense(features=2*self.c_hid)(x)
        x = nn.gelu(x)
        x = nn.Dropout(rate=0.25, deterministic=not self.training)(x)
        x = nn.Dense(features=self.c_out)(x)
        x = nn.tanh(x)
        return x

    
class AutoEncoder(nn.Module):
    c_hid: int
    latent_dim : int
    input_dim: int
    training: bool

    def setup(self):
        self.encoder = Encoder(c_hid=self.c_hid,
                               latent_dim=self.latent_dim,
                               training=self.training)
        self.decoder = Decoder(c_hid=self.c_hid,
                               latent_dim=self.latent_dim,
                               c_out=self.input_dim, training=self.training)

    def __call__(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

In [20]:
# ------ model & training hyperparams ------
LATENT_DIM = 512
HIDDEN_BASE_DIM = 1024
INPUT_DIM = X.shape[1]  # 19319
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
EPOCHS=200

# ------------------------------------------
rng = random.PRNGKey(0)
main_key, params_key, dropout_key = jax.random.split(key=rng, num=3)
# ------------------------------------------


# ----------- initialize model -------------
ae = AutoEncoder(
    input_dim=INPUT_DIM,
    c_hid=HIDDEN_BASE_DIM,
    latent_dim=LATENT_DIM,
    training=False
)
variables = ae.init(params_key, jnp.ones([BATCH_SIZE, INPUT_DIM]))
state = train_state.TrainState.create(
        apply_fn = ae.apply,
        tx=optax.adam(LEARNING_RATE),
        params=variables['params']
)


# ------------ fns to drive training -------
@jax.jit
def mse(params, x_batched, y_batched):
    def squared_error(x, y):
        pred = ae.apply({'params': params}, x, rngs={'dropout': dropout_key})
        return jnp.inner(y - pred, y - pred) / 2.0
    return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched), axis=0)

@jax.jit
def train_step(
    state: train_state.TrainState, batch: jnp.ndarray
):

    def loss_fn(params):
        # logits = state.apply_fn({'params': params}, batch)
        loss = mse(params, batch, batch)
        return loss

    gradient_fn = jax.value_and_grad(loss_fn)
    loss, grads = gradient_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss