# Task 3: BERT Inference (Kaggle)

Attach these Kaggle inputs before running:
- Dataset: `thesis-data-task3-test-gold-500`
- Notebook Output: `task3-bert-training` (latest version)

The notebook loads the trained model, runs QA-based span extraction across 500 test files, reports Exact Match/F1, shows sample predictions, and saves `task3_bert_predictions.csv`.

In [1]:
# ============================================================
# Cell 1: Setup
# ============================================================
import json
import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from tqdm import tqdm
import pandas as pd

print("‚úÖ Libraries imported")

‚úÖ Libraries imported


In [2]:
# ============================================================
# Cell 2: Load Model
# ============================================================
# Path to saved model from training notebook output
model_path = '/kaggle/input/task3-bert-training/models/task3_bert_final'

print(f"üì¶ Loading model from: {model_path}")

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForQuestionAnswering.from_pretrained(model_path)

# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

print(f"‚úÖ Model loaded on {device}")
print(f"üìä Model parameters: {model.num_parameters():,}")

üì¶ Loading model from: /kaggle/input/task3-bert-training/models/task3_bert_final


2026-01-18 09:26:52.926353: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1768728413.249060      18 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1768728413.345280      18 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1768728414.143810      18 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768728414.143860      18 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768728414.143863      18 computation_placer.cc:177] computation placer alr

‚úÖ Model loaded on cpu
üìä Model parameters: 108,893,186


In [3]:
# ============================================================
# Cell 3: Load Test Data
# ============================================================
test_path = Path('/kaggle/input/thesis-data-task3-test-gold-500/test_gold_500')
test_files = sorted(test_path.glob("*.in"))

print(f"üìÇ Test path: {test_path}")
print(f"üìä Found {len(test_files)} test files")

üìÇ Test path: /kaggle/input/thesis-data-task3-test-gold-500/test_gold_500
üìä Found 500 test files


In [4]:
# ============================================================
# Cell 4: Inference Function
# ============================================================
def extract_citation_span(text, citation, model, tokenizer, device, max_length=512):
    """
    Extract span for a given citation using Question Answering approach
    """
    # Create question
    question = f"What does citation {citation} support?"

    # Tokenize
    inputs = tokenizer(
        question,
        text,
        max_length=max_length,
        truncation='only_second',
        return_tensors='pt',
        padding=True
    )

    # Move to device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Get predictions
    with torch.no_grad():
        outputs = model(**inputs)

    # Get start and end positions
    start_logits = outputs.start_logits
    end_logits = outputs.end_logits

    start_idx = torch.argmax(start_logits, dim=1).item()
    end_idx = torch.argmax(end_logits, dim=1).item()

    # Decode answer
    if start_idx <= end_idx and start_idx > 0:
        answer_tokens = inputs['input_ids'][0][start_idx:end_idx+1]
        predicted_span = tokenizer.decode(answer_tokens, skip_special_tokens=True)
    else:
        predicted_span = ""

    # Get confidence scores
    start_prob = torch.softmax(start_logits, dim=1)[0][start_idx].item()
    end_prob = torch.softmax(end_logits, dim=1)[0][end_idx].item()
    confidence = (start_prob + end_prob) / 2

    return {
        'predicted_span': predicted_span,
        'start_idx': start_idx,
        'end_idx': end_idx,
        'confidence': confidence
    }

print("‚úÖ Inference function defined")

‚úÖ Inference function defined


In [5]:
# ============================================================
# Cell 5: Run Inference on Test Set
# ============================================================
results = []
errors = []

print("üöÄ Starting inference...")
print("=" * 60)

for test_file in tqdm(test_files, desc="Processing files"):
    try:
        # Load input file
        with open(test_file) as f:
            in_data = json.load(f)

        # Load label file (for comparison)
        label_file = test_file.with_suffix('.label')
        with open(label_file) as f:
            label_data = json.load(f)

        text = in_data['text']
        citation_spans = label_data.get('citation_spans', [])

        # Process each citation
        for span_info in citation_spans:
            citation_id = span_info['citation_id']
            gold_span = span_info['span_text']

            # Run inference
            prediction = extract_citation_span(
                text=text,
                citation=citation_id,
                model=model,
                tokenizer=tokenizer,
                device=device
            )

            results.append({
                'file': test_file.stem,
                'citation_id': citation_id,
                'gold_span': gold_span,
                'predicted_span': prediction['predicted_span'],
                'confidence': prediction['confidence'],
                'start_idx': prediction['start_idx'],
                'end_idx': prediction['end_idx']
            })

    except Exception as e:
        errors.append({
            'file': test_file.stem,
            'error': str(e)
        })

print(f"‚úÖ Inference complete!")
print(f"üìä Processed: {len(results)} predictions")
print(f"‚ùå Errors: {len(errors)}")

üöÄ Starting inference...


Processing files: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 500/500 [10:28<00:00,  1.26s/it]

‚úÖ Inference complete!
üìä Processed: 1272 predictions
‚ùå Errors: 0





In [6]:
# ============================================================
# Cell 6: Calculate Metrics
# ============================================================
def calculate_exact_match(gold, pred):
    """Exact match: predicted == gold (after normalization)"""
    return gold.strip().lower() == pred.strip().lower()


