In [144]:
#first load dataset
#custom dataset simulating a lang distribution from switzerland
import transformers
import datasets 
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
warnings.filterwarnings('ignore')
device = torch.device("cuda")
import numpy as np
import pandas as pd

In [145]:
#xtreme dataset, from xtreme load PAN-X.{lang} ex PAN-X.de
from datasets import load_dataset
ds = load_dataset('xtreme',name='PAN-X.de')

In [146]:
ds

DatasetDict({
    train: Dataset({
        features: ['tokens', 'ner_tags', 'langs'],
        num_rows: 20000
    })
    validation: Dataset({
        features: ['tokens', 'ner_tags', 'langs'],
        num_rows: 10000
    })
    test: Dataset({
        features: ['tokens', 'ner_tags', 'langs'],
        num_rows: 10000
    })
})

In [147]:
ds['train'].shuffle().select([i for i in range(int(0.9*ds['train'].num_rows))])

Dataset({
    features: ['tokens', 'ner_tags', 'langs'],
    num_rows: 18000
})

### simulated dataset

In [148]:
#de - 62.9, #fr - 22.9, #it -8.4, #en-5.9

In [149]:
#create a defaultdict where value is DatasetDict
from collections import defaultdict
from datasets import DatasetDict
panx_ch = defaultdict(DatasetDict)
fracs = [0.629,0.229,0.084,0.059]
langs = ['de','fr','it','en']
for lang,frac in zip(langs,fracs):
    ds = load_dataset('xtreme',name = f"PAN-X.{lang}")
    for split in ds:
        panx_ch[lang][split] =  (ds[split].shuffle().select([_ for _ in range(int(round(frac * ds[split].num_rows)))]))

In [150]:
panx_ch
# de has 6290 val
# fe has 2290 val
# it has 840 val
# en has 590 val 

