## Data preparation

In [None]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
import random
import pickle
import datetime
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# DATASET_DIR = (Path("..") / ".." / "datasets").resolve()
# DATASETS = ["OFFICE-MANNERSDB", "MANNERSDBPlus"]
# LABEL_COLS = [
#     "Vaccum Cleaning", "Mopping the Floor", "Carry Warm Food",
#     "Carry Cold Food", "Carry Drinks", "Carry Small Objects",
#     "Carry Large Objects", "Cleaning", "Starting a conversation"
# ]

import sys
sys.path.append('..')

from data_processing.data_processing import ImageLabelDataset,DualImageDataset,create_dataloaders,create_crossvalidation_loaders

In [None]:
df = pd.read_pickle("../data/pepper_data.pkl")
def get_dataloader(df, batch_sizes=(32, 64, 64), resize_img_to=(224,224), return_splits=False, double_img=False, transforms=None, num_workers=0):
    """
    for yolo_depthanything model 224,224
    otherwise 512, 288
    LGR had 128,128 
    MobileNetv2 had 224, 224
    """

    # Create domain-specific dataloaders
    domains = df['domain'].unique()
    domain_dataloaders = {}
    test_split_idx = set()
    for domain in domains:
        domain_df = df[df['domain'] == domain]
        #domain_df = domain_df.sample(frac=0.5, random_state=42)
        loaders, split_idx = create_dataloaders(domain_df, batch_sizes=batch_sizes, resize_img_to=resize_img_to, seed=SEED, return_splits=True, double_img=double_img, transforms=transforms, num_workers=num_workers)
        domain_dataloaders[domain] = loaders
        test_split_idx.update(set(split_idx['test']))

    return domain_dataloaders if not return_splits else (domain_dataloaders, test_split_idx)

## CL Models

### base

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm, trange
import torch.optim as optim
from torchvision import models
import numpy as np
from collections import deque
import random
import time
import torch.nn.functional as F
from torch.utils.data import ConcatDataset, DataLoader, Subset
import torchvision.models.segmentation as segmentation
from collections import defaultdict

#### baselines

In [None]:
class SocialContinualModel(nn.Module):
    def __init__(self, num_tasks=9):
        super().__init__()
        # Initialize the ResNet50 architecture with Places365 configuration
        self.backbone = models.resnet50(num_classes=365)
        
        # get Places365 weights and fix their naming leftover from troch saving convention
        places365_weights = torch.load('resnet50_places365.pth.tar', weights_only=True)
        state_dict = places365_weights['state_dict']
        state_dict = {k.replace('module.', ''): v 
                     for k, v in state_dict.items()}
        
        # Load weights
        self.backbone.load_state_dict(state_dict)
        
        # Remove classification head
        self.backbone.fc = nn.Identity()

        #Freeze all params except last layer
        for name, param in self.backbone.named_parameters():
            if 'layer4' not in name:
                param.requires_grad_(False)
        
        #TODO human mask average, std, quadrants, human in realtion to robot std
        


        # Shared layers #TODO deeper shared space?
        self.shared_fc = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        # Task-specific heads with more expensive but finegrained GELU
        self.heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(1024, 512),
                nn.GELU(),
                nn.Linear(512, 1),
                nn.Sigmoid()
            ) for _ in range(num_tasks)
        ])

    def forward(self, x):
        features = self.backbone(x)
        shared = self.shared_fc(features)
        outputs = [head(shared) for head in self.heads]
        return torch.cat(outputs, dim=1)


In [None]:
import torch.nn as nn
from torchvision import models

class LGRBaseline(nn.Module):
    """
    @misc{churamani_feature_2024,
		title = {Feature Aggregation with Latent Generative Replay for Federated Continual Learning of Socially Appropriate Robot Behaviours},
		url = {http://arxiv.org/abs/2405.15773},
		doi = {10.48550/arXiv.2405.15773},
		number = {{arXiv}:2405.15773},
		publisher = {{arXiv}},
		author = {Churamani, Nikhil and Checker, Saksham and Chiang, Hao-Tien Lewis and Gunes, Hatice},
		urldate = {2025-01-30},
		date = {2024-03-16},
	}
    """
    def __init__(self, num_classes=9):
        super(LGRBaseline, self).__init__()
        
        # MobileNetV2 backbone
        self.backbone = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1).features
        
        # Backbone feature processing
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        
        # Regression head
        self.fc1_bn = nn.BatchNorm1d(1280)
        self.fc2 = nn.Linear(1280, 32)
        # self.fc2_bn = nn.BatchNorm1d(128)
		# self.fc3 = nn.Linear(128, 32)
        self.fc4 = nn.Linear(32, num_classes)

    def forward(self, x):
        # Feature extraction
        x = self.backbone(x)
        
        # Spatial reduction
        x = self.pool(x)
        x = self.flatten(x)
        
        # Regression
        x = self.fc1_bn(x)
        x = self.fc2(x)
        # x = self.fc2_bn(x)
		# x = self.fc3(x)
        x = self.fc4(x)
        
        return {'output': x}


#### buffers

In [None]:
import torch
import random
import pandas as pd

class ReservoirBuffer:
    def __init__(self, capacity=1000, replay_ratio=0.2, input_shape=(3, 384, 216), label_shape=(9,), device=torch.device('cpu')):
        self.capacity = capacity
        self.inputs = torch.empty((capacity, *input_shape), dtype=torch.float32, device=device)
        self.labels = torch.empty((capacity, *label_shape), dtype=torch.float32, device=device)
        self.domains = [None] * capacity
        self.size = 0          # Number of samples currently in buffer
        self.num_seen = 0      # Total samples seen
        self.replay_ratio = replay_ratio

    def add(self, new_samples):
        for sample in new_samples:
            self.num_seen += 1
            if self.size < self.capacity:
                idx = self.size
                self.size += 1
            else:
                idx = random.randint(0, self.num_seen - 1)
                if idx >= self.capacity:
                    continue
            self.inputs[idx].copy_(sample[0])
            self.labels[idx].copy_(sample[1])
            self.domains[idx] = sample[2]

    def sample(self, batch_size):
        if self.size == 0:
            return []
        indices = torch.randint(0, self.size, (batch_size,))
        return [(self.inputs[i], self.labels[i], self.domains[i]) for i in indices]

    def __len__(self):
        return self.size

    def get_domain_distribution(self):
        return pd.Series(self.domains[:self.size]).value_counts()


In [None]:
# from collections import Counter
# from itertools import product

# df = pd.read_pickle("../data/pepper_data.pkl")
# domain_dataloaders = get_dataloader(df, batch_sizes=(32, 64, 64), resize_img_to=(1,1), return_splits=False, double_img=False)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# domains = df['domain'].unique()

# versions = list(product([1000, 500, 200, 100], ['', 'downsample_buffer', 'upsample_current']))

# for buffer_size, balanced in tqdm(versions):
#     print(f"{buffer_size=} {balanced}\n")

#     buffer = NaiveRehearsalBuffer(buffer_size, balanced)
#     for domain_idx, current_domain in enumerate(tqdm(domains)):
#         train_loader = buffer.get_loader_with_replay(current_domain, domain_dataloaders[current_domain]['train'])
        
#         counter = Counter()
#         for _,_,labels_batch in tqdm(train_loader, leave=False):  # batch of (images, labels, domain_labels)
#             counter.update(labels_batch)
        
#         ratio = {key: counter.get(key, 0) for key in domains}
#         print(f"{current_domain}: {ratio}")

#         buffer.update_buffer(current_domain, domain_dataloaders[current_domain]['train'].dataset)


In [None]:
class NaiveRehearsalBuffer:
    """
    @inproceedings{Hsu18_EvalCL,
        title={Re-evaluating Continual Learning Scenarios: A Categorization and Case for Strong Baselines},
        author={Yen-Chang Hsu and Yen-Cheng Liu and Anita Ramasamy and Zsolt Kira},
        booktitle={NeurIPS Continual learning Workshop },
        year={2018},
        url={https://arxiv.org/abs/1810.12488}
    }
    """
    def __init__(self, buffer_size=1000, balancing=True):
        self.buffer_size = buffer_size
        self.buffer = {}
        self.balancing=balancing

    def update_buffer(self, current_domain, current_dataset):
        # Add/overwrite current domain
        self.buffer[current_domain] = Subset(current_dataset, torch.arange(len(current_dataset)))
        
        # Recalculate quota - even for each domain
        num_domains = len(self.buffer)
        buffer_quota_per_domain = self.buffer_size // num_domains
        
        # Reduce all domains (including current)
        for domain in self.buffer:
            domain_buffer = self.buffer[domain]
            max_safe_samples_to_overwrite = min(buffer_quota_per_domain, len(domain_buffer.dataset))
            rand_indices = torch.randperm(len(domain_buffer.dataset))[:max_safe_samples_to_overwrite].numpy()
            self.buffer[domain] = Subset(domain_buffer.dataset, rand_indices)
    
    def get_loader_with_replay(self, current_domain, current_loader):
        if self.balancing:
            return self.balanced_combine_training_and_replay_loader(current_domain, current_loader)
        else:
            return self.combine_training_and_replay_loader(current_domain, current_loader)
    
    def balanced_combine_training_and_replay_loader(self, current_domain, current_loader):
        current_dataset = current_loader.dataset
        replay_datasets = [dataset for domain, dataset in self.buffer.items() if domain != current_domain]

        if not replay_datasets:
            return current_loader
        
        current_size = len(current_dataset)
        total_replay_size = sum(len(d) for d in replay_datasets)
        samples_per_domain = current_size // len(replay_datasets)
        replay_subsets = []
        
        # Case 1: Small buffer - upsample the buffer - domain stratified upsampling with replacement
        if total_replay_size < current_size:
            for domain_data in replay_datasets:
                indices = torch.randint(0, len(domain_data), (samples_per_domain,))
                replay_subsets.append(Subset(domain_data, indices))
            replay_dataset = ConcatDataset(replay_subsets)
        
        # Case 2: Large buffer - upsample the training data - upsampling with replacement
        elif (total_replay_size > current_size):
            indices = torch.randint(0, current_size, (total_replay_size,))
            current_dataset = Subset(current_dataset, indices)
            replay_dataset = ConcatDataset(replay_datasets)
        
        # Case 3: Large buffer - downsample the buffer - domain stratified downsampling
        # Obsolete - just set the buffer size smaller
        # elif (total_replay_size > current_size) and (self.balancing == 'downsample_buffer'):
        #     for domain_data in replay_datasets:
        #         indices = torch.randperm(len(domain_data))[:samples_per_domain]
        #         replay_subsets.append(Subset(domain_data, indices))
        #     replay_dataset = ConcatDataset(replay_subsets)
        
        # Case 3: Replay buffer = training dataset
        else: 
            replay_dataset = ConcatDataset(replay_datasets)

        combined_dataset = ConcatDataset([replay_dataset, current_dataset])
        combined_loader = DataLoader(
            combined_dataset,
            batch_size=current_loader.batch_size,
            shuffle=True,
            num_workers=current_loader.num_workers,
            pin_memory=current_loader.pin_memory,
            drop_last=current_loader.drop_last
        )
        return combined_loader

    def combine_training_and_replay_loader(self, current_domain, current_loader):
        current_dataset = current_loader.dataset
        replay_datasets = [dataset for domain, dataset in self.buffer.items() if domain != current_domain]

        #Enforces 1:1 ratio when current ≥ buffer
        total_replay = sum(len(dataset) for dataset in replay_datasets)
        if total_replay > 0:
            K = max(len(current_dataset) // total_replay, 1)
            replay_datasets = replay_datasets * K

        combined_dataset =  ConcatDataset(replay_datasets + [current_dataset])
        combined_dataset = DataLoader(
            combined_dataset,
            batch_size=current_loader.batch_size,
            shuffle=True,
            num_workers=current_loader.num_workers,
            pin_memory=current_loader.pin_memory,
            drop_last=current_loader.drop_last
        )
        return combined_dataset
    
    def get_domain_distribution(self):
        """Returns {domain: num_samples} without needing Storage"""
        return {domain: len(subset) for domain, subset in self.buffer.items()}



In [None]:
class NonstratifiedNaiveRehearsalBuffer:
    def __init__(self, buffer_size=1000):
        self.buffer_size = buffer_size
        self.buffer = []  # List of (dataset_object, idx) tuples

    def update_buffer(self, current_dataset):
        # Add all new samples as (dataset_object, idx) pairs
        new_samples = [(current_dataset, idx) for idx in range(len(current_dataset))]
        self.buffer += new_samples

        # If buffer is too large, randomly keep only buffer_size samples
        if len(self.buffer) > self.buffer_size:
            perm = torch.randperm(len(self.buffer))[:self.buffer_size]
            self.buffer = [self.buffer[i] for i in perm]

    def get_loader_with_replay(self, current_loader):
        current_dataset = current_loader.dataset

        # Group buffer samples by dataset for efficient Subset creation       
        dataset_to_indices = defaultdict(list)
        for dataset, idx in self.samples:
            dataset_to_indices[dataset].append(idx)
        buffer_subsets = [Subset(ds, idxs) for ds, idxs in dataset_to_indices.items()]
        
        if not buffer_subsets:
            return current_loader
        
        replay_dataset = ConcatDataset(buffer_subsets)

        current_size = len(current_dataset)
        total_replay_size = len(replay_dataset)

        N_max = max(total_replay_size, current_size)

        # Upsample with replacement if needed
        if total_replay_size < current_size:
            idxs = torch.randint(0, total_replay_size, (current_size,))
            replay_dataset = Subset(replay_dataset, idxs)
        if total_replay_size > current_size:
            idxs = torch.randint(0, current_size, (total_replay_size,))
            current_dataset = Subset(current_dataset, idxs)

        combined_dataset = ConcatDataset([replay_dataset, current_dataset])
        combined_loader =  DataLoader(
            combined_dataset,
            batch_size=current_loader.batch_size,
            shuffle=True,
            num_workers=current_loader.num_workers,
            pin_memory=current_loader.pin_memory,
            drop_last=current_loader.drop_last
        )
        return combined_loader


    def __len__(self):
        return len(self.buffer)
    
    def get_domain_distribution(self):
        return {"buffer": len(self.buffer)}


#### dualbranch

In [None]:
class GradientReversalFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)
    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha
        return output, None

class GradientReversal(nn.Module):
    def forward(self, x):
        return GradientReversalFunction.apply(x, 1.0)


class DualBranchNet(nn.Module):
    def __init__(self, num_outputs=9, num_domains=6, weights_init=False, weights_norm=False, layer_norm=False, detach_base=True, freeze_base=''):
        super().__init__()

        self.detach_base = detach_base

        def linear(in_f, out_f):
            layer = nn.Linear(in_f, out_f)
            return torch.nn.utils.parametrizations.weight_norm(layer, dim=0) if weights_norm else layer 
        
        def layer_norm(dim):
            return nn.LayerNorm(dim) if layer_norm else nn.Identity()

        self.backbone = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1).features
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.feature_dim = 1280

        if freeze_base in ('full', 'partial'):
            for param in self.backbone.parameters():
                param.requires_grad = False
        if freeze_base == 'partial':
            last_layer = list(self.backbone.children())[-1]
            for param in last_layer.parameters():
                param.requires_grad = True

        self.invariant = nn.Sequential(
            linear(self.feature_dim, 256),
            GradientReversal(),
            layer_norm(256),
            nn.ReLU(),
            linear(256, 256)
        )

        self.invariant_domain_classifier = nn.Sequential(
            layer_norm(256),
            nn.Linear(256, num_domains)
        )
        
        self.specific = nn.Sequential(
            linear(self.feature_dim, 256),
            layer_norm(256), 
            nn.ReLU(),
            linear(256, 256)
        )
        
        self.specific_domain_classifier = nn.Sequential(
            layer_norm(256),
            nn.Linear(256, num_domains)
        )

        head_in = 512 + self.feature_dim if self.detach_base else 512
        self.head = nn.Sequential(
            nn.Linear(head_in, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, num_outputs)
        )

        if weights_init:
            self._init_weights()

    def _init_weights(self):
        custom_modules = [
            'invariant',
            'invariant_domain_classifier',
            'specific',
            'specific_domain_classifier',
            'head'
        ]
        for name in custom_modules:
            module = getattr(self, name, None)
            for m in module.modules():
                if isinstance(m, nn.Linear):
                    nn.init.orthogonal_(m.weight)
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        base = self.backbone(x)
        base = self.pool(base).view(x.size(0), -1)

        if self.detach_base:
            invariant_feats = self.invariant(base.detach())
            specific_feats = self.specific(base.detach())
            combined = torch.cat([invariant_feats, specific_feats, base], dim=1)
        else:
            invariant_feats = self.invariant(base)
            specific_feats = self.specific(base)
            combined = torch.cat([invariant_feats, specific_feats], dim=1)
        
        invariant_domain_pred = self.invariant_domain_classifier(invariant_feats)
        specific_domain_pred = self.specific_domain_classifier(specific_feats)
        
        scores = self.head(combined)      
        
        return {
            'output': scores,
            'invariant_domain': invariant_domain_pred,
            'specific_domain': specific_domain_pred,
            'invariant_feats': invariant_feats,
            'specific_feats': specific_feats
        }


In [None]:
class DualBranchNet_deep(DualBranchNet):
    """thicker invariant network to learn features instead of realying on the backbone"""
    def __init__(self, num_outputs=9, num_domains=6, weights_init=True, weights_norm=False, layer_norm=False, detach_base=True, freeze_base=''):
        super().__init__(num_outputs, num_domains, weights_init, weights_norm, layer_norm, detach_base, freeze_base)

        def linear(in_f, out_f):
            layer = nn.Linear(in_f, out_f)
            return torch.nn.utils.parametrizations.weight_norm(layer, dim=0) if weights_norm else layer 

        self.invariant = nn.Sequential(
            linear(self.feature_dim, 512),
            nn.ReLU(),
            linear(512, 256),
            nn.ReLU(),
            linear(256, 256)
        )

        self.specific = nn.Sequential(
            linear(self.feature_dim, 512),
            nn.ReLU(),
            linear(512, 256),
            nn.ReLU(),
            linear(256, 256)
        )

        self.invariant_domain_classifier = nn.Sequential(
            GradientReversal(),
            nn.Linear(256, num_domains)
        )

In [None]:
class DualBranchNet_binary(DualBranchNet_deep):
    """binary classifier for the invariant branch"""
    def __init__(self, num_outputs=9, num_domains=6, weights_init=True, weights_norm=False, layer_norm=False, detach_base=False, freeze_base='', explicit_grl=False):
        super().__init__(num_outputs, num_domains, weights_init, weights_norm, layer_norm, detach_base, freeze_base)

        self.explicit_grl = explicit_grl
        
        def gradient_layer():
            return nn.Identity() if explicit_grl else GradientReversal()

        self.invariant_domain_classifier = nn.Sequential(
            gradient_layer(),
            nn.Linear(256, 1)
        )
    
    def forward(self, x):
        base = self.backbone(x)
        base = self.pool(base).view(x.size(0), -1)

        if self.detach_base:
            invariant_feats = self.invariant(base.detach())
            specific_feats = self.specific(base.detach())
            combined = torch.cat([invariant_feats, specific_feats, base], dim=1)
        else:
            invariant_feats = self.invariant(base)
            specific_feats = self.specific(base)
            combined = torch.cat([invariant_feats, specific_feats], dim=1)
        
        if self.explicit_grl:
            invariant_domain_pred = self.invariant_domain_classifier(GradientReversalFunction.apply(invariant_feats, 1.0))
        else:
            invariant_domain_pred = self.invariant_domain_classifier(invariant_feats)
            
        specific_domain_pred = self.specific_domain_classifier(specific_feats)
        
        scores = self.head(combined)      
        
        return {
            'output': scores,
            'invariant_domain': invariant_domain_pred,
            'specific_domain': specific_domain_pred,
            'invariant_feats': invariant_feats,
            'specific_feats': specific_feats
        }


