In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import torch
import random
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [3]:
from utils import io
from utils import plot
from utils import metric
from model import train_evaluate

from model import xlmr_xnli_model
from model import xlmr_xnli_dataset

from transformers import XLMRobertaTokenizer, XLMRobertaModel

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
seed = 144
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True

## XNLI

In [20]:
batch_size = 32

### load data

In [7]:
"""
train_file = "data/train/train.tsv"
data = io.load_xnli_dataset(train_file)

train_file = "data/train/extended_train.csv"
"""

l_code = "sw"
train_file = "data/train/extended_train_en_hi_sw_zh_es.csv"
data = io.load_xnli_dataset_csv(train_file)

In [8]:
data = data[data.language == l_code].reset_index(drop=True)

In [10]:
data.head()

Unnamed: 0,gold_label,premise,hypothesis,language
0,contradiction,Wanaweza pia kuwa wazuri baada ya kufunzwa.,Wanakuwa wagumu na baridi baada ya kukamilisha...,sw
1,entailment,Hauna heshima bwana kama vile nishaona.,Mwanaume anatenda vitu kishamba.,sw
2,contradiction,"Mimi, kwa upande mwingine, huwa na tamaa kama ...",Ni afadhali nile pea 100 kuliko tufaha.,sw
3,contradiction,"Hakika, ni vizuri, miendo zikawa za kasi na ka...",Ilionekana kukawia milele.,sw
4,entailment,Alitafuta faraja katika mstari katika ukurasa ...,Kulikuwa na lugha ya kigeni katika kurasa mbel...,sw


### dataset information

In [11]:
languages = [l_code]

lang_code_map = {x:i for i, x in enumerate(data.language.cat.categories)}
lang_codes = {lang_code_map[lang]: lang for lang in languages}

In [12]:
dataset_info = {
    'language': data.language.cat.categories.values,
    'gold_labels': data.gold_label.cat.categories.values
}

### train-test split

In [13]:
train_data, valid_data, test_data = io.split_dataset(data, lang_codes=lang_codes)

In [14]:
train_data

Unnamed: 0,gold_label,premise,hypothesis,language
0,contradiction,Wanaweza pia kuwa wazuri baada ya kufunzwa.,Wanakuwa wagumu na baridi baada ya kukamilisha...,sw
1,entailment,Hauna heshima bwana kama vile nishaona.,Mwanaume anatenda vitu kishamba.,sw
2,contradiction,"Mimi, kwa upande mwingine, huwa na tamaa kama ...",Ni afadhali nile pea 100 kuliko tufaha.,sw
3,contradiction,"Hakika, ni vizuri, miendo zikawa za kasi na ka...",Ilionekana kukawia milele.,sw
4,entailment,Alitafuta faraja katika mstari katika ukurasa ...,Kulikuwa na lugha ya kigeni katika kurasa mbel...,sw
...,...,...,...,...
71379,neutral,"Je, una timu yako favorite ya basket?","Je, una timu yako favorite ya mpira wa miguu a...",sw
71380,entailment,Sisi tunajua kwamba makampuni ni kuitwa kufany...,Makampuni yanayohitajika yanajulikana na sisi.,sw
71381,contradiction,Sehemu ya kulinda tovuti inaruhusu wanasheria ...,Mchakato wa mwanasheria wa tovuti hufanya ni v...,sw
71382,contradiction,"Tina Brown na mume wake, Harold Evans, rais wa...",Tina Brown na mume wake ambaye hakuwa rais wa ...,sw


### save test

In [15]:
test_dir = 'data/test/'
test_input_file = f'{test_dir}/test_input'
test_output_file = f'{test_dir}/test_output'

io.save_xnli_test_dataset(test_input_file, test_output_file, test_data)

### dataloader

In [15]:
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")

In [16]:
train_dataset = xlmr_xnli_dataset.XLMRXNLIDataset(train_data, tokenizer, torch.device('cpu'))
valid_dataset = xlmr_xnli_dataset.XLMRXNLIDataset(valid_data, tokenizer, torch.device('cpu'))

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.


In [17]:
train_dataloader = DataLoader(
        train_dataset, 
        batch_size=batch_size,
        drop_last=False,
        num_workers=0,
        shuffle=True,
        collate_fn=tokenizer.pad
    )

valid_dataloader = DataLoader(
        valid_dataset, 
        batch_size=batch_size,
        drop_last=False,
        num_workers=0,
        shuffle=True,
        collate_fn=tokenizer.pad
    )

In [18]:
batch = next(iter(train_dataloader))

### model

In [21]:
model = XLMRobertaModel.from_pretrained("xlm-roberta-base")

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaModel: ['lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.bias']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [22]:
model_params = {'model': model, 
                'device': device,
                'lstm_params':{
                    'input_size': model.config.hidden_size,
                    'hidden_size':model.config.hidden_size//2,
                    'num_layers': 2,
                    'batch_first':True,
                    'bidirectional':True,
                    'device':device
                },
                'attention_params':{
                    'embed_dim' : model.config.hidden_size, 
                    'num_heads': 3, 
                    'batch_first': True, 
                    'device':device
                },
                'dropout_params':{
                    'xlmr_drop':0.5,
                    'lstm_drop':0.5,
                    'attn_drop':0.5,
                    'mlp_drop':0.5
                },
                'layers': [768, 3]}

xnli_model = xlmr_xnli_model.XLMRXLNIModel(**model_params)

