# Task 3: Citation Span Extraction - Inference

**Model:** QA model trained with positions (BERT/RoBERTa/SciBERT)

**Task:** Extract text span that each citation supports

**Metrics:** F1 Score + Exact Match

---

In [32]:
import transformers, datasets
print(f"‚úÖ transformers: {transformers.__version__}")
print(f"‚úÖ datasets: {datasets.__version__}")

‚úÖ transformers: 4.57.1
‚úÖ datasets: 4.4.2


In [33]:
# Configuration
MODEL_PATH = '/kaggle/input/task3-bert-training-withpositions/models/task3_bert_with_positions_final'
TEST_DIR = '/kaggle/input/thesis-data-task3-with-positions-test-gold-500/test_gold_500'
OUTPUT_DIR = '/kaggle/working/predictions'
EVAL_OUTPUT = '/kaggle/working/evaluation_results.json'

print(f"üìÇ Model: {MODEL_PATH}")
print(f"üìÇ Test data: {TEST_DIR}")
print(f"üìÇ Output: {OUTPUT_DIR}")

üìÇ Model: /kaggle/input/task3-bert-training-withpositions/models/task3_bert_with_positions_final
üìÇ Test data: /kaggle/input/thesis-data-task3-with-positions-test-gold-500/test_gold_500
üìÇ Output: /kaggle/working/predictions


In [34]:
# Load model
import torch
from transformers import pipeline

device = 0 if torch.cuda.is_available() else -1
print(f"Device: {'GPU' if device == 0 else 'CPU'}")

qa_pipeline = pipeline(
    'question-answering',
    model=MODEL_PATH,
    tokenizer=MODEL_PATH,
    device=device
)

print("‚úÖ Model loaded successfully")

Device set to use cuda:0


Device: GPU
‚úÖ Model loaded successfully


In [35]:
# Inference function
def extract_citation_span(text: str, citation_id: str):
    """Extract span using QA model."""
    question = f"What does citation {citation_id} support?"
    
    try:
        result = qa_pipeline(
            question=question,
            context=text,
            max_seq_len=512,
            handle_impossible_answer=False
        )
        
        return {
            'span_text': result['answer'],
            'score': result['score'],
            'start': result['start'],
            'end': result['end']
        }
    except Exception as e:
        print(f"‚ö†Ô∏è  Error: {e}")
        return {
            'span_text': '',
            'score': 0.0,
            'start': -1,
            'end': -1
        }

print("‚úÖ Inference function defined")

‚úÖ Inference function defined


In [None]:
# Run inference
import json
from pathlib import Path
from tqdm import tqdm

test_path = Path(TEST_DIR)
output_path = Path(OUTPUT_DIR)
output_path.mkdir(parents=True, exist_ok=True)

label_files = sorted(test_path.glob("*.label"))
print(f"üìä Found {len(label_files)} files")
print("=" * 60)

stats = {
    'total_files': 0,
    'total_citations': 0,
    'successful': 0,
    'failed': 0
}

for label_file in tqdm(label_files):
    try:
        # Read file
        with open(label_file) as f:
            label_data = json.load(f)
        
        text = label_data.get('text', '')
        if not text:
            stats['failed'] += 1
            continue
        
        # Get citations
        citation_ids = list(label_data.get('correct_citation', {}).keys())
        
        # Extract spans - format y chang nh∆∞ file .label g·ªëc
        citation_spans = []
        for citation_id in citation_ids:
            result = extract_citation_span(text, citation_id)
            
            citation_spans.append({
                'citation_id': citation_id,
                'span_text': result['span_text'],
                's_span': result['start'],
                'e_span': result['end']
            })
            
            if result['score'] > 0:
                stats['successful'] += 1
            else:
                stats['failed'] += 1
            
            stats['total_citations'] += 1
        
        # Save predictions - structure y chang file .label
        output_data = {
            'doc_id': label_data.get('doc_id', label_file.stem),
            'text': text,
            'correct_citation': label_data.get('correct_citation', {}),
            'citation_spans': citation_spans,  # Y chang t√™n field trong .label
            'bib_entries': label_data.get('bib_entries', {}),  # Gi·ªØ nguy√™n bib_entries
            'generator': 'qa_model_inference'  # ƒê√°nh d·∫•u l√† model prediction
        }
        
        output_file = output_path / label_file.name
        with open(output_file, 'w') as f:
            json.dump(output_data, f, indent=2, ensure_ascii=False)
        
        stats['total_files'] += 1
        
    except Exception as e:
        print(f"\n‚ùå Error processing {label_file.name}: {e}")
        stats['failed'] += 1

print("\n" + "=" * 60)
print("üìä INFERENCE RESULTS")
print("=" * 60)
print(f"Files processed: {stats['total_files']}")
print(f"Total citations: {stats['total_citations']}")
print(f"‚úÖ Successful: {stats['successful']} ({stats['successful']/max(stats['total_citations'],1)*100:.1f}%)")
print(f"‚ùå Failed: {stats['failed']} ({stats['failed']/max(stats['total_citations'],1)*100:.1f}%)")
print("=" * 60)

