### Prep

In [None]:
import os
import torch
import torch.nn as nn
from tqdm.notebook import tqdm, trange
import torch.optim as optim
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
import random
import pickle
import datetime
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# DATASET_DIR = (Path("..") / ".." / "datasets").resolve()
# DATASETS = ["OFFICE-MANNERSDB", "MANNERSDBPlus"]
# LABEL_COLS = [
#     "Vaccum Cleaning", "Mopping the Floor", "Carry Warm Food",
#     "Carry Cold Food", "Carry Drinks", "Carry Small Objects",
#     "Carry Large Objects", "Cleaning", "Starting a conversation"
# ]

import sys
sys.path.append('..')

from data_processing.data_processing import ImageLabelDataset,DualImageDataset,create_dataloaders,create_crossvalidation_loaders

In [None]:
df = pd.read_pickle("../data/pepper_data.pkl")

## Evaluation

### files

In [None]:
# different cos 1-abs fixed residual - fixedred_fixedcos_dualbranchmodel_20250604_014811_history
#fixed res but old cosine ^2 - fixedred_dualbranchmodel_20250604_002951_history
#broken baseline - baselinemodel_20250603_200948_history
# recent dualbranch - dualbranchmodel_20250603_135728_history
#normalised and standarised residual - dualbranchmodel_20250609_013632_history
#standarised residual - nonorm_dualbranchmodel_20250609_131820_history
#no residual connection - nores_dualbranchmodel_20250609_232842_history
#same plus gradient clipping - nores_gradclip_dualbranchmodel_20250610_010647_history
#same, no graidient clipping, weights initialisation and normalisation - nores_winit_wnorm_dualbranchmodel_20250610_024401_history
#same, gradient clipping and weights optimisation - nores_gradclip_winit_wnorm_dualbranchmodel_20250610_042118_history
#deeper invariant network - deep_norm_dualbranchmodel_20250610_223350_history
#same but with weiths normalisation - deep_dualbranchmodel_20250610_205411_history

with open('../checkpoints/nonorm_dualbranchmodel_20250609_131820_history.pkl', 'rb') as f:
    nonorm = pickle.load(f)

with open('../checkpoints/nores_dualbranchmodel_20250609_232842_history.pkl', 'rb') as f:
    lnorm_nores = pickle.load(f)

with open('../checkpoints/nores_gradclip_dualbranchmodel_20250610_010647_history.pkl', 'rb') as f:
    lnorm_nores_gradclip = pickle.load(f)

with open('../checkpoints/dualbranchmodel_20250609_013632_history.pkl', 'rb') as f:
    lnorm = pickle.load(f)

with open('../checkpoints/nores_winit_wnorm_dualbranchmodel_20250610_024401_history.pkl', 'rb') as f:
    wnorm_nores = pickle.load(f)

with open('../checkpoints/nores_gradclip_winit_wnorm_dualbranchmodel_20250610_042118_history.pkl', 'rb') as f:
    wnorm_nores_gradclip = pickle.load(f)

with open('../checkpoints/deep_norm_dualbranchmodel_20250610_223350_history.pkl', 'rb') as f:
    wnorm_gradclip_deep = pickle.load(f)

with open('../checkpoints/deep_dualbranchmodel_20250610_205411_history.pkl', 'rb') as f:
    lnorm_gradclip_deep = pickle.load(f)

In [None]:
# all have gradient clipping
# Either detach and skip connect base -> head
# - detach reference - how does detachement work? can backbone learn through skip connection?
# - detach smaller batch effect?
# - detach how partial freeze backbone learns?
# or
# freeze backbone
# - frozen backbone doesn't learn, branches learn solo
# - frozen, branches solo learn but binary?
# - frozen branches solo learn binary but explicit gradient reverse function rather than layer?
# - add full replay? 

files = [
    'bbdetach_deep_dualbranchmodel_20250613_155613_history',                 #deep weights_init=False, detach_base=True, - base for below  experiments
    'bdetach_batch16_dualbranchmodel_20250615_025032_history', # deep unfrozen bb, detached bb, smaller 16 batch - did smaller batch improve stability, did skip connection improve head loss and total training/val loss?
    'bdetach_pfrozen_dualbranchmodel_20250615_040906_history', # deep pfrozen bb, detached bb - training branches+last layer bb - did partially frozen pretrained backbone improve skip head connection?
   
    'ffrozen_dualbranchmodel_20250615_133027_history',                       #deep weights_init=False, detach_base=False, -how do deep linear branches train on their own?
    'ffrozen_binary_dualbranchmodel_20250615_153755_history',                      #weights_init=False, detach_base=False, explicit_grl=False, full_replay = False - does binary improve anything?
    'ffrozen_binary_explicitgrl_dualbranchmodel_20250615_162820_history',          #weights_init=False, detach_base=False, explicit_grl=False, full_replay = False - does the gradient reversal layer work better or function?
    'freplay_ffrozen_binary_explicitgrl_dualbranchmodel_20250615_174623_history'  #weights_init=True,  detach_base=False, explicit_grl=True,  full_replay = True - does full repaly help the branches on their own?
]

