## Data preparation

In [None]:
import os
import pandas as pd
from torchvision import transforms
from pathlib import Path
import numpy as np
import random
import pickle
from tqdm.notebook import tqdm, trange
import torch.optim as optim
import numpy as np
import random
import torch
import torch.nn as nn
import torchvision.models as models
import datetime


SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# DATASET_DIR = (Path("..") / ".." / "datasets").resolve()
# DATASETS = ["OFFICE-MANNERSDB", "MANNERSDBPlus"]
# LABEL_COLS = [
#     "Vaccum Cleaning", "Mopping the Floor", "Carry Warm Food",
#     "Carry Cold Food", "Carry Drinks", "Carry Small Objects",
#     "Carry Large Objects", "Cleaning", "Starting a conversation"
# ]

# import sys
# sys.path.append('..')
# from data_processing.data_processing import create_crossvalidation_loaders

sys.path.append('../models')

from models.buffers import NaiveRehearsalBuffer
from models.utils import get_dataloader

In [None]:
df = pd.read_pickle("../data/pepper_data.pkl")

## CL Models

In [None]:
# Ideas for potential ways to combine networks
# residual = SpecificHead(base + invariant_feats)
# specific_feats = invariant_feats + residual

# gate = GateNet(base + invariant_feats)
# residual = SpecificHead(base)
# specific_feats = invariant_feats + gate * residual

# residual = Attention(invariant_feats, base)
# specific_feats = invariant_feats + residual

# gamma, beta = SpecificHead(base)
# specific_feats = gamma * invariant_feats + beta

# invariant_feats = InvariantHead(base)
# specific_feats = SpecificHead(base)
# cos_sim(invariant_feats, specific_feats) ~ 0


### dualbranch model

In [None]:
def intermediate_layer_size(input, output, n_layers):
    start_exp = (output + 1).bit_length()
    end_exp = (input - 1).bit_length()
    
    total_powers = end_exp - start_exp
    if total_powers < n_layers:
        return None

    result = []
    denominator = n_layers + 1
    half_denominator = denominator // 2

    for i in range(1, n_layers + 1):
        numerator = i * total_powers + half_denominator #works same as rounding
        idx = numerator // denominator
        power = 1 << (start_exp + idx)
        result.append(power)

    return result


intermediate_layer_size(1280+64, 9, 3)