In [None]:
class DualBranchCNNNet(nn.Module):
    def __init__(self, num_outputs=9, num_domains=6, backbone_type='3conv', branch_type='special', end_type='simple', batch_norm=False, detach_base=False):
        super().__init__()
        self.backbone_type = backbone_type
        self.branch_type = branch_type
        self.end_type = end_type
        self.detach_base = detach_base
        
        def BatchNorm(in_channels):
            return nn.BatchNorm2d(in_channels) if batch_norm else nn.Identity()

        #Input resized to 512,288

        if self.backbone_type == 'none':
            self.backbone = nn.Identity()
            self.backbone_channels = 3

        elif self.backbone_type == '2conv':
            self.backbone = nn.Sequential(
                nn.Conv2d(3, 32, 3, padding=1),
                BatchNorm(32),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(32, 64, 3, padding=1),
                BatchNorm(64),
                nn.ReLU(),
                nn.MaxPool2d(2)
            )
            self.backbone_channels = 64

        elif self.backbone_type == '3conv':
            self.backbone = nn.Sequential(
                nn.Conv2d(3, 32, 3, padding=1),
                BatchNorm(32),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(32, 64, 3, padding=1),
                BatchNorm(64),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(64, 128, 3, padding=1),
                BatchNorm(128),
                nn.ReLU(),
                nn.MaxPool2d(2)
            )
            self.backbone_channels = 128

        elif self.backbone_type == 'pretrained':
            self.backbone = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1).features
            if not self.detach_base:
                for param in self.backbone.parameters():
                    param.requires_grad = False
                
            self.backbone_channels = 1280

        else:
            raise ValueError("backbone must be 'none' | '2conv' | '3conv' | 'pretrained'")

        if self.branch_type == 'linear':
            self.social_branch = nn.Sequential(
                nn.AdaptiveAvgPool2d((1, 1)),
                nn.Flatten(),
                nn.Linear(self.backbone_channels, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, 256)
            )
            self.room_branch = nn.Sequential(
                nn.AdaptiveAvgPool2d((1, 1)),
                nn.Flatten(),
                nn.Linear(self.backbone_channels, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, 256)
            )
            self.branch_channels = 256

        elif self.branch_type == 'simple':
            self.room_branch = nn.Sequential(
                nn.Conv2d(self.backbone_channels, 128, 3, padding=1),
                BatchNorm(128),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d((4,4)),
                nn.Flatten()
            )
            self.social_branch = nn.Sequential(
                nn.Conv2d(self.backbone_channels, 128, 3, dilation=2, padding=2),
                BatchNorm(128),
                nn.ReLU(),
                nn.Conv2d(128, 128, 3, padding=1),
                BatchNorm(128),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d((4,4)),
                nn.Flatten()
            )
            self.branch_channels = 128*4*4

        elif self.branch_type == 'special':
            self.room_branch = nn.Sequential(
                CoordConv(self.backbone_channels, 128, 3),
                BatchNorm(128),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d((4,4)),
                nn.Flatten()
            )
            self.social_branch = nn.Sequential(
                nn.Conv2d(self.backbone_channels, 128, 3, dilation=2, padding=2),
                BatchNorm(128),
                nn.ReLU(),
                SpatialMultiheadAttention(128, 4),
                nn.Conv2d(128, 128, 3, padding=1),
                BatchNorm(128),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d((4,4)),
                nn.Flatten()
            )
            self.branch_channels = 128*4*4

        elif self.branch_type == 'adapted_adversarial':
            self.social_branch = nn.Sequential(
                nn.Conv2d(self.backbone_channels, 256, 5, padding=2),
                BatchNorm(256),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(256, 512, 5, padding=2),
                BatchNorm(512),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Flatten()
            )
            self.room_branch = nn.Sequential(
                nn.Conv2d(self.backbone_channels, 256, 5, padding=2),
                BatchNorm(256),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(256, 512, 5, padding=2),
                BatchNorm(512),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Flatten()
            )
            # Calculate output dimensions based on backbone output            
            adversarial_channels = {
                'none': 512 * 128 * 72,       # 4,718,592
                '2conv': 512 * 32 * 18,       # 294,912
                '3conv': 512 * 16 * 9,        # 73,728
                'pretrained': 512 * 4 * 2     # 4,096
            }
            self.branch_channels = adversarial_channels[self.backbone_type]

        elif self.branch_type == 'adversarial':
            self.social_branch = nn.Sequential(
                nn.Conv2d(self.backbone_channels, 96, 5, padding=2),
                BatchNorm(96),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(96, 144, 3, padding=1),
                BatchNorm(144),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(144, 256, 5, padding=2),
                BatchNorm(256),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Flatten()
            )
            self.room_branch = nn.Sequential(
                nn.Conv2d(self.backbone_channels, 96, 5, padding=2),
                BatchNorm(96),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(96, 144, 3, padding=1),
                BatchNorm(144),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(144, 256, 5, padding=2),
                BatchNorm(256),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Flatten()
            )
            # Calculate output dimensions based on backbone output
            adversarial_channels = {
                'none': 256 * 64 * 36,        # 589,824
                '2conv': 256 * 16 * 9,        # 36,864 
                '3conv': 256 * 8 * 4,         # 8,192
                'pretrained': 256 * 2 * 1     # 512
            }
            self.branch_channels = adversarial_channels[self.backbone_type]

        else:
            raise ValueError("branch must be 'linear' | 'simple' | 'special' | 'adversarial' | 'adapted_adversarial'")

        if self.end_type == 'simple':
            self.room_domain_cls = nn.Linear(self.branch_channels, num_domains)
            self.social_domain_cls = nn.Linear(self.branch_channels, 1)
            self.head = nn.Sequential(
                nn.Linear(self.branch_channels*2, 256),
                nn.ReLU(),
                nn.Linear(256, num_outputs)
            )
        elif self.end_type == 'adversarial':
            self.room_domain_cls = nn.Sequential(
                nn.Linear(self.branch_channels, 1024),
                nn.ReLU(),
                nn.Linear(1024, 1024),
                nn.ReLU(),
                nn.Linear(1024, num_domains)
            )
            self.social_domain_cls = nn.Sequential(
                nn.Linear(self.branch_channels, 1024),
                nn.ReLU(),
                nn.Linear(1024, 1024),
                nn.ReLU(),
                nn.Linear(1024, 1)
            )
            self.head = nn.Sequential(
                nn.Linear(self.branch_channels*2, 512),
                nn.ReLU(),
                nn.Linear(512, num_outputs)
            )
        
        elif self.end_type == '3linear':
            self.room_domain_cls = nn.Sequential(
                nn.Linear(self.branch_channels, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, num_domains)
            )
            self.social_domain_cls = nn.Sequential(
                nn.Linear(self.branch_channels, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, 1)
            )
            if self.detach_base:
                self.backbone_proj = nn.Sequential(
                    nn.AdaptiveAvgPool2d((4,4)),
                    nn.Flatten(),
                    nn.Linear(self.backbone_channels*4*4, self.branch_channels),
                    nn.ReLU(),
                )
                proj_backbone_channels = self.branch_channels
            total_channels = self.branch_channels*2
            total_channels += proj_backbone_channels if self.detach_base else 0
            self.head = nn.Sequential(
                nn.Linear(total_channels, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, num_outputs)
            )
        else:
            raise ValueError("end must be 'simple' | '3linear' | 'adversarial'")

    def forward(self, x, alpha=1.0):
        base = self.backbone(x)
        
        if self.detach_base:
            room_feat = self.room_branch(base.detach())
            social_feat = self.social_branch(base.detach())
            proj_base = self.backbone_proj(base)
            scores = self.head(torch.cat([room_feat, social_feat, proj_base], 1))
        else:
            room_feat = self.room_branch(base)
            social_feat = self.social_branch(base)
            scores = self.head(torch.cat([room_feat, social_feat], 1))

        room_domain_cls = self.room_domain_cls(room_feat)
        social_domain_cls = self.social_domain_cls(GradientReversalFunction.apply(social_feat, alpha))
        
        return {
            'output': scores,
            'invariant_domain': social_domain_cls,
            'specific_domain': room_domain_cls,
            'invariant_feats': social_feat,
            'specific_feats': room_feat
        }
        

