In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import torch
import random
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [3]:
from utils import io
from utils import plot
from utils import metric
from model import train_evaluate

from model import xlmr_xnli_model
from model import xlmr_xnli_dataset

from transformers import XLMRobertaTokenizer, XLMRobertaModel

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
seed = 144
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True

## XNLI

In [9]:
batch_size = 16

### load data

In [7]:
"""
train_file = "data/train/train.tsv"
data = io.load_xnli_dataset(train_file)

train_file = "data/train/extended_train.csv"
"""

train_file = "data/train/extended_train_en_hi_sw_zh_es.csv"
data = io.load_xnli_dataset_csv(train_file)

In [10]:
data

Unnamed: 0,gold_label,premise,hypothesis,language
0,neutral,"At ground level, the asymmetrical cathedral is...",It's hard to find a dramatic view of the cathe...,en
1,contradiction,Hanuman is a beneficent deity predating classi...,Hanuman declared that all the lemurs here need...,en
2,contradiction,All other spending as well as federal revenue ...,None of the federal spending is assumed to grow,en
3,neutral,uh-huh that's interesting well it sounds as th...,That information about graduation rates is int...,en
4,neutral,Some kind of instant recognition on his father...,Did his father recognize him?,en
...,...,...,...,...
518885,neutral,एक महत्वपूर्ण उपसमूह में इस तरह की कम संवेदनशी...,उपसमूहों के लिए अनुकूलित कई स्क्रीन अधिक व्याप...,hi
518886,contradiction,वास्तव में मैं सिर्फ मेरे खेत के चारों ओर एक छ...,मैंने अपने जंगल में कभी जमीन नहीं रखी है।,hi
518887,contradiction,यदि आपके पास पहले से ही स्पष्ट विचार नहीं है क...,यह महल उसे बिल्कुल भी नहीं दर्शाता है।,hi
518888,neutral,लेकिन मैं एलए में बड़ा हुआ मैं यहां काम कर रहा...,मैं 5 साल की उम्र में एलए में चला गया।,hi


### dataset information

In [11]:
languages = ['zh', 'es', 'hi', 'sw']

lang_code_map = {x:i for i, x in enumerate(data.language.cat.categories)}
lang_codes = {lang_code_map[lang]: lang for lang in languages}

In [12]:
dataset_info = {
    'language': data.language.cat.categories.values,
    'gold_labels': data.gold_label.cat.categories.values
}

### train-test split

In [13]:
train_data, valid_data, test_data = io.split_dataset(data, lang_codes=lang_codes)

In [14]:
train_data

Unnamed: 0,gold_label,premise,hypothesis,language
0,neutral,那里不过是一片沙漠，跑道上有灌木。,沿着河床有风滚草摇曳。,zh
1,entailment,他们在范围内，Ogle尖叫道。,Ogle说他们可以到达。,zh
2,entailment,迈克尔·桑托来自纽约布法罗市消防公司，他们就是那个发明并生产了高氧气调节器的公司，在那之前他...,Santo住在纽约，并从事高氧调节器相关的工作。,zh
3,neutral,所以当他们告诉她她必须和这个男人回家时，她说，和他一起回家?,他们告诉她，她必须和那个男人睡觉。,zh
4,entailment,这是一项光荣的服务。,这是一项杰出和卓越的服务。,zh
...,...,...,...,...
396480,entailment,Họ đã hỏi một vài câu hỏi và tôi trả lời họ và...,Họ bảo tôi lấy túi của mình.,vi
396481,neutral,phải một số nhóm lợi ích đặc biệt,Nhóm này quan tâm tới các vấn đề về môi trường.,vi
396482,entailment,bạn có bạn có cắm trại hoang dã không,Bạn có tham dự trại về hoang dã không?,vi
396483,contradiction,Rõ ràng trẻ em ngày nay dành quá nhiều giờ ở p...,"TV chưa được phát minh, đó là lý do tại sao hầ...",vi


### save test

In [15]:
test_dir = 'data/test/'
test_input_file = f'{test_dir}/test_input'
test_output_file = f'{test_dir}/test_output'

io.save_xnli_test_dataset(test_input_file, test_output_file, test_data)

### dataloader

In [16]:
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")

In [17]:
train_dataset = xlmr_xnli_dataset.XLMRXNLIDataset(train_data, tokenizer, torch.device('cpu'))
valid_dataset = xlmr_xnli_dataset.XLMRXNLIDataset(valid_data, tokenizer, torch.device('cpu'))

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pai

In [18]:
train_dataloader = DataLoader(
        train_dataset, 
        batch_size=batch_size,
        drop_last=False,
        num_workers=0,
        shuffle=True,
        collate_fn=tokenizer.pad
    )

valid_dataloader = DataLoader(
        valid_dataset, 
        batch_size=batch_size,
        drop_last=False,
        num_workers=0,
        shuffle=True,
        collate_fn=tokenizer.pad
    )

In [19]:
batch = next(iter(train_dataloader))

### model