In [None]:
class DualBranchModel(nn.Module):
    def __init__(self, num_outputs=9, dropout_rate=0.3, architecture={'env':'lightweight', 'head':'deep'}):
        super(DualBranchModel, self).__init__()
        self.setup = architecture
        
        self.social_branch = nn.Sequential(
            models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1).features,
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten()
        )
        soc_feature_dim = 1280
        

        if self.setup['env'] == 'lightweight':
            self.env_branch = nn.Sequential(
                models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1).features,
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten()
            )
            env_feature_dim = 1280

        elif self.setup['env'] == 'label':
            max_rooms = 30
            self.env_branch = nn.Embedding(max_rooms, 64)
            env_feature_dim = 64


        self.fusion_dim = soc_feature_dim + env_feature_dim

        layers = intermediate_layer_size(soc_feature_dim, num_outputs, 1)
        self.social_classifier = nn.Sequential(
            nn.Linear(soc_feature_dim, layers[0]),
            nn.ReLU(),
            nn.Linear(layers[0], num_outputs)
        )
        layers = intermediate_layer_size(env_feature_dim, num_outputs, 1)
        self.env_classifier = nn.Sequential(
            nn.Linear(env_feature_dim, layers[0]),
            nn.ReLU(),
            nn.Linear(layers[0], num_outputs)
        )
        
        if self.setup['head'] == 'deep':
            self.head = nn.Sequential(
                nn.Linear(self.fusion_dim, 512),
                nn.BatchNorm1d(512),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                
                nn.Linear(512, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                
                nn.Linear(256, 128),
                nn.BatchNorm1d(128),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                
                nn.Linear(128, num_outputs)
            )
        elif self.setup['head'] == 'shallow':
            self.head = nn.Sequential(
                nn.Linear(self.fusion_dim, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, num_outputs)
            )
        
    def forward(self, social_img, env_img):
        social_features = self.social_branch(social_img)
        env_features = self.env_branch(env_img)
        
        fused_features = torch.cat([social_features, env_features], dim=1)
        scores = self.head(fused_features)

        social_class = self.social_classifier(social_features.detach())
        env_class = self.env_classifier(env_features.detach())
        
        return {
            'output': scores,
            'invariant_domain': social_class,
            'specific_domain': env_class,
            'invariant_feats': social_features,
            'specific_feats': env_features
        }

### K Fold Cross Validation

In [None]:
def combine_fold_histories(fold_history):
    """
    Combine multiple training histories by averaging them element-wise.
    Handles both simple lists and lists of dictionaries.
    """
    combined_history = {}
    metric_keys = list(fold_history[0].keys())
    
    for key in metric_keys:
        # Check if this metric contains dictionaries
        first_element = fold_history[0][key][0] if fold_history[0][key] else None
        
        if isinstance(first_element, dict):
            # Handle lists of dictionaries (train_epoch_metrics, grad_norms)
            combined_history[key] = average_list_of_dicts(fold_history, key)
        else:
            # Handle simple lists (train_epoch_loss, val_epoch_loss, cross_domain_val)
            stacked_metrics = np.stack([fold_history[fold][key] for fold in fold_history])
            combined_history[key] = np.mean(stacked_metrics, axis=0).tolist()
    
    return combined_history

def average_list_of_dicts(fold_history, metric_key):
    """
    Average a list of dictionaries across folds.
    """
    # Get the dictionary keys from the first fold's first epoch
    dict_keys = list(fold_history[0][metric_key][0].keys())
    
    # Convert each fold's list of dicts to a 2D numpy array
    fold_arrays = []
    for fold in fold_history:
        # Convert list of dicts to 2D array: [epochs, sub_metrics]
        fold_array = np.array([[epoch_dict[k] for k in dict_keys] 
                              for epoch_dict in fold_history[fold][metric_key]])
        fold_arrays.append(fold_array)
    
    # Stack all folds and average: [folds, epochs, sub_metrics] -> [epochs, sub_metrics]
    stacked = np.stack(fold_arrays)
    averaged = np.mean(stacked, axis=0)
    
    # Convert back to list of dictionaries
    result = []
    for epoch_values in averaged:
        epoch_dict = dict(zip(dict_keys, epoch_values))
        result.append(epoch_dict)
    
    return result

In [None]:
from sklearn.model_selection import StratifiedKFold

# Add K fold labels to datapoints in dataframe, later used for creating K Fold Dataloaders

df = pd.read_pickle("../data/pepper_data.pkl")
df['image_path'] = '../' + df['image_path']

# Initialize fold column
df['fold'] = -1

# Get unique image paths for splitting
unique_images_df = df[['image_path', 'domain']].reset_index(drop=True)
#exclude test subset test idx get from get_dataloader()
unique_images_df = unique_images_df[~unique_images_df['image_path'].isin(test_split_idx)]

# Create stratified 5-fold splits based on domain
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)

# Assign fold numbers based on domain stratification
for fold, (_, val_idx) in enumerate(skf.split(unique_images_df['image_path'], unique_images_df['domain'])):
    val_image_paths = unique_images_df.iloc[val_idx]['image_path'].tolist()
    df.loc[df['image_path'].isin(val_image_paths), 'fold'] = fold

# Verify domain distribution across folds
print("Domain distribution across folds:")
print(df.groupby(['fold', 'domain']).size().unstack(fill_value=0))

In [None]:
# fold_loaders = create_crossvalidation_loaders(df, 5, batch_sizes=(32, 64, 64), resize_img_to=(512, 288))

## Training

### setup

In [None]:
# For LGRBaseline
def baseline_batch(model, batch, device, detach_base, binary, full_replay, loss_params, **kwargs):
    inputs, labels, _ = batch
    inputs, labels = inputs.to(device), labels.to(device)
    outputs = model(inputs)['output']
    loss = kwargs['mse_criterion'](outputs, labels)
    loss.backward()
    metrics = {}
    return loss, metrics

