In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm, trange
import torch.optim as optim
from torchvision import models
import numpy as np
from collections import deque
import random
import time
import torch.nn.functional as F
from torch.utils.data import ConcatDataset, DataLoader, Subset
import torchvision.models.segmentation as segmentation
from collections import defaultdict

### ForgettingGatedSplitModel

In [None]:
import os
import sys
import cv2
import pickle
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics import YOLO
from tqdm.notebook import tqdm, trange
from typing import Tuple, Optional
from collections import deque

sys.path.append('D:/projects/Depth-Anything-V2')
from depth_anything_v2.dpt import DepthAnythingV2

#### backbone

In [None]:
class YOLO11FeatureExtractor(nn.Module):
    """Extract features from YOLO11-seg backbone (layers 0-8)"""
    def __init__(self, model_path='../models/yolo11n-seg.pt'):
        super().__init__()
        yolo = YOLO(model_path)

        # Extract layers 0-8 as discussed (before SPPF to preserve spatial granularity)
        self.backbone = nn.Sequential(*list(yolo.model.model[:9]))
        del yolo
        
        # Freeze parameters
        for param in self.backbone.parameters():
            param.requires_grad = False
        
    def forward(self, x):
        return self.backbone(x)  # Output: [B, 256, H/32, W/32]