In [None]:
class DualBranchNet_DANN(nn.Module):
    def __init__(self, num_outputs=9, backbone_type='3conv'):
        super().__init__()
        self.backbone_type = backbone_type

        #Input RGB of size 512,288

        if self.backbone_type == 'mobilenet':
            self.social_branch = nn.Sequential(
                models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1).features,
                nn.AdaptiveAvgPool2d((1, 1)),
                nn.Flatten()
            )
            self.backbone_channels = 1280

        elif self.backbone_type == '3conv':
            self.social_branch = nn.Sequential(
                nn.Conv2d(3, 96, 5, padding=2),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(96, 144, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(144, 256, 5, padding=2),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.AdaptiveAvgPool2d(output_size=(4, 4)),
                nn.Flatten()
            )
            self.backbone_channels = 256*4*4
        else:
            raise ValueError("backbone_type must be '3conv' | 'mobilenet'")

        self.social_domain_cls = nn.Sequential(
            nn.Linear(self.backbone_channels, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1)
        )

        self.head = nn.Sequential(
            nn.Linear(self.backbone_channels, 512),
            nn.ReLU(),
            nn.Linear(512, num_outputs)
        )

    def forward(self, x, alpha=1.0, is_first_domain=False):
        social_feat = self.social_branch(x)
        scores = self.head(social_feat)
        
        social_domain_cls = None if is_first_domain else self.social_domain_cls(GradientReversalFunction.apply(social_feat, alpha))
        
        return {
            'output': scores,
            'invariant_domain': social_domain_cls,
            'invariant_feats': social_feat
        }


In [None]:
class CoordConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super().__init__()
        self.conv = nn.Conv2d(in_channels + 2, 
                             out_channels, 
                             kernel_size, 
                             padding=kernel_size//2)

    def forward(self, x):
        batch, _, h, w = x.shape
        
        # Create coordinate grids (range [-1, 1])
        x_coord = torch.linspace(-1, 1, w).repeat(h, 1)
        y_coord = torch.linspace(-1, 1, h).repeat(w, 1).t()
        coords = torch.stack([x_coord, y_coord], dim=0)
        coords = coords.unsqueeze(0).repeat(batch, 1, 1, 1).to(x.device)
        x = torch.cat([x, coords], dim=1)
        
        return self.conv(x)


In [None]:
class SpatialMultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d((8, 8))
        self.mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        
    def forward(self, x):
        B, C, H, W = x.shape
        x_small = self.pool(x)
        # Reshape to sequence format
        x_flat = x_small.flatten(2).permute(0, 2, 1)
        attn_output, _ = self.mha(x_flat, x_flat, x_flat)
        # Reshape back to spatial format
        attn_reshaped = attn_output.permute(0, 2, 1).view(B, C, 8, 8)
        return F.interpolate(attn_reshaped, size=(H, W), mode='bilinear')

In [None]:
class DualBranchNet_minimal(nn.Module):
    def __init__(self, num_outputs=9, backbone_type='3conv', branch_type='simple', head_type='3layer', batch_norm=False):
        super().__init__()
        self.backbone_type = backbone_type
        self.branch_type = branch_type
        self.head_type = head_type
        
        def BatchNorm(in_channels):
            return nn.BatchNorm2d(in_channels) if batch_norm else nn.Identity()

        #Input resized to 512,288
        if self.backbone_type == 'absurd':
            self.backbone = nn.Sequential(
                nn.Conv2d(3, 16, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2, 2)
            ) 

        elif self.backbone_type == '3conv':
            self.backbone = nn.Sequential(
                nn.Conv2d(3, 32, 3, padding=1),
                BatchNorm(32),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(32, 64, 3, padding=1),
                BatchNorm(64),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(64, 128, 3, padding=1),
                BatchNorm(128),
                nn.ReLU(),
                nn.MaxPool2d(2)
            )
            self.backbone_channels = 128

        elif self.backbone_type == 'pretrained':
            self.backbone = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1).features
            for param in self.backbone.parameters():
                param.requires_grad = False
            self.backbone_channels = 1280

        if self.branch_type == 'simple':
            self.social_branch = nn.Sequential(
                nn.Conv2d(self.backbone_channels, 128, 3, dilation=2, padding=2),
                BatchNorm(128),
                nn.ReLU(),
                nn.Conv2d(128, 128, 3, padding=1),
                BatchNorm(128),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d((4,4)),
                nn.Flatten()
            )

        elif self.branch_type == 'special':
            self.social_branch = nn.Sequential(
                nn.Conv2d(self.backbone_channels, 128, 3, dilation=2, padding=2),
                BatchNorm(128),
                nn.ReLU(),
                SpatialMultiheadAttention(128, 4),
                nn.Conv2d(128, 128, 3, padding=1),
                BatchNorm(128),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d((4,4)),
                nn.Flatten()
            )
        elif self.branch_type == '1convo':
            self.social_branch = nn.Sequential(
                nn.Conv2d(self.backbone_channels, 128, 3, padding=1),
                BatchNorm(128),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d((4,4)),
                nn.Flatten()
            )
        elif self.branch_type == 'absurd':
            self.social_branch = nn.Sequential(
                nn.Conv2d(16, 32, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2, 2),
                nn.AdaptiveAvgPool2d((1, 1)),
                nn.Flatten()
            )

        if self.head_type == '3layer':
            self.head = nn.Sequential(
                nn.Linear(128*4*4, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, num_outputs)
            )
        elif self.head_type == '2layer256':
            self.head = nn.Sequential(
                nn.Linear(128*4*4, 256),
                nn.ReLU(),
                nn.Linear(256, num_outputs)
            )
        elif self.head_type == 'absurd':
            self.head = nn.Sequential(
                nn.Linear(32, num_outputs)
            )
    

    def forward(self, x, alpha=1.0):
        base = self.backbone(x)
    
        social_feat = self.social_branch(base)
        scores = self.head(social_feat)
        
        return {
            'output': scores,
            'invariant_feats': social_feat,
        }
        

In [None]:
import torchvision.models.segmentation as segmentation

class DANN_classifier_poc(nn.Module):
    def __init__(self):
        super().__init__()
        pretrained_model = segmentation.deeplabv3_mobilenet_v3_large(
            pretrained=True,
            weights=segmentation.DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT
        )
        self.backbone = pretrained_model.backbone
        self.classifier = pretrained_model.classifier
        self.bn = nn.BatchNorm1d(4116)

        self.social_domain_cls = nn.Sequential(
            nn.Linear(4116, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1)
        )
        self.head = nn.Sequential(
            nn.Linear(4116, 512),
            nn.ReLU(),
            nn.Linear(512, 9)
        )

    def forward(self, x, alpha=1.0, is_first_domain=False):
        features = self.backbone(x)['out']
        features = self.classifier(features)
        features = F.adaptive_avg_pool2d(features, (1, 1))
        features = torch.flatten(features, 1)
        features = self.bn(features)

        domain = None if is_first_domain else self.social_domain_cls(GradientReversalFunction.apply(features, alpha))
        output = self.head(features)
        return {
            'output': output,
            'invariant_domain': domain,
            'invariant_feats': features,
        }


### ForgettingGatedSplitModel

In [None]:
import os
import sys
import cv2
import pickle
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics import YOLO
from tqdm.notebook import tqdm, trange
from typing import Tuple, Optional
from collections import deque

sys.path.append('D:/projects/Depth-Anything-V2')
from depth_anything_v2.dpt import DepthAnythingV2

#### backbone

In [None]:
class YOLO11FeatureExtractor(nn.Module):
    """Extract features from YOLO11-seg backbone (layers 0-8)"""
    def __init__(self, model_path='../models/yolo11n-seg.pt'):
        super().__init__()
        yolo = YOLO(model_path)

        # Extract layers 0-8 as discussed (before SPPF to preserve spatial granularity)
        self.backbone = nn.Sequential(*list(yolo.model.model[:9]))
        del yolo
        
        # Freeze parameters
        for param in self.backbone.parameters():
            param.requires_grad = False
        
    def forward(self, x):
        return self.backbone(x)  # Output: [B, 256, H/32, W/32]

class DepthAnythingFeatureExtractor(nn.Module):
    def __init__(self, model_path='../models/depth_anything_v2_vits.pth'):
        super().__init__()
        model_configs = {
            'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
        }
        model = DepthAnythingV2(**model_configs['vits'])
        model.load_state_dict(torch.load(model_path, map_location='cpu'))
        
        self.backbone = model.pretrained
        del model

        for param in self.backbone.parameters():
            param.requires_grad = False
    
    def forward(self, x):
        B, C, H, W = x.shape
        
        with torch.no_grad():
            # Patch embedding
            x = self.backbone.patch_embed(x)
            
            # Pass through transformer blocks
            for block in self.backbone.blocks:
                x = block(x)
            
            # Apply final norm
            x = self.backbone.norm(x)
            
            # Remove CLS token and reshape to spatial format
            if x.shape[1] > (H//14) * (W//14):
                patch_tokens = x[:, 1:, :]  # Remove CLS token
            else:
                patch_tokens = x
            
            # Reshape to spatial format
            patch_h, patch_w = H // 14, W // 14
            spatial_features = patch_tokens.transpose(1, 2).reshape(B, 384, patch_h, patch_w)
            
            return spatial_features


class DualBackboneFeatureExtractor(nn.Module):
    """Combines YOLO11 and DepthAnything feature extractors"""
    def __init__(self, model_paths):
        super().__init__()
        # Individual feature extractors
        self.yolo_extractor = YOLO11FeatureExtractor(model_paths['YOLO11n_seg'])
        self.depth_extractor = DepthAnythingFeatureExtractor(model_paths['DepthAnythingV2_small'])
        
        # Feature projection layers (optional - can keep features as-is)
        self.yolo_proj = nn.Conv2d(256, 256, 1)
        self.depth_proj = nn.Conv2d(384, 384, 1)
        
    def forward(self, x):
        # Extract features from both backbones
        yolo_features = self.yolo_extractor(x[0])  # [B, 256, H/32, W/32]
        depth_features = self.depth_extractor(x[1])  # [B, 384, H/14, W/14]
        
        # Apply projections
        yolo_features = self.yolo_proj(yolo_features)
        depth_features = self.depth_proj(depth_features)
        
        # Downsample the depthanything features to match yolo
        depth_features = F.adaptive_avg_pool2d(depth_features, (7, 7)) #image size /32 224,224 > 7,7
        
        # Concatenate features
        combined_features = torch.cat([yolo_features, depth_features], dim=1)  # [B, 640, H/14, W/14] 
        
        return combined_features

#### model components

In [None]:
class CatastrophicForgettingAdapter(nn.Module):
    """Standalone adapter that measures forgetting and gates spatial features"""
    def __init__(self, num_channels=640, num_outputs=9):
        super().__init__()
        self.num_channels = num_channels
        
        # Forgetting measurement storage (per channel)
        self.previous_adapter_state = None
        self.current_forgetting_scores = None
        
        # Adapter that predicts main task (for forgetting measurement)
        self.adapter = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # Global average pooling
            nn.Flatten(),
            nn.Linear(num_channels, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_outputs)  # Predicts main task
        )
        
        # Learnable gating network (operates on channel-wise forgetting scores)
        self.gating = nn.Sequential(
            nn.Linear(num_channels, num_channels // 4),
            nn.ReLU(),
            nn.Linear(num_channels // 4, num_channels),
            nn.Sigmoid()
        )
        
    def forward(self, backbone_features, use_gating=True):
        """
        Args:
            backbone_features: [B, 640, 36, 36] spatial features
        Returns:
            forgetting_features: [B, 640, 36, 36] features routed to forgetting branch
            not_forgetting_features: [B, 640, 36, 36] features routed to stable branch
            adapter_predictions: [B, 9] task predictions for forgetting measurement
        """
        B, C, H, W = backbone_features.shape
        
        # Adapter makes predictions (for forgetting measurement)
        adapter_predictions = self.adapter(backbone_features)
        
        if use_gating and self.current_forgetting_scores is not None:
            # Apply channel-wise gating based on forgetting scores
            forgetting_scores_batch = self.current_forgetting_scores.unsqueeze(0).expand(B, -1)  # [B, 640]
            gating_mask = self.gating(forgetting_scores_batch)  # [B, 640]
            
            # Reshape for spatial broadcasting
            gating_mask = gating_mask.unsqueeze(2).unsqueeze(3)  # [B, 640, 1, 1]
            
            # Route features channel-wise
            forgetting_features = backbone_features * gating_mask
            not_forgetting_features = backbone_features * (1 - gating_mask)
        else:
            # No gating - both branches get all features
            forgetting_features = backbone_features
            not_forgetting_features = backbone_features
        
        return forgetting_features, not_forgetting_features, adapter_predictions
    
    def store_adapter_state(self):
        """Store current adapter state for forgetting measurement"""
        self.previous_adapter_state = {
            name: param.clone().detach() 
            for name, param in self.adapter.named_parameters()
        }
    
    def compute_forgetting_scores(self):
        """Compute channel-wise forgetting scores by comparing adapter states"""
        # Get the device from model parameters
        device = next(self.parameters()).device
        
        if self.previous_adapter_state is None:
            # First domain - no forgetting to measure
            self.current_forgetting_scores = torch.zeros(self.num_channels, device=device)
            return
        
        # Compare adapter weights to measure which channels are most affected
        forgetting_scores = []
        
        for name, current_param in self.adapter.named_parameters():
            if name in self.previous_adapter_state:
                previous_param = self.previous_adapter_state[name]
                
                if 'weight' in name and len(current_param.shape) == 2:
                    # For first linear layer after pooling: [256, 640]
                    if current_param.shape[1] == self.num_channels:
                        # Compute change per input channel
                        weight_change = torch.norm(current_param - previous_param, dim=0)
                        forgetting_scores.append(weight_change)
                        break  # Use only first layer that maps from channels
        
        # Use channel-wise forgetting scores
        if forgetting_scores:
            self.current_forgetting_scores = forgetting_scores[0].to(device)
            
            # Normalize to [0, 1]
            if self.current_forgetting_scores.max() > 0:
                self.current_forgetting_scores = (
                    self.current_forgetting_scores / self.current_forgetting_scores.max()
                )
        else:
            # Fallback if no forgetting scores computed
            self.current_forgetting_scores = torch.zeros(self.num_channels, device=device)

class ConvBranch(nn.Module):
    """Convolutional processing branch for spatial features"""
    def __init__(self, input_channels=640, hidden_channels=256, output_channels=128):
        super().__init__()
        self.conv_layers = nn.Sequential(
            # First conv block
            nn.Conv2d(input_channels, hidden_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),
            
            # Second conv block
            nn.Conv2d(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),
            
            # Output conv block
            nn.Conv2d(hidden_channels, output_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True),
            
            # Global pooling to get fixed-size output
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten()  # [B, output_channels]
        )
        
    def forward(self, x):
        return self.conv_layers(x)
    
    
class FeatureFusion(nn.Module):
    """Attention-based fusion of branch features"""
    def __init__(self, feature_dim=256):
        super().__init__()
        
        self.attention = nn.Sequential(
            nn.Linear(feature_dim * 2, feature_dim),
            nn.ReLU(),
            nn.Linear(feature_dim, 2),
            nn.Softmax(dim=1)
        )
        
    def forward(self, high_features, low_features):
        # Concatenate features
        combined = torch.cat([high_features, low_features], dim=1)  # [B, 512]
        
        # Compute attention weights
        attention_weights = self.attention(combined)  # [B, 2]
        
        # Apply attention
        weighted_high = high_features * attention_weights[:, 0:1]
        weighted_low = low_features * attention_weights[:, 1:2]
        
        # Combine with residual connection
        fused = weighted_high + weighted_low
        
        return fused


#### model

In [None]:
class CatastrophicForgettingDisentanglementModel(nn.Module):
    """Updated model with spatial features and convolutional branches"""
    
    def __init__(self, 
                 backbone_output_channels: int = 640,
                 branch_hidden_channels: int = 256,
                 branch_output_channels: int = 128,
                 num_outputs: int = 9):
        super().__init__()
        
        print("🚀 Initializing Spatial Catastrophic Forgetting Model...")
        
        # Dual backbone (frozen) - now outputs spatial features
        model_paths={
            'YOLO11n_seg': '../models/yolo11n-seg.pt', 
            'DepthAnythingV2_small': '../models/depth_anything_v2_vits.pth'
            }
        missing = [name for name, path in model_paths.items() if not os.path.exists(path)]
        if missing:
            raise FileNotFoundError(f"Missing model files for: {', '.join(missing)}")

        self.backbone = DualBackboneFeatureExtractor(model_paths=model_paths)  # [B, 640, 36, 36]
        
        # Standalone forgetting adapter
        self.forgetting_adapter = CatastrophicForgettingAdapter(
            num_channels=backbone_output_channels,
            num_outputs=num_outputs
        )
        
        # Convolutional branches for spatial processing
        self.branch_forgetting = ConvBranch(
            backbone_output_channels, 
            branch_hidden_channels, 
            branch_output_channels
        )
        self.branch_not_forgetting = ConvBranch(
            backbone_output_channels, 
            branch_hidden_channels, 
            branch_output_channels
        )
        
        self.fusion = FeatureFusion(branch_output_channels)

        # Final head combines branch outputs
        self.head =  nn.Sequential(
            nn.Linear(branch_output_channels, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(256, num_outputs)
        )
        
        print("✅ Spatial model initialized successfully!")
        
    def forward(self, x, use_gating: bool = True):
        # Extract spatial features from frozen backbone
        backbone_features = self.backbone(x)  # [B, 640, 36, 36]
        
        # Process through forgetting adapter (gating + task prediction)
        forgetting_features, not_forgetting_features, adapter_predictions = \
            self.forgetting_adapter(backbone_features, use_gating)
        
        # Process through convolutional branches
        forgetting_output = self.branch_forgetting(forgetting_features)      # [B, 128]
        not_forgetting_output = self.branch_not_forgetting(not_forgetting_features)  # [B, 128]
        
        # Combine branch outputs for final prediction
        fused_features = self.fusion(forgetting_output, not_forgetting_output)
        final_output = self.head(fused_features)
        
        return {
            'output': final_output,
            'adapter_output': adapter_predictions,
            'backbone_features': backbone_features,
            'forgetting_features': forgetting_features,
            'not_forgetting_features': not_forgetting_features
        }
    
    def store_adapter_state(self):
        """Delegate to forgetting adapter"""
        self.forgetting_adapter.store_adapter_state()
    
    def compute_forgetting_scores(self):
        """Delegate to forgetting adapter"""
        self.forgetting_adapter.compute_forgetting_scores()

    @property
    def current_forgetting_scores(self):
        """Access forgetting scores from adapter"""
        return self.forgetting_adapter.current_forgetting_scores


#### testing

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Initialize your model
# cf_model = CatastrophicForgettingDisentanglementModel(
#     backbone_output_channels=640,
#     branch_hidden_channels=256,
#     branch_output_channels=128,
#     num_outputs=9
# ).to(device)

# # Set model to evaluation mode (important for testing)
# cf_model.eval()

# # Create dummy input tensor matching your expected input size
# # Your model expects: [batch_size, 3, height, width]
# # Using 224x224 as we discussed for memory efficiency
# batch_size = 2
# dummy1 = torch.randn(batch_size, 3, 224, 224).to(device)
# dummy2 = torch.randn(batch_size, 3, 224, 224).to(device)
# dummy_input = (dummy1, dummy2)

# print(f"Input shape: {dummy_input[0].shape}")

# # Test forward pass without gating (first domain)
# with torch.no_grad():
#     outputs_no_gating = cf_model(dummy_input, use_gating=False)

# print("✅ Forward pass without gating successful!")
# print(f"Output shape: {outputs_no_gating['output'].shape}")
# print(f"Adapter output shape: {outputs_no_gating['adapter_output'].shape}")
# print(f"Backbone features shape: {outputs_no_gating['backbone_features'].shape}")

# # Test forward pass with gating (after first domain)
# # First, simulate having forgetting scores
# cf_model.forgetting_adapter.current_forgetting_scores = torch.randn(640).to(device)

# with torch.no_grad():
#     outputs_with_gating = cf_model(dummy_input, use_gating=True)

# print("✅ Forward pass with gating successful!")
# print(f"Forgetting features shape: {outputs_with_gating['forgetting_features'].shape}")
# print(f"Not forgetting features shape: {outputs_with_gating['not_forgetting_features'].shape}")

# # Test forgetting measurement functions
# cf_model.store_adapter_state()
# print("✅ Adapter state stored successfully!")

# cf_model.compute_forgetting_scores()
# print("✅ Forgetting scores computed successfully!")
# print(f"Forgetting scores shape: {cf_model.current_forgetting_scores.shape}")

# outputs = cf_model(dummy_input, use_gating=True)
# print(outputs['output'].device)

#### helper function

In [None]:
def evaluate_model_cf(model, dataloader, criterion, device):
    """Modified evaluation function for catastrophic forgetting model"""
    model.eval()
    total_loss = 0.0
    total_samples = 0
    with torch.no_grad():
        for yolo_images, depth_images, labels, _ in dataloader:
            yolo_images = yolo_images.to(device, dtype=torch.float32)
            depth_images = depth_images.to(device, dtype=torch.float32)
            inputs = (yolo_images, depth_images)
            labels = labels.to(device, dtype=torch.float32)
            outputs = model(inputs, use_gating=True)  # Use gating during evaluation
            loss = criterion(outputs['output'], labels)
            total_loss += loss.item() * inputs[0].size(0)
            total_samples += inputs[0].size(0)
    return total_loss / total_samples

def cross_domain_validation_cf(model, domain_dataloaders, criterion, device):
    """Modified cross-domain validation for catastrophic forgetting model"""
    results = {}
    for domain, loaders in domain_dataloaders.items():
        val_loader = loaders['val']
        val_loss = evaluate_model_cf(model, val_loader, criterion, device)
        results[domain] = val_loss
    return results

def average_metrics(metrics_list):
    if not metrics_list:
        return {}
    keys = metrics_list[0].keys()
    avg_metrics = {}
    for k in keys:
        avg_metrics[k] = float(np.mean([m[k] for m in metrics_list if k in m]))
    return avg_metrics

#### training

In [None]:
def compute_gating_loss(forgetting_scores: torch.Tensor, 
                       gating_mask: torch.Tensor,
                       lambda_balance: float = 1.0) -> torch.Tensor:
    """Compute loss for gating network"""
    high_forgetting_features = forgetting_scores * gating_mask
    low_forgetting_features = forgetting_scores * (1 - gating_mask)
    
    high_mean = high_forgetting_features.sum() / (gating_mask.sum() + 1e-8)
    low_mean = low_forgetting_features.sum() / ((1 - gating_mask).sum() + 1e-8)
    
    split_loss = -(high_mean - low_mean)
    balance_loss = torch.abs(gating_mask.mean() - 0.5)
    
    total_loss = split_loss + lambda_balance * balance_loss
    return total_loss

def catastrophic_forgetting_batch(model, batch, device, loss_params={'main': 1.0, 'adapter': 0.5, 'gating': 0.1}, **kwargs):
    """CORRECTED: Batch function with proper adapter access"""
    yolo_images, depth_images, labels, domain_labels = batch
    yolo_images = yolo_images.to(device)
    depth_images = depth_images.to(device)
    inputs = (yolo_images, depth_images)
    labels = labels.to(device)
    
    mse_criterion = kwargs['mse_criterion']
    domain_idx = kwargs.get('domain_idx', 0)
    
    # Forward pass
    outputs = model(inputs, use_gating=(domain_idx > 0))
    
    # Main task loss (from final head)
    main_loss = mse_criterion(outputs['output'], labels)
    
    # Adapter loss (for monitoring forgetting)
    adapter_loss = mse_criterion(outputs['adapter_output'], labels)
    
    # Main loss backpropagation
    main_loss.backward(retain_graph=True) 
    
    # CORRECTED: Gating loss with proper path access
    gating_loss = torch.tensor(0.0, device=device)
    if domain_idx > 0 and model.current_forgetting_scores is not None:
        backbone_features = outputs['backbone_features']
        batch_size = backbone_features.size(0)
        forgetting_scores_batch = model.current_forgetting_scores.unsqueeze(0).expand(batch_size, -1).to(device)
        
        # FIXED: Use correct path to gating network
        gating_mask = model.forgetting_adapter.gating(forgetting_scores_batch)
        
        gating_loss = compute_gating_loss(
            model.current_forgetting_scores.to(device), 
            gating_mask.mean(dim=0)
        )
    
    metrics = {
        'main_loss': main_loss.item(),
        'adapter_loss': adapter_loss.item(),
        'gating_loss': gating_loss.item()
    }
    
    return main_loss, gating_loss, metrics

def catastrophic_forgetting_train_loop(
    model, domains, domain_dataloaders, buffer, optimizer, gating_optimizer, device,
    batch_fn, batch_kwargs, loss_params, num_epochs=5, exp_name="cf_exp", 
    gradient_clipping=False, restart={}
):
    """CORRECTED: Training loop with proper forgetting measurement"""
    start_domain_idx = 0
    global_step = 0
    history = {
        'train_epoch_loss': [],
        'val_epoch_loss': [],
        'train_epoch_metrics': [],
        'cross_domain_val': [],
        'grad_norms': [],
        'forgetting_scores_history': [],
        'gating_losses': []
    }
    
    if restart:
        global_step = restart['global_step']
        history = restart['history']
        start_domain_idx = np.where(domains == restart['domain'])[0][0]
        for domain_idx, current_domain in enumerate(domains[:start_domain_idx]):
            buffer.update_buffer(current_domain, domain_dataloaders[current_domain]['train'].dataset) 
        print(f"Restarting from domain {restart['domain']} index {start_domain_idx}")
        print(f"Buffer: {buffer.get_domain_distribution()}")         

    for domain_idx, current_domain in enumerate(tqdm(domains[start_domain_idx:], desc=f"Total training"), start=start_domain_idx):
        print(f"\n=== Training on Domain {domain_idx}: {current_domain} ===")
        
        # Store adapter state before training (for forgetting measurement)
        if domain_idx > 0:
            model.store_adapter_state()
        
        train_loader = buffer.get_loader_with_replay(current_domain, domain_dataloaders[current_domain]['train'])
        
        for epoch in trange(num_epochs, desc=f"Current domain {current_domain}"):
            model.train()
            epoch_loss = 0.0
            samples = 0
            batch_metrics_list = []
            
            for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Current epoch {epoch}", leave=False)):
                # Main model training
                optimizer.zero_grad()
                
                batch_kwargs_with_domain = {**batch_kwargs, 'current_domain': current_domain, 'domain_idx': domain_idx}
                
                main_loss, gating_loss, metrics = batch_fn(model, batch, device, loss_params, **batch_kwargs_with_domain)
                
                if gradient_clipping:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                
                # Gating training (in same batch, separate optimizer)
                if domain_idx > 0 and gating_loss.item() > 0:
                    gating_optimizer.zero_grad()
                    gating_loss.backward()
                    gating_optimizer.step()
                
                batch_size = batch[0].size(0)
                epoch_loss += main_loss.item() * batch_size
                samples += batch_size
                global_step += 1
                batch_metrics_list.append(metrics)
                
            avg_epoch_loss = epoch_loss / samples
            history['train_epoch_loss'].append(avg_epoch_loss)
            
            avg_metrics = average_metrics(batch_metrics_list)
            history['train_epoch_metrics'].append(avg_metrics)
            
            grad_norms = collect_gradients(model)
            history['grad_norms'].append(grad_norms)
            
            # Validation on current domain
            val_loss = evaluate_model_cf(model, domain_dataloaders[current_domain]['val'], batch_kwargs['mse_criterion'], device)
            history['val_epoch_loss'].append(val_loss)
            
            # Cross-domain validation (after each domain)
            if epoch == num_epochs-1:
                cross_val = cross_domain_validation_cf(model, domain_dataloaders, batch_kwargs['mse_criterion'], device)
                history['cross_domain_val'].append(cross_val)

                # Compute forgetting scores after training on domain (for next domain)
                model.compute_forgetting_scores()
                history['forgetting_scores_history'].append(
                    model.current_forgetting_scores.clone() if model.current_forgetting_scores is not None else None
                )
                if model.current_forgetting_scores is not None:
                    print(f"Forgetting scores computed. Mean: {model.current_forgetting_scores.mean():.4f}, "
                        f"Std: {model.current_forgetting_scores.std():.4f}")
                

                # Save checkpoint
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'gating_optimizer_state_dict': gating_optimizer.state_dict(),
                    'history': history,
                    'forgetting_scores': model.current_forgetting_scores,
                }, f"../checkpoints/{exp_name}_domain{current_domain}_epoch{epoch}_step{global_step}.pt")
            
            with open(f"../checkpoints/{exp_name}_history.pkl", "wb") as f:
                pickle.dump(history, f)
        
        buffer.update_buffer(current_domain, domain_dataloaders[current_domain]['train'].dataset)
        print(f"Domain {domain_idx} completed. Buffer: {buffer.get_domain_distribution()}")
    
    return history


### Explicit Heuristic Split Model

#### ZoeDepth - HuggingFace

In [None]:
# import os
# from pathlib import Path
# import torch
# import numpy as np
# from transformers import AutoImageProcessor, ZoeDepthForDepthEstimation
# from PIL import Image
# from tqdm import tqdm
# from torchvision import transforms

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# output_dir = Path('../data/depth')
# output_dir.mkdir(parents=True, exist_ok=True)

# image_processor = AutoImageProcessor.from_pretrained("Intel/zoedepth-nyu-kitti", use_fast=True)
# model = ZoeDepthForDepthEstimation.from_pretrained("Intel/zoedepth-nyu-kitti").to(device).eval()

# # Prepare image paths list
# image_paths = df['image_path'].tolist()

# batch_size = 10  # or whatever batch size you want
# for batch_idx in tqdm(range(0, len(image_paths), batch_size)):
#     batch_paths = image_paths[batch_idx:batch_idx + batch_size]
    
#     # Load images as PIL Images (no manual transform)
#     batch_images = [Image.open(img_path).convert("RGB") for img_path in batch_paths]
    
#     # Preprocess with ZoeDepth image processor
#     inputs = image_processor(images=batch_images, return_tensors="pt").to(device)
    
#     with torch.no_grad():
#         outputs = model(**inputs)
    
#     # Post-process depth maps to original sizes
#     source_sizes = [(img.height, img.width) for img in batch_images]
#     post_processed = image_processor.post_process_depth_estimation(
#         outputs,
#         source_sizes=source_sizes
#     )
    
#     for i, depth_dict in enumerate(post_processed):
#         # Get raw depth map
#         depth_array = depth_dict["predicted_depth"].cpu().numpy()
#         img_stem = Path(batch_paths[i]).stem
#         np.save(output_dir / f"{img_stem}.npy", depth_array)
#         # Save visualization PNG
#         depth_norm = (depth_array - depth_array.min()) / (depth_array.max() - depth_array.min())
#         depth_img = Image.fromarray((depth_norm * 255).astype(np.uint8))
#         depth_img.save(output_dir / f"{img_stem}_depth.png")



#### segmentation generation

In [None]:
from autodistill_grounded_sam import GroundedSAM
from autodistill.detection import CaptionOntology
from autodistill.utils import plot
import cv2
import pickle
import bz2


import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
# define an ontology to map class names to our GroundedSAM prompt
# the ontology dictionary has the format {caption: class}
# where caption is the prompt sent to the base model, and class is the label that will
# be saved for that caption in the generated annotations
# then, load the model
base_model = GroundedSAM(
    ontology=CaptionOntology(
        {
            "human . child . person": "human",
            "robot": "robot",
            "dog": "dog"
        }
    )
)

# human : human, anima: animal, robot:robot

In [None]:
# # run inference on a single image
# results = base_model.predict(df['image_path'].iloc[0])

# plot(
#     image=cv2.imread(df['image_path'].iloc[0]),
#     classes=base_model.ontology.classes(),
#     detections=results
# )


In [None]:
with bz2.BZ2File('autodistill_dataset_home.pbz2', 'rb') as f:
    dataset_home = pickle.load(f)

In [None]:
# with bz2.BZ2File('autodistill_dataset.pbz2', 'wb') as f:
#     pickle.dump(dataset, f, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:

gc.collect()
torch.cuda.empty_cache()

In [None]:
# dataset_home = base_model.label("../../socialsense/data/images/home", extension=".png")
# with bz2.BZ2File('autodistill_dataset_home.pbz2', 'wb') as f:
#     pickle.dump(dataset_home, f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
with open("autodistill_temp_dataset.pkl", "rb") as f:
    dataset = pickle.load(f)

In [None]:
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import os
import supervision as sv
from tqdm.notebook import tqdm, trange

# Define a color palette for your classes
# Use RGBA for masks (with transparency), RGB for boxes (solid)
class_colors = {
    0: (255, 0, 0, 100),    # Red for class 0 (human)
    1: (0, 255, 0, 100),    # Green for class 1 (robot)
    2: (0, 0, 255, 100),    # Blue for class 2 (animal)
    3: (122, 122, 0, 100),
    # Add more if you have more classes
}

def visualize_and_save_pil_colored(dataset, confidence_threshold=0.3, output_dir="../data/temp_masks_coloured"):
    os.makedirs(output_dir, exist_ok=True)
    
    for idx, (path, image, detections) in enumerate(tqdm(dataset)):
        # Convert BGR to RGB for PIL (if your image is BGR)
        rgb_image = image[:, :, ::-1]
        pil_img = Image.fromarray(rgb_image)
        draw = ImageDraw.Draw(pil_img, "RGBA")
        
        # Filter detections by confidence
        keep_indices = [i for i, conf in enumerate(detections.confidence) if conf >= confidence_threshold]
        if not keep_indices:
            print(f"No detections above threshold for image {path}")
            continue

        filtered_boxes = detections.xyxy[keep_indices]
        filtered_masks = detections.mask[keep_indices]
        filtered_confidences = detections.confidence[keep_indices]
        filtered_class_ids = detections.class_id[keep_indices]
        
        # Overlay masks with transparency and class colors
        for i, (mask, conf) in enumerate(zip(filtered_masks, filtered_confidences)):
            class_id = filtered_class_ids[i] if filtered_class_ids is not None else 0
            color = class_colors.get(class_id, (255, 255, 255, 100))  # Default white if class unknown
            # Create a colored mask with alpha
            mask_img = Image.fromarray((mask * 255).astype(np.uint8), mode="L")
            colored_mask = Image.new("RGBA", pil_img.size, color)
            # Composite the colored mask onto the image with transparency
            pil_img = Image.alpha_composite(pil_img.convert("RGBA"), Image.composite(colored_mask, Image.new("RGBA", pil_img.size), mask_img))
        
        draw = ImageDraw.Draw(pil_img)
        
        # Draw bounding boxes and confidence with class colors
        for i, (box, conf) in enumerate(zip(filtered_boxes, filtered_confidences)):
            class_id = filtered_class_ids[i] if filtered_class_ids is not None else 0
            color = class_colors[class_id][:3]  # Use RGB for boxes (no alpha)
            x1, y1, x2, y2 = map(int, box)
            draw.rectangle([x1, y1, x2, y2], outline=color, width=2)
            draw.text((x1, y1 - 10*class_id), f"{class_id}: {conf:.2f}", fill=color)
        
        # Save the image as PNG
        filename = os.path.basename(path)
        save_path = os.path.join(output_dir, os.path.splitext(filename)[0] + ".png")
        pil_img.convert("RGB").save(save_path)


# Example usage



In [None]:
import numpy as np
from supervision import Detections

def calculate_iou(box1, box2):
    """Calculate Intersection over Union for two boxes [x1,y1,x2,y2]"""
    x1_inter = max(box1[0], box2[0])
    y1_inter = max(box1[1], box2[1])
    x2_inter = min(box1[2], box2[2])
    y2_inter = min(box1[3], box2[3])
    inter_area = max(0, x2_inter - x1_inter) * max(0, y2_inter - y1_inter)
    box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
    return inter_area / (box1_area + box2_area - inter_area + 1e-6)

def remove_duplicates_any_class(detections: Detections, iou_threshold=0.9) -> Detections:
    """
    Remove duplicate detections using IoU and confidence, regardless of class.
    
    Args:
        detections: supervision.Detections object
        iou_threshold: IoU threshold for considering duplicates
        
    Returns:
        Filtered Detections object
    """
    # Extract components from Detections
    xyxy = detections.xyxy
    confidence = detections.confidence
    class_id = detections.class_id
    mask = detections.mask

    # Convert to list of dicts for processing
    detections_list = [
        {
            'box': xyxy[i],
            'mask': mask[i] if mask is not None else None,
            'confidence': confidence[i],
            'class_id': class_id[i]
        }
        for i in range(len(xyxy))
    ]

    # Sort by confidence (highest first)
    detections_list.sort(key=lambda x: x['confidence'], reverse=True)
    
    # Filter duplicates (do NOT check class_id)
    keep = []
    while detections_list:
        current = detections_list.pop(0)
        keep.append(current)
        detections_list = [
            d for d in detections_list
            if calculate_iou(current['box'], d['box']) <= iou_threshold
        ]

    # Reconstruct Detections object
    return Detections(
        xyxy=np.array([d['box'] for d in keep]),
        confidence=np.array([d['confidence'] for d in keep]),
        class_id=np.array([d['class_id'] for d in keep]),
        mask=np.array([d['mask'] for d in keep]) if mask is not None else None
    )

# Usage example


In [None]:
filtered_dataset = []
for path, image, detections in dataset_746:
    filtered_detections = remove_duplicates_any_class(detections, iou_threshold=0.95)
    filtered_dataset.append((path, image, filtered_detections))

In [None]:
# visualize_and_save_pil_colored(dataset, confidence_threshold=0.3)

#### depth mean std

In [None]:
import numpy as np
import os
from tqdm.notebook import tqdm

def calculate_mean_std_for_npy(folder_path):
    total_sum = 0
    total_sum_sq = 0
    total_count = 0

    # List all .npy files
    files = [f for f in os.listdir(folder_path) if f.endswith('.npy')]
    
    # Wrap the loop with tqdm for progress bar
    for filename in tqdm(files):
        file_path = os.path.join(folder_path, filename)
        img = np.load(file_path).astype(np.float64)
        total_sum += img.sum()
        total_sum_sq += (img ** 2).sum()
        total_count += img.size

    mean = total_sum / total_count if total_count > 0 else None
    variance = (total_sum_sq / total_count) - (mean ** 2) if total_count > 0 else None
    std = np.sqrt(variance) if variance is not None else None
    return mean, std


In [None]:
# calculate_mean_std_for_npy('../data/depth')

#### mask fusion

In [None]:
import cv2
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
import os
from pathlib import Path

def process_segmentation_data(detections_dataset, output_dir, imagenet_mean=[0.485, 0.456, 0.406], confidence_threshold=0.3):
    """
    Process supervision Detections dataset to create social and environment images
    
    Args:
        detections_dataset: List of tuples (image_path, Detection_object, ...)
        output_dir: Base directory to save processed images
        imagenet_mean: RGB mean values for filling masked areas
    """
    
    # Create output directories
    social_dir = Path(output_dir) / "social"
    env_dir = Path(output_dir) / "environment"
    social_dir.mkdir(parents=True, exist_ok=True)
    env_dir.mkdir(parents=True, exist_ok=True)
    
    for item in tqdm(detections_dataset):
        image_path = item[0]
        detections = item[2]
        
        # Load original image
        original_image = cv2.imread(str(image_path))
        original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
        height, width = original_image.shape[:2]
        
        # Get filename for saving
        filename = Path(image_path).name

        keep_indices = [i for i, conf in enumerate(detections.confidence) if conf >= confidence_threshold]
        filtered_masks = detections.mask[keep_indices]
        
        # Combine all masks into one
        combined_mask = combine_masks(filtered_masks, height, width)
        
        # Create social image (only people and robot visible)
        social_image = apply_mask_with_mean(
            original_image, combined_mask, imagenet_mean, keep_masked=True
        )
        
        # Create environment image (room only, people and robot masked out)
        env_image = apply_mask_with_mean(
            original_image, combined_mask, imagenet_mean, keep_masked=False
        )
        
        # Save images
        social_path = social_dir / filename
        env_path = env_dir / filename
        
        save_image(social_image, social_path)
        save_image(env_image, env_path)


def combine_masks(masks, height, width, confidence_threshold=0.3):
    """
    Combine multiple masks into a single binary mask using union operation
    
    Args:
        masks: Array of individual masks from Detection object
        height, width: Dimensions of the original image
        
    Returns:
        combined_mask: Single binary mask (1 = object, 0 = background)
    """
    if masks is None or len(masks) == 0:
        return np.zeros((height, width), dtype=np.uint8)
    
    # Initialize combined mask
    combined_mask = np.zeros((height, width), dtype=np.uint8)
    
    # Union all individual masks using maximum function (as shown in search results)
    for mask in masks:
        # Ensure mask is the right size
        if mask.shape != (height, width):
            raise ValueError(f"Mask shape incorrect: {mask.shape}, should be {(height, width)}")
        
        # Union operation: take maximum of current combined mask and new mask
        combined_mask = np.maximum(combined_mask, mask.astype(np.uint8))
    
    return combined_mask

def apply_mask_with_mean(image, mask, imagenet_mean, keep_masked=True):
    """
    Apply mask to image and fill empty areas with ImageNet mean values
    
    Args:
        image: Original RGB image (H, W, 3)
        mask: Binary mask (H, W) where 1 = object, 0 = background
        imagenet_mean: RGB mean values [R, G, B] in range [0, 1]
        keep_masked: If True, keep masked areas (social). If False, remove masked areas (environment)
        
    Returns:
        processed_image: Image with mask applied and filled with mean values
    """
    processed_image = image.copy().astype(np.float32) / 255.0
    
    # Convert imagenet_mean to same range as image
    mean_values = np.array(imagenet_mean).reshape(1, 1, 3)
    
    if keep_masked:
        # Social image: keep people/robot, fill background with mean
        fill_mask = (mask == 0)  # Areas to fill (background)
    else:
        # Environment image: keep background, fill people/robot with mean  
        fill_mask = (mask == 1)  # Areas to fill (people/robot)
    
    # Fill specified areas with ImageNet mean values
    for c in range(3):  # RGB channels
        processed_image[:, :, c][fill_mask] = imagenet_mean[c]
    
    # Convert back to uint8
    processed_image = (processed_image * 255).astype(np.uint8)
    
    return processed_image

def save_image(image, save_path):
    """
    Save image to specified path
    
    Args:
        image: RGB image array (H, W, 3)
        save_path: Path to save the image
    """
    # Convert to PIL Image and save
    pil_image = Image.fromarray(image)
    pil_image.save(save_path)


In [None]:
# with bz2.BZ2File('autodistill_dataset_home.pbz2', 'rb') as f:
#     dataset_home = pickle.load(f)

# process_segmentation_data(
#     detections_dataset=dataset_home,
#     output_dir='../data/masked',
#     imagenet_mean=[0.485, 0.456, 0.406]  # ImageNet RGB means
# )


#### dualbranch model

In [None]:
def intermediate_layer_size(input, output, n_layers):
    start_exp = (output + 1).bit_length()
    end_exp = (input - 1).bit_length()
    
    total_powers = end_exp - start_exp
    if total_powers < n_layers:
        return None

    result = []
    denominator = n_layers + 1
    half_denominator = denominator // 2

    for i in range(1, n_layers + 1):
        numerator = i * total_powers + half_denominator #works same as rounding
        idx = numerator // denominator
        power = 1 << (start_exp + idx)
        result.append(power)

    return result


intermediate_layer_size(1280+64, 9, 3)

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class DualBranchModel(nn.Module):
    def __init__(self, num_outputs=9, dropout_rate=0.3, architecture={'env':'lightweight', 'head':'deep'}):
        super(DualBranchModel, self).__init__()
        self.setup = architecture
        
        self.social_branch = nn.Sequential(
            models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1).features,
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten()
        )
        soc_feature_dim = 1280
        

        if self.setup['env'] == 'lightweight':
            self.env_branch = nn.Sequential(
                models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1).features,
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten()
            )
            env_feature_dim = 1280

        elif self.setup['env'] == 'label':
            max_rooms = 30
            self.env_branch = nn.Embedding(max_rooms, 64)
            env_feature_dim = 64


        self.fusion_dim = soc_feature_dim + env_feature_dim

        layers = intermediate_layer_size(soc_feature_dim, num_outputs, 1)
        self.social_classifier = nn.Sequential(
            nn.Linear(soc_feature_dim, layers[0]),
            nn.ReLU(),
            nn.Linear(layers[0], num_outputs)
        )
        layers = intermediate_layer_size(env_feature_dim, num_outputs, 1)
        self.env_classifier = nn.Sequential(
            nn.Linear(env_feature_dim, layers[0]),
            nn.ReLU(),
            nn.Linear(layers[0], num_outputs)
        )
        
        if self.setup['head'] == 'deep':
            self.head = nn.Sequential(
                nn.Linear(self.fusion_dim, 512),
                nn.BatchNorm1d(512),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                
                nn.Linear(512, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                
                nn.Linear(256, 128),
                nn.BatchNorm1d(128),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                
                nn.Linear(128, num_outputs)
            )
        elif self.setup['head'] == 'shallow':
            self.head = nn.Sequential(
                nn.Linear(self.fusion_dim, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, num_outputs)
            )
        
    def forward(self, social_img, env_img):
        social_features = self.social_branch(social_img)
        env_features = self.env_branch(env_img)
        
        fused_features = torch.cat([social_features, env_features], dim=1)
        scores = self.head(fused_features)

        social_class = self.social_classifier(social_features.detach())
        env_class = self.env_classifier(env_features.detach())
        
        return {
            'output': scores,
            'invariant_domain': social_class,
            'specific_domain': env_class,
            'invariant_feats': social_features,
            'specific_feats': env_features
        }

## ---

In [None]:
def combine_fold_histories(fold_history):
    """
    Combine multiple training histories by averaging them element-wise.
    Handles both simple lists and lists of dictionaries.
    """
    combined_history = {}
    metric_keys = list(fold_history[0].keys())
    
    for key in metric_keys:
        # Check if this metric contains dictionaries
        first_element = fold_history[0][key][0] if fold_history[0][key] else None
        
        if isinstance(first_element, dict):
            # Handle lists of dictionaries (train_epoch_metrics, grad_norms)
            combined_history[key] = average_list_of_dicts(fold_history, key)
        else:
            # Handle simple lists (train_epoch_loss, val_epoch_loss, cross_domain_val)
            stacked_metrics = np.stack([fold_history[fold][key] for fold in fold_history])
            combined_history[key] = np.mean(stacked_metrics, axis=0).tolist()
    
    return combined_history

def average_list_of_dicts(fold_history, metric_key):
    """
    Average a list of dictionaries across folds.
    """
    # Get the dictionary keys from the first fold's first epoch
    dict_keys = list(fold_history[0][metric_key][0].keys())
    
    # Convert each fold's list of dicts to a 2D numpy array
    fold_arrays = []
    for fold in fold_history:
        # Convert list of dicts to 2D array: [epochs, sub_metrics]
        fold_array = np.array([[epoch_dict[k] for k in dict_keys] 
                              for epoch_dict in fold_history[fold][metric_key]])
        fold_arrays.append(fold_array)
    
    # Stack all folds and average: [folds, epochs, sub_metrics] -> [epochs, sub_metrics]
    stacked = np.stack(fold_arrays)
    averaged = np.mean(stacked, axis=0)
    
    # Convert back to list of dictionaries
    result = []
    for epoch_values in averaged:
        epoch_dict = dict(zip(dict_keys, epoch_values))
        result.append(epoch_dict)
    
    return result

In [None]:
class RMSELoss(torch.nn.Module):
    def __init__(self, eps=1e-8):
        super().__init__()
        self.mse = torch.nn.MSELoss()
        self.eps = eps

    def forward(self, x, y):
        return torch.sqrt(self.mse(x, y) + self.eps)

In [None]:
from sklearn.model_selection import StratifiedKFold

df = pd.read_pickle("../data/pepper_data.pkl")
df['image_path'] = '../' + df['image_path']

# Initialize fold column
df['fold'] = -1

# Get unique image paths for splitting
unique_images_df = df[['image_path', 'domain']].reset_index(drop=True)
unique_images_df = unique_images_df[~unique_images_df['image_path'].isin(test_split_idx)]

# Create stratified 5-fold splits based on domain
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)

