In [None]:
from sklearn.model_selection import train_test_split

def dataset_train_test_split(dataset, test_size=0.2, seed=42):
    examples = [ex for ex in dataset]
    train_examples, test_examples = train_test_split(examples, test_size=test_size, random_state=seed)
    return Dataset.from_list(train_examples), Dataset.from_list(test_examples)

train_dataset, test_dataset = dataset_train_test_split(dataset)

encoded_train = train_dataset.map(tokenize_and_align_labels)
encoded_test = test_dataset.map(tokenize_and_align_labels)

In [None]:
xlmr_model = AutoModelForTokenClassification.from_pretrained(
    "xlm-roberta-base",
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id
)

xlmr_args = TrainingArguments(
    output_dir="./models/xlmr",
    evaluation_strategy="epoch",
    per_device_train_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs/xlmr",
    logging_steps=10,
    save_strategy="no"
)

xlmr_trainer = Trainer(
    model=xlmr_model,
    args=xlmr_args,
    train_dataset=encoded_train,
    eval_dataset=encoded_test,
    tokenizer=tokenizer,
    data_collator=DataCollatorForTokenClassification(tokenizer)
)

xlmr_trainer.train()

In [None]:
import numpy as np
from seqeval.metrics import classification_report as seqeval_classification_report

xlmr_preds = xlmr_trainer.predict(encoded_test)
xlmr_pred_labels = np.argmax(xlmr_preds.predictions, axis=-1)
xlmr_true_labels = xlmr_preds.label_ids

def decode_labels(preds, true, id2label):
    pred_tags, true_tags = [], []
    for pred, true_seq in zip(preds, true):
        pred_tags.append([id2label[i] for i in pred if i != -100])
        true_tags.append([id2label[i] for i in true_seq if i != -100])
    return pred_tags, true_tags

xlmr_pred_tags, xlmr_true_tags = decode_labels(xlmr_pred_labels, xlmr_true_labels, id2label)

print("XLM-Roberta Evaluation:")
print(seqeval_classification_report(xlmr_true_tags, xlmr_pred_tags))

In [None]:
mbert_tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")

def tokenize_and_align_labels_mbert(example):
    tokenized = mbert_tokenizer(example['tokens'], truncation=True, is_split_into_words=True)
    word_ids = tokenized.word_ids()
    labels = []
    previous_word_idx = None
    for word_idx in word_ids:
        if word_idx is None:
            labels.append(-100)
        elif word_idx != previous_word_idx:
            labels.append(label2id[example["ner_tags"][word_idx]])
        else:
            labels.append(label2id[example["ner_tags"][word_idx]])
        previous_word_idx = word_idx
    tokenized["labels"] = labels
    return tokenized

encoded_train_mbert = train_dataset.map(tokenize_and_align_labels_mbert)
encoded_test_mbert = test_dataset.map(tokenize_and_align_labels_mbert)

mbert_model = AutoModelForTokenClassification.from_pretrained(
    "bert-base-multilingual-cased",
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id
)

mbert_args = TrainingArguments(
    output_dir="./models/mbert",
    evaluation_strategy="epoch",
    per_device_train_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs/mbert",
    logging_steps=10,
    save_strategy="no"
)

mbert_trainer = Trainer(
    model=mbert_model,
    args=mbert_args,
    train_dataset=encoded_train_mbert,
    eval_dataset=encoded_test_mbert,
    tokenizer=mbert_tokenizer,
    data_collator=DataCollatorForTokenClassification(mbert_tokenizer)
)

mbert_trainer.train()

In [None]:
mbert_preds = mbert_trainer.predict(encoded_test_mbert)
mbert_pred_labels = np.argmax(mbert_preds.predictions, axis=-1)
mbert_true_labels = mbert_preds.label_ids

mbert_pred_tags, mbert_true_tags = decode_labels(mbert_pred_labels, mbert_true_labels, id2label)

print("mBERT Evaluation:")
print(seqeval_classification_report(mbert_true_tags, mbert_pred_tags))

In [None]:
import shap

# Use a small sample for explanation
test_sample = encoded_test.select(range(5))

explainer = shap.Explainer(xlmr_model, masker=shap.maskers.Text(tokenizer))
shap_values = explainer([tokenizer.decode(x) for x in test_sample['input_ids']])
shap.plots.text(shap_values)

# For LIME, see lime.lime_text.LimeTextExplainer (not shown here for brevity)