# WeightMapper Complete Example

This notebook provides a comprehensive demonstration of the `WeightMapper` class, showcasing all its features for mapping weights between different model architectures.

## Overview

The `WeightMapper` helps transfer learned weights from one model architecture to another, even when layer names have changed. This is useful for:

- **Model refactoring**: Renaming layers while preserving learned weights
- **Architecture migrations**: Moving between different model structures
- **Transfer learning**: Adapting pre-trained models to new architectures

## Key Features

1. **Group-based parameter mapping**: Parameters are organized into groups (e.g., conv layers with weight and bias)
2. **Hierarchical structure extraction**: Parent-child relationships in model architecture are preserved
3. **Batch normalization support**: Handles BN layers with buffers (running_mean, running_var)
4. **Multiple mapping strategies**: Conservative, shape-only, and best-match strategies
5. **State dict mapping**: Work with checkpoints when original model code is unavailable
6. **Confidence scores**: Evaluate mapping quality with scoring metrics


## Setup and Imports


In [None]:
import json
from pathlib import Path

import torch
from torch import Tensor, nn

from lit_wsl.models.checkpoint import load_checkpoint_as_dict
from lit_wsl.models.weight_mapper import WeightMapper

---

# Part 1: Basic Weight Mapping

Let's start by defining two similar models with different naming conventions and demonstrate basic mapping functionality.


