In [12]:
from pathlib import Path
import numpy as np
import polars as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchmetrics import AveragePrecision
import lightning as L
from lightning.pytorch.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    TQDMProgressBar,
)
from transformers import AutoConfig, AutoTokenizer, AutoModel, DataCollatorWithPadding

import gc

import torch
import torch.nn.functional as F

from torch.utils.data import Dataset
import polars as pl
import numpy as np

In [None]:
# PROTEIN_NAMES = ["binds_BRD4", "binds_HSA", "binds_sEH"]
PROTEIN_NAMES = ["BRD4", "HSA", "sEH"]
data_dir = Path("/tokenized-chemberta")
model_name = "ChemBERTa-77M-MTR"
batch_size = 1024 #512

trainer_params = {
  "max_epochs": 3,
  "enable_progress_bar": True,
  "accelerator": "auto",
  # "precision": "16-mixed",
  "precision": "16-mixed",
  "gradient_clip_val": None,
  "accumulate_grad_batches": 6,
  "devices": [0,1,2,3],
  # 'strategy': 'ddp_spawn',
}

seed = 42

In [None]:
import sys

class CustomDataset(Dataset):
    def __init__(self, df):
        self.input_ids = np.array(df['input_ids'])
        self.attention_masks = np.array(df['attention_mask'])
        self.labels = np.array(df[PROTEIN_NAMES])
        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        data = {
            "input_ids": torch.tensor(self.input_ids[index],dtype=torch.int32),
            "attention_mask": torch.tensor(self.attention_masks[index],dtype=torch.bool),
            "labels": torch.tensor(self.labels[index],dtype=torch.bool),  
        }
        return data


In [None]:
train_tokenize = pl.read_parquet(
                                "train_tokenized_77M-MTR_replaced_dy.parquet"
                                #  n_rows=10000
                                 )
train_dataset = CustomDataset(train_tokenize)
del train_tokenize

valid_tokenize = pl.read_parquet(
                                'valid_tokenized_77M-MTR_replaced_dy.parquet',
                                #  ,n_rows=10000
                                 )
valid_dataset = CustomDataset(valid_tokenize)
del valid_tokenize
gc.collect()

all_data = pl.concat([train_tokenize, valid_tokenize])
len_all_data = len(all_data)
all_data_y = all_data[PROTEIN_NAMES].sum_horizontal()

In [5]:
class LMModel3(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        
        self.config = AutoConfig.from_pretrained("DeepChem/"+model_name, num_labels=3)
        self.lm = AutoModel.from_pretrained("DeepChem/"+model_name, add_pooling_layer=False)
        self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
        
        self.intermediate = nn.Linear(self.config.hidden_size, 64)
        self.intermediate_activation = nn.ReLU()
        
        self.classifier = nn.Linear(64, self.config.num_labels)
        self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([100.0]),reduction="mean")

    def forward(self, batch):
        last_hidden_state = self.lm(
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
        ).last_hidden_state
        
        
        attention_mask = batch["attention_mask"]
        mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * mask_expanded, 1)
        sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
        output = sum_embeddings / sum_mask
        
        output = self.dropout(output)
        output = self.intermediate(output)
        output = self.intermediate_activation(output)
        
        logits = self.classifier(output)
        return {
            "logits": logits,
        }

    def calculate_loss(self, batch):
        output = self.forward(batch)
        
        loss = self.loss_fn(output["logits"], batch["labels"].float())
       
        output["loss"] = loss
        return output  

In [None]:
class LBModelModule(L.LightningModule):
    def __init__(self, model_name, batch_size):
        super().__init__()
        self.model = LMModel3(model_name)
        self.map = AveragePrecision(task="binary")
        self.map_per_class = [AveragePrecision(task="binary") for _ in range(3)]
        self.batch_size = batch_size

    def forward(self, batch):
        return self.model(batch)
    def calculate_loss(self, batch, batch_idx):
        return self.model.calculate_loss(batch)

    def training_step(self, batch, batch_idx):
        ret = self.calculate_loss(batch, batch_idx)
        self.log("train_loss", ret["loss"], on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        return ret["loss"]

    def validation_step(self, batch, batch_idx):
        ret = self.calculate_loss(batch, batch_idx)
        self.log("val_loss", ret["loss"], on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.map.update(F.sigmoid(ret["logits"]), batch["labels"].long())

        for i in range(3):
            self.map_per_class[i].update(F.sigmoid(ret["logits"])[:, i], batch["labels"].long()[:, i])
    def on_validation_epoch_end(self):
        val_map = self.map.compute()
        self.log("val_map", val_map, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        for i in range(3):
            val_map = self.map_per_class[i].compute()
            self.log(f"val_map_{PROTEIN_NAMES[i]}", val_map, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
            self.map_per_class[i].reset()
        self.map.reset()
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        logits = self.forward(batch)["logits"]
        probs = F.sigmoid(logits)
        return probs

    def train_dataloader(self):
        # Return your dataloader here
        return DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True,num_workers=0,collate_fn=DataCollatorWithPadding(tokenizer))

    def val_dataloader(self):
        # Return your dataloader here
        return DataLoader(valid_dataset, batch_size=self.batch_size, shuffle=False,num_workers=0, #4 pin_memory=True,
                              collate_fn=DataCollatorWithPadding(tokenizer))


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience= 2, verbose=True)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val_loss",
                "interval": "epoch",
                "frequency": 1
            }
        }

