In [13]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
import json
import string

MODEL_DIR = "snips_ner_model_full"

tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForTokenClassification.from_pretrained(MODEL_DIR)
model.eval()

with open("ner_id2label.json") as f:
    id2label = {int(k): v for k, v in json.load(f).items()}

In [14]:
# Clean text (match training behavior)

def clean_inference_text(text: str) -> str:
    return text.translate(str.maketrans("", "", string.punctuation))

In [15]:
# Merge WordPieces into words

def wordpiece_to_words(tokens, tags):
    words = []
    word_tags = []

    current_word = ""
    current_tag = None

    for tok, tag in zip(tokens, tags):
        if tok.startswith("##"):
            # continuation of previous word
            current_word += tok[2:]
        else:
            # flush previous
            if current_word:
                words.append(current_word)
                word_tags.append(current_tag)
            current_word = tok
            current_tag = tag

    # flush last
    if current_word:
        words.append(current_word)
        word_tags.append(current_tag)

    return words, word_tags

In [16]:
# Decode BIO tags on word-level tokens
def decode_bio(words, tags):
    results = []
    buffer = []
    current_slot = None

    for w, tag in zip(words, tags):
        if tag is None or tag == "O":
            # End of any current entity
            if buffer and current_slot:
                results.append({
                    "slot": current_slot,
                    "value": " ".join(buffer)
                })
            buffer = []
            current_slot = None
            continue

        prefix, slot = tag.split("-", 1)

        if prefix == "B":
            # Start of a new entity: flush previous
            if buffer and current_slot:
                results.append({
                    "slot": current_slot,
                    "value": " ".join(buffer)
                })
            buffer = [w]
            current_slot = slot

        elif prefix == "I":
            if current_slot == slot:
                # Continuation of the same entity
                buffer.append(w)
            else:
                # Malformed I-tag, treat as a new B
                if buffer and current_slot:
                    results.append({
                        "slot": current_slot,
                        "value": " ".join(buffer)
                    })
                buffer = [w]
                current_slot = slot

    # Flush leftover at the end
    if buffer and current_slot:
        results.append({
            "slot": current_slot,
            "value": " ".join(buffer)
        })

    return results

In [17]:
# Main prediction function

def predict_slots(text):
    cleaned = clean_inference_text(text)

    enc = tokenizer(cleaned, return_tensors="pt", truncation=True, max_length=64)

    with torch.no_grad():
        outputs = model(**enc)
    pred_ids = torch.argmax(outputs.logits, dim=-1)[0].tolist()

    wp_tokens = tokenizer.convert_ids_to_tokens(enc["input_ids"][0])
    wp_tags = [id2label[i] for i in pred_ids]

    # Remove CLS, SEP, PAD
    wp_tokens_clean = []
    wp_tags_clean = []
    for t, tag in zip(wp_tokens, wp_tags):
        if t in ["[CLS]", "[SEP]"] or t.startswith("[PAD]"):
            continue
        wp_tokens_clean.append(t)
        wp_tags_clean.append(tag)

    # 1. Convert WordPieces -> Words
    words, word_tags = wordpiece_to_words(wp_tokens_clean, wp_tags_clean)

    # 2. Decode IOBES at word level
    spans = decode_bio(words, word_tags)

    return words, word_tags, spans

In [18]:
# Run inference

text = input("Enter query: ")
words, tags, spans = predict_slots(text)

print("\nWORDS:", words)
print("TAGS :", tags)
print("\nSPANS:", spans)


WORDS: ['id', 'like', 'to', 'have', 'this', 'track', 'onto', 'my', 'classical', 'relaxations', 'playlist']
TAGS : ['O', 'O', 'O', 'O', 'O', 'B-music_item', 'O', 'B-playlist_owner', 'B-playlist', 'I-playlist', 'O']

SPANS: [{'slot': 'music_item', 'value': 'track'}, {'slot': 'playlist_owner', 'value': 'my'}, {'slot': 'playlist', 'value': 'classical relaxations'}]
