In [1]:
import math
import pickle
import warnings
import torch
import matplotlib.pyplot as plt
from bayes_opt import BayesianOptimization
from utils import *
from model import TranSiGen
from dataset import TranSiGenDataset

In [2]:
warnings.filterwarnings('ignore')

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Carregamento dos dados

Os dados são carregados a partir de um arquivo `.h5`:

In [5]:
data_path = '../data/LINCS2020/data_example/processed_data_id.h5'

In [6]:
data = load_from_HDF(data_path)

In [7]:
cell_count = len(set(data['cid']))
feat_type = 'KPGT'
batch_size = 64
learning_rate = 1e-3
beta = 0.1
dropout = 0.1
weight_decay = 1e-5
n_folds = 5
random_seed = 364039
split_type = 'smiles_split'
features_dim = 2304
features_embed_dim = [400]
n_latent = 100
# init_mode = 'pretrain_shRNA'
init_mode = 'random'
n_epochs = 300
molecule_path = '../data/LINCS2020/idx2smi.pickle'

In [8]:
local_out = './results/trained_models_{}_cell_{}/{}/feature_{}_init_{}/'.format(cell_count, split_type, random_seed, feat_type, init_mode)

In [9]:
with open(molecule_path, 'rb') as f:
    idx2smi = pickle.load(f)

## Divisão de treino e teste

In [10]:
pair, pairv, pairt = split_data(data, n_folds=n_folds, split_type=split_type, rnds=random_seed)

In [11]:
train = TranSiGenDataset(
    LINCS_index=pair['LINCS_index'],
    mol_feature_type=feat_type,
    mol_id=pair['canonical_smiles'],
    cid=pair['cid']
)

valid = TranSiGenDataset(
    LINCS_index=pairv['LINCS_index'],
    mol_feature_type=feat_type,
    mol_id=pairv['canonical_smiles'],
    cid=pairv['cid']
)

test = TranSiGenDataset(
    LINCS_index=pairt['LINCS_index'],
    mol_feature_type=feat_type,
    mol_id=pairt['canonical_smiles'],
    cid=pairt['cid']
)

train_loader = torch.utils.data.DataLoader(dataset=train, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=4, worker_init_fn=seed_worker)
valid_loader = torch.utils.data.DataLoader(dataset=valid, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=4, worker_init_fn=seed_worker)
test_loader = torch.utils.data.DataLoader(dataset=test, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=4, worker_init_fn=seed_worker)

## Criação do Modelo

In [13]:
pbounds = {
    'n_latent': (10, 600), # dimensão do espaço latente Z1, Z2
    'dropout': (0, 0.9),   # fator de dropout aplicado às redes de enc e dec
    'beta': (0, 1),        # hiperparâmetro beta que multiplica o termo KL na loss do VAE
    'features_embed_dim': (100, 1000) # dimensão do embedding da molécula dentro da rede
}

In [21]:
def train_model(n_latent, dropout, beta, features_embed_dim, is_best=False):
    path_model = f'./results/nl={round(n_latent)}-dp={dropout:.4f}-bt={beta:.4f}-fd={round(features_embed_dim)}--'
    if is_best:
        path_model = './results/'
    model = TranSiGen(
        n_genes=978,
        n_latent=round(n_latent),
        n_en_hidden=[1200],
        n_de_hidden=[800],
        features_dim=features_dim,
        features_embed_dim=[round(features_embed_dim)],
        init_w=True,
        beta=beta,
        device=dev,
        dropout=dropout,
        path_model=path_model,
        random_seed=random_seed
    ).to(dev)
    
    epoch_hist, best_epoch = model.train_model(
        train_loader=train_loader,
        test_loader=valid_loader,
        n_epochs=400,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        save_model=True,
        verbose=True
    )
    
    return -epoch_hist['valid_loss'][best_epoch]

In [16]:
optimizer = BayesianOptimization(
    f=train_model,
    pbounds=pbounds,
    random_state=random_seed,
)

In [17]:
optimizer.maximize(init_points=0, n_iter=20)

[Epoch 50] | loss: 2191081.331, mse_x1_rec: 1967.900, mse_x2_rec: 2071258.592, mse_pert: 43807.536, kld_x1: 275.930, kld_x2: 144986.681, kld_pert: 158.804| valid_loss: 12577014.222, valid_mse_x1_rec: 1867.470, valid_mse_x2_rec: 12421649.778, valid_mse_pert: 46403.326, valid_kld_x1: 326.970, valid_kld_x2: 209825.722, valid_kld_pert: 168.979|
[Epoch 60] | loss: 2830848.356, mse_x1_rec: 3732.533, mse_x2_rec: 2666516.421, mse_pert: 45938.900, kld_x1: 280.417, kld_x2: 224715.978, kld_pert: 185.941| valid_loss: 1981456.444, valid_mse_x1_rec: 1985.725, valid_mse_x2_rec: 1836876.889, valid_mse_pert: 47753.903, valid_kld_x1: 393.437, valid_kld_x2: 185677.028, valid_kld_pert: 185.780|
[Epoch 70] | loss: 496207.618, mse_x1_rec: 1880.119, mse_x2_rec: 430209.499, mse_pert: 43433.997, kld_x1: 282.737, kld_x2: 40202.179, kld_pert: 136.486| valid_loss: 308384.167, valid_mse_x1_rec: 1804.393, valid_mse_x2_rec: 242801.833, valid_mse_pert: 47722.896, valid_kld_x1: 327.172, valid_kld_x2: 31067.556, valid_