#For DANN
def dann_batch(model, batch, device, detach_base, binary, full_replay, loss_params={'head':0.5, 'social':0.5}, **kwargs):
    inputs, labels, domain_labels = batch
    inputs, labels = inputs.to(device), labels.to(device)
    domain_to_idx = kwargs['domain_to_idx']
    domain_labels = torch.tensor([domain_to_idx[d] for d in domain_labels], device=device)
    mse_criterion = kwargs['mse_criterion']
    bce_criterion = kwargs['bce_criterion']
    alpha = kwargs['alpha']

    current_domain = kwargs['current_domain']
    current_binary_labels = (domain_labels == domain_to_idx[current_domain]).float()
    is_first_domain = bool(domain_to_idx[current_domain] == 0)

    outputs = model(inputs, alpha=alpha, is_first_domain=is_first_domain)

    task_loss = mse_criterion(outputs['output'], labels)

    if is_first_domain:
        inv_domain_loss = 0
        inv_acc = 0
    else:
        inv_domain_loss = bce_criterion(outputs['invariant_domain'].squeeze(), current_binary_labels)
        preds = (outputs['invariant_domain'].squeeze() > 0).float()
        inv_acc = (preds == current_binary_labels).float().mean().item()
    
    total_loss = (loss_params['head'] * task_loss +
                    loss_params['social'] * inv_domain_loss)

    total_loss.backward()
    
    metrics = {
        'task_loss': task_loss.item(),
        'inv_domain': 0 if is_first_domain else inv_domain_loss.item(),
        'inv_acc': inv_acc
    }
    return total_loss, metrics

def heuristic_dualbranch_batch(model, batch, device, detach_base, binary, full_replay, loss_params={}, **kwargs):
    inputs1, inputs2, labels, domain_labels = batch
    inputs1, inputs2, labels, domain_labels = inputs1.to(device), inputs2.to(device), labels.to(device), domain_labels.to(device)
    domain_to_idx = kwargs['domain_to_idx']
    # domain_labels = torch.tensor([domain_to_idx[d] for d in domain_labels], device=device)
    mse_criterion = kwargs['mse_criterion']
    ce_criterion = kwargs['ce_criterion']

    outputs = model(inputs1, inputs2)

    loss = mse_criterion(outputs['output'], labels)
    loss.backward()

    class_optimiser = kwargs['class_optimizer']
    class_optimiser.zero_grad()
    inv_domain_loss = ce_criterion(outputs['invariant_domain'], domain_labels)
    spec_domain_loss = ce_criterion(outputs['specific_domain'], domain_labels)
    (inv_domain_loss + spec_domain_loss).backward()
    class_optimiser.step()
    inv_acc = (outputs['invariant_domain'].argmax(1) == domain_labels).float().mean().item()
    spec_acc = (outputs['specific_domain'].argmax(1) == domain_labels).float().mean().item()
    
    metrics = {
        'inv_domain': inv_domain_loss.item(),
        'spec_domain': spec_domain_loss.item(),
        'inv_acc': inv_acc,
        'spec_acc': spec_acc
    }
    return loss, metrics


