<a href="https://colab.research.google.com/github/tilaboy/nlp_transformer_tutorial/blob/main/learning_notes/ch4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install datasets --quiet
!pip install transformers --quiet
!pip install tensorflow --quiet
!pip install pandas --quiet
!pip install numpy --quiet
!pip install seqeval --quiet
!pip install torch --quiet
!pip install sklearn --quiet
!pip install matplotlib --quiet

In [2]:
from datasets import load_dataset, DatasetDict
from datasets import get_dataset_config_names, concatenate_datasets
from collections import defaultdict, Counter
import pandas as pd
import numpy as np
import torch.nn as nn
from torch.nn.functional import cross_entropy
import torch
from transformers import AutoTokenizer, AutoConfig, TrainingArguments
from transformers import XLMRobertaConfig, Trainer
from transformers import DataCollatorForTokenClassification
from transformers.modeling_outputs import TokenClassifierOutput 
from transformers.models.roberta.modeling_roberta import RobertaModel 
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel
from seqeval.metrics import classification_report, f1_score
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
import matplotlib.pyplot as plt

In [3]:
xtreme_subsets = get_dataset_config_names("xtreme")
print(f"XTREME has {len(xtreme_subsets)} configurations") 
panx_subsets = [s for s in xtreme_subsets if s.startswith("PAN")]
print(f'nr of languages in PAN dataset: {len(panx_subsets)}')
print([set_name[-2:] for set_name in panx_subsets])

XTREME has 183 configurations
nr of languages in PAN dataset: 40
['af', 'ar', 'bg', 'bn', 'de', 'el', 'en', 'es', 'et', 'eu', 'fa', 'fi', 'fr', 'he', 'hi', 'hu', 'id', 'it', 'ja', 'jv', 'ka', 'kk', 'ko', 'ml', 'mr', 'ms', 'my', 'nl', 'pt', 'ru', 'sw', 'ta', 'te', 'th', 'tl', 'tr', 'ur', 'vi', 'yo', 'zh']


In [4]:
langs = ["de", "fr", "it", "en"] 
fracs = [0.629, 0.229, 0.084, 0.059] 
# Return a DatasetDict if a key doesn't exist 
panx_ch = defaultdict(DatasetDict) 
for lang, frac in zip(langs, fracs): 
    # Load monolingual corpus 
    ds = load_dataset("xtreme", name=f"PAN-X.{lang}") 
    # Shuffle and downsample each split according to spoken proportion 
    for split in ds:
        nr_to_select = int(frac * ds[split].num_rows / 4)
        print(f'{lang}-{split}: {nr_to_select} out of {ds[split].num_rows}')
        panx_ch[lang][split] = ( ds[split].shuffle(seed=0).select(range(nr_to_select)))


