In [1]:
!pip install conllu



In [4]:
import requests
from conllu import parse

# Source: https://universaldependencies.org/u/pos/index.html
ud_pos_tags = {
    "ADJ": "adjective",
    "ADP": "adposition",
    "ADV": "adverb",
    "AUX": "auxiliary",
    "CCONJ": "coordinating conjunction",
    "DET": "determiner",
    "INTJ": "interjection",
    "NOUN": "noun",
    "NUM": "numeral",
    "PART": "particle",
    "PRON": "pronoun",
    "PROPN": "proper noun",
    "PUNCT": "punctuation",
    "SCONJ": "subordinating conjunction",
    # "SYM": "symbol", # not used in PerDT https://universaldependencies.org/treebanks/fa_perdt/index.html
    "VERB": "verb",
    "X": "other"
}

training_set_url = "https://github.com/UniversalDependencies/UD_Persian-PerDT/blob/d728a98d73bd2e52b693c582f249e162a76b6b8f/fa_perdt-ud-train.conllu"
test_set_url = "https://github.com/UniversalDependencies/UD_Persian-PerDT/blob/d728a98d73bd2e52b693c582f249e162a76b6b8f/fa_perdt-ud-test.conllu"
dev_set_url = "https://github.com/UniversalDependencies/UD_Persian-PerDT/blob/d728a98d73bd2e52b693c582f249e162a76b6b8f/fa_perdt-ud-dev.conllu"

def get_raw_dataset(url):

    dataset_url = url.replace("github.com", "raw.githubusercontent.com").replace("/blob/", "/")
    response = requests.get(dataset_url)
    response.encoding = 'utf-8'
    data = response.text

    parsed_sentences = parse(data)

    sentences = []
    pos_tags = []

    for sentence in parsed_sentences:
        tokens = [token["form"] for token in sentence if token["form"] is not None]
        tags = [token["upostag"] for token in sentence if token["upostag"] is not None]
        sentences.append(tokens)
        pos_tags.append(tags)

    return sentences, pos_tags

raw_train_sentences, raw_train_pos_tags = get_raw_dataset(training_set_url)
raw_test_sentences, raw_test_pos_tags = get_raw_dataset(test_set_url)
raw_dev_sentences, raw_dev_pos_tags = get_raw_dataset(dev_set_url)

# Create label mappings
unique_labels = sorted(set(tag for tags in raw_train_pos_tags for tag in tags))
label2id = {label: i for i, label in enumerate(unique_labels)}
id2label = {i: label for label, i in label2id.items()}

print("Number of unique POS tags:", len(unique_labels))
print("Label to ID mapping:", label2id)
print("ID to label mapping:", id2label)

print("number of training sentences:", len(raw_train_sentences))
print("number of dev sentences:", len(raw_dev_sentences))
print("number of test sentences:", len(raw_test_sentences))



Number of unique POS tags: 17
Label to ID mapping: {'ADJ': 0, 'ADP': 1, 'ADV': 2, 'AUX': 3, 'CCONJ': 4, 'DET': 5, 'INTJ': 6, 'NOUN': 7, 'NUM': 8, 'PART': 9, 'PRON': 10, 'PROPN': 11, 'PUNCT': 12, 'SCONJ': 13, 'VERB': 14, 'X': 15, '_': 16}
ID to label mapping: {0: 'ADJ', 1: 'ADP', 2: 'ADV', 3: 'AUX', 4: 'CCONJ', 5: 'DET', 6: 'INTJ', 7: 'NOUN', 8: 'NUM', 9: 'PART', 10: 'PRON', 11: 'PROPN', 12: 'PUNCT', 13: 'SCONJ', 14: 'VERB', 15: 'X', 16: '_'}
number of training sentences: 26196
number of dev sentences: 1456
number of test sentences: 1455


In [3]:
def align_labels_with_tokens(words, labels, tokenizer):
    tokens = []
    label_ids = []

    for word, label in zip(words, labels):
        word_tokens = tokenizer.tokenize(word)
        tokens.extend(word_tokens)
        # Assign the label to first subtoken, and -100 (ignored) to others
        label_ids.extend([label2id[label]] + [-100] * (len(word_tokens) - 1))

    return tokenizer.convert_tokens_to_ids(tokens), label_ids


In [4]:
from sklearn.metrics import classification_report, f1_score, accuracy_score

def compute_metrics(p):
    preds, labels = p
    preds = preds.argmax(-1)

    true_preds = []
    true_labels = []

    for pred, label in zip(preds, labels):
        for p_i, l_i in zip(pred, label):
            if l_i != -100:
                true_preds.append(id2label[p_i])
                true_labels.append(id2label[l_i])

    print("\nTag-level classification report:")
    print(classification_report(true_labels, true_preds, labels=list(label2id.keys()), digits=3))

    return {
        "accuracy": accuracy_score(true_labels, true_preds),
        "f1_macro": f1_score(true_labels, true_preds, average="macro"),
        "f1_micro": f1_score(true_labels, true_preds, average="micro"),
    }


