## Data preparation

In [None]:
import os
import pandas as pd
from torchvision import transforms
from pathlib import Path
import numpy as np
import random
import pickle
from tqdm.notebook import tqdm, trange
import torch.optim as optim
import numpy as np
import random
import torch
import torch.nn as nn
import torchvision.models as models
import datetime


SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# DATASET_DIR = (Path("..") / ".." / "datasets").resolve()
# DATASETS = ["OFFICE-MANNERSDB", "MANNERSDBPlus"]
# LABEL_COLS = [
#     "Vaccum Cleaning", "Mopping the Floor", "Carry Warm Food",
#     "Carry Cold Food", "Carry Drinks", "Carry Small Objects",
#     "Carry Large Objects", "Cleaning", "Starting a conversation"
# ]

# import sys
# sys.path.append('..')
# from data_processing.data_processing import create_crossvalidation_loaders

import sys
sys.path.append('..')

from models.buffers import NaiveRehearsalBuffer
from models.utils import get_dataloader
from models.training_utils import heuristic_dualbranch_batch, unified_train_loop

In [None]:
df = pd.read_pickle("../data/pepper_data.pkl")

## CL Models

In [None]:
# Ideas for potential ways to combine networks
# residual = SpecificHead(base + invariant_feats)
# specific_feats = invariant_feats + residual

# gate = GateNet(base + invariant_feats)
# residual = SpecificHead(base)
# specific_feats = invariant_feats + gate * residual

# residual = Attention(invariant_feats, base)
# specific_feats = invariant_feats + residual

# gamma, beta = SpecificHead(base)
# specific_feats = gamma * invariant_feats + beta

# invariant_feats = InvariantHead(base)
# specific_feats = SpecificHead(base)
# cos_sim(invariant_feats, specific_feats) ~ 0


### dualbranch model

In [None]:
def intermediate_layer_size(input, output, n_layers):
    start_exp = (output + 1).bit_length()
    end_exp = (input - 1).bit_length()
    
    total_powers = end_exp - start_exp
    if total_powers < n_layers:
        return None

    result = []
    denominator = n_layers + 1
    half_denominator = denominator // 2

    for i in range(1, n_layers + 1):
        numerator = i * total_powers + half_denominator #works same as rounding
        idx = numerator // denominator
        power = 1 << (start_exp + idx)
        result.append(power)

    return result


intermediate_layer_size(1280+64, 9, 3)

In [None]:
class DualBranchModel(nn.Module):
    def __init__(self, num_outputs=9, dropout_rate=0.3, architecture={'env':'lightweight', 'head':'deep'}):
        super(DualBranchModel, self).__init__()
        self.setup = architecture
        
        self.social_branch = nn.Sequential(
            models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1).features,
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten()
        )
        soc_feature_dim = 1280
        

        if self.setup['env'] == 'lightweight':
            self.env_branch = nn.Sequential(
                models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1).features,
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten()
            )
            env_feature_dim = 1280

        elif self.setup['env'] == 'label':
            max_rooms = 30
            self.env_branch = nn.Embedding(max_rooms, 64)
            env_feature_dim = 64


        self.fusion_dim = soc_feature_dim + env_feature_dim

        layers = intermediate_layer_size(soc_feature_dim, num_outputs, 1)
        self.social_classifier = nn.Sequential(
            nn.Linear(soc_feature_dim, layers[0]),
            nn.ReLU(),
            nn.Linear(layers[0], num_outputs)
        )
        layers = intermediate_layer_size(env_feature_dim, num_outputs, 1)
        self.env_classifier = nn.Sequential(
            nn.Linear(env_feature_dim, layers[0]),
            nn.ReLU(),
            nn.Linear(layers[0], num_outputs)
        )
        
        if self.setup['head'] == 'deep':
            self.head = nn.Sequential(
                nn.Linear(self.fusion_dim, 512),
                nn.BatchNorm1d(512),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                
                nn.Linear(512, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                
                nn.Linear(256, 128),
                nn.BatchNorm1d(128),
                nn.ReLU(),
                nn.Dropout(dropout_rate),
                
                nn.Linear(128, num_outputs)
            )
        elif self.setup['head'] == 'shallow':
            self.head = nn.Sequential(
                nn.Linear(self.fusion_dim, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, num_outputs)
            )
        
    def forward(self, social_img, env_img):
        social_features = self.social_branch(social_img)
        env_features = self.env_branch(env_img)
        
        fused_features = torch.cat([social_features, env_features], dim=1)
        scores = self.head(fused_features)

        social_class = self.social_classifier(social_features.detach())
        env_class = self.env_classifier(env_features.detach())
        
        return {
            'output': scores,
            'invariant_domain': social_class,
            'specific_domain': env_class,
            'invariant_feats': social_features,
            'specific_feats': env_features
        }

### K Fold Cross Validation

In [None]:
def combine_fold_histories(fold_history):
    """
    Combine multiple training histories by averaging them element-wise.
    Handles both simple lists and lists of dictionaries.
    """
    combined_history = {}
    metric_keys = list(fold_history[0].keys())
    
    for key in metric_keys:
        # Check if this metric contains dictionaries
        first_element = fold_history[0][key][0] if fold_history[0][key] else None
        
        if isinstance(first_element, dict):
            # Handle lists of dictionaries (train_epoch_metrics, grad_norms)
            combined_history[key] = average_list_of_dicts(fold_history, key)
        else:
            # Handle simple lists (train_epoch_loss, val_epoch_loss, cross_domain_val)
            stacked_metrics = np.stack([fold_history[fold][key] for fold in fold_history])
            combined_history[key] = np.mean(stacked_metrics, axis=0).tolist()
    
    return combined_history

def average_list_of_dicts(fold_history, metric_key):
    """
    Average a list of dictionaries across folds.
    """
    # Get the dictionary keys from the first fold's first epoch
    dict_keys = list(fold_history[0][metric_key][0].keys())
    
    # Convert each fold's list of dicts to a 2D numpy array
    fold_arrays = []
    for fold in fold_history:
        # Convert list of dicts to 2D array: [epochs, sub_metrics]
        fold_array = np.array([[epoch_dict[k] for k in dict_keys] 
                              for epoch_dict in fold_history[fold][metric_key]])
        fold_arrays.append(fold_array)
    
    # Stack all folds and average: [folds, epochs, sub_metrics] -> [epochs, sub_metrics]
    stacked = np.stack(fold_arrays)
    averaged = np.mean(stacked, axis=0)
    
    # Convert back to list of dictionaries
    result = []
    for epoch_values in averaged:
        epoch_dict = dict(zip(dict_keys, epoch_values))
        result.append(epoch_dict)
    
    return result

In [None]:
from sklearn.model_selection import StratifiedKFold

# Add K fold labels to datapoints in dataframe, later used for creating K Fold Dataloaders

df = pd.read_pickle("../data/pepper_data.pkl")
df['image_path'] = '../' + df['image_path']

# Initialize fold column
df['fold'] = -1

# Get unique image paths for splitting
unique_images_df = df[['image_path', 'domain']].reset_index(drop=True)
#exclude test subset test idx get from get_dataloader()
unique_images_df = unique_images_df[~unique_images_df['image_path'].isin(test_split_idx)]

# Create stratified 5-fold splits based on domain
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)

