In [1]:
# -----------------------------
# Step 1: Install and Upgrade Necessary Packages
# -----------------------------
!pip install --upgrade torchcrf seqeval datasets
%pip install spacy
%pip install spacy_conll
import random
import spacy
from spacy_conll import ConllFormatter

Collecting torchcrf
  Downloading TorchCRF-1.1.0-py3-none-any.whl.metadata (2.3 kB)
Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.

In [2]:
# -----------------------------
# Step 1: Install and Upgrade Necessary Packages
# -----------------------------
!pip install --upgrade torchcrf seqeval datasets

# -----------------------------
# Step 2: Import Libraries and Set Device
# -----------------------------
import torch
import torch.nn as nn
import torch.optim as optim
from TorchCRF import CRF  # Corrected import statement
from datasets import load_dataset
from sklearn.metrics import classification_report
from typing import List
import numpy as np
from seqeval.metrics import classification_report as seq_classification_report
from torch.utils.data import TensorDataset, DataLoader
from collections import defaultdict
import pickle
from google.colab import files

# Set device to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# -----------------------------
# Step 3: Define the Custom LSTM and LSTM-CRF Model
# -----------------------------
class CustomLSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(CustomLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # Input gate parameters
        self.W_ii = nn.Parameter(torch.Tensor(hidden_size, input_size))
        self.W_hi = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_i = nn.Parameter(torch.Tensor(hidden_size))

        # Forget gate parameters
        self.W_if = nn.Parameter(torch.Tensor(hidden_size, input_size))
        self.W_hf = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_f = nn.Parameter(torch.Tensor(hidden_size))

        # Cell gate parameters
        self.W_ig = nn.Parameter(torch.Tensor(hidden_size, input_size))
        self.W_hg = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_g = nn.Parameter(torch.Tensor(hidden_size))

        # Output gate parameters
        self.W_io = nn.Parameter(torch.Tensor(hidden_size, input_size))
        self.W_ho = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        self.b_o = nn.Parameter(torch.Tensor(hidden_size))

        self.init_weights()

    def init_weights(self):
        # Initialize weight matrices with Xavier Uniform
        nn.init.xavier_uniform_(self.W_ii)
        nn.init.xavier_uniform_(self.W_hi)
        nn.init.xavier_uniform_(self.W_if)
        nn.init.xavier_uniform_(self.W_hf)
        nn.init.xavier_uniform_(self.W_ig)
        nn.init.xavier_uniform_(self.W_hg)
        nn.init.xavier_uniform_(self.W_io)
        nn.init.xavier_uniform_(self.W_ho)

        # Initialize biases with zeros
        nn.init.zeros_(self.b_i)
        nn.init.zeros_(self.b_f)
        nn.init.zeros_(self.b_g)
        nn.init.zeros_(self.b_o)

    def forward(self, input_seq, h_0=None, c_0=None):
        """
        input_seq: Tensor of shape (batch_size, seq_length, input_size)
        h_0: Initial hidden state (batch_size, hidden_size)
        c_0: Initial cell state (batch_size, hidden_size)
        Returns:
            h_seq: Tensor containing hidden states for all time steps (batch_size, seq_length, hidden_size)
            (h_n, c_n): Final hidden and cell states
        """
        batch_size, seq_length, _ = input_seq.size()
        if h_0 is None:
            h_t = torch.zeros(batch_size, self.hidden_size, device=input_seq.device)
        else:
            h_t = h_0
        if c_0 is None:
            c_t = torch.zeros(batch_size, self.hidden_size, device=input_seq.device)
        else:
            c_t = c_0

        h_seq = []

        for t in range(seq_length):
            x_t = input_seq[:, t, :]  # (batch_size, input_size)

            i_t = torch.sigmoid(x_t @ self.W_ii.T + h_t @ self.W_hi.T + self.b_i)
            f_t = torch.sigmoid(x_t @ self.W_if.T + h_t @ self.W_hf.T + self.b_f)
            g_t = torch.tanh(x_t @ self.W_ig.T + h_t @ self.W_hg.T + self.b_g)
            o_t = torch.sigmoid(x_t @ self.W_io.T + h_t @ self.W_ho.T + self.b_o)

            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)

            h_seq.append(h_t.unsqueeze(1))  # (batch_size, 1, hidden_size)

        h_seq = torch.cat(h_seq, dim=1)  # (batch_size, seq_length, hidden_size)
        return h_seq, (h_t, c_t)

class BidirectionalCustomLSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(BidirectionalCustomLSTM, self).__init__()
        self.forward_lstm = CustomLSTM(input_size, hidden_size)
        self.backward_lstm = CustomLSTM(input_size, hidden_size)
        self.hidden_size = hidden_size

    def forward(self, input_seq, h_0=None, c_0=None):
        # Forward direction
        forward_out, (h_f, c_f) = self.forward_lstm(input_seq, h_0, c_0)

        # Backward direction
        reversed_input = torch.flip(input_seq, [1])  # Reverse the sequence
        backward_out, (h_b, c_b) = self.backward_lstm(reversed_input, h_0, c_0)
        backward_out = torch.flip(backward_out, [1])  # Re-reverse to original order

        # Concatenate forward and backward outputs
        h_seq = torch.cat([forward_out, backward_out], dim=2)  # (batch_size, seq_length, 2*hidden_size)

        # Final hidden and cell states
        h_n = torch.cat([h_f, h_b], dim=1)  # (batch_size, 2*hidden_size)
        c_n = torch.cat([c_f, c_b], dim=1)  # (batch_size, 2*hidden_size)

        return h_seq, (h_n, c_n)

class LSTM_CRF(nn.Module):
    def __init__(self, vocab_size, tagset_size, embedding_dim, hidden_dim, padding_idx, dropout=0.5):
        super(LSTM_CRF, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.lstm = BidirectionalCustomLSTM(embedding_dim, hidden_dim // 2)  # hidden_dim // 2 for each direction
        self.dropout = nn.Dropout(dropout)
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
        self.crf = CRF(tagset_size)

    def forward(self, sentences, tags, mask):
        embeds = self.embedding(sentences)  # (batch_size, seq_length, embedding_dim)
        embeds = self.dropout(embeds)
        lstm_out, _ = self.lstm(embeds)  # (batch_size, seq_length, hidden_dim)

        lstm_out = lstm_out.transpose(0, 1) if lstm_out.shape[0] == sentences.shape[1] else lstm_out

        lstm_out = self.dropout(lstm_out)
        emissions = self.hidden2tag(lstm_out)  # (batch_size, seq_length, tagset_size)

        # Transpose to (seq_length, batch_size, tagset_size)
        emissions = emissions.transpose(0, 1)
        tags = tags.transpose(0, 1)
        mask = mask.transpose(0, 1)

        loss = -self.crf(emissions, tags, mask=mask)
        loss = torch.mean(loss)
        return loss


    def predict(self, sentences, mask):
        self.eval()
        predictions = []
        embeds = self.embedding(sentences)
        embeds = self.dropout(embeds)
        lstm_out, _ = self.lstm(embeds)

        lstm_out = lstm_out.transpose(0, 1) if lstm_out.shape[0] == sentences.shape[1] else lstm_out

        lstm_out = self.dropout(lstm_out)
        emissions = self.hidden2tag(lstm_out)

        # Transpose to (seq_length, batch_size, tagset_size)

        # Alp - commenting this out to debug step 7
        # emissions = emissions.transpose(0, 1)
        # mask = mask.transpose(0, 1)

        predictions = self.crf.viterbi_decode(emissions, mask=mask)  # List[List[int]]
        return predictions  # List[List[int]]

Using device: cuda


In [3]:
# -----------------------------
# Step 4: Load and Preprocess the CoNLL-2003 Dataset
# -----------------------------
# Load the CoNLL-2003 dataset
dataset = load_dataset('conll2003')

# Inspect the dataset
print(dataset)

# Example of the dataset
print(dataset['train'][0])

# Extract all unique words and tags
words = set()
tags = set()

for split in ['train', 'validation', 'test']:
    for sentence in dataset[split]:
        for word in sentence['tokens']:
            words.add(word.lower())  # Lowercasing for normalization
        for tag in sentence['ner_tags']:
            tags.add(tag)

# Create word2idx dictionary
word2idx = {"<PAD>": 0, "<UNK>": 1}
for word in sorted(words):
    word2idx[word] = len(word2idx)

# Mapping from tag indices to tag names
tag_names = dataset['train'].features['ner_tags'].feature.names
tag2idx = {"<PAD>": 0}
for tag in tag_names:
    tag2idx[tag] = len(tag2idx)

idx2tag = {v: k for k, v in tag2idx.items()}

print(f"Number of unique words: {len(word2idx)}")
print(f"Number of unique tags: {len(tag2idx)}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/12.3k [00:00<?, ?B/s]

conll2003.py:   0%|          | 0.00/9.57k [00:00<?, ?B/s]

The repository for conll2003 contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/conll2003.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


Downloading data:   0%|          | 0.00/983k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/14041 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3250 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3453 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 14041
    })
    validation: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3250
    })
    test: Dataset({
        features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],
        num_rows: 3453
    })
})
{'id': '0', 'tokens': ['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.'], 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7], 'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0], 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}
Number of unique words: 26871
Number of unique tags: 10


