In [8]:
import torch
import json
from transformers import GPT2TokenizerFast
import torch.nn.functional as F

def run_inference(model_path: str, tokenizer_path: str, label_encoders_path: str, input_text: str, 
                  max_length: int = 128, debug: bool = True):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load the full model
    model = torch.load(model_path, map_location=device, weights_only=False)
    model.eval()

    # Load tokenizer
    tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_path)

    # Load label encoders
    with open(label_encoders_path, "r", encoding="utf-8") as f:
        label_encoders = json.load(f)
    intent_decoder = {v: k for k, v in label_encoders["intent_encoder"].items()}
    category_decoder = {v: k for k, v in label_encoders["category_encoder"].items()}
    ner_decoder = {v: k for k, v in label_encoders["ner_label_encoder"].items()}

    # Preprocess input with offset mapping
    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_offsets_mapping=True
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    offset_mapping = inputs["offset_mapping"][0].cpu().tolist()

    # Run inference
    with torch.no_grad():
        outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
        intent_logits = outputs["intent_logits"]
        category_logits = outputs["category_logits"]
        ner_logits = outputs["ner_logits"][0]  # [seq_len, num_ner_labels]

        # Intent prediction
        intent_probs = F.softmax(intent_logits, dim=-1)[0]
        intent_pred_idx = torch.argmax(intent_probs).item()
        intent_confidence = intent_probs[intent_pred_idx].item()
        intent_label = intent_decoder[intent_pred_idx]

        # Category prediction
        category_probs = F.softmax(category_logits, dim=-1)[0]
        category_pred_idx = torch.argmax(category_probs).item()
        category_confidence = category_probs[category_pred_idx].item()
        category_label = category_decoder[category_pred_idx]

        # NER prediction
        ner_probs = F.softmax(ner_logits, dim=-1)
        ner_pred_idxs = torch.argmax(ner_probs, dim=-1).tolist()
        ner_confidences = torch.max(ner_probs, dim=-1).values.tolist()
        ner_labels = [ner_decoder[idx] for idx in ner_pred_idxs]

        # Truncate to sequence length
        seq_len = inputs["attention_mask"][0].sum().item()
        ner_labels = ner_labels[:seq_len]
        ner_confidences = ner_confidences[:seq_len]
        offset_mapping = offset_mapping[:seq_len]
        tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][:seq_len].cpu().tolist())

    # Debug raw predictions
    if debug:
        print(f"\nText: {input_text}")
        print("Token | Predicted Tag | Probability")
        for token, label, conf in zip(tokens, ner_labels, ner_confidences):
            print(f"{token:<15} | {label:<15} | {conf:.4f}")

    # Detect entity spans with subword handling
    entities = []
    current_entity = None
    entity_start = None
    entity_confidences = []
    entity_text = ""

    for i, (label, conf, (start, end), token) in enumerate(zip(ner_labels, ner_confidences, offset_mapping, tokens)):
        if label.startswith("B-"):
            if current_entity is not None:
                entities.append({
                    "entity": entity_text.strip(),
                    "label": current_entity,
                    "confidence": sum(entity_confidences) / len(entity_confidences)
                })
            current_entity = label[2:]
            entity_start = start
            entity_confidences = [conf]
            entity_text = token if not token.startswith("##") else token[2:]
        elif label.startswith("I-") and current_entity == label[2:]:
            entity_confidences.append(conf)
            if token.startswith("##"):
                entity_text += token[2:]  # Append subword without "##"
            else:
                entity_text += " " + token
        elif label == "O" and current_entity is not None:
            entities.append({
                "entity": entity_text.strip(),
                "label": current_entity,
                "confidence": sum(entity_confidences) / len(entity_confidences)
            })
            current_entity = None
            entity_confidences = []
            entity_text = ""
        elif label == "O":
            current_entity = None
            entity_confidences = []
            entity_text = ""

    if current_entity is not None:
        entities.append({
            "entity": entity_text.strip(),
            "label": current_entity,
            "confidence": sum(entity_confidences) / len(entity_confidences)
        })

    # Compile results
    results = {
        "intent": {"label": intent_label, "confidence": intent_confidence},
        "category": {"label": category_label, "confidence": category_confidence},
        "ner": entities
    }

    return results

# Example usage
if __name__ == "__main__":
    model_path = "../results/baseline/test_2/model/full_model.pt"
    tokenizer_path = "../results/baseline/test_2/tokenizer"
    label_encoders_path = "../results/baseline/test_2/label_encoders.json"
    input_text = "I want to cnacel order ord-2134"

    results = run_inference(model_path, tokenizer_path, label_encoders_path, input_text)
    
    print("\nInference Results:")
    print(f"Intent: {results['intent']['label']} (Confidence: {results['intent']['confidence']:.4f})")
    print(f"Category: {results['category']['label']} (Confidence: {results['category']['confidence']:.4f})")
    print("NER:")
    for entity in results['ner']:
        print(f"  Entity: {entity['entity']} | Label: {entity['label']} | Confidence: {entity['confidence']:.4f}")


Text: I want to cnacel order ord-2134
Token | Predicted Tag | Probability
I               | O               | 0.9995
Ġwant           | O               | 0.9998
Ġto             | O               | 0.9999
Ġc              | O               | 0.9986
n               | O               | 0.9986
ac              | O               | 0.9994
el              | O               | 0.9989
Ġorder          | O               | 0.9998
Ġord            | O               | 0.9996
-               | O               | 0.9955
2               | O               | 0.9986
134             | O               | 0.9972

Inference Results:
Intent: delivery_options (Confidence: 0.1098)
Category: order (Confidence: 0.3472)
NER:
