In [1]:
# this automatically reloads the libraries so you can update them dynamically
%load_ext autoreload
%autoreload 2

from models.clip import CLIPModel
from utils.dataset import make_train_valid_dfs, build_loaders, prepare_data
from utils import config
from utils.train import train_epoch, valid_epoch

import torch
from transformers import DistilBertTokenizer
import tqdm.auto as tqdm

In [2]:
prepare_data()

In [None]:
train_df, valid_df = make_train_valid_dfs()
tokenizer = DistilBertTokenizer.from_pretrained(config.text_tokenizer)
train_loader = build_loaders(train_df, tokenizer, mode="train")
valid_loader = build_loaders(valid_df, tokenizer, mode="valid")


model = CLIPModel().to(config.device)
optimizer = torch.optim.AdamW(
    model.parameters(), lr=config.lr, weight_decay=config.weight_decay
)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", patience=config.patience, factor=config.factor
)
step = "epoch"

In [None]:
best_loss = float('inf')
for epoch in range(config.epochs):
    print(f"Epoch: {epoch + 1}")
    model.train()
    train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
    model.eval()
    with torch.no_grad():
        valid_loss = valid_epoch(model, valid_loader)
    
    if valid_loss < best_loss:
        best_loss = valid_loss
        torch.save(model.state_dict(), "best.pt")
        print("Saved Best Model!")