In [1]:
import sys
import torch
import pandas as pd
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.utils.data.sampler import SubsetRandomSampler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import BertTokenizer, BertModel
from transformers import DataCollatorForWholeWordMask

In [3]:
import sys
sys.path.append("..")

In [5]:
from src.model import RecoBERT
from src.data import CollatorWrapper, RecoDataset
from src.train import train

In [None]:
lr = 0.0001
l2_reg = 0.0
beta1 = 0.9
beta2 = 0.999
epochs = 100
batch_size = 16
workers = 8

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
bert = BertModel.from_pretrained("bert-base-cased")
collator = DataCollatorForWholeWordMask(tokenizer)
wrapper = CollatorWrapper(tokenizer, collator)

In [None]:
wines = pd.read_csv("../data/winemag-data-130k-v2.csv", index_col=0)
dataset = RecoDataset(wines)

In [None]:
idxs = list(range(len(dataset)))
split = int(len(dataset) * 0.8)
train_idxs, val_idxs = idxs[:split], idxs[split:]

train_sampler = SubsetRandomSampler(train_idxs)
val_sampler = SubsetRandomSampler(val_idxs)

train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, collate_fn=wrapper, num_workers=workers)
val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_sampler, collate_fn=wrapper, num_workers=workers)

In [24]:
model = RecoBERT(bert, tokenizer.vocab_size)

In [None]:
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cuda:2" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
model = model.to(device)

In [None]:
optim = Adam(model.parameters(), lr=lr, weight_decay=l2_reg, betas=(beta1, beta2))

In [None]:
model = train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optim=optim,
    epochs=epochs,
    device=device,
    checkpoint="./checkpoint",
    early_stop=20
)