In [None]:
import random
from dataclasses import asdict
from functools import partial

from torch.optim import AdamW
from transformers import (
    AutoModel,
    AutoTokenizer,
    DataCollatorWithPadding,
    get_constant_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
)

import wandb
from config import TrainArgs
from data import (
    load_kasa_regression,
    multiple_contrastive_collate,
    tokenize,
)
from model import MultiContrastiveModel
from utils import set_seeds

args = TrainArgs(num_dataloader_workers=0)
wandb.init(project="delta", name="init_v10", config=asdict(args))

set_seeds(args.random_seed)

In [None]:
ds = load_kasa_regression(args)
n_train = len(ds["train"])

rand_idxs = [
    random.randint(0, n_train - 1)
    for _ in range(int(args.train_undersample_ratio * n_train))
]

ds["train"] = ds["train"].select(rand_idxs)

tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
tok_func = partial(tokenize, tokenizer=tokenizer)
ds = ds.map(tok_func, num_proc=8).remove_columns(["smiles", "inchi"])

tokenizer.pad_token = tokenizer.eos_token
padding_collator = DataCollatorWithPadding(tokenizer)

In [None]:
from train import MultiContrastiveTrainer, initialize_train_dataloader

model = MultiContrastiveModel(
    AutoModel.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
).cuda()

train_dl = initialize_train_dataloader(ds['train'], model, padding_collator, args)

optimizer = AdamW(model.parameters(), lr=args.lr)
n_epochs = args.n_epochs
n_steps = len(train_dl) * n_epochs
warmup_ratio = args.warmup_ratio
scheduler = get_cosine_schedule_with_warmup(
    optimizer, n_steps * warmup_ratio, n_steps * (1 - warmup_ratio)
)

trainer = MultiContrastiveTrainer(model, optimizer, scheduler, train_dl, ds['test'], args)

In [None]:
trainer.train()