class DepthAnythingFeatureExtractor(nn.Module):
    def __init__(self, model_path='../models/depth_anything_v2_vits.pth'):
        super().__init__()
        model_configs = {
            'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
        }
        model = DepthAnythingV2(**model_configs['vits'])
        model.load_state_dict(torch.load(model_path, map_location='cpu'))
        
        self.backbone = model.pretrained
        del model

        for param in self.backbone.parameters():
            param.requires_grad = False
    
    def forward(self, x):
        B, C, H, W = x.shape
        
        with torch.no_grad():
            # Patch embedding
            x = self.backbone.patch_embed(x)
            
            # Pass through transformer blocks
            for block in self.backbone.blocks:
                x = block(x)
            
            # Apply final norm
            x = self.backbone.norm(x)
            
            # Remove CLS token and reshape to spatial format
            if x.shape[1] > (H//14) * (W//14):
                patch_tokens = x[:, 1:, :]  # Remove CLS token
            else:
                patch_tokens = x
            
            # Reshape to spatial format
            patch_h, patch_w = H // 14, W // 14
            spatial_features = patch_tokens.transpose(1, 2).reshape(B, 384, patch_h, patch_w)
            
            return spatial_features


class DualBackboneFeatureExtractor(nn.Module):
    """Combines YOLO11 and DepthAnything feature extractors"""
    def __init__(self, model_paths):
        super().__init__()
        # Individual feature extractors
        self.yolo_extractor = YOLO11FeatureExtractor(model_paths['YOLO11n_seg'])
        self.depth_extractor = DepthAnythingFeatureExtractor(model_paths['DepthAnythingV2_small'])
        
        # Feature projection layers (optional - can keep features as-is)
        self.yolo_proj = nn.Conv2d(256, 256, 1)
        self.depth_proj = nn.Conv2d(384, 384, 1)
        
    def forward(self, x):
        # Extract features from both backbones
        yolo_features = self.yolo_extractor(x[0])  # [B, 256, H/32, W/32]
        depth_features = self.depth_extractor(x[1])  # [B, 384, H/14, W/14]
        
        # Apply projections
        yolo_features = self.yolo_proj(yolo_features)
        depth_features = self.depth_proj(depth_features)
        
        # Downsample the depthanything features to match yolo
        depth_features = F.adaptive_avg_pool2d(depth_features, (7, 7)) #image size /32 224,224 > 7,7
        
        # Concatenate features
        combined_features = torch.cat([yolo_features, depth_features], dim=1)  # [B, 640, H/14, W/14] 
        
        return combined_features

#### model components

In [None]:
class CatastrophicForgettingAdapter(nn.Module):
    """Standalone adapter that measures forgetting and gates spatial features"""
    def __init__(self, num_channels=640, num_outputs=9):
        super().__init__()
        self.num_channels = num_channels
        
        # Forgetting measurement storage (per channel)
        self.previous_adapter_state = None
        self.current_forgetting_scores = None
        
        # Adapter that predicts main task (for forgetting measurement)
        self.adapter = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # Global average pooling
            nn.Flatten(),
            nn.Linear(num_channels, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_outputs)  # Predicts main task
        )
        
        # Learnable gating network (operates on channel-wise forgetting scores)
        self.gating = nn.Sequential(
            nn.Linear(num_channels, num_channels // 4),
            nn.ReLU(),
            nn.Linear(num_channels // 4, num_channels),
            nn.Sigmoid()
        )
        
    def forward(self, backbone_features, use_gating=True):
        """
        Args:
            backbone_features: [B, 640, 36, 36] spatial features
        Returns:
            forgetting_features: [B, 640, 36, 36] features routed to forgetting branch
            not_forgetting_features: [B, 640, 36, 36] features routed to stable branch
            adapter_predictions: [B, 9] task predictions for forgetting measurement
        """
        B, C, H, W = backbone_features.shape
        
        # Adapter makes predictions (for forgetting measurement)
        adapter_predictions = self.adapter(backbone_features)
        
        if use_gating and self.current_forgetting_scores is not None:
            # Apply channel-wise gating based on forgetting scores
            forgetting_scores_batch = self.current_forgetting_scores.unsqueeze(0).expand(B, -1)  # [B, 640]
            gating_mask = self.gating(forgetting_scores_batch)  # [B, 640]
            
            # Reshape for spatial broadcasting
            gating_mask = gating_mask.unsqueeze(2).unsqueeze(3)  # [B, 640, 1, 1]
            
            # Route features channel-wise
            forgetting_features = backbone_features * gating_mask
            not_forgetting_features = backbone_features * (1 - gating_mask)
        else:
            # No gating - both branches get all features
            forgetting_features = backbone_features
            not_forgetting_features = backbone_features
        
        return forgetting_features, not_forgetting_features, adapter_predictions
    
    def store_adapter_state(self):
        """Store current adapter state for forgetting measurement"""
        self.previous_adapter_state = {
            name: param.clone().detach() 
            for name, param in self.adapter.named_parameters()
        }
    
    def compute_forgetting_scores(self):
        """Compute channel-wise forgetting scores by comparing adapter states"""
        # Get the device from model parameters
        device = next(self.parameters()).device
        
        if self.previous_adapter_state is None:
            # First domain - no forgetting to measure
            self.current_forgetting_scores = torch.zeros(self.num_channels, device=device)
            return
        
        # Compare adapter weights to measure which channels are most affected
        forgetting_scores = []
        
        for name, current_param in self.adapter.named_parameters():
            if name in self.previous_adapter_state:
                previous_param = self.previous_adapter_state[name]
                
                if 'weight' in name and len(current_param.shape) == 2:
                    # For first linear layer after pooling: [256, 640]
                    if current_param.shape[1] == self.num_channels:
                        # Compute change per input channel
                        weight_change = torch.norm(current_param - previous_param, dim=0)
                        forgetting_scores.append(weight_change)
                        break  # Use only first layer that maps from channels
        
        # Use channel-wise forgetting scores
        if forgetting_scores:
            self.current_forgetting_scores = forgetting_scores[0].to(device)
            
            # Normalize to [0, 1]
            if self.current_forgetting_scores.max() > 0:
                self.current_forgetting_scores = (
                    self.current_forgetting_scores / self.current_forgetting_scores.max()
                )
        else:
            # Fallback if no forgetting scores computed
            self.current_forgetting_scores = torch.zeros(self.num_channels, device=device)

class ConvBranch(nn.Module):
    """Convolutional processing branch for spatial features"""
    def __init__(self, input_channels=640, hidden_channels=256, output_channels=128):
        super().__init__()
        self.conv_layers = nn.Sequential(
            # First conv block
            nn.Conv2d(input_channels, hidden_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),
            
            # Second conv block
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),
            
            # Output conv block
            nn.Conv2d(hidden_channels, output_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True),
            
            # Global pooling to get fixed-size output
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten()  # [B, output_channels]
        )
        
    def forward(self, x):
        return self.conv_layers(x)
    
    
class FeatureFusion(nn.Module):
    """Attention-based fusion of branch features"""
    def __init__(self, feature_dim=256):
        super().__init__()
        
        self.attention = nn.Sequential(
            nn.Linear(feature_dim * 2, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, 2),
            nn.Softmax(dim=1)
        )
        
    def forward(self, high_features, low_features):
        # Concatenate features
        combined = torch.cat([high_features, low_features], dim=1)  # [B, 512]
        
        # Compute attention weights
        attention_weights = self.attention(combined)  # [B, 2]
        
        # Apply attention
        weighted_high = high_features * attention_weights[:, 0:1]
        weighted_low = low_features * attention_weights[:, 1:2]
        
        # Combine with residual connection
        fused = weighted_high + weighted_low
        
        return fused


#### model

In [None]:
class CatastrophicForgettingDisentanglementModel(nn.Module):
    """Updated model with spatial features and convolutional branches"""
    
    def __init__(self, 
                 backbone_output_channels: int = 640,
                 branch_hidden_channels: int = 256,
                 branch_output_channels: int = 128,
                 num_outputs: int = 9):
        super().__init__()
        
        print("🚀 Initializing Spatial Catastrophic Forgetting Model...")
        
        # Dual backbone (frozen) - now outputs spatial features
        model_paths={
            'YOLO11n_seg': '../models/yolo11n-seg.pt', 
            'DepthAnythingV2_small': '../models/depth_anything_v2_vits.pth'
            }
        missing = [name for name, path in model_paths.items() if not os.path.exists(path)]
        if missing:
            raise FileNotFoundError(f"Missing model files for: {', '.join(missing)}")

        self.backbone = DualBackboneFeatureExtractor(model_paths=model_paths)  # [B, 640, 36, 36]
        
        # Standalone forgetting adapter
        self.forgetting_adapter = CatastrophicForgettingAdapter(
            num_channels=backbone_output_channels,
            num_outputs=num_outputs
        )
        
        # Convolutional branches for spatial processing
        self.branch_forgetting = ConvBranch(
            backbone_output_channels, 
            branch_hidden_channels, 
            branch_output_channels
        )
        self.branch_not_forgetting = ConvBranch(
            backbone_output_channels, 
            branch_hidden_channels, 
            branch_output_channels
        )
        
        self.fusion = FeatureFusion(branch_output_channels)

        # Final head combines branch outputs
        self.head =  nn.Sequential(
            nn.Linear(branch_output_channels, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(256, num_outputs)
        )
        
        print("✅ Spatial model initialized successfully!")
        
    def forward(self, x, use_gating: bool = True):
        # Extract spatial features from frozen backbone
        backbone_features = self.backbone(x)  # [B, 640, 36, 36]
        
        # Process through forgetting adapter (gating + task prediction)
        forgetting_features, not_forgetting_features, adapter_predictions = \
            self.forgetting_adapter(backbone_features, use_gating)
        
        # Process through convolutional branches
        forgetting_output = self.branch_forgetting(forgetting_features)      # [B, 128]
        not_forgetting_output = self.branch_not_forgetting(not_forgetting_features)  # [B, 128]
        
        # Combine branch outputs for final prediction
        fused_features = self.fusion(forgetting_output, not_forgetting_output)
        final_output = self.head(fused_features)
        
        return {
            'output': final_output,
            'adapter_output': adapter_predictions,
            'backbone_features': backbone_features,
            'forgetting_features': forgetting_features,
            'not_forgetting_features': not_forgetting_features
        }
    
    def store_adapter_state(self):
        """Delegate to forgetting adapter"""
        self.forgetting_adapter.store_adapter_state()
    
    def compute_forgetting_scores(self):
        """Delegate to forgetting adapter"""
        self.forgetting_adapter.compute_forgetting_scores()

    @property
    def current_forgetting_scores(self):
        """Access forgetting scores from adapter"""
        return self.forgetting_adapter.current_forgetting_scores


#### testing

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Initialize your model
# cf_model = CatastrophicForgettingDisentanglementModel(
#     backbone_output_channels=640,
#     branch_hidden_channels=256,
#     branch_output_channels=128,
#     num_outputs=9
# ).to(device)

# # Set model to evaluation mode (important for testing)
# cf_model.eval()

# # Create dummy input tensor matching your expected input size
# # Your model expects: [batch_size, 3, height, width]
# # Using 224x224 as we discussed for memory efficiency
# batch_size = 2
# dummy1 = torch.randn(batch_size, 3, 224, 224).to(device)
# dummy2 = torch.randn(batch_size, 3, 224, 224).to(device)
# dummy_input = (dummy1, dummy2)

# print(f"Input shape: {dummy_input[0].shape}")

# # Test forward pass without gating (first domain)
# with torch.no_grad():
#     outputs_no_gating = cf_model(dummy_input, use_gating=False)

# print("✅ Forward pass without gating successful!")
# print(f"Output shape: {outputs_no_gating['output'].shape}")
# print(f"Adapter output shape: {outputs_no_gating['adapter_output'].shape}")
# print(f"Backbone features shape: {outputs_no_gating['backbone_features'].shape}")

# # Test forward pass with gating (after first domain)
# # First, simulate having forgetting scores
# cf_model.forgetting_adapter.current_forgetting_scores = torch.randn(640).to(device)

# with torch.no_grad():
#     outputs_with_gating = cf_model(dummy_input, use_gating=True)

# print("✅ Forward pass with gating successful!")
# print(f"Forgetting features shape: {outputs_with_gating['forgetting_features'].shape}")
# print(f"Not forgetting features shape: {outputs_with_gating['not_forgetting_features'].shape}")

# # Test forgetting measurement functions
# cf_model.store_adapter_state()
# print("✅ Adapter state stored successfully!")

# cf_model.compute_forgetting_scores()
# print("✅ Forgetting scores computed successfully!")
# print(f"Forgetting scores shape: {cf_model.current_forgetting_scores.shape}")

# outputs = cf_model(dummy_input, use_gating=True)
# print(outputs['output'].device)

#### helper function

In [None]:
def evaluate_model_cf(model, dataloader, criterion, device):
    """Modified evaluation function for catastrophic forgetting model"""
    model.eval()
    total_loss = 0.0
    total_samples = 0
    with torch.no_grad():
        for yolo_images, depth_images, labels, _ in dataloader:
            yolo_images = yolo_images.to(device, dtype=torch.float32)
            depth_images = depth_images.to(device, dtype=torch.float32)
            inputs = (yolo_images, depth_images)
            labels = labels.to(device, dtype=torch.float32)
            outputs = model(inputs, use_gating=True)  # Use gating during evaluation
            loss = criterion(outputs['output'], labels)
            total_loss += loss.item() * inputs[0].size(0)
            total_samples += inputs[0].size(0)
    return total_loss / total_samples

def cross_domain_validation_cf(model, domain_dataloaders, criterion, device):
    """Modified cross-domain validation for catastrophic forgetting model"""
    results = {}
    for domain, loaders in domain_dataloaders.items():
        val_loader = loaders['val']
        val_loss = evaluate_model_cf(model, val_loader, criterion, device)
        results[domain] = val_loss
    return results

def average_metrics(metrics_list):
    if not metrics_list:
        return {}
    keys = metrics_list[0].keys()
    avg_metrics = {}
    for k in keys:
        avg_metrics[k] = float(np.mean([m[k] for m in metrics_list if k in m]))
    return avg_metrics

#### training

In [None]:
def compute_gating_loss(forgetting_scores: torch.Tensor, 
                       gating_mask: torch.Tensor,
                       lambda_balance: float = 1.0) -> torch.Tensor:
    """Compute loss for gating network"""
    high_forgetting_features = forgetting_scores * gating_mask
    low_forgetting_features = forgetting_scores * (1 - gating_mask)
    
    high_mean = high_forgetting_features.sum() / (gating_mask.sum() + 1e-8)
    low_mean = low_forgetting_features.sum() / ((1 - gating_mask).sum() + 1e-8)
    
    split_loss = -(high_mean - low_mean)
    balance_loss = torch.abs(gating_mask.mean() - 0.5)
    
    total_loss = split_loss + lambda_balance * balance_loss
    return total_loss

def catastrophic_forgetting_batch(model, batch, device, loss_params={'main': 1.0, 'adapter': 0.5, 'gating': 0.1}, **kwargs):
    """CORRECTED: Batch function with proper adapter access"""
    yolo_images, depth_images, labels, domain_labels = batch
    yolo_images = yolo_images.to(device)
    depth_images = depth_images.to(device)
    inputs = (yolo_images, depth_images)
    labels = labels.to(device)
    
    mse_criterion = kwargs['mse_criterion']
    domain_idx = kwargs.get('domain_idx', 0)
    
    # Forward pass
    outputs = model(inputs, use_gating=(domain_idx > 0))
    
    # Main task loss (from final head)
    main_loss = mse_criterion(outputs['output'], labels)
    
    # Adapter loss (for monitoring forgetting)
    adapter_loss = mse_criterion(outputs['adapter_output'], labels)
    
    # Main loss backpropagation
    main_loss.backward(retain_graph=True) 
    
    # CORRECTED: Gating loss with proper path access
    gating_loss = torch.tensor(0.0, device=device)
    if domain_idx > 0 and model.current_forgetting_scores is not None:
        backbone_features = outputs['backbone_features']
        batch_size = backbone_features.size(0)
        forgetting_scores_batch = model.current_forgetting_scores.unsqueeze(0).expand(batch_size, -1).to(device)
        
        # FIXED: Use correct path to gating network
        gating_mask = model.forgetting_adapter.gating(forgetting_scores_batch)
        
        gating_loss = compute_gating_loss(
            model.current_forgetting_scores.to(device), 
            gating_mask.mean(dim=0)
        )
    
    metrics = {
        'main_loss': main_loss.item(),
        'adapter_loss': adapter_loss.item(),
        'gating_loss': gating_loss.item()
    }
    
    return main_loss, gating_loss, metrics

def catastrophic_forgetting_train_loop(
    model, domains, domain_dataloaders, buffer, optimizer, gating_optimizer, device,
    batch_fn, batch_kwargs, loss_params, num_epochs=5, exp_name="cf_exp", 
    gradient_clipping=False, restart={}
):
    """CORRECTED: Training loop with proper forgetting measurement"""
    start_domain_idx = 0
    global_step = 0
    history = {
        'train_epoch_loss': [],
        'val_epoch_loss': [],
        'train_epoch_metrics': [],
        'cross_domain_val': [],
        'grad_norms': [],
        'forgetting_scores_history': [],
        'gating_losses': []
    }
    
    if restart:
        global_step = restart['global_step']
        history = restart['history']
        start_domain_idx = np.where(domains == restart['domain'])[0][0]
        for domain_idx, current_domain in enumerate(domains[:start_domain_idx]):
            buffer.update_buffer(current_domain, domain_dataloaders[current_domain]['train'].dataset) 
        print(f"Restarting from domain {restart['domain']} index {start_domain_idx}")
        print(f"Buffer: {buffer.get_domain_distribution()}")         

    for domain_idx, current_domain in enumerate(tqdm(domains[start_domain_idx:], desc=f"Total training"), start=start_domain_idx):
        print(f"\n=== Training on Domain {domain_idx}: {current_domain} ===")
        
        # Store adapter state before training (for forgetting measurement)
        if domain_idx > 0:
            model.store_adapter_state()
        
        train_loader = buffer.get_loader_with_replay(current_domain, domain_dataloaders[current_domain]['train'])
        
        for epoch in trange(num_epochs, desc=f"Current domain {current_domain}"):
            model.train()
            epoch_loss = 0.0
            samples = 0
            batch_metrics_list = []
            
            for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Current epoch {epoch}", leave=False)):
                # Main model training
                optimizer.zero_grad()
                
                batch_kwargs_with_domain = {**batch_kwargs, 'current_domain': current_domain, 'domain_idx': domain_idx}
                
                main_loss, gating_loss, metrics = batch_fn(model, batch, device, loss_params, **batch_kwargs_with_domain)
                
                if gradient_clipping:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                
                # Gating training (in same batch, separate optimizer)
                if domain_idx > 0 and gating_loss.item() > 0:
                    gating_optimizer.zero_grad()
                    gating_loss.backward()
                    gating_optimizer.step()
                
                batch_size = batch[0].size(0)
                epoch_loss += main_loss.item() * batch_size
                samples += batch_size
                global_step += 1
                batch_metrics_list.append(metrics)
                
            avg_epoch_loss = epoch_loss / samples
            history['train_epoch_loss'].append(avg_epoch_loss)
            
            avg_metrics = average_metrics(batch_metrics_list)
            history['train_epoch_metrics'].append(avg_metrics)
            
            grad_norms = collect_gradients(model)
            history['grad_norms'].append(grad_norms)
            
            # Validation on current domain
            val_loss = evaluate_model_cf(model, domain_dataloaders[current_domain]['val'], batch_kwargs['mse_criterion'], device)
            history['val_epoch_loss'].append(val_loss)
            
            # Cross-domain validation (after each domain)
            if epoch == num_epochs-1:
                cross_val = cross_domain_validation_cf(model, domain_dataloaders, batch_kwargs['mse_criterion'], device)
                history['cross_domain_val'].append(cross_val)

                # Compute forgetting scores after training on domain (for next domain)
                model.compute_forgetting_scores()
                history['forgetting_scores_history'].append(
                    model.current_forgetting_scores.clone() if model.current_forgetting_scores is not None else None
                )
                if model.current_forgetting_scores is not None:
                    print(f"Forgetting scores computed. Mean: {model.current_forgetting_scores.mean():.4f}, "
                        f"Std: {model.current_forgetting_scores.std():.4f}")
                

                # Save checkpoint
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'gating_optimizer_state_dict': gating_optimizer.state_dict(),
                    'history': history,
                    'forgetting_scores': model.current_forgetting_scores,
                }, f"../checkpoints/{exp_name}_domain{current_domain}_epoch{epoch}_step{global_step}.pt")
            
            with open(f"../checkpoints/{exp_name}_history.pkl", "wb") as f:
                pickle.dump(history, f)
        
        buffer.update_buffer(current_domain, domain_dataloaders[current_domain]['train'].dataset)
        print(f"Domain {domain_idx} completed. Buffer: {buffer.get_domain_distribution()}")
    
    return history