In [20]:
model = XLMRobertaModel.from_pretrained("xlm-roberta-base")

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaModel: ['lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [21]:
model_params = {'model': model, 
                'device': device,
                'lstm_params':{
                    'input_size': model.config.hidden_size,
                    'hidden_size':model.config.hidden_size//2,
                    'num_layers': 2,
                    'batch_first':True,
                    'bidirectional':True,
                    'device':device
                },
                'attention_params':{
                    'embed_dim' : model.config.hidden_size, 
                    'num_heads': 3, 
                    'batch_first': True, 
                    'device':device
                },
                'dropout_params':{
                    'xlmr_drop':0.5,
                    'lstm_drop':0.5,
                    'attn_drop':0.5,
                    'mlp_drop':0.5
                },
                'layers': [768, 3]}

xnli_model = xlmr_xnli_model.XLMRXLNIModel(**model_params)

In [21]:
output = xnli_model(batch.to(device))

### training

In [22]:
metric_params = {
    'accuracy': metric.accuracy, 
    'macro_f1': metric.macro_f1, 
    'average_f1': metric.average_f1,
}

In [23]:
train_params_base = {
    'num_epochs': 10,
    'step_size': 3,
    'gamma': 0.1,
    'lr': 1e-3,
    'betas': (0.9, 0.999),
    'lrs': [1e-3, 1e-3, 1e-3],
    'lang_codes': lang_codes,
    'weight_decay': 0,
    'save_dir':'experiments/LinearHead/',
    'save_tag':'',
    'verbose': True,
    'restore_file': None, #last, best
    'tensorboard_dir': 'runs/LinearHead/',
    'device': device
}

In [24]:
train_params = io.setup_training(train_params_base, model_params, dataset_info)

In [33]:
train_params

{'num_epochs': 10,
 'step_size': 3,
 'gamma': 0.1,
 'lr': 0.001,
 'betas': (0.9, 0.999),
 'lrs': [0.0001, 0.001, 0.001],
 'lang_codes': {14: 'zh', 5: 'es', 7: 'hi', 9: 'sw'},
 'weight_decay': 0,
 'save_dir': 'experiments/LinearHead//R_024/',
 'save_tag': '_024',
 'verbose': True,
 'restore_file': None,
 'tensorboard_dir': 'runs/LinearHead//R_024',
 'device': device(type='cuda')}

In [26]:
xnli_model.freeze_layer()

In [32]:
train_params['lrs'] = [1e-4, 1e-3, 1e-3]

In [34]:
xnli_model.unfreeze_layer()

In [35]:
summary = train_evaluate.train_and_evaluate(xnli_model, train_dataloader, valid_dataloader, 
                                            metric_params, train_params, continue_training=True)

  0%|          | 0/10 [00:00<?, ?it/s]

INFO: Epoch 1/10


  0%|          | 0/24781 [00:00<?, ?it/s]

  0%|          | 0/3824 [00:00<?, ?it/s]

Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
1,0.298334,1.091402,0.24197,0.354699,3550.69421,0.263937,1.096912,0.174185,0.353689,145.752092


Unnamed: 0_level_0,Training,Training,Training,Training,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,es,hi,sw,zh,es,hi,sw,zh
accuracy,0.357629,0.353749,0.351507,0.355911,0.353687,0.353687,0.353687,0.353695


Unnamed: 0_level_0,Training,Training,Training,Training,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,es,hi,sw,zh,es,hi,sw,zh
macro_f1,0.244819,0.240857,0.239586,0.242614,0.174185,0.174185,0.174185,0.174188


Unnamed: 0_level_0,Training,Training,Training,Training,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,es,hi,sw,zh,es,hi,sw,zh
average_f1,0.301224,0.297303,0.295547,0.299262,0.263936,0.263936,0.263936,0.263941


INFO: Epoch 2/10


  0%|          | 0/24781 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceed

Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
4,0.261903,1.097504,0.174308,0.349498,3567.222253,0.263937,1.096966,0.174185,0.353689,146.697897


Unnamed: 0_level_0,Training,Training,Training,Training,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,es,hi,sw,zh,es,hi,sw,zh
accuracy,0.349462,0.349574,0.34928,0.349674,0.353687,0.353687,0.353687,0.353695


Unnamed: 0_level_0,Training,Training,Training,Training,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,es,hi,sw,zh,es,hi,sw,zh
macro_f1,0.174401,0.174306,0.174026,0.174499,0.174185,0.174185,0.174185,0.174188


Unnamed: 0_level_0,Training,Training,Training,Training,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,es,hi,sw,zh,es,hi,sw,zh
average_f1,0.261931,0.26194,0.261653,0.262087,0.263936,0.263936,0.263936,0.263941


INFO: Epoch 5/10


  0%|          | 0/24781 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceed

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceed

  0%|          | 0/3824 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceed

  0%|          | 0/3824 [00:00<?, ?it/s]

Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
10,0.261122,1.097486,0.17268,0.349564,3553.859221,0.263937,1.096913,0.174185,0.353689,145.79101


Unnamed: 0_level_0,Training,Training,Training,Training,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,es,hi,sw,zh,es,hi,sw,zh
accuracy,0.34956,0.34956,0.34956,0.349576,0.353687,0.353687,0.353687,0.353695


Unnamed: 0_level_0,Training,Training,Training,Training,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,es,hi,sw,zh,es,hi,sw,zh
macro_f1,0.172679,0.17268,0.172679,0.172684,0.174185,0.174185,0.174185,0.174188


Unnamed: 0_level_0,Training,Training,Training,Training,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,es,hi,sw,zh,es,hi,sw,zh
average_f1,0.261119,0.26112,0.261119,0.26113,0.263936,0.263936,0.263936,0.263941


INFO: - Total training time : 37032.05680214055 secs
