In [1]:
import sys
import torch
import pandas as pd
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.utils.data.sampler import SubsetRandomSampler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import BertTokenizer, BertModel
from transformers import RobertaTokenizer, RobertaModel
from transformers import DataCollatorForLanguageModeling

In [3]:
import sys
sys.path.append("..")

In [4]:
from src.model import RecoBERT
from src.data import TrainCollator, RecoDataset
from src.train import train

In [5]:
lr = 0.0001
l2_reg = 0.0
beta1 = 0.9
beta2 = 0.999
epochs = 50
batch_size = 32
workers = 8

In [6]:
# SlovakBERT (https://arxiv.org/abs/2109.15254)
# tokenizer = RobertaTokenizer.from_pretrained('gerulata/slovakbert')
# bert = RobertaModel.from_pretrained('gerulata/slovakbert')

In [7]:
# BERT (https://arxiv.org/pdf/1810.04805.pdf)
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
bert = BertModel.from_pretrained("bert-base-cased")

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [8]:
collator = DataCollatorForLanguageModeling(tokenizer)
collate_fn = TrainCollator(tokenizer, collator)

In [9]:
wines = pd.read_csv("../data/winemag-data-130k-v3.csv", index_col=0)
dataset = RecoDataset(df=wines, swap_prob=0.5)

In [10]:
idxs = list(range(len(dataset)))
split = int(len(dataset) * 0.8)
train_idxs, val_idxs = idxs[:split], idxs[split:]

train_sampler = SubsetRandomSampler(train_idxs)
val_sampler = SubsetRandomSampler(val_idxs)

train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, collate_fn=collate_fn, num_workers=workers)
val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_sampler, collate_fn=collate_fn, num_workers=workers)

In [11]:
model = RecoBERT(bert, tokenizer.vocab_size)

In [12]:
device = "cuda:2" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
model = model.to(device)

Using cuda:2 device


In [13]:
optim = Adam(model.parameters(), lr=lr, weight_decay=l2_reg, betas=(beta1, beta2))

In [14]:
model = train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optim=optim,
    epochs=epochs,
    device=device,
    checkpoint="./checkpoint",
    early_stop=5
)

11:08:26 - Epoch 000: Train Loss = 2.9264
11:09:38 - Epoch 000: Val Loss = 1.9746
11:21:41 - Epoch 001: Train Loss = 1.8581
11:22:53 - Epoch 001: Val Loss = 1.6117
11:34:58 - Epoch 002: Train Loss = 1.6205
11:36:10 - Epoch 002: Val Loss = 1.4515
11:48:14 - Epoch 003: Train Loss = 1.4803
11:49:26 - Epoch 003: Val Loss = 1.3901
12:01:29 - Epoch 004: Train Loss = 1.4041
12:02:41 - Epoch 004: Val Loss = 1.2861
12:14:45 - Epoch 005: Train Loss = 1.3390
12:15:57 - Epoch 005: Val Loss = 1.2436
12:28:01 - Epoch 006: Train Loss = 1.2914
12:29:14 - Epoch 006: Val Loss = 1.2027
12:41:18 - Epoch 007: Train Loss = 1.2588
12:42:30 - Epoch 007: Val Loss = 1.1759
12:54:35 - Epoch 008: Train Loss = 1.2266
12:55:47 - Epoch 008: Val Loss = 1.1610
13:07:51 - Epoch 009: Train Loss = 1.1932
13:09:03 - Epoch 009: Val Loss = 1.1230
13:21:09 - Epoch 010: Train Loss = 1.1806
13:22:22 - Epoch 010: Val Loss = 1.1095
13:34:30 - Epoch 011: Train Loss = 1.1593
13:35:42 - Epoch 011: Val Loss = 1.0880
13:47:49 - Epoch

KeyboardInterrupt: 