In [1]:
from functools import partial

import torch

from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
    AutoModel,
    AutoTokenizer,
    DataCollatorWithPadding,
    get_constant_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)

from data import PairDS, pair_collate, load_kasa_regression, standardize_df, tokenize
from model import DeltaModel
from utils import parallelize, standardize

import wandb
wandb.init(project="delta",name="init_v6")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrahul-e-dev[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
def train_step(model, batch, optimizer, scheduler):
    optimizer.zero_grad()
    model.train()
    device = next(model.parameters()).device

    batch = {k: v.to(device) for k,v in batch.items()}
    losses = model(batch)
    losses['loss'].backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    scheduler.step()
    return losses

@torch.no_grad()
def val_step(model, batch):
    model.eval()
    device = next(model.parameters()).device
    batch = {k: v.to(device) for k,v in batch.items()}
    losses = model(batch)
    return losses

In [3]:
ds = load_kasa_regression()
tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
tok_func = partial(tokenize, tokenizer=tokenizer)
ds = ds.map(tok_func, num_proc=8).remove_columns(['smiles'])

tokenizer.pad_token = tokenizer.eos_token
padding_collator = DataCollatorWithPadding(tokenizer)

c:\Users\rahul\mambaforge\envs\bio\lib\site-packages\astartes\samplers\extrapolation\scaffold.py:44: NoMatchingScaffold: No matching scaffold was found for the 3 molecules corresponding to indices {23234, 36709, 11879}


Map (num_proc=8):   0%|          | 0/39777 [00:00<?, ? examples/s]

Map (num_proc=8):   0%|          | 0/4972 [00:00<?, ? examples/s]

Map (num_proc=8):   0%|          | 0/4972 [00:00<?, ? examples/s]

In [4]:
model = DeltaModel(AutoModel.from_pretrained('DeepChem/ChemBERTa-77M-MTR')).cuda()
train_dl = DataLoader(
    PairDS(ds['train']), 
    batch_size=128, 
    shuffle=True, 
    collate_fn=partial(pair_collate, base_collator=padding_collator),
    # num_workers=4
)

val_dl = DataLoader(
    PairDS(ds['val']), 
    batch_size=128, 
    shuffle=False, 
    collate_fn=partial(pair_collate, base_collator=padding_collator),
    # num_workers=4
)

optimizer = AdamW(model.parameters(), lr=2e-4)
n_epochs = 4
n_steps = len(train_dl) * n_epochs
scheduler = get_cosine_schedule_with_warmup(optimizer, n_steps*0.01, n_steps*(1-0.01))
# scheduler = get_constant_schedule_with_warmup(optimizer, n_steps*0.01)

wandb.watch(model, log='gradients', log_freq=10)

In [5]:
from tqdm.auto import tqdm
from collections import defaultdict

# best_val, best_model = (-float('inf'), None)
for e in tqdm(range(n_epochs+1)):
    total_train_loss = 0.0
    train_loss_accumulator = defaultdict(float)

    for batch in tqdm(train_dl, total=len(train_dl)):
        train_losses = train_step(model, batch, optimizer, scheduler)
        for k,v in train_losses.items():
            train_loss_accumulator[k] += v.item()

    train_loss_accumulator = {
        f'train/{k}': round(v/len(train_dl), 4) 
        for k,v in train_loss_accumulator.items()
    }

    wandb.log(dict(train_loss_accumulator) | {'train/epoch': e})

    if e%4 == 0:
        val_loss_accumulator = defaultdict(float)
        for batch in val_dl:
            val_losses = val_step(model, batch)
            for k,v in train_losses.items():
                val_loss_accumulator[k] += v.item()
            
        val_loss_accumulator = {
            f'val/{k}': round(v/len(val_dl), 4) 
            for k,v in val_loss_accumulator.items()
        }

        wandb.log(dict(val_loss_accumulator) | {'val/epoch': e})