In [1]:
from functools import partial

import datasets as hds
import pandas as pd
import rdkit.Chem as Chem
import torch
from astartes import train_val_test_split
from rdkit.Chem.Descriptors import CalcMolDescriptors
from rdkit.rdBase import BlockLogs
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
    AutoModel,
    AutoTokenizer,
    DataCollatorWithPadding,
    get_constant_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)

from data import PairDS, pair_collate
from model import DeltaModel
from utils import parallelize, standardize

import wandb
wandb.init(project="delta",name="init_v6")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrahul-e-dev[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
@parallelize(8)
def standardize_df(df):
    df["inchi"] = df["smiles"].map(standardize)
    return df

In [3]:
def load_kasa_regression():
    df = pd.read_csv("./KasA_SMM_regression.csv")
    df = df[["SMILES", "Average Average Z Score"]]
    df = df.rename({"SMILES": "smiles", "Average Average Z Score": "score"}, axis=1)

    # df = standardize_df(df)

    splits = train_val_test_split(
        X=df["smiles"].to_numpy(),
        y=df["score"].to_numpy(),
        sampler="scaffold",
        random_state=42,
        return_indices=True,
    )

    train_ids, val_ids, test_ids = splits[-3], splits[-2], splits[-1]
    df_train = df.iloc[train_ids].reset_index(drop=True)
    df_val = df.iloc[val_ids].reset_index(drop=True)
    df_test = df.iloc[test_ids].reset_index(drop=True)

    return hds.DatasetDict(
        {
            "train": hds.Dataset.from_pandas(df_train),
            "val": hds.Dataset.from_pandas(df_val),
            "test": hds.Dataset.from_pandas(df_test),
        }
    )

In [4]:
def tokenize(entry, tokenizer):
    entry = dict(entry)
    smiles = entry.pop("smiles")
    encoded = tokenizer(
        smiles,
        truncation=False,
        return_attention_mask=True,
        return_special_tokens_mask=True,
    )

    return encoded

In [5]:
def train_step(model, batch, optimizer, scheduler):
    optimizer.zero_grad()
    model.train()
    device = next(model.parameters()).device

    batch = {k: v.to(device) for k,v in batch.items()}
    losses = model(batch)
    losses['loss'].backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
    optimizer.step()
    scheduler.step()
    return losses

@torch.no_grad()
def val_step(model, batch):
    model.eval()
    device = next(model.parameters()).device
    batch = {k: v.to(device) for k,v in batch.items()}
    losses = model(batch)
    return losses

In [6]:
ds = load_kasa_regression()
tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
tok_func = partial(tokenize, tokenizer=tokenizer)
ds = ds.map(tok_func, num_proc=8).remove_columns(['smiles'])

tokenizer.pad_token = tokenizer.eos_token
padding_collator = DataCollatorWithPadding(tokenizer)

c:\Users\rahul\mambaforge\envs\bio\lib\site-packages\astartes\samplers\extrapolation\scaffold.py:44: NoMatchingScaffold: No matching scaffold was found for the 3 molecules corresponding to indices {23234, 36709, 11879}


Map (num_proc=8):   0%|          | 0/39777 [00:00<?, ? examples/s]

Map (num_proc=8):   0%|          | 0/4972 [00:00<?, ? examples/s]

Map (num_proc=8):   0%|          | 0/4972 [00:00<?, ? examples/s]

In [7]:
model = DeltaModel(AutoModel.from_pretrained('DeepChem/ChemBERTa-77M-MTR')).cuda()
train_dl = DataLoader(
    PairDS(ds['train']), 
    batch_size=1024, 
    shuffle=True, 
    collate_fn=partial(pair_collate, base_collator=padding_collator)
)

val_dl = DataLoader(
    PairDS(ds['val']), 
    batch_size=1024, 
    shuffle=False, 
    collate_fn=partial(pair_collate, base_collator=padding_collator)
)

optimizer = AdamW(model.parameters(), lr=2e-4)
n_epochs = 20
n_steps = len(train_dl) * n_epochs
scheduler = get_cosine_schedule_with_warmup(optimizer, n_steps*0.01, n_steps*(1-0.01))
# scheduler = get_constant_schedule_with_warmup(optimizer, n_steps*0.01)

Some weights of the model checkpoint at DeepChem/ChemBERTa-77M-MTR were not used when initializing RobertaModel: ['regression.out_proj.weight', 'norm_mean', 'regression.out_proj.bias', 'regression.dense.weight', 'regression.dense.bias', 'norm_std']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaModel were not initialized from the model checkpoint at DeepChem/ChemBERTa-77M-MTR and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions an

In [9]:
from tqdm.auto import tqdm
from collections import defaultdict

# best_val, best_model = (-float('inf'), None)
for e in range(n_epochs+1):
    total_train_loss = 0.0
    train_loss_accumulator = defaultdict(float)

    for batch in tqdm(train_dl, total=len(train_dl)):
        train_losses = train_step(model, batch, optimizer, scheduler)
        for k,v in train_losses.items():
            train_loss_accumulator[k] += v.item()

    train_loss_accumulator = {
        f'train/{k}': round(v/len(train_dl), 4) 
        for k,v in train_loss_accumulator.items()
    }

    wandb.log(dict(train_loss_accumulator) | {'train/epoch': e})

    if e%2 == 0:
        val_loss_accumulator = defaultdict(float)
        for batch in tqdm(val_dl, total=len(val_dl)):
            val_losses = val_step(model, batch)
            for k,v in train_losses.items():
                val_loss_accumulator[k] += v.item()
            
        val_loss_accumulator = {
            f'val/{k}': round(v/len(val_dl), 4) 
            for k,v in val_loss_accumulator.items()
        }

        wandb.log(dict(val_loss_accumulator) | {'val/epoch': e})

  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/78 [00:00<?, ?it/s]

  0%|          | 0/622 [00:00<?, ?it/s]