# Assign fold numbers based on domain stratification
for fold, (_, val_idx) in enumerate(skf.split(unique_images_df['image_path'], unique_images_df['domain'])):
    val_image_paths = unique_images_df.iloc[val_idx]['image_path'].tolist()
    df.loc[df['image_path'].isin(val_image_paths), 'fold'] = fold

# Verify domain distribution across folds
print("Domain distribution across folds:")
print(df.groupby(['fold', 'domain']).size().unstack(fill_value=0))

In [None]:
fold_loaders = create_crossvalidation_loaders(df, 5, batch_sizes=(32, 64, 64), resize_img_to=(512, 288))

## Training

### setup

In [None]:
# For LGRBaseline
def baseline_batch(model, batch, device, detach_base, binary, full_replay, loss_params, **kwargs):
    inputs, labels, _ = batch
    inputs, labels = inputs.to(device), labels.to(device)
    outputs = model(inputs)['output']
    loss = kwargs['mse_criterion'](outputs, labels)
    loss.backward()
    metrics = {}
    return loss, metrics

#For DANN
def dann_batch(model, batch, device, detach_base, binary, full_replay, loss_params={'head':0.5, 'social':0.5}, **kwargs):
    inputs, labels, domain_labels = batch
    inputs, labels = inputs.to(device), labels.to(device)
    domain_to_idx = kwargs['domain_to_idx']
    domain_labels = torch.tensor([domain_to_idx[d] for d in domain_labels], device=device)
    mse_criterion = kwargs['mse_criterion']
    bce_criterion = kwargs['bce_criterion']
    alpha = kwargs['alpha']

    current_domain = kwargs['current_domain']
    current_binary_labels = (domain_labels == domain_to_idx[current_domain]).float()
    is_first_domain = bool(domain_to_idx[current_domain] == 0)

    outputs = model(inputs, alpha=alpha, is_first_domain=is_first_domain)

    task_loss = mse_criterion(outputs['output'], labels)

    if is_first_domain:
        inv_domain_loss = 0
        inv_acc = 0
    else:
        inv_domain_loss = bce_criterion(outputs['invariant_domain'].squeeze(), current_binary_labels)
        preds = (outputs['invariant_domain'].squeeze() > 0).float()
        inv_acc = (preds == current_binary_labels).float().mean().item()
    
    total_loss = (loss_params['head'] * task_loss +
                    loss_params['social'] * inv_domain_loss)

    total_loss.backward()
    
    metrics = {
        'task_loss': task_loss.item(),
        'inv_domain': 0 if is_first_domain else inv_domain_loss.item(),
        'inv_acc': inv_acc
    }
    return total_loss, metrics

def heuristic_dualbranch_batch(model, batch, device, detach_base, binary, full_replay, loss_params={}, **kwargs):
    inputs1, inputs2, labels, domain_labels = batch
    inputs1, inputs2, labels, domain_labels = inputs1.to(device), inputs2.to(device), labels.to(device), domain_labels.to(device)
    domain_to_idx = kwargs['domain_to_idx']
    # domain_labels = torch.tensor([domain_to_idx[d] for d in domain_labels], device=device)
    mse_criterion = kwargs['mse_criterion']
    ce_criterion = kwargs['ce_criterion']

    outputs = model(inputs1, inputs2)

    loss = mse_criterion(outputs['output'], labels)
    loss.backward()

    class_optimiser = kwargs['class_optimizer']
    class_optimiser.zero_grad()
    inv_domain_loss = ce_criterion(outputs['invariant_domain'], domain_labels)
    spec_domain_loss = ce_criterion(outputs['specific_domain'], domain_labels)
    (inv_domain_loss + spec_domain_loss).backward()
    class_optimiser.step()
    inv_acc = (outputs['invariant_domain'].argmax(1) == domain_labels).float().mean().item()
    spec_acc = (outputs['specific_domain'].argmax(1) == domain_labels).float().mean().item()
    
    metrics = {
        'inv_domain': inv_domain_loss.item(),
        'spec_domain': spec_domain_loss.item(),
        'inv_acc': inv_acc,
        'spec_acc': spec_acc
    }
    return loss, metrics


# For DualBranchNet
def dualbranch_batch(model, batch, device, detach_base, binary, full_replay, loss_params={'head': 1, 'social': 0.5, 'room': 0.2}, **kwargs):
    inputs, labels, domain_labels = batch
    inputs, labels = inputs.to(device), labels.to(device)
    domain_to_idx = kwargs['domain_to_idx']
    domain_labels = torch.tensor([domain_to_idx[d] for d in domain_labels], device=device)
    mse_criterion = kwargs['mse_criterion']
    ce_criterion = kwargs['ce_criterion']
    cos_criterion = kwargs['cos_criterion']
    alpha = kwargs['alpha']
    if binary:
        bce_criterion = kwargs['bce_criterion']

    # Split batch
    current_domain = kwargs['current_domain']
    current_binary_labels = (domain_labels == domain_to_idx[current_domain]).float()
    current_mask = (domain_labels == domain_to_idx[current_domain])

    if full_replay:
        current_mask = torch.ones_like(current_mask, dtype=torch.bool)

    replay_mask = ~current_mask

    # 1. Current samples: update all parameters
    if current_mask.any():
        inputs_current = inputs[current_mask]
        labels_current = labels[current_mask]
        domain_labels_current = domain_labels[current_mask]

        outputs_current = model(inputs_current, alpha=alpha)
        inv_feats = outputs_current['invariant_feats']
        spec_feats = outputs_current['specific_feats']

        task_loss = mse_criterion(outputs_current['output'], labels_current)
        if binary:
            inv_domain_loss = bce_criterion(outputs_current['invariant_domain'].squeeze(), current_binary_labels[current_mask])
        else:
            inv_domain_loss = ce_criterion(outputs_current['invariant_domain'], domain_labels_current)
        spec_domain_loss = ce_criterion(outputs_current['specific_domain'], domain_labels_current)
        similarity_loss = cos_criterion(inv_feats, spec_feats)
        
        total_loss = (loss_params['head'] * task_loss +
                      loss_params['social'] * inv_domain_loss +
                      loss_params['room'] * spec_domain_loss)

        total_loss.backward(retain_graph= not full_replay)
        
        if binary:
            # Threshold at 0 (sigmoid(0) = 0.5)
            preds = (outputs_current['invariant_domain'].squeeze() > 0).float()
            inv_acc = (preds == current_binary_labels[current_mask]).float().mean().item()
        else:
            inv_acc = (outputs_current['invariant_domain'].argmax(1) == domain_labels_current).float().mean().item()
        spec_acc = (outputs_current['specific_domain'].argmax(1) == domain_labels_current).float().mean().item()
    else:
        total_loss = torch.tensor(0.0, device=device)
        inv_acc = 0.0
        spec_acc = 0.0
        task_loss = torch.tensor(0.0, device=device)
        inv_domain_loss = torch.tensor(0.0, device=device)
        spec_domain_loss = torch.tensor(0.0, device=device)
        similarity_loss = torch.tensor(0.0, device=device)

    # 2. Replay samples: update only specific branch + head
    if replay_mask.any():
        inputs_replay = inputs[replay_mask]
        labels_replay = labels[replay_mask]
        domain_labels_replay = domain_labels[replay_mask]

        #no_grad, unlike requires_grad=False, detaches all elements from the gradient computation graph
        with torch.no_grad():
            base_replay = model.backbone(inputs_replay)
            base_replay = model.pool(base_replay).flatten(1)
            inv_feats_replay = model.invariant(base_replay)

        specific_feats = model.specific(base_replay)
        spec_domain_pred = model.specific_domain_classifier(specific_feats)
        
        if detach_base:
            combined = torch.cat([inv_feats_replay, specific_feats, base_replay], dim=1)
        else:
            combined = torch.cat([inv_feats_replay, specific_feats], dim=1)  
        
        scores = model.head(combined)
        
        task_loss_replay = mse_criterion(scores, labels_replay)
        spec_domain_loss_replay = ce_criterion(spec_domain_pred, domain_labels_replay)
        total_loss_replay = task_loss_replay + 0.2 * spec_domain_loss_replay
        
        total_loss_replay.backward()
    
    metrics = {
        'task_loss': task_loss.item(),
        'inv_domain': inv_domain_loss.item(),
        'spec_domain': spec_domain_loss.item(),
        'similarity': similarity_loss.item(),
        'inv_acc': inv_acc,
        'spec_acc': spec_acc,
        'replay_count': replay_mask.sum().item(),
        'current_count': current_mask.sum().item()
    }
    return total_loss, metrics


In [None]:
#TODO: change dataset to return torch domain labels indexes not strings
def evaluate_model(model, dataloader, criterion, device , tsne=None):
    model.eval()
    total_loss = 0.0
    total_samples = 0
    with torch.no_grad():
        for batch in dataloader:
            if len(batch) == 4:
                inputs1, inputs2, labels, domain_labels = batch
                inputs1 = inputs1.to(device, dtype=torch.float32)
                inputs2 = inputs2.to(device, dtype=torch.float32)
                labels = labels.to(device, dtype=torch.float32)
                inputs = (inputs1, inputs2)
            elif len(batch) == 3:
                inputs, labels, _ = batch
                inputs = inputs.to(device, dtype=torch.float32)
                labels = labels.to(device, dtype=torch.float32)
                inputs = (inputs,)
            else:
                raise ValueError(f"Batch contains {len(batch)} objects. Should contain 3 or 4 - image/two, labels, domain_labels")

            outputs = model(*inputs)['output']

            loss = criterion(outputs, labels)

            total_loss += loss.item() * inputs[0].size(0)
            total_samples += inputs[0].size(0)
            
            if tsne:
                tsne['social'].append(outputs['invariant_feats'].cpu())
                tsne['environmental'].append(outputs['specific_feats'].cpu())
                tsne['domains'].append(domain_labels.cpu())

            val_loss = total_loss / total_samples

    return (val_loss, tsne) if tsne else val_loss

def cross_domain_validation(model, domain_dataloaders, criterion, device, tsne=None):
    results = {}
    for domain, loaders in domain_dataloaders.items():
        val_loader = loaders['val']
        if tsne:
            val_loss, tsne = evaluate_model(model, val_loader, criterion, device, tsne)
        else:
            val_loss = evaluate_model(model, val_loader, criterion, device)
        results[domain] = val_loss
    return (results, tsne) if tsne else results

def average_metrics(metrics_list):
    # metrics_list: list of dicts, each dict contains metrics for a batch
    if not metrics_list:
        return {}
    keys = metrics_list[0].keys()
    avg_metrics = {}
    for k in keys:
        avg_metrics[k] = float(np.mean([m[k] for m in metrics_list if k in m]))
    return avg_metrics

