# Seq-to-Seq Approach

a - loading libs

In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score, recall_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import nltk
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import DataLoader
from sklearn.preprocessing import LabelEncoder
import re

In [None]:
!pip install --upgrade nltk



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
device

device(type='cuda')

b - loadidng parsBERT for tokenization and embedding

In [None]:
tokenizer = AutoTokenizer.from_pretrained("HooshvareLab/bert-base-parsbert-uncased")
bert_model = AutoModel.from_pretrained("HooshvareLab/bert-base-parsbert-uncased").to(device)



c - load daata

In [None]:
class SimpleSpaceTokenizer:
    def __init__(self):
        self.token2id = {}
        self.id2token = {}
        self.vocab_size = 0

    def fit_on_texts(self, texts):
        unique_tokens = set()
        for text in texts:
            tokens = text.split(" ")
            unique_tokens.update(tokens)

        self.token2id = {token: idx for idx, token in enumerate(unique_tokens, start=1)}
        self.id2token = {idx: token for token, idx in self.token2id.items()}
        self.vocab_size = len(self.token2id) + 1  # Adding 1 for padding token

    def tokenize(self, texts, max_length=48):
        tokenized_texts = []
        for text in texts:
            tokens = text.split(" ")
            token_ids = [self.token2id.get(token, 0) for token in tokens][:max_length]
            padding_length = max_length - len(token_ids)
            token_ids += [0] * padding_length
            tokenized_texts.append(token_ids)
        return torch.tensor(tokenized_texts)

    def decode(self, token_ids):
        return " ".join([self.id2token.get(token_id, "") for token_id in token_ids if token_id != 0])

In [None]:
train_data = pd.read_csv(f'Poem Meter Dataset/train_samples.csv')

poem_text = train_data['poem_text']
metre = train_data['metre'].astype(str)

inputs = tokenizer(poem_text.tolist(), padding=True, truncation=True, return_tensors="pt", max_length=14)
input_ids = inputs['input_ids'].squeeze().to(device)
attention_mask = inputs['attention_mask'].squeeze().to(device)

label_tokenizer = SimpleSpaceTokenizer()
label_tokenizer.fit_on_texts(metre.tolist())
labels = label_tokenizer.tokenize(metre.tolist(), max_length=6).to(device)

train_loader = DataLoader(torch.utils.data.TensorDataset(input_ids, attention_mask, labels), batch_size=512, shuffle=True)

In [None]:
val_data = pd.read_csv(f'Poem Meter Dataset/validation_samples.csv')

val_poem_text = val_data['poem_text']
val_metre = val_data['metre'].astype(str)

val_inputs = tokenizer(val_poem_text.tolist(), padding=True, truncation=True, return_tensors="pt", max_length=14)
val_input_ids = val_inputs['input_ids'].squeeze().to(device)
val_attention_mask = val_inputs['attention_mask'].squeeze().to(device)
val_labels = label_tokenizer.tokenize(val_metre.tolist(), max_length=6).to(device)

val_loader = DataLoader(torch.utils.data.TensorDataset(val_input_ids, val_attention_mask, val_labels), batch_size=512, shuffle=True)

In [None]:
test_data = pd.read_csv(f'Poem Meter Dataset/test_samples.csv')

test_poem_text = test_data['poem_text']

test_inputs = tokenizer(test_poem_text.tolist(), padding=True, truncation=True, return_tensors="pt", max_length=14)

test_input_ids = test_inputs['input_ids'].squeeze().to(device)
test_attention_mask = test_inputs['attention_mask'].squeeze().to(device)

test_loader = DataLoader(torch.utils.data.TensorDataset(test_input_ids, test_attention_mask), batch_size=512, shuffle=True)

d - Model Architecture

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Encoder(nn.Module):  # Bi-LSTM
    def __init__(self, bert_model, hidden_size, num_layers):
        super(Encoder, self).__init__()
        self.bert = bert_model
        self.bi_lstm = nn.LSTM(bert_model.config.hidden_size, hidden_size, num_layers, batch_first=True, bidirectional=True)

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            bert_embedding_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)[0]
        outputs, (hidden, cell) = self.bi_lstm(bert_embedding_outputs)
        return outputs, hidden, cell