# For DualBranchNet
def dualbranch_batch(model, batch, device, detach_base, binary, full_replay, loss_params={'head': 1, 'social': 0.5, 'room': 0.2}, **kwargs):
    inputs, labels, domain_labels = batch
    inputs, labels = inputs.to(device), labels.to(device)
    domain_to_idx = kwargs['domain_to_idx']
    domain_labels = torch.tensor([domain_to_idx[d] for d in domain_labels], device=device)
    mse_criterion = kwargs['mse_criterion']
    ce_criterion = kwargs['ce_criterion']
    cos_criterion = kwargs['cos_criterion']
    alpha = kwargs['alpha']
    if binary:
        bce_criterion = kwargs['bce_criterion']

    # Split batch
    current_domain = kwargs['current_domain']
    current_binary_labels = (domain_labels == domain_to_idx[current_domain]).float()
    current_mask = (domain_labels == domain_to_idx[current_domain])

    if full_replay:
        current_mask = torch.ones_like(current_mask, dtype=torch.bool)

    replay_mask = ~current_mask

    # 1. Current samples: update all parameters
    if current_mask.any():
        inputs_current = inputs[current_mask]
        labels_current = labels[current_mask]
        domain_labels_current = domain_labels[current_mask]

        outputs_current = model(inputs_current, alpha=alpha)
        inv_feats = outputs_current['invariant_feats']
        spec_feats = outputs_current['specific_feats']

        task_loss = mse_criterion(outputs_current['output'], labels_current)
        if binary:
            inv_domain_loss = bce_criterion(outputs_current['invariant_domain'].squeeze(), current_binary_labels[current_mask])
        else:
            inv_domain_loss = ce_criterion(outputs_current['invariant_domain'], domain_labels_current)
        spec_domain_loss = ce_criterion(outputs_current['specific_domain'], domain_labels_current)
        similarity_loss = cos_criterion(inv_feats, spec_feats)
        
        total_loss = (loss_params['head'] * task_loss +
                      loss_params['social'] * inv_domain_loss +
                      loss_params['room'] * spec_domain_loss)

        total_loss.backward(retain_graph= not full_replay)
        
        if binary:
            # Threshold at 0 (sigmoid(0) = 0.5)
            preds = (outputs_current['invariant_domain'].squeeze() > 0).float()
            inv_acc = (preds == current_binary_labels[current_mask]).float().mean().item()
        else:
            inv_acc = (outputs_current['invariant_domain'].argmax(1) == domain_labels_current).float().mean().item()
        spec_acc = (outputs_current['specific_domain'].argmax(1) == domain_labels_current).float().mean().item()
    else:
        total_loss = torch.tensor(0.0, device=device)
        inv_acc = 0.0
        spec_acc = 0.0
        task_loss = torch.tensor(0.0, device=device)
        inv_domain_loss = torch.tensor(0.0, device=device)
        spec_domain_loss = torch.tensor(0.0, device=device)
        similarity_loss = torch.tensor(0.0, device=device)

    # 2. Replay samples: update only specific branch + head
    if replay_mask.any():
        inputs_replay = inputs[replay_mask]
        labels_replay = labels[replay_mask]
        domain_labels_replay = domain_labels[replay_mask]

        #no_grad, unlike requires_grad=False, detaches all elements from the gradient computation graph
        with torch.no_grad():
            base_replay = model.backbone(inputs_replay)
            base_replay = model.pool(base_replay).flatten(1)
            inv_feats_replay = model.invariant(base_replay)

        specific_feats = model.specific(base_replay)
        spec_domain_pred = model.specific_domain_classifier(specific_feats)
        
        if detach_base:
            combined = torch.cat([inv_feats_replay, specific_feats, base_replay], dim=1)
        else:
            combined = torch.cat([inv_feats_replay, specific_feats], dim=1)  
        
        scores = model.head(combined)
        
        task_loss_replay = mse_criterion(scores, labels_replay)
        spec_domain_loss_replay = ce_criterion(spec_domain_pred, domain_labels_replay)
        total_loss_replay = task_loss_replay + 0.2 * spec_domain_loss_replay
        
        total_loss_replay.backward()
    
    metrics = {
        'task_loss': task_loss.item(),
        'inv_domain': inv_domain_loss.item(),
        'spec_domain': spec_domain_loss.item(),
        'similarity': similarity_loss.item(),
        'inv_acc': inv_acc,
        'spec_acc': spec_acc,
        'replay_count': replay_mask.sum().item(),
        'current_count': current_mask.sum().item()
    }
    return total_loss, metrics


