In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import torch
import random
import numpy as np
import pandas as pd
from torch import nn
from tqdm.notebook import tqdm
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from transformers.adapters.composition import Stack

In [3]:
from utils import io
from utils import plot
from utils import metric
from model import train_evaluate

from model import xlmr_xnli_model
from model import xlmr_xnli_dataset

from transformers import AutoTokenizer
from transformers import XLMRobertaTokenizer, XLMRobertaModel
from transformers import TrainingArguments, AdapterTrainer

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
seed = 144
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True

## Load data

In [31]:
train_file = "data/train/extended_train_en_hi_sw_zh_es.csv"
complete_data = io.load_xnli_dataset_csv(train_file)

In [32]:
data = complete_data[complete_data.language == "en"].reset_index(drop=True)

In [33]:
data

Unnamed: 0,gold_label,premise,hypothesis,language
0,neutral,"At ground level, the asymmetrical cathedral is...",It's hard to find a dramatic view of the cathe...,en
1,contradiction,Hanuman is a beneficent deity predating classi...,Hanuman declared that all the lemurs here need...,en
2,contradiction,All other spending as well as federal revenue ...,None of the federal spending is assumed to grow,en
3,neutral,uh-huh that's interesting well it sounds as th...,That information about graduation rates is int...,en
4,neutral,Some kind of instant recognition on his father...,Did his father recognize him?,en
...,...,...,...,...
100973,neutral,Evidence of such low sensitivity in an importa...,Multiple screens tailored to subgroups will pr...,en
100974,contradiction,actually i just put a uh little fence around m...,I have never had groundhogs in my yard.,en
100975,contradiction,If you don't already have a clear idea of what...,THe palace doesn't' reflect him at all.,en
100976,neutral,but i grew up in LA i work out here,I moved to LA when I was 5.,en


In [13]:
#languages = ['zh', 'es', 'hi', 'sw']
languages = ['en']

lang_code_map = {x:i for i, x in enumerate(data.language.cat.categories)}
lang_codes = {lang_code_map[lang]: lang for lang in languages}

In [14]:
dataset_info = {
    'language': data.language.cat.categories.values,
    'gold_labels': data.gold_label.cat.categories.values
}

In [15]:
train_data, valid_data, test_data = io.split_dataset(data, lang_codes=lang_codes)

## Dataloader

In [16]:
batch_size = 32

In [17]:
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")

In [18]:
train_dataset = xlmr_xnli_dataset.XLMRXNLIDataset(train_data, tokenizer, torch.device('cuda'))
valid_dataset = xlmr_xnli_dataset.XLMRXNLIDataset(valid_data, tokenizer, torch.device('cuda'))

In [19]:
train_dataloader = DataLoader(
        train_dataset, 
        batch_size=batch_size,
        drop_last=False,
        num_workers=0,
        shuffle=True,
        collate_fn=tokenizer.pad
    )

valid_dataloader = DataLoader(
        valid_dataset, 
        batch_size=batch_size,
        drop_last=False,
        num_workers=0,
        shuffle=True,
        collate_fn=tokenizer.pad
    )

In [20]:
batch = next(iter(train_dataloader))
batch = batch.to(device)

## Model

In [21]:
from transformers import AutoConfig, AutoAdapterModel
from transformers import TrainingArguments, AdapterTrainer
from transformers import AdapterConfig

config = AutoConfig.from_pretrained(
    "xlm-roberta-base",
)
model = AutoAdapterModel.from_pretrained(
    "xlm-roberta-base",
    config=config
)

model = model.to(device)

lang_adapter_config = AdapterConfig.load("pfeiffer", reduction_factor=2)
model.load_adapter("en/wiki@ukp", config=lang_adapter_config)
model.load_adapter("zh/wiki@ukp", config=lang_adapter_config)
model.load_adapter("hi/wiki@ukp", config=lang_adapter_config)
model.load_adapter("es/wiki@ukp", config=lang_adapter_config)
model.load_adapter("sw/wiki@ukp", config=lang_adapter_config)

