## Data preparation

In [None]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
import random
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

DATASET_DIR = (Path("..") / "datasets").resolve()
DATASETS = ["OFFICE-MANNERSDB", "MANNERSDBPlus"]
LABEL_COLS = [
    "Vaccum Cleaning", "Mopping the Floor", "Carry Warm Food",
    "Carry Cold Food", "Carry Drinks", "Carry Small Objects",
    "Carry Large Objects", "Cleaning", "Starting a conversation"
]

In [None]:
def process_csv(csv_path, dataset):
    """Process individual CSV files"""
    df = pd.read_csv(csv_path)
    df = df.drop(columns=df.columns[-1])
    
    # Extract metadata from first column
    first_col = df.columns[0]
    split_data = df[first_col].str.split('_', n=2, expand=True)
    
    df["robot"] = split_data[0]
    df["domain"] = split_data[1]
    df["image_ref"] = split_data[2].astype(int)
    df["dataset"] = dataset

    df = df.drop(columns=[first_col])
    
    return df

def consolidate_data(datasets):
    """Aggregate all CSVs"""
    all_dfs = []
    for dataset in datasets:
        source_path = DATASET_DIR / dataset
        
        for robot in ["NAO", "Pepper", "PR2"]:
            ann_dir = source_path / robot / "Annotations"
            if not ann_dir.exists():
                raise ValueError(f"Labels csv file path ({ann_dir}) doesn't exist")
                
            
            for csv_file in ann_dir.glob("*.csv"):
                try:
                    df = process_csv(csv_file, dataset)
                    all_dfs.append(df)
                except Exception as e:
                    print(f"Error processing {csv_file}: {str(e)}")
    
    df = pd.concat(all_dfs, ignore_index=True)

    return df

In [None]:
def validate_raw_data(df):
    """Comprehensive data quality checks for raw annotation data"""
    required_columns = {'robot', 'domain', 'image_ref', 'dataset'}

    # Check for any missing columns
    missing_cols = required_columns - set(df.columns)
    if missing_cols:
        raise ValueError(f"Missing required columns: {missing_cols}")

    # Label value validation (should be between 1 and 5)
    for col in LABEL_COLS:
        if df[col].min() < 1 or df[col].max() > 5:
            raise ValueError(f"Label {col} has invalid range [{df[col].min()}, {df[col].max()}]")

    # Null values check
    null_cols = df.columns[df.isnull().any()].tolist()
    if null_cols:
        raise ValueError(f"Null values found in columns: {null_cols}")

    # Data type and value validation for image_ref
    if not pd.api.types.is_integer_dtype(df['image_ref']):
        raise TypeError("image_ref must be integer type")
    if (df['image_ref'] < 0).any():
        raise ValueError("image_ref contains negative values, which is invalid")

    # Categorical value validation
    valid_robots = {'NAO', 'Pepper', 'PR2'}
    invalid_robots = set(df['robot']) - valid_robots
    if invalid_robots:
        raise ValueError(f"Invalid robot values: {invalid_robots}")

    valid_sources = {'OFFICE-MANNERSDB', 'MANNERSDBPlus'}
    invalid_sources = set(df['dataset']) - valid_sources
    if invalid_sources:
        raise ValueError(f"Invalid source directories: {invalid_sources}")

    return True


In [None]:
def aggregate_labels(df):
    """Aggregate multiple annotations per image by image path"""    
    agg_dict = {
        **{col: 'mean' for col in LABEL_COLS},
        **{col: 'first' for col in df.columns.difference(LABEL_COLS).tolist()},
    }
    
    return df.groupby('image_path', as_index=False).agg(agg_dict)


In [None]:
def resolve_image_path(row):
    """Robust path resolution with validation"""
    base_dir = DATASET_DIR / row['dataset'] / row['robot'] / "Images"
    
    if row['dataset'] == "OFFICE-MANNERSDB":
        target = base_dir / f"{row['domain']}_{row['image_ref']}.png"
    else:
        target = next(base_dir.glob(f"{row['image_ref']}_*.png"), None)
    
    if target and target.exists():
        return target.resolve()
    return None

