# Task 3: BERT Inference (Kaggle)

Attach these Kaggle inputs before running:
- Dataset: `thesis-data-task3-test-gold-500`
- Notebook Output: `task3-bert-training` (latest version)

The notebook loads the trained model, runs QA-based span extraction across 500 test files, reports Exact Match/F1, shows sample predictions, and saves `task3_bert_predictions.csv`.

In [None]:
# ============================================================
# Cell 1: Setup
# ============================================================
import json
import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from tqdm import tqdm
import pandas as pd

print("✅ Libraries imported")

In [None]:
# ============================================================
# Cell 2: Load Model
# ============================================================
# Path to saved model from training notebook output
model_path = '/kaggle/input/task3-bert-training/models/task3_bert_final'

print(f"📦 Loading model from: {model_path}")

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForQuestionAnswering.from_pretrained(model_path)

# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

print(f"✅ Model loaded on {device}")
print(f"📊 Model parameters: {model.num_parameters():,}")

In [None]:
# ============================================================
# Cell 3: Load Test Data
# ============================================================
test_path = Path('/kaggle/input/thesis-data-task3-test-gold-500/test_gold_500')
test_files = sorted(test_path.glob("*.in"))

print(f"📂 Test path: {test_path}")
print(f"📊 Found {len(test_files)} test files")

In [None]:
# ============================================================
# Cell 4: Inference Function
# ============================================================
def extract_citation_span(text, citation, model, tokenizer, device, max_length=512):
    """
    Extract span for a given citation using Question Answering approach
    """
    # Create question
    question = f"What does citation {citation} support?"

    # Tokenize
    inputs = tokenizer(
        question,
        text,
        max_length=max_length,
        truncation='only_second',
        return_tensors='pt',
        padding=True
    )

    # Move to device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Get predictions
    with torch.no_grad():
        outputs = model(**inputs)

    # Get start and end positions
    start_logits = outputs.start_logits
    end_logits = outputs.end_logits

    start_idx = torch.argmax(start_logits, dim=1).item()
    end_idx = torch.argmax(end_logits, dim=1).item()

    # Decode answer
    if start_idx <= end_idx and start_idx > 0:
        answer_tokens = inputs['input_ids'][0][start_idx:end_idx+1]
        predicted_span = tokenizer.decode(answer_tokens, skip_special_tokens=True)
    else:
        predicted_span = ""

    # Get confidence scores
    start_prob = torch.softmax(start_logits, dim=1)[0][start_idx].item()
    end_prob = torch.softmax(end_logits, dim=1)[0][end_idx].item()
    confidence = (start_prob + end_prob) / 2

    return {
        'predicted_span': predicted_span,
        'start_idx': start_idx,
        'end_idx': end_idx,
        'confidence': confidence
    }

print("✅ Inference function defined")

In [None]:
# ============================================================
# Cell 5: Run Inference on Test Set
# ============================================================
results = []
errors = []

print("🚀 Starting inference...")
print("=" * 60)

for test_file in tqdm(test_files, desc="Processing files"):
    try:
        # Load input file
        with open(test_file) as f:
            in_data = json.load(f)

        # Load label file (for comparison)
        label_file = test_file.with_suffix('.label')
        with open(label_file) as f:
            label_data = json.load(f)

        text = in_data['text']
        citation_spans = label_data.get('citation_spans', [])

        # Process each citation
        for span_info in citation_spans:
            citation_id = span_info['citation_id']
            gold_span = span_info['span_text']

            # Run inference
            prediction = extract_citation_span(
                text=text,
                citation=citation_id,
                model=model,
                tokenizer=tokenizer,
                device=device
            )

            results.append({
                'file': test_file.stem,
                'citation_id': citation_id,
                'gold_span': gold_span,
                'predicted_span': prediction['predicted_span'],
                'confidence': prediction['confidence'],
                'start_idx': prediction['start_idx'],
                'end_idx': prediction['end_idx']
            })

    except Exception as e:
        errors.append({
            'file': test_file.stem,
            'error': str(e)
        })

print(f"
✅ Inference complete!")
print(f"📊 Processed: {len(results)} predictions")
print(f"❌ Errors: {len(errors)}")

In [None]:
# ============================================================
# Cell 6: Calculate Metrics
# ============================================================
def calculate_exact_match(gold, pred):
    """Exact match: predicted == gold (after normalization)"""
    return gold.strip().lower() == pred.strip().lower()


def calculate_f1(gold, pred):
    """F1 score based on token overlap"""
    gold_tokens = gold.strip().lower().split()
    pred_tokens = pred.strip().lower().split()

    if len(gold_tokens) == 0 or len(pred_tokens) == 0:
        return 0.0

    common = set(gold_tokens) & set(pred_tokens)

    if len(common) == 0:
        return 0.0

    precision = len(common) / len(pred_tokens)
    recall = len(common) / len(gold_tokens)

    f1 = 2 * (precision * recall) / (precision + recall)
    return f1


# Calculate metrics
exact_matches = 0
total_f1 = 0

for result in results:
    if calculate_exact_match(result['gold_span'], result['predicted_span']):
        exact_matches += 1

    total_f1 += calculate_f1(result['gold_span'], result['predicted_span'])

exact_match_score = exact_matches / len(results) if results else 0
avg_f1_score = total_f1 / len(results) if results else 0

print("=" * 60)
print("📊 EVALUATION METRICS")
print("=" * 60)
print(f"Total predictions: {len(results)}")
print(f"Exact Match: {exact_match_score:.4f} ({exact_matches}/{len(results)})")
print(f"F1 Score: {avg_f1_score:.4f}")
print(f"Average Confidence: {sum(r['confidence'] for r in results)/len(results):.4f}")
print("=" * 60)

In [None]:
# ============================================================
# Cell 7: Show Examples
# ============================================================
print("
📋 SAMPLE PREDICTIONS:
")

# Show first 10 predictions
for i, result in enumerate(results[:10]):
    print(f"
{'='*60}")
    print(f"Example {i+1}:")
    print(f"File: {result['file']}")
    print(f"Citation: {result['citation_id']}")
    print(f"Gold Span: {result['gold_span'][:100]}...")
    print(f"Predicted: {result['predicted_span'][:100]}...")
    print(f"Confidence: {result['confidence']:.4f}")
    match = "✅ MATCH" if calculate_exact_match(result['gold_span'], result['predicted_span']) else "❌ NO MATCH"
    print(f"Result: {match}")

In [None]:
# ============================================================
# Cell 8: Save Results
# ============================================================
# Convert to DataFrame
df = pd.DataFrame(results)

# Save to CSV
output_file = 'task3_bert_predictions.csv'
df.to_csv(output_file, index=False)

print(f"
✅ Results saved to: {output_file}")
print(f"📊 Total rows: {len(df)}")

# Show errors if any
if errors:
    print(f"
⚠️ Errors encountered: {len(errors)}")
    for error in errors[:5]:
        print(f"  - {error['file']}: {error['error']}")

print("
" + "="*60)
print("✅ INFERENCE COMPLETE!")
print("="*60)