In [None]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
import random

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True  # Enforce deterministic algorithms
        torch.backends.cudnn.benchmark = False     # Disable benchmark for reproducibility

    os.environ['PYTHONHASHSEED'] = str(seed)       # Seed Python hashing, which can affect ordering
set_seed(42)

In [None]:
def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_samples = 0
    with torch.no_grad():
        for inputs, labels, _ in dataloader:
            inputs = inputs.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.float32)
            outputs = model(inputs)['output']
            loss = criterion(outputs, labels)
            total_loss += loss.item() * inputs.size(0)
            total_samples += inputs.size(0)
    return total_loss / total_samples

In [None]:
df = pd.read_pickle("processed_all_data.pkl")

# Create domain-specific dataloaders
domains = df['domain'].unique()
domain_dataloaders = {}
for domain in domains:
    domain_df = df[df['domain'] == domain]
    loaders = create_dataloaders(domain_df, batch_sizes=(32, 64, 64), resize_img_to=(128, 128))  #TODO should be (384, 216) to retain scale or (224, 224) for best performance on MobileNet
    domain_dataloaders[domain] = loaders

In [None]:
# #Reservoire Buffer per batch
# for batch_idx, (inputs, labels, _) in enumerate(train_loader):
#             batch_start = time.time()    

#             inputs = inputs.to(device, dtype=torch.float32)
#             labels = labels.to(device, dtype=torch.float32)

#             # ReservoirBuffer # Sample the buffer and add replay samples to training
#             # if buffer and random.random() < buffer.replay_ratio:
#             #     batch_size = inputs.size(0)
#             #     replay_batch = buffer.sample(int(batch_size * 0.25))
#             #     if replay_batch:
#             #         replay_inputs, replay_labels, _ = zip(*replay_batch)
#             #         replay_inputs = torch.stack(replay_inputs)
#             #         replay_labels = torch.stack(replay_labels)
#             #         inputs = torch.cat([inputs, replay_inputs])
#             #         labels = torch.cat([labels, replay_labels])
            
#             optimizer.zero_grad()
            
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
            
#             loss.backward()
#             optimizer.step()
            
#             # ReservoirBuffer # Add current batch to buffer
#             # with torch.no_grad():
#             #     samples = [(img.detach(), label.detach(), domain) for img, label in zip(inputs, labels)]
#             #     buffer.add(samples)

            
#             batch_time = time.time() - batch_start
                
#             running_loss += loss.item() * inputs.size(0)
#             total_train_samples += inputs.size(0)

#             if batch_idx % 10 == 9:               
#                 avg_loss = running_loss / 10
#                 batch_time = time.time() - batch_start  # ← Keep this
#                 print(
#                     f"Domain: {domain} | Epoch: {epoch+1} | "
#                     f"Batch: {batch_idx+1}/{total_batches} | "
#                     f"Avg Loss: {avg_loss:.4f} | "  # Changed to Avg Loss
#                     f"Time/batch: {batch_time:.2f}s"
#                 )
#                 running_loss = 0.0

In [None]:
import time
from torch.utils.tensorboard import SummaryWriter

def train_model(model, domains, domain_dataloaders, buffer, optimizer, writer, device, criterion, num_epochs=5):
    """Main training function with integrated TensorBoard logging"""
    # Initialize tracking components
    writer = writer
    optimizer = optimizer
    global_step = 0

    # Training loop through domains
    for domain_idx, current_domain in enumerate(domains):
        domain_start_time = time.time()
        train_loader = buffer.get_loader_with_replay(current_domain, domain_dataloaders[current_domain]['train'])
        
        # Domain training
        for epoch in range(num_epochs):
            model.train()
            epoch_loss = 0.0
            samples_processed = 0
            
            for batch_idx, (inputs, labels, _) in enumerate(train_loader):
                # Forward/backward pass
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)['output']
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                # Update tracking
                epoch_loss += loss.item() * inputs.size(0)
                samples_processed += inputs.size(0)
                global_step += 1

                # Batch logging (every 50 batches)
                if batch_idx % 50 == 0:
                    writer.add_scalar('Loss/train_batch', loss.item(), global_step)

            # Epoch summary
            avg_epoch_loss = epoch_loss / samples_processed
            writer.add_scalar('Loss/train_epoch', avg_epoch_loss, global_step)
            
            # Validation
            val_loss = evaluate_model(model, domain_dataloaders[current_domain]['val'], criterion, device)
            writer.add_scalar('Loss/val_domain', val_loss, global_step)
            
            print(f"[{current_domain}][Epoch {epoch+1}] Train: {avg_epoch_loss:.4f} | Val: {val_loss:.4f}")
            
            torch.save(model.state_dict(), f"model_checkpoints/{exp_name}_domain{current_domain}_epoch{epoch}_step{global_step}.pth") 
        
        buffer.update_buffer(current_domain, domain_dataloaders[current_domain]['train'].dataset)
         
        # Cross-domain evaluation
        for eval_domain in domains:
            eval_loss = evaluate_model(model, domain_dataloaders[eval_domain]['val'], criterion, device)
            writer.add_scalar(f'CrossVal/{eval_domain}', eval_loss, global_step)
        
        print(f"Domain {current_domain} completed in {time.time()-domain_start_time:.1f}s")

    writer.close()
    return model

In [None]:
def train_domain(domain_idx, domain, num_epochs=5):
    global global_step

    current_domain = domain
    loaders = domain_dataloaders[domain]
    train_loader = buffer.get_loader_with_replay(current_domain, loaders['train'])

    for epoch in range(num_epochs):
        model.train()
        
        for batch_idx, (inputs, labels, domain_labels) in enumerate(train_loader):
            optimizer.zero_grad()
            inputs = inputs.to(device)
            labels = labels.to(device)
            domain_labels = torch.tensor([domain_to_idx[d] for d in domain_labels], device=device)
            
            # Split batch into current and replay samples
            current_mask = (domain_labels == domain_to_idx[domain])
            replay_mask = ~current_mask

            # 1. Process CURRENT SAMPLES and update all parameters
            if current_mask.any():
                inputs_current = inputs[current_mask]
                labels_current = labels[current_mask]
                domain_labels_current = domain_labels[current_mask]
                
                outputs_current = model(inputs_current)
                inv_feats = outputs_current['invariant_feats']
                spec_feats = outputs_current['specific_feats']
                
                # Losses
                task_loss = mse_criterion(outputs_current['output'], labels_current)
                inv_domain_loss = ce_criterion(outputs_current['invariant_domain'], domain_labels_current)
                spec_domain_loss = ce_criterion(outputs_current['specific_domain'], domain_labels_current)
                similarity_loss = cos_criterion(inv_feats, spec_feats)
                
                total_loss = (task_loss + 
                              0.5 * inv_domain_loss + 
                              0.2 * spec_domain_loss + 
                              0.1 * similarity_loss
                )
                total_loss.backward()

                # Metrics
                inv_pred = outputs_current['invariant_domain'].argmax(1)
                spec_pred = outputs_current['specific_domain'].argmax(1)
                inv_acc = (inv_pred == domain_labels_current).float().mean().item()
                spec_acc = (spec_pred == domain_labels_current).float().mean().item()
            else:
                inv_acc = 0.0
                spec_acc = 0.0
                task_loss = torch.tensor(0.0)

            # 2. Process REPLAY SAMPLES  and update only specific branch + head)
            if replay_mask.any():
                inputs_replay = inputs[replay_mask]
                labels_replay = labels[replay_mask]
                domain_labels_replay = domain_labels[replay_mask]
                
                # No gradients for backbone and invariant branch
                with torch.no_grad():
                    base_replay = model.backbone(inputs_replay)
                    base_replay = model.pool(base_replay).flatten(1)
                    inv_feats_replay = model.invariant(base_replay)
                
                # Normal gradient for the rest
                residual = model.specific_residual(inv_feats_replay)
                spec_feats_replay = inv_feats_replay + residual
                spec_domain_pred = model.specific_domain_classifier(spec_feats_replay)
                
                combined = torch.cat([inv_feats_replay, spec_feats_replay], dim=1)
                scores = model.head(combined)

                # Losses
                task_loss_replay = mse_criterion(scores, labels_replay)
                spec_domain_loss_replay = ce_criterion(spec_domain_pred, domain_labels_replay)
                total_loss_replay = task_loss_replay + 0.2 * spec_domain_loss_replay
                
                total_loss_replay.backward()

            optimizer.step()

            writer.add_scalar('Loss/train', task_loss.item(), global_step)
            writer.add_scalar('Loss/inv_domain', inv_domain_loss.item(), global_step)
            writer.add_scalar('Loss/spec_domain', spec_domain_loss.item(), global_step)
            writer.add_scalar('Loss/similarity', similarity_loss.item(), global_step)
            writer.add_scalar('Accuracy/invariant', inv_acc, global_step)
            writer.add_scalar('Accuracy/specific', spec_acc, global_step)
            writer.add_scalar('Replay/replay_count', replay_mask.sum().item(), global_step)
            writer.add_scalar('Replay/current_count', current_mask.sum().item(), global_step)  

            if batch_idx % 10 == 0:
                print(f"Epoch {epoch} Batch {batch_idx} | "
                    f"Task: {task_loss.item():.4f} | "
                    f"InvAcc: {inv_acc:.2%} | SpecAcc: {spec_acc:.2%} | "
                    f"Sim: {similarity_loss.item():.4f} | "
                    f"Replay: {replay_mask.sum().item()} | "
                    f"Current: {current_mask.sum().item()}")
                
            global_step += 1 
    
        # Validation
        val_loss = evaluate_model(model, loaders['val'], mse_criterion, device)
        
        writer.add_scalar('Loss/val', val_loss, epoch)
        print(f"Domain {current_domain} | Epoch {epoch+1} | Val Loss: {val_loss:.4f}")
        torch.save(model.state_dict(), f"model_checkpoints/{exp_name}_domain{current_domain}_epoch{epoch}_step{global_step}.pth")  

    buffer.update_buffer(domain, loaders['train'].dataset)

In [None]:
metrics = {
    'train_loss': [],
    'val_loss': [],
    'test_loss': [],
    'domain_performance': {domain: [] for domain in domains}
}

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DualBranchNet(num_domains=len(domains)).to(device)
buffer = NaiveRehearsalBuffer(buffer_size=1000)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

mse_criterion = nn.MSELoss()
ce_criterion = nn.CrossEntropyLoss()
def cos_criterion(a, b):
    criterion = nn.CosineSimilarity()
    return (criterion(a, b) ** 2).mean()

domain_to_idx = {d: i for i, d in enumerate(domains)}

In [None]:
import datetime
from torch.utils.tensorboard import SummaryWriter

exp_name = f"dualbranchmodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
writer = SummaryWriter(log_dir=f"tensorboard/{exp_name}")

global_step = 0

for domain_idx, domain in enumerate(domains):
    train_domain(domain_idx, domain, num_epochs=10)
    
    for eval_domain in domains[:domain_idx+1]:
        loader = domain_dataloaders[eval_domain]['val']
        loss = evaluate_model(model, loader, mse_criterion, device)
        print(f"Domain {eval_domain} | Val Loss: {loss:.4f}")

writer.close()



exp_name = f"baselinemodel_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
writer = SummaryWriter(log_dir=f"tensorboard/{exp_name}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LGRBaseline().to(device)
buffer = NaiveRehearsalBuffer(buffer_size=1000)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

train_model(model, domains, domain_dataloaders, buffer, optimizer, writer, device, criterion, num_epochs=10)

writer.close()

## Evaluation

In [None]:
import matplotlib.pyplot as plt

plt.plot(metrics['val_loss'])
plt.plot(metrics['train_loss'])
plt.xlabel('Epoch')
plt.ylabel('Validation Loss')
plt.title('Validation Loss Over Epochs')
plt.show()

In [None]:

for domain, losses in metrics['domain_performance'].items():
    plt.plot(losses, label=domain)
plt.xlabel('Domain Training Step')
plt.ylabel('Loss')
plt.title('Domain Performance Over Training')
plt.legend()
plt.show()

In [None]:
for domain, losses in metrics['domain_performance'].items():
    print(f"{domain}: Initial = {losses[0]:.4f}, Final = {losses[-1]:.4f}, Change = {losses[-1] - losses[0]:.4f}")

In [None]:
# After model output
outputs = model(inputs)  # Should be in [1,5]
print(f"Output range: {outputs.min().item()}–{outputs.max().item()}")

## Overfitting Sanity Check

In [None]:
#Single Batch sanity check for dual branch model

# Get first batch
single_batch = next(iter(domain_dataloaders[domain]['train']))
inputs, labels, domain_labels = single_batch
inputs, labels = inputs.to(device), labels.to(device)
domain_labels = torch.tensor([domain_to_idx[d] for d in domain_labels], device=device)

# Overfit test
for epoch in range(100):
    optimizer.zero_grad()
    
    # Forward
    outputs = model(inputs)
    inv_feats = outputs['invariant_feats']
    spec_feats = outputs['specific_feats']
    
    # Losses
    task_loss = mse_criterion(outputs['output'], labels)
    inv_domain_loss = ce_criterion(outputs['invariant_domain'], domain_labels)
    spec_domain_loss = ce_criterion(outputs['specific_domain'], domain_labels)
    similarity_loss = cos_criterion(inv_feats, spec_feats)
    total_loss = task_loss + 0.5*inv_domain_loss + 0.2*spec_domain_loss + 0.1*similarity_loss
    
    # Backward
    total_loss.backward()
    optimizer.step()
    
    print(f"Epoch {epoch}: Loss {total_loss.item():.4f}")
    
    # Early exit if loss < 0.001
    if total_loss < 0.001:
        break


In [None]:
# Overfit Single Batch Check

test_model = DualBranchNet().to(device)
test_optimizer = optim.Adam(test_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(0)

first_domain = domains[0]
train_loader = domain_dataloaders[first_domain]['train']
single_batch = next(iter(train_loader))
inputs, labels = single_batch
inputs = inputs.to(device, dtype=torch.float32)
labels = labels.to(device, dtype=torch.float32)

# %%
num_test_epochs = 500
for epoch in range(num_test_epochs):
    test_optimizer.zero_grad()

    outputs = test_model(inputs)
    loss = criterion(outputs['output'], labels)
    
    loss.backward()
    test_optimizer.step()
    
    if (epoch+1) % 10 == 0 or epoch == 0:
        print(f"Overfit Epoch {epoch+1}/{num_test_epochs} | Loss: {loss.item():.4f}")


