In [21]:
import os
import json
from nltk.tree import Tree
from nltk.tree.prettyprinter import TreePrettyPrinter
import numpy as np
import pandas as pd
import torch
from torch import nn
import ipywidgets as widgets
from IPython.display import display

In [22]:
# ----------------------
# Load Vocabulary
# ----------------------
def load_vocab(model_name):
    with open(os.path.join("models", model_name, "word2idx.json"), 'r') as json_file:
        word2idx = json.load(json_file)
    with open(os.path.join("models", model_name, "idx2pos.json"), 'r') as json_file:
        idx2pos = json.load(json_file)
        idx2pos = {int(k): v for k, v in idx2pos.items()}
    return word2idx, idx2pos

# ----------------------
# Load Model
# ----------------------
def load_model(model_name, vocab):
    class SPositionalEncoding(nn.Module):
        def __init__(self, embed_size, max_len=5000):
            super(SPositionalEncoding, self).__init__()
            pe = torch.zeros(max_len, embed_size)
            position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-np.log(10000.0) / embed_size))
            pe[:, 0::2] = torch.sin(position * div_term)
            if embed_size % 2 == 1:
                pe[:, 1::2] = torch.cos(position * div_term[:-1])
            else:
                pe[:, 1::2] = torch.cos(position * div_term)
            pe = pe.unsqueeze(0)
            self.register_buffer('pe', pe)
        def forward(self, x):
            x = x + self.pe[:, :x.size(1), :]
            return x
    class STransformerE(nn.Module):
        def __init__(self, vocab_size, embed_size, num_heads, hidden_dim, num_layers, output_dim, padding_idx, embedding_matrix, dropout=0.1):
            super(STransformerE, self).__init__()
            # Embedding Layer
            self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=padding_idx)
            self.embedding.weight = nn.Parameter(embedding_matrix)
            self.embedding.weight.requires_grad = True
            # Positional Encoding Layer
            self.pos_encoder = SPositionalEncoding(embed_size)
            # Transformer Encoder Layer
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=embed_size, 
                nhead=num_heads, 
                dim_feedforward=hidden_dim, 
                dropout=dropout,
                batch_first=True
            )
            self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers, norm=nn.LayerNorm(embed_size))
            # Dropout Layer
            self.dropout = nn.Dropout(dropout)
            # Fully Connected Layer
            self.fc = nn.Linear(embed_size, output_dim)
        def forward(self, x):
            embedded = self.embedding(x)
            embedded = self.pos_encoder(embedded)
            embedded = self.dropout(embedded)
            src_key_padding_mask = (x == self.embedding.padding_idx)
            transformer_output = self.transformer_encoder(
                embedded, 
                src_key_padding_mask=src_key_padding_mask
            )
            transformer_output = self.dropout(transformer_output)
            logits = self.fc(transformer_output)
            return logits
    embedding_matrix = torch.load(os.path.join("models", model_name, "embedding-matrix.pth"), weights_only=True, map_location=torch.device('cpu'))
    model = STransformerE(
        vocab_size=len(vocab),
        embed_size=200,
        num_heads=8,
        hidden_dim=512,
        num_layers=2,
        output_dim=47,
        padding_idx=0,
        embedding_matrix=embedding_matrix,
        dropout=0.16426146772147993
    )
    model.load_state_dict(torch.load(os.path.join("models", model_name, "model-state.pt"), weights_only=True, map_location=torch.device('cpu')))
    return model

# ----------------------
# Predict Function
# ----------------------
def predict_pos_tag(model, word2idx, idx2pos, sequence):
    model.eval()
    pos_descriptions = {
        "CC": "Coordinating Conjunction",
        "CD": "Cardinal Number",
        "DT": "Determiner",
        "EX": "Existential 'There'",
        "FW": "Foreign Word",
        "IN": "Preposition or Subordinating Conjunction",
        "JJ": "Adjective",
        "JJR": "Adjective (Comparative)",
        "JJS": "Adjective (Superlative)",
        "LS": "List Item Marker",
        "MD": "Modal",
        "NN": "Noun (Singular or Mass)",
        "NNS": "Noun (Plural)",
        "NNP": "Proper Noun (Singular)",
        "NNPS": "Proper Noun (Plural)",
        "PDT": "Predeterminer",
        "POS": "Possessive Ending",
        "PRP": "Personal Pronoun",
        "PRP$": "Possessive Pronoun",
        "RB": "Adverb",
        "RBR": "Adverb (Comparative)",
        "RBS": "Adverb (Superlative)",
        "RP": "Particle",
        "SYM": "Symbol",
        "TO": "to",
        "UH": "Interjection",
        "VB": "Verb (Base Form)",
        "VBD": "Verb (Past Tense)",
        "VBG": "Verb (Gerund or Present Participle)",
        "VBN": "Verb (Past Participle)",
        "VBP": "Verb (Non-3rd-Person Singular Present)",
        "VBZ": "Verb (3rd Person Singular Present)",
        "WDT": "Wh-Determiner",
        "WP": "Wh-Pronoun",
        "WP$": "Possessive Wh-Pronoun",
        "WRB": "Wh-Adverb"
    }
    if isinstance(sequence, str):
        words = sequence.split()
    elif isinstance(sequence, list):
        words = sequence
    else:
        raise ValueError("Input sequence must be a string or list of words")
    words_lower = [word.lower() for word in words]
    word_indices = [word2idx.get(word, word2idx['<UNK>']) for word in words_lower]
    input_tensor = torch.tensor([word_indices], dtype=torch.long)
    lengths = torch.tensor([len(word_indices)], dtype=torch.long)
    with torch.no_grad():
        logits = model(input_tensor)
        predictions = torch.argmax(logits, dim=-1)
    predicted_pos_indices = predictions[0][:lengths[0]].cpu().numpy()
    predicted_pos_tags = [idx2pos[idx] for idx in predicted_pos_indices]
    word_pos_pairs = list(zip(words, predicted_pos_tags))
    tree = Tree('S', [Tree(pos, [word]) for word, pos in word_pos_pairs])
    ordered_unique_pos = []
    for pos in predicted_pos_tags:
        if pos not in ordered_unique_pos:
            ordered_unique_pos.append(pos)    
    description = {pos: pos_descriptions.get(pos, "Unknown POS tag") for pos in ordered_unique_pos}
    return tree, description

In [23]:
# ----------------------
# User Interface
# ----------------------
def create_pos_tagging_interface(model_name="s_transformer-e"):
    word2idx, idx2pos = load_vocab(model_name)
    model = load_model(model_name, word2idx)

    title = widgets.Label(value="POS Tagging")
    text_input = widgets.Textarea(description="Sentence:", placeholder="e.g. The quick brown fox jumps over the lazy dog.")
    output_area = widgets.Textarea(value="Result:", layout=widgets.Layout(height='200px', width='500px'), disabled=True)
    tag_button = widgets.Button(description="Tag")
    
    def on_infer_clicked(b):
        input_text = text_input.value
        if input_text.strip():
            tree, description = predict_pos_tag(model, word2idx, idx2pos, input_text)
            result = TreePrettyPrinter(tree).text()
            description_result = "\n".join([f"{k}: {v}" for k, v in description.items()])
            output_area.value = f"Result:\n{result}\n{description_result}"
        else:
            output_area.value = "Please enter some text."
    
    tag_button.on_click(on_infer_clicked)
    
    display(widgets.VBox([title, text_input, tag_button, output_area]))

In [24]:
create_pos_tagging_interface("s_transformer-e")

VBox(children=(Label(value='POS Tagging'), Textarea(value='', description='Sentence:', placeholder='e.g. The q…