In [1]:
import numpy as np 
import pandas as pd 

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import (RobertaTokenizerFast, RobertaForMaskedLM, DataCollatorWithPadding, 
AutoModelForSequenceClassification, AutoConfig)

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import CSVLogger
import os
import gc

from sklearn.metrics import average_precision_score

In [9]:
MODEL_NAME = "DeepChem/ChemBERTa-77M-MTR"

CHECKPOINT_PATH = "../chemberta_all_train_v2/fold_0_epoch_2_step_70615_0.0112.ckpt"

POS_WEIGHT = torch.tensor([1, 1, 1])

LR = 1e-4
BATCH_SIZE = 1024 
ACCUMULATE_GRAD_BATCHES = 4
EPOCHS = 3
OUT_DIR = "../chemberta_all_train_v2_continue"
SEED = 42
NUM_WORKERS = 4
os.makedirs(OUT_DIR, exist_ok = True)

os.makedirs("../submits", exist_ok = True)

In [3]:
def compute_ap(df):
    df = pd.concat(
        [df.loc[:, ["BRD4", "HSA", "sEH", "label_id", "split"]].melt(
            id_vars = ["label_id", "split"], 
            var_name = "protein",
            value_name = "y_true"
            ),
        df.loc[:, ["BRD4_pred", "HSA_pred", "sEH_pred", "label_id", "split"]].melt(
            id_vars = ["label_id", "split"],
            var_name = "_",
            value_name = "y_pred"
            ).drop("split", axis = 1)],
        axis = 1
    )

    ap = df.groupby(["split", "protein"]).apply(
        lambda x: average_precision_score(x.y_true, x.y_pred),
        #include_groups = True
    ).reset_index(name = "AP")

    return(ap)

In [6]:
df = pd.read_parquet("data/train_no_test_wide.parquet")
df["id"] = df.reset_index().index
df["split"] = 0

df