In [4]:
# -----------------------------
# Step 5: Encode the Dataset
# -----------------------------
# Parameters
MAX_LEN = 50  # Maximum sentence length
EMBEDDING_DIM = 100
HIDDEN_DIM = 128
BATCH_SIZE = 32
EPOCHS = 5

# Encoding functions
def encode_sentences(sentences: List[List[str]], word2idx: dict, max_len: int) -> torch.Tensor:
    encoded = []
    for sentence in sentences:
        encoded_sentence = [word2idx.get(word.lower(), word2idx["<UNK>"]) for word in sentence]
        # Padding
        if len(encoded_sentence) < max_len:
            encoded_sentence += [word2idx["<PAD>"]] * (max_len - len(encoded_sentence))
        else:
            encoded_sentence = encoded_sentence[:max_len]
        encoded.append(encoded_sentence)
    return torch.tensor(encoded, dtype=torch.long)

def encode_labels(labels: List[List[int]], tag2idx: dict, max_len: int) -> torch.Tensor:
    encoded = []
    for label_seq in labels:
        encoded_label = [label + 1 for label in label_seq]  # +1 to account for <PAD> tag at 0
        if len(encoded_label) < max_len:
            encoded_label += [tag2idx["<PAD>"]] * (max_len - len(encoded_label))
        else:
            encoded_label = encoded_label[:max_len]
        encoded.append(encoded_label)
    return torch.tensor(encoded, dtype=torch.long)

# Prepare training data
train_sentences = [example['tokens'] for example in dataset['train']]
train_labels = [example['ner_tags'] for example in dataset['train']]

X_train = encode_sentences(train_sentences, word2idx, MAX_LEN)
y_train = encode_labels(train_labels, tag2idx, MAX_LEN)

# Prepare validation data
val_sentences = [example['tokens'] for example in dataset['validation']]
val_labels = [example['ner_tags'] for example in dataset['validation']]

X_val = encode_sentences(val_sentences, word2idx, MAX_LEN)
y_val = encode_labels(val_labels, tag2idx, MAX_LEN)

# Prepare test data (optional, for later evaluation)
test_sentences = [example['tokens'] for example in dataset['test']]
test_labels = [example['ner_tags'] for example in dataset['test']]

X_test = encode_sentences(test_sentences, word2idx, MAX_LEN)
y_test = encode_labels(test_labels, tag2idx, MAX_LEN)

print(f"Training samples: {X_train.shape[0]}")
print(f"Validation samples: {X_val.shape[0]}")
print(f"Test samples: {X_test.shape[0]}")

Training samples: 14041
Validation samples: 3250
Test samples: 3453


In [5]:
# -----------------------------
# Step 6: Create DataLoaders
# -----------------------------
# Create TensorDatasets
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# -----------------------------
# Step 7: Train the LSTM-CRF Model
# -----------------------------
# Initialize Model
VOCAB_SIZE = len(word2idx)
TAGSET_SIZE = len(tag2idx)
PAD_IDX = word2idx["<PAD>"]