In [None]:
files = [
    # 'CNN_pretrained_simple_simple_dualbranchmodel_20250616_032555_history',  # 'pretrained_backbone_cnn_branch': ('pretrained', 'simple', 'simple'), 
    # 'CNN_diy_backbone_linear_branch_dualbranchmodel_20250616_105211_history', # (2layer convolution backbone, 'linear', 'simple'),
    # 'CNN_linear_branch_dualbranchmodel_20250616_172252_history',  # 'linear_branch': ('3conv', 'linear', 'simple'),
    'CNN_cnn_branch_dualbranchmodel_20250616_190434_history',      # 'cnn_branch': ('3conv', 'simple', 'simple'),
    'CNN_adversarial_dualbranchmodel_20250616_201008_history',      # 'adversarial': ('2conv', 'adversarial', 'adversarial'),
    'CNN_cnn_specialised_branches_dualbranchmodel_20250616_232330_history',      # 'cnn_specialised_branches': ('3conv', 'special', 'simple')
    'CNN_pretrained_simple_0.25branches_dualbranchmodel_20250617_230422_history',
    'CNN_pretrained_special_dualbranchmodel_20250617_032445_history'
]


    

In [None]:
files = ['CNN_pretrained_simple_dualbranchmodel_20250617_022356_history', #'pretrained_simple': ('pretrained', 'simple', '3linear', detach_base=False), 
'CNN_pretrained_special_dualbranchmodel_20250617_032445_history', #    'pretrained_special': ('pretrained', 'special', '3linear', False),
'CNN_3conv_simple_dualbranchmodel_20250617_042555_history', #    '3conv_simple': ('3conv', 'simple', '3linear', True),
'CNN_3conv_special_dualbranchmodel_20250617_052951_history' #    '3conv_special': ('3conv', 'special', '3linear', True),
]

In [None]:
files = ['CNN_pretrained_simple_dualbranchmodel_20250617_022356_history',
'CNN_pretrained_simple_0.25branches_dualbranchmodel_20250617_230422_history' #change the loss proportions
]

In [None]:
files =['5foldcrossval_fold4_CNN_3conv_adversarial_dualbranchmodel_20250618_121607_history',
'5foldcrossval_fold4_CNN_pretrained_simple_dualbranchmodel_20250618_064848_history'
]
files = [
    'CNN_pretrained_simple_0.25branches_dualbranchmodel_20250617_230422_history',
    'CNN_adversarial_dualbranchmodel_20250616_201008_history'
]

In [None]:
files = ['DANN_dualbranchmodel_20250618_152753_history',
         'DANN_dynamicalpha_notrain1dom_20250620_201349_history',
         'DANN_dynamicalpha_notrain1dom_20250620_232330_history'
         ]

In [None]:
files = [
'baselinemodel_20250618_195623_history',
'minimal_simple_dualbranchmodel_20250618_213930_history',
'minimal_simple_buffer05_dualbranchmodel_20250618_224113_history',
'minimal_special_dualbranchmodel_20250618_233832_history'
]


In [None]:
files=[
'baselinemodel_20250603_200948_history',
'baselinemodel_20250603_034233_history',
'baselinemodel_20250618_195623_history'
]

In [None]:
files = [
    'baselinemodel_20250603_200948_history',
    'minimal_absurd_buff500_20250619_121348_history',
    'minimal_simple_buffer05_dualbranchmodel_20250618_224113_history',
    'DANN_dualbranchmodel_20250618_152753_history',
    'CNN_pretrained_simple_0.25branches_dualbranchmodel_20250617_230422_history'
    ]

In [None]:
files = [
'DANN_dynamicalpha_notrain1dom_20250620_232330_history',
'DANN_notrain1dom_20250621_005208_history',
]

In [None]:
files = [
'dualbranch_CNN_dynamicalpha_20250621_015610_history',
'CNN_adversarial_dualbranchmodel_20250616_201008_history',
'CNN_cnn_branch_dualbranchmodel_20250616_190434_history'
]

In [None]:
files =[
'dualbranch_CNN_1epoch_20250621_014841_history',
'CNN_pretrained_simple_025branches_dualbranchmodel_20250617_230422_history',
'CNN_pretrained_simple_dualbranchmodel_20250617_022356_history',
'CNN_adversarial_dualbranchmodel_20250616_201008_history'
]

