In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import re
from pathlib import Path

# Load the field recall results and excluded fields
with open('../field_recall_results.json', 'r') as f:
    data = json.load(f)

with open('../info.json', 'r') as f:
    info_data = json.load(f)

field_recall_data = data['field_recall_by_model']
excluded_fields = info_data.get('excluded_fields', [])

# Strip all punctuation including underscores for comparison
def strip_punctuation(text):
    return re.sub(r'[^a-zA-Z0-9]', '', text)

excluded_fields_clean = [strip_punctuation(field) for field in excluded_fields]

print(f"Loaded data for {len(field_recall_data)} models")
print(f"Models: {list(field_recall_data.keys())}")
print(f"Excluded {len(excluded_fields)} fields")

In [None]:
# Create heatmap data from field recall results
heatmap_data = []

for model, fields in field_recall_data.items():
    for field, score in fields.items():
        # Strip punctuation from field name for excluded field comparison
        field_clean = strip_punctuation(field)
        if field_clean not in excluded_fields_clean:
            heatmap_data.append({
                'model': model,
                'field': field,
                'score': score
            })

# Convert to DataFrame and pivot for heatmap
df = pd.DataFrame(heatmap_data)
heatmap_data_pivot = df.pivot(index='field', columns='model', values='score')

# Sort fields by average score (highest to lowest)
field_avg_scores = heatmap_data_pivot.mean(axis=1)
heatmap_data_sorted = heatmap_data_pivot.loc[field_avg_scores.sort_values(ascending=False).index]

print(f"Overall heatmap shape: {heatmap_data_sorted.shape}")
print(f"Fields included: {heatmap_data_sorted.shape[0]}")
print(f"Models included: {heatmap_data_sorted.shape[1]}")

# Create the overall heatmap visualization
plt.style.use('default')  
fig, ax = plt.subplots(figsize=(14, 16))

# Create sophisticated orange-to-green colormap
colors = ['#CC4125', '#E55100', '#FF6F00', '#FF8F00', '#FFB300', '#FFCC02', 
          '#C5E1A5', '#A5D6A7', '#81C784', '#66BB6A', '#4CAF50', '#388E3C']
n_bins = 256
cmap = sns.blend_palette(colors, n_colors=n_bins, as_cmap=True)

# Create heatmap
ax = sns.heatmap(heatmap_data_sorted, 
                 annot=True, 
                 fmt='.2f', 
                 cmap=cmap,
                 vmin=0, 
                 vmax=1,
                 cbar_kws={
                     'label': 'Field Recall Score (All PMIDs)',
                     'shrink': 0.8,
                     'aspect': 30,
                     'pad': 0.02
                 },
                 annot_kws={'size': 9, 'weight': 'bold', 'color': 'white'},
                 linewidths=0.5,
                 linecolor='white',
                 square=False,
                 xticklabels=True,
                 yticklabels=True,
                 ax=ax)

# Style the plot
plt.title('Field Recall Scores by Model (All PMIDs)\n(Ordered by Average Score: High to Low)\n(Orange = Low Score, Green = High Score)', 
          fontsize=18, fontweight='bold', pad=30, color='#2C3E50')
plt.xlabel('Models', fontsize=14, fontweight='bold', color='#34495E')
plt.ylabel('Fields', fontsize=14, fontweight='bold', color='#34495E')

# Style the tick labels
plt.xticks(rotation=45, ha='right', fontsize=11, color='#2C3E50')
plt.yticks(rotation=0, fontsize=10, color='#2C3E50')

# Add model labels at the top
ax2 = ax.twiny()
ax2.set_xlim(ax.get_xlim())
ax2.set_xticks(ax.get_xticks())
ax2.set_xticklabels(heatmap_data_sorted.columns, rotation=45, ha='left', 
                    fontsize=11, color='#2C3E50', weight='bold')
ax2.set_xlabel('Models', fontsize=14, fontweight='bold', color='#34495E', pad=15)

# Remove spines for cleaner look
for spine in ax.spines.values():
    spine.set_visible(False)
for spine in ax2.spines.values():
    spine.set_visible(False)

# Add subtle background
fig.patch.set_facecolor('#FAFAFA')
ax.set_facecolor('#FAFAFA')

plt.tight_layout()
plt.show()

In [None]:
# Overall summary statistics
print("=== FIELD RECALL SUMMARY (ALL PMIDs) ===\n")

# Overall model performance
model_avg_scores = heatmap_data_sorted.mean(axis=0).sort_values(ascending=False)
print("Average scores by model:")
for model, score in model_avg_scores.items():
    print(f"  {model}: {score:.3f}")

print(f"\nBest performing model: {model_avg_scores.index[0]} ({model_avg_scores.iloc[0]:.3f})")
print(f"Worst performing model: {model_avg_scores.index[-1]} ({model_avg_scores.iloc[-1]:.3f})")

# Field difficulty analysis (already sorted in heatmap)
field_avg_scores = heatmap_data_sorted.mean(axis=1)
print(f"\nEasiest fields (top 5):")
for field, score in field_avg_scores.head().items():
    print(f"  {field}: {score:.3f}")

print(f"\nHardest fields (bottom 5):")
for field, score in field_avg_scores.tail().items():
    print(f"  {field}: {score:.3f}")

# Zero score fields
zero_scores = heatmap_data_sorted[heatmap_data_sorted == 0].stack()
print(f"\nFields with zero scores: {len(zero_scores)} instances")
if len(zero_scores) > 0:
    print("Zero score fields:")
    for (field, model), _ in zero_scores.items():
        print(f"  {model}: {field}")

In [None]:
# Load SCR PMIDs from experiment.json and create filtered visualization
with open('../../../experiment.json', 'r') as f:
    experiment_data = json.load(f)

scr_pmids = set(experiment_data.get('scr_pmids', []))
print(f"Found {len(scr_pmids)} SCR PMIDs to filter by")

# Load field results to get raw comparison data
field_results_dir = Path('../field_results')
field_files = list(field_results_dir.glob('*.json'))

# Filter data to only include SCR PMIDs
scr_filtered_data = []
total_comparisons = 0
scr_comparisons = 0

for field_file in field_files:
    if field_file.stem in ['README', 'visualisation']:
        continue
        
    try:
        with open(field_file, 'r') as f:
            field_data = json.load(f)
    except:
        print(f"Could not read {field_file}")
        continue
    
    field_path = field_data.get('field_path', field_file.stem.replace('_', '.'))
    
    # Filter comparisons to only SCR PMIDs
    for comparison in field_data.get('comparisons', []):
        total_comparisons += 1
        pmid = comparison.get('pmid', '')
        
        if pmid in scr_pmids:
            scr_comparisons += 1
            # Strip punctuation from field name for excluded field comparison
            field_clean = strip_punctuation(field_path)
            if field_clean not in excluded_fields_clean:
                scr_filtered_data.append({
                    'model': comparison['model_name'],
                    'field': field_path,
                    'pmid': pmid,
                    'ground_truth': comparison.get('ground_truth', ''),
                    'model_output': comparison.get('model_output', ''),
                    'match': comparison.get('ground_truth') == comparison.get('model_output')
                })

print(f"Total comparisons: {total_comparisons}")
print(f"SCR PMID comparisons: {scr_comparisons}")
print(f"SCR comparisons after filtering excluded fields: {len(scr_filtered_data)}")

# Convert to DataFrame
scr_df = pd.DataFrame(scr_filtered_data)

if len(scr_df) > 0:
    # Calculate recall scores for SCR PMIDs only
    scr_recall_scores = []
    for model in scr_df['model'].unique():
        for field in scr_df['field'].unique():
            model_field_data = scr_df[(scr_df['model'] == model) & (scr_df['field'] == field)]
            if len(model_field_data) > 0:
                recall = model_field_data['match'].mean()
                scr_recall_scores.append({
                    'model': model,
                    'field': field,
                    'score': recall
                })

    scr_recall_df = pd.DataFrame(scr_recall_scores)
    scr_heatmap_data = scr_recall_df.pivot(index='field', columns='model', values='score')
    
    # Sort by average score
    scr_field_avg_scores = scr_heatmap_data.mean(axis=1)
    scr_heatmap_data_sorted = scr_heatmap_data.loc[scr_field_avg_scores.sort_values(ascending=False).index]
    
    print(f"SCR-filtered heatmap shape: {scr_heatmap_data_sorted.shape}")
else:
    print("No SCR PMID data found after filtering!")
    scr_df = pd.DataFrame()
    scr_heatmap_data_sorted = None

