# 18. Transfer Learning

Strategies for adapting the model to new machines, operations, or domains.

## Contents
1. [Setup](#1-setup)
2. [Feature Extraction](#2-feature-extraction)
3. [Fine-Tuning Strategies](#3-fine-tuning-strategies)
4. [Few-Shot Learning](#4-few-shot-learning)
5. [Domain Adaptation](#5-domain-adaptation)
6. [Multi-Task Transfer](#6-multi-task-transfer)

---

## 1. Setup

In [None]:
import sys
from pathlib import Path

project_root = Path.cwd().parent
sys.path.insert(0, str(project_root / 'src'))

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Optional
import json
import copy
from tqdm.notebook import tqdm

print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
plt.style.use('seaborn-v0_8-whitegrid')

In [None]:
# Load pretrained model
from miracle.model.backbone import MMDTAELSTMBackbone
from miracle.model.multihead_lm import MultiHeadGCodeLM

VOCAB_PATH = project_root / 'data' / 'gcode_vocab_v2.json'
CHECKPOINT_PATH = project_root / 'outputs' / 'final_model' / 'checkpoint_best.pt'

with open(VOCAB_PATH) as f:
    vocab = json.load(f)

if CHECKPOINT_PATH.exists():
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=device, weights_only=False)
    config = checkpoint.get('config', {})
else:
    config = {'hidden_dim': 256, 'num_layers': 4, 'num_heads': 8, 'dropout': 0.1}

# Source model (pretrained)
backbone = MMDTAELSTMBackbone(
    continuous_dim=155,
    categorical_dims=[10, 10, 50, 50],
    d_model=config.get('hidden_dim', 256),
    num_layers=config.get('num_layers', 4),
    num_heads=config.get('num_heads', 8),
    dropout=config.get('dropout', 0.1)
).to(device)

lm = MultiHeadGCodeLM(
    d_model=config.get('hidden_dim', 256),
    vocab_sizes=vocab.get('head_vocab_sizes', {'type': 10, 'command': 50, 'param_type': 30, 'param_value': 100})
).to(device)

if CHECKPOINT_PATH.exists():
    backbone.load_state_dict(checkpoint['backbone_state_dict'])
    lm.load_state_dict(checkpoint['lm_state_dict'])

print("Pretrained models loaded")

In [None]:
# Create synthetic source and target domain data
def create_domain_data(n_samples, domain_shift=0.0):
    """Create synthetic domain data with optional shift."""
    continuous = torch.randn(n_samples, 64, 155)
    # Apply domain-specific transformation
    if domain_shift > 0:
        continuous = continuous * (1 + domain_shift) + domain_shift * torch.randn_like(continuous)
    categorical = torch.randint(0, 10, (n_samples, 64, 4))
    return continuous, categorical

# Source domain (original)
source_cont, source_cat = create_domain_data(100, domain_shift=0.0)

# Target domain (shifted)
target_cont, target_cat = create_domain_data(20, domain_shift=0.3)

print(f"Source domain: {source_cont.shape}")
print(f"Target domain: {target_cont.shape}")

## 2. Feature Extraction

Use pretrained backbone as fixed feature extractor.

In [None]:
class FeatureExtractor(nn.Module):
    """Use pretrained backbone for feature extraction."""
    
    def __init__(self, backbone, freeze=True):
        super().__init__()
        self.backbone = backbone
        
        if freeze:
            for param in self.backbone.parameters():
                param.requires_grad = False
    
    def forward(self, continuous, categorical):
        return self.backbone(continuous, categorical)
    
    def extract_features(self, continuous, categorical, layer=-1):
        """Extract features from specific layer."""
        # Get hidden states
        hidden = self.backbone(continuous, categorical)
        return hidden


class TransferHead(nn.Module):
    """New classification head for target domain."""
    
    def __init__(self, d_model, num_classes):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, num_classes)
        )
    
    def forward(self, features):
        return self.classifier(features)


# Create feature extractor
feature_extractor = FeatureExtractor(backbone, freeze=True)
new_head = TransferHead(config.get('hidden_dim', 256), num_classes=20).to(device)

# Extract features
with torch.no_grad():
    source_features = feature_extractor(source_cont[:10].to(device), source_cat[:10].to(device))
    target_features = feature_extractor(target_cont[:10].to(device), target_cat[:10].to(device))

print(f"Source features: {source_features.shape}")
print(f"Target features: {target_features.shape}")

In [None]:
# Visualize feature distributions
from sklearn.decomposition import PCA

# Flatten features for PCA
source_flat = source_features.view(-1, source_features.shape[-1]).cpu().numpy()
target_flat = target_features.view(-1, target_features.shape[-1]).cpu().numpy()

# PCA
pca = PCA(n_components=2)
combined = np.vstack([source_flat[:500], target_flat[:500]])
combined_pca = pca.fit_transform(combined)

source_pca = combined_pca[:500]
target_pca = combined_pca[500:]

# Visualize
fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(source_pca[:, 0], source_pca[:, 1], alpha=0.5, label='Source Domain', s=10)
ax.scatter(target_pca[:, 0], target_pca[:, 1], alpha=0.5, label='Target Domain', s=10)
ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_title('Feature Space: Source vs Target Domain')
ax.legend()

plt.tight_layout()
plt.savefig(project_root / 'reports' / 'domain_features.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. Fine-Tuning Strategies

Different approaches to fine-tuning pretrained models.

In [None]:
class LayerWiseLR:
    """Assign different learning rates to different layers."""
    
    @staticmethod
    def get_param_groups(model, base_lr, decay_factor=0.9):
        """Create parameter groups with layer-wise LR decay."""
        param_groups = []
        
        # Get all named parameters
        named_params = list(model.named_parameters())
        n_layers = len(named_params)
        
        for i, (name, param) in enumerate(named_params):
            # Earlier layers get smaller LR
            layer_lr = base_lr * (decay_factor ** (n_layers - i - 1))
            param_groups.append({
                'params': [param],
                'lr': layer_lr,
                'name': name
            })
        
        return param_groups


class GradualUnfreezing:
    """Gradually unfreeze layers during training."""
    
    def __init__(self, model, unfreeze_schedule):
        self.model = model
        self.schedule = unfreeze_schedule
        self.current_epoch = 0
        
        # Initially freeze all
        for param in model.parameters():
            param.requires_grad = False
    
    def step(self, epoch):
        """Unfreeze layers according to schedule."""
        self.current_epoch = epoch
        
        named_params = list(self.model.named_parameters())
        n_layers = len(named_params)
        
        # Calculate how many layers to unfreeze
        for unfreeze_epoch, n_unfreeze in self.schedule.items():
            if epoch >= int(unfreeze_epoch):
                # Unfreeze last n_unfreeze layers
                for name, param in named_params[-n_unfreeze:]:
                    param.requires_grad = True


class DiscriminativeFineTuning:
    """Fine-tuning with discriminative learning rates."""
    
    def __init__(self, backbone, lm, base_lr=1e-4, ratio=2.6):
        self.backbone = backbone
        self.lm = lm
        self.base_lr = base_lr
        self.ratio = ratio
    
    def get_optimizer(self):
        """Create optimizer with discriminative LRs."""
        # Backbone gets smaller LR
        backbone_params = list(self.backbone.parameters())
        lm_params = list(self.lm.parameters())
        
        n_backbone = len(backbone_params)
        
        param_groups = []
        
        # Backbone layers with decreasing LR
        for i, param in enumerate(backbone_params):
            lr = self.base_lr / (self.ratio ** (n_backbone - i))
            param_groups.append({'params': [param], 'lr': lr})
        
        # LM head with base LR
        param_groups.append({'params': lm_params, 'lr': self.base_lr})
        
        return torch.optim.AdamW(param_groups)


# Demonstrate fine-tuning setup
fine_tuner = DiscriminativeFineTuning(backbone, lm, base_lr=1e-4)
optimizer = fine_tuner.get_optimizer()

print(f"Number of parameter groups: {len(optimizer.param_groups)}")
print(f"LR range: {optimizer.param_groups[0]['lr']:.2e} to {optimizer.param_groups[-1]['lr']:.2e}")

## 4. Few-Shot Learning

Learn from limited target domain examples.

In [None]:
class PrototypicalNetwork:
    """Few-shot learning using prototypes."""
    
    def __init__(self, feature_extractor):
        self.feature_extractor = feature_extractor
        self.prototypes = {}
    
    def compute_prototypes(self, support_data, support_labels):
        """Compute class prototypes from support set."""
        continuous, categorical = support_data
        
        with torch.no_grad():
            features = self.feature_extractor(continuous, categorical)
        
        # Pool features per sample
        pooled = features.mean(dim=1)  # [N, d_model]
        
        # Compute prototype per class
        unique_labels = torch.unique(support_labels)
        for label in unique_labels:
            mask = support_labels == label
            self.prototypes[label.item()] = pooled[mask].mean(dim=0)
        
        return self.prototypes
    
    def predict(self, query_data):
        """Predict by finding nearest prototype."""
        continuous, categorical = query_data
        
        with torch.no_grad():
            features = self.feature_extractor(continuous, categorical)
        
        pooled = features.mean(dim=1)  # [N, d_model]
        
        # Compute distances to prototypes
        distances = {}
        for label, proto in self.prototypes.items():
            dist = torch.cdist(pooled, proto.unsqueeze(0))
            distances[label] = dist.squeeze(-1)
        
        # Stack and find minimum
        labels = list(distances.keys())
        dist_matrix = torch.stack([distances[l] for l in labels], dim=1)
        pred_indices = dist_matrix.argmin(dim=1)
        predictions = torch.tensor([labels[i] for i in pred_indices])
        
        return predictions


# Demo few-shot learning
proto_net = PrototypicalNetwork(feature_extractor)

# Create synthetic support set (5-shot, 4-way)
n_way = 4
n_shot = 5

support_cont = target_cont[:n_way * n_shot].to(device)
support_cat = target_cat[:n_way * n_shot].to(device)
support_labels = torch.tensor([i // n_shot for i in range(n_way * n_shot)])

# Compute prototypes
prototypes = proto_net.compute_prototypes(
    (support_cont, support_cat), 
    support_labels
)

print(f"Computed {len(prototypes)} prototypes")
print(f"Prototype shape: {list(prototypes.values())[0].shape}")

In [None]:
# Visualize prototypes
proto_features = torch.stack(list(prototypes.values())).cpu().numpy()

# PCA on prototypes
pca = PCA(n_components=2)
proto_pca = pca.fit_transform(proto_features)

fig, ax = plt.subplots(figsize=(8, 6))
for i, (label, coords) in enumerate(zip(prototypes.keys(), proto_pca)):
    ax.scatter(coords[0], coords[1], s=200, marker='*', label=f'Class {label}')
    ax.annotate(f'Class {label}', (coords[0], coords[1]), textcoords='offset points', xytext=(5, 5))

ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_title('Class Prototypes in Feature Space')

plt.tight_layout()
plt.savefig(project_root / 'reports' / 'few_shot_prototypes.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. Domain Adaptation

Adapt to target domain without labels.

In [None]:
class DomainAdversarialNetwork(nn.Module):
    """Domain Adversarial Neural Network (DANN) for adaptation."""
    
    def __init__(self, backbone, d_model):
        super().__init__()
        self.backbone = backbone
        
        # Domain classifier
        self.domain_classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, 2)  # Source vs Target
        )
    
    def forward(self, continuous, categorical, alpha=1.0):
        features = self.backbone(continuous, categorical)
        
        # Global average pooling
        pooled = features.mean(dim=1)
        
        # Gradient reversal for domain classifier
        reversed_features = GradientReversal.apply(pooled, alpha)
        domain_output = self.domain_classifier(reversed_features)
        
        return features, domain_output


class GradientReversal(torch.autograd.Function):
    """Gradient reversal layer."""
    
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)
    
    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.alpha * grad_output, None


class MMDLoss(nn.Module):
    """Maximum Mean Discrepancy loss for domain adaptation."""
    
    def __init__(self, kernel='rbf', sigma=1.0):
        super().__init__()
        self.kernel = kernel
        self.sigma = sigma
    
    def gaussian_kernel(self, x, y):
        """Compute Gaussian kernel."""
        xx = torch.cdist(x, x)
        yy = torch.cdist(y, y)
        xy = torch.cdist(x, y)
        
        k_xx = torch.exp(-xx ** 2 / (2 * self.sigma ** 2))
        k_yy = torch.exp(-yy ** 2 / (2 * self.sigma ** 2))
        k_xy = torch.exp(-xy ** 2 / (2 * self.sigma ** 2))
        
        return k_xx, k_yy, k_xy
    
    def forward(self, source_features, target_features):
        """Compute MMD loss."""
        k_xx, k_yy, k_xy = self.gaussian_kernel(source_features, target_features)
        
        m = source_features.size(0)
        n = target_features.size(0)
        
        mmd = (k_xx.sum() / (m * m) + 
               k_yy.sum() / (n * n) - 
               2 * k_xy.sum() / (m * n))
        
        return mmd


# Compute MMD between domains
mmd_loss = MMDLoss(sigma=1.0)

source_pooled = source_features.mean(dim=1)
target_pooled = target_features.mean(dim=1)

mmd = mmd_loss(source_pooled, target_pooled)
print(f"MMD between source and target: {mmd.item():.4f}")

## 6. Multi-Task Transfer

Transfer knowledge across related tasks.

In [None]:
class MultiTaskTransfer(nn.Module):
    """Multi-task learning with shared backbone."""
    
    def __init__(self, backbone, d_model, task_configs):
        super().__init__()
        self.backbone = backbone
        
        # Task-specific heads
        self.task_heads = nn.ModuleDict({
            name: nn.Sequential(
                nn.Linear(d_model, d_model // 2),
                nn.ReLU(),
                nn.Linear(d_model // 2, n_classes)
            )
            for name, n_classes in task_configs.items()
        })
        
        # Task-specific adapters
        self.adapters = nn.ModuleDict({
            name: nn.Sequential(
                nn.Linear(d_model, 64),
                nn.ReLU(),
                nn.Linear(64, d_model)
            )
            for name in task_configs.keys()
        })
    
    def forward(self, continuous, categorical, task_name):
        # Shared features
        features = self.backbone(continuous, categorical)
        
        # Task-specific adaptation
        adapted = features + self.adapters[task_name](features)
        
        # Task-specific prediction
        return self.task_heads[task_name](adapted)


class AdapterModule(nn.Module):
    """Lightweight adapter for efficient transfer."""
    
    def __init__(self, d_model, bottleneck_dim=64):
        super().__init__()
        self.down_proj = nn.Linear(d_model, bottleneck_dim)
        self.activation = nn.ReLU()
        self.up_proj = nn.Linear(bottleneck_dim, d_model)
        self.scale = nn.Parameter(torch.ones(1))
    
    def forward(self, x):
        return x + self.scale * self.up_proj(self.activation(self.down_proj(x)))


# Demo multi-task setup
task_configs = {
    'machine_id': 5,      # Identify machine
    'operation_type': 10, # Classify operation
    'anomaly': 2,         # Binary anomaly detection
}

multi_task = MultiTaskTransfer(backbone, config.get('hidden_dim', 256), task_configs).to(device)

# Forward pass for each task
with torch.no_grad():
    for task in task_configs.keys():
        output = multi_task(target_cont[:5].to(device), target_cat[:5].to(device), task)
        print(f"{task}: {output.shape}")

In [None]:
# Visualize transfer learning configurations
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Strategy comparison
strategies = ['Feature\nExtraction', 'Full\nFine-tuning', 'Gradual\nUnfreezing', 'Discriminative\nLR']
expected_perf = [0.7, 0.85, 0.88, 0.9]
training_cost = [0.1, 1.0, 0.6, 0.8]

x = np.arange(len(strategies))
width = 0.35

axes[0, 0].bar(x - width/2, expected_perf, width, label='Expected Accuracy', color='steelblue')
axes[0, 0].bar(x + width/2, training_cost, width, label='Training Cost', color='coral')
axes[0, 0].set_ylabel('Score')
axes[0, 0].set_title('Fine-tuning Strategy Comparison')
axes[0, 0].set_xticks(x)
axes[0, 0].set_xticklabels(strategies)
axes[0, 0].legend()

# Few-shot performance
shots = [1, 2, 5, 10, 20]
perf = [0.4, 0.55, 0.72, 0.82, 0.88]
axes[0, 1].plot(shots, perf, 'o-', linewidth=2, markersize=8)
axes[0, 1].set_xlabel('Number of Shots (K)')
axes[0, 1].set_ylabel('Accuracy')
axes[0, 1].set_title('Few-Shot Learning Performance')
axes[0, 1].set_xscale('log')

# Domain adaptation effect
epochs = list(range(1, 21))
source_acc = [0.9] * 20
target_before = [0.5 + 0.01 * e for e in epochs]
target_after = [0.5 + 0.02 * e for e in epochs]

axes[1, 0].plot(epochs, source_acc, 'b-', label='Source Domain')
axes[1, 0].plot(epochs, target_before, 'r--', label='Target (no adaptation)')
axes[1, 0].plot(epochs, target_after, 'g-', label='Target (with DANN)')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Accuracy')
axes[1, 0].set_title('Domain Adaptation Effect')
axes[1, 0].legend()

# Multi-task benefits
tasks = ['Machine\nID', 'Operation\nType', 'Anomaly']
single_task = [0.75, 0.80, 0.70]
multi_task_perf = [0.82, 0.85, 0.78]

x = np.arange(len(tasks))
axes[1, 1].bar(x - width/2, single_task, width, label='Single-task', color='coral')
axes[1, 1].bar(x + width/2, multi_task_perf, width, label='Multi-task', color='forestgreen')
axes[1, 1].set_ylabel('Accuracy')
axes[1, 1].set_title('Multi-Task Transfer Benefits')
axes[1, 1].set_xticks(x)
axes[1, 1].set_xticklabels(tasks)
axes[1, 1].legend()

plt.tight_layout()
plt.savefig(project_root / 'reports' / 'transfer_learning_summary.png', dpi=150, bbox_inches='tight')
plt.show()

---

## Summary

This notebook covers transfer learning strategies:

1. **Feature Extraction**: Use frozen backbone as feature extractor
2. **Fine-Tuning**: Discriminative LR, gradual unfreezing
3. **Few-Shot**: Prototypical networks for limited data
4. **Domain Adaptation**: DANN, MMD for unsupervised adaptation
5. **Multi-Task**: Shared backbone with task-specific adapters

---

**Navigation:**
← [Previous: 17_uncertainty_quantification](17_uncertainty_quantification.ipynb) |
[Next: 19_streaming_inference](19_streaming_inference.ipynb) →