In [None]:
files = [
'3conv_simple_3linear_bnorm_nogradclip_20250621_130858_history',
'3conv_simple_simple_bnorm_20250621_121437_history',
'3conv_simple_simple_bnorm_nogradclip_20250621_102634_history',
'3conv_simple_simple_nogradclip_20250621_112040_history',
]

In [None]:
files =['mobinev2_dann_20250621_204449_history',
        'deeplabv3mobilenetv3_dann_20250621_223924_history',
        'deeplabv3mobilenetv3_dann_20250624_000526_history'
        ]

In [None]:
files = [
    # ('dann', 'dann_mobilenet_20250630_014020_history'), #smoother, stable disentagnlement but not significantly better results
    # ('dann_old', 'mobinev2_dann_20250621_204449_history'), #old mobilenet dann
    ('detached', 'dual_2conv_adversarial_3linear_detach_20250630_045130_history'),
    ('detached_old', 'CNN_3conv_simple_dualbranchmodel_20250617_042555_history'), #detached old
    ('trainable_old', 'CNN_adversarial_dualbranchmodel_20250616_201008_history'), #trainable backbone pretrained adv adv old
    ('frozen', 'dual_mobilenet_simple_3linear_20250630_081954_history'),
    ('frozen_old', 'CNN_pretrained_simple_dualbranchmodel_20250617_022356_history'), #pretrained simple, 3linear old
    # 'dual_mobilenet_linear_simple_20250630_111817_history',  #performs poorly, no surprise taht linear layer canot extract meaningfull features from frozen backbone
    # ('baseline', 'baseline_20250630_161803_history'), #more stable, similar results
    # ('baseline_old', 'baselinemodel_20250618_195623_history'), #old baseline
]

In [None]:
files = [
('linear', 'dual_mobilenet_linear_simple_20250630_111817_history'),
('linear_b120', 'dual_mobilenet_linear_simple_buff120_20250704_081958_history'), 
('base', 'baseline_20250630_161803_history'),
('base_b500', 'baseline_buff500_20250704_085427_history'), 
('base_b120', 'baseline_buff120_20250704_113155_history'), 
('dann', 'dann_mobilenet_20250630_014020_history'),
('dann_b500', 'dann_mobilenet_buff500_20250704_003743_history'), 
('dann_b120', 'dann_mobilenet_buff120_20250704_031423_history'), 
('detach', 'dual_2conv_adversarial_3linear_detach_20250630_045130_history'),
('detach_b500', 'dual_2conv_adversarial_3linear_detach_buff500_20250704_041620_history'), 
('detach_b120', 'dual_2conv_adversarial_3linear_detach_buff120_20250704_054532_history'),
('frozen', 'dual_mobilenet_simple_3linear_20250630_081954_history'), 
('frozen_b500', 'dual_mobilenet_simple_3linear_buff500_20250704_062411_history'), 
('frozen_b120', 'dual_mobilenet_simple_3linear_buff120_20250704_074516_history'),
]

In [None]:
files = [
('base_b500', 'baseline_buff500_20250704_085427_history'), 
('base_b120', 'baseline_buff120_20250704_113155_history'), 
('new_b500','heuristic_dualbranch_buff500_20250704_213101_history'),
('new_b120','heuristic_dualbranch_buff120_20250705_031333_history'),
('frozen_b500', 'dual_mobilenet_simple_3linear_buff500_20250704_062411_history'), 
('frozen_b120', 'dual_mobilenet_simple_3linear_buff120_20250704_074516_history'),


]

In [None]:
files = [
    ('base_b120', 'baseline_buff120_20250704_113155_history'),
]

In [None]:
files = [
('base_b500', 'baseline_buff500_20250704_085427_history'), 
('base_b120', 'baseline_buff120_20250704_113155_history'), 
('new_b500','heuristic_dualbranch_buff500_20250704_213101_history'),
('new_b120','heuristic_dualbranch_buff120_20250705_031333_history'),
('frozen_b500', 'dual_mobilenet_simple_3linear_buff500_20250704_062411_history'), 
('frozen_b120', 'dual_mobilenet_simple_3linear_buff120_20250704_074516_history'),
('old_b500', 'CNN_pretrained_simple_dualbranchmodel_20250617_022356_history'), 
('linear_b1000', 'dual_mobilenet_linear_simple_20250630_111817_history'),
]

In [None]:
files = [
    ('base','heuristic_dualbranch_buff120_20250705_031333_history'),
    # ('small_env','heuristic_small_env_20250722_222924_history'),
    # ('square','heuristic_square_img_20250722_233743_history'),
    ('eval_buffer','heuristic_eval_buffer_20250723_144515_history'),
]

