In [6]:
%load_ext autoreload
%autoreload 2

import argparse
import os
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import yaml
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm

import wandb
from src import BertClassifier, train_utils, utils
from src import datasets as data_utils

device = utils.get_device()

config = utils.load_config(
    "model_params/bert_classifier.yaml", epochs=10, num_training_examples=-1
)

# Create datasets
train_dataset = data_utils.create_train_sst2(
    device,
    num_samples=config["num_training_examples"],
    tokenizer_name=config["bert_model_name"],
    max_seq_len=config["max_sequence_length"],
)

test_dataset = data_utils.create_test_sst2(
    device,
    tokenizer_name=config["bert_model_name"],
    max_seq_len=config["max_sequence_length"],
)
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=1)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67349/67349 [00:05<00:00, 13344.79it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 872/872 [00:00<00:00, 10499.25it/s]


In [9]:
full_model, fdf, full_test_loss, full_test_acc = train_utils.train_bert_model(
    train_dataset, test_dataset, config
)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669839999910133, max=1.0…

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4210/4210 [05:44<00:00, 12.22batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4210/4210 [05:41<00:00, 12.33batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4210/4210 [05:12<00:00, 13.46batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4210/4210 [05:32<00:00, 12.68batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4210/4210 [05:37<00:00, 12.46batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4210/4210 [05:07<00:00, 13.69batch/s]
100%|█████████████████████████████████████████████████████

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▂▃▃▄▅▆▆▇█
train/accuracy,▁▆▆▆▆▆▇▇██
train/batch_loss,▅▃▅▅▃▃▁▃▁▅▄▃▂▅▁▄▅▃▂▃▆▄▂█▂▂▂▅▃▄▅▆▆▃▃▂▅▆▃▆
train/loss,█▄▄▄▄▃▂▂▁▁

0,1
epoch,10.0
test/accuracy,83.83028
test/loss,0.37148
train/accuracy,84.89578
train/batch_loss,0.24677
train/loss,0.34957


In [14]:
full_test_loss, full_test_acc

(0.3714797575896015, 83.8302752293578)

In [10]:
torch.save(full_model.classifier.state_dict(), 'bert-classifier-10epoch-fulldata.pt')

In [11]:
fdf.to_csv('bert-classifier-10epoch-fulldata-loss.csv', index=False)

In [12]:
fdf

Unnamed: 0,test_guid,logits,pred,label,loss
0,0,"[-3.4585288, 3.077545]",1,1,0.001449
1,1,"[1.5943244, -1.4615642]",0,0,0.046006
2,2,"[-2.634699, 2.254184]",1,1,0.007502
3,3,"[-1.7075384, 1.5407968]",1,1,0.038104
4,4,"[0.9781762, -1.2377704]",0,0,0.103504
...,...,...,...,...,...
867,867,"[-0.11757666, 0.004934795]",1,0,0.756278
868,868,"[-0.5531204, 0.6147107]",1,1,0.270821
869,869,"[-0.3901366, -0.062573045]",1,0,0.870282
870,870,"[0.020624734, -0.25613683]",0,0,0.564311


In [15]:
firstq_loss = fdf.loss.quantile(0.25)
median_loss = fdf.loss.quantile(0.5)
thirdq_loss = fdf.loss.quantile(0.75)

In [16]:
fdf[(fdf.loss >= firstq_loss) & (fdf.loss < median_loss)].iloc[:2]

Unnamed: 0,test_guid,logits,pred,label,loss
1,1,"[1.5943244, -1.4615642]",0,0,0.046006
3,3,"[-1.7075384, 1.5407968]",1,1,0.038104


In [17]:
fdf[(fdf.loss >= median_loss) & (fdf.loss < thirdq_loss)].iloc[:2]

Unnamed: 0,test_guid,logits,pred,label,loss
4,4,"[0.9781762, -1.2377704]",0,0,0.103504
12,12,"[0.6405805, -0.8864901]",0,0,0.196529


In [18]:
fdf[(fdf.loss >= thirdq_loss)].iloc[:2]

Unnamed: 0,test_guid,logits,pred,label,loss
11,11,"[0.010494724, -0.08012711]",0,0,0.648862
13,13,"[-0.012740873, -0.54817855]",0,1,0.996283
