# PocketSage Model Evaluation - Reproducible Inference Notebook

This notebook provides a complete, reproducible evaluation pipeline for the PocketSage receipt categorization model.

## Contents
1. Setup and Configuration
2. Data Loading and Normalization
3. Model Evaluation
4. Confusion Matrix Analysis
5. Per-Category Metrics
6. Inference Examples
7. Results Summary


## 1. Setup and Configuration


In [None]:
import os
import sys
import json
import yaml
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from datetime import datetime
from dotenv import load_dotenv
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, classification_report
)

# Add project root to path
project_root = Path().resolve().parent
sys.path.append(str(project_root))
sys.path.append(str(project_root / 'api-endpoints'))
sys.path.append(str(project_root / 'evaluation'))

# Load environment variables
load_dotenv()

# Import custom modules
from normalization import ReceiptNormalizer
from gemini_retraining import GeminiReceiptTrainer, load_receipts_from_firestore

# Set style for plots
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

print("Setup complete!")


In [None]:
# Load configuration
config_path = project_root / 'evaluation' / 'fine_tuning_config.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("Configuration loaded:")
print(f"  Model: {config['model']['base_model']}")
print(f"  Categories: {len(config['categories']['categories'])}")
print(f"  Normalization: {config['normalization']['enabled']}")


## 2. Data Loading and Normalization


In [None]:
# Initialize normalizer and trainer
normalizer = ReceiptNormalizer()
trainer = GeminiReceiptTrainer(model_name=config['model']['base_model'])

# Load receipt data
print("Loading receipt data...")
all_receipts = load_receipts_from_firestore(limit=100)
print(f"Loaded {len(all_receipts)} receipts")

# Prepare training data
training_examples = trainer.prepare_training_data(all_receipts)
print(f"Prepared {len(training_examples)} training examples")

# If insufficient data, use mock examples
if len(training_examples) < 2:
    print("Using mock examples for demonstration...")
    training_examples = [
        {
            'input': 'Vendor: Grocery Store\nTotal: 150.50\nItems:\n  - Milk: 50\n  - Bread: 30',
            'output': 'groceries',
            'metadata': {}
        },
        {
            'input': 'Vendor: Restaurant\nTotal: 500\nItems:\n  - Pizza: 300\n  - Drinks: 200',
            'output': 'dining',
            'metadata': {}
        },
        {
            'input': 'Vendor: Gas Station\nTotal: 2000\nItems:\n  - Fuel: 2000',
            'output': 'transportation',
            'metadata': {}
        },
        {
            'input': 'Vendor: Electricity Board\nTotal: 1500\nItems:\n  - Electricity Bill: 1500',
            'output': 'utilities',
            'metadata': {}
        },
        {
            'input': 'Vendor: Hotel\nTotal: 5000\nItems:\n  - Room: 5000',
            'output': 'travel',
            'metadata': {}
        },
    ]

# Split into training and test sets
train_split = config['training']['train_split']
split_idx = int(len(training_examples) * train_split)
train_set = training_examples[:split_idx]
test_set = training_examples[split_idx:]

if len(test_set) == 0:
    test_set = train_set[:min(5, len(train_set))]

print(f"\nTraining set: {len(train_set)} examples")
print(f"Test set: {len(test_set)} examples")


## 3. Model Evaluation


In [None]:
# Run evaluation
print("Starting model evaluation...")
print("=" * 80)

metrics = trainer.evaluate_model(
    training_data=train_set,
    test_data=test_set,
    use_few_shot=True
)

print("\nEvaluation complete!")

# Display overall metrics
overall = metrics['overall']
print("\n" + "=" * 80)
print("OVERALL METRICS")
print("=" * 80)
print(f"Accuracy:  {overall['accuracy']:.4f} ({overall['accuracy']*100:.2f}%)")
print(f"Precision: {overall['precision']:.4f} ({overall['precision']*100:.2f}%)")
print(f"Recall:    {overall['recall']:.4f} ({overall['recall']*100:.2f}%)")
print(f"F1 Score:  {overall['f1_score']:.4f} ({overall['f1_score']*100:.2f}%)")