In [None]:
files = [
    ('base','heuristic_dualbranch_buff120_20250705_031333_history'),
    ('small','heuristic__small_imgs_20250729_151705_history'),
    ('small_env','heuristic_small_env_20250722_222924_history'),
    
]


In [None]:
import os

pkl_files = [f for f in os.listdir('../checkpoints/') if f.endswith('.pkl')]
for file in pkl_files:
    print(file)

In [None]:
import pickle
models = {}
for i, file in enumerate(files):
    file_name = ''
    if isinstance(file, tuple):
        file_name, file = file
    with open(f'../checkpoints/{file}.pkl', 'rb') as f:
        model_name = file_name or '_'.join(file.split('_')[:-3])+str(i)
        models[model_name] = pickle.load(f)

In [None]:
# For fold histories
for k,m in models.items():
    models[k] = combine_fold_histories(m)

In [None]:
models2 = models

In [None]:
models.update(models2)

In [None]:
with open('../checkpoints/bbdetach_deep_dualbranchmodel_20250613_155613_history.pkl', 'rb') as f:
    bbdetach_deep = pickle.load(f)


In [None]:
pt_file = torch.load('../checkpoints/dualbranchmodel_20250609_013632_domainSmallOffice_epoch9_step1300.pt')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()
model = DualBranchNet(num_domains=len(domains)).to(device)
model.load_state_dict(pt_file['model_state_dict'])
history = pt_file['history']
tsne_data = pt_file['tsne']

In [None]:
for model in models.keys():
    print(model)

In [None]:
for i in list(models.keys())[1:]:
    for j in models[i]['cross_domain_val']:
        for key in j:
            j[key] = j[key][0]

In [None]:
for i in list(models.keys())[1:]:
    models[i]['val_epoch_loss'] = [j[0] for j in models[i]['val_epoch_loss']]
    models[i]['val_buffer_epoch_loss'] = [j[0] for j in models[i]['val_buffer_epoch_loss']]

### plots

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
sns.set_theme(style="whitegrid")

In [None]:
models = {'bbdetach_deep':bbdetach_deep}

In [None]:
models = {
'nonorm': nonorm,
'lnorm': lnorm,
'lnorm_nores': lnorm_nores,
'lnorm_nores_gradclip': lnorm_nores_gradclip,
'wnorm_nores': wnorm_nores,
'wnorm_nores_gradclip': wnorm_nores_gradclip,
'wnorm_gradclip_deep': wnorm_gradclip_deep,
'lnorm_gradclip_deep': lnorm_gradclip_deep
}

In [None]:
# for i,m in models.items():
#     filtered = [x for i, x in enumerate(m['cross_domain_val']) if i in [9, 19, 29, 39, 49, 58]]
#     m['cross_domain_val'] = filtered

for i,m in models.items():
    print(len(m['cross_domain_val']))

In [None]:
tickvals=[0, 10, 20, 30, 40, 50]
np.multiply(np.array(tickvals), 2).tolist()


In [None]:
import plotly.graph_objects as go
traces = []
for model_name, history in models.items():
    traces.append(go.Scatter(
        x=list(range(len(history['train_epoch_loss']))),
        y=history['train_epoch_loss'],
        mode='lines',
        name=f'{model_name} Train Loss',
        visible=False
    ))
    traces.append(go.Scatter(
        x=list(range(len(history['val_epoch_loss']))),
        y=history['val_epoch_loss'],
        mode='lines',
        name=f'{model_name} Val Loss',
        visible=False
    ))

# Make the first model visible by default
for i in range(2):
    traces[i].visible = True

buttons = []
for i, model_name in enumerate(models.keys()):
    visible = [False] * len(traces)
    visible[2*i] = True
    visible[2*i + 1] = True
    buttons.append(dict(
        label=model_name,
        method='update',
        args=[{'visible': visible}, {'title': f'Training and Validation Loss - {model_name}'}]
    ))

fig = go.Figure(data=traces)
fig.update_layout(
    updatemenus=[dict(
        active=0,
        buttons=buttons,
        x=0.1,
        y=1.15,
        xanchor='left',
        yanchor='top'
    )],
    title='Training and Validation Loss - Model A',
    xaxis_title='Epochs',
    yaxis_title='MSE Loss',
    template='plotly_white',
    yaxis=dict(range=[0, 1])
)
fig.show()

In [None]:
import plotly.graph_objects as go

fig = go.Figure()