model = LSTM_CRF(VOCAB_SIZE, TAGSET_SIZE, EMBEDDING_DIM, HIDDEN_DIM, padding_idx=PAD_IDX)
model.to(device)

# Initialize Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training Loop
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    for batch_X, batch_y in train_loader:
        batch_X = batch_X.to(device)
        batch_y = batch_y.to(device)

        # Create mask (batch_size, max_len)
        mask = (batch_X != PAD_IDX).to(device)

        # Zero gradients
        optimizer.zero_grad()

        # Compute loss
        loss = model(batch_X, batch_y, mask)
        epoch_loss += loss.item()

        # Backward pass
        loss.backward()
        optimizer.step()

    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {avg_loss:.4f}")

Epoch 1/5, Loss: 2.3481
Epoch 2/5, Loss: -32.0897
Epoch 3/5, Loss: -63.1210
Epoch 4/5, Loss: -95.7824


In [None]:
def evaluate_and_show_results_with_metrics(model, X_test, y_test, test_sentences, idx2tag, device, PAD_IDX=0):
    model.eval()
    all_preds = []
    all_true = []

    # Move data to the specified device
    X_test, y_test = X_test.to(device), y_test.to(device)
    mask = (X_test != PAD_IDX)

    with torch.no_grad():
        # Get model predictions
        predictions = model.predict(X_test, mask=mask)  # Should return List[List[int]]

    for i, (pred_indices, true_labels) in enumerate(zip(predictions, y_test)):
        tokens = test_sentences[i]
        true_labels = true_labels.cpu().numpy()
        pred_indices = np.array(pred_indices)

        print('')

        # Ensure pred_indices and true_labels are the same length before masking
        if len(pred_indices) != len(true_labels[true_labels != PAD_IDX]):
            print(f"Sequence {i} has mismatched lengths before masking: pred_indices ({len(pred_indices)}), true_labels ({len(true_labels[true_labels != PAD_IDX])})")
            # Handle mismatch if necessary

        # Use the mask to filter out padding tokens
        valid_positions = (true_labels != PAD_IDX)
        true_labels_filtered = true_labels[valid_positions]
        pred_indices_filtered = pred_indices

        # Map indices to tags
        true_labels_list = [idx2tag[label] for label in true_labels_filtered]
        pred_tags_list = [idx2tag[idx] for idx in pred_indices_filtered]

        # Handle the "O-O" special case
        true_labels_list = ["O" if tag == "O-O" else tag for tag in true_labels_list]

        # Collect true and predicted labels for metrics
        all_true.append(true_labels_list)
        all_preds.append(pred_tags_list)

        # Print tokens with true labels and predicted labels
        print(f"\nSequence {i}:")
        print(f"{'Token':15} {'True Label':15} {'Predicted Label'}")
        print('-' * 45)
        for token, true_label, pred_label in zip(tokens, true_labels_list, pred_tags_list):
            print(f"{token:15} {true_label:15} {pred_label}")

    # Print classification report
    print("\nClassification Report:")
    print(seq_classification_report(all_true, all_preds, zero_division=0))


evaluate_and_show_results_with_metrics(model, X_test, y_test, test_sentences, idx2tag, device)

In [None]:
# -----------------------------
# Step 8: Save the Model and Dictionaries
# -----------------------------
# Save the model state dictionary
torch.save(model.state_dict(), 'custom_lstm_crf_model.pth')
print("Model saved successfully!")

# Save word2idx and tag2idx dictionaries using pickle
with open('word2idx.pkl', 'wb') as f:
    pickle.dump(word2idx, f)
print("word2idx saved successfully!")

with open('tag2idx.pkl', 'wb') as f:
    pickle.dump(tag2idx, f)
print("tag2idx saved successfully!")

# Optional: Download the files to your local machine
# Uncomment the lines below if you wish to download the files immediately
# files.download('custom_lstm_crf_model.pth')
# files.download('word2idx.pkl')
# files.download('tag2idx.pkl')

In [None]:

# -----------------------------
# Step 9: Upload Your Text File to Colab
# -----------------------------
from google.colab import files
import nltk
from nltk.tokenize import sent_tokenize

# Upload the file
uploaded = files.upload()