def collect_tsne_features(model, loader, device):
    model.eval()
    all_inv, all_spec, all_domains = [], [], []
    with torch.no_grad():
        for domain, loaders in domain_dataloaders.items():
            loader = loaders['val']
            for x, _, d in loader:
                x = x.to(device)
                out = model(x)
                all_inv.append(out['invariant_feats'].cpu())
                all_domains += list(d)
    inv_feats = torch.cat(all_inv, dim=0).numpy()
    return inv_feats, all_domains


def collect_gradients(model):
    grad_norms = {}
    for name, param in model.named_parameters():
        if param.grad is not None and not name.startswith("backbone"):
            module = name.split('.')[0]
            norm = param.grad.norm(2).item()
            if module not in grad_norms:
                grad_norms[module] = []
            grad_norms[module].append(norm)
    # Take mean per module
    grad_norms = {k: float(np.mean(v)) for k, v in grad_norms.items()}
    return grad_norms


In [None]:
import pickle
import torch
from tqdm.notebook import tqdm, trange

def unified_train_loop(
    model, domains, domain_dataloaders, buffer, optimizer, writer, device,
    batch_fn, batch_kwargs, loss_params, num_epochs=5, exp_name="exp", gradient_clipping=False, detach_base=False, binary=False, full_replay=False, collect_tsne_data=False, restart={}, eval_buffer=False
):
    start_domain_idx = 0
    global_step = 0
    history = {
        'train_epoch_loss': [],
        'val_epoch_loss': [],
        'val_buffer_epoch_loss': [],
        'train_epoch_metrics': [],
        'cross_domain_val': [],
        'grad_norms': [],
    }
    
    if restart:
        # Populate history
        global_step = restart['global_step']
        history = restart['history']
        # Populate buffer
        start_domain_idx = np.where(domains == restart['domain'])[0][0]
        for domain_idx, current_domain in enumerate(domains[:start_domain_idx]):
            buffer.update_buffer(current_domain, domain_dataloaders[current_domain]['train'].dataset) 
        print(f"Restarting from domain {restart['domain']} index {start_domain_idx}")
        print(f"Buffer: {buffer.get_domain_distribution()}")         
        

    for domain_idx, current_domain in enumerate(tqdm(domains[start_domain_idx:], desc=f"Total training"), start=start_domain_idx):
        train_loader = buffer.get_loader_with_replay(current_domain, domain_dataloaders[current_domain]['train'])
        if eval_buffer:
            eval_loader = eval_buffer.get_loader_with_replay(current_domain, domain_dataloaders[current_domain]['val'])
        else:
            eval_loader = domain_dataloaders[current_domain]['val']
        len_dataloader = len(train_loader)
        
        for epoch in trange(num_epochs, desc=f"Current domain {current_domain}"):
            model.train()
            epoch_loss = 0.0
            samples = 0
            batch_metrics_list = []
            
            # for batch_idx, batch in enumerate(train_loader):
            for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Current epoch {epoch}", leave=False)):
                if not batch_kwargs['alpha']:
                    p = (epoch * len_dataloader + batch_idx) / (num_epochs * len_dataloader)
                    alpha = 2. / (1. + np.exp(-10 * p)) - 1
                else:
                    alpha = batch_kwargs['alpha']

                optimizer.zero_grad()
                loss, metrics = batch_fn(model, batch, device, detach_base, binary, full_replay, loss_params, **{**batch_kwargs, 'current_domain': current_domain, 'alpha':alpha})
                if gradient_clipping:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                batch_size = batch[0].size(0)
                epoch_loss += loss.item() * batch_size
                samples += batch_size
                global_step += 1
                batch_metrics_list.append(metrics)
                # # TensorBoard logging (every 10 batches)
                # if writer and batch_idx % 10 == 0:
                #     writer.add_scalar(f'{exp_name}/train_loss', loss.item(), global_step)
                #     for k, v in metrics.items():
                #         writer.add_scalar(f'{exp_name}/train_{k}', v, global_step)
            avg_epoch_loss = epoch_loss / samples
            # writer.add_scalar(f'{exp_name}/train_epoch_loss', avg_epoch_loss, global_step)
            history['train_epoch_loss'].append(avg_epoch_loss)
            # Average batch metrics for this epoch
            avg_metrics = average_metrics(batch_metrics_list)
            history['train_epoch_metrics'].append(avg_metrics)

            # Collect gradients
            grad_norms = collect_gradients(model)
            history['grad_norms'].append(grad_norms)

            # Validation on current domain
            val_loss = evaluate_model(model, domain_dataloaders[current_domain]['val'], batch_kwargs['mse_criterion'], device)
            val_loss_buffer = evaluate_model(model, eval_loader, batch_kwargs['mse_criterion'], device)
            # writer.add_scalar(f'{exp_name}/val_epoch_loss', val_loss, global_step)
            history['val_epoch_loss'].append(val_loss)
            history['val_buffer_epoch_loss'].append(val_loss_buffer)

            # Collect data for t-SNE domain separation graphs
            if collect_tsne_data:
                inv_feats, domain_labels = collect_tsne_features(model, domain_dataloaders, device)
                tsne_data = {
                    'inv_feats': inv_feats,
                    'domain_labels': domain_labels
                }
            else:
                tsne_data = None

            # Cross-domain validation (after each domain)
            if epoch == num_epochs-1:
                if collect_tsne_data:
                    tsne = {'social': [], 'env': [], 'domains': []}
                    cross_val, tsne_data = cross_domain_validation(model, domain_dataloaders, batch_kwargs['mse_criterion'], device, tsne)
                else:
                    cross_val = cross_domain_validation(model, domain_dataloaders, batch_kwargs['mse_criterion'], device)
                history['cross_domain_val'].append(cross_val)

                # Only save last model per domain to save space
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'history': history,
                    'tsne' : tsne_data,
                }, f"../checkpoints/{exp_name}_domain{current_domain}_epoch{epoch}_step{global_step}.pt")
            # else:
            #     # Save metrics
            #     torch.save({
            #         # 'model_state_dict': model.state_dict(),
            #         # 'optimizer_state_dict': optimizer.state_dict(),
            #         'history': history,
            #         'tsne' : tsne_data,
            #     }, f"../checkpoints/{exp_name}_domain{current_domain}_epoch{epoch}_step{global_step}.pt")
            with open(f"../checkpoints/{exp_name}_history.pkl", "wb") as f:
                pickle.dump(history, f)
            
        buffer.update_buffer(current_domain, domain_dataloaders[current_domain]['train'].dataset)
        eval_buffer.update_buffer(current_domain, domain_dataloaders[current_domain]['val'].dataset)
    return history


### runs

In [None]:
import datetime
from torch.utils.tensorboard import SummaryWriter

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
domain_to_idx = {d: i for i, d in enumerate(domains)}

# # For baseline model
# baseline_model = LGRBaseline().to(device)
# optimizer = torch.optim.Adam(baseline_model.parameters(), lr=1e-3)
# buffer = NaiveRehearsalBuffer(buffer_size=1000)
# exp_name = f"baselinemodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
# writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

# baseline_kwargs = {'mse_criterion': nn.MSELoss()}
# unified_train_loop(
#     model=baseline_model,
#     domains=domains,
#     domain_dataloaders=domain_dataloaders,
#     buffer=buffer,
#     optimizer=optimizer,
#     writer=writer,
#     device=device,
#     batch_fn=baseline_batch,
#     batch_kwargs=baseline_kwargs,
#     num_epochs=10,
#     exp_name=exp_name
# )

# For DualBranchNet
dual_model = DualBranchNet(num_domains=len(domains)).to(device)
optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(buffer_size=1000)

def cos_criterion(a, b):
    return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    # return (nn.CosineSimilarity()(a, b) ** 2).mean()

exp_name = f"nores_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

dualbranch_kwargs = {
    'mse_criterion': nn.MSELoss(),
    'ce_criterion': nn.CrossEntropyLoss(),
    'cos_criterion': cos_criterion,
    'domain_to_idx': domain_to_idx
}
unified_train_loop(
    model=dual_model,
    domains=domains,
    domain_dataloaders=domain_dataloaders,
    buffer=buffer,
    optimizer=optimizer,
    writer=writer,
    device=device,
    batch_fn=dualbranch_batch,
    batch_kwargs=dualbranch_kwargs,
    num_epochs=10,
    exp_name=exp_name,
    gradient_clipping=False
)


# For DualBranchNet
dual_model = DualBranchNet(num_domains=len(domains)).to(device)
optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(buffer_size=1000)

def cos_criterion(a, b):
    return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    # return (nn.CosineSimilarity()(a, b) ** 2).mean()

exp_name = f"nores_gradclip_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

dualbranch_kwargs = {
    'mse_criterion': nn.MSELoss(),
    'ce_criterion': nn.CrossEntropyLoss(),
    'cos_criterion': cos_criterion,
    'domain_to_idx': domain_to_idx
}
unified_train_loop(
    model=dual_model,
    domains=domains,
    domain_dataloaders=domain_dataloaders,
    buffer=buffer,
    optimizer=optimizer,
    writer=writer,
    device=device,
    batch_fn=dualbranch_batch,
    batch_kwargs=dualbranch_kwargs,
    num_epochs=10,
    exp_name=exp_name,
    gradient_clipping=True
)

# For DualBranchNet
dual_model = DualBranchNet(weights_init=True, weights_norm=True).to(device)
optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(buffer_size=1000)

def cos_criterion(a, b):
    return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    # return (nn.CosineSimilarity()(a, b) ** 2).mean()

exp_name = f"nores_winit_wnorm_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

dualbranch_kwargs = {
    'mse_criterion': nn.MSELoss(),
    'ce_criterion': nn.CrossEntropyLoss(),
    'cos_criterion': cos_criterion,
    'domain_to_idx': domain_to_idx
}
unified_train_loop(
    model=dual_model,
    domains=domains,
    domain_dataloaders=domain_dataloaders,
    buffer=buffer,
    optimizer=optimizer,
    writer=writer,
    device=device,
    batch_fn=dualbranch_batch,
    batch_kwargs=dualbranch_kwargs,
    num_epochs=10,
    exp_name=exp_name,
    gradient_clipping=False
)

# For DualBranchNet
dual_model = DualBranchNet(weights_init=True, weights_norm=True).to(device)
optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(buffer_size=1000)

def cos_criterion(a, b):
    return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    # return (nn.CosineSimilarity()(a, b) ** 2).mean()

exp_name = f"nores_gradclip_winit_wnorm_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

dualbranch_kwargs = {
    'mse_criterion': nn.MSELoss(),
    'ce_criterion': nn.CrossEntropyLoss(),
    'cos_criterion': cos_criterion,
    'domain_to_idx': domain_to_idx
}
unified_train_loop(
    model=dual_model,
    domains=domains,
    domain_dataloaders=domain_dataloaders,
    buffer=buffer,
    optimizer=optimizer,
    writer=writer,
    device=device,
    batch_fn=dualbranch_batch,
    batch_kwargs=dualbranch_kwargs,
    num_epochs=10,
    exp_name=exp_name,
    gradient_clipping = True
)


In [None]:
import datetime
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
domain_to_idx = {d: i for i, d in enumerate(domains)}

dual_model = DualBranchNet_deep(weights_init=True, weights_norm=False).to(device)
optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(buffer_size=1000)

def cos_criterion(a, b):
    return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    # return (nn.CosineSimilarity()(a, b) ** 2).mean()

exp_name = f"deep_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

dualbranch_kwargs = {
    'mse_criterion': nn.MSELoss(),
    'ce_criterion': nn.CrossEntropyLoss(),
    'cos_criterion': cos_criterion,
    'domain_to_idx': domain_to_idx
}
unified_train_loop(
    model=dual_model,
    domains=domains,
    domain_dataloaders=domain_dataloaders,
    buffer=buffer,
    optimizer=optimizer,
    writer=writer,
    device=device,
    batch_fn=dualbranch_batch,
    batch_kwargs=dualbranch_kwargs,
    num_epochs=10,
    exp_name=exp_name,
    gradient_clipping=True
)


dual_model = DualBranchNet_deep(weights_init=True, weights_norm=True).to(device)
optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(buffer_size=1000)

def cos_criterion(a, b):
    return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    # return (nn.CosineSimilarity()(a, b) ** 2).mean()

exp_name = f"deep_norm_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

dualbranch_kwargs = {
    'mse_criterion': nn.MSELoss(),
    'ce_criterion': nn.CrossEntropyLoss(),
    'cos_criterion': cos_criterion,
    'domain_to_idx': domain_to_idx
}
unified_train_loop(
    model=dual_model,
    domains=domains,
    domain_dataloaders=domain_dataloaders,
    buffer=buffer,
    optimizer=optimizer,
    writer=writer,
    device=device,
    batch_fn=dualbranch_batch,
    batch_kwargs=dualbranch_kwargs,
    num_epochs=10,
    exp_name=exp_name,
    gradient_clipping=True
)

In [None]:
import datetime
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
domain_to_idx = {d: i for i, d in enumerate(domains)}

dual_model = DualBranchNet_deep(weights_init=False, weights_norm=False, layer_norm=False, detach_base=True).to(device)
optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(buffer_size=1000)

def cos_criterion(a, b):
    return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    # return (nn.CosineSimilarity()(a, b) ** 2).mean()

exp_name = f"bbdetach_deep_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

dualbranch_kwargs = {
    'mse_criterion': nn.MSELoss(),
    'ce_criterion': nn.CrossEntropyLoss(),
    'cos_criterion': cos_criterion,
    'domain_to_idx': domain_to_idx
}
unified_train_loop(
    model=dual_model,
    domains=domains,
    domain_dataloaders=domain_dataloaders,
    buffer=buffer,
    optimizer=optimizer,
    writer=writer,
    device=device,
    batch_fn=dualbranch_batch,
    batch_kwargs=dualbranch_kwargs,
    num_epochs=10,
    exp_name=exp_name,
    gradient_clipping=True
)

In [None]:
import datetime
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
domain_to_idx = {d: i for i, d in enumerate(domains)}

dual_model = DualBranchNet_deep(weights_init=True, weights_norm=False, layer_norm=False, detach_base=True).to(device)
optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(buffer_size=1000)

def cos_criterion(a, b):
    return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    # return (nn.CosineSimilarity()(a, b) ** 2).mean()

exp_name = f"bdetach_batch16_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

dualbranch_kwargs = {
    'mse_criterion': nn.MSELoss(),
    'ce_criterion': nn.CrossEntropyLoss(),
    'cos_criterion': cos_criterion,
    'domain_to_idx': domain_to_idx
}

domain_dataloaders = {}
for domain in domains:
    domain_df = df[df['domain'] == domain]
    loaders = create_dataloaders(domain_df, batch_sizes=(16, 64, 64), resize_img_to=(128, 128), seed=SEED)  #TODO should be (384, 216) to retain scale or (224, 224) for best performance on MobileNet
    domain_dataloaders[domain] = loaders

unified_train_loop(
    model=dual_model,
    domains=domains,
    domain_dataloaders=domain_dataloaders,
    buffer=buffer,
    optimizer=optimizer,
    writer=writer,
    device=device,
    batch_fn=dualbranch_batch,
    batch_kwargs=dualbranch_kwargs,
    num_epochs=10,
    exp_name=exp_name,
    gradient_clipping=True,
    detach_base=True
)

domain_dataloaders = {}
for domain in domains:
    domain_df = df[df['domain'] == domain]
    loaders = create_dataloaders(domain_df, batch_sizes=(32, 64, 64), resize_img_to=(128, 128), seed=SEED)  #TODO should be (384, 216) to retain scale or (224, 224) for best performance on MobileNet
    domain_dataloaders[domain] = loaders


dual_model = DualBranchNet_deep(weights_init=False, weights_norm=False, layer_norm=False, detach_base=True, freeze_base='partial').to(device)
optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(buffer_size=1000)

def cos_criterion(a, b):
    return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    # return (nn.CosineSimilarity()(a, b) ** 2).mean()

exp_name = f"bdetach_pfrozen_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

dualbranch_kwargs = {
    'mse_criterion': nn.MSELoss(),
    'ce_criterion': nn.CrossEntropyLoss(),
    'cos_criterion': cos_criterion,
    'domain_to_idx': domain_to_idx
}
unified_train_loop(
    model=dual_model,
    domains=domains,
    domain_dataloaders=domain_dataloaders,
    buffer=buffer,
    optimizer=optimizer,
    writer=writer,
    device=device,
    batch_fn=dualbranch_batch,
    batch_kwargs=dualbranch_kwargs,
    num_epochs=10,
    exp_name=exp_name,
    gradient_clipping=True,
    detach_base=True
)

dual_model = DualBranchNet_deep(weights_init=False, weights_norm=False, layer_norm=False, detach_base=False, freeze_base='full').to(device)
optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(buffer_size=1000)

def cos_criterion(a, b):
    return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    # return (nn.CosineSimilarity()(a, b) ** 2).mean()

exp_name = f"ffrozen_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

dualbranch_kwargs = {
    'mse_criterion': nn.MSELoss(),
    'ce_criterion': nn.CrossEntropyLoss(),
    'cos_criterion': cos_criterion,
    'domain_to_idx': domain_to_idx
}
unified_train_loop(
    model=dual_model,
    domains=domains,
    domain_dataloaders=domain_dataloaders,
    buffer=buffer,
    optimizer=optimizer,
    writer=writer,
    device=device,
    batch_fn=dualbranch_batch,
    batch_kwargs=dualbranch_kwargs,
    num_epochs=10,
    exp_name=exp_name,
    gradient_clipping=True,
    detach_base=False
)

In [None]:
import datetime
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
domain_to_idx = {d: i for i, d in enumerate(domains)}

dual_model = DualBranchNet_binary(weights_init=False, weights_norm=False, layer_norm=False, detach_base=False, freeze_base='full', explicit_grl=False).to(device)
optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(buffer_size=1000)

def cos_criterion(a, b):
    return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    # return (nn.CosineSimilarity()(a, b) ** 2).mean()

exp_name = f"ffrozen_binary_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

dualbranch_kwargs = {
    'mse_criterion': nn.MSELoss(),
    'ce_criterion': nn.CrossEntropyLoss(),
    'cos_criterion': cos_criterion,
    'domain_to_idx': domain_to_idx,
    'bce_criterion': nn.BCEWithLogitsLoss()
}
unified_train_loop(
    model=dual_model,
    domains=domains,
    domain_dataloaders=domain_dataloaders,
    buffer=buffer,
    optimizer=optimizer,
    writer=writer,
    device=device,
    batch_fn=dualbranch_batch,
    batch_kwargs=dualbranch_kwargs,
    num_epochs=10,
    exp_name=exp_name,
    gradient_clipping=True,
    detach_base=False,
    binary = True,
    full_replay = False
)

dual_model = DualBranchNet_binary(weights_init=False, weights_norm=False, layer_norm=False, detach_base=False, freeze_base='full', explicit_grl=False).to(device)
optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(buffer_size=1000)

def cos_criterion(a, b):
    return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    # return (nn.CosineSimilarity()(a, b) ** 2).mean()

exp_name = f"ffrozen_binary_explicitgrl_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

dualbranch_kwargs = {
    'mse_criterion': nn.MSELoss(),
    'ce_criterion': nn.CrossEntropyLoss(),
    'cos_criterion': cos_criterion,
    'domain_to_idx': domain_to_idx,
    'bce_criterion': nn.BCEWithLogitsLoss()
}
unified_train_loop(
    model=dual_model,
    domains=domains,
    domain_dataloaders=domain_dataloaders,
    buffer=buffer,
    optimizer=optimizer,
    writer=writer,
    device=device,
    batch_fn=dualbranch_batch,
    batch_kwargs=dualbranch_kwargs,
    num_epochs=10,
    exp_name=exp_name,
    gradient_clipping=True,
    detach_base=False,
    binary=True,
    full_replay = False
)


dual_model = DualBranchNet_binary(weights_init=True, weights_norm=False, layer_norm=False, detach_base=False, freeze_base='full', explicit_grl=True).to(device)
optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(buffer_size=1000)

def cos_criterion(a, b):
    return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    # return (nn.CosineSimilarity()(a, b) ** 2).mean()

exp_name = f"freplay_ffrozen_binary_explicitgrl_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

dualbranch_kwargs = {
    'mse_criterion': nn.MSELoss(),
    'ce_criterion': nn.CrossEntropyLoss(),
    'cos_criterion': cos_criterion,
    'domain_to_idx': domain_to_idx,
    'bce_criterion': nn.BCEWithLogitsLoss()
}
unified_train_loop(
    model=dual_model,
    domains=domains,
    domain_dataloaders=domain_dataloaders,
    buffer=buffer,
    optimizer=optimizer,
    writer=writer,
    device=device,
    batch_fn=dualbranch_batch,
    batch_kwargs=dualbranch_kwargs,
    num_epochs=10,
    exp_name=exp_name,
    gradient_clipping=True,
    detach_base=False,
    binary=True,
    full_replay = True
)

In [None]:
from itertools import product
import datetime
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
domain_to_idx = {d: i for i, d in enumerate(domains)}

testing_scenarios = {
    # 'pretrained_backbone_cnn_branch': ('pretrained', 'simple', 'simple'), 
    # 'linear_branch': ('3conv', 'linear', 'simple'),
    # 'cnn_branch': ('3conv', 'simple', 'simple'),
    # 'adversarial': ('2conv', 'adversarial', 'adversarial'),
    'cnn_specialised_branches': ('3conv', 'special', 'simple')
}

