In [None]:
# @title Unified-OneHead Multi-Task Challenge Implementation
# Install required libraries
!pip install torch torchvision torchaudio fastscnn -q

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from fastscnn import FastSCNN
import numpy as np
import os
import json
from PIL import Image
import time

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Download and prepare datasets (simulated for this example)
# Replace with actual download command from course
!mkdir -p data/mini_coco_det/train data/mini_coco_det/val data/mini_voc_seg/train data/mini_voc_seg/val data/imagenette_160/train data/imagenette_160/val
# Simulated data loading (replace with real data paths)
class MultiTaskDataset(Dataset):
    def __init__(self, data_dir, task, transform=None):
        self.data_dir = data_dir
        self.task = task
        self.transform = transform
        self.images = [os.path.join(data_dir, img) for img in os.listdir(data_dir) if img.endswith('.jpg')]
        self.annotations = self._load_annotations()

    def _load_annotations(self):
        annotations = []
        for img in self.images:
            if self.task == 'det':
                ann = {'boxes': np.random.rand(5, 4), 'labels': np.random.randint(0, 10, 5)}  # Simulated COCO JSON
            elif self.task == 'seg':
                ann = np.random.randint(0, 20, (512, 512))  # Simulated PNG mask
            elif self.task == 'cls':
                ann = np.random.randint(0, 10)  # Simulated class label
            annotations.append(ann)
        return annotations

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, self.annotations[idx]

# Data transforms
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Datasets and DataLoaders
train_datasets = {
    'seg': MultiTaskDataset('data/mini_voc_seg/train', 'seg', transform),
    'det': MultiTaskDataset('data/mini_coco_det/train', 'det', transform),
    'cls': MultiTaskDataset('data/imagenette_160/train', 'cls', transform)
}
val_datasets = {
    'seg': MultiTaskDataset('data/mini_voc_seg/val', 'seg', transform),
    'det': MultiTaskDataset('data/mini_coco_det/val', 'det', transform),
    'cls': MultiTaskDataset('data/imagenette_160/val', 'cls', transform)
}
train_loaders = {task: DataLoader(dataset, batch_size=8, shuffle=True) for task, dataset in train_datasets.items()}
val_loaders = {task: DataLoader(dataset, batch_size=8, shuffle=False) for task, dataset in val_datasets.items()}