In [None]:
# Create SCR-filtered heatmap visualization
if len(scr_df) > 0 and scr_heatmap_data_sorted is not None:
    plt.style.use('default')  
    fig, ax = plt.subplots(figsize=(14, 16))

    # Use same sophisticated orange-to-green colormap
    colors = ['#CC4125', '#E55100', '#FF6F00', '#FF8F00', '#FFB300', '#FFCC02', 
              '#C5E1A5', '#A5D6A7', '#81C784', '#66BB6A', '#4CAF50', '#388E3C']
    n_bins = 256
    cmap = sns.blend_palette(colors, n_colors=n_bins, as_cmap=True)

    # Create heatmap with SCR-filtered data
    ax = sns.heatmap(scr_heatmap_data_sorted, 
                     annot=True, 
                     fmt='.2f', 
                     cmap=cmap,
                     vmin=0, 
                     vmax=1,
                     cbar_kws={
                         'label': 'Field Recall Score (SCR PMIDs Only)',
                         'shrink': 0.8,
                         'aspect': 30,
                         'pad': 0.02
                     },
                     annot_kws={'size': 9, 'weight': 'bold', 'color': 'white'},
                     linewidths=0.5,
                     linecolor='white',
                     square=False,
                     xticklabels=True,
                     yticklabels=True,
                     ax=ax)

    # Style the plot
    plt.title(f'Field Recall Scores by Model (SCR PMIDs Only)\n({len(scr_pmids)} SCR Papers, {len(scr_df)} Comparisons)\n(Orange = Low Score, Green = High Score)', 
              fontsize=18, fontweight='bold', pad=30, color='#2C3E50')
    plt.xlabel('Models', fontsize=14, fontweight='bold', color='#34495E')
    plt.ylabel('Fields', fontsize=14, fontweight='bold', color='#34495E')

    # Style the tick labels
    plt.xticks(rotation=45, ha='right', fontsize=11, color='#2C3E50')
    plt.yticks(rotation=0, fontsize=10, color='#2C3E50')

    # Add model labels at the top
    ax2 = ax.twiny()
    ax2.set_xlim(ax.get_xlim())
    ax2.set_xticks(ax.get_xticks())
    ax2.set_xticklabels(scr_heatmap_data_sorted.columns, rotation=45, ha='left', 
                        fontsize=11, color='#2C3E50', weight='bold')
    ax2.set_xlabel('Models', fontsize=14, fontweight='bold', color='#34495E', pad=15)

    # Remove spines for cleaner look
    for spine in ax.spines.values():
        spine.set_visible(False)
    for spine in ax2.spines.values():
        spine.set_visible(False)

    # Add subtle background
    fig.patch.set_facecolor('#FAFAFA')
    ax.set_facecolor('#FAFAFA')

    plt.tight_layout()
    plt.show()
else:
    print("Cannot create SCR-filtered visualization - no data available")

In [None]:
# SCR-filtered summary statistics
if len(scr_df) > 0 and scr_heatmap_data_sorted is not None:
    print("=== SCR-FILTERED FIELD RECALL SUMMARY ===\n")

    # Overall model performance on SCR PMIDs
    scr_model_avg_scores = scr_heatmap_data_sorted.mean(axis=0).sort_values(ascending=False)
    print("Average scores by model (SCR PMIDs only):")
    for model, score in scr_model_avg_scores.items():
        print(f"  {model}: {score:.3f}")

    print(f"\nBest performing model on SCR PMIDs: {scr_model_avg_scores.index[0]} ({scr_model_avg_scores.iloc[0]:.3f})")
    print(f"Worst performing model on SCR PMIDs: {scr_model_avg_scores.index[-1]} ({scr_model_avg_scores.iloc[-1]:.3f})")

    # Field difficulty analysis for SCR PMIDs
    scr_field_avg_scores = scr_heatmap_data_sorted.mean(axis=1)
    print(f"\nEasiest fields on SCR PMIDs (top 5):")
    for field, score in scr_field_avg_scores.head().items():
        print(f"  {field}: {score:.3f}")

    print(f"\nHardest fields on SCR PMIDs (bottom 5):")
    for field, score in scr_field_avg_scores.tail().items():
        print(f"  {field}: {score:.3f}")

    # Compare with overall performance
    print(f"\n=== COMPARISON: ALL vs SCR PMIDs ===")
    print("Model performance comparison (All PMIDs vs SCR PMIDs only):")
    overall_model_scores = heatmap_data_sorted.mean(axis=0)
    
    for model in overall_model_scores.index:
        overall_score = overall_model_scores[model]
        scr_score = scr_model_avg_scores.get(model, 0.0)
        difference = scr_score - overall_score
        direction = "↑" if difference > 0 else "↓" if difference < 0 else "="
        print(f"  {model}: {overall_score:.3f} → {scr_score:.3f} ({difference:+.3f} {direction})")

    # Zero score fields on SCR PMIDs
    scr_zero_scores = scr_heatmap_data_sorted[scr_heatmap_data_sorted == 0].stack()
    print(f"\nFields with zero scores on SCR PMIDs: {len(scr_zero_scores)} instances")
    if len(scr_zero_scores) > 0:
        print("Zero score fields (SCR PMIDs):")
        for (field, model), _ in scr_zero_scores.items():
            print(f"  {model}: {field}")
else:
    print("No SCR PMID data available for summary statistics")