In [None]:
class ImageLabelDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if isinstance(idx, torch.Tensor):
            idx = idx.item()
            
        img_path = str(self.df.at[idx, "image_path"])
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            raise RuntimeError(f"Error loading {img_path}: {str(e)}")
        
        raw_labels = self.df.iloc[idx][LABEL_COLS].values.astype(np.float32)
        scaled_labels = (raw_labels - 1) / 4  # Convert 1-5 → 0-1

        if self.transform:
            image = self.transform(image)
            
        return image, torch.from_numpy(scaled_labels)

In [None]:
def create_dataloaders(df, batch_sizes=(32, 64, 64), resize_img_to=(128, 128)):
    """Create train/val/test dataloaders using image_path as unique key"""
    
    # Get image paths as indexing for split
    unique_images = df[['image_path']].reset_index(drop=True)
    
    # Split using image_path as key #TODO is is spliting based on index or path strings? index faster
    train_paths, temp_paths = train_test_split(
        unique_images['image_path'], 
        test_size=0.3, 
        random_state=42
    )
    val_paths, test_paths = train_test_split(
        temp_paths,
        test_size=0.5, 
        random_state=42
    )
   
    # Create subsets
    train_df = df[df['image_path'].isin(train_paths)].reset_index(drop=True)
    val_df = df[df['image_path'].isin(val_paths)].reset_index(drop=True)
    test_df = df[df['image_path'].isin(test_paths)].reset_index(drop=True)
    
    #TODO add coordinate values for spacialy aware CNN Uber's CoordNav
    # Define transforms
    transform = transforms.Compose([
        transforms.Resize(resize_img_to),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                            std=[0.229, 0.224, 0.225])
    ])
    
    # Create datasets
    train_dataset = ImageLabelDataset(train_df, transform)
    val_dataset = ImageLabelDataset(val_df, transform)
    test_dataset = ImageLabelDataset(test_df, transform)
    
    # Create loaders
    num_workers = 0
    loaders = {
        'train': DataLoader(train_dataset, batch_size=batch_sizes[0], shuffle=True, num_workers=num_workers, pin_memory=torch.cuda.is_available()),
        'val': DataLoader(val_dataset, batch_size=batch_sizes[1], shuffle=False, num_workers=num_workers, pin_memory=torch.cuda.is_available()),
        'test': DataLoader(test_dataset, batch_size=batch_sizes[2], shuffle=False, num_workers=num_workers, pin_memory=torch.cuda.is_available())
    }
    
    return loaders


In [None]:
def validate_final_data(df):
    """Final validation after aggregation"""
    # Missing image paths
    missing = df[df['image_path'].isnull()]
    if not missing.empty:
        raise FileNotFoundError(
            f"{len(missing)} images missing after aggregation. Examples:\n"
            f"{missing[['robot', 'domain', 'image_ref']].head()}"
        )
    
    # Null values check
    null_cols = df.columns[df.isnull().any()].tolist()
    if null_cols:
        raise ValueError(f"Null values found in columns: {null_cols}")

    # Duplicate image paths
    duplicates = df[df.duplicated('image_path', keep=False)]
    if not duplicates.empty:
        raise RuntimeError(
            f"Duplicate image paths after aggregation:\n"
            f"{duplicates['image_path'].unique()}"
        )

    # Label validity (1-5)
    for col in LABEL_COLS:
        if df[col].min() < 1 or df[col].max() > 5:
            raise ValueError(
                f"Aggregated label {col} out of range: "
                f"[{df[col].min()}, {df[col].max()}]"
            )

    return True


## --

In [None]:
try:
    raw_df = consolidate_data(DATASETS)
    validate_raw_data(raw_df)
    raw_df['image_path'] = raw_df.apply(resolve_image_path, axis=1)   
    aggregated_df = aggregate_labels(raw_df)
    validate_final_data(aggregated_df) 
except Exception as e:
    print(f"Pipeline failed: {str(e)}")
    raise

aggregated_df.to_pickle("processed_all_data.pkl")


## Data analysis

In [None]:
try:
    raw_df = consolidate_data(DATASETS)
    validate_raw_data(raw_df)
except Exception as e:
    print(f"Pipeline failed: {str(e)}")
    raise

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_style("whitegrid")

entries_per_image = raw_df.groupby(['robot', 'domain', 'image_ref']).size().reset_index(name='num_entries')
domain_counts = entries_per_image.groupby(['domain', 'num_entries']).size().reset_index(name='count')

plt.figure(figsize=(12, 5))
ax = sns.barplot(
    data=domain_counts,
    x='num_entries',
    y='count',
    hue='domain',
    hue_order = ['Home', 'BigOffice-2', 'BigOffice-3', 'Hallway', 'MeetingRoom',
       'SmallOffice'],
    palette='viridis'
)
for p in ax.patches:
    height = p.get_height()
    if height > 0:
        ax.text(
            p.get_x() + p.get_width() / 2.,
            height + 0.5,  # Slightly above the bar
            int(height),
            ha='center',
            va='bottom',
            fontsize=11
        )
ax.set_ylim(0, domain_counts['count'].sum())

plt.xlabel('Number of Entries per Image')
plt.ylabel('Number of Images')
plt.title('Distribution of Number of Entries per Image by Domain')
plt.tight_layout()
plt.show()



In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

duplicated_rows = raw_df.duplicated(keep=False)
duplicates_per_index = duplicated_rows.groupby([raw_df['robot'], raw_df['domain'], raw_df['image_ref']]).sum()
duplicates_per_index = duplicates_per_index.apply(lambda x: max(x - 1, 0))

duplicates_per_index = duplicates_per_index.reset_index(name='num_duplicates')
domain_counts = duplicates_per_index.groupby(['domain', 'num_duplicates']).size().reset_index(name='count')


plt.figure(figsize=(12, 6))
ax = sns.barplot(
    data=domain_counts,
    x='num_duplicates',
    y='count',
    hue='domain',
    hue_order = ['Home', 'BigOffice-2', 'BigOffice-3', 'Hallway', 'MeetingRoom',
       'SmallOffice'],
    palette='viridis'
)

for p in ax.patches:
    height = p.get_height()
    if height > 0:
        ax.text(
            p.get_x() + p.get_width() / 2.,
            height + 0.5,
            int(height),
            ha='center',
            va='bottom',
            fontsize=11
        )

# Set y-axis max to total number of duplicates
ax.set_ylim(0, domain_counts['count'].sum())

plt.xlabel('Number of Duplicates per Image')
plt.ylabel('Number of Images')
plt.title('Distribution of Number of Duplicates per Image by Domain')
plt.tight_layout()
plt.show()


## CL Baseline Model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
import numpy as np
from collections import deque
import random
import time
import torch.nn.functional as F
from torch.utils.data import Subset, ConcatDataset, DataLoader

In [None]:
class SocialContinualModel(nn.Module):
    def __init__(self, num_tasks=9):
        super().__init__()
        # Initialize the ResNet50 architecture with Places365 configuration
        self.backbone = models.resnet50(num_classes=365)
        
        # get Places365 weights and fix their naming leftover from troch saving convention
        places365_weights = torch.load('resnet50_places365.pth.tar', weights_only=True)
        state_dict = places365_weights['state_dict']
        state_dict = {k.replace('module.', ''): v 
                     for k, v in state_dict.items()}
        
        # Load weights
        self.backbone.load_state_dict(state_dict)
        
        # Remove classification head
        self.backbone.fc = nn.Identity()

        #Freeze all params except last layer
        for name, param in self.backbone.named_parameters():
            if 'layer4' not in name:
                param.requires_grad_(False)
        
        #TODO human mask average, std, quadrants, human in realtion to robot std
        


        # Shared layers #TODO deeper shared space?
        self.shared_fc = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        # Task-specific heads with more expensive but finegrained GELU
        self.heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(1024, 512),
                nn.GELU(),
                nn.Linear(512, 1),
                nn.Sigmoid()
            ) for _ in range(num_tasks)
        ])

    def forward(self, x):
        features = self.backbone(x)
        shared = self.shared_fc(features)
        outputs = [head(shared) for head in self.heads]
        return torch.cat(outputs, dim=1)


In [None]:
import torch.nn as nn
from torchvision import models

class LGRBaseline(nn.Module):
    """
    @misc{churamani_feature_2024,
		title = {Feature Aggregation with Latent Generative Replay for Federated Continual Learning of Socially Appropriate Robot Behaviours},
		url = {http://arxiv.org/abs/2405.15773},
		doi = {10.48550/arXiv.2405.15773},
		number = {{arXiv}:2405.15773},
		publisher = {{arXiv}},
		author = {Churamani, Nikhil and Checker, Saksham and Chiang, Hao-Tien Lewis and Gunes, Hatice},
		urldate = {2025-01-30},
		date = {2024-03-16},
	}
    """
    def __init__(self, num_classes=9):
        super(LGRBaseline, self).__init__()
        
        # MobileNetV2 backbone
        self.backbone = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1).features
        
        # Backbone feature processing
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        
        # Regression head
        self.fc1_bn = nn.BatchNorm1d(1280)
        self.fc2 = nn.Linear(1280, 32)
        # self.fc2_bn = nn.BatchNorm1d(128)
		# self.fc3 = nn.Linear(128, 32)
        self.fc4 = nn.Linear(32, num_classes)

    def forward(self, x):
        # Feature extraction
        x = self.backbone(x)
        
        # Spatial reduction
        x = self.pool(x)
        x = self.flatten(x)
        
        # Regression
        x = self.fc1_bn(x)
        x = self.fc2(x)
        # x = self.fc2_bn(x)
		# x = self.fc3(x)
        x = self.fc4(x)
        
        return x


In [None]:
import torch
import random
import pandas as pd

class ReservoirBuffer:
    def __init__(self, capacity=1000, replay_ratio=0.2, input_shape=(3, 384, 216), label_shape=(9,), device=torch.device('cpu')):
        self.capacity = capacity
        self.inputs = torch.empty((capacity, *input_shape), dtype=torch.float32, device=device)
        self.labels = torch.empty((capacity, *label_shape), dtype=torch.float32, device=device)
        self.domains = [None] * capacity
        self.size = 0          # Number of samples currently in buffer
        self.num_seen = 0      # Total samples seen
        self.replay_ratio = replay_ratio

    def add(self, new_samples):
        for sample in new_samples:
            self.num_seen += 1
            if self.size < self.capacity:
                idx = self.size
                self.size += 1
            else:
                idx = random.randint(0, self.num_seen - 1)
                if idx >= self.capacity:
                    continue
            self.inputs[idx].copy_(sample[0])
            self.labels[idx].copy_(sample[1])
            self.domains[idx] = sample[2]

    def sample(self, batch_size):
        if self.size == 0:
            return []
        indices = torch.randint(0, self.size, (batch_size,))
        return [(self.inputs[i], self.labels[i], self.domains[i]) for i in indices]

    def __len__(self):
        return self.size

    def get_domain_distribution(self):
        return pd.Series(self.domains[:self.size]).value_counts()


In [None]:
class TaggedDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, tag):
        self.base_dataset = base_dataset
        self.tag = tag  # 0=current, 1=replay

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        x, y = self.base_dataset[idx]
        return x, y, self.tag

class NaiveRehearsalBuffer:
    """
    @inproceedings{Hsu18_EvalCL,
        title={Re-evaluating Continual Learning Scenarios: A Categorization and Case for Strong Baselines},
        author={Yen-Chang Hsu and Yen-Cheng Liu and Anita Ramasamy and Zsolt Kira},
        booktitle={NeurIPS Continual learning Workshop },
        year={2018},
        url={https://arxiv.org/abs/1810.12488}
    }
    """
    def __init__(self, buffer_size=1000):
        self.buffer_size = buffer_size
        self.domain_buffer = {}

    def update_buffer(self, domain, dataset):
        # Add/overwrite current domain
        self.domain_buffer[domain] = Subset(dataset, torch.arange(len(dataset)))
        
        # Recalculate quota
        num_domains = len(self.domain_buffer)
        buffer_quota_per_domain = self.buffer_size // num_domains
        
        # Reduce all domains (including current)
        for domain in self.domain_buffer:
            subset = self.domain_buffer[domain]
            max_safe_samples_to_overwrite = min(buffer_quota_per_domain, len(subset.dataset))
            rand_indices = torch.randperm(len(dataset))[:max_safe_samples_to_overwrite].numpy()
            self.domain_buffer[domain] = Subset(dataset, rand_indices)

    def get_loader_with_replay(self, current_domain, current_loader):
        current_dataset = TaggedDataset(current_loader.dataset, tag=0)
        replay_datasets = [TaggedDataset(dataset, tag=1) for domain, dataset in self.domain_buffer.items() if domain != current_domain]

        #Enforces 1:1 ratio when current ≥ buffer
        total_replay = sum(len(dataset) for dataset in replay_datasets)
        if total_replay > 0:
            K = max(len(current_dataset) // total_replay, 1)
            replay_datasets = replay_datasets * K

        combined_dataset =  ConcatDataset(replay_datasets + [current_dataset])
        combined_dataset = DataLoader(
            combined_dataset,
            batch_size=current_loader.batch_size,
            shuffle=True,
            num_workers=current_loader.num_workers,
            pin_memory=current_loader.pin_memory,
            drop_last=current_loader.drop_last
        )
        return combined_dataset
    
    def get_domain_distribution(self):
        """Returns {domain: num_samples} without needing Storage"""
        return {domain: len(subset) for domain, subset in self.domain_buffer.items()}



In [None]:
def evaluate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_samples = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.float32)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * inputs.size(0)  # Scale by batch size
            total_samples += inputs.size(0)

    average_loss_per_sample = total_loss / total_samples

    return average_loss_per_sample

## Run model

In [None]:
df = pd.read_pickle("processed_all_data.pkl")

# Create domain-specific dataloaders
domains = df['domain'].unique()
domain_dataloaders = {}
for domain in domains:
    domain_df = df[df['domain'] == domain]
    #domain_df = domain_df.sample(frac=0.5, random_state=42)
    loaders = create_dataloaders(domain_df, batch_sizes=(32, 64, 64), resize_img_to=(128, 128))  #TODO should be (384, 216) to retain scale or (224, 224) for best performance on MobileNet
    domain_dataloaders[domain] = loaders

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = LGRBaseline().to(device)
#buffer = ReservoirBuffer(capacity=1000, replay_ration=0.2, input_shape=(3, 384, 216), label_shape=(9,), device=device)
buffer = NaiveRehearsalBuffer(buffer_size=1000)

In [None]:

optimizer = optim.Adam(model.parameters(), lr=1e-3)
#TODO separate optimisers for shared and heads?
criterion = nn.MSELoss()
#TODO potential loss optimisations like unccertainty loss, 

num_epochs=5

metrics = {
    'train_loss': [],
    'val_loss': [],
    'test_loss': [],
    'domain_performance': {domain: [] for domain in domains}
}

In [None]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir='runs/naive_rehearsal_buffer')  # You can change the log_dir name



for domain_idx, domain in enumerate(domains):
    start_time = time.time()
    print(f"\n=== Training Domain {domain_idx+1}/{len(domains)}: {domain} ===")

    original_train_loader = domain_dataloaders[domain]['train']
    original_train_dataset = original_train_loader.dataset
    
    train_loader = buffer.get_loader_with_replay(domain, original_train_loader)
    
    total_batches = len(train_loader)

    print(f"Training on {str(device).upper()} | Batch Size: {train_loader.batch_size}")
    print(f"Total Batches per Epoch: {total_batches}")
    
    for epoch in range(num_epochs):
        epoch_start = time.time()
        print(f"\n--- Epoch {epoch+1}/{num_epochs} ---")
        
        model.train()
        running_loss = 0.0
        total_train_samples = 0
        
        for batch_idx, (inputs, labels, _) in enumerate(train_loader):
            batch_start = time.time()    

            inputs = inputs.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.float32)

            # ReservoirBuffer # Sample the buffer and add replay samples to training
            # if buffer and random.random() < buffer.replay_ratio:
            #     batch_size = inputs.size(0)
            #     replay_batch = buffer.sample(int(batch_size * 0.25))
            #     if replay_batch:
            #         replay_inputs, replay_labels, _ = zip(*replay_batch)
            #         replay_inputs = torch.stack(replay_inputs)
            #         replay_labels = torch.stack(replay_labels)
            #         inputs = torch.cat([inputs, replay_inputs])
            #         labels = torch.cat([labels, replay_labels])
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            # ReservoirBuffer # Add current batch to buffer
            # with torch.no_grad():
            #     samples = [(img.detach(), label.detach(), domain) for img, label in zip(inputs, labels)]
            #     buffer.add(samples)

            
            batch_time = time.time() - batch_start
                
            running_loss += loss.item() * inputs.size(0)
            total_train_samples += inputs.size(0)

            if batch_idx % 10 == 9:               
                avg_loss = running_loss / 10
                batch_time = time.time() - batch_start  # ← Keep this
                print(
                    f"Domain: {domain} | Epoch: {epoch+1} | "
                    f"Batch: {batch_idx+1}/{total_batches} | "
                    f"Avg Loss: {avg_loss:.4f} | "  # Changed to Avg Loss
                    f"Time/batch: {batch_time:.2f}s"
                )
                running_loss = 0.0
    
        # Epoch summary
        epoch_time = time.time() - epoch_start
        avg_train_loss = running_loss / total_train_samples
        metrics['train_loss'].append(avg_train_loss)
        print(f"Epoch {epoch+1} Time: {epoch_time//60:.0f}m {epoch_time%60:.0f}s")     
        print(f"Training Loss: {avg_train_loss:.4f}")
       
        # Validation
        val_loss = evaluate_model(model, domain_dataloaders[domain]['val'], criterion, device)
        metrics['val_loss'].append(val_loss)
        print(f"Domain {domain} | Epoch {epoch+1} | Val Loss: {val_loss:.4f}")

        writer.add_scalar('Loss/train', avg_train_loss, epoch)
        writer.add_scalar('Loss/val', val_loss, epoch) 
    
    buffer.update_buffer(domain, original_train_dataset)

    print("\nBuffer after domain:", domain)
    dist = buffer.get_domain_distribution()
    for d, counts in dist.items():
        print(f"- {d}: {sum(counts.values()) if isinstance(counts, dict) else counts} samples")
    
    # Evaluate on all domains
    for eval_domain in domains:
        domain_loss = evaluate_model(model, domain_dataloaders[eval_domain]['val'], criterion, device)
        metrics['domain_performance'][eval_domain].append(domain_loss)
        print(f"Performance on {eval_domain} after {domain}: {domain_loss:.4f}")
        writer.add_scalar(f'Loss/val_{eval_domain}', domain_loss, epoch)
    

writer.close()



In [None]:
checkpoint = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'metrics': metrics,
    'epoch': epoch + 1,
    'domain': domain
}
torch.save(checkpoint, 'base_model_naiverehearsal.pth')


In [None]:
metrics

In [None]:
import matplotlib.pyplot as plt

plt.plot(metrics['val_loss'])
plt.plot(metrics['train_loss'])
plt.xlabel('Epoch')
plt.ylabel('Validation Loss')
plt.title('Validation Loss Over Epochs')
plt.show()


In [None]:
for domain, losses in metrics['domain_performance'].items():
    plt.plot(losses, label=domain)
plt.xlabel('Domain Training Step')
plt.ylabel('Loss')
plt.title('Domain Performance Over Training')
plt.legend()
plt.show()


In [None]:
for domain, losses in metrics['domain_performance'].items():
    print(f"{domain}: Initial = {losses[0]:.4f}, Final = {losses[-1]:.4f}, Change = {losses[-1] - losses[0]:.4f}")


In [None]:
# After model output
outputs = model(inputs)  # Should be in [1,5]
print(f"Output range: {outputs.min().item()}–{outputs.max().item()}")


In [None]:
# Overfit Single Batch Check

test_model = LGRBaseline().to(device)
test_optimizer = optim.Adam(test_model.parameters(), lr=1e-3)
buffer = NaiveRehearsalBuffer(0)

first_domain = domains[0]
train_loader = domain_dataloaders[first_domain]['train']
single_batch = next(iter(train_loader))
inputs, labels = single_batch
inputs = inputs.to(device, dtype=torch.float32)
labels = labels.to(device, dtype=torch.float32)

In [None]:
num_test_epochs = 100
for epoch in range(num_test_epochs):
    test_optimizer.zero_grad()

    outputs = test_model(inputs)
    loss = criterion(outputs, labels)
    
    loss.backward()
    test_optimizer.step()
    
    if (epoch+1) % 10 == 0 or epoch == 0:
        print(f"Overfit Epoch {epoch+1}/{num_test_epochs} | Loss: {loss.item():.4f}")