In [None]:
tokenizer = AutoTokenizer.from_pretrained("DeepChem/"+model_name)

In [None]:
from sklearn.model_selection import StratifiedKFold

n_splits = 5
kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
FOLD = [0,1,2,3,4]

train_dataset = None
valid_dataset = None

for fold, (train_idx, val_idx) in enumerate(kf.split(np.zeros(len_all_data), all_data_y)):
    print(f'Fold {fold + 1}/{n_splits}')

    if fold not in FOLD:
        continue;
    
    all_data = pl.read_parquet("all_data.parquet")

    train = all_data[train_idx]
    valid = all_data[val_idx][:10000]

    del all_data
    gc.collect()
    
    train_dataset = CustomDataset(train)
    valid_dataset = CustomDataset(valid)
    
    del train, valid

    modelmodule = LBModelModule(model_name, batch_size, train_dataset, valid_dataset)

    EXP_NAME = f'5fold_chemberta_model3_e3_fold{fold + 1}'

    checkpoint_callback = ModelCheckpoint(
        filename=f"model_{model_name}_fold{fold + 1}_{{val_map:.4f}}",
        save_weights_only=True,
        monitor="val_map",
        mode="max",
        dirpath=f"chemberta_v3_e3_5fold/fold{fold+1}",
        save_top_k=5,
        verbose=1,
    )


    early_stop_callback = EarlyStopping(monitor="val_map", mode="max", patience=5)
  
    progress_bar_callback = TQDMProgressBar(refresh_rate=1)
   
    callbacks = [checkpoint_callback, early_stop_callback, progress_bar_callback]
  
    trainer = L.Trainer(callbacks=callbacks, **trainer_params)

    trainer.fit(modelmodule) #, train_dataloader, valid_dataloader
    del train_dataset, valid_dataset # , train_dataloader, valid_dataloader
    # gc.collect()

# Inference

In [9]:
class CustomTestDataset(Dataset):
    def __init__(self, df):
        self.input_ids = np.array(df['input_ids'])
        self.attention_masks = np.array(df['attention_mask'])
        
    def __len__(self):
        return len(self.attention_masks)

    def __getitem__(self, index):
        data = {
            "input_ids": torch.tensor(self.input_ids[index],dtype=torch.int32),
            "attention_mask": torch.tensor(self.attention_masks[index],dtype=torch.bool),
        }
        return data

In [None]:
trainer = L.Trainer(callbacks=callbacks, **trainer_params)
EPOCHS = [1,2,3,4,5]
# !mv /kaggle/input/lesh-bio-model-weights/model_ChemBERTa-77M-MTR_fold1_epoch3.ckpt /kaggle/input/leash-bio-model-weights/chemberta_v3_5fold/fold1
tokenizer = AutoTokenizer.from_pretrained("DeepChem/"+model_name)
test_tokenize = pl.read_parquet(Path(data_dir, f'test_tokenized_77M_replace_dy.parquet'))
test_dataset = CustomTestDataset(test_tokenize)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,pin_memory=True,num_workers=1,
                              collate_fn=DataCollatorWithPadding(tokenizer))

fold_predictions = []

for EPOCH in EPOCHS:
    fold_predictions = []
    for FOLD in range(1,6):
    #     model_path = Path(f'/kaggle/input/leash-bio-model-weights/chemberta_v3_5fold/fold{FOLD}/model_ChemBERTa-77M-MTR_fold{FOLD}_epoch{EPOCH}.ckpt')
        model_path = f'model_ChemBERTa-77M-MTR_fold{FOLD}_epoch{EPOCH}.ckpt'
        # /home/sato/kag/chemberta_output/model2/model_ChemBERTa-77M-MTR_val_map=0.4487.ckpt
        print(model_path)
        modelmodule = LBModelModule.load_from_checkpoint(
            checkpoint_path=model_path,
            model_name=model_name,
        )

        predictions = trainer.predict(modelmodule, test_dataloader)

        predictions = torch.cat(predictions).numpy()
        fold_predictions.append(predictions)
    avg_predictions = sum(fold_predictions) / len(fold_predictions)


    pred_dfs = []
    for i, protein_name in enumerate(PROTEIN_NAMES):
        pred_dfs.append(
            test_tokenize.with_columns(
                pl.lit(protein_name).alias("protein_name"),
                pl.lit(avg_predictions[:, i]).alias("binds"),
            )
        )
    
    pred_df = pl.concat(pred_dfs)

    submit_df = (
        pl.read_parquet("/kaggle/input/leash-BELKA/test.parquet", columns=["id", "molecule_smiles", "protein_name"])
        .join(pred_df, on=["id", "protein_name"], how="left")
        .select(["id", "binds"])
        .sort("id")
    )
    
    submit_df.group_by('id').mean().write_csv(f"chemberta_5fold_{EPOCH}.csv")
    print(submit_df)