for model_name, history in models.items():
    fig.add_trace(go.Scatter(
        x=list(range(len(history['train_epoch_loss']))),
        y=history['train_epoch_loss'],
        mode='lines',
        name=f'{model_name} Train Loss'
    ))
    fig.add_trace(go.Scatter(
        x=list(range(len(history['val_epoch_loss']))),
        y=history['val_epoch_loss'],
        mode='lines',
        name=f'{model_name} Val Loss'
    ))
    try:
        fig.add_trace(go.Scatter(
            x=list(range(len(history['val_buffer_epoch_loss']))),
            y=history['val_buffer_epoch_loss'],
            mode='lines',
            name=f'{model_name} Buffer Val Loss'
        ))
    except:
        pass

fig.update_layout(
    title='Training and Validation Loss',
    xaxis_title='Epochs',
    yaxis_title='MSE Loss',
    template='plotly_white',
    yaxis=dict(range=[-0.0, 1]),
)
fig.update_xaxes(
    tickvals=[0, 10, 20, 30, 40, 50],
    # tickvals = np.multiply(np.array([0, 10, 20, 30, 40, 50]), 3).tolist(),
    ticktext=list(df['domain'].unique())
)

fig.show()

In [None]:
#plt
plt.figure(figsize=(20,7))
for name, model in models.items():
    plt.plot(model['train_epoch_loss'], label=f'{name} Train Loss')
    plt.plot(model['val_epoch_loss'], label=f'{name} Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
# plt.ylim(-0.01, 1)
plt.legend()
plt.show()

In [None]:
for i,m in models.items():
    print(i)

In [None]:
final = []
for name, model in models.items():
    for i in range(6):
        final += list(model['cross_domain_val'][i].values())
max_loss = max(final)
min_loss = min(final)

In [None]:
# plt
history = models['history']['cross_domain_val']

# Extract domain names
domains = list(history[0].keys())

# Prepare accuracy per domain over time
domain_scores = {domain: [] for domain in domains}
for snapshot in history:
    for domain in domains:
        domain_scores[domain].append(snapshot[domain])

# Plot
plt.figure(figsize=(12, 6))
for domain, scores in domain_scores.items():
    plt.plot(domains[:len(scores)], scores, label=domain, marker='o')

plt.xlabel("After training on domain X")
plt.ylabel("Loss")
plt.title("Domain-wise Accuracy Over Time")
# plt.ylim(min(-0.1, min_loss), max_loss)
# plt.ylim(-0.05, 0.1)
plt.legend()
plt.show()

In [None]:
import plotly.graph_objects as go

# Prepare data structure
domain_data = {}
for model_name, model_history in models.items():
    history = model_history['cross_domain_val']
    domains = list(history[0].keys())
    domain_data[model_name] = {
        domain: [snapshot[domain] for snapshot in history]
        for domain in domains
    }

# Create figure
fig = go.Figure()

# Color palette for domains
domain_colors = {
    domain: color for domain, color in zip(
        domains, 
        ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', "#05ffee"]
    )
}

# Add traces for each model and domain
for model_idx, (model_name, domains) in enumerate(domain_data.items()):
    for domain, scores in domains.items():
        fig.add_trace(go.Scatter(
            x=list(range(len(scores))),
            y=scores,
            mode='lines+markers',
            name=domain,
            line=dict(color=domain_colors[domain], width=2),
            marker=dict(size=8, symbol=model_idx+1),  # Unique symbol per model
            legendgroup=model_name,
            legendgrouptitle_text=model_name,
            visible=True if model_idx == 0 else 'legendonly'  # Show first model by default
        ))

# Create model selection buttons
buttons = [
    dict(
        label=model_name,
        method='update',
        args=[
            {'visible': [m == model_name for m in domain_data.keys() for _ in domains]},
            {'title': f'Domain Losses: {model_name}'}
        ]
    ) for model_name in domain_data.keys()
]

# Layout configuration
fig.update_layout(
    title='Validation Loss of Domain X After Training on Domain Y',
    xaxis_title='After Training on Domain Y',
    yaxis_title='MSE Loss',
    legend=dict(
        title='Domain X',
        groupclick="toggleitem",  # Allows group toggling while preserving individual control
        itemsizing='constant'
    ),
    updatemenus=[{
        'type': 'dropdown',
        'direction': 'down',
        'showactive': True,
        'buttons': buttons,
        'x': 1,
        'xanchor': 'left',
        'y': 1.1,
        'yanchor': 'top'
    }],
    template='plotly_white',
    # width=1200,
    # height=700
    yaxis=dict(range=[min(-0.01, min_loss), max_loss]),
    # yaxis=dict(range=[-0.0, 0.8]),
    
)
fig.update_xaxes(
    tickvals=[0, 1, 2, 3, 4, 5],
    ticktext=list(df['domain'].unique())
)

fig.show()


In [None]:
#plt similarity
h = models['deep_norm']

sns.set_theme(style="whitegrid")
plt.figure(figsize=(20,5))
plt.plot([m['similarity'] for m in h['train_epoch_metrics']])
plt.xticks(np.arange(0, 60, step=1))
plt.title('cosine similarity of two branches')
plt.show()

In [None]:
#plt
h = models['history']

plt.figure(figsize=(20,5))
for metric in ['inv_acc', 'spec_acc']:
    plt.plot([m[metric] for m in h['train_epoch_metrics']], label=metric)
plt.title('Branch accuracies and their similarity')
plt.xlabel('Epoch')
plt.ylabel('Metric Value')
plt.xticks(np.arange(0, 60, step=1))
plt.legend()
plt.show()


In [None]:
import plotly.graph_objects as go

fig = go.Figure()

for model_name, history in models.items():
    for metric in ['inv_acc', 'spec_acc']:
        try:
            y_values = [m[metric] for m in history['train_epoch_metrics']]
            fig.add_trace(go.Scatter(
                x=list(range(len(y_values))),
                y=y_values,
                mode='lines+markers',
                name=f'{model_name} - {metric}'
            ))
        except:
            continue

fig.update_layout(
    title='Branch Accuracies',
    xaxis_title='Epoch',
    yaxis_title='Accuracy',
    # xaxis=dict(tickmode='linear', tick0=0, dtick=1),
    template='plotly_white',
    # width=1000,
    # height=400
    yaxis=dict(range=[-0.0, 1]),
)
fig.update_xaxes(
    tickvals=[0, 10, 20, 30, 40, 50],
    ticktext=list(df['domain'].unique())
)

fig.show()


In [None]:
#plt
h = models['history']

plt.figure(figsize=(20,5))
for metric in ['inv_domain', 'spec_domain', 'task_loss']:
    plt.plot([m[metric] for m in h['train_epoch_metrics']], label=metric)
plt.xlabel('Epoch')
plt.ylabel('Metric Value')
plt.ylim(-0.01, 1)
plt.legend()
plt.show()

In [None]:
import plotly.graph_objects as go

fig = go.Figure()

for model_name, history in models.items():
    for metric in ['inv_domain', 'spec_domain', 'task_loss']:
        try:
            y_values = [m[metric] for m in history['train_epoch_metrics']]
            fig.add_trace(go.Scatter(
                x=list(range(len(y_values))),
                y=y_values,
                mode='lines+markers',
                name=f'{model_name} - {metric}'
            ))
        except:
            continue

fig.update_layout(
    title='CE loss - inv_domain, spec_domain, and MSE task_loss',
    xaxis_title='Epoch',
    yaxis_title='Loss',
    template='plotly_white',
    # width=1000,
    # height=400
    yaxis=dict(range=[-0.01, 0.12]),
)
fig.update_xaxes(
    tickvals=[0, 10, 20, 30, 40, 50],
    ticktext=list(df['domain'].unique())
)

fig.show()


In [None]:
h = models['history']

plt.figure(figsize=(20,5))
for metric in ['replay_count', 'current_count']:
    plt.plot([m[metric]/32*100 for m in h['train_epoch_metrics']], label=metric)
plt.title('Type of samples in batch')
plt.xlabel('Epoch')
plt.ylabel('% of batch')
# plt.axhline(32, color='r')
# plt.xticks(np.arange(0, 60, step=1))
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(20,5))
# for module in ['invariant', 'specific_residual', 'domain_classifier']:
for module in h['grad_norms'][0].keys():
    plt.plot([m[f'{module}'] for m in h['grad_norms']], label=module)
