# Bidirectional LSTM & CRF (Hyperparams)

In [None]:
import itertools
import logging
import os
import time

import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from datasets import load_dataset
from sklearn.metrics import accuracy_score, classification_report, f1_score
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoTokenizer

from data.ner_dataset import NERDataset
from models.lstm_crf import BiLSTMCRF

In [None]:
# Set up logging
logging.basicConfig(
    filename="training.log",
    filemode="w",
    format="%(asctime)s - %(levelname)s - %(message)s",
    level=logging.INFO,
)
logger = logging.getLogger()

## Load Data

In [None]:
dataset = load_dataset("conll2003")

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

unique_labels = dataset["train"].features["ner_tags"].feature.names
label2id = {label: i for i, label in enumerate(unique_labels)}
id2label = {i: label for label, i in label2id.items()}

In [None]:
train_dataset = NERDataset(dataset["train"], tokenizer)
val_dataset = NERDataset(dataset["validation"], tokenizer)

## Model

In [None]:
VOCAB_SIZE = tokenizer.vocab_size
EMBED_DIM = 128
HIDDEN_DIM = 256
NUM_LABELS = len(unique_labels)
PAD_IDX = tokenizer.pad_token_id

In [None]:
model = BiLSTMCRF(VOCAB_SIZE, EMBED_DIM, HIDDEN_DIM, NUM_LABELS, PAD_IDX)
optimizer = optim.Adam(model.parameters(), lr=5e-4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

## Train & Validate

In [None]:
def evaluate(model, loader, label_list):
    model.eval()
    predictions, true_labels = [], []
    total_val_loss = 0

    with torch.no_grad():
        for batch in loader:
            input_ids, attention_mask, labels = batch
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)

            # Compute loss if labels are provided
            loss = model(input_ids, attention_mask, labels)
            if loss is not None:
                total_val_loss += loss.item()

            # Decode predictions
            preds = model(input_ids, attention_mask)

            # Align predictions and true labels
            for pred, label, mask in zip(preds, labels.cpu().numpy(), attention_mask.cpu().numpy()):
                true_seq = [label_list[l] for l, m in zip(label, mask) if m == 1 and l != -100]
                pred_seq = [label_list[p] for p, m in zip(pred, mask) if m == 1][: len(true_seq)]
                if len(pred_seq) == len(true_seq):
                    true_labels.append(true_seq)
                    predictions.append(pred_seq)

    # Flatten predictions and labels for evaluation
    flattened_preds = [label for seq in predictions for label in seq]
    flattened_labels = [label for seq in true_labels for label in seq]

    # Compute evaluation metrics
    accuracy = accuracy_score(flattened_labels, flattened_preds)
    f1 = f1_score(flattened_labels, flattened_preds, average="weighted")
    classification_rep = classification_report(
        flattened_labels, flattened_preds, target_names=label_list, zero_division=1
    )

    return {
        "loss": total_val_loss / len(loader) if len(loader) > 0 else None,
        "accuracy": accuracy,
        "f1": f1,
        "classification_report": classification_rep,
    }

