# Model Evaluation and Analysis

This notebook evaluates the fine-tuned model by computing similarity metrics, ranking accuracy, and analyzing hard negatives.


In [None]:
# Import functions from the scripts directory
from src.data.loaders import load_toy_dataset
from src.models.embedding_pipeline import load_embeddinggemma_model, embed_texts
from src.models.lora_setup import setup_lora_model
from src.evaluation.similarity_metrics import (
    compute_similarity_matrix,
    rank_positives_for_anchors,
    compute_accuracy,
    compute_mean_reciprocal_rank,
    analyze_hard_negatives
)
import pandas as pd
import torch


## Load Fine-Tuned Model

Load the model with LoRA adapters. If you've trained a model, load it here. Otherwise, we'll use a freshly configured model for demonstration.


In [None]:
# Load model with LoRA (for demonstration, using base model with LoRA setup)
# In practice, you would load a trained model:
# from peft import PeftModel
# base_model = AutoModel.from_pretrained('google/embeddinggemma-300m')
# model = PeftModel.from_pretrained(base_model, 'outputs/models/embeddinggemma_lora_YYYYMMDD_HHMMSS')

tokenizer, base_model = load_embeddinggemma_model()
model = setup_lora_model(base_model, r=16, lora_alpha=32, lora_dropout=0.1)

print("Model loaded for evaluation")


## Load Test Data

We'll evaluate on the same toy dataset to see how well the model distinguishes pairs.


In [None]:
# Load dataset
train_data = load_toy_dataset()

# Extract anchor and positive texts
anchor_texts = [item["anchor"] for item in train_data]
positive_texts = [item["positive"] for item in train_data]

print(f"Evaluating on {len(anchor_texts)} anchor-positive pairs")


## Compute Embeddings

Generate embeddings for all anchors and positives.


In [None]:
# Compute embeddings
anchor_embeddings = embed_texts(anchor_texts, model, tokenizer)
positive_embeddings = embed_texts(positive_texts, model, tokenizer)

print(f"Anchor embeddings shape: {anchor_embeddings.shape}")
print(f"Positive embeddings shape: {positive_embeddings.shape}")


## Compute Similarity Matrix

See how similar each anchor is to each positive.


In [None]:
# Compute similarity matrix
sim_matrix = compute_similarity_matrix(anchor_embeddings, positive_embeddings)

# Display as DataFrame
sim_df = pd.DataFrame(
    sim_matrix.numpy(),
    index=[f"Anchor {i+1}" for i in range(len(anchor_texts))],
    columns=[f"Positive {i+1}" for i in range(len(positive_texts))]
)

print("Similarity Matrix (Anchor × Positive):")
print(sim_df.round(3))


## Ranking Analysis

For each anchor, see how the positives are ranked by similarity.


In [None]:
# Get rankings for each anchor
rankings, similarities = rank_positives_for_anchors(anchor_embeddings, positive_embeddings)

# Display rankings
for i, anchor in enumerate(anchor_texts):
    print(f"\nAnchor {i+1}: '{anchor}'")
    print("Rankings (most similar first):")
    for rank_pos, pos_idx in enumerate(rankings[i]):
        pos_idx_val = pos_idx.item()
        sim = similarities[i, pos_idx_val].item()
        marker = "✓" if pos_idx_val == i else " "
        print(f"  {marker} Rank {rank_pos+1}: Positive {pos_idx_val+1} (sim={sim:.3f}) - '{positive_texts[pos_idx_val]}'")


## Compute Metrics

Calculate accuracy and Mean Reciprocal Rank (MRR).


In [None]:
# Compute accuracy (correct positive ranked #1)
accuracy = compute_accuracy(anchor_embeddings, positive_embeddings)

# Compute MRR
mrr = compute_mean_reciprocal_rank(anchor_embeddings, positive_embeddings)

print("Evaluation Metrics:")
print("-" * 40)
print(f"Accuracy (correct positive at rank 1): {accuracy:.2%}")
print(f"Mean Reciprocal Rank (MRR): {mrr:.3f}")


## Hard Negative Analysis

Identify pairs that are incorrectly similar (hard negatives).


In [None]:
# Find hard negatives (incorrect pairs with high similarity)
hard_negatives = analyze_hard_negatives(
    anchor_texts,
    positive_texts,
    anchor_embeddings,
    positive_embeddings,
    threshold=0.3  # Consider pairs with similarity > 0.3 as hard negatives
)

if hard_negatives:
    print(f"Found {len(hard_negatives)} hard negative pairs:")
    for i, hn in enumerate(hard_negatives[:5], 1):  # Show top 5
        print(f"\n{i}. Similarity: {hn['similarity']:.3f}")
        print(f"   Anchor: '{hn['anchor']}'")
        print(f"   Incorrect Positive: '{hn['incorrect_positive']}'")
else:
    print("No hard negatives found (all incorrect pairs have low similarity)")
