# Evaluating the Multi-Modal Neural Network

This notebook shows how to load a trained model and run benchmarks on evaluation datasets.

## Import Libraries

In [None]:
import sys
sys.path.append('..')

import torch
import yaml
from pathlib import Path
from src.training.trainer import Trainer
from src.utils.config import load_config

print("Libraries imported successfully")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Load Model

In [None]:
# Load configuration and trained model
config_path = '../configs/default.yaml'
config = load_config(config_path)

# Specify checkpoint path (update this to your actual checkpoint)
checkpoint_path = '../outputs/checkpoints/best.pt'

# Check if checkpoint exists
if Path(checkpoint_path).exists():
    trainer = Trainer(config_path=config_path, resume_from=checkpoint_path)
    model = trainer.model
    model.eval()
    print("Model loaded from checkpoint and set to evaluation mode")
else:
    print(f"Checkpoint not found at: {checkpoint_path}")
    print("Please train a model first or update the checkpoint path")

## Load Evaluation Data

In [None]:
# Create evaluation dataset from config
from src.data.dataset import create_dataset_from_config, create_dataloader

# Create validation dataset
_, val_dataset = create_dataset_from_config(config)
val_loader = create_dataloader(
    val_dataset, 
    batch_size=config.get('data', {}).get('batch_size', 32),
    shuffle=False,
    num_workers=0  # Set to 0 for notebooks
)
print(f"Evaluation dataset loaded with {len(val_dataset)} samples")

## Run Evaluation

In [None]:
# Run evaluation
all_preds = []
all_labels = []
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Running evaluation on {device}...")
model.to(device)

with torch.no_grad():
    for batch_idx, batch in enumerate(val_loader):
        # Move batch to device
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v 
                 for k, v in batch.items()}
        
        # Forward pass
        outputs = model(
            images=batch.get('image') or batch.get('images'),
            input_ids=batch.get('input_ids'),
            attention_mask=batch.get('attention_mask')
        )
        
        logits = outputs['logits']
        preds = torch.argmax(logits, dim=-1)
        labels = batch.get('label') or batch.get('labels')
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        if (batch_idx + 1) % 10 == 0:
            print(f"Processed {batch_idx + 1}/{len(val_loader)} batches")

# Compute metrics
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

accuracy = accuracy_score(all_labels, all_preds)
precision, recall, f1, _ = precision_recall_fscore_support(
    all_labels, all_preds, average='weighted', zero_division=0
)

print("\nEvaluation Results:")
print(f"  Accuracy:  {accuracy:.4f}")
print(f"  Precision: {precision:.4f}")
print(f"  Recall:    {recall:.4f}")
print(f"  F1 Score:  {f1:.4f}")

## Visualize Results

Visualize the confusion matrix and sample predictions.

In [None]:
# Visualize results
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Create confusion matrix
cm = confusion_matrix(all_labels, all_preds)
num_classes = len(np.unique(all_labels))

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=range(num_classes), 
            yticklabels=range(num_classes))
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.show()

# Show some sample predictions
print("\nSample Predictions:")
print("-" * 60)
sample_indices = np.random.choice(len(all_labels), min(10, len(all_labels)), replace=False)
for idx in sample_indices:
    true_label = all_labels[idx]
    pred_label = all_preds[idx]
    status = "✓" if true_label == pred_label else "✗"
    print(f"{status} Sample {idx}: True={true_label}, Predicted={pred_label}")