defaultdict(datasets.dataset_dict.DatasetDict,
            {'de': DatasetDict({
                 train: Dataset({
                     features: ['tokens', 'ner_tags', 'langs'],
                     num_rows: 12580
                 })
                 validation: Dataset({
                     features: ['tokens', 'ner_tags', 'langs'],
                     num_rows: 6290
                 })
                 test: Dataset({
                     features: ['tokens', 'ner_tags', 'langs'],
                     num_rows: 6290
                 })
             }),
             'fr': DatasetDict({
                 train: Dataset({
                     features: ['tokens', 'ner_tags', 'langs'],
                     num_rows: 4580
                 })
                 validation: Dataset({
                     features: ['tokens', 'ner_tags', 'langs'],
                     num_rows: 2290
                 })
                 test: Dataset({
                     features: ['tokens', 'ner_tags', 'la

### create dataset of primary lang de

### save string format of tags

In [151]:
tags = panx_ch['de']['train'].features['ner_tags'].feature

In [152]:
tags

ClassLabel(names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC'], id=None)

In [153]:
def ner_2_str(batch):
    return {'ner_tags_str':[tags.int2str(idx) for idx in batch['ner_tags']]}

In [154]:
panx_de = panx_ch['de'].map(ner_2_str)

In [155]:
panx_de['train'][0]

{'tokens': ['Nicolás', 'Gómez', 'Dávila', '.'],
 'ner_tags': [1, 2, 2, 0],
 'langs': ['de', 'de', 'de', 'de'],
 'ner_tags_str': ['B-PER', 'I-PER', 'I-PER', 'O']}

In [156]:
#check frequncy of tag by split
#for imbalance
from collections import Counter
split_freq = defaultdict(Counter)
for split, dataset in panx_de.items():
    for row in dataset['ner_tags_str']:
        for tag in row:
            if tag.startswith('B'):
                split_freq[split][tag] += 1


In [157]:
# we have almost equal number of tags in all split
pd.DataFrame(split_freq)

Unnamed: 0,train,validation,test
B-PER,5907,2906,2956
B-ORG,5392,2598,2617
B-LOC,6185,3151,3122


In [158]:
#load tokenizer
from transformers import AutoTokenizer
xlmr_path = 'xlm-roberta-base'
xlmr_tokenizer = AutoTokenizer.from_pretrained(xlmr_path)

### Build custom Model for NER classificaiton

In [159]:
#we need 3 things 
#1.Base class to build a custom model upon <name>PreTrainedModel
#2 Model body <name>Model
#3 classifier Head
#head
from transformers.modeling_outputs import TokenClassifierOutput
#base class
from transformers import RobertaPreTrainedModel
#body
from transformers import RobertaModel
#config
from transformers import XLMRobertaConfig
#custom class inherit from base
class XLMRobertaForTokenClassification(RobertaPreTrainedModel):
    configclass = XLMRobertaConfig
    def __init__(self,config):
        super().__init__(config)
        self.roberta = RobertaModel(config,add_pooling_layer=False)
        self.num_labels = config.num_labels
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size,config.num_labels)
        #call init_weights
        self.init_weights()
    def forward(self,input_ids=None,attention_mask=None,token_type_ids=None,labels=None,**kwargs):
        output = self.roberta(input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,**kwargs)
        seq_out = self.dropout(output[0])
        logits = self.classifier(seq_out)
        loss=None
        if labels is not None:
            loss_func = nn.CrossEntropyLoss()
            loss = loss_func(logits.view(-1,self.num_labels),labels.view(-1))
        return TokenClassifierOutput(loss,logits=logits,hidden_states=output.hidden_states,attentions=output.attentions,)

In [160]:
#load config with custom tag2idx and idx2tag
tag2index = {tag:idx for idx,tag in enumerate(tags.names)}
index2tag = {idx:tag for idx,tag in enumerate(tags.names)}

In [161]:
tag2index

{'O': 0,
 'B-PER': 1,
 'I-PER': 2,
 'B-ORG': 3,
 'I-ORG': 4,
 'B-LOC': 5,
 'I-LOC': 6}

In [162]:
index2tag

{0: 'O',
 1: 'B-PER',
 2: 'I-PER',
 3: 'B-ORG',
 4: 'I-ORG',
 5: 'B-LOC',
 6: 'I-LOC'}

In [163]:
from transformers import AutoConfig
xlm_config = AutoConfig.from_pretrained(xlmr_path,num_labels=tags.num_classes,id2label=index2tag,label2id=tag2index,)

### finally load the custom model with config and model name

In [164]:
#from pretrained method is provided by base class
xlm_model = (XLMRobertaForTokenClassification.from_pretrained(xlmr_path,config=xlm_config).to(device))

In [165]:
xlm_model

XLMRobertaForTokenClassification(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(250002, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
        

In [166]:
input_ids = xlmr_tokenizer.encode("jack sparrow loves new york!",return_tensors="pt")

In [167]:
input_ids

tensor([[     0, 121477,  27148,  15555,   5161,      7,   3525,  70662,     92,
             38,      2]])

In [168]:
outputs = xlm_model(input_ids.to(device)).logits

In [169]:
outputs.shape

torch.Size([1, 11, 7])

In [170]:
predictions = torch.argmax(outputs,dim=-1)

In [171]:
pred_tags = [tags.names[p] for p in predictions[0].cpu().numpy()]

In [172]:
predictions

tensor([[5, 2, 2, 2, 2, 2, 2, 2, 2, 2, 5]], device='cuda:0')

In [173]:
pred_tags

['B-LOC',
 'I-PER',
 'I-PER',
 'I-PER',
 'I-PER',
 'I-PER',
 'I-PER',
 'I-PER',
 'I-PER',
 'I-PER',
 'B-LOC']

In [174]:
panx_de['train']

Dataset({
    features: ['tokens', 'ner_tags', 'langs', 'ner_tags_str'],
    num_rows: 12580
})

### helper functions

In [197]:
#tag text for predctions
def tag_text(text,tags,model,tokenizer):
    tokens = tokenizer(text).tokens()
    input_ids = xlmr_tokenizer(text,return_tensors="pt").input_ids.to(device)
    out = model(input_ids)[0]
    predictions = torch.argmax(out,dim=2)
    pred = [tags.names[p] for p in predictions[0].cpu().numpy()]
    return pd.DataFrame([tokens,pred],index=['Tokens','Preds'])

In [176]:
#tag only subwords
#-100 is ignored by nn.Crossentropy
def tokenize_align_label(batch):
    tokenized_inputs = xlmr_tokenizer(batch['tokens'],is_split_into_words=True,truncation=True)
    labels = []
    for idx,label in enumerate(batch['ner_tags']):
        word_ids = tokenized_inputs.word_ids(batch_index=idx)
        previous_word = None
        label_ids = []
        for word_id in word_ids:
            if word_id is not None and word_id != previous_word:
                label_ids.append(label[word_id])
            elif word_id == previous_word:
                label_ids.append(-100)
        labels.append(label_ids)
    tokenized_inputs['labels'] = labels
    return tokenized_inputs

In [177]:
#align predictions
def align_predictions(predictions,labels):
    preds = np.argmax(predictions,axis=2)
    batch_size,seq_len = preds.shape
    label_list,pred_list = [],[]
    for batch_idx in range(batch_size):
        example_label,example_pred = [],[]
        for seq_idx in range(seq_len):
            if labels[batch_idx,seq_idx] != -100:
                example_label.append(index2tag[labels[batch_idx][seq_idx]])
                example_pred.append(index2tag[preds[batch_idx][seq_idx]])
        label_list.append(example_label)
        pred_list.append(example_pred)
    return label_list,pred_list

In [178]:
#tokenize dataset
panx_de_encoded = panx_de.map(tokenize_align_label,batched=True,remove_columns=['ner_tags','langs','tokens'])

Map:   0%|          | 0/6290 [00:00<?, ? examples/s]

### fine tuning process

In [179]:
from seqeval.metrics import classification_report,f1_score

In [180]:
#data collator
from transformers import DataCollatorForTokenClassification
data_collator = DataCollatorForTokenClassification(xlmr_tokenizer)

In [181]:
#trainign args
from transformers import TrainingArguments
model_name = 'XLM_Roberta_NER_task'
batch_size = 8
weight_decay = 0.01
logging_steps = len(panx_de_encoded['train'])//batch_size
training_args = TrainingArguments(output_dir=model_name,
                                 per_device_eval_batch_size=batch_size,
                                 per_device_train_batch_size=batch_size,
                                 logging_steps=logging_steps,
                                 log_level="error",
                                 eval_strategy='epoch',
                                 weight_decay=weight_decay,
                                 disable_tqdm=False,
                                 num_train_epochs=3,
                                 save_steps=1e6)

In [182]:
def compute_metrics(pred):
    y_true,y_pred = align_predictions(pred.predictions,pred.label_ids)
    f1 = f1_score(y_true,y_pred)
    return {'f1':f1}

In [183]:
from transformers import Trainer
trainer = Trainer(model=xlm_model,
                 args = training_args,data_collator=data_collator,
                 compute_metrics=compute_metrics,
                 train_dataset=panx_de_encoded['train'],
                 eval_dataset=panx_de_encoded['validation'],
                 tokenizer=xlmr_tokenizer)

In [184]:
trainer.train()

Epoch,Training Loss,Validation Loss,F1
1,0.3333,0.231996,0.803615
2,0.1795,0.194158,0.837066
3,0.1063,0.215021,0.844686


TrainOutput(global_step=4719, training_loss=0.20629638485242613, metrics={'train_runtime': 463.0355, 'train_samples_per_second': 81.506, 'train_steps_per_second': 10.191, 'total_flos': 669933925633104.0, 'train_loss': 0.20629638485242613, 'epoch': 3.0})

In [199]:
text = "Apple Company eröffnet eine Niederlassung in Hongkong"
tag_text(text,tags,trainer.model,xlmr_tokenizer)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
Tokens,<s>,▁Apple,▁Company,▁eröffnet,▁eine,▁Nieder,lassung,▁in,▁Hongkong,</s>
Preds,O,B-ORG,I-ORG,O,O,O,O,O,B-LOC,O


[1;34mwandb[0m: 🚀 View run [33mXLM_Roberta_NER_task[0m at: [34mhttps://wandb.ai/adi-joshi2018-vit/huggingface/runs/aszqqret[0m
[1;34mwandb[0m: Find logs at: [1;35mwandb/run-20241213_125527-aszqqret/logs[0m
