In [None]:
!pip install evaluate



In [None]:
from typing import List
import numpy as np
import torch
import evaluate
from sklearn.model_selection import train_test_split
import nltk

nltk.download('treebank')

[nltk_data] Downloading package treebank to /root/nltk_data...
[nltk_data]   Unzipping corpora/treebank.zip.


True

In [None]:
# Load treebank dataset
tagged_sentences = nltk.corpus.treebank.tagged_sents ()
print("Number of samples:", len(tagged_sentences))

# Save sentences and tags
sentences, sentence_tags = [], []
for tagged_sentence in tagged_sentences:
    sentence = [word.lower() for word, _ in tagged_sentence]
    tags = [tag for _, tag in tagged_sentence]
    sentences.append(sentence)
    sentence_tags.append(tags)

Number of samples: 3914


In [None]:
def create_label2id(sentence_tags: List[List[str]]) -> dict:
    all_labels = set(tag for tags in sentence_tags for tag in tags)
    label2id = {label: idx for idx, label in enumerate(sorted(all_labels))}
    return label2id

label2id = create_label2id(sentence_tags)

# In ra label2id
print(label2id)


{'#': 0, '$': 1, "''": 2, ',': 3, '-LRB-': 4, '-NONE-': 5, '-RRB-': 6, '.': 7, ':': 8, 'CC': 9, 'CD': 10, 'DT': 11, 'EX': 12, 'FW': 13, 'IN': 14, 'JJ': 15, 'JJR': 16, 'JJS': 17, 'LS': 18, 'MD': 19, 'NN': 20, 'NNP': 21, 'NNPS': 22, 'NNS': 23, 'PDT': 24, 'POS': 25, 'PRP': 26, 'PRP$': 27, 'RB': 28, 'RBR': 29, 'RBS': 30, 'RP': 31, 'SYM': 32, 'TO': 33, 'UH': 34, 'VB': 35, 'VBD': 36, 'VBG': 37, 'VBN': 38, 'VBP': 39, 'VBZ': 40, 'WDT': 41, 'WP': 42, 'WP$': 43, 'WRB': 44, '``': 45}


In [None]:
train_sentences, test_sentences, train_tags, test_tags = train_test_split(
    sentences, sentence_tags, test_size=0.3
)

valid_sentences, test_sentences, valid_tags, test_tags = train_test_split(
    test_sentences, test_tags, test_size=0.5
)

In [None]:
# Tokenization
from transformers import AutoTokenizer
from torch.utils.data import Dataset

model_name = "QCRI/bert-base-multilingual-cased-pos-english"
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    use_fast=True
)

MAX_LEN = 256

class PosTaggingDataset(Dataset):
    def __init__(self, sentences: List[List[str]], tags: List[List[str]], tokenizer, label2id, max_len=MAX_LEN):
        super().__init__()
        self.sentences = sentences
        self.tags = tags
        self.max_len = max_len
        self.tokenizer = tokenizer
        self.label2id = label2id

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx: int):
        input_token = self.sentences[idx]
        label_token = self.tags[idx]

        input_token = self.tokenizer.convert_tokens_to_ids(input_token)
        attention_mask = [1] * len(input_token)
        labels = [self.label2id[token] for token in label_token]

        return {
            "input_ids": self.pad_and_truncate(input_token, pad_id=self.tokenizer.pad_token_id),
            "labels": self.pad_and_truncate(labels, pad_id=-100),  # Assuming "O" is for outside tokens
            "attention_mask": self.pad_and_truncate(attention_mask, pad_id=0)
        }

    def pad_and_truncate(self, inputs: List[int], pad_id: int):
        if len(inputs) < self.max_len:
            padded_inputs = inputs + [pad_id] * (self.max_len - len(inputs))
        else:
            padded_inputs = inputs[:self.max_len]
        return torch.as_tensor(padded_inputs)

In [None]:
train_dataset = PosTaggingDataset(train_sentences, train_tags, tokenizer, label2id)  # Ensure this class is defined
val_dataset = PosTaggingDataset(valid_sentences, valid_tags, tokenizer, label2id)
test_dataset = PosTaggingDataset(test_sentences, test_tags, tokenizer, label2id)

In [None]:
from transformers import AutoTokenizer, AutoModelForTokenClassification

model_name = "QCRI/bert-base-multilingual-cased-pos-english"
model = AutoModelForTokenClassification.from_pretrained(model_name)

Some weights of the model checkpoint at QCRI/bert-base-multilingual-cased-pos-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
accuracy = evaluate.load("accuracy")
ignore_label = len(label2id)  # Ensure label2id is defined earlier

# Define the compute_metrics function
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    mask = labels != ignore_label
    predictions = np.argmax(predictions, axis=-1)
    return accuracy.compute(predictions=predictions[mask], references=labels[mask])

In [None]:
from transformers import TrainingArguments , Trainer

# Set up training arguments
training_args = TrainingArguments(
    output_dir="out_dir",
    learning_rate=1e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# Train the model
trainer.train()

  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.501574,0.088287
2,No log,0.376496,0.091833
3,0.764000,0.34743,0.092725
4,0.764000,0.330706,0.093038
5,0.764000,0.326093,0.093331
6,0.287200,0.319185,0.093577
7,0.287200,0.316724,0.09369
8,0.287200,0.317158,0.09367
9,0.226700,0.316924,0.093604
10,0.226700,0.316347,0.093663


TrainOutput(global_step=1720, training_loss=0.3979897144228913, metrics={'train_runtime': 1482.1532, 'train_samples_per_second': 18.48, 'train_steps_per_second': 1.16, 'total_flos': 3579882599208960.0, 'train_loss': 0.3979897144228913, 'epoch': 10.0})