In [None]:
# -----------------------------
# Step 10: Load and Preprocess Your Text File
# -----------------------------

import random
import spacy
from spacy_conll import ConllFormatter
nlp = spacy.load("en_core_web_sm")

file_path = "combined_summaries.txt"

with open(file_path, "r", encoding="utf-8") as file:
    text = file.read()

# Process the text using spaCy
doc = nlp(text)

# Prepare data in a structured format
conll_data = []
ner_tag_set = set()  # Collect all unique NER tags for feature mapping

for i, sent in enumerate(doc.sents):
    tokens, pos_tags, chunk_tags, ner_tags = [], [], [], []
    for token in sent:
        tokens.append(token.text)
        pos_tags.append(token.pos_)
        chunk_tags.append(token.dep_)
        ner_tag = f"{token.ent_iob_}-{token.ent_type_ if token.ent_iob_ != 'O' else 'O'}"
        if ner_tag == "O-O":  # Handle invalid tag
          ner_tag = "O"
        ner_tags.append(ner_tag)
        ner_tag_set.add(ner_tag)  # Add to the set of unique tags
    conll_data.append({
        "id": i,
        "tokens": tokens,
        "pos_tags": pos_tags,
        "chunk_tags": chunk_tags,
        "ner_tags": ner_tags
    })


# Shuffle and split the data into 80% train, 10% validation, 10% test
random.shuffle(conll_data)
test_data = conll_data