# Compare with targets
target_f1 = config['evaluation']['target_metrics']['macro_f1_score']
print(f"\nTarget F1 Score: {target_f1:.4f}")
print(f"Target Met: {'✓ YES' if overall['f1_score'] >= target_f1 else '✗ NO'}")


## 4. Confusion Matrix Analysis


In [None]:
# Generate confusion matrix visualization
cm = np.array(metrics['confusion_matrix'])
categories = trainer.categories

# Normalized confusion matrix
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
cm_normalized = np.nan_to_num(cm_normalized)

# Create visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

# Normalized heatmap
sns.heatmap(
    cm_normalized,
    annot=True,
    fmt='.2%',
    cmap='Blues',
    xticklabels=[cat[:10] for cat in categories],
    yticklabels=[cat[:10] for cat in categories],
    ax=ax1,
    cbar_kws={'label': 'Percentage'}
)
ax1.set_xlabel('Predicted Category', fontsize=12, fontweight='bold')
ax1.set_ylabel('True Category', fontsize=12, fontweight='bold')
ax1.set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold', pad=20)

# Raw counts heatmap
sns.heatmap(
    cm,
    annot=True,
    fmt='d',
    cmap='Blues',
    xticklabels=[cat[:10] for cat in categories],
    yticklabels=[cat[:10] for cat in categories],
    ax=ax2,
    cbar_kws={'label': 'Count'}
)
ax2.set_xlabel('Predicted Category', fontsize=12, fontweight='bold')
ax2.set_ylabel('True Category', fontsize=12, fontweight='bold')
ax2.set_title('Confusion Matrix (Raw Counts)', fontsize=14, fontweight='bold', pad=20)

plt.tight_layout()
plt.show()


## 5. Per-Category Metrics


In [None]:
# Create per-category metrics DataFrame
per_category = metrics['per_category']
df_metrics = pd.DataFrame([
    {
        'Category': cat,
        'Precision': vals['precision'],
        'Recall': vals['recall'],
        'F1 Score': vals['f1_score']
    }
    for cat, vals in per_category.items()
])

print("Per-Category Metrics:")
print("=" * 80)
print(df_metrics.to_string(index=False))

# Check which categories meet the target
min_f1 = config['evaluation']['target_metrics']['per_category_f1_min']
df_metrics['Meets Target'] = df_metrics['F1 Score'] >= min_f1
print(f"\nCategories meeting F1 ≥ {min_f1}: {df_metrics['Meets Target'].sum()}/{len(df_metrics)}")

# Visualize per-category metrics
categories_list = df_metrics['Category'].values
precision_vals = df_metrics['Precision'].values
recall_vals = df_metrics['Recall'].values
f1_vals = df_metrics['F1 Score'].values

x = np.arange(len(categories_list))
width = 0.25

fig, ax = plt.subplots(figsize=(14, 8))

bars1 = ax.bar(x - width, precision_vals, width, label='Precision', 
              color='#3498db', alpha=0.8, edgecolor='black')
bars2 = ax.bar(x, recall_vals, width, label='Recall', 
              color='#2ecc71', alpha=0.8, edgecolor='black')
bars3 = ax.bar(x + width, f1_vals, width, label='F1 Score', 
              color='#e74c3c', alpha=0.8, edgecolor='black')

# Add value labels
for bars in [bars1, bars2, bars3]:
    for bar in bars:
        height = bar.get_height()
        if height > 0.01:
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{height:.2f}',
                   ha='center', va='bottom', fontsize=9)

# Add target line
target = config['evaluation']['target_metrics']['per_category_f1_min']
ax.axhline(y=target, color='red', linestyle='--', linewidth=2, 
          label=f'Target F1: {target}', alpha=0.7)

