## Data preparation

In [None]:
import os
import pandas as pd
from torchvision import transforms
from pathlib import Path
import numpy as np
import random
import pickle
from tqdm.notebook import tqdm, trange
import torch.optim as optim
import numpy as np
import random
import torch
import torch.nn as nn
import torchvision.models as models
import datetime


SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# DATASET_DIR = (Path("..") / ".." / "datasets").resolve()
# DATASETS = ["OFFICE-MANNERSDB", "MANNERSDBPlus"]
# LABEL_COLS = [
#     "Vaccum Cleaning", "Mopping the Floor", "Carry Warm Food",
#     "Carry Cold Food", "Carry Drinks", "Carry Small Objects",
#     "Carry Large Objects", "Cleaning", "Starting a conversation"
# ]

# import sys
# sys.path.append('..')
# from data_processing.data_processing import create_crossvalidation_loaders

import sys
sys.path.append('..')

from models.buffers import NaiveRehearsalBuffer
from data_processing.data_processing import get_domain_dataloaders
from models.training_utils import heuristic_dualbranch_batch, unified_train_loop

In [None]:
df = pd.read_pickle("../data/pepper_data.pkl")

## CL Models

In [None]:
# Ideas for potential ways to combine networks
# residual = SpecificHead(base + invariant_feats)
# specific_feats = invariant_feats + residual

# gate = GateNet(base + invariant_feats)
# residual = SpecificHead(base)
# specific_feats = invariant_feats + gate * residual

# residual = Attention(invariant_feats, base)
# specific_feats = invariant_feats + residual

# gamma, beta = SpecificHead(base)
# specific_feats = gamma * invariant_feats + beta

# invariant_feats = InvariantHead(base)
# specific_feats = SpecificHead(base)
# cos_sim(invariant_feats, specific_feats) ~ 0


### dualbranch model

In [None]:
def intermediate_layer_size(input, output, n_layers):
    start_exp = (output + 1).bit_length()
    end_exp = (input - 1).bit_length()
    
    total_powers = end_exp - start_exp
    if total_powers < n_layers:
        return None

    result = []
    denominator = n_layers + 1
    half_denominator = denominator // 2

    for i in range(1, n_layers + 1):
        numerator = i * total_powers + half_denominator #works same as rounding
        idx = numerator // denominator
        power = 1 << (start_exp + idx)
        result.append(power)

    return result


intermediate_layer_size(1280+64, 9, 3)

In [None]:
import copy

class CLIPEncoderWrapper(nn.Module):
    """Wrapper for the forward .encode_image() function of the CLIP encoder"""
    def __init__(self, clip_model):
        super().__init__()
        self.clip_model = clip_model
        for p in self.clip_model.parameters():
            p.requires_grad = False
    def forward(self, x):
        with torch.no_grad():
            return self.clip_model.encode_image(x)

