In [2]:
# ======================================
# BiLSTM POS Tagger for Arabic (PyTorch)
# ======================================

!pip install -q torch

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import numpy as np

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

TRAIN_PATH = "train.conll"
DEV_PATH   = "dev.conll"
TEST_PATH  = "test.conll"

PAD_WORD = "<PAD>"
UNK_WORD = "<UNK>"
PAD_TAG  = "<PAD_TAG>"


# ---------------------------
# 1) Read .conll data
# ---------------------------
def read_conll(path):
    sentences = []
    tokens = []
    tags = []
    with open(path, encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                if tokens:
                    sentences.append((tokens, tags))
                    tokens, tags = [], []
                continue
            parts = line.split("\t")
            if len(parts) != 2:
                continue
            w, t = parts
            tokens.append(w)
            tags.append(t)
    if tokens:
        sentences.append((tokens, tags))
    return sentences

train_sents = read_conll(TRAIN_PATH)
dev_sents   = read_conll(DEV_PATH)
test_sents  = read_conll(TEST_PATH)

print("#train sentences:", len(train_sents))
print("#dev sentences:", len(dev_sents))
print("#test sentences:", len(test_sents))


# ---------------------------
# 2) Build vocabularies
# ---------------------------
word_freq = Counter()
tag_set = set()

for tokens, tags in train_sents:
    word_freq.update(tokens)
    tag_set.update(tags)

# keep all words with freq >= 1 (you can set 2 if you want)
min_freq = 1
words = [w for w, c in word_freq.items() if c >= min_freq]

word2id = {PAD_WORD: 0, UNK_WORD: 1}
for w in words:
    word2id[w] = len(word2id)

id2word = {i: w for w, i in word2id.items()}

tag2id = {PAD_TAG: 0}
for t in sorted(tag_set):
    tag2id[t] = len(tag2id)
id2tag = {i: t for t, i in tag2id.items()}

vocab_size = len(word2id)
num_tags = len(tag2id)
pad_word_id = word2id[PAD_WORD]
unk_word_id = word2id[UNK_WORD]
pad_tag_id  = tag2id[PAD_TAG]

print("Vocab size:", vocab_size)
print("Num tags:", num_tags)
print("Tags:", tag2id)


# ---------------------------
# 3) Dataset & DataLoader
# ---------------------------
class PosDataset(Dataset):
    def __init__(self, sentences, word2id, tag2id):
        self.data = []
        for tokens, tags in sentences:
            w_ids = [word2id.get(w, unk_word_id) for w in tokens]
            t_ids = [tag2id[t] for t in tags]
            self.data.append((w_ids, t_ids))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch):
    # batch: list of (w_ids, t_ids)
    lengths = [len(x[0]) for x in batch]
    max_len = max(lengths)

    batch_words = []
    batch_tags = []

    for (w_ids, t_ids) in batch:
        # pad words
        w_padded = w_ids + [pad_word_id] * (max_len - len(w_ids))
        t_padded = t_ids + [pad_tag_id] * (max_len - len(t_ids))
        batch_words.append(w_padded)
        batch_tags.append(t_padded)

    batch_words = torch.tensor(batch_words, dtype=torch.long)
    batch_tags  = torch.tensor(batch_tags, dtype=torch.long)
    lengths = torch.tensor(lengths, dtype=torch.long)

    return batch_words, batch_tags, lengths

batch_size = 32

train_dataset = PosDataset(train_sents, word2id, tag2id)
dev_dataset   = PosDataset(dev_sents, word2id, tag2id)
test_dataset  = PosDataset(test_sents, word2id, tag2id)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                          collate_fn=collate_fn)
dev_loader   = DataLoader(dev_dataset, batch_size=batch_size, shuffle=False,
                          collate_fn=collate_fn)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                          collate_fn=collate_fn)