In [None]:
class Attention(nn.Module):
    def __init__(self, encoder_hidden_size, decoder_hidden_size, method='general'):
        super(Attention, self).__init__()
        self.method = method
        self.encoder_hidden_size = encoder_hidden_size * 2  # Bi-directional
        self.decoder_hidden_size = decoder_hidden_size

        if self.method == 'general':
            self.attn = nn.Linear(self.encoder_hidden_size, decoder_hidden_size)
        elif self.method == 'concat':
            self.attn = nn.Linear(self.encoder_hidden_size + decoder_hidden_size, decoder_hidden_size)
            self.v = nn.Parameter(torch.rand(decoder_hidden_size))

    def forward(self, hidden, encoder_outputs):
        if self.method == 'general':
            hidden = hidden[-1].unsqueeze(1)
            logits = torch.bmm(self.attn(encoder_outputs), hidden.transpose(1, 2)).squeeze(2)
        elif self.method == 'concat':
            hidden = hidden[-1].expand(encoder_outputs.shape[0], -1, -1)
            logits = torch.sum(self.v * torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), 2))), dim=2)
        return F.softmax(logits, dim=1)

In [None]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, encoder_hidden_size, decoder_hidden_size, num_layers, attention_method='general'):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, decoder_hidden_size, num_layers, batch_first=True)
        self.w = nn.Linear(decoder_hidden_size + encoder_hidden_size * 2, decoder_hidden_size)
        self.attention = Attention(encoder_hidden_size, decoder_hidden_size, attention_method)
        self.fc = nn.Linear(decoder_hidden_size, vocab_size)

    def forward(self, inputs, hidden, cell, encoder_outputs):
        inputs = inputs.unsqueeze(1)
        embedding = self.embedding(inputs)
        outputs, (hidden, cell) = self.lstm(embedding, (hidden, cell))
        attention_weights = self.attention(hidden, encoder_outputs)
        context_vec = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
        concat_input = torch.cat((outputs, context_vec), dim=2)
        cats = torch.tanh(self.w(concat_input))
        pred = self.fc(cats.squeeze(1))
        return pred, hidden, cell

e - training  and evaluation

In [None]:
hidden_size = 256
num_layers_encoder = 2
num_layers_decoder = 4
output_dim = label_tokenizer.vocab_size
embed_dim = hidden_size

In [None]:
encoder = Encoder(bert_model, hidden_size, num_layers_encoder).to(device)
decoder = Decoder(output_dim, embed_dim, hidden_size, hidden_size, num_layers_decoder).to(device)

encoder_optimizer = optim.Adam(encoder.parameters(), lr=0.001, weight_decay=0.01)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=0.001, weight_decay=0.01)

criterion = nn.CrossEntropyLoss()