model.add_adapter("xnli")

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaAdapterModel: ['lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.bias']
- This IS expected if you are initializing XLMRobertaAdapterModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaAdapterModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaAdapterModel were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['roberta.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for prediction

In [22]:
model.train_adapter(["xnli"])
model.active_adapters = Stack("en", "xnli")

In [23]:
model_params = {'model': model, 
                'device': device,
                'dropout_params':{
                    'xlmr_drop':0.5,
                    'mlp_drop':0.5
                },
                'layers': [768, 3]}


In [24]:
xnli_model = xlmr_xnli_model.XLMRXNLIAdaptorModel(**model_params)

## Training

In [25]:
metric_params = {
    'accuracy': metric.accuracy, 
    'macro_f1': metric.macro_f1, 
    'average_f1': metric.average_f1,
}

In [26]:
train_params_base = {
    'num_epochs': 5,
    'step_size': 3,
    'gamma': 0.1,
    'lr': 1e-3,
    'betas': (0.9, 0.999),
    'lrs': [1e-4, 1e-3],
    'lang_codes': lang_codes,
    'weight_decay': 0,
    'save_dir':'experiments/LinearHead/',
    'save_tag':'',
    'verbose': True,
    'restore_file': None, #last, best
    'tensorboard_dir': 'runs/LinearHead/',
    'device': device
}

In [27]:
train_params = io.setup_training(train_params_base, model_params, dataset_info)

In [28]:
train_params

{'num_epochs': 5,
 'step_size': 3,
 'gamma': 0.1,
 'lr': 0.001,
 'betas': (0.9, 0.999),
 'lrs': [0.0001, 0.001],
 'lang_codes': {4: 'en'},
 'weight_decay': 0,
 'save_dir': 'experiments/LinearHead//R_029/',
 'save_tag': '_029',
 'verbose': True,
 'restore_file': None,
 'tensorboard_dir': 'runs/LinearHead//R_029',
 'device': device(type='cuda')}

In [29]:
summary = train_evaluate.train_and_evaluate(xnli_model, train_dataloader, valid_dataloader, 
                                            metric_params, train_params, continue_training=False)

  0%|          | 0/5 [00:00<?, ?it/s]

INFO: Epoch 1/5


  0%|          | 0/2209 [00:00<?, ?it/s]

  0%|          | 0/474 [00:00<?, ?it/s]

INFO: - Found new best accuracy.


Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
1,0.594135,0.855172,0.592024,0.596245,311.975201,0.747428,0.61743,0.747529,0.747326,34.259178


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,en,en
accuracy,0.596245,0.747326


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,en,en
macro_f1,0.592024,0.747529


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,en,en
average_f1,0.594135,0.747428


INFO: Epoch 2/5


  0%|          | 0/2209 [00:00<?, ?it/s]

  0%|          | 0/474 [00:00<?, ?it/s]

INFO: - Found new best accuracy.


Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
2,0.744532,0.628622,0.743564,0.745501,309.471278,0.772942,0.577549,0.771423,0.774462,34.253382


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,en,en
accuracy,0.745501,0.774462


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,en,en
macro_f1,0.743564,0.771423


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,en,en
average_f1,0.744532,0.772942


INFO: Epoch 3/5


  0%|          | 0/2209 [00:00<?, ?it/s]

  0%|          | 0/474 [00:00<?, ?it/s]

INFO: - Found new best accuracy.


Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
3,0.769471,0.575731,0.768555,0.770387,310.459919,0.780963,0.565777,0.778749,0.783177,34.442538


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,en,en
accuracy,0.770387,0.783177


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,en,en
macro_f1,0.768555,0.778749


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,en,en
average_f1,0.769471,0.780963


INFO: Epoch 4/5


  0%|          | 0/2209 [00:00<?, ?it/s]

  0%|          | 0/474 [00:00<?, ?it/s]

INFO: - Found new best accuracy.


Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
4,0.792847,0.520961,0.791992,0.793702,310.437948,0.791279,0.53513,0.790336,0.792222,34.670506


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,en,en
accuracy,0.793702,0.792222


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,en,en
macro_f1,0.791992,0.790336


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,en,en
average_f1,0.792847,0.791279


INFO: Epoch 5/5


  0%|          | 0/2209 [00:00<?, ?it/s]

  0%|          | 0/474 [00:00<?, ?it/s]

INFO: - Found new best accuracy.


Unnamed: 0_level_0,Training,Training,Training,Training,Training,Validation,Validation,Validation,Validation,Validation
Unnamed: 0_level_1,Average f1,Loss,Macro f1,Micro f1,Time (secs),Average f1,Loss,Macro f1,Micro f1,Time (secs)
5,0.79699,0.511851,0.796133,0.797847,309.921146,0.792855,0.540744,0.792036,0.793675,34.486546


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,en,en
accuracy,0.797847,0.793675


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,en,en
macro_f1,0.796133,0.792036


Unnamed: 0_level_0,Training,Validation
Unnamed: 0_level_1,en,en
average_f1,0.79699,0.792855


INFO: - Total training time : 1724.3776422850788 secs


## Evaluation

In [34]:
data_sw = complete_data[complete_data.language == "sw"].reset_index(drop=True)

In [35]:
data_sw

Unnamed: 0,gold_label,premise,hypothesis,language
0,contradiction,Wanaweza pia kuwa wazuri baada ya kufunzwa.,Wanakuwa wagumu na baridi baada ya kukamilisha...,sw
1,entailment,Hauna heshima bwana kama vile nishaona.,Mwanaume anatenda vitu kishamba.,sw
2,contradiction,"Mimi, kwa upande mwingine, huwa na tamaa kama ...",Ni afadhali nile pea 100 kuliko tufaha.,sw
3,contradiction,"Hakika, ni vizuri, miendo zikawa za kasi na ka...",Ilionekana kukawia milele.,sw
4,entailment,Alitafuta faraja katika mstari katika ukurasa ...,Kulikuwa na lugha ya kigeni katika kurasa mbel...,sw
...,...,...,...,...
101973,neutral,Maonyesho ya hisia ya chini katika subgroup mu...,Maonyesho kadhaa yanayohusiana na subgroups it...,sw
101974,contradiction,kwa kweli mimi tu kuweka uh kidogo fence karib...,Sikuwahi kuwa na mbegu katika bahari yangu.,sw
101975,contradiction,Ikiwa wewe tayari hawana ufahamu wazi wa aina ...,Siku hiyo haina kuonyesha kwa hakika.,sw
101976,neutral,Nimeishi huko Texas kando na nataka kuishi huk...,Nimesafiri kwa gari kwa Zaidi ya saa mbili had...,sw


In [37]:
xnli_model.xlmr.active_adapters = Stack("sw", "xnli")

In [39]:
valid_dataset = xlmr_xnli_dataset.XLMRXNLIDataset(data_sw, tokenizer, torch.device('cpu'))

valid_dataloader = DataLoader(
        valid_dataset, 
        batch_size=batch_size,
        drop_last=False,
        num_workers=0,
        shuffle=True,
        collate_fn=tokenizer.pad
    )

In [51]:
valid_lang = ['sw']
valid_lang_codes = {lang_code_map[lang]: lang for lang in valid_lang}

In [None]:
criterion = torch.nn.NLLLoss()
valid_metrics = train_evaluate.evaluate(xnli_model, valid_dataloader, criterion, metric_params, 
                                        valid_lang_codes, device)

  0%|          | 0/3187 [00:00<?, ?it/s]

In [49]:
valid_metrics

{'loss': 1.0444867271575802,
 'accuracy': {'en': 0.0, 'all': 0.0},
 'macro_f1': {'en': nan, 'all': nan},
 'average_f1': {'en': nan, 'all': nan}}