In [5]:

from transformers import AutoTokenizer
from sklearn.model_selection import train_test_split

model_ckpt = "shekar-ai/albert-base-v2-persian-zwnj-naab-mlm"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt, use_fast=True)

print("Tokenizer model max length:", tokenizer.model_max_length)

def tokenize_and_align_labels(sentences, pos_tags, tokenizer):
    """
    Tokenizes the input sentences and aligns the POS tags with the tokens.
    """
    encoded_inputs = []
    encoded_labels = []

    for words, labels in zip(sentences, pos_tags):
        input_ids, label_ids = align_labels_with_tokens(words, labels, tokenizer)
        encoded_inputs.append(input_ids)
        encoded_labels.append(label_ids)

    return encoded_inputs, encoded_labels

encoded_train_inputs, encoded_train_labels = tokenize_and_align_labels(
    raw_train_sentences, raw_train_pos_tags, tokenizer
)

encoded_test_inputs, encoded_test_labels = tokenize_and_align_labels(
    raw_test_sentences, raw_test_pos_tags, tokenizer
)

encoded_dev_inputs, encoded_dev_labels = tokenize_and_align_labels(
    raw_dev_sentences, raw_dev_pos_tags, tokenizer
)


Tokenizer model max length: 512


In [6]:
import torch
from torch.utils.data import Dataset

class POSDataset(Dataset):
    def __init__(self, input_ids, labels, tokenizer, max_len=128):
        self.input_ids = input_ids
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        input_id = self.input_ids[idx][:self.max_len]
        label = self.labels[idx][:self.max_len]
        attention_mask = [1] * len(input_id)

        # Padding
        pad_len = self.max_len - len(input_id)
        input_id += [self.tokenizer.pad_token_id] * pad_len
        label += [-100] * pad_len
        attention_mask += [0] * pad_len

        return {
            "input_ids": torch.tensor(input_id),
            "attention_mask": torch.tensor(attention_mask),
            "labels": torch.tensor(label),
        }
    
train_dataset = POSDataset(encoded_train_inputs, encoded_train_labels, tokenizer, max_len=tokenizer.model_max_length)
test_dataset = POSDataset(encoded_test_inputs, encoded_test_labels, tokenizer, max_len=tokenizer.model_max_length)
dev_dataset = POSDataset(encoded_dev_inputs, encoded_dev_labels, tokenizer, max_len=tokenizer.model_max_length)

In [None]:
from transformers import TrainingArguments, Trainer, DataCollatorForTokenClassification
from transformers import AutoModelForTokenClassification

model = AutoModelForTokenClassification.from_pretrained(
    model_ckpt,
    num_labels=len(label2id),
    id2label=id2label,
    label2id=label2id
)

training_args = TrainingArguments(
    output_dir="./pos-albert",
    eval_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=10,
    logging_dir="./logs",
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",
    greater_is_better=True,
    push_to_hub=True,
    report_to="tensorboard",
    hub_model_id="shekar-ai/albert-base-v2-persian-pos-tagger",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    data_collator=DataCollatorForTokenClassification(tokenizer),
    compute_metrics=compute_metrics,
)

trainer.train()
trainer.push_to_hub(commit_message="Training complete", blocking=True)
metrics = trainer.evaluate()
print(metrics)


Some weights of AlbertForTokenClassification were not initialized from the model checkpoint at shekar-ai/albert-base-v2-persian-zwnj-naab-mlm and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro,F1 Micro
1,0.1121,0.102692,0.969913,0.87841,0.969913
2,0.0726,0.081732,0.975316,0.902546,0.975316
3,0.0563,0.080625,0.976258,0.902575,0.976258
4,0.0396,0.088837,0.976135,0.906743,0.976135
5,0.028,0.105485,0.976462,0.905975,0.976462
6,0.0124,0.122969,0.976585,0.908196,0.976585
7,0.0078,0.138499,0.976299,0.906944,0.976299
8,0.0042,0.154378,0.975848,0.903952,0.975848
9,0.0017,0.163003,0.976258,0.902693,0.976258