In [None]:
def train_and_validate(
    model, train_loader, val_loader, optimizer, label_list, epochs, device, print_every=25, lr=None, dropout=None
):
    # Track metrics
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []
    best_val_loss = float("inf")
    patience, epochs_without_improvement = 10, 0

    # Create results directory
    if not os.path.exists("results"):
        os.makedirs("results")

    epoch_start_time = time.time()  # Track overall epoch timing
    for epoch in range(epochs):
        # Training
        model.train()
        total_train_loss, correct_train_preds, total_train_tokens = 0, 0, 0
        for input_ids, attention_mask, labels in train_loader:
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
            optimizer.zero_grad()
            loss = model(input_ids, attention_mask, labels)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        # Perform evaluation and log metrics only every print_every epochs
        if (epoch + 1) % print_every == 0 or epoch == epochs - 1:
            model.eval()

            # Calculate training accuracy
            correct_train_preds, total_train_tokens = 0, 0
            for input_ids, attention_mask, labels in train_loader:
                input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
                with torch.no_grad():
                    preds = model(input_ids, attention_mask)
                    for pred, label, mask in zip(preds, labels.cpu().numpy(), attention_mask.cpu().numpy()):
                        true_seq = [label_list[l] for l, m in zip(label, mask) if m == 1 and l != -100]
                        pred_seq = [label_list[p] for p, m in zip(pred, mask) if m == 1][: len(true_seq)]
                        correct_train_preds += sum([p == t for p, t in zip(pred_seq, true_seq)])
                        total_train_tokens += len(true_seq)

            train_accuracy = correct_train_preds / total_train_tokens if total_train_tokens > 0 else 0
            train_accuracies.append(train_accuracy)

            # Perform validation evaluation
            with torch.no_grad():
                results = evaluate(model, val_loader, label_list)
                avg_val_loss = results["loss"]
                val_losses.append(avg_val_loss)

            # Calculate validation accuracy
            correct_val_preds, total_val_tokens = 0, 0
            for input_ids, attention_mask, labels in val_loader:
                input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
                with torch.no_grad():
                    preds = model(input_ids, attention_mask)
                    for pred, label, mask in zip(preds, labels.cpu().numpy(), attention_mask.cpu().numpy()):
                        true_seq = [label_list[l] for l, m in zip(label, mask) if m == 1 and l != -100]
                        pred_seq = [label_list[p] for p, m in zip(pred, mask) if m == 1][: len(true_seq)]
                        correct_val_preds += sum([p == t for p, t in zip(pred_seq, true_seq)])
                        total_val_tokens += len(true_seq)

            val_accuracy = correct_val_preds / total_val_tokens if total_val_tokens > 0 else 0
            val_accuracies.append(val_accuracy)

            # Log metrics
            elapsed_time = time.time() - epoch_start_time  # Calculate elapsed time
            logger.info(f"Epoch {epoch + 1}/{epochs}")
            logger.info(f"Training Loss: {avg_train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}")
            logger.info(f"Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")
            logger.info(f"Time elapsed for last {print_every} epoch(s): {elapsed_time:.2f} seconds")
            epoch_start_time = time.time()  # Reset timer for next interval

            # Early stopping
            if avg_val_loss < best_val_loss - 0.001:
                best_val_loss = avg_val_loss
                epochs_without_improvement = 0
            else:
                epochs_without_improvement += 1
                if epochs_without_improvement >= patience:
                    logger.info("Early stopping triggered.")
                    break

    # Save losses
    losses_df = pd.DataFrame(
        {
            "epoch": list(range(1, len(train_losses) + 1)),
            "train_loss": train_losses,
            "val_loss": [None] * (len(train_losses) - len(val_losses)) + val_losses,
        }
    )
    losses_path = os.path.join("results", f"losses_lr_{lr}_dropout_{dropout}.csv")
    losses_df.to_csv(losses_path, index=False)

    # Save accuracies
    accuracies_df = pd.DataFrame(
        {
            "epoch": list(range(1, len(train_accuracies) + 1)),
            "train_accuracy": train_accuracies,
            "val_accuracy": [None] * (len(train_accuracies) - len(val_accuracies)) + val_accuracies,
        }
    )
    accuracies_path = os.path.join("results", f"accuracies_lr_{lr}_dropout_{dropout}.csv")
    accuracies_df.to_csv(accuracies_path, index=False)

    return min(val_losses), max(val_accuracies)

In [None]:
# Hyperparameter grid
learning_rates = [1e-3, 1e-4]
dropouts = [0.1, 0.2]
hyperparams = list(itertools.product(learning_rates, dropouts))

# Results storage
results = []

In [None]:
for lr, dropout in hyperparams:
    logger.info(f"Training with Learning Rate: {lr}, Dropout: {dropout}")

    # Initialize model, optimizer, and data loaders
    model = BiLSTMCRF(
        vocab_size=tokenizer.vocab_size,
        embed_dim=128,
        hidden_dim=256,
        num_labels=len(unique_labels),
        pad_idx=tokenizer.pad_token_id,
    ).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    train_loader = DataLoader(train_dataset, batch_size=4000, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=4000)

    # Train and validate
    best_val_loss, best_val_accuracy = train_and_validate(
        model,
        train_loader,
        val_loader,
        optimizer,
        unique_labels,
        epochs=1000,
        device=device,
        print_every=50,
        lr=lr,
        dropout=dropout,
    )

    # Log and store results
    logger.info(f"Best Validation Loss for LR: {lr}, Dropout: {dropout}: {best_val_loss:.4f}")
    logger.info(f"Best Validation Accuracy for LR: {lr}, Dropout: {dropout}: {best_val_accuracy:.4f}")
    results.append(
        {
            "learning_rate": lr,
            "dropout": dropout,
            "best_val_loss": best_val_loss,
            "best_val_accuracy": best_val_accuracy,
        }
    )

# Save hyperparameter results to a DataFrame
results_df = pd.DataFrame(results)
results_df.to_csv("results/hyperparameter_results.csv", index=False)

logger.info("Hyperparameter tuning complete. Results saved.")