# Model Definition
class MultiTaskHead(nn.Module):
    def __init__(self, in_channels=64):
        super(MultiTaskHead, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 128, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.output = nn.Conv2d(64, 10 + 20 + 10, kernel_size=1)  # Det(10) + Seg(20) + Cls(10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.output(x)
        det_out = x[:, :10, :, :]  # N x (cx, cy, w, h, conf, C_det)
        seg_out = x[:, 10:30, :, :]  # C_seg x H x W
        cls_out = x[:, 30:, :, :].mean([2, 3])  # C_cls logits
        return det_out, seg_out, cls_out

class UnifiedModel(nn.Module):
    def __init__(self):
        super(UnifiedModel, self).__init__()
        self.backbone = FastSCNN(pretrained=True)
        self.head = MultiTaskHead(in_channels=64)
        self.fisher = {}  # For EWC
        self.old_model = None  # For LwF and KD

    def forward(self, x):
        features = self.backbone(x)
        det_out, seg_out, cls_out = self.head(features)
        return det_out, seg_out, cls_out

model = UnifiedModel().to(device)

# Loss Functions
def compute_losses(outputs, targets, task):
    det_out, seg_out, cls_out = outputs
    if task == 'det':
        # Simulated detection loss (e.g., IoU loss)
        loss = nn.MSELoss()(det_out, torch.tensor(targets).to(device))
    elif task == 'seg':
        loss = nn.CrossEntropyLoss()(seg_out, torch.tensor(targets).to(device).long())
    elif task == 'cls':
        loss = nn.CrossEntropyLoss()(cls_out, torch.tensor(targets).to(device).long())
    return loss

# Forgetting Mitigation Tools
def ewc_loss(model, task, fisher, old_params):
    loss = 0
    for name, param in model.named_parameters():
        if name in fisher[task]:
            loss += (fisher[task][name] * (param - old_params[name]).pow(2)).sum()
    return loss * 0.1  # EWC penalty

def lwf_loss(model, inputs, task, old_model):
    with torch.no_grad():
        old_det, old_seg, old_cls = old_model(inputs)
    new_det, new_seg, new_cls = model(inputs)
    loss = nn.KLDivLoss()(torch.log_softmax(new_det, dim=1), torch.softmax(old_det, dim=1)) + \
           nn.KLDivLoss()(torch.log_softmax(new_seg, dim=1), torch.softmax(old_seg, dim=1)) + \
           nn.KLDivLoss()(torch.log_softmax(new_cls, dim=1), torch.softmax(old_cls, dim=1))
    return loss * 0.5

def replay_buffer(model, dataloader, buffer_size=10):
    buffer = []
    for i, (inputs, targets) in enumerate(dataloader):
        if len(buffer) >= buffer_size:
            break
        buffer.append((inputs.to(device), targets))
    return buffer

def knowledge_distillation(model, teacher_model, inputs):
    with torch.no_grad():
        teacher_det, teacher_seg, teacher_cls = teacher_model(inputs)
    student_det, student_seg, student_cls = model(inputs)
    loss = nn.MSELoss()(student_det, teacher_det) + nn.MSELoss()(student_seg, teacher_seg) + nn.MSELoss()(student_cls, teacher_cls)
    return loss * 0.3

# Advanced Strategies from Top-Tier Papers
def pocl_optimization(model, task_loaders, memory):
    # Pareto-Optimized CL (Wu et al., ICML 2024)
    grads = {}
    for task, loader in task_loaders.items():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets
            outputs = model(inputs)
            loss = compute_losses(outputs, targets, task)
            grads[task] = torch.autograd.grad(loss, model.parameters())
    # Simplified Pareto optimization (multi-objective balancing)
    total_grad = torch.stack([g.sum() for g in grads.values()]).mean(dim=0)
    return total_grad

def self_synthesized_rehearsal(model, task, num_samples=10):
    # Self-Synthesized Rehearsal (Huang et al., ACL 2024)
    synthetic_inputs = torch.randn(num_samples, 3, 512, 512).to(device)
    with torch.no_grad():
        _, _, cls_out = model(synthetic_inputs)
    synthetic_targets = cls_out.argmax(dim=1)
    return synthetic_inputs, synthetic_targets

# Training Loop
optimizer = optim.Adam(model.parameters(), lr=0.001)
tasks = ['seg', 'det', 'cls']
baselines = {}
mitigation_methods = ['None', 'EWC', 'LwF', 'Replay', 'KD', 'POCL', 'SSR']

for stage, task in enumerate(tasks):
    print(f"Training Stage {stage + 1}: {task}")
    start_time = time.time()
    model.train()
    if stage == 0:
        model.old_model = None
    else:
        model.old_model = UnifiedModel().to(device)
        model.old_model.load_state_dict(model.state_dict())

    # Compute Fisher Information for EWC
    if stage > 0 and 'EWC' in mitigation_methods:
        for name, param in model.named_parameters():
            model.fisher[task] = param.data.clone().detach()

    for epoch in range(5):  # Adjust epochs to fit 2h limit
        for inputs, targets in train_loaders[task]:
            inputs, targets = inputs.to(device), targets
            optimizer.zero_grad()
            det_out, seg_out, cls_out = model(inputs)

            task_loss = compute_losses((det_out, seg_out, cls_out), targets, task)
            total_loss = task_loss

            # Apply mitigation methods
            method_losses = {}
            if 'EWC' in mitigation_methods:
                method_losses['EWC'] = ewc_loss(model, task, model.fisher, model.old_model.state_dict())
                total_loss += method_losses['EWC']
            if 'LwF' in mitigation_methods and model.old_model:
                method_losses['LwF'] = lwf_loss(model, inputs, task, model.old_model)
                total_loss += method_losses['LwF']
            if 'Replay' in mitigation_methods:
                buffer = replay_buffer(model, train_loaders[tasks[stage-1]], buffer_size=10)
                replay_loss = sum(compute_losses(model(inputs), targets, tasks[stage-1]) for inputs, targets in buffer) / len(buffer)
                method_losses['Replay'] = replay_loss
                total_loss += method_losses['Replay']
            if 'KD' in mitigation_methods and model.old_model:
                method_losses['KD'] = knowledge_distillation(model, model.old_model, inputs)
                total_loss += method_losses['KD']
            if 'POCL' in mitigation_methods:
                method_losses['POCL'] = pocl_optimization(model, {t: train_loaders[t] for t in tasks[:stage+1]}, None)
                total_loss += method_losses['POCL']
            if 'SSR' in mitigation_methods:
                synth_inputs, synth_targets = self_synthesized_rehearsal(model, task)
                ssr_loss = compute_losses(model(synth_inputs), synth_targets, task)
                method_losses['SSR'] = ssr_loss
                total_loss += method_losses['SSR']

            total_loss.backward()
            optimizer.step()

        # Validation
        model.eval()
        with torch.no_grad():
            metric = np.random.rand()  # Simulated mIoU/mAP/Top-1 (replace with real metrics)
            if task not in baselines:
                baselines[task] = metric
            print(f"Epoch {epoch+1}, {task} Metric: {metric:.4f}, Drop: {(baselines[task] - metric) / baselines[task] * 100:.2f}%")

    print(f"Stage {stage + 1} completed in {time.time() - start_time:.2f}s")

# Evaluation
def evaluate(model, loader):
    model.eval()
    metrics = {'mIoU': 0, 'mAP': 0, 'Top-1': 0}
    with torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to(device)
            det_out, seg_out, cls_out = model(inputs)
            # Simulated evaluation (replace with real metrics)
            metrics['mIoU'] += np.random.rand()
            metrics['mAP'] += np.random.rand()
            metrics['Top-1'] += np.random.rand()
    return {k: v / len(loader) for k, v in metrics.items()}

for task, loader in val_loaders.items():
    metrics = evaluate(model, loader)
    drop = {(k, (baselines[task] - metrics[k]) / baselines[task] * 100) for k in metrics}
    print(f"{task} Evaluation: {metrics}, Drops: {drop}")

# Save model and results
torch.save(model.state_dict(), 'your_model.pt')