In [None]:
import random
def train_epoch(encoder, decoder, dataloader, val_loader, encoder_optimizer, decoder_optimizer, criterion, teacher_forcing_ratio = 0.95):
    encoder.train()
    decoder.train()
    epoch_loss = 0
    best_val_loss = float('inf')
    patience_counter = 0
    patience = 3

    for i, batch in enumerate(dataloader):
        input_ids, attention_mask, labels = [x.to(device) for x in batch]

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        encoder_outputs, hidden, cell = encoder(input_ids, attention_mask)

        decoder_input = torch.zeros(labels.size(0), dtype=torch.long, device=device)  # Start token
        loss = 0

        for t in range(0, labels.size(1)):
            output, hidden, cell = decoder(decoder_input, hidden, cell, encoder_outputs)
            loss += criterion(output, labels[:, t])
            # teacher forcing ....
            if random.random() < teacher_forcing_ratio:
                decoder_input = labels[:, t]
            else:
                decoder_input = output.argmax(1)

        # teacher forcing scheduler ...
        teacher_forcing_ratio = max(0.03, teacher_forcing_ratio - 0.001)

        loss.backward()
        encoder_optimizer.step()
        decoder_optimizer.step()

        epoch_loss += loss.item() / labels.size(1)
        print(f'batch {i}/{len(dataloader)} ,loss: {loss:.4f}')

        # Early stopping and validation check
        if i == (len(dataloader) - 1):
            total_val_loss = 0
            num_val_batches = 0

            encoder.eval()
            decoder.eval()
            with torch.no_grad():
                for val_batch in val_loader:
                    input_ids, attention_mask, labels = [x.to(device) for x in val_batch]
                    encoder_outputs, hidden, cell = encoder(input_ids, attention_mask)
                    decoder_input = torch.zeros(labels.size(0), dtype=torch.long, device=device)
                    val_loss = 0

                    for t in range(0, labels.size(1)):
                        output, hidden, cell = decoder(decoder_input, hidden, cell, encoder_outputs)
                        val_loss += criterion(output, labels[:, t])
                        # decoder_input = labels[:, t]
                        decoder_input = output.argmax(1)

                    total_val_loss += val_loss.item() / labels.size(1)
                    num_val_batches += 1

            avg_val_loss = total_val_loss / num_val_batches if num_val_batches > 0 else float('inf')
            print(f'Batch {i + 1}, Validation Loss: {avg_val_loss:.4f}')

            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                patience_counter = 0
                torch.save({'encoder': encoder.state_dict(), 'decoder': decoder.state_dict()}, 'best_model.pth')
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f'Early stopping triggered during training at batch {i + 1} in epoch.')
                return epoch_loss / len(dataloader)

            encoder.train()
            decoder.train()

    return epoch_loss / len(dataloader)

Train and eval

In [None]:
n_epochs = 20
for epoch in range(n_epochs):
    train_loss = train_epoch(encoder, decoder, train_loader, val_loader, encoder_optimizer, decoder_optimizer, criterion)
    print(f'Epoch {epoch+1}, Training Loss: {train_loss:.4f}')

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
batch 431/1464 ,loss: 3.8973
batch 432/1464 ,loss: 4.0466
batch 433/1464 ,loss: 4.4828
batch 434/1464 ,loss: 3.9284
batch 435/1464 ,loss: 4.0522
batch 433/1464 ,loss: 4.4828
batch 434/1464 ,loss: 3.9284
batch 435/1464 ,loss: 4.0522
batch 436/1464 ,loss: 4.1892
batch 437/1464 ,loss: 3.8183
batch 438/1464 ,loss: 4.0886
batch 436/1464 ,loss: 4.1892
batch 437/1464 ,loss: 3.8183
batch 438/1464 ,loss: 4.0886
batch 439/1464 ,loss: 4.2227
batch 440/1464 ,loss: 4.1994
batch 441/1464 ,loss: 3.9060
batch 439/1464 ,loss: 4.2227
batch 440/1464 ,loss: 4.1994
batch 441/1464 ,loss: 3.9060
batch 442/1464 ,loss: 4.1929
batch 443/1464 ,loss: 4.1243
batch 444/1464 ,loss: 3.9342
batch 442/1464 ,loss: 4.1929
batch 443/1464 ,loss: 4.1243
batch 444/1464 ,loss: 3.9342
batch 445/1464 ,loss: 4.0377
batch 446/1464 ,loss: 3.9911
batch 447/1464 ,loss: 3.9671
batch 445/1464 ,loss: 4.0377
batch 446/1464 ,loss: 3.9911
batch 447/1464 ,loss: 3.9671
batch 4

In [57]:
def evaluation(encoder, decoder, dataloader):
    preds = []
    true_labels = []

    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        for j, val_batch in enumerate(dataloader):
            print(f'batch: {j} / {len(dataloader)}')
            input_ids, attention_mask, labels = [x.to(device) for x in val_batch]
            encoder_outputs, hidden, cell = encoder(input_ids, attention_mask)

            batch_size = input_ids.size(0)
            seq_length = labels.size(1)

            decoder_input = torch.zeros(batch_size, dtype=torch.long, device=device)
            batch_preds = [[] for _ in range(batch_size)]

            hidden = hidden.contiguous()
            cell = cell.contiguous()

            for t in range(0, seq_length):
                output, hidden, cell = decoder(decoder_input, hidden, cell, encoder_outputs)
                decoder_input = output.argmax(1)

                for i in range(batch_size):
                    batch_preds[i].append(decoder_input[i].item())

            preds.extend(batch_preds)
            true_labels.extend(labels[:, 0:].tolist())
    return preds, true_labels