In [None]:
#TODO: change dataset to return torch domain labels indexes not strings
def evaluate_model(model, dataloader, criterion, device , tsne=None):
    model.eval()
    total_loss = 0.0
    total_samples = 0
    with torch.no_grad():
        for batch in dataloader:
            if len(batch) == 4:
                inputs1, inputs2, labels, domain_labels = batch
                inputs1 = inputs1.to(device, dtype=torch.float32)
                inputs2 = inputs2.to(device, dtype=torch.float32)
                labels = labels.to(device, dtype=torch.float32)
                inputs = (inputs1, inputs2)
            elif len(batch) == 3:
                inputs, labels, _ = batch
                inputs = inputs.to(device, dtype=torch.float32)
                labels = labels.to(device, dtype=torch.float32)
                inputs = (inputs,)
            else:
                raise ValueError(f"Batch contains {len(batch)} objects. Should contain 3 or 4 - image/two, labels, domain_labels")

            outputs = model(*inputs)['output']

            loss = criterion(outputs, labels)

            total_loss += loss.item() * inputs[0].size(0)
            total_samples += inputs[0].size(0)
            
            if tsne:
                tsne['social'].append(outputs['invariant_feats'].cpu())
                tsne['environmental'].append(outputs['specific_feats'].cpu())
                tsne['domains'].append(domain_labels.cpu())

            val_loss = total_loss / total_samples

    return (val_loss, tsne) if tsne else val_loss

def cross_domain_validation(model, domain_dataloaders, criterion, device, tsne=None):
    results = {}
    for domain, loaders in domain_dataloaders.items():
        val_loader = loaders['val']
        if tsne:
            val_loss, tsne = evaluate_model(model, val_loader, criterion, device, tsne)
        else:
            val_loss = evaluate_model(model, val_loader, criterion, device)
        results[domain] = val_loss
    return (results, tsne) if tsne else results

def average_metrics(metrics_list):
    # metrics_list: list of dicts, each dict contains metrics for a batch
    if not metrics_list:
        return {}
    keys = metrics_list[0].keys()
    avg_metrics = {}
    for k in keys:
        avg_metrics[k] = float(np.mean([m[k] for m in metrics_list if k in m]))
    return avg_metrics

def collect_tsne_features(model, loader, device):
    model.eval()
    all_inv, all_spec, all_domains = [], [], []
    with torch.no_grad():
        for domain, loaders in domain_dataloaders.items():
            loader = loaders['val']
            for x, _, d in loader:
                x = x.to(device)
                out = model(x)
                all_inv.append(out['invariant_feats'].cpu())
                all_domains += list(d)
    inv_feats = torch.cat(all_inv, dim=0).numpy()
    return inv_feats, all_domains


def collect_gradients(model):
    grad_norms = {}
    for name, param in model.named_parameters():
        if param.grad is not None and not name.startswith("backbone"):
            module = name.split('.')[0]
            norm = param.grad.norm(2).item()
            if module not in grad_norms:
                grad_norms[module] = []
            grad_norms[module].append(norm)
    # Take mean per module
    grad_norms = {k: float(np.mean(v)) for k, v in grad_norms.items()}
    return grad_norms


