In [11]:
# Core PyTorch
import torch
import torch.nn as nn

# TorchText for NLP
from torchtext.data.utils import get_tokenizer


# Standard Python Libraries (Only if needed, math is in model.py)
# import math 

# custom modules
import sys
import os
# Go up one level from 'notebooks' to the project root
sys.path.append(os.path.abspath('..')) 


from model import TransformerClassifier
from utils import evaluate
import config



print("✅ All necessary libraries and modules imported successfully.")

✅ All necessary libraries and modules imported successfully.


In [12]:
# --- Part 1: Load Saved Artifacts ---

# Define the category mapping globally
CATEGORY_MAP = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}

# Define global components
tokenizer = get_tokenizer("basic_english")

# Load the vocabulary
vocab = torch.load(config.VOCAB_SAVE_PATH)
PAD_IDX = vocab['<pad>'] # Not needed for single-sentence predict, but good practice

print(f"✅ Vocab loaded. Size: {len(vocab)} words.")
print(f"✅ Tokenizer defined. Running on device: {config.DEVICE}")

✅ Vocab loaded. Size: 95812 words.
✅ Tokenizer defined. Running on device: cuda


In [13]:
# --- Part 2: Load Trained Model ---

# DEFINE THE EXACT SAME HYPERPARAMETERS USED FOR TRAINING
VOCAB_SIZE = len(vocab)

# Instantiate the model architecture
inference_model = TransformerClassifier(
    vocab_size=VOCAB_SIZE,
    d_model=config.D_MODEL,
    num_heads=config.NUM_HEADS,
    num_layers=config.NUM_LAYERS,
    d_ff=config.D_FF,
    num_classes=config.NUM_CLASSES
).to(config.DEVICE)

# Load the saved weights from your BEST model file
inference_model.load_state_dict(torch.load(config.MODEL_SAVE_PATH))

# Set the model to evaluation mode (CRITICAL)
inference_model.eval()

print(f"✅ Model architecture created and trained weights loaded from {config.MODEL_SAVE_PATH}.")

✅ Model architecture created and trained weights loaded from ../models/transformer_news_classifier_best.pth.


In [14]:

def predict(text, model, vocab, tokenizer, device):
    """
    Takes a raw text string and returns the predicted category.
    
    Args:
        text (str): The input news article or sentence.
        model (nn.Module): Your trained TransformerClassifier.
        vocab (Vocab): The vocabulary object from your training data.
        tokenizer (callable): The tokenizer you used for training.
        device (str): The device the model is on ('cuda' or 'cpu').
        
    Returns:
        str: The predicted category name.
    """
    
    # 1. Set model to evaluation mode
    model.eval()
    
    # 2. Tokenize and Numericalize the text
    # We use the same vocab and tokenizer as in training
    token_ids = torch.tensor([vocab[token] for token in tokenizer(text)], dtype=torch.long)
    
    # 3. Add the batch dimension and move to the correct device
    # The model expects a batch of data, so we unsqueeze to create a batch of 1
    token_ids = token_ids.unsqueeze(0).to(device)
    
    # 4. Get the model's prediction (logits)
    # No gradient tracking is needed
    with torch.no_grad():
        logits = model(token_ids)
    
    # 5. Get the predicted class index by finding the max logit
    predicted_index = torch.argmax(logits, dim=1).item()
    
    # 6. Convert the index back to a human-readable category name
    # The AG_NEWS categories are: 1-World, 2-Sports, 3-Business, 4-Sci/Tech
    # Since we subtracted 1, our indices are: 0-World, 1-Sports, 2-Business, 3-Sci/Tech
    category_map = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
    
    return CATEGORY_MAP[predicted_index]

In [15]:
# --- INFERENCE TEST ---



# Example 1: A Business/Finance news headline
news_article_1 = "The US economy is a puzzle but the pieces aren't fitting together"
prediction_1 = predict(news_article_1, inference_model, vocab, tokenizer, config.DEVICE)
print(f"Article: '{news_article_1}'")
print(f"Predicted Category: {prediction_1}\n")


# Example 2: A Sports news headline
news_article_2 = "Premier League: Chelsea held by Crystal Palace & Forest beat Brentford - reaction"
prediction_2 = predict(news_article_2, inference_model, vocab, tokenizer, config.DEVICE)
print(f"Article: '{news_article_2}'")
print(f"Predicted Category: {prediction_2}\n")


# Example 3: A Sci/Tech news headline
news_article_3 = "Will AI make language dubbing easy for film and TV?"
prediction_3 = predict(news_article_3, inference_model, vocab, tokenizer, config.DEVICE)
print(f"Article: '{news_article_3}'")
print(f"Predicted Category: {prediction_3}\n")

# Example 4: A World news headline
news_article_4 = "Putin agreed to security guarantees for Ukraine being part of potential peace deal, US envoy says"
prediction_4 = predict(news_article_4, inference_model, vocab, tokenizer, config.DEVICE)
print(f"Article: '{news_article_4}'")
print(f"Predicted Category: {prediction_4}\n")

Article: 'The US economy is a puzzle but the pieces aren't fitting together'
Predicted Category: Sports

Article: 'Premier League: Chelsea held by Crystal Palace & Forest beat Brentford - reaction'
Predicted Category: World

Article: 'Will AI make language dubbing easy for film and TV?'
Predicted Category: Sci/Tech

Article: 'Putin agreed to security guarantees for Ukraine being part of potential peace deal, US envoy says'
Predicted Category: World