class DualBranchModel(nn.Module):
    def __init__(self, num_outputs=9, dropout_rate=0.3, setup={'branch':'mobilenetv2'}, clip_model=None, freeze_branches=False):
        assert not (setup['branch'] == 'clip' and clip_model is None), "clip_model must be provided for CLIP branch"
        assert not (setup['branch'] == 'clip' and freeze_branches is False), "CLIP branch will be frozen regardless of freeze_branches"
        super(DualBranchModel, self).__init__()
        self.setup = setup
        self.freeze_branches = freeze_branches

        if self.setup['branch'] == 'resnet18':
            model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
            branch = nn.Sequential(
                *list(model.children())[:-2],
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten()
            )
            branch_feature_dim = 512
        elif self.setup['branch'] == 'mobilenetv2':
            model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1).features
            branch = nn.Sequential(
                model,
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten()
            )
            branch_feature_dim = 1280
        elif self.setup['branch'] == 'efficientnetb0':
            model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1).features
            branch = nn.Sequential(
                model,
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten()
            )
            branch_feature_dim = 1280
        elif self.setup['branch'] == 'clip':
            branch = CLIPEncoderWrapper(clip_model)
            branch_feature_dim = 512

        # Optional freeze params
        if self.freeze_branches:
            for p in branch.parameters():
                p.requires_grad = False
        
        # If using frozen CLIP reusing the model is more efficient
        if self.setup['branch'] == 'clip':
            self.social_branch = branch
            self.env_branch = branch
        else:
            self.social_branch = branch
            self.env_branch = copy.deepcopy(branch)

        soc_feature_dim = branch_feature_dim
        env_feature_dim = branch_feature_dim

        # Override one of the branches to run ablations
        if self.setup.get('env') == 'ablated':
            env_feature_dim = 0


        self.fusion_dim = soc_feature_dim + env_feature_dim
        
        self.head = nn.Sequential(
            nn.Linear(self.fusion_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            
            nn.Linear(128, num_outputs)
        )
        
    def forward(self, social_imgs, env_imgs=None):
        if env_imgs is None:
            env_features = torch.zeros(social_imgs.size(0), 0, device=social_imgs.device, dtype=social_imgs.dtype)
        else:
            env_features = self.env_branch(env_imgs)
        
        social_features = self.social_branch(social_imgs)

        fused_features = torch.cat([social_features, env_features], dim=1)
        scores = self.head(fused_features)

        return {
            'output': scores,
            'invariant_feats': social_features,
            'specific_feats': env_features
        }

### K Fold Cross Validation

In [None]:
# def combine_fold_histories(fold_history):
#     """
#     Combine multiple training histories by averaging them element-wise.
#     Handles both simple lists and lists of dictionaries.
#     """
#     combined_history = {}
#     metric_keys = list(fold_history[0].keys())
    
#     for key in metric_keys:
#         # Check if this metric contains dictionaries
#         first_element = fold_history[0][key][0] if fold_history[0][key] else None
        
#         if isinstance(first_element, dict):
#             # Handle lists of dictionaries (train_epoch_metrics, grad_norms)
#             combined_history[key] = average_list_of_dicts(fold_history, key)
#         else:
#             # Handle simple lists (train_epoch_loss, val_epoch_loss, cross_domain_val)
#             stacked_metrics = np.stack([fold_history[fold][key] for fold in fold_history])
#             combined_history[key] = np.mean(stacked_metrics, axis=0).tolist()
    
#     return combined_history

# def average_list_of_dicts(fold_history, metric_key):
#     """
#     Average a list of dictionaries across folds.
#     """
#     # Get the dictionary keys from the first fold's first epoch
#     dict_keys = list(fold_history[0][metric_key][0].keys())
    
#     # Convert each fold's list of dicts to a 2D numpy array
#     fold_arrays = []
#     for fold in fold_history:
#         # Convert list of dicts to 2D array: [epochs, sub_metrics]
#         fold_array = np.array([[epoch_dict[k] for k in dict_keys] 
#                               for epoch_dict in fold_history[fold][metric_key]])
#         fold_arrays.append(fold_array)
    
#     # Stack all folds and average: [folds, epochs, sub_metrics] -> [epochs, sub_metrics]
#     stacked = np.stack(fold_arrays)
#     averaged = np.mean(stacked, axis=0)
    
#     # Convert back to list of dictionaries
#     result = []
#     for epoch_values in averaged:
#         epoch_dict = dict(zip(dict_keys, epoch_values))
#         result.append(epoch_dict)
    
#     return result

In [None]:
# from sklearn.model_selection import StratifiedKFold

# # Add K fold labels to datapoints in dataframe, later used for creating K Fold Dataloaders

# df = pd.read_pickle("../data/pepper_data.pkl")
# df['image_path'] = '../' + df['image_path']

# # Initialize fold column
# df['fold'] = -1

# # Get unique image paths for splitting
# unique_images_df = df[['image_path', 'domain']].reset_index(drop=True)
# #exclude test subset test idx get from get_dataloader()
# unique_images_df = unique_images_df[~unique_images_df['image_path'].isin(test_split_idx)]

# # Create stratified 5-fold splits based on domain
# skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)

# # Assign fold numbers based on domain stratification
# for fold, (_, val_idx) in enumerate(skf.split(unique_images_df['image_path'], unique_images_df['domain'])):
#     val_image_paths = unique_images_df.iloc[val_idx]['image_path'].tolist()
#     df.loc[df['image_path'].isin(val_image_paths), 'fold'] = fold

# # Verify domain distribution across folds
# print("Domain distribution across folds:")
# print(df.groupby(['fold', 'domain']).size().unstack(fill_value=0))

In [None]:
# fold_loaders = create_crossvalidation_loaders(df, 5, batch_sizes=(32, 64, 64), resize_img_to=(512, 288))

## Training

### current

In [None]:
# TODO tests: retrain the model with
# 3 trains averaged
# both branches + random masks

# but first: fix underfitting: validation set
# lower droput
# diff branch:
#     resnet18
#     mobilenetv2
#     efficientnetb0
#     clip

#     + ablations - only_soc, only_env, no_mask
# learning rate scheduler with warmup and decay

# then
# best model, x3 random initialisaiton, on same test set, no validation, report final metrics on test set
#same for ablations
#same for random mask



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import clip
clip_model, clip_transform = clip.load("ViT-B/32", device=device)

default_transform = transforms.Compose([
        transforms.Resize((144,256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225])
    ])

domains = ['Home', 'BigOffice-2', 'BigOffice-3', 'Hallway', 'MeetingRoom', 'SmallOffice']

testing_scenarios = {
    'mobilenetv2':      (False,),
    'resnet18':         (False,),
    'efficientnetb0':   (False,),
    'clip':             (True,),
}
    
for name, (freeze_branches) in testing_scenarios.items():
    for ablation in ['base', 'no_mask', 'only_soc', 'only_env']:
        
        # Training Data
        df = pd.read_pickle("../data/pepper_data_train.pkl")
        
        transform = clip_transform if name == 'clip' else default_transform

        if ablation == 'base':
            domain_dataloaders = get_domain_dataloaders(df, batch_sizes=(32, 64, 64), double_img=True, transforms=[transform]*2, num_workers=0, include_test=None)
        elif ablation == 'no_mask':
            domain_dataloaders = get_domain_dataloaders(df, batch_sizes=(32, 64, 64), double_img=False, transforms=transform, num_workers=0, include_test=None)
        elif ablation == 'only_soc':
            df['image_path'] = df['image_path_social']
            domain_dataloaders = get_domain_dataloaders(df, batch_sizes=(32, 64, 64), double_img=False, transforms=transform, num_workers=0, include_test=None)
        elif ablation == 'only_env':
            df['image_path'] = df['image_path_env']
            domain_dataloaders = get_domain_dataloaders(df, batch_sizes=(32, 64, 64), double_img=False, transforms=transform, num_workers=0, include_test=None)

        print(f"\nTesting: {name} Ablation: {ablation}")
        setup = {'branch':name} if ablation == 'base' else {'branch': name, 'env': 'ablated'}
        auxilary_model = clip_model if name == 'clip' else None
        model = DualBranchModel(dropout_rate=0.1, setup=setup, freeze_branches=freeze_branches, clip_model=auxilary_model)
        dual_model = model.to(device)
        trainable_params = [p for p in dual_model.parameters() if p.requires_grad]
        optimizer = torch.optim.Adam(trainable_params, lr=1e-3)
        buffer = NaiveRehearsalBuffer(buffer_size=120)

        epochs = 20

        exp_name = f"{name}_dropout{0.1}_epochs{epochs}_ablation-{ablation}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
        
        dualbranch_kwargs = {
                'mse_criterion': nn.MSELoss(),
                'ce_criterion': nn.CrossEntropyLoss(),
            }
        
        
        unified_train_loop(
            model=dual_model,
            domains=domains,
            domain_dataloaders=domain_dataloaders,
            buffer=buffer,
            optimizer=optimizer,
            device=device,
            batch_fn=heuristic_dualbranch_batch,
            batch_kwargs=dualbranch_kwargs,
            num_epochs=epochs,
            exp_name=exp_name,
            gradient_clipping=True,
            collect_tsne_data=False,
            checkpoint_dir="../checkpoints",
            validation_set='val',
        )

In [None]:
# import subprocess

# model_names = [
#     'heuristic_small_env',
#     'heuristic_square_img',
#     'heuristic_eval_buffer'
# ]

# for model_name in model_names:
#     print(f"Running {model_name}")
#     process = subprocess.Popen(
#         ["python", "train_models.py", "--model_name", model_name, "--num_workers", "0"],
#         stdout=subprocess.PIPE,
#         stderr=subprocess.STDOUT,
#         text=True
#     )

#     for line in process.stdout:
#         print(line, end='')

#     process.wait()