plt.title('Gradient Norms by Module')
# plt.ylim(-0.001, 1)
plt.legend()

In [None]:
import plotly.graph_objects as go

fig = go.Figure()

# Collect all module names across all models
all_modules = set()
for model_name, history in models.items():
    if 'grad_norms' in history and len(history['grad_norms']) > 0:
        for grad_dict in history['grad_norms']:
            all_modules.update(grad_dict.keys())
all_modules = sorted(all_modules)
# all_modules = ['invariant', 'specific', 'head']


# Add traces for each model and module, only show the first model by default
trace_visibility = []
for model_idx, (model_name, history) in enumerate(models.items()):
    if 'grad_norms' not in history or len(history['grad_norms']) == 0:
        continue
    for grad_dict in history['grad_norms']:
        for key in all_modules:
            grad_dict.setdefault(key, 0)
    for module in all_modules:
        y_values = [m[module] for m in history['grad_norms']]
        fig.add_trace(go.Scatter(
            x=list(range(len(y_values))),
            y=y_values,
            mode='lines+markers',
            name=f'{model_name} - {module}',
            # visible=(model_idx == 0)
        ))
    trace_visibility.append((model_name, len(all_modules)))

# Create dropdown buttons for each model
buttons = []
start_idx = 0
for model_name, n_traces in trace_visibility:
    visible = [False] * len(fig.data)
    for i in range(start_idx, start_idx + n_traces):
        visible[i] = True
    buttons.append(dict(
        label=model_name,
        method='update',
        args=[{'visible': visible}, {'title': f'Gradient Norms by Module - {model_name}'}]
    ))
    start_idx += n_traces