Tag-level classification report:
              precision    recall  f1-score   support

         ADJ      0.943     0.925     0.934      1652
         ADP      0.993     0.991     0.992      3407
         ADV      0.843     0.883     0.863       384
         AUX      0.994     0.987     0.991       899
       CCONJ      0.998     0.999     0.999      1026
         DET      0.950     0.978     0.964       491
        INTJ      0.643     0.667     0.655        27
        NOUN      0.974     0.960     0.967      8219
         NUM      0.965     0.945     0.955       293
        PART      0.692     0.964     0.806        28
        PRON      0.976     0.988     0.982      1126
       PROPN      0.828     0.905     0.865      1111
       PUNCT      0.995     0.997     0.996      2141
       SCONJ      0.973     0.979     0.976       632
        VERB      0.996     0.994     0.995      2696
           X      0.000     0.000     0.000         1
           _      0.993     0.997     0.995    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Tag-level classification report:
              precision    recall  f1-score   support

         ADJ      0.936     0.942     0.939      1652
         ADP      0.994     0.995     0.994      3407
         ADV      0.907     0.914     0.911       384
         AUX      0.994     0.990     0.992       899
       CCONJ      0.996     0.999     0.998      1026
         DET      0.968     0.978     0.973       491
        INTJ      0.846     0.815     0.830        27
        NOUN      0.972     0.970     0.971      8219
         NUM      0.953     0.966     0.959       293
        PART      0.929     0.929     0.929        28
        PRON      0.993     0.987     0.990      1126
       PROPN      0.886     0.891     0.888      1111
       PUNCT      0.996     0.997     0.996      2141
       SCONJ      0.989     0.978     0.983       632
        VERB      0.996     0.995     0.996      2696
           X      0.000     0.000     0.000         1
           _      0.997     0.993     0.995    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Tag-level classification report:
              precision    recall  f1-score   support

         ADJ      0.941     0.936     0.938      1652
         ADP      0.992     0.994     0.993      3407
         ADV      0.928     0.904     0.916       384
         AUX      0.993     0.992     0.993       899
       CCONJ      0.996     0.998     0.997      1026
         DET      0.966     0.980     0.973       491
        INTJ      0.840     0.778     0.808        27
        NOUN      0.974     0.972     0.973      8219
         NUM      0.986     0.962     0.974       293
        PART      1.000     0.857     0.923        28
        PRON      0.987     0.988     0.987      1126
       PROPN      0.879     0.905     0.892      1111
       PUNCT      0.999     0.998     0.998      2141
       SCONJ      0.989     0.981     0.985       632
        VERB      0.995     0.998     0.996      2696
           X      0.000     0.000     0.000         1
           _      1.000     0.997     0.998    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Tag-level classification report:
              precision    recall  f1-score   support

         ADJ      0.938     0.940     0.939      1652
         ADP      0.991     0.992     0.992      3407
         ADV      0.926     0.914     0.920       384
         AUX      0.997     0.990     0.993       899
       CCONJ      0.997     0.998     0.998      1026
         DET      0.974     0.978     0.976       491
        INTJ      0.821     0.852     0.836        27
        NOUN      0.976     0.969     0.973      8219
         NUM      0.973     0.976     0.974       293
        PART      1.000     0.929     0.963        28
        PRON      0.992     0.989     0.991      1126
       PROPN      0.858     0.912     0.884      1111
       PUNCT      1.000     0.998     0.999      2141
       SCONJ      0.990     0.981     0.986       632
        VERB      0.995     0.998     0.996      2696
           X      0.000     0.000     0.000         1
           _      1.000     0.990     0.995    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Tag-level classification report:
              precision    recall  f1-score   support

         ADJ      0.935     0.941     0.938      1652
         ADP      0.992     0.993     0.993      3407
         ADV      0.904     0.930     0.917       384
         AUX      0.998     0.992     0.995       899
       CCONJ      0.995     0.997     0.996      1026
         DET      0.972     0.986     0.979       491
        INTJ      0.815     0.815     0.815        27
        NOUN      0.976     0.969     0.972      8219
         NUM      0.970     0.983     0.976       293
        PART      1.000     0.929     0.963        28
        PRON      0.993     0.993     0.993      1126
       PROPN      0.874     0.902     0.888      1111
       PUNCT      0.999     0.998     0.998      2141
       SCONJ      0.992     0.983     0.987       632
        VERB      0.996     0.997     0.997      2696
           X      0.000     0.000     0.000         1
           _      1.000     0.990     0.995    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Tag-level classification report:
              precision    recall  f1-score   support

         ADJ      0.943     0.935     0.939      1652
         ADP      0.994     0.992     0.993      3407
         ADV      0.915     0.924     0.920       384
         AUX      0.998     0.991     0.994       899
       CCONJ      0.996     0.997     0.997      1026
         DET      0.972     0.982     0.977       491
        INTJ      0.852     0.852     0.852        27
        NOUN      0.974     0.971     0.972      8219
         NUM      0.970     0.980     0.975       293
        PART      1.000     0.929     0.963        28
        PRON      0.992     0.991     0.992      1126
       PROPN      0.873     0.907     0.890      1111
       PUNCT      0.999     0.999     0.999      2141
       SCONJ      0.990     0.979     0.985       632
        VERB      0.996     0.998     0.997      2696
           X      0.000     0.000     0.000         1
           _      1.000     0.993     0.997    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Tag-level classification report:
              precision    recall  f1-score   support

         ADJ      0.939     0.939     0.939      1652
         ADP      0.992     0.992     0.992      3407
         ADV      0.917     0.922     0.919       384
         AUX      0.999     0.992     0.996       899
       CCONJ      0.997     0.998     0.998      1026
         DET      0.968     0.976     0.972       491
        INTJ      0.821     0.852     0.836        27
        NOUN      0.976     0.969     0.972      8219
         NUM      0.976     0.980     0.978       293
        PART      1.000     0.929     0.963        28
        PRON      0.991     0.988     0.990      1126
       PROPN      0.865     0.913     0.888      1111
       PUNCT      0.999     0.999     0.999      2141
       SCONJ      0.990     0.978     0.984       632
        VERB      0.996     0.999     0.997      2696
           X      0.000     0.000     0.000         1
           _      0.997     0.993     0.995    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Tag-level classification report:
              precision    recall  f1-score   support

         ADJ      0.934     0.937     0.935      1652
         ADP      0.993     0.991     0.992      3407
         ADV      0.916     0.911     0.914       384
         AUX      0.999     0.990     0.994       899
       CCONJ      0.995     0.997     0.996      1026
         DET      0.970     0.980     0.975       491
        INTJ      0.786     0.815     0.800        27
        NOUN      0.974     0.970     0.972      8219
         NUM      0.963     0.973     0.968       293
        PART      1.000     0.929     0.963        28
        PRON      0.994     0.990     0.992      1126
       PROPN      0.871     0.906     0.888      1111
       PUNCT      0.999     1.000     0.999      2141
       SCONJ      0.990     0.979     0.985       632
        VERB      0.996     0.997     0.996      2696
           X      0.000     0.000     0.000         1
           _      1.000     0.993     0.997    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Tag-level classification report:
              precision    recall  f1-score   support

         ADJ      0.939     0.934     0.937      1652
         ADP      0.992     0.992     0.992      3407
         ADV      0.911     0.930     0.920       384
         AUX      0.999     0.990     0.994       899
       CCONJ      0.996     0.997     0.997      1026
         DET      0.970     0.976     0.973       491
        INTJ      0.750     0.778     0.764        27
        NOUN      0.976     0.969     0.972      8219
         NUM      0.966     0.976     0.971       293
        PART      1.000     0.929     0.963        28
        PRON      0.992     0.992     0.992      1126
       PROPN      0.868     0.913     0.890      1111
       PUNCT      0.998     1.000     0.999      2141
       SCONJ      0.992     0.983     0.987       632
        VERB      0.996     0.999     0.997      2696
           X      0.000     0.000     0.000         1
           _      1.000     0.997     0.998    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [12]:
