In [None]:
from functools import partial
import random

import torch

from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
    AutoModel,
    AutoTokenizer,
    DataCollatorWithPadding,
    get_constant_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)

from data import (
    ContrastiveTripletDS,
    load_kasa_regression,
    standardize_df,
    tokenize,
    batch_of_dict_collate,
)
from model import TripletDeltaModel
from collections import defaultdict
from tqdm.auto import tqdm

import wandb

wandb.init(project="delta", name="init_v8")

In [None]:
def train_step(model, batch, optimizer, scheduler):
    optimizer.zero_grad()
    model.train()
    device = next(model.parameters()).device

    batch = {k: v.to(device) for k, v in batch.items()}
    losses = model(batch)
    losses["loss"].backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    scheduler.step()
    return losses


@torch.no_grad()
def val_step(model, batch):
    model.eval()
    device = next(model.parameters()).device
    batch = {k: v.to(device) for k, v in batch.items()}
    losses = model(batch)
    return losses


@torch.no_grad()
def embed_split(ds_split, model, padding_collator):
    embeddings = []
    device = next(model.parameters()).device
    for batch in DataLoader(
        ds_split, shuffle=False, batch_size=1024, collate_fn=padding_collator
    ):
        batch = {k: v.to(device) for k, v in batch.items()}
        batch_embeddings = model.embed_batch(batch)
        embeddings.append(batch_embeddings.cpu())

    embeddings = torch.cat(embeddings, dim=0)
    return embeddings.to(device)


@torch.no_grad()
def get_actual_diffs_for_split(ds_split):
    val_scores = torch.tensor(ds_split["score"])
    actual_diffs = val_scores.unsqueeze(-1) - val_scores.unsqueeze(0)
    return actual_diffs.cpu().numpy()


@torch.no_grad()
def get_predicted_diffs_for_split(ds_split, model, padding_collator):
    embeddings = embed_split(ds_split, model, padding_collator)

    N = embeddings.shape[0]
    pred_delta = torch.zeros((N, N))
    for row_idx in tqdm(range(N)):
        pred_delta_row = model.get_delta(
            from_embedding=embeddings[[row_idx]].expand_as(embeddings),
            to_embedding=embeddings,
        )
        pred_delta[row_idx, :] = pred_delta_row.squeeze()

    return pred_delta

In [None]:
def update_dataloaders(train_split, val_split, model, padding_collator):
    train_ds = ContrastiveTripletDS(
        train_split,
        get_actual_diffs_for_split(train_split),
        get_predicted_diffs_for_split(train_split, model, padding_collator),
    )

    val_ds = ContrastiveTripletDS(
        val_split,
        get_actual_diffs_for_split(val_split),
        get_predicted_diffs_for_split(val_split, model, padding_collator),
    )

    train_dl = DataLoader(
        train_ds,
        batch_size=256,
        shuffle=True,
        collate_fn=partial(batch_of_dict_collate, base_collator=padding_collator),
        # num_workers=4
    )

    val_dl = DataLoader(
        val_ds,
        batch_size=256,
        shuffle=False,
        collate_fn=partial(batch_of_dict_collate, base_collator=padding_collator),
        # num_workers=4
    )

    return train_dl, val_dl

In [None]:
def step_epoch(e, model, train_dl, val_dl, optimizer, scheduler):
    train_loss_accumulator = defaultdict(float)
    for batch in tqdm(train_dl, total=len(train_dl)):
        train_losses = train_step(model, batch, optimizer, scheduler)
        for k, v in train_losses.items():
            train_loss_accumulator[k] += v.item()

    train_loss_accumulator = {
        f"train/{k}": round(v / len(train_dl), 4)
        for k, v in train_loss_accumulator.items()
    }

    wandb.log(dict(train_loss_accumulator) | {"train/epoch": e})

    if e % 1 == 0:
        val_loss_accumulator = defaultdict(float)
        for batch in val_dl:
            val_losses = val_step(model, batch)
            for k, v in val_losses.items():
                val_loss_accumulator[k] += v.item()

        val_loss_accumulator = {
            f"val/{k}": round(v / len(val_dl), 4)
            for k, v in val_loss_accumulator.items()
        }

        wandb.log(dict(val_loss_accumulator) | {"val/epoch": e})

In [None]:
ds = load_kasa_regression()
n_train = len(ds['train'])
rand_idxs = [random.randint(0, n_train - 1) for _ in range(int(0.4*n_train))]
ds['train'] = ds['train'].select(rand_idxs)

tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
tok_func = partial(tokenize, tokenizer=tokenizer)
ds = ds.map(tok_func, num_proc=8).remove_columns(["smiles", "inchi"])

tokenizer.pad_token = tokenizer.eos_token
padding_collator = DataCollatorWithPadding(tokenizer)

In [None]:
model = TripletDeltaModel(AutoModel.from_pretrained("DeepChem/ChemBERTa-77M-MTR")).cuda()
train_dl, val_dl = update_dataloaders(ds["train"], ds["test"], model, padding_collator)

optimizer = AdamW(model.parameters(), lr=2e-5)
n_epochs = 10
n_steps = len(train_dl) * n_epochs
warmup_ratio = 0.01
scheduler = get_cosine_schedule_with_warmup(
    optimizer, n_steps * warmup_ratio, n_steps * (1 - warmup_ratio)
)
# scheduler = get_constant_schedule_with_warmup(optimizer, n_steps*warmup_ratio)

wandb.watch(model, log="gradients", log_freq=20)

In [None]:
for e in tqdm(range(n_epochs + 1)):
    step_epoch(e, model, train_dl, val_dl, optimizer, scheduler)

    if e % 1 == 0:
        train_dl, val_dl = update_dataloaders(ds["val"], ds["test"], model, padding_collator)