In [None]:
val_preds, val_true_labels = evaluation(encoder, decoder, val_loader)

batch: 0 / 84
batch: 1 / 84
batch: 0 / 84
batch: 1 / 84
batch: 2 / 84
batch: 3 / 84
batch: 2 / 84
batch: 3 / 84
batch: 4 / 84
batch: 5 / 84
batch: 4 / 84
batch: 5 / 84
batch: 6 / 84
batch: 7 / 84
batch: 6 / 84
batch: 7 / 84
batch: 8 / 84
batch: 9 / 84
batch: 8 / 84
batch: 9 / 84
batch: 10 / 84
batch: 11 / 84
batch: 10 / 84
batch: 11 / 84
batch: 12 / 84
batch: 13 / 84
batch: 12 / 84
batch: 13 / 84
batch: 14 / 84
batch: 15 / 84
batch: 14 / 84
batch: 15 / 84
batch: 16 / 84
batch: 17 / 84
batch: 16 / 84
batch: 17 / 84
batch: 18 / 84
batch: 19 / 84
batch: 18 / 84
batch: 19 / 84
batch: 20 / 84
batch: 20 / 84
batch: 21 / 84
batch: 22 / 84
batch: 21 / 84
batch: 22 / 84
batch: 23 / 84
batch: 24 / 84
batch: 23 / 84
batch: 24 / 84
batch: 25 / 84
batch: 26 / 84
batch: 25 / 84
batch: 26 / 84
batch: 27 / 84
batch: 28 / 84
batch: 27 / 84
batch: 28 / 84
batch: 29 / 84
batch: 30 / 84
batch: 29 / 84
batch: 30 / 84
batch: 31 / 84
batch: 32 / 84
batch: 31 / 84
batch: 32 / 84
batch: 33 / 84
batch: 34 / 84


In [None]:
val_preds = np.array(val_preds)
val_true_labels = np.array(val_true_labels)

In [None]:
val_pred_decoded = [label_tokenizer.decode(pred) for pred in val_preds]
val_true_labels_decoded = [label_tokenizer.decode(label) for label in val_true_labels]

In [None]:
val_true_labels[0:10]

array([[ 7,  4, 11,  0,  0,  0],
       [ 3,  3,  3,  1,  0,  0],
       [10, 10,  3,  0,  0,  0],
       [ 4,  7,  4, 11,  0,  0],
       [10, 10,  3,  0,  0,  0],
       [10, 10,  3,  0,  0,  0],
       [10, 10,  3,  0,  0,  0],
       [ 7,  4, 11,  0,  0,  0],
       [13, 13, 13, 13,  0,  0],
       [10, 10,  3,  0,  0,  0]])

array([[ 7,  4, 11,  0,  0,  0],
       [ 3,  3,  3,  1,  0,  0],
       [10, 10,  3,  0,  0,  0],
       [ 4,  7,  4, 11,  0,  0],
       [10, 10,  3,  0,  0,  0],
       [10, 10,  3,  0,  0,  0],
       [10, 10,  3,  0,  0,  0],
       [ 7,  4, 11,  0,  0,  0],
       [13, 13, 13, 13,  0,  0],
       [10, 10,  3,  0,  0,  0]])

Results

In [None]:
for i in range(0,20):
    print(f'val_pred: {val_pred_decoded[i]}')
    print(f'val_true_label: {val_true_labels_decoded[i]}')
    print('--------------------------------------------')