def calculate_f1(gold, pred):
    """F1 score based on token overlap"""
    gold_tokens = gold.strip().lower().split()
    pred_tokens = pred.strip().lower().split()

    if len(gold_tokens) == 0 or len(pred_tokens) == 0:
        return 0.0

    common = set(gold_tokens) & set(pred_tokens)

    if len(common) == 0:
        return 0.0

    precision = len(common) / len(pred_tokens)
    recall = len(common) / len(gold_tokens)

    f1 = 2 * (precision * recall) / (precision + recall)
    return f1


# Calculate metrics
exact_matches = 0
total_f1 = 0

for result in results:
    if calculate_exact_match(result['gold_span'], result['predicted_span']):
        exact_matches += 1

    total_f1 += calculate_f1(result['gold_span'], result['predicted_span'])

exact_match_score = exact_matches / len(results) if results else 0
avg_f1_score = total_f1 / len(results) if results else 0

print("=" * 60)
print("üìä EVALUATION METRICS")
print("=" * 60)
print(f"Total predictions: {len(results)}")
print(f"Exact Match: {exact_match_score:.4f} ({exact_matches}/{len(results)})")
print(f"F1 Score: {avg_f1_score:.4f}")
print(f"Average Confidence: {sum(r['confidence'] for r in results)/len(results):.4f}")
print("=" * 60)

üìä EVALUATION METRICS
Total predictions: 1272
Exact Match: 0.0094 (12/1272)
F1 Score: 0.2596
Average Confidence: 0.6060


In [7]:
# ============================================================
# Cell 7: Show Examples (TEXT -> GOLD -> PRED)
# ============================================================
import json
import textwrap
from pathlib import Path

TEST_DIR = Path("/kaggle/input/thesis-data-task3-test-gold-500/test_gold_500")
N = 10
WIDTH = 120
WINDOW = 260  # chars before/after citation marker

def clean(s):
    return " ".join((s or "").split())

def wrap(s, indent="  "):
    s = clean(s)
    if not s:
        return indent + "(empty)"
    return textwrap.fill(s, width=WIDTH, initial_indent=indent, subsequent_indent=indent)

def context(text, marker):
    text = text or ""
    idx = text.find(marker)
    if idx == -1:
        return f"(marker {marker} not found)"
    start = max(0, idx - WINDOW)
    end = min(len(text), idx + len(marker) + WINDOW)
    snip = text[start:end].replace(marker, f"<<{marker}>>", 1)
    return snip

print("\n===== SAMPLE (TEXT -> GOLD -> PRED) =====\n")

for i, r in enumerate(results[:N], 1):
    doc_id = r["file"]
    cid = r["citation_id"]

    label = json.load(open(TEST_DIR / f"{doc_id}.label"))
    text = label.get("text", "")
    gold = next((x["span_text"] for x in label.get("citation_spans", []) if x["citation_id"] == cid), "")
    pred = r.get("predicted_span", "")
    conf = r.get("confidence", 0.0)

    print("=" * 110)
    print(f"[{i}] doc={doc_id}  citation={cid}  conf={conf:.4f}")

    print("\nTEXT:")
    print(wrap(context(text, cid), indent="  "))

    print("\nGOLD:")
    print(wrap(gold, indent="  "))

    print("\nPRED:")
    print(wrap(pred, indent="  "))

    print()



===== SAMPLE (TEXT -> GOLD -> PRED) =====

[1] doc=10050  citation=[CITATION_1]  conf=0.7373

TEXT:
  c. Studies of vertebrate hearts suggest a role for Kind2 in cardiac development and function, however these studies
  are limited by the embryonic lethality of Kind2 knock-out in the mouse model, and the lack of tissue specific Kind2
  silencing in the fish model <<[CITATION_1]>> [CITATION_2] . The aberrant phenotype caused by silencing orthologs of
  Kind2 in the cardiomyocytes of an invertebrate demonstrates that the protein's role in cardiac development has been
  evolutionarily conserved and also reiterates the validity of using Droso

GOLD:
  Studies of vertebrate hearts suggest a role for Kind2 in cardiac development and function, however these studies are
  limited by the embryonic lethality of Kind2 knock-out in the mouse model, and the lack of tissue specific Kind2
  silencing in the fish model .

PRED:
  . when drosophila cardiomyocytes fail to couple together to form a card

In [8]:
# ============================================================
# Cell 8: Save Results
# ============================================================
# Convert to DataFrame
df = pd.DataFrame(results)

# Save to CSV
output_file = 'task3_bert_predictions.csv'
df.to_csv(output_file, index=False)

print(f"‚úÖ Results saved to: {output_file}")
print(f"üìä Total rows: {len(df)}")

# Show errors if any
if errors:
    print(f"‚ö†Ô∏è Errors encountered: {len(errors)}")
    for error in errors[:5]:
        print(f"  - {error['file']}: {error['error']}")

print("" + "="*60)
print("‚úÖ INFERENCE COMPLETE!")
print("="*60)

‚úÖ Results saved to: task3_bert_predictions.csv
üìä Total rows: 1272
‚úÖ INFERENCE COMPLETE!