In [23]:
output = xnli_model(batch.to(device))

### training

In [24]:
metric_params = {
    'accuracy': metric.accuracy, 
    'macro_f1': metric.macro_f1, 
    'average_f1': metric.average_f1,
}

In [25]:
train_params_base = {
    'num_epochs': 10,
    'step_size': 3,
    'gamma': 0.1,
    'lr': 1e-3,
    'betas': (0.9, 0.999),
    'lrs': [1e-5, 1e-3, 1e-3],
    'lang_codes': lang_codes,
    'weight_decay': 0,
    'save_dir':'experiments/LinearHead/',
    'save_tag':'',
    'verbose': True,
    'restore_file': None, #last, best
    'tensorboard_dir': 'runs/LinearHead/',
    'device': device
}

In [26]:
train_params = io.setup_training(train_params_base, model_params, dataset_info)

In [27]:
train_params

{'num_epochs': 10,
 'step_size': 3,
 'gamma': 0.1,
 'lr': 0.001,
 'betas': (0.9, 0.999),
 'lrs': [1e-05, 0.001, 0.001],
 'lang_codes': {9: 'sw'},
 'weight_decay': 0,
 'save_dir': 'experiments/LinearHead//R_027/',
 'save_tag': '_027',
 'verbose': True,
 'restore_file': None,
 'tensorboard_dir': 'runs/LinearHead//R_027',
 'device': device(type='cuda')}

In [28]:
xnli_model.freeze_layer()

In [41]:
xnli_model.unfreeze_layer()

In [29]:
summary = train_evaluate.train_and_evaluate(xnli_model, train_dataloader, valid_dataloader, 
                                            metric_params, train_params, continue_training=False)

  0%|          | 0/10 [00:00<?, ?it/s]

INFO: Epoch 1/10


  0%|          | 0/2231 [00:00<?, ?it/s]

  0%|          | 0/478 [00:00<?, ?it/s]

INFO: - Found new best accuracy.


Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
1,0.469887,1.025577,0.467356,0.472417,599.282241,0.546818,0.931818,0.531985,0.56165,38.732266


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
accuracy,0.472417,0.56165


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
macro_f1,0.467356,0.531985


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
average_f1,0.469887,0.546818


INFO: Epoch 2/10


  0%|          | 0/2231 [00:00<?, ?it/s]

  0%|          | 0/478 [00:00<?, ?it/s]

INFO: - Found new best accuracy.


Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
2,0.603871,0.883694,0.603224,0.604519,595.991592,0.627713,0.828083,0.624477,0.630949,38.431696


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
accuracy,0.604519,0.630949


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
macro_f1,0.603224,0.624477


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
average_f1,0.603871,0.627713


INFO: Epoch 3/10


  0%|          | 0/2231 [00:00<?, ?it/s]

  0%|          | 0/478 [00:00<?, ?it/s]

INFO: - Found new best accuracy.


Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
3,0.644581,0.814253,0.643878,0.645285,601.012101,0.640712,0.802171,0.640667,0.640756,38.547373


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
accuracy,0.645285,0.640756


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
macro_f1,0.643878,0.640667


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
average_f1,0.644581,0.640712


INFO: Epoch 4/10


  0%|          | 0/2231 [00:00<?, ?it/s]

  0%|          | 0/478 [00:00<?, ?it/s]

INFO: - Found new best accuracy.


Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
4,0.691483,0.723863,0.690891,0.692074,599.517705,0.6554,0.790691,0.654157,0.656642,38.932231


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
accuracy,0.692074,0.656642


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
macro_f1,0.690891,0.654157


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
average_f1,0.691483,0.6554


INFO: Epoch 5/10


  0%|          | 0/2231 [00:00<?, ?it/s]

  0%|          | 0/478 [00:00<?, ?it/s]

Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
5,0.700065,0.703002,0.699426,0.700703,601.437708,0.651104,0.791314,0.649161,0.653047,38.29248


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
accuracy,0.700703,0.653047


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
macro_f1,0.699426,0.649161


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
average_f1,0.700065,0.651104


INFO: Epoch 6/10


  0%|          | 0/2231 [00:00<?, ?it/s]

  0%|          | 0/478 [00:00<?, ?it/s]

Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
6,0.704372,0.689474,0.703783,0.704962,595.776793,0.644224,0.804459,0.640958,0.64749,38.636896


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
accuracy,0.704962,0.64749


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
macro_f1,0.703783,0.640958


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
average_f1,0.704372,0.644224


INFO: Epoch 7/10


  0%|          | 0/2231 [00:00<?, ?it/s]

  0%|          | 0/478 [00:00<?, ?it/s]

Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
7,0.713035,0.674005,0.712422,0.713647,599.19173,0.64576,0.806674,0.642788,0.648732,39.292889


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
accuracy,0.713647,0.648732


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
macro_f1,0.712422,0.642788


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
average_f1,0.713035,0.64576


INFO: Epoch 8/10


  0%|          | 0/2231 [00:00<?, ?it/s]

  0%|          | 0/478 [00:00<?, ?it/s]

Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
8,0.71228,0.673893,0.711655,0.712905,599.568101,0.646059,0.809084,0.643189,0.648928,38.700843


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
accuracy,0.712905,0.648928


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
macro_f1,0.711655,0.643189


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,sw,sw
average_f1,0.71228,0.646059


INFO: Epoch 9/10


  0%|          | 0/2231 [00:00<?, ?it/s]

KeyboardInterrupt: 