Unnamed: 0,molecule_smiles,BRD4,HSA,sEH,id,split
0,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCCCN2C(=O)NC(C)(...,0,0,0,0,0
1,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCc2ncc(C)o2)nc(N...,0,0,0,1,0
2,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCc2cnc(Cl)s2)nc(...,0,0,0,2,0
3,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCCOc2cccnc2)nc(N...,0,0,0,3,0
4,C#CC[C@@H](CC(=O)N[Dy])Nc1nc(NCc2cnc(Cl)s2)nc(...,0,0,0,4,0
...,...,...,...,...,...,...
96415605,Cc1ccc(S(C)(=O)=O)cc1Nc1nc(Nc2cnc(Cl)cc2Cl)nc(...,0,0,0,96415605,0
96415606,Cc1ccc(S(C)(=O)=O)cc1Nc1nc(Nc2nc(Cl)c3[nH]cnc3...,0,0,0,96415606,0
96415607,Cc1ccc(S(C)(=O)=O)cc1Nc1nc(Nc2nc(OCc3ccccc3)c3...,0,0,0,96415607,0
96415608,Cc1ccc(S(C)(=O)=O)cc1Nc1nc(Nc2nc3c(Br)cccc3s2)...,0,0,0,96415608,0


In [10]:
df_val = pd.read_parquet("data/test_ensemble_wide.parquet").sample(n = 200_000, random_state = 42)
df_val["id"] = df_val.reset_index().index
df_val["split"] = 1

df_val

Unnamed: 0,molecule_smiles,BRD4,HSA,sEH,id,split
1828401,CN(C)C(CNc1nc(NCc2cn(C)c(=O)[nH]c2=O)nc(Nc2ccc...,0,0,0,0,1
1200071,CCOC(=O)c1nonc1Nc1nc(Nc2cc(Cl)ccc2C(=O)N[Dy])n...,0,0,0,1,1
194849,COc1ccc([C@H](Nc2nc(Nc3nc4c(s3)CCCC4)nc(Nc3cc(...,0,0,0,2,1
1629054,[N-]=[N+]=NCCC[C@H](Nc1nc(NCCCN2C(=O)CCC2=O)nc...,0,0,0,3,1
191144,O=C(CCNc1nc(Nc2cc(F)c(F)cc2C(=O)N[Dy])nc(Nc2nc...,0,0,0,4,1
...,...,...,...,...,...,...
863442,CC(CCNc1nc(Nc2ccc(F)c([N+](=O)[O-])c2)nc(Nc2cc...,0,0,0,199995,1
1456308,Cn1nccc1[C@@H]1OCC[C@H]1CNc1nc(Nc2ncc(F)cn2)nc...,0,0,0,199996,1
1068734,Cc1ccc(S(C)(=O)=O)cc1Nc1nc(Nc2ccnc(C(=O)N[Dy])...,0,0,0,199997,1
484203,N#Cc1cc(Nc2nc(NCc3cc(=O)nc[nH]3)nc(N[C@H](CC(=...,0,0,0,199998,1


In [11]:
class SmilesDataset(Dataset):
    def __init__(
        self, 
        df, 
        target_cols = ["BRD4", "HSA", "sEH"],
        mode = "train"
    ): 
        self.df = df
        self.target_cols = target_cols
        self.mode = mode
        self.tokenizer = RobertaTokenizerFast.from_pretrained(
            MODEL_NAME, max_len = 140
        )
        self.collator = DataCollatorWithPadding(
            self.tokenizer, 
            padding = "max_length", 
            max_length = 140, 
            return_tensors = "pt"
        )
        
    def __len__(self):
        """
        Length of dataset.
        """
        return len(self.df)
        
    def __getitem__(self, index):
        """
        Get one item.
        """
        if self.mode != "test":
            X, y, label_ids, split = self.__data_generation(index)
            return X, y, label_ids, split
        else:
            X = self.__data_generation(index)
            return X
                        
    def __data_generation(self, index):
        
        row = self.df.iloc[index]
        id = row.id
        split = row.split
        
        
        X = row.molecule_smiles
        X = self.collator(self.tokenizer(X))
        
        if self.mode != "test":
            y = row[self.target_cols].to_numpy().astype(np.float32)
            y = torch.tensor(y, dtype = torch.float32)
            return X, y, id, split
        else:
            return X

In [12]:
class EMATracker:
    def __init__(self, alpha: float = 0.05):
        super().__init__()
        self.alpha = alpha
        self._value = None

    def update(self, new_value):
        if self._value is None:
            self._value = new_value
        else:
            self._value = (
                new_value * self.alpha +
                self._value * (1-self.alpha)
            )

    @property
    def value(self):
        return self._value

class SmilesRoberta(L.LightningModule):
    
    def __init__(self, lr, epochs, steps_per_epoch, lr_scheduler = False, fold = 0):
        super().__init__()
        self.lr = lr
        self.epochs = epochs
        self.steps_per_epoch = steps_per_epoch
        self.lr_scheduler = lr_scheduler
        self.fold = fold
        self.base_model = AutoModelForSequenceClassification.from_pretrained(
            MODEL_NAME, 
            num_labels = 3
        )
        self.model_config = AutoConfig.from_pretrained(MODEL_NAME)
        self.loss = nn.BCEWithLogitsLoss(pos_weight = POS_WEIGHT)
        self.train_loss_tracker = EMATracker(alpha = 0.02)
        self.validation_step_outputs = []
        
    def forward(self, x):
        out = self.base_model(**x, output_hidden_states = True)
        return out
    
    def training_step(self, batch, batch_idx):
        x, y, label_ids, split = batch
        out = self.forward(x)
        loss = self.loss(out["logits"], y)
        self.log("train_loss", loss, on_step = True, on_epoch = True, prog_bar = True)
        return {"loss": loss}
    
    def validation_step(self, batch, batch_idx):
        x, y, label_ids, split = batch
        out = self.forward(x)
        val_loss = self.loss(out["logits"], y)
        self.log("val_loss", val_loss, on_step = False, on_epoch = True, logger = True, prog_bar = True)
        self.validation_step_outputs.append(
            {"val_loss": val_loss, "predictions": out["logits"], "targets": y, "label_ids": label_ids, "split": split}
        )
        return {"val_loss": val_loss}
    
    def on_validation_end(self):
        outputs = self.validation_step_outputs
        output_val = torch.cat([x['predictions'] for x in outputs], dim = 0)
        output_val = torch.sigmoid(output_val).to(torch.float32).cpu().detach().numpy()
        target_val = torch.cat([x['targets'] for x in outputs], dim = 0).to(torch.float32).cpu().detach().numpy()
        label_ids = torch.cat([x['label_ids'] for x in outputs], dim = 0).to(torch.float32).cpu().detach().numpy()
        self.validation_step_outputs = []

        TARGETS = ["BRD4", "HSA", "sEH"]
        PREDS = [f"{t}_pred" for t in TARGETS]
        val_df = pd.DataFrame(target_val, columns = list(TARGETS))
        pred_df = pd.DataFrame(output_val, columns = list(PREDS))
        val_df = pd.concat([val_df, pred_df], axis = 1)
        val_df["label_id"] = label_ids
        val_df["split"] = "val"
        
        ap = compute_ap(val_df)
        ap["epoch"] = self.current_epoch
        ap["step"] = self.global_step       
        ap.to_csv(f"{OUT_DIR}/metrics.csv", index = False, mode = "a",
                  header = not os.path.exists(f"{OUT_DIR}/metrics.csv"))
        
        map = ap.AP.mean()

        print(f"validation MAP: {map}")

        val_df.to_csv(f"{OUT_DIR}/val_df_fold_{self.fold}_epoch_{self.current_epoch}_step_{self.global_step}_{map:.4f}.csv", 
                      index = False)
        
    
    def predict_step(self, batch, batch_idx, dataloader_idx = 0):
        return F.sigmoid(self(batch))
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(), 
            lr = self.lr,
            betas = (0.9, 0.999),
            eps = 1e-08,
            weight_decay = 0.01
        )
        if self.lr_scheduler:
            scheduler_dict = {
            "scheduler": OneCycleLR(
                optimizer,
                self.lr,
                total_steps =  math.floor(self.steps_per_epoch * self.epochs / ACCUMULATE_GRAD_BATCHES)
            ),
            "interval": "step",
        }
            return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}
        else:
            return optimizer

In [None]:
train_ds = SmilesDataset(df)
train_loader = DataLoader(
    train_ds, 
    shuffle = True, 
    batch_size = BATCH_SIZE, 
    num_workers = NUM_WORKERS
)

valid_ds = SmilesDataset(df_val)
valid_loader = DataLoader(
    valid_ds, 
    shuffle = False, 
    batch_size = BATCH_SIZE, 
    num_workers = 1
)
        
logger = CSVLogger(save_dir = OUT_DIR, prefix = f"fold_{0}")

checkpoint_callback = ModelCheckpoint(
    dirpath = OUT_DIR,
    monitor = "val_loss",
    save_top_k = -1,
    save_last = False,
    save_weights_only = True,
    filename = f"fold_{0}" + "_epoch_{epoch}_step_{step}_{val_loss:.4f}",
    save_on_train_epoch_end = False,
    verbose = True,
    auto_insert_metric_name = False,
    mode = "min"
)

early_stop_callback = EarlyStopping(
    monitor = "val_loss", 
    min_delta = 0.00, 
    patience = 10, 
    verbose = False, 
    mode = "min"
)

model = SmilesRoberta.load_from_checkpoint(
    CHECKPOINT_PATH,
    lr = LR, 
    epochs = EPOCHS,
    steps_per_epoch = len(train_loader),
    lr_scheduler = False
)

trainer = L.Trainer(
    max_epochs = EPOCHS, 
    deterministic = True, 
    accumulate_grad_batches = ACCUMULATE_GRAD_BATCHES,
    logger = logger,
    val_check_interval = 1/25,
    log_every_n_steps = 100,
    callbacks = [checkpoint_callback],
    devices = 1, 
    precision = "bf16-mixed"
)
        
trainer.fit(
    model = model, 
    train_dataloaders = train_loader,
    val_dataloaders = valid_loader
)

# Submit

In [None]:
def make_submit(checkpoint):

    model = SmilesRoberta.load_from_checkpoint(
        checkpoint,
        lr = LR, 
        epochs = EPOCHS,
        steps_per_epoch = 1,
        lr_scheduler = False
    )

    model = model.eval()

    df_test = pd.read_csv("../data/test.csv")
    df_test_source = df_test.copy()

    df_test = df_test[["molecule_smiles"]].drop_duplicates()
    df_test["id"] = df_test.reset_index().index
    df_test = df_test.reset_index(drop = True)
    df_test["split"] = 0

    dataset = SmilesDataset(df_test, mode = "test")
    test_loader = DataLoader(dataset, batch_size = 128, shuffle = False)

    preds = []
    with torch.inference_mode():
        for i, test_batch in enumerate(tqdm(test_loader)):
            test_batch = test_batch.to("cuda")
            out = model(test_batch)
            pred = torch.sigmoid(out["logits"]).cpu().numpy()
            preds.append(pred)
            
    preds = np.concatenate(preds)

    submit = pd.DataFrame(preds, columns = ["BRD4", "HSA", "sEH"])
    submit = submit.reset_index(drop = True)

    submit = pd.concat([df_test, submit], axis = 1)

    submit = pd.melt(
        submit, 
        id_vars = ["molecule_smiles"], 
        value_vars = ["BRD4", "HSA", "sEH"], 
        value_name = "binds", 
        var_name = "protein_name"
    )
    submit = pd.merge(
        df_test_source, 
        submit, 
        how = "outer",
        on = ["molecule_smiles", "protein_name"]
    )
    submit = submit[["id", "binds"]]
        
    submit[submit["binds"].isna()] = 0
    submit = submit.dropna()

    submit = submit.drop_duplicates()

    return submit

In [None]:
CHECKPOINT = "../chemberta_all_train_v2_continue/fold_0_epoch_2_step_70615_0.0109.ckpt"

submit = make_submit(CHECKPOINT)

submit.to_csv(f"../submits/chemberta_all_train_v2_continue.csv", index = False)