fig.update_layout(
    title=f'Gradient Norms by Module - {list(models.keys())[0]}',
    xaxis_title='Epoch',
    yaxis_title='Gradient Norm',
    template='plotly_white',
    # width=1000,
    # height=400,
    updatemenus=[{
        'buttons': buttons,
        'direction': 'down',
        'showactive': True,
        'x': 1,
        'xanchor': 'left',
        'y': 1.2,
        'yanchor': 'top'
    }],
    yaxis=dict(range=[-0.0, 0.2]),
)
fig.update_xaxes(
    tickvals=[0, 10, 20, 30, 40, 50],
    ticktext=list(df['domain'].unique())
)

fig.show()


### tsne projection

In [None]:
import glob
import re
import torch
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from ipywidgets import interact

# 1. Gather all checkpoint files
checkpoint_files = glob.glob("../checkpoints/dualbranchmodel_20250609_013632_*.pt")

# 2. Parse out the step value and sort
pattern = re.compile(r"_step(\d+)\.pt")
files_with_steps = []
for f in checkpoint_files:
    match = pattern.search(f)
    if match:
        step = int(match.group(1))
        files_with_steps.append((step, f))
files_with_steps.sort()  # Sort by step

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
# 3. Precompute t-SNE for all checkpoints (0-59)
tsne_projections = []
for idx, (step, ckpt_file) in enumerate(tqdm(files_with_steps, desc="Processing checkpoints")):
    ckpt = torch.load(ckpt_file, map_location='cpu')
    data = ckpt['tsne']
    inv_feats = np.array(data['inv_feats'])
    spec_feats = np.array(data['spec_feats'])
    domain_labels = np.array(data['domain_labels'])

    tsne = TSNE(n_components=2, random_state=42)
    inv_2d = tsne.fit_transform(inv_feats)
    spec_2d = tsne.fit_transform(spec_feats)

    tsne_projections.append({
        'timeline_idx': idx,  # 0 to 59
        'inv_2d': inv_2d,
        'spec_2d': spec_2d,
        'domains': domain_labels,
        'filename': ckpt_file
    })

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output

# State variable for current index
current_idx = 0

# Output widget for the plot
out = widgets.Output()

# Buttons
button_prev = widgets.Button(description="Previous")
button_next = widgets.Button(description="Next")

# Precompute limits
all_x = np.concatenate([d['inv_2d'][:,0] for d in tsne_projections] + [d['spec_2d'][:,0] for d in tsne_projections])
all_y = np.concatenate([d['inv_2d'][:,1] for d in tsne_projections] + [d['spec_2d'][:,1] for d in tsne_projections])
x_min, x_max = all_x.min(), all_x.max()
y_min, y_max = all_y.min(), all_y.max()

def plot_epoch(timeline_idx):
    data = tsne_projections[timeline_idx]
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), constrained_layout=True)
    domain_to_int = {name: i for i, name in enumerate(domains)}
    domain_ints = np.array([domain_to_int[name] for name in data['domains']])
    scatter1 = ax1.scatter(data['inv_2d'][:,0], data['inv_2d'][:,1], 
                          c=domain_ints, cmap='tab10', alpha=0.7, vmin=0, vmax=len(domains)-1)
    ax1.set_title(f"Invariant Features - Timeline {timeline_idx}")
    ax1.set_xlim(x_min, x_max)
    ax1.set_ylim(y_min, y_max)
    scatter2 = ax2.scatter(data['spec_2d'][:,0], data['spec_2d'][:,1],
                          c=domain_ints, cmap='tab10', alpha=0.7, vmin=0, vmax=len(domains)-1)
    ax2.set_title(f"Specific Features - Timeline {timeline_idx}")
    ax2.set_xlim(x_min, x_max)
    ax2.set_ylim(y_min, y_max)
    cbar = fig.colorbar(scatter1, ax=[ax1, ax2], label='Domain', 
                        ticks=np.arange(len(domains)), boundaries=np.arange(len(domains)+1)-0.5)
    cbar.set_ticks(np.arange(len(domains)))
    cbar.set_ticklabels(domains)
    plt.show()