# Define the Dataset class
class Dataset:
    def __init__(self, split_data, ner_tag_names):
        self.data = split_data
        self.features = {
            "ner_tags": {
                "feature": {
                    "names": ner_tag_names
                }
            }
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

    def __repr__(self):
        return f"Dataset({{\n    features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n    num_rows: {len(self.data)}\n}})"

# Create datasets
ner_tag_names = sorted(ner_tag_set)  # Sorted list of unique NER tag names
dataset = {
    "test": Dataset(test_data, ner_tag_names)
}
# Extract unique words and tags
words = set()
tags = set()

for split in ['test']:
    for sentence in dataset[split]:
        for word in sentence['tokens']:
            words.add(word.lower())  # Lowercasing for normalization
        for tag in sentence['ner_tags']:
            tags.add(tag)

# Create word2idx and tag2idx dictionaries
word2idx = {"<PAD>": 0, "<UNK>": 1}
for word in sorted(words):
    word2idx[word] = len(word2idx)

# Mapping from tag indices to tag names
tag_names = dataset['test'].features['ner_tags']['feature']['names']
tag2idx = {"<PAD>": 0}
for idx, tag in enumerate(tag_names, start=1):
    tag2idx[tag] = len(tag2idx)

idx2tag = {v: k for k, v in tag2idx.items()}

# Encoding functions
def encode_sentences(sentences: List[List[str]], word2idx: dict, max_len: int) -> torch.Tensor:
    encoded = []
    for sentence in sentences:
        encoded_sentence = [word2idx.get(word.lower(), word2idx["<UNK>"]) for word in sentence]

        if len(encoded_sentence) < max_len:
            encoded_sentence += [word2idx["<PAD>"]] * (max_len - len(encoded_sentence))
        else:
            encoded_sentence = encoded_sentence[:max_len]
        encoded.append(encoded_sentence)
    return torch.tensor(encoded, dtype=torch.long)

def encode_labels(labels: List[List[str]], tag2idx: dict, max_len: int) -> torch.Tensor:
    encoded = []
    for label_seq in labels:
        encoded_label = [tag2idx.get(label, tag2idx["<PAD>"]) for label in label_seq]  # Map NER tags to indices
        if len(encoded_label) < max_len:
            encoded_label += [tag2idx["<PAD>"]] * (max_len - len(encoded_label))  # Padding
        else:
            encoded_label = encoded_label[:max_len]
        encoded.append(encoded_label)
    return torch.tensor(encoded, dtype=torch.long)

test_sentences = [example['tokens'] for example in dataset['test']]
test_labels = [example['ner_tags'] for example in dataset['test']]
X_test = encode_sentences(test_sentences, word2idx, MAX_LEN)
y_test = encode_labels(test_labels, tag2idx, MAX_LEN)

In [None]:
# -----------------------------
# Step 11: Perform NER on Your Text File
# -----------------------------
# Load word2idx and tag2idx dictionaries
with open('word2idx.pkl', 'rb') as f:
    word2idx_loaded = pickle.load(f)

with open('tag2idx.pkl', 'rb') as f:
    tag2idx_loaded = pickle.load(f)

idx2tag_loaded = {v: k for k, v in tag2idx_loaded.items()}

# Initialize the model
model = LSTM_CRF(VOCAB_SIZE, TAGSET_SIZE, EMBEDDING_DIM, HIDDEN_DIM, padding_idx=PAD_IDX)
model.to(device)

# Load the saved model state
model_path = 'custom_lstm_crf_model.pth'  # Ensure this path is correct
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print("Model loaded successfully!")

# Create DataLoader for test data
def create_dataloader_custom(encoded_sentences: torch.Tensor, encoded_labels: torch.Tensor ,batch_size: int=32) -> DataLoader:
    """
    Create a DataLoader for the encoded sentences.

    Args:
        encoded_sentences (torch.Tensor): Tensor of encoded sentences.
        batch_size (int, optional): Batch size. Defaults to 32.

    Returns:
        DataLoader: DataLoader object.
    """
    # Since we are only doing inference, labels are not needed. Use dummy labels.
    test_dataset = TensorDataset(X_test, y_test)

    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    return test_loader

test_loader_custom = create_dataloader_custom(X_test, y_test, batch_size=BATCH_SIZE)

# -----------------------------
# Step 12: Display the NER Results
# -----------------------------

# Perform NER
from typing import List, Tuple
def perform_ner_with_labels(model: nn.Module, dataloader: DataLoader, idx2tag: dict, device: torch.device) -> Tuple[List[List[str]], List[List[str]]]:
    """
    Perform NER and retrieve true labels.

    Args:
        model (nn.Module): Trained LSTM_CRF model.
        dataloader (DataLoader): DataLoader with sentences and true labels.
        idx2tag (dict): Mapping from tag indices to tag names.
        device (torch.device): Device to perform computation on.

    Returns:
        Tuple[List[List[str]], List[List[str]]]: Predicted tags and true tags.
    """
    all_preds = []
    all_true_labels = []

    with torch.no_grad():
        for batch_X, batch_y in dataloader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            mask = (batch_X != PAD_IDX).to(device)

            # Get predictions
            preds = model.predict(batch_X, mask)

            for pred, true_label in zip(preds, batch_y):
                all_preds.append(pred)
                all_true_labels.append(true_label.tolist())

    # Convert indices to tag names
    all_preds_tags = [[idx2tag.get(idx, "O") for idx in sent] for sent in all_preds]
    all_true_tags = [[idx2tag.get(idx, "O") for idx in sent] for sent in all_true_labels]

    # Print classification report
    print("\nClassification Report:")
    print(seq_classification_report(all_true_labels, all_preds, zero_division=0))

    return all_preds_tags, all_true_tags

predicted_tags_custom, true_tags_custom = perform_ner_with_labels(model, test_loader_custom, idx2tag_loaded, device)
print("NER prediction and true label extraction completed!")

# Display NER results with true labels
def display_ner_results_with_labels(sentences: List[List[str]], predicted_tags: List[List[str]], true_tags: List[List[str]]):
    """
    Display tokens with their predicted and true NER tags.

    Args:
        sentences (List[List[str]]): Original tokenized sentences.
        predicted_tags (List[List[str]]): Predicted NER tags for each token.
        true_tags (List[List[str]]): True NER tags for each token.
    """
    for i, (sentence, pred_tags, true_tags) in enumerate(zip(sentences, predicted_tags, true_tags)):
        print(f"Sentence {i+1}:")
        for token, pred_tag, true_tag in zip(sentence[:len(pred_tags)], pred_tags, true_tags):
            print(f"{token:15}\tPredicted: {pred_tag:10}\tTrue: {true_tag}")
        print("\n")

# Display results
display_ner_results_with_labels(test_sentences, predicted_tags_custom, true_tags_custom)