val_pred: فاعلاتن فاعلاتن فاعلن
val_true_label: فعلاتن مفاعلن فعلن
--------------------------------------------
val_pred: فعولن فعولن فعولن فعل
val_true_label: فعولن فعولن فعولن فعل
--------------------------------------------
val_pred: مفاعیلن مفاعیلن فعولن
val_true_label: مفاعیلن مفاعیلن فعولن
--------------------------------------------
val_pred: فعلاتن فعلاتن فعلاتن فعلن
val_true_label: مفاعلن فعلاتن مفاعلن فعلن
--------------------------------------------
val_pred: مفعول مفاعیل مفاعیل
val_true_label: مفاعیلن مفاعیلن فعولن
--------------------------------------------
val_pred: مفاعیلن مفاعیلن فعولن
val_true_label: مفاعیلن مفاعیلن فعولن
--------------------------------------------
val_pred: فعولن فعولن فعولن فعل
val_true_label: مفاعیلن مفاعیلن فعولن
--------------------------------------------
val_pred: فعلاتن مفاعلن فعلن
val_true_label: فعلاتن مفاعلن فعلن
--------------------------------------------
val_pred: فاعلاتن فاعلاتن فاعلاتن فاعلن
val_true_label: مستفعلن مستفعلن مستفعلن مست

Metrics

In [None]:
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
import numpy as np

val_preds = np.array(val_preds)
val_true_labels = np.array(val_true_labels)

val_preds_flat = val_preds.ravel()
val_true_labels_flat = val_true_labels.ravel()

accuracy = accuracy_score(val_true_labels_flat, val_preds_flat)
f1 = f1_score(val_true_labels_flat, val_preds_flat, average='macro', zero_division=1)
recall = recall_score(val_true_labels_flat, val_preds_flat, average='macro', zero_division=1)
precision = precision_score(val_true_labels_flat, val_preds_flat, average='macro', zero_division=1)

