In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import torch
import random
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
seed = 144
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True

## Tokenizer

In [None]:
from transformers import XLMRobertaTokenizer, XLMRobertaModel

In [None]:
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")

In [7]:
tokens = tokenizer.tokenize('Hello WORLD how ARE yoU?')
tokens

['▁Hello', '▁', 'WORLD', '▁how', '▁', 'ARE', '▁yo', 'U', '?']

In [10]:
indexes = tokenizer.convert_tokens_to_ids(tokens)
indexes

[35378, 6, 99972, 3642, 6, 23711, 3005, 1062, 32]

In [11]:
init_token = tokenizer.cls_token
eos_token = tokenizer.sep_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token

print(init_token, eos_token, pad_token, unk_token)

<s> </s> <pad> <unk>


In [12]:
init_token_idx = tokenizer.convert_tokens_to_ids(init_token)
eos_token_idx = tokenizer.convert_tokens_to_ids(eos_token)
pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)
unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

0 2 1 3


In [13]:
init_token_idx = tokenizer.cls_token_id
eos_token_idx = tokenizer.sep_token_id
pad_token_idx = tokenizer.pad_token_id
unk_token_idx = tokenizer.unk_token_id

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

0 2 1 3


In [14]:
max_input_length = tokenizer.max_model_input_sizes['xlm-roberta-base']
print(max_input_length)

512


## Load data

In [None]:
from utils import io

In [None]:
train_file = "data/train/train.tsv"
data, vocab = io.load_xnli_dataset(train_file)

In [9]:
data.head()

Unnamed: 0,gold_label,premise,hypothesis,language
0,neutral,"At ground level, the asymmetrical cathedral is...",It's hard to find a dramatic view of the cathe...,en
1,contradiction,Hanuman is a beneficent deity predating classi...,Hanuman declared that all the lemurs here need...,en
2,contradiction,All other spending as well as federal revenue ...,None of the federal spending is assumed to grow,en
3,neutral,uh-huh that's interesting well it sounds as th...,That information about graduation rates is int...,en
4,neutral,Some kind of instant recognition on his father...,Did his father recognize him?,en


In [10]:
train_data, valid_data, test_data = io.split_dataset(data)

In [12]:
train_data.head()

Unnamed: 0,gold_label,premise,hypothesis,language
0,neutral,"At ground level, the asymmetrical cathedral is...",It's hard to find a dramatic view of the cathe...,en
1,contradiction,Hanuman is a beneficent deity predating classi...,Hanuman declared that all the lemurs here need...,en
2,contradiction,All other spending as well as federal revenue ...,None of the federal spending is assumed to grow,en
3,neutral,uh-huh that's interesting well it sounds as th...,That information about graduation rates is int...,en
4,neutral,Some kind of instant recognition on his father...,Did his father recognize him?,en


In [23]:
valid_data.head()

