In [1]:
import os
import json
import numpy as np
import spacy
import onnxruntime
import ipywidgets as widgets
from IPython.display import display

In [2]:
# ----------------------
# Load Vocabulary
# ----------------------
def load_vocab(model_name):
    model_path = os.path.join("./models", model_name)
    with open(os.path.join(model_path, 'word2idx.json'), 'r') as json_file:
        word2idx = json.load(json_file)
    with open(os.path.join(model_path, 'char2idx.json'), 'r') as json_file:
        char2idx = json.load(json_file)
    with open(os.path.join(model_path, 'idx2tag.json'), 'r') as json_file:
        idx2tag = json.load(json_file)
    return word2idx, char2idx, {int(k): v for k, v in idx2tag.items()}

# ----------------------
# Load ONNX Model
# ----------------------
def load_model(model_name):
    model_path = os.path.join("./models", model_name, "model-q.onnx")
    return onnxruntime.InferenceSession(model_path)

# ----------------------
# Tokenizer
# ----------------------
def load_spacy():
    model_path = os.path.join("./data", "en_core_web_sm", "en_core_web_sm-3.8.0")
    if os.path.isdir(model_path):
        return spacy.load(model_path)
    else:
        raise FileNotFoundError(f"SpaCy model not found at {model_path}. Please ensure it is correctly placed.")
spacy_en = load_spacy()

def tokenizer(sentence):
    return [
        token.text for token in spacy_en(sentence)
    ]

# ----------------------
# Predict Function
# ----------------------
def predict_ner(ort_session, sentence, word2idx, char2idx, idx2tag):
    tokens = tokenizer(sentence)
    max_len = 50
    max_char_len = 10
    word_ids = []
    char_ids = []

    for token in tokens:
        word_id = word2idx.get(token.lower(), word2idx['<UNK>'])
        word_ids.append(word_id)
        chars_of_token = [char2idx.get(c, char2idx['<UNK>']) for c in token]
        if len(chars_of_token) > max_char_len:
            chars_of_token = chars_of_token[:max_char_len]
        else:
            chars_of_token += [char2idx['<PAD>']] * (max_char_len - len(chars_of_token))

        char_ids.append(chars_of_token)

    if len(word_ids) > max_len:
        word_ids = word_ids[:max_len]
        char_ids = char_ids[:max_len]
    else:
        pad_length = max_len - len(word_ids)
        word_ids += [word2idx['<PAD>']] * pad_length
        char_ids += [[char2idx['<PAD>']] * max_char_len] * pad_length

    word_tensor = np.array([word_ids], dtype=np.int64)
    char_tensor = np.array([char_ids], dtype=np.int64)
    inputs = {
        "word_ids": word_tensor,
        "char_ids": char_tensor
    }
    emissions = ort_session.run(None, inputs)[0]
    preds = np.argmax(emissions, axis=2).squeeze(0)
    real_length = min(len(tokens), max_len)
    pred_tags = [idx2tag[preds[i]] for i in range(real_length)]
    return tokens, pred_tags

In [3]:
# ----------------------
# User Interface
# ----------------------
def create_ner_interface(model_name="bilstm-w-a"):
    word2idx, char2idx, idx2tag = load_vocab(model_name)
    model = load_model(model_name)

    title = widgets.Label(value="Named-Entity Recognition (NER)")
    text_input = widgets.Textarea(description="Sentence:", placeholder="e.g. U.N. official Ekeus heads for Baghdad.")
    output_area = widgets.Textarea(value="Result:", layout=widgets.Layout(height='150px'), disabled=True)
    tag_button = widgets.Button(description="Infer")
    
    def on_infer_clicked(b):
        input_text = text_input.value
        if input_text.strip():
            tokens, tags = predict_ner(model, input_text, word2idx, char2idx, idx2tag)
            result = "\n".join([f"{token}: {tag}" for token, tag in zip(tokens, tags)])
            output_area.value = f"Result:\n{result}"
        else:
            output_area.value = "Please enter some text for analysis."
    
    tag_button.on_click(on_infer_clicked)
    
    display(widgets.VBox([title, text_input, tag_button, output_area]))

In [4]:
create_ner_interface("bilstm-w-a")

VBox(children=(Label(value='Named-Entity Recognition (NER)'), Textarea(value='', description='Sentence:', plac…