def on_prev_clicked(b):
    global current_idx
    if current_idx > 0:
        current_idx -= 1
        with out:
            clear_output(wait=True)
            plot_epoch(current_idx)

def on_next_clicked(b):
    global current_idx
    if current_idx < len(tsne_projections) - 1:
        current_idx += 1
        with out:
            clear_output(wait=True)
            plot_epoch(current_idx)

button_prev.on_click(on_prev_clicked)
button_next.on_click(on_next_clicked)

# Display everything
display(widgets.HBox([button_prev, button_next]))
display(out)

# Initial plot
with out:
    plot_epoch(current_idx)

### single batch overfit

In [None]:
# After model output
outputs = model(inputs)  # Should be in [1,5]
print(f"Output range: {outputs.min().item()}–{outputs.max().item()}")


In [None]:
test_model = LGRBaseline().to(device)
test_optimizer = optim.Adam(test_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(0)
criterion = nn.MSELoss()

first_domain = domains[0]
train_loader = domain_dataloaders[first_domain]['train']
single_batch = next(iter(train_loader))
inputs, labels, _ = single_batch
inputs = inputs.to(device, dtype=torch.float32)
labels = labels.to(device, dtype=torch.float32)

# %%
num_test_epochs = 400
for epoch in range(num_test_epochs):
    test_optimizer.zero_grad()

    outputs = test_model(inputs)
    loss = criterion(outputs['output'], labels)
    
    loss.backward()
    test_optimizer.step()
    
    if (epoch+1) % 10 == 0 or epoch == 0:
        print(f"Overfit Epoch {epoch+1}/{num_test_epochs} | Loss: {loss.item():.4f}")

In [None]:
from torch.utils.tensorboard import SummaryWriter

# device = torch.device('cpu')torch.device("cuda" if torch.cuda.is_available() else "cpu")
domains = df['domain'].unique()

writer = SummaryWriter("visualisation/")
model = DualBranchNet().to(device)
first_domain = domains[0]
train_loader = domain_dataloaders[first_domain]['train']
single_batch = next(iter(train_loader))
inputs, labels, _ = single_batch
inputs = inputs.to(device, dtype=torch.float32)
labels = labels.to(device, dtype=torch.float32)
# writer.add_graph(model, inputs)

writer.close()

In [None]:
input_names = ["image"]
output_names = ["appropriateness scores"]

torch.onnx.export(model, inputs, "model.onnx", input_names=input_names, output_names=output_names)

In [None]:
from torchviz import make_dot

y = model(inputs)
make_dot(y.mean(), params=dict(model.named_parameters()))

In [None]:
import torch
from torch.utils.data import DataLoader

# 1. Get a single batch from any domain's train loader
domain = domains[0]
single_batch = next(iter(domain_dataloaders[domain]['train']))


buffer = NaiveRehearsalBuffer(0)

# 4. Overfit loop for both models
def overfit_model(
    model, optimizer, batch_fn, batch_kwargs, device, num_epochs=100, exp_name="overfit"
):
    model.train()
    losses = []
    for epoch in range(num_epochs):
        optimizer.zero_grad()
        loss, metrics = batch_fn(model, single_batch, device, **batch_kwargs)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        print(f"Epoch {epoch+1}, Loss: {loss.item():.6f}")
    return losses



In [None]:
# 5. Baseline model overfit
baseline_model = LGRBaseline().to(device)
optimizer = torch.optim.Adam(baseline_model.parameters(), lr=1e-3)
baseline_losses = overfit_model(
    baseline_model, optimizer, baseline_batch, {'mse_criterion': torch.nn.MSELoss()}, device
)

# 6. DualBranch model overfit
dual_model = DualBranchNet(num_domains=len(domains)).to(device)
optimizer = torch.optim.Adam(dual_model.parameters(), lr=1e-3)
dualbranch_kwargs = {
    'mse_criterion': torch.nn.MSELoss(),
    'ce_criterion': torch.nn.CrossEntropyLoss(),
    'cos_criterion': lambda a, b: (torch.nn.CosineSimilarity()(a, b) ** 2).mean(),
    'domain_to_idx': domain_to_idx,
    'current_domain': domain
}
dualbranch_losses = overfit_model(
    dual_model, optimizer, dualbranch_batch, dualbranch_kwargs, device
)



In [None]:
# 7. Plot the loss curves (optional)
import matplotlib.pyplot as plt
plt.plot(baseline_losses, label='Baseline')
plt.plot(dualbranch_losses, label='DualBranch')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.legend()
plt.title('Overfitting to a Single Batch')
plt.show()