print(f"Accuracy: {accuracy:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"Recall: {recall:.4f}")
print(f"Precision: {precision:.4f}")

Accuracy: 0.7590
F1 Score: 0.4108
Recall: 0.4101
Precision: 0.6879
Accuracy: 0.7590
F1 Score: 0.4108
Recall: 0.4101
Precision: 0.6879


In [None]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.4.3-py3-none-any.whl.metadata (19 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.11.8-py3-none-any.whl.metadata (5.2 kB)
Collecting torchmetrics
  Downloading torchmetrics-1.4.3-py3-none-any.whl.metadata (19 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.11.8-py3-none-any.whl.metadata (5.2 kB)
Downloading torchmetrics-1.4.3-py3-none-any.whl (869 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m869.5/869.5 kB[0m [31m16.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.11.8-py3-none-any.whl (26 kB)
Downloading torchmetrics-1.4.3-py3-none-any.whl (869 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m869.5/869.5 kB[0m [31m16.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.11.8-py3-none-any.whl (26 kB)
Installing collected packages: lightning-utilities, to

In [None]:
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

nltk.download('punkt_tab')

from torchmetrics.text import BLEUScore, ROUGEScore

bleu = BLEUScore()
rouge = ROUGEScore()

# NLTK not good for persian tokenization ...
val_pred_str = [' '.join(map(str, pred)) for pred in val_preds]
val_true_str = [' '.join(map(str, true)) for true in val_true_labels]

print(f'BLEU Score: {bleu(val_pred_str, [[true] for true in val_true_str])}')
print(f'ROUGE Score: {rouge(val_pred_str, val_true_str)}')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


BLEU Score: 0.6196612119674683
ROUGE Score: {'rouge1_fmeasure': tensor(0.7640), 'rouge1_precision': tensor(0.7640), 'rouge1_recall': tensor(0.7640), 'rouge2_fmeasure': tensor(0.6680), 'rouge2_precision': tensor(0.6680), 'rouge2_recall': tensor(0.6680), 'rougeL_fmeasure': tensor(0.7639), 'rougeL_precision': tensor(0.7639), 'rougeL_recall': tensor(0.7639), 'rougeLsum_fmeasure': tensor(0.7639), 'rougeLsum_precision': tensor(0.7639), 'rougeLsum_recall': tensor(0.7639)}


Inference and saving the results

In [None]:
def predict(encoder, decoder, dataloader):
    encoder.eval()
    decoder.eval()

    predicted_metres = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids, attention_mask = batch
            encoder_outputs, hidden, cell = encoder(input_ids, attention_mask)
            decoder_input = torch.zeros(input_ids.size(0), dtype=torch.long).to(device)

            batch_predictions = []
            for t in range(14):
                output, hidden, cell = decoder(decoder_input, hidden, cell, encoder_outputs)
                decoder_input = output.argmax(1)
                batch_predictions.append(decoder_input)

            batch_predictions = torch.stack(batch_predictions, dim=1).cpu().numpy()
            predicted_metres.extend(batch_predictions)

    return predicted_metres

In [None]:
test_predictions = predict(encoder, decoder, test_loader)

In [None]:
test_predictions[0:5]

[array([3, 3, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 array([ 8,  8, 14,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]),
 array([ 8,  8, 14,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]),
 array([ 7,  4, 11,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]),
 array([10, 10,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])]

In [None]:
test_prediction_decoded = [label_tokenizer.decode(pred) for pred in test_predictions]

In [None]:
test_prediction_decoded[0:5]

['فعولن فعولن فعولن فعل',
 'فاعلاتن فاعلاتن فاعلن',
 'فاعلاتن فاعلاتن فاعلن',
 'فعلاتن مفاعلن فعلن',
 'مفاعیلن مفاعیلن فعولن']

In [None]:
test_data['predicted_metre'] = test_prediction_decoded
test_data.to_csv('test_samples_seq_to_seq_results.csv', index=False)

Using Beam Search

In [58]:
def beam_search_eval(encoder, decoder, dataloader, beam_width=3, length_penalty_alpha=0.7):
    preds = []
    true_labels = []

    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        for j, val_batch in enumerate(dataloader):
            print(f'batch: {j} / {len(dataloader)}')
            input_ids, attention_mask, labels = [x.to(device) for x in val_batch]
            encoder_outputs, hidden, cell = encoder(input_ids, attention_mask)

            batch_size = input_ids.size(0)
            seq_length = labels.size(1)

            hidden = hidden.contiguous()
            cell = cell.contiguous()

            batch_preds = [[] for _ in range(batch_size)]

            # Beam Search for each sequence in the batch
            for i in range(batch_size):
                beams = [(torch.zeros(1, dtype=torch.long, device=device), 0.0, hidden[:, i:i+1, :].contiguous(), cell[:, i:i+1, :].contiguous(), 1)]
                completed_sequences = []

                for t in range(seq_length):
                    new_beams = []
                    for seq, score, hidden_i, cell_i, length in beams:
                        output, hidden_i, cell_i = decoder(seq[-1:], hidden_i, cell_i, encoder_outputs[i:i+1])
                        topk_logits, topk_indices = torch.topk(output, beam_width, dim=-1)

                        # log-softmax
                        topk_log_probs = F.log_softmax(topk_logits, dim=-1)

                        for k in range(beam_width):
                            new_seq = torch.cat([seq, topk_indices[:, k]], dim=-1)
                            new_score = score + topk_log_probs[0, k].item()
                            normalized_score = new_score / ((5 + length + 1) / 6) ** length_penalty_alpha
                            new_beams.append((new_seq, normalized_score, hidden_i, cell_i, length + 1))

                    # Keep the top `beam_width` sequences
                    beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]

                    completed_sequences.extend([b for b in beams if b[0][-1].item() == 0])
                    beams = [b for b in beams if b[0][-1].item() != 0]

                    if len(beams) == 0:
                        break

                if len(completed_sequences) == 0:
                    completed_sequences = beams

                best_seq = max(completed_sequences, key=lambda x: x[1])[0]
                batch_preds[i] = best_seq[1:].tolist()  

            preds.extend(batch_preds)
            true_labels.extend(labels[:, 0:].tolist())

    return preds, true_labels

In [59]:
beam_val_preds, beam_val_true_labels = beam_search_eval(encoder, decoder, val_loader)

batch: 0 / 84
batch: 1 / 84
batch: 2 / 84
batch: 3 / 84
batch: 4 / 84
batch: 5 / 84
batch: 6 / 84
batch: 7 / 84
batch: 8 / 84
batch: 9 / 84
batch: 10 / 84
batch: 11 / 84
batch: 12 / 84
batch: 13 / 84
batch: 14 / 84
batch: 15 / 84
batch: 16 / 84
batch: 17 / 84
batch: 18 / 84
batch: 19 / 84
batch: 20 / 84
batch: 21 / 84
batch: 22 / 84
batch: 23 / 84
batch: 24 / 84
batch: 25 / 84
batch: 26 / 84
batch: 27 / 84
batch: 28 / 84
batch: 29 / 84
batch: 30 / 84
batch: 31 / 84
batch: 32 / 84
batch: 33 / 84
batch: 34 / 84
batch: 35 / 84
batch: 36 / 84
batch: 37 / 84
batch: 38 / 84
batch: 39 / 84
batch: 40 / 84
batch: 41 / 84
batch: 42 / 84
batch: 43 / 84
batch: 44 / 84
batch: 45 / 84
batch: 46 / 84
batch: 47 / 84
batch: 48 / 84
batch: 49 / 84
batch: 50 / 84
batch: 51 / 84
batch: 52 / 84
batch: 53 / 84
batch: 54 / 84
batch: 55 / 84
batch: 56 / 84
batch: 57 / 84
batch: 58 / 84
batch: 59 / 84
batch: 60 / 84
batch: 61 / 84
batch: 62 / 84
batch: 63 / 84
batch: 64 / 84
batch: 65 / 84
batch: 66 / 84
batch

In [60]:
from itertools import zip_longest
max_length = max(max(len(pred) for pred in beam_val_preds), max(len(label) for label in beam_val_true_labels))
beam_val_preds_padded = [pred + [0] * (max_length - len(pred)) for pred in beam_val_preds]
beam_val_true_labels_padded = [label + [0] * (max_length - len(label)) for label in beam_val_true_labels]

beam_val_preds_np = np.array(beam_val_preds_padded)
beam_val_true_labels_np = np.array(beam_val_true_labels_padded)

beam_val_preds_decoded = [label_tokenizer.decode(pred) for pred in beam_val_preds]
beam_val_true_labels_decoded = [label_tokenizer.decode(label) for label in beam_val_true_labels]

In [61]:
for i in range(0,20):
    print(f'beam search val_pred: {beam_val_preds_decoded[i]}')
    print(f'beam search val_true_label: {beam_val_true_labels_decoded[i]}')
    print('--------------------------------------------')

beam search val_pred: فعلاتن مفاعلن فعولن
beam search val_true_label: فعلاتن مفاعلن فعلن
--------------------------------------------
beam search val_pred: مفاعیلن مفاعیلن فعولن
beam search val_true_label: مفاعیلن مفاعیلن فعولن
--------------------------------------------
beam search val_pred: فعلاتن مفاعلن فعلن
beam search val_true_label: فاعلاتن فاعلاتن فاعلن
--------------------------------------------
beam search val_pred: مفعول مفاعلن فعولن
beam search val_true_label: فعلاتن مفاعلن فعلن
--------------------------------------------
beam search val_pred: فاعلاتن فاعلاتن فاعلن
beam search val_true_label: فاعلاتن فاعلاتن فاعلن
--------------------------------------------
beam search val_pred: فعولن فعولن فعولن فعل
beam search val_true_label: فعلاتن مفاعلن فعلن
--------------------------------------------
beam search val_pred: مفعول فعلاتن مفاعیل فعلن
beam search val_true_label: مفعول فاعلات مفاعیل فاعلن
--------------------------------------------
beam search val_pred: فعولن فعولن فعو

In [62]:
beam_val_preds = np.array(beam_val_preds_np)
beam_val_true_labels = np.array(beam_val_true_labels_np)

beam_val_preds_flat = beam_val_preds.ravel()
beam_val_true_labels_flat = beam_val_true_labels.ravel()

accuracy = accuracy_score(beam_val_true_labels_flat, beam_val_preds_flat)
f1 = f1_score(beam_val_true_labels_flat, beam_val_preds_flat, average='macro', zero_division=1)
recall = recall_score(beam_val_true_labels_flat, beam_val_preds_flat, average='macro', zero_division=1)
precision = precision_score(beam_val_true_labels_flat, beam_val_preds_flat, average='macro', zero_division=1)

print(f'Beam search eval: ')
print(f"Accuracy: {accuracy:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"Recall: {recall:.4f}")
print(f"Precision: {precision:.4f}")


Beam search eval: 
Accuracy: 0.7582
F1 Score: 0.4108
Recall: 0.4125
Precision: 0.6864


In [63]:
bleu = BLEUScore()
rouge = ROUGEScore()

beam_val_pred_str = [' '.join(map(str, pred)) for pred in beam_val_preds]
beam_val_true_str = [' '.join(map(str, true)) for true in beam_val_true_labels]

print(f'BLEU Score: {bleu(beam_val_pred_str, [[true] for true in beam_val_true_str])}')
print(f'ROUGE Score: {rouge(beam_val_pred_str, beam_val_true_str)}')

BLEU Score: 0.6199867129325867
ROUGE Score: {'rouge1_fmeasure': tensor(0.7631), 'rouge1_precision': tensor(0.7631), 'rouge1_recall': tensor(0.7631), 'rouge2_fmeasure': tensor(0.6667), 'rouge2_precision': tensor(0.6667), 'rouge2_recall': tensor(0.6667), 'rougeL_fmeasure': tensor(0.7630), 'rougeL_precision': tensor(0.7630), 'rougeL_recall': tensor(0.7630), 'rougeLsum_fmeasure': tensor(0.7630), 'rougeLsum_precision': tensor(0.7630), 'rougeLsum_recall': tensor(0.7630)}


In [64]:
def beam_prediction(encoder, decoder, dataloader, beam_width=3, length_penalty_alpha=0.7):
    encoder.eval()
    decoder.eval()

    predicted_metres = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids, attention_mask = [x.to(device) for x in batch]
            encoder_outputs, hidden, cell = encoder(input_ids, attention_mask)

            batch_size = input_ids.size(0)
            hidden = hidden.contiguous()
            cell = cell.contiguous()

            batch_predictions = []

            # Beam Search for each sequence in the batch
            for i in range(batch_size):
                beams = [(torch.zeros(1, dtype=torch.long, device=device), 0.0, hidden[:, i:i+1, :].contiguous(), cell[:, i:i+1, :].contiguous(), 1)]
                completed_sequences = []

                for t in range(14):
                    new_beams = []
                    for seq, score, hidden_i, cell_i, length in beams:
                        output, hidden_i, cell_i = decoder(seq[-1:], hidden_i, cell_i, encoder_outputs[i:i+1])
                        topk_logits, topk_indices = torch.topk(output, beam_width, dim=-1)

                        # log softmax
                        topk_log_probs = F.log_softmax(topk_logits, dim=-1)

                        for k in range(beam_width):
                            new_seq = torch.cat([seq, topk_indices[:, k]], dim=-1)
                            new_score = score + topk_log_probs[0, k].item()

                            # Apply length normalization to the score
                            normalized_score = new_score / ((5 + length + 1) / 6) ** length_penalty_alpha
                            new_beams.append((new_seq, normalized_score, hidden_i, cell_i, length + 1))

                    # Keep the top `beam_width` sequences
                    beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]

                    completed_sequences.extend([b for b in beams if b[0][-1].item() == 0])
                    beams = [b for b in beams if b[0][-1].item() != 0]

                    if len(beams) == 0:
                        break

                if len(completed_sequences) == 0:
                    completed_sequences = beams

                best_seq = max(completed_sequences, key=lambda x: x[1])[0]
                batch_predictions.append(best_seq[1:].tolist())  

            predicted_metres.extend(batch_predictions)

    return predicted_metres

In [65]:
beam_test_predictions = beam_prediction(encoder, decoder, test_loader)
beam_test_prediction_decoded = [label_tokenizer.decode(pred) for pred in beam_test_predictions]


test_data['predicted_metre'] = beam_test_prediction_decoded
test_data.to_csv('test_samples_seq_to_seq_results_beam_search.csv', index=False)