# ---------------------------
# 4) BiLSTM model
# ---------------------------
class BiLSTMTagger(nn.Module):
    def __init__(self, vocab_size, tagset_size, emb_dim=128, hidden_dim=256, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_word_id)
        self.lstm = nn.LSTM(
            emb_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
        )
        self.fc = nn.Linear(hidden_dim * 2, tagset_size)

    def forward(self, x, lengths):
        # x: (batch, seq_len)
        emb = self.embedding(x)  # (batch, seq_len, emb_dim)

        # pack for efficient LSTM
        packed = nn.utils.rnn.pack_padded_sequence(
            emb, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        packed_out, _ = self.lstm(packed)
        lstm_out, _ = nn.utils.rnn.pad_packed_sequence(
            packed_out, batch_first=True
        )  # (batch, seq_len, hidden*2)

        logits = self.fc(lstm_out)  # (batch, seq_len, num_tags)
        return logits

model = BiLSTMTagger(vocab_size, num_tags).to(DEVICE)
criterion = nn.CrossEntropyLoss(ignore_index=pad_tag_id)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


# ---------------------------
# 5) Training & eval functions
# ---------------------------
def train_epoch(model, loader):
    model.train()
    total_loss = 0.0
    for words, tags, lengths in loader:
        words = words.to(DEVICE)
        tags = tags.to(DEVICE)
        lengths = lengths.to(DEVICE)

        optimizer.zero_grad()
        logits = model(words, lengths)  # (batch, seq_len, num_tags)

        loss = criterion(
            logits.view(-1, num_tags),
            tags.view(-1)
        )
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(loader)

def eval_accuracy(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for words, tags, lengths in loader:
            words = words.to(DEVICE)
            tags = tags.to(DEVICE)
            lengths = lengths.to(DEVICE)

            logits = model(words, lengths)
            preds = logits.argmax(dim=-1)  # (batch, seq_len)

            # compare while ignoring padding
            for p_seq, t_seq in zip(preds.cpu().numpy(), tags.cpu().numpy()):
                for p, t in zip(p_seq, t_seq):
                    if t == pad_tag_id:
                        continue
                    if p == t:
                        correct += 1
                    total += 1
    return correct / total if total > 0 else 0.0


# ---------------------------
# 6) Train loop
# ---------------------------
num_epochs = 5

for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch(model, train_loader)
    dev_acc = eval_accuracy(model, dev_loader)
    print(f"Epoch {epoch}: train_loss={train_loss:.4f}, dev_acc={dev_acc:.4f}")

# Final test accuracy
test_acc = eval_accuracy(model, test_loader)
print("Test accuracy (BiLSTM):", test_acc)


# ---------------------------
# 7) Predict on one sentence
# ---------------------------
def predict_sentence_bilstm(tokens):
    model.eval()
    w_ids = [word2id.get(w, unk_word_id) for w in tokens]
    length = torch.tensor([len(w_ids)], dtype=torch.long)
    x = torch.tensor([w_ids], dtype=torch.long)
    x = x.to(DEVICE)
    length = length.to(DEVICE)

    with torch.no_grad():
        logits = model(x, length)
        preds = logits.argmax(dim=-1).cpu().numpy()[0]

    tags = [id2tag[p] for p in preds[:len(tokens)]]
    return list(zip(tokens, tags))

example = ["سوريا", "تستقبل", "وفدا", "رسميا", "."]
print("Example prediction:", predict_sentence_bilstm(example))


Using device: cpu
#train sentences: 6075
#dev sentences: 909
#test sentences: 680
Vocab size: 21917
Num tags: 18
Tags: {'<PAD_TAG>': 0, 'ADJ': 1, 'ADP': 2, 'ADV': 3, 'AUX': 4, 'CCONJ': 5, 'DET': 6, 'INTJ': 7, 'NOUN': 8, 'NUM': 9, 'PART': 10, 'PRON': 11, 'PROPN': 12, 'PUNCT': 13, 'SCONJ': 14, 'SYM': 15, 'VERB': 16, 'X': 17}
Epoch 1: train_loss=0.9562, dev_acc=0.8173
Epoch 2: train_loss=0.4176, dev_acc=0.8729
Epoch 3: train_loss=0.2734, dev_acc=0.8942
Epoch 4: train_loss=0.1894, dev_acc=0.9033
Epoch 5: train_loss=0.1304, dev_acc=0.9115
Test accuracy (BiLSTM): 0.9172091706764789
Example prediction: [('سوريا', 'X'), ('تستقبل', 'VERB'), ('وفدا', 'NOUN'), ('رسميا', 'ADJ'), ('.', 'PUNCT')]