for name, (backbone, branch, end) in testing_scenarios.items():
    print(f"\nTesting: {backbone} + {branch} + {end}")
    dual_model = DualBranchCNNNet(backbone_type=backbone, branch_type=branch, end_type=end, batch_norm=True).to(device)
    optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
    buffer = NaiveRehearsalBuffer(buffer_size=1000)

    def cos_criterion(a, b):
        return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()

    exp_name = f"CNN_{name}_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
    writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

    dualbranch_kwargs = {
        'mse_criterion': nn.MSELoss(),
        'ce_criterion': nn.CrossEntropyLoss(),
        'cos_criterion': cos_criterion,
        'domain_to_idx': domain_to_idx,
        'bce_criterion': nn.BCEWithLogitsLoss()
    }
    unified_train_loop(
        model=dual_model,
        domains=domains,
        domain_dataloaders=domain_dataloaders,
        buffer=buffer,
        optimizer=optimizer,
        writer=writer,
        device=device,
        batch_fn=dualbranch_batch,
        batch_kwargs=dualbranch_kwargs,
        num_epochs=10,
        exp_name=exp_name,
        gradient_clipping=False,
        detach_base=False,
        binary = True,
        full_replay = True,
        collect_tsne_data=False
    )


In [None]:
import datetime
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
domain_to_idx = {d: i for i, d in enumerate(domains)}

testing_scenarios = {
    'pretrained_simple': ('pretrained', 'simple', '3linear', False), 
    'pretrained_special': ('pretrained', 'special', '3linear', False),
    '3conv_simple': ('3conv', 'simple', '3linear', True),
    '3conv_special': ('3conv', 'special', '3linear', True),
}

for name, (backbone, branch, end, detach_base) in testing_scenarios.items():
    print(f"\nTesting: {backbone} + {branch} + {end}")
    dual_model = DualBranchCNNNet(backbone_type=backbone, branch_type=branch, end_type=end, batch_norm=True).to(device)
    optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
    buffer = NaiveRehearsalBuffer(buffer_size=1000)

    def cos_criterion(a, b):
        return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()

    exp_name = f"CNN_{name}_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
    writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

    dualbranch_kwargs = {
        'mse_criterion': nn.MSELoss(),
        'ce_criterion': nn.CrossEntropyLoss(),
        'cos_criterion': cos_criterion,
        'domain_to_idx': domain_to_idx,
        'bce_criterion': nn.BCEWithLogitsLoss()
    }
    unified_train_loop(
        model=dual_model,
        domains=domains,
        domain_dataloaders=domain_dataloaders,
        buffer=buffer,
        optimizer=optimizer,
        writer=writer,
        device=device,
        batch_fn=dualbranch_batch,
        batch_kwargs=dualbranch_kwargs,
        num_epochs=10,
        exp_name=exp_name,
        gradient_clipping=True,
        detach_base=detach_base,
        binary = True,
        full_replay = True,
        collect_tsne_data=False
    )

In [None]:
import datetime
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
domain_to_idx = {d: i for i, d in enumerate(domains)}

testing_scenarios = {
    'pretrained_simple_0.25branches': ('pretrained', 'simple', '3linear'), 
    # '3conv_adversarial': ('3conv', 'adversarial', 'adversarial'),
}

for name, (backbone, branch, end) in testing_scenarios.items():
    print(f"\nTesting: {backbone} + {branch} + {end}")
    dual_model = DualBranchCNNNet(backbone_type=backbone, branch_type=branch, end_type=end, batch_norm=True).to(device)
    optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
    buffer = NaiveRehearsalBuffer(buffer_size=1000)

    def cos_criterion(a, b):
        return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()

    exp_name = f"CNN_{name}_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
    writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

    dualbranch_kwargs = {
        'mse_criterion': nn.MSELoss(),
        'ce_criterion': nn.CrossEntropyLoss(),
        'cos_criterion': cos_criterion,
        'domain_to_idx': domain_to_idx,
        'bce_criterion': nn.BCEWithLogitsLoss()
    }
    unified_train_loop(
        model=dual_model,
        domains=domains,
        domain_dataloaders=domain_dataloaders,
        buffer=buffer,
        optimizer=optimizer,
        writer=writer,
        device=device,
        batch_fn=dualbranch_batch,
        batch_kwargs=dualbranch_kwargs,
        num_epochs=10,
        exp_name=exp_name,
        gradient_clipping=True,
        detach_base=False,
        binary = True,
        full_replay = True,
        collect_tsne_data=False,
        loss_params = {'head': 0.5, 'social': 0.25, 'room': 0.25}
    )

In [None]:
import datetime
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
domain_to_idx = {d: i for i, d in enumerate(domains)}

testing_scenarios = {
    # 'pretrained_simple': ('pretrained', 'simple', '3linear'), 
    '3conv_adversarial': ('3conv', 'adversarial', 'adversarial'),
}

for name, (backbone, branch, end) in testing_scenarios.items():
    print(f"\nTesting: {backbone} + {branch} + {end}")
    fold_history = {}
    for fold_num, domain_dataloaders in tqdm(fold_loaders.items()):
        dual_model = DualBranchCNNNet(backbone_type=backbone, branch_type=branch, end_type=end, batch_norm=True).to(device)
        optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
        buffer = NaiveRehearsalBuffer(buffer_size=1000)

        def cos_criterion(a, b):
            return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()

        exp_name = f"fold{fold_num}_CNN_{name}_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
        writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

        dualbranch_kwargs = {
            'mse_criterion': nn.MSELoss(),
            'ce_criterion': nn.CrossEntropyLoss(),
            'cos_criterion': cos_criterion,
            'domain_to_idx': domain_to_idx,
            'bce_criterion': nn.BCEWithLogitsLoss()
        }
        history = unified_train_loop(
            model=dual_model,
            domains=domains,
            domain_dataloaders=domain_dataloaders,
            buffer=buffer,
            optimizer=optimizer,
            writer=writer,
            device=device,
            batch_fn=dualbranch_batch,
            batch_kwargs=dualbranch_kwargs,
            num_epochs=10,
            exp_name=exp_name,
            gradient_clipping=True,
            detach_base=False,
            binary = True,
            full_replay = True,
            collect_tsne_data=False,
            loss_params = {'head': 0.5, 'social': 0.25, 'room': 0.25}
        )
        fold_history[fold_num] = history
    with open(f"../checkpoints/5foldcrossval_{exp_name}_history.pkl", "wb") as f:
        pickle.dump(fold_history, f)

In [None]:
import datetime
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
domain_to_idx = {d: i for i, d in enumerate(domains)}


dual_model = DualBranchNet_DANN().to(device)
optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(buffer_size=1000)

def cos_criterion(a, b):
    return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()

exp_name = f"DANN_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

dualbranch_kwargs = {
    'mse_criterion': nn.MSELoss(),
    'domain_to_idx': domain_to_idx,
    'bce_criterion': nn.BCEWithLogitsLoss()
}
history = unified_train_loop(
    model=dual_model,
    domains=domains,
    domain_dataloaders=domain_dataloaders,
    buffer=buffer,
    optimizer=optimizer,
    writer=writer,
    device=device,
    batch_fn=dann_batch,
    batch_kwargs=dualbranch_kwargs,
    num_epochs=10,
    exp_name=exp_name,
    gradient_clipping=True,
    detach_base=False,
    binary = True,
    full_replay = True,
    collect_tsne_data=False,
    loss_params = {'head': 1, 'social': 2}
)

In [None]:
import datetime
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
domain_to_idx = {d: i for i, d in enumerate(domains)}

domain_dataloaders = {}
for domain in domains:
    domain_df = df[df['domain'] == domain]
    loaders, split_idx = create_dataloaders(domain_df, batch_sizes=(32, 64, 64), resize_img_to=(512,288), seed=SEED, return_splits=True)  #512, 288 # 352,128 #LGR had 128,128 MobileNetv2 had 224, 224
    domain_dataloaders[domain] = loaders

