In [None]:
# ====================================================
# Configuration
# ====================================================
class CFG:
    debug=False
    apex=True
    print_freq=1000
    scheduler='cosine' # ['linear', 'cosine']
    batch_scheduler=True
    num_cycles=0.5
    num_warmup_steps=0
    epochs=10
    encoder_lr=2e-5
    decoder_lr=2e-5
    min_lr=1e-6
    eps=1e-6
    betas=(0.9, 0.999)
    batch_size=16
    weight_decay=0.01
    gradient_accumulation_steps=1
    max_grad_norm=1000
    hidden_size=256
    num_workers=4
    seed=42
    n_fold=5
    trn_fold=[0, 1, 2, 3, 4]
    label='LN_IC50'
    train=True

if CFG.debug:
    CFG.epochs = 2
    CFG.trn_fold = [0]

In [None]:
# ====================================================
# Library
# ====================================================
import os
import gc
import math
import time
import random
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold
from sklearn.metrics import mean_squared_error

import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup

from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from rdkit.Chem.Fingerprints import FingerprintMols

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

OUTPUT_DIR = './'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

In [None]:
# ====================================================
# Utils
# ====================================================
def get_logger(filename=OUTPUT_DIR+'train'):
    from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=f"{filename}.log")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = get_logger()

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed=42)

In [None]:
# ====================================================
# Data Loading
# ====================================================
overlap_target = np.load("../data/overlap_target.npy",allow_pickle='TRUE').item()
use_tcols = ["Unnamed: 0"] + list(overlap_target.values()) 
ccle_expression = pd.read_csv("../data/CCLE_expression.csv", usecols=use_tcols)
ccle_expression.rename(columns={"Unnamed: 0": "DepMap_ID"}, inplace=True)
use_tcols = use_tcols[1:]
CFG.use_tcols = use_tcols
ccle_depmap_list = ccle_expression.DepMap_ID.values

gdsc = pd.read_csv("../processed_data/gdsc_overlap.csv")
gdsc = gdsc[gdsc.DepMap_ID.isin(ccle_depmap_list)]
gdsc = gdsc.merge(ccle_expression, on='DepMap_ID', how='left')
gdsc.reset_index(drop=True, inplace=True)

print(f"ccle_expression.shape:{ccle_expression.shape}")
display(ccle_expression.head())

print(f"gdsc.shape: {gdsc.shape}")
display(gdsc.head())

In [None]:
# ====================================================
# CV split
# ====================================================

# Fold = GroupKFold(n_splits=CFG.n_fold)
# groups = gdsc["DRUG_NAME"].values
# for n, (train_index, val_index) in enumerate(Fold.split(gdsc, gdsc[CFG.label], groups)):
#     gdsc.loc[val_index, 'fold'] = int(n)

Fold = KFold(n_splits=CFG.n_fold, random_state=CFG.seed, shuffle=True)
for n, (train_index, val_index) in enumerate(Fold.split(gdsc)):
    gdsc.loc[val_index, 'fold'] = int(n)
    
gdsc['fold'] = gdsc['fold'].astype(int)
display(gdsc.groupby('fold').size())

In [None]:
# ====================================================
# Helper functions
# ====================================================
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

In [None]:
# ====================================================
# Dataset
# ====================================================
class TrainDataset(Dataset):
    def __init__(self, cfg, df):
        self.cfg = cfg
        self.drug_embedding = df["SMILES"].values
        self.target_embedding = df[cfg.use_tcols].values.astype(np.float32)
        self.label = df[cfg.label].values.astype(np.float16)
    
    def __len__(self):
        return len(self.drug_embedding)
    
    def __getitem__(self, item):
        drug_embedding = self.smiles2morgan(self.drug_embedding[item])
        target_embedding = self.target_embedding[item]
        label = self.label[item]
        return (drug_embedding, target_embedding), label
    
    def smiles2morgan(self, s, radius=2, nBits=512):
        try:
            mol = Chem.MolFromSmiles(s)
            features_vec = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
            features = np.zeros((1,))
            DataStructs.ConvertToNumpyArray(features_vec, features)
        except:
            print('rdkit not found this smiles for morgan: ' + s + ' convert to all 0 features')
            features = np.zeros((nBits, ))
        return features.astype(np.float32)

In [None]:
# ====================================================
# Model
# ====================================================
class CustomModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.drug_fc = nn.Linear(512, cfg.hidden_size)
        self.target_fc = nn.Linear(len(cfg.use_tcols), cfg.hidden_size)
        self.fc = nn.Linear(cfg.hidden_size*2, 1)
        self.init_layers()
    
    def init_layers(self):
        self._init_weights(self.drug_fc)
        self._init_weights(self.target_fc)
        self._init_weights(self.fc)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform(m.weight)
            m.bias.data.fill_(0.01)
    
    def forward(self, inputs):
        drug_embedding, target_embedding = inputs
        drug_features = F.relu(self.drug_fc(drug_embedding))
        target_features = F.relu(self.target_fc(target_embedding))
        combined_features = torch.cat((drug_features, target_features), axis=1)
        output = self.fc(combined_features)
        return output