In [18]:
optimizer.max

{'target': np.float64(-2686.9559461805557),
 'params': {'beta': np.float64(0.6542363348175738),
  'dropout': np.float64(0.08043982573185449),
  'features_embed_dim': np.float64(173.03576615467415),
  'n_latent': np.float64(79.53801414895078)}}

## Treinamento do Modelo

In [24]:
model = TranSiGen(
    n_genes=978,
    n_latent=80,
    n_en_hidden=[1200],
    n_de_hidden=[800],
    features_dim=features_dim,
    features_embed_dim=[173],
    init_w=True,
    beta=0.6542363348175738,
    device=dev,
    dropout=0.08043982573185449,
    path_model='',
    random_seed=random_seed
).to(dev)

In [22]:
train_model(**optimizer.max['params'], is_best=True)

[Epoch 0] | loss: nan, mse_x1_rec: 12369726605.785, mse_x2_rec: 5125201.427, mse_pert: nan, kld_x1: 5086023506.708, kld_x2: 2361606.735, kld_pert: inf| valid_loss: nan, valid_mse_x1_rec: 5487286499.556, valid_mse_x2_rec: 7469121.778, valid_mse_pert: nan, valid_kld_x1: 4047532487.111, valid_kld_x2: 2936835.778, valid_kld_pert: inf|
[Epoch 10] | loss: 1546883.388, mse_x1_rec: 56555.530, mse_x2_rec: 715337.479, mse_pert: 24970.572, kld_x1: 2390.581, kld_x2: 10877.541, kld_pert: 1133136.900| valid_loss: 1783892.444, valid_mse_x1_rec: 88785.174, valid_mse_x2_rec: 633000.556, valid_mse_pert: 9738.804, valid_kld_x1: 3093.194, valid_kld_x2: 11583.564, valid_kld_pert: 1593867.333|
[Epoch 20] | loss: 554135.116, mse_x1_rec: 186737.786, mse_x2_rec: 193460.092, mse_pert: 95054.077, kld_x1: 2647.781, kld_x2: 2541.327, kld_pert: 115383.779| valid_loss: 719811.722, valid_mse_x1_rec: 231110.806, valid_mse_x2_rec: 295983.611, valid_mse_pert: 103029.715, valid_kld_x1: 2734.417, valid_kld_x2: 3307.986, v

-2444.4021267361113

In [28]:
model = torch.load('results/best_model.pt')

## Avaliação do Modelo no Conjunto Teste

In [29]:
_, _, test_metrics_dict_ls = model.test_model(loader=test_loader, metrics_func=['pearson'])

for name, rec_dict_value in zip(['test'], [test_metrics_dict_ls]):
    df_rec = pd.DataFrame.from_dict(rec_dict_value)
    smi_ls = []
    for smi_id in df_rec['cp_id']:
        smi_ls.append(idx2smi[smi_id])
    df_rec['canonical_smiles'] = smi_ls

In [30]:
df_rec

Unnamed: 0,x1_rec_pearson,x2_rec_pearson,x2_pred_pearson,DEG_rec_pearson,DEG_pred_pearson,cp_id,cid,sig,canonical_smiles
0,0.944225,0.926119,0.922537,0.199172,0.166481,6,A375,A375_6,Brc1c[nH]c2nc(SCc3ccccc3C#N)nc2c1
1,0.844968,0.853367,0.849935,0.163265,0.171045,6,HT29,HT29_6,Brc1c[nH]c2nc(SCc3ccccc3C#N)nc2c1
2,0.942965,0.918368,0.897666,0.295777,0.269045,9,A549,A549_9,Brc1ccc(CSc2nnc(c3ccccn3)n2Cc4ccco4)cc1
3,0.977116,0.955755,0.947049,0.571797,0.529164,13,ASC,ASC_13,Brc1ccc(NC(=O)N2NC(=O)[C@H]([C@@H]2c2ccccc2)c2...
4,0.972269,0.95442,0.945741,0.421421,0.303794,9,PC3,PC3_9,Brc1ccc(CSc2nnc(c3ccccn3)n2Cc4ccco4)cc1
5,0.986818,0.971507,0.966324,0.396427,0.387523,13,HA1E,HA1E_13,Brc1ccc(NC(=O)N2NC(=O)[C@H]([C@@H]2c2ccccc2)c2...
6,0.957665,0.950882,0.941815,0.342316,0.271578,6,HA1E,HA1E_6,Brc1c[nH]c2nc(SCc3ccccc3C#N)nc2c1
7,0.942791,0.963365,0.952109,0.578217,0.48659,13,A549,A549_13,Brc1ccc(NC(=O)N2NC(=O)[C@H]([C@@H]2c2ccccc2)c2...
8,0.911938,0.94271,0.932689,0.604693,0.549577,9,VCAP,VCAP_9,Brc1ccc(CSc2nnc(c3ccccn3)n2Cc4ccco4)cc1
9,0.942084,0.927439,0.921694,0.175447,0.146525,6,A549,A549_6,Brc1c[nH]c2nc(SCc3ccccc3C#N)nc2c1