Unnamed: 0,gold_label,premise,hypothesis,language
0,entailment,yeah see i come from a Catholic church backgro...,I was raised with the Catholic church.,en
1,neutral,The best case of the agency is best (everyone ...,The best case of the agency is when everyone i...,en
2,neutral,Walk right in.,Do not hesitate to get in.,en
3,entailment,You're sure to slip up sooner or later.,You'll make a mistake at some point.,en
4,neutral,"We requested comments from BEA, OMB, and sever...",Comments from professors were important,en


In [24]:
test_data.head()

Unnamed: 0,gold_label,premise,hypothesis,language
0,neutral,Equal thrills can be had atop the Stratosphere...,You can have a lot of thrills on top of the St...,en
1,contradiction,Table 6 shows various Cpk values and the defec...,Table 6 does not show any Cpk values and assoc...,en
2,contradiction,"As in centuries past, people go on mass pilgri...",The cherry blossoms are not popular in Japan.,en
3,neutral,"For example, the American Institute of Certifi...",The AICPA standard promotes consistency throug...,en
4,neutral,An article says Republicans gave Al Gore a pot...,An article in the NYT says Republicans gave Al...,en


In [11]:
print(f'Training data   : {train_data.shape}')
print(f'Validation data : {valid_data.shape}') 
print(f'Testing data    : {test_data.shape}')

Training data   : (80484, 4)
Validation data : (17246, 4)
Testing data    : (17248, 4)


## Dataloader

In [12]:
from model import xlmr_xnli_model
from model import xlmr_xnli_dataset

In [13]:
train_dataset = xlmr_xnli_dataset.XLMRXNLIDataset(train_data, tokenizer, device)
valid_dataset = xlmr_xnli_dataset.XLMRXNLIDataset(valid_data, tokenizer, device)

In [14]:
batch_size = 32

In [15]:
train_dataloader = DataLoader(
        train_dataset, 
        batch_size=batch_size,
        drop_last=False,
        num_workers=0,
        shuffle=True,
        #collate_fn=tokenizer.pad
    )

valid_dataloader = DataLoader(
        valid_dataset, 
        batch_size=batch_size,
        drop_last=False,
        num_workers=0,
        shuffle=True,
        #collate_fn=tokenizer.pad
    )

In [16]:
batch = next(iter(train_dataloader))

## Model

In [17]:
model = XLMRobertaModel.from_pretrained("xlm-roberta-base")

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaModel: ['lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [18]:
model_params = {'model': model, 
                'device': device, 
                'layers': [1024, 512, 256, 3]}

In [19]:
xnli_model = xlmr_xnli_model.XLMRXLNIModel(**model_params)

## Training

In [20]:
from utils import plot
from utils import metric
from model import train_evaluate

In [21]:
metric_params = {
    'accuracy': metric.accuracy, 
    'macro_f1': metric.macro_f1, 
    'average_f1': metric.average_f1
}

In [22]:
train_params_base = {
    'num_epochs': 10,
    'step_size': 3,
    'gamma': 0.1,
    'lr': 1e-3,
    'betas': (0.9, 0.999),
    'lrs': [1e-5, 1e-3],
    'weight_decay': 0,
    'save_dir':'experiments/LinearHead/',
    'save_tag':'',
    'verbose': True,
    'restore_file': None, #last, best
    'tensorboard_dir': 'runs/LinearHead/'
}

In [23]:
train_params = io.setup_training(train_params_base, model_params, vocab)

In [24]:
train_params

{'num_epochs': 10,
 'step_size': 3,
 'gamma': 0.1,
 'lr': 0.001,
 'betas': (0.9, 0.999),
 'lrs': [1e-05, 0.001],
 'weight_decay': 0,
 'save_dir': 'experiments/LinearHead//R_003/',
 'save_tag': '_003',
 'verbose': True,
 'restore_file': None,
 'tensorboard_dir': 'runs/LinearHead//R_003'}

In [25]:
xnli_model.freeze_layer()

In [26]:
summary = train_evaluate.train_and_evaluate(xnli_model, train_dataloader, valid_dataloader, 
                                            metric_params, train_params, continue_training=False)

  0%|          | 0/10 [00:00<?, ?it/s]

INFO: Epoch 1/10


  0%|          | 0/2516 [00:00<?, ?it/s]

  0%|          | 0/539 [00:00<?, ?it/s]

INFO: - Found new best accuracy.


Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
1,0.411671,1.029472,0.39343,0.429912,1652.94424,0.67943,0.750433,0.679108,0.679752,101.441507


INFO: Epoch 2/10


  0%|          | 0/2516 [00:00<?, ?it/s]

  0%|          | 0/539 [00:00<?, ?it/s]

INFO: - Found new best accuracy.


Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
2,0.715598,0.704089,0.714668,0.716528,1644.716308,0.780608,0.554938,0.779876,0.781341,101.581599


INFO: Epoch 3/10


  0%|          | 0/2516 [00:00<?, ?it/s]

  0%|          | 0/539 [00:00<?, ?it/s]

INFO: - Found new best accuracy.


Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
3,0.783559,0.571363,0.782564,0.784553,1646.02914,0.797664,0.526884,0.797288,0.79804,101.767623


INFO: Epoch 4/10


  0%|          | 0/2516 [00:00<?, ?it/s]

  0%|          | 0/539 [00:00<?, ?it/s]

INFO: - Found new best accuracy.


Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
4,0.832899,0.453013,0.832116,0.833681,1646.447467,0.799917,0.51358,0.79988,0.799954,101.931465


INFO: Epoch 5/10


  0%|          | 0/2516 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



  0%|          | 0/539 [00:00<?, ?it/s]

INFO: - Found new best accuracy.


Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
7,0.856396,0.393999,0.85574,0.857052,1648.41574,0.806249,0.506121,0.80605,0.806448,101.975174


INFO: Epoch 8/10


  0%|          | 0/2516 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceed

  0%|          | 0/539 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