In [37]:
# Evaluation - Calculate F1 and Exact Match
import numpy as np

def calculate_f1_em(pred_start, pred_end, true_start, true_end):
    """Calculate F1 score and Exact Match for character-level spans."""
    # Exact Match
    exact_match = 1 if (pred_start == true_start and pred_end == true_end) else 0
    
    # F1 Score
    if pred_start == -1 or pred_end == -1:
        return 0.0, exact_match
    
    if pred_end < pred_start:
        pred_end = pred_start
    
    # Calculate overlap
    overlap_start = max(pred_start, true_start)
    overlap_end = min(pred_end, true_end)
    overlap = max(0, overlap_end - overlap_start)
    
    if overlap == 0:
        return 0.0, exact_match
    
    pred_length = pred_end - pred_start
    true_length = true_end - true_start
    
    precision = overlap / pred_length if pred_length > 0 else 0
    recall = overlap / true_length if true_length > 0 else 0
    
    if precision + recall == 0:
        f1 = 0.0
    else:
        f1 = 2 * precision * recall / (precision + recall)
    
    return f1, exact_match

print("‚úÖ Evaluation function defined")

‚úÖ Evaluation function defined


In [None]:
# Evaluate all predictions
prediction_files = sorted(output_path.glob("*.label"))

all_f1_scores = []
all_exact_matches = []
file_results = []

for pred_file in tqdm(prediction_files, desc="Evaluating"):
    with open(pred_file) as f:
        data = json.load(f)
    
    # ƒê·ªçc ground truth t·ª´ test data g·ªëc
    gt_file = test_path / pred_file.name
    with open(gt_file) as f:
        gt_data = json.load(f)
    
    if 'citation_spans' not in gt_data:
        continue
    
    # Ground truth spans
    ground_truth = {
        span['citation_id']: span
        for span in gt_data['citation_spans']
    }
    
    # Predicted spans
    predictions = {
        span['citation_id']: span
        for span in data['citation_spans']
    }
    
    file_f1_scores = []
    file_exact_matches = []
    
    for citation_id, gt_span in ground_truth.items():
        if citation_id not in predictions:
            file_f1_scores.append(0.0)
            file_exact_matches.append(0)
            continue
        
        pred = predictions[citation_id]
        
        true_start = gt_span.get('s_span', -1)
        true_end = gt_span.get('e_span', -1)
        pred_start = pred.get('s_span', -1)
        pred_end = pred.get('e_span', -1)
        
        if true_start == -1 or true_end == -1:
            continue
        
        f1, em = calculate_f1_em(pred_start, pred_end, true_start, true_end)
        
        file_f1_scores.append(f1)
        file_exact_matches.append(em)
    
    all_f1_scores.extend(file_f1_scores)
    all_exact_matches.extend(file_exact_matches)
    
    file_results.append({
        'file': pred_file.name,
        'num_citations': len(file_f1_scores),
        'avg_f1': np.mean(file_f1_scores) if file_f1_scores else 0,
        'avg_em': np.mean(file_exact_matches) if file_exact_matches else 0
    })

# Overall metrics
overall_metrics = {
    'total_files': len(file_results),
    'total_citations': len(all_f1_scores),
    'f1_score': np.mean(all_f1_scores) if all_f1_scores else 0,
    'exact_match': np.mean(all_exact_matches) if all_exact_matches else 0,
    'file_results': file_results
}

print("\n" + "=" * 60)
print("üìä EVALUATION RESULTS")
print("=" * 60)
print(f"Files evaluated: {overall_metrics['total_files']}")
print(f"Total citations: {overall_metrics['total_citations']}")
print(f"F1 Score: {overall_metrics['f1_score']:.4f} ({overall_metrics['f1_score']*100:.2f}%)")
print(f"Exact Match: {overall_metrics['exact_match']:.4f} ({overall_metrics['exact_match']*100:.2f}%)")
print("=" * 60)

# Save evaluation results
with open(EVAL_OUTPUT, 'w') as f:
    json.dump(overall_metrics, f, indent=2, ensure_ascii=False)

print(f"\n‚úÖ Evaluation results saved to: {EVAL_OUTPUT}")

In [None]:
# Sample prediction
sample_file = sorted(output_path.glob("*.label"))[0]
with open(sample_file) as f:
    sample = json.load(f)

print(f"üìã Sample: {sample['doc_id']}")
print(f"\nText: {sample['text'][:200]}...")
print(f"\nCorrect Citations: {sample['correct_citation']}")
print(f"\nPredicted Citation Spans:")
for span in sample['citation_spans']:
    print(f"\n{span['citation_id']}:")
    print(f"  span_text: {span['span_text'][:100]}...")
    print(f"  s_span: {span['s_span']}")
    print(f"  e_span: {span['e_span']}")
    
print(f"\n\nüìÑ Full structure (same as .label file):")
print(json.dumps(sample, indent=2, ensure_ascii=False)[:500] + "...")