In [None]:
def unified_train_loop(
    model, domains, domain_dataloaders, buffer, optimizer, writer, device,
    batch_fn, batch_kwargs, loss_params, num_epochs=5, exp_name="exp", gradient_clipping=False, detach_base=False, binary=False, full_replay=False, collect_tsne_data=False, restart={}, eval_buffer=False
):
    start_domain_idx = 0
    global_step = 0
    history = {
        'train_epoch_loss': [],
        'val_epoch_loss': [],
        'val_buffer_epoch_loss': [],
        'train_epoch_metrics': [],
        'cross_domain_val': [],
        'grad_norms': [],
    }
    
    if restart:
        # Populate history
        global_step = restart['global_step']
        history = restart['history']
        # Populate buffer
        start_domain_idx = np.where(domains == restart['domain'])[0][0]
        for domain_idx, current_domain in enumerate(domains[:start_domain_idx]):
            buffer.update_buffer(current_domain, domain_dataloaders[current_domain]['train'].dataset) 
        print(f"Restarting from domain {restart['domain']} index {start_domain_idx}")
        print(f"Buffer: {buffer.get_domain_distribution()}")         
        

    for domain_idx, current_domain in enumerate(tqdm(domains[start_domain_idx:], desc=f"Total training"), start=start_domain_idx):
        train_loader = buffer.get_loader_with_replay(current_domain, domain_dataloaders[current_domain]['train'])
        if eval_buffer:
            eval_loader = eval_buffer.get_loader_with_replay(current_domain, domain_dataloaders[current_domain]['val'])
        else:
            eval_loader = domain_dataloaders[current_domain]['val']
        len_dataloader = len(train_loader)
        
        for epoch in trange(num_epochs, desc=f"Current domain {current_domain}"):
            model.train()
            epoch_loss = 0.0
            samples = 0
            batch_metrics_list = []
            
            # for batch_idx, batch in enumerate(train_loader):
            for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Current epoch {epoch}", leave=False)):
                if not batch_kwargs['alpha']:
                    p = (epoch * len_dataloader + batch_idx) / (num_epochs * len_dataloader)
                    alpha = 2. / (1. + np.exp(-10 * p)) - 1
                else:
                    alpha = batch_kwargs['alpha']

                optimizer.zero_grad()
                loss, metrics = batch_fn(model, batch, device, detach_base, binary, full_replay, loss_params, **{**batch_kwargs, 'current_domain': current_domain, 'alpha':alpha})
                if gradient_clipping:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                batch_size = batch[0].size(0)
                epoch_loss += loss.item() * batch_size
                samples += batch_size
                global_step += 1
                batch_metrics_list.append(metrics)
                # # TensorBoard logging (every 10 batches)
                # if writer and batch_idx % 10 == 0:
                #     writer.add_scalar(f'{exp_name}/train_loss', loss.item(), global_step)
                #     for k, v in metrics.items():
                #         writer.add_scalar(f'{exp_name}/train_{k}', v, global_step)
            avg_epoch_loss = epoch_loss / samples
            # writer.add_scalar(f'{exp_name}/train_epoch_loss', avg_epoch_loss, global_step)
            history['train_epoch_loss'].append(avg_epoch_loss)
            # Average batch metrics for this epoch
            avg_metrics = average_metrics(batch_metrics_list)
            history['train_epoch_metrics'].append(avg_metrics)

            # Collect gradients
            grad_norms = collect_gradients(model)
            history['grad_norms'].append(grad_norms)

            # Validation on current domain
            val_loss = evaluate_model(model, domain_dataloaders[current_domain]['val'], batch_kwargs['mse_criterion'], device)
            val_loss_buffer = evaluate_model(model, eval_loader, batch_kwargs['mse_criterion'], device)
            # writer.add_scalar(f'{exp_name}/val_epoch_loss', val_loss, global_step)
            history['val_epoch_loss'].append(val_loss)
            history['val_buffer_epoch_loss'].append(val_loss_buffer)

            # Collect data for t-SNE domain separation graphs
            if collect_tsne_data:
                inv_feats, domain_labels = collect_tsne_features(model, domain_dataloaders, device)
                tsne_data = {
                    'inv_feats': inv_feats,
                    'domain_labels': domain_labels
                }
            else:
                tsne_data = None

            # Cross-domain validation (after each domain)
            if epoch == num_epochs-1:
                if collect_tsne_data:
                    tsne = {'social': [], 'env': [], 'domains': []}
                    cross_val, tsne_data = cross_domain_validation(model, domain_dataloaders, batch_kwargs['mse_criterion'], device, tsne)
                else:
                    cross_val = cross_domain_validation(model, domain_dataloaders, batch_kwargs['mse_criterion'], device)
                history['cross_domain_val'].append(cross_val)

                # Only save last model per domain to save space
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'history': history,
                    'tsne' : tsne_data,
                }, f"../checkpoints/{exp_name}_domain{current_domain}_epoch{epoch}_step{global_step}.pt")
            # else:
            #     # Save metrics
            #     torch.save({
            #         # 'model_state_dict': model.state_dict(),
            #         # 'optimizer_state_dict': optimizer.state_dict(),
            #         'history': history,
            #         'tsne' : tsne_data,
            #     }, f"../checkpoints/{exp_name}_domain{current_domain}_epoch{epoch}_step{global_step}.pt")
            with open(f"../checkpoints/{exp_name}_history.pkl", "wb") as f:
                pickle.dump(history, f)
            
        buffer.update_buffer(current_domain, domain_dataloaders[current_domain]['train'].dataset)
        eval_buffer.update_buffer(current_domain, domain_dataloaders[current_domain]['val'].dataset)
    return history