from shekar import WordTokenizer
from transformers import AutoModelForTokenClassification, AutoTokenizer
import torch
model = AutoModelForTokenClassification.from_pretrained("shekar-ai/albert-base-v2-persian-pos-tagger")
tokenizer = AutoTokenizer.from_pretrained("shekar-ai/albert-base-v2-persian-pos-tagger")

label2id = model.config.label2id
id2label = model.config.id2label

word_tokenizer = WordTokenizer()
test_sentence = list(word_tokenizer("این یک متن نمونه است."))
print("Test sentence:", test_sentence)
model.eval()

# Tokenize with word-level alignment
tokens = []
word_ids = []
for word in test_sentence:
    word_tokens = tokenizer.tokenize(word)
    tokens.extend(word_tokens)
    word_ids.extend([word] * len(word_tokens))  # repeat word for each sub-token

# Convert to IDs
input_ids = tokenizer.convert_tokens_to_ids(tokens)
attention_mask = [1] * len(input_ids)

# Pad if necessary
max_len = tokenizer.model_max_length
pad_len = max_len - len(input_ids)
input_ids += [tokenizer.pad_token_id] * pad_len
attention_mask += [0] * pad_len

input_tensor = {
    "input_ids": torch.tensor([input_ids]),
    "attention_mask": torch.tensor([attention_mask]),
}

with torch.no_grad():
    outputs = model(**input_tensor)
    predictions = outputs.logits.argmax(dim=-1).squeeze().tolist()


# Remove padding and subword duplicates
final_preds = []
prev_word = None
for token, word, pred_id in zip(tokens, word_ids, predictions):
    if word != prev_word:
        final_preds.append(id2label[pred_id])
        prev_word = word

# Display results
for word, tag in zip(test_sentence, final_preds):
    print(f"{word}\t{tag}")


Test sentence: ['این', 'یک', 'متن', 'نمونه', 'است', '.']
این	PRON
یک	NUM
متن	NOUN
نمونه	ADJ
است	AUX
.	PUNCT