baseline_model = LGRBaseline().to(device)
optimizer = torch.optim.Adam(baseline_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(buffer_size=1000)

exp_name = f"baselinemodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

baseline_kwargs = {'mse_criterion': nn.MSELoss()}
unified_train_loop(
    model=baseline_model,
    domains=domains,
    domain_dataloaders=domain_dataloaders,
    buffer=buffer,
    optimizer=optimizer,
    writer=writer,
    device=device,
    batch_fn=baseline_batch,
    batch_kwargs=baseline_kwargs,
    num_epochs=10,
    exp_name=exp_name,
    loss_params={}
)

domain_dataloaders = {}
for domain in domains:
    domain_df = df[df['domain'] == domain]
    loaders, split_idx = create_dataloaders(domain_df, batch_sizes=(32, 64, 64), resize_img_to=(256,144), seed=SEED, return_splits=True)  #512, 288 # 352,128 #LGR had 128,128 MobileNetv2 had 224, 224
    domain_dataloaders[domain] = loaders

testing_scenarios = {
    'simple': ('3conv', 'simple', False, 1000), 
    'simple_buffer05': ('3conv', 'simple', False, 500),
    'special': ('3conv', 'special', False, 1000),
    'special_bnorm': ('3conv', 'special', True, 1000)
}

for name, (backbone, branch, batch_norm, buffer_size) in testing_scenarios.items():
    print(f"\nTesting: {backbone=} {branch=} {batch_norm=} {buffer_size=}")
    dual_model = DualBranchNet_minimal(backbone_type=backbone, branch_type=branch, head_type='3layer', batch_norm=batch_norm).to(device)
    optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
    buffer = NaiveRehearsalBuffer(buffer_size=buffer_size)

    exp_name = f"minimal_{name}_dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
    writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

    baseline_kwargs = {'mse_criterion': nn.MSELoss()}
    unified_train_loop(
        model=dual_model,
        domains=domains,
        domain_dataloaders=domain_dataloaders,
        buffer=buffer,
        optimizer=optimizer,
        writer=writer,
        device=device,
        batch_fn=baseline_batch,
        batch_kwargs=baseline_kwargs,
        num_epochs=10,
        exp_name=exp_name,
        gradient_clipping=True,
        detach_base=False,
        binary = True,
        full_replay = True,
        collect_tsne_data=False,
        loss_params={}
    )

In [None]:
import datetime
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
domain_to_idx = {d: i for i, d in enumerate(domains)}

testing_scenarios = {
    # 'dualbranch_CNN_1epoch': (DualBranchCNNNet('3conv', 'simple', 'simple'), 500, 1), 
    'minimal_absurd_buff500': (DualBranchNet_minimal(backbone_type='absurd', branch_type='absurd', head_type='absurd'), 500, 10),
}

for name, (model, buffer_size, epochs) in testing_scenarios.items():
    print(f"\nTesting: {name}")
    dual_model = model.to(device)
    optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
    buffer = NaiveRehearsalBuffer(buffer_size=buffer_size)

    exp_name = f"{name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
    writer = SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

    def cos_criterion(a, b):
        return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    
    baseline_kwargs = {'mse_criterion': nn.MSELoss()}
    dualbranch_kwargs = {
            'mse_criterion': nn.MSELoss(),
            'ce_criterion': nn.CrossEntropyLoss(),
            'cos_criterion': cos_criterion,
            'domain_to_idx': domain_to_idx,
            'bce_criterion': nn.BCEWithLogitsLoss()
        }
    
    unified_train_loop(
        model=dual_model,
        domains=domains,
        domain_dataloaders=domain_dataloaders,
        buffer=buffer,
        optimizer=optimizer,
        writer=writer,
        device=device,
        batch_fn=baseline_batch,
        batch_kwargs=dualbranch_kwargs,
        num_epochs=epochs,
        exp_name=exp_name,
        gradient_clipping=True,
        detach_base=False,
        binary = True,
        full_replay = True,
        collect_tsne_data=False,
        loss_params={}
    )

In [None]:
import datetime
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
domain_to_idx = {d: i for i, d in enumerate(domains)}

testing_scenarios = {
    'DANN_notrain1dom': (DualBranchNet_DANN(), 500, 10, 1)
}

for name, (model, buffer_size, epochs, alpha) in testing_scenarios.items():
    print(f"\nTesting: {name}")
    dual_model = model.to(device)
    optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
    buffer = NaiveRehearsalBuffer(buffer_size=buffer_size)

    exp_name = f"{name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
    writer = None #SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

    def cos_criterion(a, b):
        return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    
    dualbranch_kwargs = {
            'mse_criterion': nn.MSELoss(),
            'ce_criterion': nn.CrossEntropyLoss(),
            'cos_criterion': cos_criterion,
            'domain_to_idx': domain_to_idx,
            'bce_criterion': nn.BCEWithLogitsLoss(),
            'alpha': alpha,
        }
    
    unified_train_loop(
        model=dual_model,
        domains=domains,
        domain_dataloaders=domain_dataloaders,
        buffer=buffer,
        optimizer=optimizer,
        writer=writer,
        device=device,
        batch_fn=dann_batch,
        batch_kwargs=dualbranch_kwargs,
        num_epochs=epochs,
        exp_name=exp_name,
        gradient_clipping=True,
        detach_base=False,
        binary = True,
        full_replay = True,
        collect_tsne_data=False,
        loss_params={'head':0.5, 'social':0.5}
    )

testing_scenarios = {
    'dualbranch_CNN_1epoch': (DualBranchCNNNet(backbone_type='3conv', branch_type='simple', end_type='3linear'), 500, 1, 1), 
    'dualbranch_CNN_dynamicalpha': (DualBranchCNNNet(backbone_type='3conv', branch_type='simple', end_type='3linear'), 500, 10, 0), 
}

for name, (model, buffer_size, epochs, alpha) in testing_scenarios.items():
    print(f"\nTesting: {name}")
    dual_model = model.to(device)
    optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
    buffer = NaiveRehearsalBuffer(buffer_size=buffer_size)

    exp_name = f"{name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
    writer = None #SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

    def cos_criterion(a, b):
        return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    
    dualbranch_kwargs = {
            'mse_criterion': nn.MSELoss(),
            'ce_criterion': nn.CrossEntropyLoss(),
            'cos_criterion': cos_criterion,
            'domain_to_idx': domain_to_idx,
            'bce_criterion': nn.BCEWithLogitsLoss(),
            'alpha': alpha,
        }
    
    unified_train_loop(
        model=dual_model,
        domains=domains,
        domain_dataloaders=domain_dataloaders,
        buffer=buffer,
        optimizer=optimizer,
        writer=writer,
        device=device,
        batch_fn=dualbranch_batch,
        batch_kwargs=dualbranch_kwargs,
        num_epochs=epochs,
        exp_name=exp_name,
        gradient_clipping=True,
        detach_base=False,
        binary = True,
        full_replay = True,
        collect_tsne_data=False,
        loss_params={'head':1, 'social':1, 'room':1}
    )

In [None]:
import datetime
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
domain_to_idx = {d: i for i, d in enumerate(domains)}

testing_scenarios = {
    '3conv_simple_simple_bnorm_nogradclip': (DualBranchCNNNet(backbone_type='3conv', branch_type='simple', end_type='simple', batch_norm=True), 500, 10, 1, False), 
    '3conv_simple_simple_nogradclip': (DualBranchCNNNet(backbone_type='3conv', branch_type='simple', end_type='simple', batch_norm=False), 500, 10, 1, False), 
    '3conv_simple_simple_bnorm': (DualBranchCNNNet(backbone_type='3conv', branch_type='simple', end_type='simple', batch_norm=True), 500, 10, 1, True),
    '3conv_simple_3linear_bnorm_nogradclip': (DualBranchCNNNet(backbone_type='3conv', branch_type='simple', end_type='simple', batch_norm=True), 500, 10, 1, False),
}

for name, (model, buffer_size, epochs, alpha, gradient_clip) in testing_scenarios.items():
    print(f"\nTesting: {name}")
    dual_model = model.to(device)
    optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
    buffer = NaiveRehearsalBuffer(buffer_size=buffer_size)

    exp_name = f"{name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
    writer = None #SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

    def cos_criterion(a, b):
        return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    
    dualbranch_kwargs = {
            'mse_criterion': nn.MSELoss(),
            'ce_criterion': nn.CrossEntropyLoss(),
            'cos_criterion': cos_criterion,
            'domain_to_idx': domain_to_idx,
            'bce_criterion': nn.BCEWithLogitsLoss(),
            'alpha': alpha,
        }
    
    unified_train_loop(
        model=dual_model,
        domains=domains,
        domain_dataloaders=domain_dataloaders,
        buffer=buffer,
        optimizer=optimizer,
        writer=writer,
        device=device,
        batch_fn=dualbranch_batch,
        batch_kwargs=dualbranch_kwargs,
        num_epochs=epochs,
        exp_name=exp_name,
        gradient_clipping=gradient_clip,
        detach_base=False,
        binary = True,
        full_replay = True,
        collect_tsne_data=False,
        loss_params={'head':1, 'social':1, 'room':1}
    )

In [None]:
import datetime
from torch.utils.tensorboard import SummaryWriter

df = pd.read_pickle("../data/pepper_data.pkl")
domain_dataloaders = get_dataloader(df, batch_sizes=(32, 64, 64), resize_img_to=(224,224), return_splits=False, double_img=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
domain_to_idx = {d: i for i, d in enumerate(domains)}

testing_scenarios = {
    'deeplabv3mobilenetv3_dann': (DANN_classifier_poc(), 500, 30, 0, True),
}

for name, (model, buffer_size, epochs, alpha, gradient_clip) in testing_scenarios.items():
    print(f"\nTesting: {name}")
    dual_model = model.to(device)
    optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
    buffer = NaiveRehearsalBuffer(buffer_size=buffer_size)

    exp_name = f"{name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
    writer = None #SummaryWriter(log_dir=f"../tensorboard/{exp_name}")

    def cos_criterion(a, b):
        return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    
    dualbranch_kwargs = {
            'mse_criterion': nn.MSELoss(),
            'ce_criterion': nn.CrossEntropyLoss(),
            'cos_criterion': cos_criterion,
            'domain_to_idx': domain_to_idx,
            'bce_criterion': nn.BCEWithLogitsLoss(),
            'alpha': alpha,
        }
    
    unified_train_loop(
        model=dual_model,
        domains=domains,
        domain_dataloaders=domain_dataloaders,
        buffer=buffer,
        optimizer=optimizer,
        writer=writer,
        device=device,
        batch_fn=dann_batch,
        batch_kwargs=dualbranch_kwargs,
        num_epochs=epochs,
        exp_name=exp_name,
        gradient_clipping=gradient_clip,
        detach_base=False,
        binary = True,
        full_replay = True,
        collect_tsne_data=True,
        loss_params={'head':1, 'social':1}
    )

In [None]:
#didit work, the architecture is flawed. Assesing category of features based on catastrophic forgetting makes no sense, 
# the features are not experiencing CF, the weights are. 
# But passing weights to split branches is not justified.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
domain_to_idx = {d: i for i, d in enumerate(domains)}

df = pd.read_pickle("../data/pepper_data.pkl")
domain_dataloaders = get_dataloader(df, batch_sizes=(32, 64, 64), resize_img_to=(224,224), return_splits=False, double_img=False)

# CORRECTED: Initialize model with proper parameter names
cf_model = CatastrophicForgettingDisentanglementModel(
    backbone_output_channels=640,
    branch_hidden_channels=256,
    branch_output_channels=128,
    num_outputs=9
).to(device)

# CORRECTED: Optimizers with proper parameter paths
main_optimizer = torch.optim.Adam([
    {'params': cf_model.forgetting_adapter.adapter.parameters()},
    {'params': cf_model.branch_forgetting.parameters()},
    {'params': cf_model.branch_not_forgetting.parameters()},
    {'params': cf_model.head.parameters()},
    {'params': cf_model.fusion.parameters()}
], lr=1e-3)

gating_optimizer = torch.optim.Adam(
    cf_model.forgetting_adapter.gating.parameters(),
    lr=1e-4
)

restart = {
'global_step':
'history':
'domain':
}

# Use your existing buffer and data
buffer = NaiveRehearsalBuffer(buffer_size=1000)

# Modified batch kwargs for our model
cf_kwargs = {
    'mse_criterion': nn.MSELoss(),
    'domain_to_idx': domain_to_idx,
}

# Run training with corrected functions
exp_name = f"forgetinggated_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"

history = catastrophic_forgetting_train_loop(
    model=cf_model,
    domains=domains,
    domain_dataloaders=domain_dataloaders,
    buffer=buffer,
    optimizer=main_optimizer,
    gating_optimizer=gating_optimizer,
    device=device,
    batch_fn=catastrophic_forgetting_batch,
    batch_kwargs=cf_kwargs,
    num_epochs=10,
    exp_name=exp_name,
    gradient_clipping=True,
    loss_params={'main': 1.0, 'adapter': 0.5, 'gating': 0.1},
    restart = restart
)

In [None]:
import datetime

df = pd.read_pickle("../data/pepper_data.pkl")
domain_dataloaders = get_dataloader(df, batch_sizes=(32, 64, 64), resize_img_to=(288,512), return_splits=False, double_img=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
domain_to_idx = {d: i for i, d in enumerate(domains)}

testing_scenarios = {
    # 'dann_mobilenet': (DualBranchNet_DANN(backbone_type='mobilenet'), 500, 10, 0, False, dann_batch),
    # 'dual_2conv_adversarial_3linear_detach': (DualBranchCNNNet(9,6,'2conv', 'adversarial', '3linear'), 500, 10, 0, True, dualbranch_batch),
    # 'dual_mobilenet_simple_3linear': (DualBranchCNNNet(9,6,'pretrained', 'simple', '3linear'), 500, 10, 0, False, dualbranch_batch),
    # 'dual_mobilenet_linear_simple': (DualBranchCNNNet(9,6,'pretrained', 'linear', 'simple'), 500, 10, 0, False, dualbranch_batch),
    'baseline': (LGRBaseline(), 500, 10, 0, False, baseline_batch)
}

for name, (model, buffer_size, epochs, alpha, detach_base, batch_fn) in testing_scenarios.items():
    print(f"\nTesting: {name}")
    dual_model = model.to(device)
    optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
    buffer = NaiveRehearsalBuffer(buffer_size=buffer_size)

    exp_name = f"{name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
    writer = None

    def cos_criterion(a, b):
        return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    
    dualbranch_kwargs = {
            'mse_criterion': nn.MSELoss(),
            'ce_criterion': nn.CrossEntropyLoss(),
            'cos_criterion': cos_criterion,
            'domain_to_idx': domain_to_idx,
            'bce_criterion': nn.BCEWithLogitsLoss(),
            'alpha': alpha,
        }
    
    unified_train_loop(
        model=dual_model,
        domains=domains,
        domain_dataloaders=domain_dataloaders,
        buffer=buffer,
        optimizer=optimizer,
        writer=writer,
        device=device,
        batch_fn=batch_fn,
        batch_kwargs=dualbranch_kwargs,
        num_epochs=epochs,
        exp_name=exp_name,
        gradient_clipping=True,
        detach_base=detach_base,
        binary = True,
        full_replay = True,
        collect_tsne_data=False,
        loss_params={'head': 1, 'social': 1, 'room': 0.5}
    )

In [None]:
df = pd.read_pickle("../data/pepper_data.pkl")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
domain_to_idx = {d: i for i, d in enumerate(domains)}

domain_dataloaders = get_dataloader(df, batch_sizes=(32, 64, 64), resize_img_to=(288,512), return_splits=False, double_img=False)

testing_scenarios = {
    # 'dann_mobilenet_buff500': (DualBranchNet_DANN(backbone_type='mobilenet'), 500, 10, 0, False, dann_batch),
    'dann_mobilenet_buff120': (DualBranchNet_DANN(backbone_type='mobilenet'), 120, 10, 0, False, dann_batch),
    # 'dual_2conv_adversarial_3linear_detach_buff500': (DualBranchCNNNet(9,6,'2conv', 'adversarial', '3linear'), 500, 10, 0, True, dualbranch_batch),
    # 'dual_2conv_adversarial_3linear_detach_buff120': (DualBranchCNNNet(9,6,'2conv', 'adversarial', '3linear'), 120, 10, 0, True, dualbranch_batch),
    # 'dual_mobilenet_simple_3linear_buff500': (DualBranchCNNNet(9,6,'pretrained', 'simple', '3linear'), 500, 10, 0, False, dualbranch_batch),
    # 'dual_mobilenet_simple_3linear_buff120': (DualBranchCNNNet(9,6,'pretrained', 'simple', '3linear'), 120, 10, 0, False, dualbranch_batch),
    # 'dual_mobilenet_linear_simple_buff120': (DualBranchCNNNet(9,6,'pretrained', 'linear', 'simple'), 120, 10, 0, False, dualbranch_batch),
    # 'baseline_buff500': (LGRBaseline(), 500, 10, 0, False, baseline_batch),
    # 'baseline_buff120': (LGRBaseline(), 120, 10, 0, False, baseline_batch)
}

for name, (model, buffer_size, epochs, alpha, detach_base, batch_fn) in testing_scenarios.items():
    print(f"\nTesting: {name}")
    dual_model = model.to(device)
    optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
    buffer = NaiveRehearsalBuffer(buffer_size=buffer_size)

    exp_name = f"{name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
    writer = None

    def cos_criterion(a, b):
        return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    
    dualbranch_kwargs = {
            'mse_criterion': nn.MSELoss(),
            'ce_criterion': nn.CrossEntropyLoss(),
            'cos_criterion': cos_criterion,
            'domain_to_idx': domain_to_idx,
            'bce_criterion': nn.BCEWithLogitsLoss(),
            'alpha': alpha,
        }
    
    unified_train_loop(
        model=dual_model,
        domains=domains,
        domain_dataloaders=domain_dataloaders,
        buffer=buffer,
        optimizer=optimizer,
        writer=writer,
        device=device,
        batch_fn=batch_fn,
        batch_kwargs=dualbranch_kwargs,
        num_epochs=epochs,
        exp_name=exp_name,
        gradient_clipping=True,
        detach_base=detach_base,
        binary = True,
        full_replay = True,
        collect_tsne_data=False,
        loss_params={'head': 1, 'social': 1, 'room': 0.5}
    )

# import subprocess

# model_names = [
#     'dann_mobilenet_buff500',
#     'dann_mobilenet_buff120',
#     'dual_2conv_adversarial_3linear_detach_buff500',
#     'dual_2conv_adversarial_3linear_detach_buff120',
#     'dual_mobilenet_simple_3linear_buff500',
#     'dual_mobilenet_simple_3linear_buff120',
#     'dual_mobilenet_linear_simple_buff120',
#     'baseline_buff500',
#     'baseline_buff120'
#  ]

# for model_name in model_names:
#     print(f"Running {model_name}")
#     result = subprocess.run(
#         ["python", "train_models.py", "--model_name", model_name],
#         capture_output=True, text=True
#     )
#     print(result.stdout)
#     if result.stderr:
#         print("Error:", result.stderr)


### current

In [None]:
# TODO tests: retrain the model with
# 3 trains averaged
# one branch, no mask
# top branch + mask
# bottom branch + mask
# both branches + random masks

In [None]:

testing_scenarios = {
    # 'heuristic_small_env': (DualBranchModel(), [(288,512), (144,256)], False),
    # 'heuristic_square_img': (DualBranchModel(), [(224,224)]*2, False),
    'heuristic_eval_buffer': (DualBranchModel(), [(288,512)]*2, True)
}

for name, (model, img_size, eval_buffer) in testing_scenarios.items():

    transform_soc = transforms.Compose([
        transforms.Resize(img_size[0]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225])
    ])
    transform_env = transforms.Compose([
        transforms.Resize(img_size[1]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225])
    ])

    df = pd.read_pickle("../data/pepper_data.pkl")
    df['image_path_env'] = df['image_path'].apply(lambda p: str(Path('../data/masked/environment') / Path(p).name))
    df['image_path_social'] = df['image_path'].apply(lambda p: str(Path('../data/masked/social') / Path(p).name))
    domain_dataloaders = get_dataloader(df, batch_sizes=(16, 64, 64), return_splits=False, double_img=True, transforms=[transform_soc, transform_env], num_workers=0)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    domains = df['domain'].unique()
    domain_to_idx = {d: i for i, d in enumerate(domains)}


    print(f"\nTesting: {name}")
    dual_model = model.to(device)
    # optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
    optimizer = optim.Adam([
        {'params': dual_model.social_branch.parameters()},
        {'params': dual_model.env_branch.parameters()},
        {'params': dual_model.head.parameters()},
    ], lr=1e-3)
    classifier_optimizer = optim.Adam([ 
        {'params': dual_model.social_classifier.parameters()},
        {'params': dual_model.env_classifier.parameters()}
    ], lr=1e-3)

    buffer = NaiveRehearsalBuffer(buffer_size=120)
    if eval_buffer:
        eval_buffer = NaiveRehearsalBuffer(buffer_size=120)

    exp_name = f"{name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
    writer = None

    def cos_criterion(a, b):
        return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    
    dualbranch_kwargs = {
            'mse_criterion': nn.MSELoss(),
            'ce_criterion': nn.CrossEntropyLoss(),
            'cos_criterion': cos_criterion,
            'domain_to_idx': domain_to_idx,
            'bce_criterion': nn.BCEWithLogitsLoss(),
            'alpha': 0,
            'class_optimizer': classifier_optimizer
        }
    
    unified_train_loop(
        model=dual_model,
        domains=domains,
        domain_dataloaders=domain_dataloaders,
        buffer=buffer,
        optimizer=optimizer,
        writer=writer,
        device=device,
        batch_fn=heuristic_dualbranch_batch,
        batch_kwargs=dualbranch_kwargs,
        num_epochs=10,
        exp_name=exp_name,
        gradient_clipping=True,
        detach_base=False,
        binary = True,
        full_replay = True,
        collect_tsne_data=False,
        loss_params={'head': 1, 'social': 1, 'room': 0.5},
        eval_buffer=eval_buffer
    )

In [None]:
import subprocess

model_names = [
    'heuristic_small_env',
    'heuristic_square_img',
    'heuristic_eval_buffer'
]

for model_name in model_names:
    print(f"Running {model_name}")
    process = subprocess.Popen(
        ["python", "train_models.py", "--model_name", model_name, "--num_workers", "0"],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True
    )

    for line in process.stdout:
        print(line, end='')

    process.wait()

In [None]:
# Ideas for potential ways to combine networks
residual = SpecificHead(base + invariant_feats)
specific_feats = invariant_feats + residual

gate = GateNet(base + invariant_feats)
residual = SpecificHead(base)
specific_feats = invariant_feats + gate * residual

residual = Attention(invariant_feats, base)
specific_feats = invariant_feats + residual

gamma, beta = SpecificHead(base)
specific_feats = gamma * invariant_feats + beta

invariant_feats = InvariantHead(base)
specific_feats = SpecificHead(base)
cos_sim(invariant_feats, specific_feats) ~ 0


## Evaluation

### files

In [None]:
# different cos 1-abs fixed residual - fixedred_fixedcos_dualbranchmodel_20250604_014811_history
#fixed res but old cosine ^2 - fixedred_dualbranchmodel_20250604_002951_history
#broken baseline - baselinemodel_20250603_200948_history
# recent dualbranch - dualbranchmodel_20250603_135728_history
#normalised and standarised residual - dualbranchmodel_20250609_013632_history
#standarised residual - nonorm_dualbranchmodel_20250609_131820_history
#no residual connection - nores_dualbranchmodel_20250609_232842_history
#same plus gradient clipping - nores_gradclip_dualbranchmodel_20250610_010647_history
#same, no graidient clipping, weights initialisation and normalisation - nores_winit_wnorm_dualbranchmodel_20250610_024401_history
#same, gradient clipping and weights optimisation - nores_gradclip_winit_wnorm_dualbranchmodel_20250610_042118_history
#deeper invariant network - deep_norm_dualbranchmodel_20250610_223350_history
#same but with weiths normalisation - deep_dualbranchmodel_20250610_205411_history

with open('../checkpoints/nonorm_dualbranchmodel_20250609_131820_history.pkl', 'rb') as f:
    nonorm = pickle.load(f)

with open('../checkpoints/nores_dualbranchmodel_20250609_232842_history.pkl', 'rb') as f:
    lnorm_nores = pickle.load(f)

with open('../checkpoints/nores_gradclip_dualbranchmodel_20250610_010647_history.pkl', 'rb') as f:
    lnorm_nores_gradclip = pickle.load(f)

with open('../checkpoints/dualbranchmodel_20250609_013632_history.pkl', 'rb') as f:
    lnorm = pickle.load(f)

with open('../checkpoints/nores_winit_wnorm_dualbranchmodel_20250610_024401_history.pkl', 'rb') as f:
    wnorm_nores = pickle.load(f)

with open('../checkpoints/nores_gradclip_winit_wnorm_dualbranchmodel_20250610_042118_history.pkl', 'rb') as f:
    wnorm_nores_gradclip = pickle.load(f)

with open('../checkpoints/deep_norm_dualbranchmodel_20250610_223350_history.pkl', 'rb') as f:
    wnorm_gradclip_deep = pickle.load(f)

with open('../checkpoints/deep_dualbranchmodel_20250610_205411_history.pkl', 'rb') as f:
    lnorm_gradclip_deep = pickle.load(f)

In [None]:
# all have gradient clipping
# Either detach and skip connect base -> head
# - detach reference - how does detachement work? can backbone learn through skip connection?
# - detach smaller batch effect?
# - detach how partial freeze backbone learns?
# or
# freeze backbone
# - frozen backbone doesn't learn, branches learn solo
# - frozen, branches solo learn but binary?
# - frozen branches solo learn binary but explicit gradient reverse function rather than layer?
# - add full replay? 

files = [
    'bbdetach_deep_dualbranchmodel_20250613_155613_history',                 #deep weights_init=False, detach_base=True, - base for below  experiments
    'bdetach_batch16_dualbranchmodel_20250615_025032_history', # deep unfrozen bb, detached bb, smaller 16 batch - did smaller batch improve stability, did skip connection improve head loss and total training/val loss?
    'bdetach_pfrozen_dualbranchmodel_20250615_040906_history', # deep pfrozen bb, detached bb - training branches+last layer bb - did partially frozen pretrained backbone improve skip head connection?
   
    'ffrozen_dualbranchmodel_20250615_133027_history',                       #deep weights_init=False, detach_base=False, -how do deep linear branches train on their own?
    'ffrozen_binary_dualbranchmodel_20250615_153755_history',                      #weights_init=False, detach_base=False, explicit_grl=False, full_replay = False - does binary improve anything?
    'ffrozen_binary_explicitgrl_dualbranchmodel_20250615_162820_history',          #weights_init=False, detach_base=False, explicit_grl=False, full_replay = False - does the gradient reversal layer work better or function?
    'freplay_ffrozen_binary_explicitgrl_dualbranchmodel_20250615_174623_history'  #weights_init=True,  detach_base=False, explicit_grl=True,  full_replay = True - does full repaly help the branches on their own?
]

In [None]:
files = [
    # 'CNN_pretrained_simple_simple_dualbranchmodel_20250616_032555_history',  # 'pretrained_backbone_cnn_branch': ('pretrained', 'simple', 'simple'), 
    # 'CNN_diy_backbone_linear_branch_dualbranchmodel_20250616_105211_history', # (2layer convolution backbone, 'linear', 'simple'),
    # 'CNN_linear_branch_dualbranchmodel_20250616_172252_history',  # 'linear_branch': ('3conv', 'linear', 'simple'),
    'CNN_cnn_branch_dualbranchmodel_20250616_190434_history',      # 'cnn_branch': ('3conv', 'simple', 'simple'),
    'CNN_adversarial_dualbranchmodel_20250616_201008_history',      # 'adversarial': ('2conv', 'adversarial', 'adversarial'),
    'CNN_cnn_specialised_branches_dualbranchmodel_20250616_232330_history',      # 'cnn_specialised_branches': ('3conv', 'special', 'simple')
    'CNN_pretrained_simple_0.25branches_dualbranchmodel_20250617_230422_history',
    'CNN_pretrained_special_dualbranchmodel_20250617_032445_history'
]


    

In [None]:
files = ['CNN_pretrained_simple_dualbranchmodel_20250617_022356_history', #'pretrained_simple': ('pretrained', 'simple', '3linear', detach_base=False), 
'CNN_pretrained_special_dualbranchmodel_20250617_032445_history', #    'pretrained_special': ('pretrained', 'special', '3linear', False),
'CNN_3conv_simple_dualbranchmodel_20250617_042555_history', #    '3conv_simple': ('3conv', 'simple', '3linear', True),
'CNN_3conv_special_dualbranchmodel_20250617_052951_history' #    '3conv_special': ('3conv', 'special', '3linear', True),
]

In [None]:
files = ['CNN_pretrained_simple_dualbranchmodel_20250617_022356_history',
'CNN_pretrained_simple_0.25branches_dualbranchmodel_20250617_230422_history' #change the loss proportions
]

In [None]:
files =['5foldcrossval_fold4_CNN_3conv_adversarial_dualbranchmodel_20250618_121607_history',
'5foldcrossval_fold4_CNN_pretrained_simple_dualbranchmodel_20250618_064848_history'
]
files = [
    'CNN_pretrained_simple_0.25branches_dualbranchmodel_20250617_230422_history',
    'CNN_adversarial_dualbranchmodel_20250616_201008_history'
]

In [None]:
files = ['DANN_dualbranchmodel_20250618_152753_history',
         'DANN_dynamicalpha_notrain1dom_20250620_201349_history',
         'DANN_dynamicalpha_notrain1dom_20250620_232330_history'
         ]

In [None]:
files = [
'baselinemodel_20250618_195623_history',
'minimal_simple_dualbranchmodel_20250618_213930_history',
'minimal_simple_buffer05_dualbranchmodel_20250618_224113_history',
'minimal_special_dualbranchmodel_20250618_233832_history'
]


In [None]:
files=[
'baselinemodel_20250603_200948_history',
'baselinemodel_20250603_034233_history',
'baselinemodel_20250618_195623_history'
]

In [None]:
files = [
    'baselinemodel_20250603_200948_history',
    'minimal_absurd_buff500_20250619_121348_history',
    'minimal_simple_buffer05_dualbranchmodel_20250618_224113_history',
    'DANN_dualbranchmodel_20250618_152753_history',
    'CNN_pretrained_simple_0.25branches_dualbranchmodel_20250617_230422_history'
    ]

In [None]:
files = [
'DANN_dynamicalpha_notrain1dom_20250620_232330_history',
'DANN_notrain1dom_20250621_005208_history',
]

In [None]:
files = [
'dualbranch_CNN_dynamicalpha_20250621_015610_history',
'CNN_adversarial_dualbranchmodel_20250616_201008_history',
'CNN_cnn_branch_dualbranchmodel_20250616_190434_history'
]

In [None]:
files =[
'dualbranch_CNN_1epoch_20250621_014841_history',
'CNN_pretrained_simple_025branches_dualbranchmodel_20250617_230422_history',
'CNN_pretrained_simple_dualbranchmodel_20250617_022356_history',
'CNN_adversarial_dualbranchmodel_20250616_201008_history'
]

In [None]:
files = [
'3conv_simple_3linear_bnorm_nogradclip_20250621_130858_history',
'3conv_simple_simple_bnorm_20250621_121437_history',
'3conv_simple_simple_bnorm_nogradclip_20250621_102634_history',
'3conv_simple_simple_nogradclip_20250621_112040_history',
]

In [None]:
files =['mobinev2_dann_20250621_204449_history',
        'deeplabv3mobilenetv3_dann_20250621_223924_history',
        'deeplabv3mobilenetv3_dann_20250624_000526_history'
        ]

In [None]:
files = [
    # ('dann', 'dann_mobilenet_20250630_014020_history'), #smoother, stable disentagnlement but not significantly better results
    # ('dann_old', 'mobinev2_dann_20250621_204449_history'), #old mobilenet dann
    ('detached', 'dual_2conv_adversarial_3linear_detach_20250630_045130_history'),
    ('detached_old', 'CNN_3conv_simple_dualbranchmodel_20250617_042555_history'), #detached old
    ('trainable_old', 'CNN_adversarial_dualbranchmodel_20250616_201008_history'), #trainable backbone pretrained adv adv old
    ('frozen', 'dual_mobilenet_simple_3linear_20250630_081954_history'),
    ('frozen_old', 'CNN_pretrained_simple_dualbranchmodel_20250617_022356_history'), #pretrained simple, 3linear old
    # 'dual_mobilenet_linear_simple_20250630_111817_history',  #performs poorly, no surprise taht linear layer canot extract meaningfull features from frozen backbone
    # ('baseline', 'baseline_20250630_161803_history'), #more stable, similar results
    # ('baseline_old', 'baselinemodel_20250618_195623_history'), #old baseline
]

In [None]:
files = [
('linear', 'dual_mobilenet_linear_simple_20250630_111817_history'),
('linear_b120', 'dual_mobilenet_linear_simple_buff120_20250704_081958_history'), 
('base', 'baseline_20250630_161803_history'),
('base_b500', 'baseline_buff500_20250704_085427_history'), 
('base_b120', 'baseline_buff120_20250704_113155_history'), 
('dann', 'dann_mobilenet_20250630_014020_history'),
('dann_b500', 'dann_mobilenet_buff500_20250704_003743_history'), 
('dann_b120', 'dann_mobilenet_buff120_20250704_031423_history'), 
('detach', 'dual_2conv_adversarial_3linear_detach_20250630_045130_history'),
('detach_b500', 'dual_2conv_adversarial_3linear_detach_buff500_20250704_041620_history'), 
('detach_b120', 'dual_2conv_adversarial_3linear_detach_buff120_20250704_054532_history'),
('frozen', 'dual_mobilenet_simple_3linear_20250630_081954_history'), 
('frozen_b500', 'dual_mobilenet_simple_3linear_buff500_20250704_062411_history'), 
('frozen_b120', 'dual_mobilenet_simple_3linear_buff120_20250704_074516_history'),
]

In [None]:
files = [
('base_b500', 'baseline_buff500_20250704_085427_history'), 
('base_b120', 'baseline_buff120_20250704_113155_history'), 
('new_b500','heuristic_dualbranch_buff500_20250704_213101_history'),
('new_b120','heuristic_dualbranch_buff120_20250705_031333_history'),
('frozen_b500', 'dual_mobilenet_simple_3linear_buff500_20250704_062411_history'), 
('frozen_b120', 'dual_mobilenet_simple_3linear_buff120_20250704_074516_history'),


]

In [None]:
files = [
    ('base_b120', 'baseline_buff120_20250704_113155_history'),
]

In [None]:
files = [
('base_b500', 'baseline_buff500_20250704_085427_history'), 
('base_b120', 'baseline_buff120_20250704_113155_history'), 
('new_b500','heuristic_dualbranch_buff500_20250704_213101_history'),
('new_b120','heuristic_dualbranch_buff120_20250705_031333_history'),
('frozen_b500', 'dual_mobilenet_simple_3linear_buff500_20250704_062411_history'), 
('frozen_b120', 'dual_mobilenet_simple_3linear_buff120_20250704_074516_history'),
('old_b500', 'CNN_pretrained_simple_dualbranchmodel_20250617_022356_history'), 
('linear_b1000', 'dual_mobilenet_linear_simple_20250630_111817_history'),
]

In [None]:
files = [
    ('base','heuristic_dualbranch_buff120_20250705_031333_history'),
    # ('small_env','heuristic_small_env_20250722_222924_history'),
    # ('square','heuristic_square_img_20250722_233743_history'),
    ('eval_buffer','heuristic_eval_buffer_20250723_144515_history'),
]

In [None]:
import os

pkl_files = [f for f in os.listdir('../checkpoints/') if f.endswith('.pkl')]
for file in pkl_files:
    print(file)

In [None]:
import pickle
models = {}
for i, file in enumerate(files):
    file_name = ''
    if isinstance(file, tuple):
        file_name, file = file
    with open(f'../checkpoints/{file}.pkl', 'rb') as f:
        model_name = file_name or '_'.join(file.split('_')[:-3])+str(i)
        models[model_name] = pickle.load(f)

In [None]:
# For fold histories
for k,m in models.items():
    models[k] = combine_fold_histories(m)

In [None]:
models2 = models

In [None]:
models.update(models2)

In [None]:
with open('../checkpoints/bbdetach_deep_dualbranchmodel_20250613_155613_history.pkl', 'rb') as f:
    bbdetach_deep = pickle.load(f)


In [None]:
pt_file = torch.load('../checkpoints/dualbranchmodel_20250609_013632_domainSmallOffice_epoch9_step1300.pt')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
model = DualBranchNet(num_domains=len(domains)).to(device)
model.load_state_dict(pt_file['model_state_dict'])
history = pt_file['history']
tsne_data = pt_file['tsne']

In [None]:
for model in models.keys():
    print(model)

In [None]:
for i in list(models.keys())[1:]:
    for j in models[i]['cross_domain_val']:
        for key in j:
            j[key] = j[key][0]

In [None]:
for i in list(models.keys())[1:]:
    models[i]['val_epoch_loss'] = [j[0] for j in models[i]['val_epoch_loss']]
    models[i]['val_buffer_epoch_loss'] = [j[0] for j in models[i]['val_buffer_epoch_loss']]

### plots

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
sns.set_theme(style="whitegrid")

In [None]:
models = {'bbdetach_deep':bbdetach_deep}

In [None]:
models = {
'nonorm': nonorm,
'lnorm': lnorm,
'lnorm_nores': lnorm_nores,
'lnorm_nores_gradclip': lnorm_nores_gradclip,
'wnorm_nores': wnorm_nores,
'wnorm_nores_gradclip': wnorm_nores_gradclip,
'wnorm_gradclip_deep': wnorm_gradclip_deep,
'lnorm_gradclip_deep': lnorm_gradclip_deep
}

In [None]:
# for i,m in models.items():
#     filtered = [x for i, x in enumerate(m['cross_domain_val']) if i in [9, 19, 29, 39, 49, 58]]
#     m['cross_domain_val'] = filtered

for i,m in models.items():
    print(len(m['cross_domain_val']))

In [None]:
tickvals=[0, 10, 20, 30, 40, 50]
np.multiply(np.array(tickvals), 2).tolist()


In [None]:
import plotly.graph_objects as go
traces = []
for model_name, history in models.items():
    traces.append(go.Scatter(
        x=list(range(len(history['train_epoch_loss']))),
        y=history['train_epoch_loss'],
        mode='lines',
        name=f'{model_name} Train Loss',
        visible=False
    ))
    traces.append(go.Scatter(
        x=list(range(len(history['val_epoch_loss']))),
        y=history['val_epoch_loss'],
        mode='lines',
        name=f'{model_name} Val Loss',
        visible=False
    ))

# Make the first model visible by default
for i in range(2):
    traces[i].visible = True

buttons = []
for i, model_name in enumerate(models.keys()):
    visible = [False] * len(traces)
    visible[2*i] = True
    visible[2*i + 1] = True
    buttons.append(dict(
        label=model_name,
        method='update',
        args=[{'visible': visible}, {'title': f'Training and Validation Loss - {model_name}'}]
    ))

fig = go.Figure(data=traces)
fig.update_layout(
    updatemenus=[dict(
        active=0,
        buttons=buttons,
        x=0.1,
        y=1.15,
        xanchor='left',
        yanchor='top'
    )],
    title='Training and Validation Loss - Model A',
    xaxis_title='Epochs',
    yaxis_title='MSE Loss',
    template='plotly_white',
    yaxis=dict(range=[0, 1])
)
fig.show()

In [None]:
import plotly.graph_objects as go

fig = go.Figure()

for model_name, history in models.items():
    fig.add_trace(go.Scatter(
        x=list(range(len(history['train_epoch_loss']))),
        y=history['train_epoch_loss'],
        mode='lines',
        name=f'{model_name} Train Loss'
    ))
    fig.add_trace(go.Scatter(
        x=list(range(len(history['val_epoch_loss']))),
        y=history['val_epoch_loss'],
        mode='lines',
        name=f'{model_name} Val Loss'
    ))
    try:
        fig.add_trace(go.Scatter(
            x=list(range(len(history['val_buffer_epoch_loss']))),
            y=history['val_buffer_epoch_loss'],
            mode='lines',
            name=f'{model_name} Buffer Val Loss'
        ))
    except:
        pass

fig.update_layout(
    title='Training and Validation Loss',
    xaxis_title='Epochs',
    yaxis_title='MSE Loss',
    template='plotly_white',
    yaxis=dict(range=[-0.0, 1]),
)
fig.update_xaxes(
    tickvals=[0, 10, 20, 30, 40, 50],
    # tickvals = np.multiply(np.array([0, 10, 20, 30, 40, 50]), 3).tolist(),
    ticktext=list(df['domain'].unique())
)

fig.show()

In [None]:
#plt
plt.figure(figsize=(20,7))
for name, model in models.items():
    plt.plot(model['train_epoch_loss'], label=f'{name} Train Loss')
    plt.plot(model['val_epoch_loss'], label=f'{name} Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
# plt.ylim(-0.01, 1)
plt.legend()
plt.show()

In [None]:
for i,m in models.items():
    print(i)

In [None]:
final = []
for name, model in models.items():
    for i in range(6):
        final += list(model['cross_domain_val'][i].values())
max_loss = max(final)
min_loss = min(final)

In [None]:
# plt
history = models['history']['cross_domain_val']

# Extract domain names
domains = list(history[0].keys())

# Prepare accuracy per domain over time
domain_scores = {domain: [] for domain in domains}
for snapshot in history:
    for domain in domains:
        domain_scores[domain].append(snapshot[domain])

# Plot
plt.figure(figsize=(12, 6))
for domain, scores in domain_scores.items():
    plt.plot(domains[:len(scores)], scores, label=domain, marker='o')

plt.xlabel("After training on domain X")
plt.ylabel("Loss")
plt.title("Domain-wise Accuracy Over Time")
# plt.ylim(min(-0.1, min_loss), max_loss)
# plt.ylim(-0.05, 0.1)
plt.legend()
plt.show()

In [None]:
import plotly.graph_objects as go

# Prepare data structure
domain_data = {}
for model_name, model_history in models.items():
    history = model_history['cross_domain_val']
    domains = list(history[0].keys())
    domain_data[model_name] = {
        domain: [snapshot[domain] for snapshot in history]
        for domain in domains
    }

# Create figure
fig = go.Figure()

# Color palette for domains
domain_colors = {
    domain: color for domain, color in zip(
        domains, 
        ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', "#05ffee"]
    )
}

# Add traces for each model and domain
for model_idx, (model_name, domains) in enumerate(domain_data.items()):
    for domain, scores in domains.items():
        fig.add_trace(go.Scatter(
            x=list(range(len(scores))),
            y=scores,
            mode='lines+markers',
            name=domain,
            line=dict(color=domain_colors[domain], width=2),
            marker=dict(size=8, symbol=model_idx+1),  # Unique symbol per model
            legendgroup=model_name,
            legendgrouptitle_text=model_name,
            visible=True if model_idx == 0 else 'legendonly'  # Show first model by default
        ))

# Create model selection buttons
buttons = [
    dict(
        label=model_name,
        method='update',
        args=[
            {'visible': [m == model_name for m in domain_data.keys() for _ in domains]},
            {'title': f'Domain Losses: {model_name}'}
        ]
    ) for model_name in domain_data.keys()
]

# Layout configuration
fig.update_layout(
    title='Validation Loss of Domain X After Training on Domain Y',
    xaxis_title='After Training on Domain Y',
    yaxis_title='MSE Loss',
    legend=dict(
        title='Domain X',
        groupclick="toggleitem",  # Allows group toggling while preserving individual control
        itemsizing='constant'
    ),
    updatemenus=[{
        'type': 'dropdown',
        'direction': 'down',
        'showactive': True,
        'buttons': buttons,
        'x': 1,
        'xanchor': 'left',
        'y': 1.1,
        'yanchor': 'top'
    }],
    template='plotly_white',
    # width=1200,
    # height=700
    yaxis=dict(range=[min(-0.01, min_loss), max_loss]),
    # yaxis=dict(range=[-0.0, 0.8]),
    
)
fig.update_xaxes(
    tickvals=[0, 1, 2, 3, 4, 5],
    ticktext=list(df['domain'].unique())
)

fig.show()


In [None]:
#plt similarity
h = models['deep_norm']

sns.set_theme(style="whitegrid")
plt.figure(figsize=(20,5))
plt.plot([m['similarity'] for m in h['train_epoch_metrics']])
plt.xticks(np.arange(0, 60, step=1))
plt.title('cosine similarity of two branches')
plt.show()

In [None]:
#plt
h = models['history']

plt.figure(figsize=(20,5))
for metric in ['inv_acc', 'spec_acc']:
    plt.plot([m[metric] for m in h['train_epoch_metrics']], label=metric)
plt.title('Branch accuracies and their similarity')
plt.xlabel('Epoch')
plt.ylabel('Metric Value')
plt.xticks(np.arange(0, 60, step=1))
plt.legend()
plt.show()


In [None]:
import plotly.graph_objects as go

fig = go.Figure()

for model_name, history in models.items():
    for metric in ['inv_acc', 'spec_acc']:
        try:
            y_values = [m[metric] for m in history['train_epoch_metrics']]
            fig.add_trace(go.Scatter(
                x=list(range(len(y_values))),
                y=y_values,
                mode='lines+markers',
                name=f'{model_name} - {metric}'
            ))
        except:
            continue

fig.update_layout(
    title='Branch Accuracies',
    xaxis_title='Epoch',
    yaxis_title='Accuracy',
    # xaxis=dict(tickmode='linear', tick0=0, dtick=1),
    template='plotly_white',
    # width=1000,
    # height=400
    yaxis=dict(range=[-0.0, 1]),
)
fig.update_xaxes(
    tickvals=[0, 10, 20, 30, 40, 50],
    ticktext=list(df['domain'].unique())
)

fig.show()


In [None]:
#plt
h = models['history']

plt.figure(figsize=(20,5))
for metric in ['inv_domain', 'spec_domain', 'task_loss']:
    plt.plot([m[metric] for m in h['train_epoch_metrics']], label=metric)
plt.xlabel('Epoch')
plt.ylabel('Metric Value')
plt.ylim(-0.01, 1)
plt.legend()
plt.show()

In [None]:
import plotly.graph_objects as go

fig = go.Figure()

for model_name, history in models.items():
    for metric in ['inv_domain', 'spec_domain', 'task_loss']:
        try:
            y_values = [m[metric] for m in history['train_epoch_metrics']]
            fig.add_trace(go.Scatter(
                x=list(range(len(y_values))),
                y=y_values,
                mode='lines+markers',
                name=f'{model_name} - {metric}'
            ))
        except:
            continue

fig.update_layout(
    title='CE loss - inv_domain, spec_domain, and MSE task_loss',
    xaxis_title='Epoch',
    yaxis_title='Loss',
    template='plotly_white',
    # width=1000,
    # height=400
    yaxis=dict(range=[-0.01, 0.12]),
)
fig.update_xaxes(
    tickvals=[0, 10, 20, 30, 40, 50],
    ticktext=list(df['domain'].unique())
)

fig.show()


In [None]:
h = models['history']

plt.figure(figsize=(20,5))
for metric in ['replay_count', 'current_count']:
    plt.plot([m[metric]/32*100 for m in h['train_epoch_metrics']], label=metric)
plt.title('Type of samples in batch')
plt.xlabel('Epoch')
plt.ylabel('% of batch')
# plt.axhline(32, color='r')
# plt.xticks(np.arange(0, 60, step=1))
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(20,5))
# for module in ['invariant', 'specific_residual', 'domain_classifier']:
for module in h['grad_norms'][0].keys():
    plt.plot([m[f'{module}'] for m in h['grad_norms']], label=module)
plt.title('Gradient Norms by Module')
# plt.ylim(-0.001, 1)
plt.legend()

In [None]:
import plotly.graph_objects as go

fig = go.Figure()

# Collect all module names across all models
all_modules = set()
for model_name, history in models.items():
    if 'grad_norms' in history and len(history['grad_norms']) > 0:
        for grad_dict in history['grad_norms']:
            all_modules.update(grad_dict.keys())
all_modules = sorted(all_modules)
# all_modules = ['invariant', 'specific', 'head']


# Add traces for each model and module, only show the first model by default
trace_visibility = []
for model_idx, (model_name, history) in enumerate(models.items()):
    if 'grad_norms' not in history or len(history['grad_norms']) == 0:
        continue
    for grad_dict in history['grad_norms']:
        for key in all_modules:
            grad_dict.setdefault(key, 0)
    for module in all_modules:
        y_values = [m[module] for m in history['grad_norms']]
        fig.add_trace(go.Scatter(
            x=list(range(len(y_values))),
            y=y_values,
            mode='lines+markers',
            name=f'{model_name} - {module}',
            # visible=(model_idx == 0)
        ))
    trace_visibility.append((model_name, len(all_modules)))

# Create dropdown buttons for each model
buttons = []
start_idx = 0
for model_name, n_traces in trace_visibility:
    visible = [False] * len(fig.data)
    for i in range(start_idx, start_idx + n_traces):
        visible[i] = True
    buttons.append(dict(
        label=model_name,
        method='update',
        args=[{'visible': visible}, {'title': f'Gradient Norms by Module - {model_name}'}]
    ))
    start_idx += n_traces

fig.update_layout(
    title=f'Gradient Norms by Module - {list(models.keys())[0]}',
    xaxis_title='Epoch',
    yaxis_title='Gradient Norm',
    template='plotly_white',
    # width=1000,
    # height=400,
    updatemenus=[{
        'buttons': buttons,
        'direction': 'down',
        'showactive': True,
        'x': 1,
        'xanchor': 'left',
        'y': 1.2,
        'yanchor': 'top'
    }],
    yaxis=dict(range=[-0.0, 0.2]),
)
fig.update_xaxes(
    tickvals=[0, 10, 20, 30, 40, 50],
    ticktext=list(df['domain'].unique())
)

fig.show()


### tsne projection

In [None]:
import glob
import re
import torch
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from ipywidgets import interact

# 1. Gather all checkpoint files
checkpoint_files = glob.glob("../checkpoints/dualbranchmodel_20250609_013632_*.pt")

# 2. Parse out the step value and sort
pattern = re.compile(r"_step(\d+)\.pt")
files_with_steps = []
for f in checkpoint_files:
    match = pattern.search(f)
    if match:
        step = int(match.group(1))
        files_with_steps.append((step, f))
files_with_steps.sort()  # Sort by step

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
# 3. Precompute t-SNE for all checkpoints (0-59)
tsne_projections = []
for idx, (step, ckpt_file) in enumerate(tqdm(files_with_steps, desc="Processing checkpoints")):
    ckpt = torch.load(ckpt_file, map_location='cpu')
    data = ckpt['tsne']
    inv_feats = np.array(data['inv_feats'])
    spec_feats = np.array(data['spec_feats'])
    domain_labels = np.array(data['domain_labels'])

    tsne = TSNE(n_components=2, random_state=42)
    inv_2d = tsne.fit_transform(inv_feats)
    spec_2d = tsne.fit_transform(spec_feats)

    tsne_projections.append({
        'timeline_idx': idx,  # 0 to 59
        'inv_2d': inv_2d,
        'spec_2d': spec_2d,
        'domains': domain_labels,
        'filename': ckpt_file
    })

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output

# State variable for current index
current_idx = 0

# Output widget for the plot
out = widgets.Output()

# Buttons
button_prev = widgets.Button(description="Previous")
button_next = widgets.Button(description="Next")

# Precompute limits
all_x = np.concatenate([d['inv_2d'][:,0] for d in tsne_projections] + [d['spec_2d'][:,0] for d in tsne_projections])
all_y = np.concatenate([d['inv_2d'][:,1] for d in tsne_projections] + [d['spec_2d'][:,1] for d in tsne_projections])
x_min, x_max = all_x.min(), all_x.max()
y_min, y_max = all_y.min(), all_y.max()

def plot_epoch(timeline_idx):
    data = tsne_projections[timeline_idx]
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), constrained_layout=True)
    domain_to_int = {name: i for i, name in enumerate(domains)}
    domain_ints = np.array([domain_to_int[name] for name in data['domains']])
    scatter1 = ax1.scatter(data['inv_2d'][:,0], data['inv_2d'][:,1], 
                          c=domain_ints, cmap='tab10', alpha=0.7, vmin=0, vmax=len(domains)-1)
    ax1.set_title(f"Invariant Features - Timeline {timeline_idx}")
    ax1.set_xlim(x_min, x_max)
    ax1.set_ylim(y_min, y_max)
    scatter2 = ax2.scatter(data['spec_2d'][:,0], data['spec_2d'][:,1],
                          c=domain_ints, cmap='tab10', alpha=0.7, vmin=0, vmax=len(domains)-1)
    ax2.set_title(f"Specific Features - Timeline {timeline_idx}")
    ax2.set_xlim(x_min, x_max)
    ax2.set_ylim(y_min, y_max)
    cbar = fig.colorbar(scatter1, ax=[ax1, ax2], label='Domain', 
                        ticks=np.arange(len(domains)), boundaries=np.arange(len(domains)+1)-0.5)
    cbar.set_ticks(np.arange(len(domains)))
    cbar.set_ticklabels(domains)
    plt.show()