### current

In [None]:
# TODO tests: retrain the model with
# 3 trains averaged
# one branch, no mask
# top branch + mask
# bottom branch + mask
# both branches + random masks

In [None]:

testing_scenarios = {
    # 'heuristic_small_env': (DualBranchModel(), [(288,512), (144,256)], False),
    # 'heuristic_square_img': (DualBranchModel(), [(224,224)]*2, False),
    'heuristic_eval_buffer': (DualBranchModel(), [(288,512)]*2, True)
}

for name, (model, img_size, eval_buffer) in testing_scenarios.items():

    transform_soc = transforms.Compose([
        transforms.Resize(img_size[0]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225])
    ])
    transform_env = transforms.Compose([
        transforms.Resize(img_size[1]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225])
    ])

    df = pd.read_pickle("../data/pepper_data.pkl")
    df['image_path_env'] = df['image_path'].apply(lambda p: str(Path('../data/masked/environment') / Path(p).name))
    df['image_path_social'] = df['image_path'].apply(lambda p: str(Path('../data/masked/social') / Path(p).name))
    domain_dataloaders = get_dataloader(df, batch_sizes=(16, 64, 64), return_splits=False, double_img=True, transforms=[transform_soc, transform_env], num_workers=0)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    domains = df['domain'].unique()
    domain_to_idx = {d: i for i, d in enumerate(domains)}


    print(f"\nTesting: {name}")
    dual_model = model.to(device)
    # optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
    optimizer = optim.Adam([
        {'params': dual_model.social_branch.parameters()},
        {'params': dual_model.env_branch.parameters()},
        {'params': dual_model.head.parameters()},
    ], lr=1e-3)
    classifier_optimizer = optim.Adam([ 
        {'params': dual_model.social_classifier.parameters()},
        {'params': dual_model.env_classifier.parameters()}
    ], lr=1e-3)

    buffer = NaiveRehearsalBuffer(buffer_size=120)
    if eval_buffer:
        eval_buffer = NaiveRehearsalBuffer(buffer_size=120)

    exp_name = f"{name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
    writer = None

    def cos_criterion(a, b):
        return (1 - torch.abs(nn.CosineSimilarity()(a, b))).mean()
    
    dualbranch_kwargs = {
            'mse_criterion': nn.MSELoss(),
            'ce_criterion': nn.CrossEntropyLoss(),
            'cos_criterion': cos_criterion,
            'domain_to_idx': domain_to_idx,
            'bce_criterion': nn.BCEWithLogitsLoss(),
            'alpha': 0,
            'class_optimizer': classifier_optimizer
        }
    
    unified_train_loop(
        model=dual_model,
        domains=domains,
        domain_dataloaders=domain_dataloaders,
        buffer=buffer,
        optimizer=optimizer,
        writer=writer,
        device=device,
        batch_fn=heuristic_dualbranch_batch,
        batch_kwargs=dualbranch_kwargs,
        num_epochs=10,
        exp_name=exp_name,
        gradient_clipping=True,
        detach_base=False,
        binary = True,
        full_replay = True,
        collect_tsne_data=False,
        loss_params={'head': 1, 'social': 1, 'room': 0.5},
        eval_buffer=eval_buffer
    )

In [None]:
# import subprocess

# model_names = [
#     'heuristic_small_env',
#     'heuristic_square_img',
#     'heuristic_eval_buffer'
# ]

# for model_name in model_names:
#     print(f"Running {model_name}")
#     process = subprocess.Popen(
#         ["python", "train_models.py", "--model_name", model_name, "--num_workers", "0"],
#         stdout=subprocess.PIPE,
#         stderr=subprocess.STDOUT,
#         text=True
#     )

#     for line in process.stdout:
#         print(line, end='')

#     process.wait()