ax.set_xlabel('Category', fontsize=12, fontweight='bold')
ax.set_ylabel('Score', fontsize=12, fontweight='bold')
ax.set_title('Per-Category Performance Metrics', fontsize=14, fontweight='bold', pad=20)
ax.set_xticks(x)
ax.set_xticklabels([cat[:12] for cat in categories_list], rotation=45, ha='right')
ax.set_ylim([0, 1.1])
ax.legend()
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()


## 6. Inference Examples


In [None]:
# Test inference on sample receipts
print("Testing inference on sample receipts...")
print("=" * 80)

sample_receipts = [
    {
        'parsedData': {
            'raw': {
                'vendor': 'DMART',
                'total': 450.50,
                'items': [
                    {'name': 'MILK 2%', 'price': 50},
                    {'name': 'BRD LOAF', 'price': 30},
                    {'name': 'EGGS', 'price': 70.50},
                    {'name': 'RICE 5KG', 'price': 300}
                ]
            }
        }
    },
    {
        'parsedData': {
            'raw': {
                'vendor': 'SWIGGY',
                'total': 350,
                'items': [
                    {'name': 'PIZZA', 'price': 200},
                    {'name': 'COKE', 'price': 50},
                    {'name': 'DELIVERY', 'price': 100}
                ]
            }
        }
    },
    {
        'parsedData': {
            'raw': {
                'vendor': 'PETROL PUMP',
                'total': 2000,
                'items': [
                    {'name': 'PETROL', 'price': 2000}
                ]
            }
        }
    }
]

for i, receipt in enumerate(sample_receipts, 1):
    predicted = trainer.train_with_few_shot(train_set, receipt)
    print(f"\nExample {i}:")
    print(f"  Vendor: {receipt['parsedData']['raw']['vendor']}")
    print(f"  Total: ₹{receipt['parsedData']['raw']['total']}")
    print(f"  Predicted Category: {predicted}")


## 7. Results Summary


In [None]:
# Print comprehensive summary
print("=" * 80)
print("EVALUATION SUMMARY")
print("=" * 80)
print(f"\nModel: {metrics['model_name']}")
print(f"Timestamp: {metrics['timestamp']}")
print(f"Test Set Size: {metrics['test_size']}")
print(f"Training Examples: {metrics['training_size']}")

print("\nOverall Performance:")
print(f"  Accuracy:  {overall['accuracy']:.4f} ({overall['accuracy']*100:.2f}%)")
print(f"  Precision: {overall['precision']:.4f} ({overall['precision']*100:.2f}%)")
print(f"  Recall:    {overall['recall']:.4f} ({overall['recall']*100:.2f}%)")
print(f"  F1 Score:  {overall['f1_score']:.4f} ({overall['f1_score']*100:.2f}%)")

target_f1 = config['evaluation']['target_metrics']['macro_f1_score']
target_met = overall['f1_score'] >= target_f1
print(f"\nTarget F1 Score: {target_f1:.4f}")
print(f"Target Met: {'✓ YES' if target_met else '✗ NO'}")

min_f1 = config['evaluation']['target_metrics']['per_category_f1_min']
categories_meeting_target = sum(
    1 for cat, vals in per_category.items()
    if vals['f1_score'] >= min_f1
)
total_categories = len(per_category)
print(f"\nCategories meeting F1 ≥ {min_f1}: {categories_meeting_target}/{total_categories}")

print("\n" + "=" * 80)
print("Evaluation complete!")
print("=" * 80)

# Save results
output_dir = project_root / 'evaluation' / 'results'
output_dir.mkdir(parents=True, exist_ok=True)

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
results_file = output_dir / f"pocketsage_evaluation_{timestamp}.json"

with open(results_file, 'w') as f:
    json.dump(metrics, f, indent=2)

print(f"\nResults saved to: {results_file}")