Reusing dataset xtreme (/root/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e/cache-895dfb6c4273b2e5.arrow
Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e/cache-441068b3ea7cbef3.arrow
Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e/cache-1ec023bbd4560444.arrow


de-train: 3145 out of 20000
de-validation: 1572 out of 10000
de-test: 1572 out of 10000


Reusing dataset xtreme (/root/.cache/huggingface/datasets/xtreme/PAN-X.fr/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/xtreme/PAN-X.fr/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e/cache-c99e2d963e99c3bc.arrow
Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/xtreme/PAN-X.fr/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e/cache-95d424970390df95.arrow
Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/xtreme/PAN-X.fr/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e/cache-c8f8d90ec625373b.arrow


fr-train: 1145 out of 20000
fr-validation: 572 out of 10000
fr-test: 572 out of 10000


Reusing dataset xtreme (/root/.cache/huggingface/datasets/xtreme/PAN-X.it/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/xtreme/PAN-X.it/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e/cache-127e932015d0d753.arrow
Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/xtreme/PAN-X.it/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e/cache-f7bdfc2c46f67b82.arrow
Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/xtreme/PAN-X.it/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e/cache-affab8fcf3f5bef3.arrow


it-train: 420 out of 20000
it-validation: 210 out of 10000
it-test: 210 out of 10000


Reusing dataset xtreme (/root/.cache/huggingface/datasets/xtreme/PAN-X.en/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/xtreme/PAN-X.en/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e/cache-739f170d5471f1cb.arrow
Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/xtreme/PAN-X.en/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e/cache-3e10f730fe826b06.arrow
Loading cached shuffled indices for dataset at /root/.cache/huggingface/datasets/xtreme/PAN-X.en/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e/cache-5314845d936312dc.arrow


en-train: 295 out of 20000
en-validation: 147 out of 10000
en-test: 147 out of 10000


In [5]:
for attr, attr_value in panx_ch["de"]["train"][0].items():
    print(attr, attr_value)

for attr, attr_value in panx_ch["de"]["train"].features.items():
    print(attr, attr_value)

tokens ['2.000', 'Einwohnern', 'an', 'der', 'Danziger', 'Bucht', 'in', 'der', 'polnischen', 'Woiwodschaft', 'Pommern', '.']
ner_tags [0, 0, 0, 0, 5, 6, 0, 0, 5, 5, 6, 0]
langs ['de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de']
tokens Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)
ner_tags Sequence(feature=ClassLabel(num_classes=7, names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC'], id=None), length=-1, id=None)
langs Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)


In [6]:
tags = panx_ch["de"]["train"].features["ner_tags"].feature
print(tags)
def create_tag_names(batch):
    return {'ner_tags_str': [tags.int2str(idx) for idx in batch['ner_tags']]}
  
panx_de = panx_ch["de"].map(create_tag_names)
de_example = panx_de["train"][0]
print(de_example)


Loading cached processed dataset at /root/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e/cache-f006d4fd6dbd3aa9.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e/cache-45042c779ba3dce9.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e/cache-aee8454953a922ee.arrow


ClassLabel(num_classes=7, names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC'], id=None)
{'tokens': ['2.000', 'Einwohnern', 'an', 'der', 'Danziger', 'Bucht', 'in', 'der', 'polnischen', 'Woiwodschaft', 'Pommern', '.'], 'ner_tags': [0, 0, 0, 0, 5, 6, 0, 0, 5, 5, 6, 0], 'langs': ['de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de', 'de'], 'ner_tags_str': ['O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'O', 'O', 'B-LOC', 'B-LOC', 'I-LOC', 'O']}


In [7]:
split2freqs = defaultdict(Counter) 
for split, dataset in panx_de.items(): 
    for row in dataset["ner_tags_str"]: 
        for tag in row: 
            if tag.startswith("B"): 
                tag_type = tag.split("-")[1] 
                split2freqs[split][tag_type] += 1

print(split2freqs)


defaultdict(<class 'collections.Counter'>, {'train': Counter({'PER': 1546, 'LOC': 1523, 'ORG': 1319}), 'validation': Counter({'LOC': 764, 'PER': 755, 'ORG': 653}), 'test': Counter({'LOC': 833, 'PER': 768, 'ORG': 640})})


In [8]:
pd.DataFrame.from_dict(split2freqs, orient="index")

Unnamed: 0,LOC,ORG,PER
train,1523,1319,1546
validation,764,653,755
test,833,640,768


In [9]:
bert_model_name = "bert-base-cased" 
xlmr_model_name = "xlm-roberta-base" 
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) 
xlmr_tokenizer = AutoTokenizer.from_pretrained(xlmr_model_name)


In [10]:
text = "Jack Sparrow loves New York!" 
bert_tokens = bert_tokenizer(text).tokens() 
xlmr_tokens = xlmr_tokenizer(text).tokens()
print(bert_tokens)
print(xlmr_tokens)
#"".join(xlmr_tokens).replace(u"\u2581", " ")
"".join(xlmr_tokens).replace("▁", " ")

['[CLS]', 'Jack', 'Spa', '##rrow', 'loves', 'New', 'York', '!', '[SEP]']
['<s>', '▁Jack', '▁Spar', 'row', '▁love', 's', '▁New', '▁York', '!', '</s>']


'<s> Jack Sparrow loves New York!</s>'

In [11]:
class XLMRobertaForTokenClassification(RobertaPreTrainedModel):
    config_class = XLMRobertaConfig
    def __init__(self, config): 
        super().__init__(config)
        self.num_labels = config.num_labels
        # Load model body 
        self.roberta = RobertaModel(config, add_pooling_layer = False)
        # Set up token classification head 
        self.dropout = nn.Dropout(config.hidden_dropout_prob) 
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        # Load and initialize weights 
        self.init_weights()

    def forward(self, input_ids = None, attention_mask = None, token_type_ids = None, labels = None, **kwargs): 
        #Use model body to get encoder representations 
        outputs = self.roberta(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, **kwargs)
        # Apply classifier to encoder representation 
        sequence_output = self.dropout(outputs[0]) 
        logits = self.classifier(sequence_output)
        # Calculate losses 
        loss = None
        if labels is not None: 
            loss_fct = nn.CrossEntropyLoss() 
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            # Return model output object
        return TokenClassifierOutput(loss = loss, logits = logits, hidden_states = outputs.hidden_states, attentions = outputs.attentions)


In [12]:
index2tag = {idx: tag for idx, tag in enumerate(tags.names)} 
tag2index = {tag: idx for idx, tag in enumerate(tags.names)}
xlmr_config = AutoConfig.from_pretrained(xlmr_model_name, num_labels=tags.num_classes, id2label=index2tag, label2id=tag2index)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
xlmr_model = XLMRobertaForTokenClassification.from_pretrained(xlmr_model_name, config=xlmr_config).to(device)

Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaForTokenClassification: ['lm_head.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'roberta.pooler.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing XLMRobertaForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['roberta.embeddings.position_

In [13]:
input_ids = xlmr_tokenizer.encode(text, return_tensors="pt")
pd.DataFrame([xlmr_tokens, input_ids[0].numpy()], index=["Tokens", "Input IDs"])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
Tokens,<s>,▁Jack,▁Spar,row,▁love,s,▁New,▁York,!,</s>
Input IDs,0,21763,37456,15555,5161,7,2356,5753,38,2


In [14]:
outputs = xlmr_model(input_ids.to(device)).logits
predictions = torch.argmax(outputs, dim=-1)
print(f"Number of tokens in sequence: {len(xlmr_tokens)}") 
print(f"Shape of outputs: {outputs.shape}") 
print(f"Shape of predictions: {predictions.shape}") 

Number of tokens in sequence: 10
Shape of outputs: torch.Size([1, 10, 7])
Shape of predictions: torch.Size([1, 10])


In [15]:
def tag_text(text, tags, model, tokenizer):
    # Get tokens with special characters 
    tokens = tokenizer(text).tokens() 
    # Encode the sequence into IDs 
    input_ids = xlmr_tokenizer(text, return_tensors="pt").input_ids.to(device) 
    # Get predictions as distribution over 7 possible classes 
    outputs = model(input_ids)[0] 
    # Take argmax to get most likely class per token 
    predictions = torch.argmax(outputs, dim=2) 
    # Convert to DataFrame 
    preds = [tags.names[p] for p in predictions[0].cpu().numpy()] 
    return pd.DataFrame([tokens, preds], index=["Tokens", "Tags"])


In [16]:
de_example = panx_de['train'][0]
example_tokens, example_tags = de_example["tokens"], de_example["ner_tags"]
tokenized_input = xlmr_tokenizer(example_tokens, is_split_into_words=True)
# print(tokenized_input) => {'input_ids': [], 'attention_mask': []}
tokens = xlmr_tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])
word_ids = tokenized_input.word_ids()  
pd.DataFrame([tokens, word_ids], index=["Tokens", "WordID"])

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,15,16,17,18,19,20,21,22,23,24
Tokens,<s>,▁2.000,▁Einwohner,n,▁an,▁der,▁Dan,zi,ger,▁Buch,...,▁Wo,i,wod,schaft,▁Po,mmer,n,▁,.,</s>
WordID,,0,1,1,2,3,4,4,4,5,...,9,9,9,9,10,10,10,11,11,


In [17]:
previous_word_idx = None 
label_ids = [] 
for word_idx in word_ids: 
    if word_idx is None or word_idx == previous_word_idx: 
        label_ids.append(-100) 
    elif word_idx != previous_word_idx: 
        label_ids.append(example_tags[word_idx]) 
    previous_word_idx = word_idx
labels = [index2tag[l] if l != -100 else "IGN" for l in label_ids] 
index = ["Tokens", "Word IDs", "Label IDs", "Labels"] 
pd.DataFrame([tokens, word_ids, label_ids, labels], index=index)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,15,16,17,18,19,20,21,22,23,24
Tokens,<s>,▁2.000,▁Einwohner,n,▁an,▁der,▁Dan,zi,ger,▁Buch,...,▁Wo,i,wod,schaft,▁Po,mmer,n,▁,.,</s>
Word IDs,,0,1,1,2,3,4,4,4,5,...,9,9,9,9,10,10,10,11,11,
Label IDs,-100,0,0,-100,0,0,5,-100,-100,6,...,5,-100,-100,-100,6,-100,-100,0,-100,-100
Labels,IGN,O,O,IGN,O,O,B-LOC,IGN,IGN,I-LOC,...,B-LOC,IGN,IGN,IGN,I-LOC,IGN,IGN,O,IGN,IGN


In [18]:
def tokenize_and_align_labels(examples): 
    tokenized_inputs = xlmr_tokenizer(examples["tokens"], truncation=True, is_split_into_words=True) 
    labels = [] 
    for idx, label in enumerate(examples["ner_tags"]): 
        word_ids = tokenized_inputs.word_ids(batch_index=idx) 
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids: 
            if word_idx is None or word_idx == previous_word_idx: 
                label_ids.append(-100) 
            else: 
                label_ids.append(label[word_idx]) 
            previous_word_idx = word_idx 
        labels.append(label_ids) 
    tokenized_inputs["labels"] = labels 
    return tokenized_inputs

def encode_panx_dataset(corpus): 
    return corpus.map(tokenize_and_align_labels, batched=True, remove_columns=['langs', 'ner_tags', 'tokens'])

In [19]:
panx_de_encoded = encode_panx_dataset(panx_ch["de"])

Loading cached processed dataset at /root/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e/cache-9592d6ed574a990c.arrow


  0%|          | 0/2 [00:00<?, ?ba/s]

Loading cached processed dataset at /root/.cache/huggingface/datasets/xtreme/PAN-X.de/1.0.0/349258adc25bb45e47de193222f95e68a44f7a7ab53c4283b3f007208a11bf7e/cache-6586ab7b5c000571.arrow


In [20]:
y_true = [["O", "O", "O", "B-MISC", "I-MISC", "I-MISC", "O"], ["B-PER", "I-PER", "O"]] 
y_pred = [["O", "O", "B-MISC", "I-MISC", "I-MISC", "I-MISC", "O"], ["B-PER", "I-PER", "O"]] 
print(classification_report(y_true, y_pred))

def align_predictions(predictions, label_ids): 
    preds = np.argmax(predictions, axis=2) 
    batch_size, seq_len = preds.shape 
    labels_list, preds_list = [], [] 
    for batch_idx in range(batch_size): 
        example_labels, example_preds = [], [] 
        for seq_idx in range(seq_len): 
            # Ignore label IDs = -100 
            if label_ids[batch_idx, seq_idx] != -100: 
                example_labels.append(index2tag[label_ids[batch_idx] [seq_idx]]) 
                example_preds.append(index2tag[preds[batch_idx][seq_idx]]) 
            labels_list.append(example_labels) 
            preds_list.append(example_preds) 
    return preds_list, labels_list

def compute_metrics(eval_pred):
    y_pred, y_true = align_predictions(eval_pred.predictions, eval_pred.label_ids)
    return {"f1": f1_score(y_true, y_pred)}

              precision    recall  f1-score   support

        MISC       0.00      0.00      0.00         1
         PER       1.00      1.00      1.00         1

   micro avg       0.50      0.50      0.50         2
   macro avg       0.50      0.50      0.50         2
weighted avg       0.50      0.50      0.50         2



In [21]:
num_epochs = 3 
batch_size = 24 
logging_steps = len(panx_de_encoded["train"]) // batch_size 
model_name = f"{xlmr_model_name}-finetuned-panx-de" 
training_args = TrainingArguments(output_dir=model_name, 
                                  log_level="error", 
                                  num_train_epochs=num_epochs, 
                                  per_device_train_batch_size=batch_size, 
                                  per_device_eval_batch_size=batch_size, 
                                  evaluation_strategy="epoch", 
                                  save_steps=1e6, 
                                  weight_decay=0.01, 
                                  disable_tqdm=False, 
                                  logging_steps=logging_steps, 
                                  push_to_hub=False)

In [None]:
def model_init():
    return XLMRobertaForTokenClassification.from_pretrained(xlmr_model_name, config=xlmr_config).to(device)

trainer = Trainer(model_init=model_init, 
                  args=training_args, 
                  data_collator=DataCollatorForTokenClassification(xlmr_tokenizer), 
                  compute_metrics=compute_metrics, 
                  train_dataset=panx_de_encoded["train"], 
                  eval_dataset=panx_de_encoded["validation"], 
                  tokenizer=xlmr_tokenizer)

trainer.train()
print("Training completed!")

Epoch,Training Loss,Validation Loss


In [None]:
print('tags to predict', tags)
text_de = "Jeff Dean ist ein Informatiker bei Google in Kalifornien" 
tag_text(text_de, tags, trainer.model, xlmr_tokenizer)
text_fr = "Jeff Dean est informaticien chez Google en Californie" 
tag_text(text_fr, tags, trainer.model, xlmr_tokenizer)

In [None]:
data_collator = DataCollatorForTokenClassification(xlmr_tokenizer)

def forward_pass_with_label(batch): 
    # Convert dict of lists to list of dicts suitable for data collator 
    features = [dict(zip(batch, t)) for t in zip(*batch.values())] 
    # Pad inputs and labels and put all tensors on device 
    batch = data_collator(features) 
    input_ids = batch["input_ids"].to(device) 
    attention_mask = batch["attention_mask"].to(device) 
    labels = batch["labels"].to(device) 
    with torch.no_grad(): 
        # Pass data through model 
        output = trainer.model(input_ids, attention_mask) 
        # logit.size: [batch_size, sequence_length, classes] 
        # Predict class with largest logit value on classes axis 
        predicted_label = torch.argmax(output.logits, axis=-1).cpu().numpy()
    # Calculate loss per token after flattening batch dimension with view 
    loss = cross_entropy(output.logits.view(-1, 7), labels.view(-1), reduction="none")
    # Unflatten batch dimension and convert to numpy array
    loss = loss.view(len(input_ids), -1).cpu().numpy() 
    return {"loss":loss, "predicted_label": predicted_label}

In [None]:
valid_set = panx_de_encoded["validation"] 
valid_set = valid_set.map(forward_pass_with_label, batched=True, batch_size=32) 

In [None]:
index2tag[-100] = "IGN" 
df = valid_set.to_pandas()
df["input_tokens"] = df["input_ids"].apply( lambda x: xlmr_tokenizer.convert_ids_to_tokens(x)) 
df["predicted_label"] = df["predicted_label"].apply( lambda x: [index2tag[i] for i in x]) 
df["labels"] = df["labels"].apply( lambda x: [index2tag[i] for i in x]) 
df['loss'] = df.apply( lambda x: x['loss'][:len(x['input_ids'])], axis=1) 
df['predicted_label'] = df.apply( lambda x: x['predicted_label'][:len(x['input_ids'])], axis=1) 
df.head(5)

In [None]:
df_tokens = df.apply(pd.Series.explode) 
df_tokens = df_tokens.query("labels != 'IGN'") 
df_tokens["loss"] = df_tokens["loss"].astype(float).round(2) 
df_tokens.head(20)


In [None]:
( 
df_tokens.groupby("input_tokens")[["loss"]] 
.agg(["count", "mean", "sum"]) 
.droplevel(level=0, axis=1) 
# Get rid of multi-level columns 
.sort_values(by="sum", ascending=False) 
#.reset_index() 
.round(2) 
.head(10) 
.T 
)

In [None]:
( 
df_tokens.groupby("labels")[["loss"]] 
.agg(["count", "mean", "sum"]) 
.droplevel(level=0, axis=1) 
.sort_values(by="mean", ascending=False) 
#.reset_index() 
.round(2) 
.T 
)

In [None]:
def plot_confusion_matrix(y_preds, y_true, labels): 
    cm = confusion_matrix(y_true, y_preds, normalize="true") 
    fig, ax = plt.subplots(figsize=(6, 6)) 
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels) 
    disp.plot(cmap="Blues", values_format=".2f", ax=ax, colorbar=False) 
    plt.title("Normalized confusion matrix") 
    plt.show() 

In [None]:
plot_confusion_matrix(df_tokens["labels"], df_tokens["predicted_label"], tags.names)

In [None]:
# error analysis
def get_samples(df):
    for _, row in df.iterrows():
        labels, predictions, tokens, losses = [], [], [], []
        print(row['attention_mask'], len(row['attention_mask']))
        for i, mask in enumerate(row['attention_mask']):
            if i == 0 or i == len(row['attention_mask']):
                continue
            labels.append(row['labels'][i])
            predictions.append(row['predicted_label'][i])
            tokens.append(row['input_tokens'][i])
            losses.append(round(row['loss'][i], 3))
        yield pd.DataFrame({'label': labels, 'prediction': predictions, 'tokens': tokens, 'losses': losses}).T

In [None]:
df['total_loss'] = df['loss'].apply(sum)
df_tmp = df.sort_values(by='total_loss', ascending=False).head(5)

for sample in get_samples(df_tmp):
    display(sample)

In [None]:
df_tmp = df.loc[df["input_tokens"].apply(lambda x: u"\u2581(" in x)].head(5) 
for sample in get_samples(df_tmp): 
    display(sample)

In [None]:
def get_f1_score(trainer, dataset): 
    return trainer.predict(dataset).metrics["test_f1"]

In [None]:
f1_scores = defaultdict(dict) 
f1_scores["de"]["de"] = get_f1_score(trainer, panx_de_encoded["test"]) 
print(f"F1-score of [de] model on [de] dataset: {f1_scores['de']['de']:.3f}")

In [None]:
def evaluate_lang_performance(lang, trainer):
    panx_ds = encode_panx_dataset(panx_ch[lang]) 
    return get_f1_score(trainer, panx_ds["test"])
f1_scores = defaultdict(dict)
for lang in ['de', 'fr', 'it', 'en']:
    f1_scores["de"][lang] = evaluate_lang_performance(lang, trainer) 
    print(f"F1-score of [de] model on [{lang}] dataset: {f1_scores['de'][lang]:.3f}") 

In [None]:
def train_on_subset(dataset, num_samples): 
    train_ds = dataset["train"].shuffle(seed=42).select(range(num_samples)) 
    valid_ds = dataset["validation"] 
    test_ds = dataset["test"]
    training_args.logging_steps = len(train_ds) // batch_size 
    trainer = Trainer(model_init=model_init, 
                      args=training_args, 
                      data_collator=data_collator, 
                      compute_metrics=compute_metrics, 
                      train_dataset=train_ds, 
                      eval_dataset=valid_ds, 
                      tokenizer=xlmr_tokenizer) 
    trainer.train() 
    if training_args.push_to_hub: 
        trainer.push_to_hub(commit_message="Training completed!") 
    f1_score = get_f1_score(trainer, test_ds) 
    return pd.DataFrame.from_dict( {"num_samples": [len(train_ds)], "f1_score": [f1_score]})

In [None]:
panx_fr_encoded = encode_panx_dataset(panx_ch["fr"])
panx_it_encoded = encode_panx_dataset(panx_ch["it"])
panx_en_encoded = encode_panx_dataset(panx_ch["en"])

In [None]:
metrics_df = pd.DataFrame()
for num_samples in [200, 500, 1000, 1500, 2000]: 
    metrics_df = metrics_df.append( train_on_subset(panx_fr_encoded, num_samples), ignore_index=True)
fig, ax = plt.subplots() 
ax.axhline(f1_scores["de"]["fr"], ls="--", color="r") 
metrics_df.set_index("num_samples").plot(ax=ax) 
plt.legend(["Zero-shot from de", "Fine-tuned on fr"], loc="lower right")
plt.ylim((0, 1)) 
plt.xlabel("Number of Training Samples") 
plt.ylabel("F1 Score") 
plt.show()


In [None]:
def concatenate_splits(corpora): 
    multi_corpus = DatasetDict() 
    for split in corpora[0].keys(): 
        multi_corpus[split] = concatenate_datasets( [corpus[split] for corpus in corpora]).shuffle(seed=42) 
    return multi_corpus 

In [None]:
panx_all_encoded = concatenate_splits([panx_de_encoded, panx_fr_encoded])
training_args.logging_steps = len(panx_all_encoded["train"]) // batch_size 
training_args.push_to_hub = False 
training_args.output_dir = "xlm-roberta-base-finetuned-panx-de-fr" 
trainer = Trainer(model_init=model_init, 
                  args=training_args, 
                  data_collator=data_collator, 
                  compute_metrics=compute_metrics, 
                  tokenizer=xlmr_tokenizer, 
                  train_dataset=panx_all_encoded["train"], 
                  eval_dataset=panx_all_encoded["validation"])
trainer.train()

In [None]:
for lang in ['de', 'fr', 'it', 'en']:
    f1_scores["de_fr"][lang] = evaluate_lang_performance(lang, trainer) 
    print(f"F1-score of [de_fr] model on [{lang}] dataset: {f1_scores['de_fr'][lang]:.3f}")

In [None]:
corpora = [panx_de_encoded] 
# Exclude German from iteration 
for lang in langs[1:]: 
    training_args.output_dir = f"xlm-roberta-base-finetuned-panx-{lang}" 
    # Fine-tune on monolingual corpus 
    ds_encoded = encode_panx_dataset(panx_ch[lang]) 
    metrics = train_on_subset(ds_encoded, ds_encoded["train"].num_rows) 
    # Collect F1-scores in common dict 
    f1_scores[lang][lang] = metrics["f1_score"][0]
    # Add monolingual corpus to list of corpora to concatenate 
    corpora.append(ds_encoded)

In [None]:
corpora_encoded = concatenate_splits(corpora)

In [None]:
training_args.logging_steps = len(corpora_encoded["train"]) // batch_size 
training_args.output_dir = "xlm-roberta-base-finetuned-panx-all"
trainer = Trainer(model_init=model_init, args=training_args, data_collator=data_collator, compute_metrics=compute_metrics, tokenizer=xlmr_tokenizer, train_dataset=corpora_encoded["train"], eval_dataset=corpora_encoded["validation"]) 
trainer.train() 

In [None]:
for idx, lang in enumerate(langs): 
    f1_scores["all"][lang] = get_f1_score(trainer, corpora[idx]["test"]) 
scores_data = {"de": f1_scores["de"], "each": {lang: f1_scores[lang][lang] for lang in langs}, "all": f1_scores["all"]} 
f1_scores_df = pd.DataFrame(scores_data).T.round(4) 
f1_scores_df.rename_axis(index="Fine-tune on", columns="Evaluated on", inplace=True)
