# WeightMapper Interactive Demo

This notebook demonstrates the capabilities of the `WeightMapper` class for mapping weights between different model architectures.

## Overview

The `WeightMapper` helps transfer learned weights from one model architecture to another, even when layer names have changed. This is useful for:

- Model refactoring
- Architecture migrations
- Transfer learning with renamed layers


In [None]:
# Import required libraries
import torch
from torch import Tensor, nn

from lit_wsl.models.weight_mapper import WeightMapper

## Part 1: Basic Weight Mapping Between Models

Let's start by defining two similar models with different naming conventions.


In [None]:
# Define a simple source model
class OldModel(nn.Module):
    """Original model architecture."""

    def __init__(self) -> None:
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.classifier = nn.Sequential(
            nn.Linear(128 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, 10),
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.backbone(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [None]:
# Define a target model with similar but renamed structure
class NewModel(nn.Module):
    """New model architecture with renamed layers."""

    def __init__(self) -> None:
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.head = nn.Sequential(
            nn.Linear(128 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, 10),
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.feature_extractor(x)
        x = x.view(x.size(0), -1)
        x = self.head(x)
        return x

### Test 1: Basic Weight Mapping


In [None]:
# Create models
old_model = OldModel()
new_model = NewModel()

# Create mapper
mapper = WeightMapper(old_model, new_model)

# Generate mapping
mapping = mapper.suggest_mapping(threshold=0.5)

# Print analysis
print("Weight Mapping Analysis:")
print("=" * 80)
mapper.print_analysis()

In [None]:
# Display the generated mapping
print("\nGenerated Mapping Dictionary:")
print("=" * 80)
for source, target in mapping.items():
    print(f"{source} -> {target}")

### Test 2: Conservative Mapping Strategy

The conservative strategy uses a higher threshold for more confident matches.


In [None]:
mapper_conservative = WeightMapper(old_model, new_model)
mapping_conservative = mapper_conservative.suggest_mapping(threshold=0.7, strategy="conservative")

print(f"Conservative mapping found {len(mapping_conservative)} matches")
print("Mapping:")
for source, target in mapping_conservative.items():
    print(f"  {source} -> {target}")

### Test 3: Shape-Only Mapping Strategy

This strategy only considers parameter shapes, ignoring names.


In [None]:
mapper_shape = WeightMapper(old_model, new_model)
mapping_shape = mapper_shape.suggest_mapping(strategy="shape_only")

print(f"Shape-only mapping found {len(mapping_shape)} matches")
print("\nFirst 5 mappings:")
for source, target in list(mapping_shape.items())[:5]:
    source_shape = mapper_shape.source_params[source].shape
    print(f"{source} ({source_shape}) -> {target}")

### Test 4: Mapping with Confidence Scores

View mappings sorted by confidence scores to identify the most reliable matches.


In [None]:
mapper_scores = WeightMapper(old_model, new_model)
mapper_scores.suggest_mapping()

# Get mappings with scores
mappings_with_scores = mapper_scores.get_mapping_with_scores()

# Sort by score
mappings_with_scores.sort(key=lambda x: x[2], reverse=True)

print("Top 10 mappings by confidence score:")
print(f"{'Source':<40} {'Target':<40} {'Score':>8}")
print("-" * 90)
for source, target, score in mappings_with_scores[:10]:
    print(f"{source:<40} {target:<40} {score:>7.3f}")

### Test 5: Export Mapping Report

Export the mapping analysis to a JSON file for future reference.


In [None]:
mapper_export = WeightMapper(old_model, new_model)
mapper_export.suggest_mapping()

# Export report
output_path = "/tmp/weight_mapping_report.json"
mapper_export.export_mapping_report(output_path)

print(f"Report exported successfully to {output_path}")

# Display the report content
import json

with open(output_path) as f:
    report = json.load(f)
    print("\nReport summary:")
    print(f"  Total mappings: {len(report.get('mapping', {}))}")
    print(f"  Unmapped source params: {len(report.get('unmapped_source', []))}")
    print(f"  Unmapped target params: {len(report.get('unmapped_target', []))}")

## Part 2: Mapping from State Dict

This section demonstrates how to use `WeightMapper` when you only have a checkpoint file (state dict) without access to the original model code.


In [None]:
# Define models for state dict demo
class OldModelV2(nn.Module):
    """Original model architecture."""

    def __init__(self) -> None:
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(64 * 32 * 32, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )


class NewModelV2(nn.Module):
    """Refactored model architecture."""

    def __init__(self) -> None:
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
        )
        self.classifier = nn.Sequential(
            nn.Linear(64 * 32 * 32, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

### Step 1: Create and Save Old Model Checkpoint


In [None]:
# Simulate having an old checkpoint
print("Creating old model and saving checkpoint...")
old_model_v2 = OldModelV2()
checkpoint_path = "/tmp/old_model_checkpoint.pth"

# Save as a typical PyTorch checkpoint
torch.save(
    {
        "state_dict": old_model_v2.state_dict(),
        "epoch": 42,
        "optimizer_state": {},  # Would normally have optimizer state
    },
    checkpoint_path,
)
print(f"✓ Saved checkpoint to {checkpoint_path}")

### Step 2: Load Checkpoint Without Original Model Code


In [None]:
# Now pretend we don't have access to OldModelV2 class anymore
from lit_wsl.models.checkpoint import load_checkpoint_as_dict

print("Loading checkpoint (without old model code)...")
checkpoint = load_checkpoint_as_dict(checkpoint_path)
old_weights = checkpoint["state_dict"]

print(f"✓ Loaded {len(old_weights)} parameters from checkpoint")
print(f"  Sample keys: {list(old_weights.keys())[:3]}...")

### Step 3: Create New Model and Mapper


In [None]:
# Create new model
print("Creating new model architecture...")
new_model_v2 = NewModelV2()
print(f"✓ New model has {len(list(new_model_v2.parameters()))} parameters")
print(f"  Sample keys: {list(new_model_v2.state_dict().keys())[:3]}...")

In [None]:
# Create mapper from state dict
print("Creating WeightMapper from state dict...")
mapper_from_dict = WeightMapper.from_state_dict(old_weights, new_model_v2)
print("✓ Mapper created successfully")

### Step 4: Generate and Analyze Mapping


In [None]:
# Get mapping
print("Generating weight mapping...")
mapping_dict = mapper_from_dict.suggest_mapping(strategy="best_match", threshold=0.5)
print(f"✓ Found {len(mapping_dict)} parameter mappings")

# Show analysis
print("\nMapping analysis:")
print("-" * 80)
mapper_from_dict.print_analysis()

### Step 5: Apply Mapping and Load Weights


In [None]:
# Apply mapping to checkpoint
print("Applying mapping to create new checkpoint...")
new_weights = {}
for old_key, new_key in mapping_dict.items():
    new_weights[new_key] = old_weights[old_key]

# Load into new model
missing, unexpected = new_model_v2.load_state_dict(new_weights, strict=False)
print(f"✓ Loaded {len(new_weights)} weights into new model")

if missing:
    print(f"  Missing keys: {len(missing)}")
    print(f"  Examples: {list(missing)[:3]}")
if unexpected:
    print(f"  Unexpected keys: {len(unexpected)}")

print("\n" + "=" * 80)
print("✓ Successfully mapped weights from old checkpoint to new model!")
print("=" * 80)

### Cleanup


In [None]:
# Clean up temporary files
import os

if os.path.exists(checkpoint_path):
    os.remove(checkpoint_path)
    print(f"Cleaned up {checkpoint_path}")

if os.path.exists("/tmp/weight_mapping_report.json"):
    os.remove("/tmp/weight_mapping_report.json")
    print("Cleaned up /tmp/weight_mapping_report.json")

## Summary

This notebook demonstrated:

1. **Basic weight mapping** between two models with different naming conventions
2. **Different mapping strategies**: conservative, shape-only, and best-match
3. **Confidence scores** for evaluating mapping quality
4. **Export capabilities** for saving mapping reports
5. **State dict mapping** for working with checkpoints when original model code is unavailable

The `WeightMapper` class provides flexible and powerful tools for transferring weights between different model architectures, making it easier to refactor code while preserving learned parameters.