# Assign fold numbers based on domain stratification
for fold, (_, val_idx) in enumerate(skf.split(unique_images_df['image_path'], unique_images_df['domain'])):
    val_image_paths = unique_images_df.iloc[val_idx]['image_path'].tolist()
    df.loc[df['image_path'].isin(val_image_paths), 'fold'] = fold

# Verify domain distribution across folds
print("Domain distribution across folds:")
print(df.groupby(['fold', 'domain']).size().unstack(fill_value=0))

In [None]:
# fold_loaders = create_crossvalidation_loaders(df, 5, batch_sizes=(32, 64, 64), resize_img_to=(512, 288))

## Training

### current

In [None]:
# TODO tests: retrain the model with
# 3 trains averaged
# full model tSNE
# one branch, no mask
# top branch + mask
# bottom branch + mask
# both branches + random masks

In [None]:

testing_scenarios = {
    'heuristic__small_imgs': (DualBranchModel(), [(144,256)]*2)
}

for name, (model, img_size) in testing_scenarios.items():

    transform_soc = transforms.Compose([
        transforms.Resize(img_size[0]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225])
    ])
    transform_env = transforms.Compose([
        transforms.Resize(img_size[1]),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225])
    ])

    df = pd.read_pickle("../data/pepper_data.pkl")
    df['image_path_env'] = df['image_path'].apply(lambda p: str(Path('../data/masked/environment') / Path(p).name))
    df['image_path_social'] = df['image_path'].apply(lambda p: str(Path('../data/masked/social') / Path(p).name))
    domain_dataloaders = get_dataloader(df, batch_sizes=(16, 64, 64), return_splits=False, double_img=True, transforms=[transform_soc, transform_env], num_workers=0)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    domains = df['domain'].unique()
    domain_to_idx = {d: i for i, d in enumerate(domains)}


    print(f"\nTesting: {name}")
    dual_model = model.to(device)
    # optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
    optimizer = optim.Adam([
        {'params': dual_model.social_branch.parameters()},
        {'params': dual_model.env_branch.parameters()},
        {'params': dual_model.head.parameters()},
    ], lr=1e-3)
    classifier_optimizer = optim.Adam([ 
        {'params': dual_model.social_classifier.parameters()},
        {'params': dual_model.env_classifier.parameters()}
    ], lr=1e-3)

    buffer = NaiveRehearsalBuffer(buffer_size=120)

    exp_name = f"{name}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
    writer = None
    
    dualbranch_kwargs = {
            'mse_criterion': nn.MSELoss(),
            'ce_criterion': nn.CrossEntropyLoss(),
            'domain_to_idx': domain_to_idx,
            'bce_criterion': nn.BCEWithLogitsLoss(),
            'alpha': 0,
            'class_optimizer': classifier_optimizer
        }
    
    unified_train_loop(
        model=dual_model,
        domains=domains,
        domain_dataloaders=domain_dataloaders,
        buffer=buffer,
        optimizer=optimizer,
        device=device,
        batch_fn=heuristic_dualbranch_batch,
        batch_kwargs=dualbranch_kwargs,
        num_epochs=10,
        exp_name=exp_name,
        gradient_clipping=True,
        detach_base=False,
        binary = True,
        full_replay = True,
        collect_tsne_data=False,
        loss_params={'head': 1, 'social': 1, 'room': 0.5},
        checkpoint_dir="../checkpoints"
    )

In [None]:
# import subprocess

# model_names = [
#     'heuristic_small_env',
#     'heuristic_square_img',
#     'heuristic_eval_buffer'
# ]

# for model_name in model_names:
#     print(f"Running {model_name}")
#     process = subprocess.Popen(
#         ["python", "train_models.py", "--model_name", model_name, "--num_workers", "0"],
#         stdout=subprocess.PIPE,
#         stderr=subprocess.STDOUT,
#         text=True
#     )

#     for line in process.stdout:
#         print(line, end='')

#     process.wait()