In [None]:
def train_fn(fold, train_loader, model, criterion, optimizer, epoch, scheduler, device):
    model.train()
    scaler = torch.cuda.amp.GradScaler(enabled=CFG.apex)
    losses = AverageMeter()
    start = end = time.time()
    global_step = 0
    for step, (inputs, labels) in enumerate(train_loader):
        for k, v in enumerate(inputs):
            inputs[k] = inputs[k].to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        with torch.cuda.amp.autocast(enabled=CFG.apex):
            y_preds = model(inputs)
        loss = criterion(y_preds.view(-1, 1), labels.view(-1, 1))
        loss = loss.reshape([-1]).mean()
        losses.update(loss.item(), batch_size)
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        losses.update(loss.item(), batch_size)
        scaler.scale(loss).backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
        if (step + 1) % CFG.gradient_accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            global_step += 1
            if CFG.batch_scheduler:
                scheduler.step()
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print('Epoch: [{0}][{1}/{2}] '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  'Grad: {grad_norm:.4f}  '
                  'LR: {lr:.8f}  '
                  .format(epoch+1, step, len(train_loader), 
                          remain=timeSince(start, float(step+1)/len(train_loader)),
                          loss=losses,
                          grad_norm=grad_norm,
                          lr=scheduler.get_lr()[0]))
    return losses.avg

def valid_fn(valid_loader, model, criterion, device):
    losses = AverageMeter()
    model.eval()
    preds = []
    start = end = time.time()
    for step, (inputs, labels) in enumerate(valid_loader):
        for k, v in enumerate(inputs):
            inputs[k] = inputs[k].to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        with torch.no_grad():
            y_preds = model(inputs)
        loss = criterion(y_preds.view(-1, 1), labels.view(-1, 1))
        loss = loss.reshape([-1]).mean()
        losses.update(loss.item(), batch_size)
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        losses.update(loss.item(), batch_size)
        preds.append(y_preds.to('cpu').numpy())
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            print('EVAL: [{0}/{1}] '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  .format(step, len(valid_loader),
                          loss=losses,
                          remain=timeSince(start, float(step+1)/len(valid_loader))))
    predictions = np.concatenate(preds)
    return losses.avg, predictions

In [None]:
# ====================================================
# train loop
# ====================================================
def train_loop(folds, fold):
    
    LOGGER.info(f"========== fold: {fold} training ==========")
    
    # ====================================================
    # loader
    # ====================================================
    train_folds = folds[folds['fold'] != fold].reset_index(drop=True)
    valid_folds = folds[folds['fold'] == fold].reset_index(drop=True)
    
    train_dataset = TrainDataset(CFG, train_folds)
    valid_dataset = TrainDataset(CFG, valid_folds)
    
    train_loader = DataLoader(train_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=True,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=False,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    
    # ====================================================
    # model & optimizer
    # ====================================================
    model = CustomModel(CFG)
    model.to(device)
    
    def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0):
        param_optimizer = list(model.named_parameters())
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_parameters = [
            {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': weight_decay},
            {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': 0.0}
        ]
        return optimizer_parameters

    optimizer_parameters = get_optimizer_params(model,
                                                encoder_lr=CFG.encoder_lr, 
                                                decoder_lr=CFG.decoder_lr,
                                                weight_decay=CFG.weight_decay)
    optimizer = AdamW(optimizer_parameters, lr=CFG.encoder_lr, eps=CFG.eps, betas=CFG.betas)
    
    # ====================================================
    # scheduler
    # ====================================================
    def get_scheduler(cfg, optimizer, num_train_steps):
        if cfg.scheduler=='linear':
            scheduler = get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps
            )
        elif cfg.scheduler=='cosine':
            scheduler = get_cosine_schedule_with_warmup(
                optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps, num_cycles=cfg.num_cycles
            )
        return scheduler
    
    num_train_steps = int(len(train_folds) / CFG.batch_size * CFG.epochs)
    scheduler = get_scheduler(CFG, optimizer, num_train_steps)
    
    # ====================================================
    # loop
    # ====================================================
    criterion = nn.MSELoss(reduction="none")
    
    for epoch in range(CFG.epochs):
        
        start_time = time.time()
        
        # train
        avg_loss = train_fn(fold, train_loader, model, criterion, optimizer, epoch, scheduler, device)
        
        # eval
        avg_val_loss, predictions = valid_fn(valid_loader, model, criterion, device)
        
        elapsed = time.time() - start_time
        
        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
        
    valid_folds['oof'] = predictions
    
    torch.cuda.empty_cache()
    gc.collect()
    
    return valid_folds

In [None]:
if __name__ == '__main__':
    
    def get_result(oof_df):
        labels = oof_df[CFG.label].values
        preds = oof_df['oof'].values
        mse_loss = mean_squared_error(labels, preds)
        LOGGER.info(f'MSE loss: {mse_loss:<.4f}')
    
    if CFG.train:
        oof_df = pd.DataFrame()
        for fold in range(CFG.n_fold):
            if fold in CFG.trn_fold:
                _oof_df = train_loop(gdsc, fold)
                oof_df = pd.concat([oof_df, _oof_df])
                LOGGER.info(f"========== fold: {fold} result ==========")
                get_result(_oof_df)
        oof_df = oof_df.reset_index(drop=True)
        LOGGER.info(f"========== CV ==========")
        get_result(oof_df)
    
    LOGGER.info(f"========== All finished ==========")