# Task 3: Citation Span Extraction - BERT + Special Tokens Inference

**Model:** bert-base-uncased (Question Answering) + Citation Special Tokens

**Task:** Extract text span that citation supports

**Dataset:** thesis-data-task3-test-gold-500

---

In [None]:
# Setup
import json
import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from tqdm import tqdm
import pandas as pd

print("‚úÖ Libraries imported")

In [None]:
# Load Model - WITH SPECIAL TOKENS
# Path to saved model from training notebook output
import os
from pathlib import Path

# First, let's explore the dataset structure
base_path = '/kaggle/input/task3-bert-training-specialtokens'
print(f"üìÇ Exploring {base_path}...")

if os.path.exists(base_path):
    for root, dirs, files in os.walk(base_path):
        level = root.replace(base_path, '').count(os.sep)
        indent = ' ' * 2 * level
        print(f"{indent}{os.path.basename(root)}/")
        subindent = ' ' * 2 * (level + 1)
        for file in files[:5]:  # Show first 5 files
            print(f"{subindent}{file}")
        if len(files) > 5:
            print(f"{subindent}... and {len(files)-5} more files")
        if level > 2:  # Stop after 3 levels deep
            break
else:
    print(f"‚ùå Base path not found: {base_path}")

print("\n" + "="*60)

# Now find the model directory
model_dir = None
models_path = os.path.join(base_path, 'models')

if os.path.exists(models_path):
    # List all subdirectories in models/
    subdirs = [d for d in os.listdir(models_path) if os.path.isdir(os.path.join(models_path, d))]
    print(f"Found subdirectories in models/: {subdirs}")
    
    if subdirs:
        # Use the first one (should be task3_bert_special_tokens_final)
        model_dir = os.path.join(models_path, subdirs[0])
        print(f"‚úÖ Using model directory: {model_dir}")

# Also check working/models path
working_models_path = os.path.join(base_path, 'working', 'models')
if os.path.exists(working_models_path):
    subdirs = [d for d in os.listdir(working_models_path) if os.path.isdir(os.path.join(working_models_path, d))]
    if subdirs:
        model_dir = os.path.join(working_models_path, subdirs[0])
        print(f"‚úÖ Using model directory: {model_dir}")

if model_dir is None:
    raise FileNotFoundError(f"‚ùå Could not find model directory in {base_path}")

print(f"\nüì¶ Loading model from: {model_dir}")

# Load tokenizer (includes special tokens)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForQuestionAnswering.from_pretrained(model_dir)

print(f"‚úÖ Tokenizer vocab size: {len(tokenizer)}")

# Verify special tokens
test_text = "This research [CITATION_1] shows that [CITATION_2] improves performance."
test_tokens = tokenizer.tokenize(test_text)
print(f"\nüìã Test tokenization:")
print(f"Text: {test_text}")
print(f"Tokens: {test_tokens}")

# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

print(f"\n‚úÖ Model loaded on {device}")
print(f"üìä Model parameters: {model.num_parameters():,}")

In [None]:
# Load Test Data
test_path = Path('/kaggle/input/thesis-data-task3-test-gold-500/test_gold_500')
test_files = sorted(test_path.glob("*.in"))

print(f"üìÇ Test path: {test_path}")
print(f"üìä Found {len(test_files)} test files")

In [None]:
# Inference Function
def extract_citation_span(text, citation, model, tokenizer, device, max_length=512):
    """
    Extract span for a given citation using Question Answering approach
    """
    # Create question
    question = f"What does citation {citation} support?"
    
    # Tokenize
    inputs = tokenizer(
        question,
        text,
        max_length=max_length,
        truncation='only_second',
        return_tensors='pt',
        padding=True
    )
    
    # Move to device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Get predictions
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get start and end positions
    start_logits = outputs.start_logits
    end_logits = outputs.end_logits
    
    start_idx = torch.argmax(start_logits, dim=1).item()
    end_idx = torch.argmax(end_logits, dim=1).item()
    
    # Decode answer
    if start_idx <= end_idx and start_idx > 0:
        answer_tokens = inputs['input_ids'][0][start_idx:end_idx+1]
        predicted_span = tokenizer.decode(answer_tokens, skip_special_tokens=True)
    else:
        predicted_span = ""
    
    # Get confidence scores
    start_prob = torch.softmax(start_logits, dim=1)[0][start_idx].item()
    end_prob = torch.softmax(end_logits, dim=1)[0][end_idx].item()
    confidence = (start_prob + end_prob) / 2
    
    return {
        'predicted_span': predicted_span,
        'start_idx': start_idx,
        'end_idx': end_idx,
        'confidence': confidence
    }