def on_prev_clicked(b):
    global current_idx
    if current_idx > 0:
        current_idx -= 1
        with out:
            clear_output(wait=True)
            plot_epoch(current_idx)

def on_next_clicked(b):
    global current_idx
    if current_idx < len(tsne_projections) - 1:
        current_idx += 1
        with out:
            clear_output(wait=True)
            plot_epoch(current_idx)

button_prev.on_click(on_prev_clicked)
button_next.on_click(on_next_clicked)

# Display everything
display(widgets.HBox([button_prev, button_next]))
display(out)

# Initial plot
with out:
    plot_epoch(current_idx)

### single batch overfit

In [None]:
# After model output
outputs = model(inputs)  # Should be in [1,5]
print(f"Output range: {outputs.min().item()}–{outputs.max().item()}")


In [None]:
test_model = LGRBaseline().to(device)
test_optimizer = optim.Adam(test_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(0)
criterion = nn.MSELoss()

first_domain = domains[0]
train_loader = domain_dataloaders[first_domain]['train']
single_batch = next(iter(train_loader))
inputs, labels, _ = single_batch
inputs = inputs.to(device, dtype=torch.float32)
labels = labels.to(device, dtype=torch.float32)

# %%
num_test_epochs = 400
for epoch in range(num_test_epochs):
    test_optimizer.zero_grad()

    outputs = test_model(inputs)
    loss = criterion(outputs['output'], labels)
    
    loss.backward()
    test_optimizer.step()
    
    if (epoch+1) % 10 == 0 or epoch == 0:
        print(f"Overfit Epoch {epoch+1}/{num_test_epochs} | Loss: {loss.item():.4f}")

In [None]:
from torch.utils.tensorboard import SummaryWriter

# device = torch.device('cpu')torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()

writer = SummaryWriter("visualisation/")
model = DualBranchNet().to(device)
first_domain = domains[0]
train_loader = domain_dataloaders[first_domain]['train']
single_batch = next(iter(train_loader))
inputs, labels, _ = single_batch
inputs = inputs.to(device, dtype=torch.float32)
labels = labels.to(device, dtype=torch.float32)
# writer.add_graph(model, inputs)

writer.close()

In [None]:
input_names = ["image"]
output_names = ["appropriateness scores"]

torch.onnx.export(model, inputs, "model.onnx", input_names=input_names, output_names=output_names)

In [None]:
from torchviz import make_dot

y = model(inputs)
make_dot(y.mean(), params=dict(model.named_parameters()))

In [None]:
import torch
from torch.utils.data import DataLoader

# 1. Get a single batch from any domain's train loader
domain = domains[0]
single_batch = next(iter(domain_dataloaders[domain]['train']))


buffer = NaiveRehearsalBuffer(0)

# 4. Overfit loop for both models
def overfit_model(
    model, optimizer, batch_fn, batch_kwargs, device, num_epochs=100, exp_name="overfit"
):
    model.train()
    losses = []
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        loss, metrics = batch_fn(model, single_batch, device, **batch_kwargs)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        print(f"Epoch {epoch+1}, Loss: {loss.item():.6f}")
    return losses



In [None]:
# 5. Baseline model overfit
baseline_model = LGRBaseline().to(device)
optimizer = torch.optim.Adam(baseline_model.parameters(), lr=1e-3)
baseline_losses = overfit_model(
    baseline_model, optimizer, baseline_batch, {'mse_criterion': torch.nn.MSELoss()}, device
)

# 6. DualBranch model overfit
dual_model = DualBranchNet(num_domains=len(domains)).to(device)
optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
dualbranch_kwargs = {
    'mse_criterion': torch.nn.MSELoss(),
    'ce_criterion': torch.nn.CrossEntropyLoss(),
    'cos_criterion': lambda a, b: (torch.nn.CosineSimilarity()(a, b) ** 2).mean(),
    'domain_to_idx': domain_to_idx,
    'current_domain': domain
}
dualbranch_losses = overfit_model(
    dual_model, optimizer, dualbranch_batch, dualbranch_kwargs, device
)



In [None]:
# 7. Plot the loss curves (optional)
import matplotlib.pyplot as plt
plt.plot(baseline_losses, label='Baseline')
plt.plot(dualbranch_losses, label='DualBranch')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.legend()
plt.title('Overfitting to a Single Batch')
plt.show()
