# Quick Prediction Test

Simple notebook to test the transformer model on custom text.

In [None]:
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd().parent))

import torch
from transformers import DistilBertModel, DistilBertTokenizerFast
import transformer.config as config
from transformer.model import TransformerClassifier, DebiasedTransformerClassifier

# Load model
ensemble_info = torch.load(config.CACHE_DIR / 'transformer_ensemble_info.pt', weights_only=True)
best_fold = ensemble_info['top_indices'][-1]
checkpoint = torch.load(config.CACHE_DIR / f'transformer_model_fold_{best_fold}.pt', weights_only=True)

bert = DistilBertModel.from_pretrained(config.TRANSFORMER_MODEL)
tokenizer = DistilBertTokenizerFast.from_pretrained(config.TRANSFORMER_MODEL)

# Load correct model type
is_debiased = checkpoint.get('debiased', False)
print(f'Model type: {"Debiased" if is_debiased else "Standard"}')

if is_debiased:
    model = DebiasedTransformerClassifier(
        input_dim=768, hidden_dim=256, dropout=0.5,
        num_extra_features=checkpoint.get('num_extra_features', 0),
        gradient_reversal_lambda=config.GRADIENT_REVERSAL_LAMBDA
    )
else:
    model = TransformerClassifier(
        input_dim=768, hidden_dim=256, dropout=0.5,
        num_extra_features=checkpoint.get('num_extra_features', 0)
    )
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
bert.eval()

print(f'Loaded fold {best_fold} (F1={ensemble_info["fold_scores"][best_fold]:.4f})')

In [None]:
def predict(text):
    encoding = tokenizer(text, return_tensors='pt', truncation=True, max_length=512, padding=True)
    with torch.no_grad():
        outputs = bert(**encoding)
        cls_embedding = outputs.last_hidden_state[:, 0]
        prob = model(cls_embedding).item()
    label = 'HYPERPARTISAN' if prob > 0.5 else 'MAINSTREAM'
    return prob, label

In [None]:
# Test your article here
article = """The corrupt politicians are destroying our country with their radical socialist agenda."""

prob, label = predict(article)
print(f'P(hyperpartisan): {prob:.3f}')
print(f'Prediction: {label}')

In [None]:
# More examples
examples = [
    "Trump is destroying America with his fascist policies.",
    "The president signed the bill into law on Tuesday.",
    "Democrats are evil communists trying to take our freedom.",
    "Congress passed bipartisan legislation on infrastructure spending.",
]

for text in examples:
    prob, label = predict(text)
    print(f'{label:12} ({prob:.2f}): {text[:60]}...')