In [None]:
class OldModel(nn.Module):
    """Original model architecture."""

    def __init__(self) -> None:
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.classifier = nn.Sequential(
            nn.Linear(128 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, 10),
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.backbone(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [None]:
class NewModel(nn.Module):
    """New model architecture with renamed layers."""

    def __init__(self) -> None:
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
        )
        self.head = nn.Sequential(
            nn.Linear(128 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, 10),
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.feature_extractor(x)
        x = x.view(x.size(0), -1)
        x = self.head(x)
        return x

In [None]:
# Create models
old_model = OldModel()
new_model = NewModel()

# Create mapper
mapper = WeightMapper(old_model, new_model)

# Generate mapping
mapping = mapper.suggest_mapping(threshold=0.5)

# Print analysis
print("Weight Mapping Analysis:")
print("=" * 80)
mapper.print_analysis()

In [None]:
# Display the generated mapping
print("\nGenerated Mapping Dictionary:")
print("=" * 80)
for source, target in mapping.items():
    print(f"{source} -> {target}")

---

# Part 2: Group-Based Parameter Mapping

Parameters are organized into groups where all related parameters (e.g., weight and bias) from the same module are mapped together.


In [None]:
class SourceModel(nn.Module):
    """Source model with a specific structure."""

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.fc = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = torch.relu(x)
        x = x.mean(dim=[2, 3])  # Global average pooling
        x = self.fc(x)
        return x

In [None]:
class TargetModel(nn.Module):
    """Target model with renamed layers."""

    def __init__(self):
        super().__init__()
        # Renamed layers
        self.encoder_conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.encoder_norm1 = nn.BatchNorm2d(64)
        self.encoder_conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.encoder_norm2 = nn.BatchNorm2d(128)
        self.classifier = nn.Linear(128, 10)

    def forward(self, x):
        x = self.encoder_conv1(x)
        x = self.encoder_norm1(x)
        x = torch.relu(x)
        x = self.encoder_conv2(x)
        x = self.encoder_norm2(x)
        x = torch.relu(x)
        x = x.mean(dim=[2, 3])
        x = self.classifier(x)
        return x

In [None]:
print("=" * 80)
print("Group-Based Parameter Mapping Demonstration")
print("=" * 80)

# Create models
source = SourceModel()
target = TargetModel()

# Create mapper
mapper_groups = WeightMapper(source, target)

print(f"\nSource model has {len(mapper_groups.source_params)} parameters")
print(f"Target model has {len(mapper_groups.target_params)} parameters")

In [None]:
# Show source parameter groups
print(f"\nSource model has {len(mapper_groups.source_groups)} parameter groups:")
for path, group in sorted(mapper_groups.source_groups.items()):
    param_types = sorted(group.param_types)
    print(f"  {path:30} -> {param_types}")

In [None]:
# Show target parameter groups
print(f"\nTarget model has {len(mapper_groups.target_groups)} parameter groups:")
for path, group in sorted(mapper_groups.target_groups.items()):
    param_types = sorted(group.param_types)
    print(f"  {path:30} -> {param_types}")

In [None]:
# Suggest mapping
mapping_groups = mapper_groups.suggest_mapping(threshold=0.5)

print(f"\nMapping results: {len(mapping_groups)} parameters mapped")

In [None]:
# Show group mappings
if mapper_groups._group_mapping and mapper_groups._group_scores:  # type: ignore[attr-defined]
    print(f"\nGroup mappings ({len(mapper_groups._group_mapping)} groups):")  # type: ignore[attr-defined]
    for source_path, target_path in sorted(mapper_groups._group_mapping.items()):  # type: ignore[attr-defined]
        score = mapper_groups._group_scores[source_path]  # type: ignore[attr-defined]
        source_group = mapper_groups.source_groups[source_path]
        param_types = sorted(source_group.param_types)
        print(f"  {source_path:30} -> {target_path:30} (score: {score:.3f})")
        print(f"    Parameters in group: {param_types}")

        # Show individual parameter mappings for this group
        for param_type in param_types:
            source_param = source_group.params[param_type]
            target_param = mapper_groups.target_groups[target_path].params[param_type]
            print(f"      {source_param.name:45} -> {target_param.name}")

In [None]:
# Verify that all parameters in a group are mapped together
print("\n" + "=" * 80)
print("Verification: All parameters in a group are mapped together")
print("=" * 80)

if mapper_groups._group_mapping:  # type: ignore[attr-defined]
    for source_path, target_path in sorted(mapper_groups._group_mapping.items()):  # type: ignore[attr-defined]
        source_group = mapper_groups.source_groups[source_path]
        target_group = mapper_groups.target_groups[target_path]

        # Check that all param types in source are in target
        source_types = set(source_group.param_types)
        target_types = set(target_group.param_types)

        if source_types == target_types:
            print(f"✓ Group {source_path:30} -> {target_path:30}")
            print(f"  All {len(source_types)} parameters mapped together: {sorted(source_types)}")
        else:
            print(f"✗ ERROR: Group {source_path} has mismatched parameters!")

---

# Part 3: Hierarchical Structure Extraction

The WeightMapper extracts and uses hierarchical structure from models to improve parameter mapping, preserving parent-child relationships in the model architecture.


In [None]:
class DeepSourceModel(nn.Module):
    """Source model with deep nested structure."""

    def __init__(self):
        super().__init__()
        # Deep nested structure
        self.backbone = nn.Sequential(
            nn.Sequential(
                nn.Conv2d(3, 32, 3, padding=1),
                nn.BatchNorm2d(32),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Conv2d(32, 64, 3, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
            ),
        )
        self.head = nn.Linear(64, 10)

    def forward(self, x):
        x = self.backbone(x)
        x = x.mean(dim=[2, 3])
        x = self.head(x)
        return x

In [None]:
class DeepTargetModel(nn.Module):
    """Target model with renamed but similar structure."""

    def __init__(self):
        super().__init__()
        # Similar structure, different names
        self.encoder = nn.Sequential(
            nn.Sequential(
                nn.Conv2d(3, 32, 3, padding=1),
                nn.BatchNorm2d(32),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Conv2d(32, 64, 3, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
            ),
        )
        self.classifier = nn.Linear(64, 10)

    def forward(self, x):
        x = self.encoder(x)
        x = x.mean(dim=[2, 3])
        x = self.classifier(x)
        return x

In [None]:
def visualize_hierarchy(node, indent=0, max_depth=3):
    """Visualize the module hierarchy tree."""
    if indent > max_depth:
        return

    prefix = "  " * indent
    param_info = ""
    if node.parameter_group:
        param_types = sorted(node.parameter_group.param_types)
        param_info = f" [{', '.join(param_types)}]"

    if node.full_path:
        print(f"{prefix}└─ {node.name}{param_info}")
    else:
        print(f"{prefix}<root>")

    for child in node.children.values():
        visualize_hierarchy(child, indent + 1, max_depth)

In [None]:
print("=" * 80)
print("Hierarchical Structure Extraction Demonstration")
print("=" * 80)

# Create models
deep_source = DeepSourceModel()
deep_target = DeepTargetModel()

# Create mapper
mapper_hier = WeightMapper(deep_source, deep_target)

print(f"\nSource model has {len(mapper_hier.source_params)} parameters")
print(f"Target model has {len(mapper_hier.target_params)} parameters")

In [None]:
print("\nSource model hierarchy:")
visualize_hierarchy(mapper_hier.source_hierarchy)

In [None]:
print("\nTarget model hierarchy:")
visualize_hierarchy(mapper_hier.target_hierarchy)

In [None]:
# Show parameter groups organized by hierarchy
print("\nSource parameter groups (organized by hierarchy):")
print(f"Total groups: {len(mapper_hier.source_groups)}")
for path in sorted(mapper_hier.source_groups.keys(), key=lambda x: (x.count("."), x)):
    group = mapper_hier.source_groups[path]
    indent = "  " * path.count(".")
    param_types = sorted(group.param_types)
    print(f"{indent}{path}: {param_types}")

In [None]:
# Perform mapping
mapping_hier = mapper_hier.suggest_mapping(threshold=0.5)

print("\nMapping results:")
print(f"  Total parameters mapped: {len(mapping_hier)}")
print(f"  Coverage: {len(mapping_hier) / len(mapper_hier.source_params) * 100:.1f}%")

In [None]:
# Show how hierarchical context improves matching
if mapper_hier._hierarchy_context and mapper_hier._group_scores:  # type: ignore[attr-defined]
    print("\nHierarchical Context Impact:")
    print(f"{'Source Path':<35} -> {'Target Path':<35} {'Score':>6} {'Context':>8}")
    print("-" * 88)
    for source_path in sorted(mapper_hier._group_mapping.keys(), key=lambda x: (x.count("."), x)):  # type: ignore[attr-defined]
        target_path = mapper_hier._group_mapping[source_path]  # type: ignore[attr-defined]
        score = mapper_hier._group_scores[source_path]  # type: ignore[attr-defined]
        context = mapper_hier._hierarchy_context[source_path]  # type: ignore[attr-defined]
        print(f"{source_path:<35} -> {target_path:<35} {score:>6.3f} {context:>8.3f}")

In [None]:
# Verify parent-child relationships
print("\n" + "=" * 80)
print("Parent-Child Relationship Verification")
print("=" * 80)
print("\nVerifying that child modules are mapped consistently with their parents:")

for source_path, target_path in sorted(mapper_hier._group_mapping.items()):  # type: ignore[attr-defined]
    source_parts = source_path.split(".")
    target_parts = target_path.split(".")

    if len(source_parts) > 1:
        # Check parent mapping
        source_parent = ".".join(source_parts[:-1])
        target_parent = ".".join(target_parts[:-1])

        if source_parent in mapper_hier._group_mapping:  # type: ignore[attr-defined]
            mapped_target_parent = mapper_hier._group_mapping[source_parent]  # type: ignore[attr-defined]
            if mapped_target_parent == target_parent:
                print(f"✓ {source_path:<30} -> {target_path:<30} (parent: {source_parent})")
            else:
                print(f"⚠ {source_path:<30} -> {target_path:<30} (parent mismatch!)")

---

# Part 4: Batch Normalization Support

Demonstrate that group-based mapping works correctly with batch normalization layers, ensuring that weight, bias, running_mean, and running_var are all mapped together.


In [None]:
class ModelWithBNBuffers(nn.Module):
    """Model with batch normalization that has running_mean and running_var."""

    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 64, 3)
        self.bn = nn.BatchNorm2d(64)
        self.fc = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = torch.relu(x)
        x = x.mean(dim=[2, 3])
        x = self.fc(x)
        return x

In [None]:
# Create and initialize models
source_bn = ModelWithBNBuffers()
target_bn = ModelWithBNBuffers()

# Initialize models with different values
with torch.no_grad():
    for p in source_bn.parameters():
        p.fill_(1.0)
    for p in target_bn.parameters():
        p.fill_(0.0)

mapper_bn = WeightMapper(source_bn, target_bn)

# Check that parameters are grouped correctly
print("\nParameter Groups (with BatchNorm buffers):")
for path, group in sorted(mapper_bn.source_groups.items()):
    print(f"  {path}: {sorted(group.param_types)}")

In [None]:
# Get mapping
mapping_bn = mapper_bn.suggest_mapping(threshold=0.5)

print(f"\nTotal parameters mapped: {len(mapping_bn)}")
print(f"Total groups: {len(mapper_bn._group_mapping) if mapper_bn._group_mapping else 0}")  # type: ignore[attr-defined]

In [None]:
# Verify that for each module, all its parameters are mapped together
print("\nGroup Mappings (verifying BatchNorm buffers):")
if mapper_bn._group_mapping:  # type: ignore[attr-defined]
    for source_path, target_path in sorted(mapper_bn._group_mapping.items()):  # type: ignore[attr-defined]
        source_group = mapper_bn.source_groups[source_path]
        target_group = mapper_bn.target_groups[target_path]

        print(f"\n  {source_path} -> {target_path}")
        print(f"    Source params: {sorted(source_group.param_types)}")
        print(f"    Target params: {sorted(target_group.param_types)}")

        # Verify all param types match
        if source_group.param_types != target_group.param_types:
            raise ValueError(f"Mismatch in param types for {source_path}")

        # Show individual mappings
        for param_type in sorted(source_group.param_types):
            source_param = source_group.params[param_type]
            target_param = target_group.params[param_type]
            print(f"      {source_param.name} -> {target_param.name}")

print("\n✓ All BatchNorm parameters are grouped and mapped together correctly!")

---

# Part 5: Multiple Mapping Strategies

The WeightMapper supports different strategies for different use cases.


In [None]:
# Use the models from Part 1
test_old = OldModel()
test_new = NewModel()

print("=" * 80)
print("Comparing Different Mapping Strategies")
print("=" * 80)

In [None]:
# Strategy 1: Conservative (higher threshold, more confident matches)
mapper_conservative = WeightMapper(test_old, test_new)
mapping_conservative = mapper_conservative.suggest_mapping(threshold=0.7, strategy="conservative")

print(f"\n1. Conservative strategy found {len(mapping_conservative)} matches")
print("   (Uses higher threshold for more confident matches)")

In [None]:
# Strategy 2: Shape-only (ignores names, only considers shapes)
mapper_shape = WeightMapper(test_old, test_new)
mapping_shape = mapper_shape.suggest_mapping(strategy="shape_only")

print(f"\n2. Shape-only strategy found {len(mapping_shape)} matches")
print("   (Considers only parameter shapes, ignoring names)")
print("\nFirst 5 mappings:")
for source, target in list(mapping_shape.items())[:5]:
    source_shape = mapper_shape.source_params[source].shape
    print(f"   {source} ({source_shape}) -> {target}")

In [None]:
# Strategy 3: Best match (default, balanced approach)
mapper_best = WeightMapper(test_old, test_new)
mapping_best = mapper_best.suggest_mapping(threshold=0.5, strategy="best_match")

print(f"\n3. Best-match strategy found {len(mapping_best)} matches")
print("   (Balanced approach considering both names and shapes)")

---

# Part 6: Confidence Scores and Analysis

View mappings sorted by confidence scores to identify the most reliable matches.


In [None]:
mapper_scores = WeightMapper(test_old, test_new)
mapper_scores.suggest_mapping()

# Get mappings with scores
mappings_with_scores = mapper_scores.get_mapping_with_scores()

# Sort by score
mappings_with_scores.sort(key=lambda x: x[2], reverse=True)

print("Top 10 mappings by confidence score:")
print(f"{'Source':<40} {'Target':<40} {'Score':>8}")
print("-" * 90)
for source, target, score in mappings_with_scores[:10]:
    print(f"{source:<40} {target:<40} {score:>7.3f}")

In [None]:
# Print detailed analysis
print("\n")
mapper_scores.print_analysis(top_n=20, show_unmatched=True)

---

# Part 7: Export Mapping Report

Export the mapping analysis to a JSON file for documentation and future reference.


In [None]:
mapper_export = WeightMapper(test_old, test_new)
mapper_export.suggest_mapping()

# Export report
output_path = "/tmp/weight_mapping_report.json"  # noqa: S108
mapper_export.export_mapping_report(output_path)

print(f"Report exported successfully to {output_path}")

# Display the report content
with Path(output_path).open() as f:
    report = json.load(f)
    print("\nReport summary:")
    print(f"  Total mappings: {len(report.get('mapping', {}))}")
    print(f"  Unmapped source params: {len(report.get('unmapped_source', []))}")
    print(f"  Unmapped target params: {len(report.get('unmapped_target', []))}")

---

# Part 8: State Dict Mapping (Checkpoint Loading)

This section demonstrates how to use `WeightMapper` when you only have a checkpoint file (state dict) without access to the original model code.


In [None]:
class OldModelV2(nn.Module):
    """Original model architecture."""

    def __init__(self) -> None:
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(64 * 32 * 32, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )


class NewModelV2(nn.Module):
    """Refactored model architecture."""

    def __init__(self) -> None:
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
        )
        self.classifier = nn.Sequential(
            nn.Linear(64 * 32 * 32, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

In [None]:
# Step 1: Create and save old model checkpoint
print("Creating old model and saving checkpoint...")
old_model_v2 = OldModelV2()
checkpoint_path = "/tmp/old_model_checkpoint.pth"  # noqa: S108

# Save as a typical PyTorch checkpoint
torch.save(
    {
        "state_dict": old_model_v2.state_dict(),
        "epoch": 42,
        "optimizer_state": {},  # Would normally have optimizer state
    },
    checkpoint_path,
)
print(f"✓ Saved checkpoint to {checkpoint_path}")

In [None]:
# Step 2: Load checkpoint without original model code
print("Loading checkpoint (without old model code)...")
checkpoint = load_checkpoint_as_dict(checkpoint_path)
old_weights = checkpoint["state_dict"]

print(f"✓ Loaded {len(old_weights)} parameters from checkpoint")
print(f"  Sample keys: {list(old_weights.keys())[:3]}...")

In [None]:
# Step 3: Create new model and mapper
print("Creating new model architecture...")
new_model_v2 = NewModelV2()
print(f"✓ New model has {len(list(new_model_v2.parameters()))} parameters")
print(f"  Sample keys: {list(new_model_v2.state_dict().keys())[:3]}...")

In [None]:
# Create mapper from state dict
print("Creating WeightMapper from state dict...")
mapper_from_dict = WeightMapper.from_state_dict(old_weights, new_model_v2)
print("✓ Mapper created successfully")

In [None]:
# Step 4: Generate and analyze mapping
print("Generating weight mapping...")
mapping_dict = mapper_from_dict.suggest_mapping(strategy="best_match", threshold=0.5)
print(f"✓ Found {len(mapping_dict)} parameter mappings")

# Show analysis
print("\nMapping analysis:")
print("-" * 80)
mapper_from_dict.print_analysis()

In [None]:
# Step 5: Apply mapping and load weights
print("Applying mapping to create new checkpoint...")
new_weights = {}
for old_key, new_key in mapping_dict.items():
    new_weights[new_key] = old_weights[old_key]

# Load into new model
missing, unexpected = new_model_v2.load_state_dict(new_weights, strict=False)
print(f"✓ Loaded {len(new_weights)} weights into new model")

if missing:
    print(f"  Missing keys: {len(missing)}")
    print(f"  Examples: {list(missing)[:3]}")
if unexpected:
    print(f"  Unexpected keys: {len(unexpected)}")

print("\n" + "=" * 80)
print("✓ Successfully mapped weights from old checkpoint to new model!")
print("=" * 80)

---

# Part 9: New Dataclass-Based API

This section demonstrates the refactored WeightMapper API that provides:

1. **Dataclass-based return types** for structured access to results
2. **Full scoring transparency** with breakdown of all components
3. **Ability to explore all compatible group mappings**
4. **Soft hierarchical constraints** instead of hard rejections
5. **Enhanced filtering and analysis** capabilities

The new API makes it easier to understand and debug weight mappings with comprehensive access to scoring details and match information.


## Define Example Models

We'll use models with reorganized structure to demonstrate the new API features.


In [None]:
class OriginalModel(nn.Module):
    """Original model with typical architecture."""

    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, 3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.head = nn.Linear(64, 10)

    def forward(self, x):
        x = self.backbone(x)
        x = x.mean(dim=[2, 3])
        x = self.head(x)
        return x

In [None]:
class ReorganizedModel(nn.Module):
    """Reorganized model where components are renamed/moved."""

    def __init__(self):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 64, 3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.classifier = nn.Linear(64, 10)

    def forward(self, x):
        x = self.feature_extractor(x)
        x = x.mean(dim=[2, 3])
        x = self.classifier(x)
        return x

In [None]:
# Create models
source_api = OriginalModel()
target_api = ReorganizedModel()

# Create mapper
mapper_api = WeightMapper(source_api, target_api)

print("Models created successfully!")
print(f"Source parameters: {len(list(source_api.parameters()))}")
print(f"Target parameters: {len(list(target_api.parameters()))}")

## 9.1 MappingResult Dataclass

The `suggest_mapping()` method now returns a `MappingResult` dataclass instead of a plain dictionary. This provides structured access to the mapping and additional metadata.


In [None]:
print("=" * 80)
print("1. NEW DATACLASS-BASED API")
print("=" * 80)

result = mapper_api.suggest_mapping(threshold=0.5)

# Access mapping through convenience methods
mapping_api = result.get_mapping()
print(f"\nMatched {len(mapping_api)} parameters")
print(f"Coverage: {result.coverage * 100:.1f}%")

# Show first few mappings
print("\nFirst 5 mappings:")
for source, target in list(mapping_api.items())[:5]:
    print(f"  {source} → {target}")

## 9.2 Scoring Transparency

Access detailed score breakdowns for each mapping to understand why parameters were matched.


In [None]:
print("=" * 80)
print("2. SCORING TRANSPARENCY")
print("=" * 80)

# Get mapping with scores
mappings_with_scores_api = result.get_mapping_with_scores()

for source_name, (target_name, final_score) in list(mappings_with_scores_api.items())[:3]:
    print(f"\n{source_name} → {target_name}")
    print(f"  Final score: {final_score:.3f}")

    # Get detailed breakdown
    breakdown = mapper_api.get_score_breakdown(source_name)
    print(f"  Shape score: {breakdown.shape_score:.3f}")
    print(f"  Name score:  {breakdown.name_score:.3f}")
    print(f"  Hierarchy:   {breakdown.hierarchy_score:.3f}")

    # Access even more detail if needed
    if breakdown.depth_score is not None:
        print(f"    - Depth:   {breakdown.depth_score:.3f}")
    if breakdown.path_score is not None:
        print(f"    - Path:    {breakdown.path_score:.3f}")
    if breakdown.order_score is not None:
        print(f"    - Order:   {breakdown.order_score:.3f}")

## 9.3 Parameter Match Details

Access full information about each parameter match, including match type and any transformations.


In [None]:
print("=" * 80)
print("3. PARAMETER MATCH DETAILS")
print("=" * 80)

param_name = list(mapping_api.keys())[0]
match_details = mapper_api.get_parameter_details(param_name)

print(f"\nParameter: {match_details.source_name}")
print(f"  Matched: {match_details.matched}")
print(f"  Target: {match_details.target_name}")
print(f"  Match type: {match_details.match_type}")
print(f"  Score: {match_details.final_score:.3f}")
if match_details.transformation:
    print(f"  Transformation: {match_details.transformation.type}")

## 9.4 Compatible Group Exploration

Explore all compatible target groups for a given source group to see alternative mappings and their scores.


In [None]:
print("=" * 80)
print("4. COMPATIBLE GROUP EXPLORATION")
print("=" * 80)

# Get all compatible groups for a specific source path
print(f"\nAvailable source groups: {list(mapper_api.source_groups.keys())[:5]}")

# Use an actual group path
if mapper_api.source_groups:
    first_group = list(mapper_api.source_groups.keys())[0]
    compatible = mapper_api.get_compatible_groups(
        source_path=first_group,
        threshold=0.0,  # Show all candidates
        max_candidates=5,
    )

    print(f"\nCompatible target groups for '{first_group}':")
    if first_group in compatible:
        for candidate in compatible[first_group][:3]:
            print(f"  {candidate.target_path}: {candidate.combined_score:.3f}")
            print(f"    Structure: {candidate.structure_score:.3f}")
            print(f"    Matched types: {candidate.param_types_matched}")

## 9.5 Unmatched Parameter Analysis

Analyze which parameters couldn't be matched and understand why.


In [None]:
print("=" * 80)
print("5. UNMATCHED PARAMETERS")
print("=" * 80)

# Get unmatched through mapper (returns dict with 'source' and 'target' keys)
unmatched = mapper_api.get_unmatched()

if unmatched.get("source"):
    print(f"\nUnmatched source parameters: {len(unmatched['source'])}")
    for param in unmatched["source"][:3]:
        details = mapper_api.get_parameter_details(param)
        print(f"  {param}: {details.unmatch_reason}")
else:
    print("\nAll source parameters matched successfully!")

if unmatched.get("target"):
    print(f"\nUnmatched target parameters: {len(unmatched['target'])}")
    for param in unmatched["target"][:3]:
        print(f"  {param}")
else:
    print("All target parameters matched successfully!")

## 9.6 Confidence Filtering

Filter mappings by confidence score to identify high-quality matches or those that need review.


In [None]:
print("=" * 80)
print("6. CONFIDENCE FILTERING")
print("=" * 80)

# Get only high-confidence matches
high_confidence = result.filter_by_score(min_score=0.8)
print(f"\nHigh confidence matches (score >= 0.8): {len(high_confidence.matched_params)}")

# Get low-confidence matches that might need review
low_confidence = result.get_low_confidence_matches(threshold=0.6)
print(f"Low confidence matches (score < 0.6): {len(low_confidence)}")

if low_confidence:
    print("\nLow confidence matches (may need review):")
    for source, target, score in low_confidence[:3]:
        print(f"  {source} → {target} (score: {score:.3f})")

print("\n" + "=" * 80)
print("✓ New API demonstration complete!")
print("=" * 80)

### Key Takeaways from the New API

The dataclass-based API provides:

1. **Structured Results**: `MappingResult` dataclass with typed fields and convenience methods
2. **Score Breakdown**: Detailed scoring components (shape, name, hierarchy, depth, path, order)
3. **Parameter Details**: Full match information including match type and transformations
4. **Group Exploration**: View all compatible group candidates with scores
5. **Enhanced Filtering**: Filter by confidence score to identify matches needing review
6. **Better Debugging**: Access to unmatch reasons and detailed scoring information

This makes it much easier to understand, debug, and validate weight mappings between models.


---

# Cleanup


In [None]:
# Clean up temporary files
checkpoint_file = Path(checkpoint_path)
if checkpoint_file.exists():
    checkpoint_file.unlink()
    print(f"Cleaned up {checkpoint_path}")

report_file = Path("/tmp/weight_mapping_report.json")  # noqa: S108
if report_file.exists():
    report_file.unlink()
    print("Cleaned up /tmp/weight_mapping_report.json")

---

# Summary

This notebook demonstrated all key features of the `WeightMapper` class:

## 1. **Basic Weight Mapping**

- Map weights between models with different naming conventions
- Automatic parameter matching based on structure similarity

## 2. **Group-Based Parameter Mapping**

- Parameters organized into groups (e.g., conv weight + bias)
- All related parameters mapped together as a cohesive unit

## 3. **Hierarchical Structure Extraction**

- Parent-child relationships in model architecture are preserved
- Hierarchical context improves matching accuracy

## 4. **Batch Normalization Support**

- Handles BN layers with all their components (weight, bias, running_mean, running_var)
- All buffers and parameters mapped together correctly

## 5. **Multiple Mapping Strategies**

- **Conservative**: Higher threshold for more confident matches
- **Shape-only**: Considers only parameter shapes, ignoring names
- **Best-match**: Balanced approach (default)

## 6. **Confidence Scores**

- Quality metrics for evaluating mapping reliability
- Detailed analysis with scoring information

## 7. **Export Capabilities**

- Save mapping reports as JSON for documentation
- Review and validate mappings before applying

## 8. **State Dict Mapping**

- Work with checkpoints when original model code is unavailable
- Essential for model refactoring and architecture migrations

## 9. **New Dataclass-Based API**

- Structured `MappingResult` dataclass with typed fields and convenience methods
- Detailed score breakdowns (shape, name, hierarchy, depth, path, order)
- Full parameter match details including match type and transformations
- Compatible group exploration to view alternative mappings
- Enhanced confidence filtering and unmatch reason analysis
- Better debugging with comprehensive scoring information

The `WeightMapper` class provides flexible and powerful tools for transferring weights between different model architectures, making it easier to refactor code while preserving learned parameters.