print("‚úÖ Inference function defined")

In [None]:
# Run Inference on Test Set
results = []
errors = []

print("üöÄ Starting inference...")
print("="*60)

for test_file in tqdm(test_files, desc="Processing files"):
    try:
        # Load input file
        with open(test_file) as f:
            in_data = json.load(f)
        
        # Load label file (for comparison)
        label_file = test_file.with_suffix('.label')
        with open(label_file) as f:
            label_data = json.load(f)
        
        text = in_data['text']
        citation_spans = label_data.get('citation_spans', [])
        
        # Process each citation
        for span_info in citation_spans:
            citation_id = span_info['citation_id']
            gold_span = span_info['span_text']
            
            # Run inference
            prediction = extract_citation_span(
                text=text,
                citation=citation_id,
                model=model,
                tokenizer=tokenizer,
                device=device
            )
            
            results.append({
                'file': test_file.stem,
                'citation_id': citation_id,
                'gold_span': gold_span,
                'predicted_span': prediction['predicted_span'],
                'confidence': prediction['confidence'],
                'start_idx': prediction['start_idx'],
                'end_idx': prediction['end_idx']
            })
    
    except Exception as e:
        errors.append({
            'file': test_file.stem,
            'error': str(e)
        })

print(f"\n‚úÖ Inference complete!")
print(f"üìä Processed: {len(results)} predictions")
print(f"‚ùå Errors: {len(errors)}")

In [None]:
# Calculate Metrics
def calculate_exact_match(gold, pred):
    """Exact match: predicted == gold (after normalization)"""
    return gold.strip().lower() == pred.strip().lower()

def calculate_f1(gold, pred):
    """F1 score based on token overlap"""
    gold_tokens = gold.strip().lower().split()
    pred_tokens = pred.strip().lower().split()
    
    if len(gold_tokens) == 0 or len(pred_tokens) == 0:
        return 0.0
    
    common = set(gold_tokens) & set(pred_tokens)
    
    if len(common) == 0:
        return 0.0
    
    precision = len(common) / len(pred_tokens)
    recall = len(common) / len(gold_tokens)
    
    f1 = 2 * (precision * recall) / (precision + recall)
    return f1

# Calculate metrics
exact_matches = 0
total_f1 = 0

for result in results:
    if calculate_exact_match(result['gold_span'], result['predicted_span']):
        exact_matches += 1
    
    total_f1 += calculate_f1(result['gold_span'], result['predicted_span'])

exact_match_score = exact_matches / len(results) if results else 0
avg_f1_score = total_f1 / len(results) if results else 0

print("="*60)
print("üìä EVALUATION METRICS - SPECIAL TOKENS MODEL")
print("="*60)
print(f"Total predictions: {len(results)}")
print(f"Exact Match: {exact_match_score:.4f} ({exact_matches}/{len(results)})")
print(f"F1 Score: {avg_f1_score:.4f}")
print(f"Average Confidence: {sum(r['confidence'] for r in results)/len(results):.4f}")
print("="*60)

In [None]:
# Show Examples - FULL TEXT (no truncation)
print("\nüìã SAMPLE PREDICTIONS:\n")

# Show first 10 predictions
for i, result in enumerate(results[:10]):
    print(f"\n{'='*80}")
    print(f"Example {i+1}:")
    print(f"File: {result['file']}")
    print(f"Citation: {result['citation_id']}")
    print(f"\nüìå Gold Span (FULL):")
    print(f"{result['gold_span']}")
    print(f"\nüîÆ Predicted Span (FULL):")
    print(f"{result['predicted_span']}")
    print(f"\nConfidence: {result['confidence']:.4f}")
    match = "‚úÖ EXACT MATCH" if calculate_exact_match(result['gold_span'], result['predicted_span']) else "‚ùå NO MATCH"
    f1 = calculate_f1(result['gold_span'], result['predicted_span'])
    print(f"Result: {match} (F1: {f1:.4f})")
    print(f"{'='*80}")

In [None]:
# Save Results
# Convert to DataFrame
df = pd.DataFrame(results)

# Save to CSV
output_file = 'task3_bert_special_tokens_predictions.csv'
df.to_csv(output_file, index=False)

print(f"\n‚úÖ Results saved to: {output_file}")
print(f"üìä Total rows: {len(df)}")

# Show errors if any
if errors:
    print(f"\n‚ö†Ô∏è Errors encountered: {len(errors)}")
    for error in errors[:5]:
        print(f"  - {error['file']}: {error['error']}")

print("\n" + "="*60)
print("‚úÖ INFERENCE COMPLETE - SPECIAL TOKENS MODEL!")
print("="*60)