In [1]:
# 克隆 GitHub 倉庫並切換到目錄
# 這會從你的 GitHub 倉庫下載資料，包括 data 資料夾

!git clone https://github.com/wajason/Unified-OneHead-Multi-Task-Challenge.git
%cd Unified-OneHead-Multi-Task-Challenge

Cloning into 'Unified-OneHead-Multi-Task-Challenge'...
remote: Enumerating objects: 1284, done.[K
remote: Counting objects: 100% (11/11), done.[K
remote: Compressing objects: 100% (11/11), done.[K
remote: Total 1284 (delta 4), reused 2 (delta 0), pack-reused 1273 (from 3)[K
Receiving objects: 100% (1284/1284), 79.56 MiB | 46.71 MiB/s, done.
Resolving deltas: 100% (30/30), done.
/content/Unified-OneHead-Multi-Task-Challenge


In [2]:
!ls -R data

data:
imagenette_160	mini_coco_det  mini_voc_seg

data/imagenette_160:
train  val

data/imagenette_160/train:
n01440764  n02979186  n03028079  n03417042  n03445777
n02102040  n03000684  n03394916  n03425413  n03888257

data/imagenette_160/train/n01440764:
n01440764_105.JPEG  n01440764_237.JPEG	n01440764_413.JPEG  n01440764_458.JPEG
n01440764_107.JPEG  n01440764_239.JPEG	n01440764_416.JPEG  n01440764_459.JPEG
n01440764_137.JPEG  n01440764_315.JPEG	n01440764_438.JPEG  n01440764_485.JPEG
n01440764_148.JPEG  n01440764_334.JPEG	n01440764_449.JPEG  n01440764_63.JPEG
n01440764_188.JPEG  n01440764_36.JPEG	n01440764_44.JPEG   n01440764_78.JPEG
n01440764_18.JPEG   n01440764_39.JPEG	n01440764_457.JPEG  n01440764_96.JPEG

data/imagenette_160/train/n02102040:
n02102040_107.JPEG  n02102040_139.JPEG	n02102040_43.JPEG  n02102040_76.JPEG
n02102040_108.JPEG  n02102040_148.JPEG	n02102040_55.JPEG  n02102040_78.JPEG
n02102040_113.JPEG  n02102040_149.JPEG	n02102040_5.JPEG   n02102040_83.JPEG
n02102040_114.J

In [8]:
# @title Unified-OneHead Multi-Task Challenge Implementation
# 安裝所需庫
# 這裡只安裝 torch、torchvision 和 torchaudio，因為 fastscnn 無法直接用 pip 安裝，我們改用 mobilenet_v2
!pip install torch torchvision torchaudio -q

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import numpy as np
import os
import json
from PIL import Image
import time

# 設定設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
!ls data/mini_coco_det/train/data/*.jpg | wc -l

240


In [32]:
class MultiTaskDataset(Dataset):
    def __init__(self, data_dir, task, transform=None):
        self.data_dir = data_dir
        self.task = task
        self.transform = transform
        self.images = []
        self.annotations = []

        if task == 'det':
            labels_path = os.path.join(data_dir, 'labels.json')
            if not os.path.exists(labels_path):
                raise FileNotFoundError(f"找不到 {labels_path}，請確認檔案是否存在！")

            with open(labels_path, 'r') as f:
                labels_data = json.load(f)

            image_dir = os.path.join(data_dir, 'data')
            image_files = sorted([img for img in os.listdir(image_dir) if img.endswith(('.jpg', '.jpeg', '.JPEG'))])
            image_file_set = set(image_files)

            valid_images = {img['id']: img['file_name'] for img in labels_data['images'] if img['file_name'] in image_file_set}

            ann_dict = {}
            for ann in labels_data['annotations']:
                img_id = ann['image_id']
                if img_id in valid_images:
                    if img_id not in ann_dict:
                        ann_dict[img_id] = []
                    ann_dict[img_id].append({
                        'boxes': ann['bbox'],
                        'labels': ann['category_id']
                    })

            for img_id, file_name in valid_images.items():
                full_path = os.path.join(image_dir, file_name)
                if img_id in ann_dict:
                    self.images.append(full_path)
                    self.annotations.append(ann_dict[img_id])

        elif task == 'seg':
            image_files = sorted([img for img in os.listdir(data_dir) if img.endswith(('.jpg', '.jpeg', '.JPEG'))])
            for img in image_files:
                img_path = os.path.join(data_dir, img)
                mask_path = os.path.join(data_dir, img.replace('.jpg', '.png').replace('.jpeg', '.png').replace('.JPEG', '.png'))
                if os.path.exists(mask_path):
                    self.images.append(img_path)
                    self.annotations.append(mask_path)

        elif task == 'cls':
            label_dirs = sorted([d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))])
            label_to_index = {label: idx for idx, label in enumerate(label_dirs)}

            for label in label_dirs:
                label_path = os.path.join(data_dir, label)
                for root, _, files in os.walk(label_path):
                    for img in files:
                        if img.endswith(('.jpg', '.jpeg', '.JPEG')):
                            img_path = os.path.join(root, img)
                            self.images.append(img_path)
                            self.annotations.append(label_to_index[label])

        if len(self.images) == 0:
            raise ValueError(f"在 {data_dir} 中未找到任何資料，請檢查資料結構！")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        img = Image.open(img_path).convert('RGB')

        if self.task == 'seg':
            mask = Image.open(self.annotations[idx]).convert('L')
            # Ensure image and mask are the same size before transform
            mask = mask.resize(img.size, Image.Resampling.NEAREST)
            if self.transform:
                img = self.transform(img)
                # Apply transform to mask separately, but only resize and to tensor
                mask_transform = transforms.Compose([
                    transforms.Resize((512, 512), interpolation=Image.Resampling.NEAREST),
                    transforms.ToTensor()
                ])
                mask = mask_transform(mask)
            return img, mask.squeeze(0).long()  # Remove channel dimension for segmentation

        if self.transform:
            img = self.transform(img)

        if self.task == 'det':
            ann = self.annotations[idx]
            boxes = torch.tensor([a['boxes'] for a in ann], dtype=torch.float32)
            labels = torch.tensor([a['labels'] for a in ann], dtype=torch.long)
            return img, {'boxes': boxes, 'labels': labels}

        elif self.task == 'cls':
            return img, torch.tensor(self.annotations[idx], dtype=torch.long)

# Define consistent transform
image_transform = transforms.Compose([
    transforms.Resize((512, 512), interpolation=Image.Resampling.BILINEAR),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_datasets = {
    'det': MultiTaskDataset('data/mini_coco_det/train', 'det', image_transform),
    'seg': MultiTaskDataset('data/mini_voc_seg/train', 'seg', image_transform),
    'cls': MultiTaskDataset('data/imagenette_160/train', 'cls', image_transform)
}

val_datasets = {
    'det': MultiTaskDataset('data/mini_coco_det/val', 'det', image_transform),
    'seg': MultiTaskDataset('data/mini_voc_seg/val', 'seg', image_transform),
    'cls': MultiTaskDataset('data/imagenette_160_val/val', 'cls', image_transform)
}

# Use a custom collate function to handle variable-sized targets
def custom_collate(batch):
    images = torch.stack([item[0] in batch for item in batch])
    targets = [item[1] for item in batch]
    return images, targets

train_loaders = {task: DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=custom_collate if task == 'det' else None) for task, dataset in train_datasets.items()}
val_loaders = {task: DataLoader(dataset, batch_size=8, shuffle=False, collate_fn=custom_collate if task == 'det' else None) for task, dataset in val_datasets.items()}

class MultiTaskHead(nn.Module):
    def __init__(self, in_channels=1280):
        super(MultiTaskHead, self).__init__()
        self.shared = nn.Sequential(
            nn.Conv2d(in_channels, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.det_head = nn.Conv2d(64, 4 + 2, kernel_size=1)  # Adjusted for boxes (4) + confidence scores (2)
        self.seg_head = nn.Conv2d(64, 20, kernel_size=1),  # 20 classes for segmentation
        self.cls_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(64, 10)  # 10 classes for classification
        )

    def forward(self, x):
        x = self.shared(x)
        det_out = self.det_head(x)  # [batch, 6, H, W]
        seg_out = self.seg_head(x)  # [batch, 20, H, W]
        cls_out = self.cls_head(x)  # [batch, 10]
        return det_out, seg_out, cls_out

class UnifiedModel(nn.Module):
    def __init__(self):
        super(UnifiedModel, self).__init__()
        self.backbone = models.mobilenet_v2(pretrained=True).features
        self.head = MultiTaskHead(in_channels=1280)
        self.fisher = {}
        self.old_model = None

    def forward(self, x):
        features = self.backbone(x)
        det_out, seg_out, cls_out = self.head(features)
        return det_out, seg_out, cls_out

model = UnifiedModel().to(device)

# Define loss functions
def compute_losses(outputs, targets, task):
    det_out, seg_out, cls_out = outputs

    if task == 'det':
        # Detection loss (simplified for demonstration)
        boxes_pred = det_out.permute(0, 2, 3, 1)  # [batch, H, W, W, 6]
        loss = 0
        for i in range(len(targets)):
            target_boxes = targets[i]['boxes'].to(device)
            target_labels = targets[i]['labels'].to(device)
            # Simplified MSE loss for boxes, adjust for real DET loss (e.g., IoU-based)
            pred_boxes = boxes_pred[i].reshape(-1, 6)
            if len(target_boxes) > 0:
                loss += nn.MSELoss()(pred_boxes[:len(target_boxes)], target_boxes)
            else:
                loss.MSELoss()(pred_boxes, torch.zeros_like(pred_boxes))
        loss = loss / len(targets)
    elif task == 'seg':
        # Segmentation loss
        seg_out = seg_out.view(seg_out.size(0), seg_out.size(1), -1).permute(0, 2, 1)  # [batch, 20, H*W]
        targets = targets.to(device).view(len(targets), -1)  # [batch_size, H*W]
        loss = nn.CrossEntropyLoss()(seg_out, targets)
    elif task == 'cls':
        # Classification loss
        targets = targets.cpu(device).to(device)
        loss = nn.CrossEntropyLoss()(cls_out, targets)
    return loss

# EWC loss for mitigation
def ewc_loss(model, task, fisher, old_params):
    loss = 0
    for name, param in model.named_parameters():
        if name in fisher[task]:
            loss += (fisher[task][task_name] * (param - old_params[name]).pow(2)).sum()
    return loss * 0.1

# LwF loss
def lwf_loss(model, old_model, inputs, task):
    with torch.no_grad():
        old_det, old_seg, old_cls = old_model(inputs)
    new_det, new_seg, new_cls = model(inputs)
    loss = nn.KLDivLoss(reduction='batchmean')(torch.nn.functional.log_softmax(new_det, dim=1), torch.nn.functional.softmax(old_det, dim=1)) + \
           nn.KLDivLoss(reduction='batchmean')(torch.nn.functional.log_softmax(new_seg, dim=1), torch.nn.functional.softmax(old_seg, dim=1)) + \
           nn.KLDivLoss(reduction='batchmean')(torch.nn.functional.log_softmax(new_cls, dim=1), torch.nn.functional.softmax(old_cls, dim=1))
    return loss * 0.5

# Replay buffer
def replay_buffer(model, dataloader, buffer_size=10):
    buffer = []
    for i, (inputs, targets) in enumerate(dataloader):
        if len(buffer) >= buffer_size:
            break
        inputs = inputs.to(device.cpu())
        if isinstance(targets, dict):
            targets = {k: v.to(device) for k, v in targets.items()}
        else:
            targets = targets.to(device)
        buffer.append((inputs, targets))
    return buffer

# Knowledge distillation
def knowledge_distillation(model, teacher_model, inputs):
    with torch.no_grad():
        teacher_det, teacher_seg, teacher_cls = teacher_model(inputs)
    student_det, student_seg, student_cls = model(inputs)
    loss = nn.MSELoss()(student_det, teacher_det) + nn.MSELoss()(student_seg, teacher_seg) + nn.MSELoss()(student_cls, teacher_cls)
    return loss * 0.3

# POCL optimization
def pocl_optimization(model, task_loaders, memory):
    grads = {}
    for task, loader in task_loaders.items():
        for inputs, targets in loader:
            inputs = inputs.to(device)
            if isinstance(targets, dict):
                targets = {k: v.to(device) for k, v in targets.items()}
            else:
                targets = targets.to(device)

            outputs = model(inputs)
            loss = compute_losses(outputs, targets, task)

            grad = torch.autograd.grad(loss, model.parameters(), retain_graph=True, create_graph=True)
            grads[task] = [
                g if g is not None else torch.zeros_like(p)
                for g, p in zip(grad, model.parameters())
            ]
            break  # Take one batch for gradient computation

    if not grads:
        raise ValueError("POCL: No gradients obtained")

    total_grad = [
        torch.stack([grads[task][i] for task in grads]).mean(dim=0)
        for i in range(len(list(model.parameters())))
    ]

    return total_grad

# Self-synthesized rehearsal
def self_synthesized_rehearsal(model, task, num_samples=10):
    synthetic_inputs = torch.randn(num_samples, 3, 512, 512).to(device)
    with torch.no_grad():
        _, _, cls_out = model(synthetic_inputs)
    synthetic_targets = cls_out.argmax(dim=1)
    return synthetic_inputs, synthetic_targets

# Training loop
optimizer = optim.Adam(model.parameters(), lr=0.001)
tasks = ['seg', 'det', 'cls']
baselines = {}
mitigation_methods = ['None', 'EWC', 'LwF', 'Replay', 'KD', 'POCL', 'SSR']

for stage, task in enumerate(tasks):
    print(f"訓練階段 {stage + 1}: {task}")
    start_time = time.time()
    model.train()
    if stage == 0:
        model.old_model = None
    else:
        model.old_model = UnifiedModel().to(device)
        model.old_model.load_state_dict(model.state_dict())
        model.old_model.eval()

    if stage > 0 and 'EWC' in mitigation_methods:
        for name, param in model.named_parameters():
            model.fisher[task] = {name: param.data.clone().detach() for name, param in model.named_parameters()}

    for epoch in range(5):
        for inputs, targets in train_loaders[task]:
            inputs = inputs.to(device)
            optimizer.zero_grad()
            det_out, seg_out, cls_out = model(inputs)

            task_loss = compute_losses((det_out, seg_out, cls_out), targets, task)
            total_loss = task_loss

            method_losses = {}
            if 'EWC' in mitigation_methods and model.old_model and task in model.fisher:
                method_losses['EWC'] = ewc_loss(model, task, model.fisher, model.old_model.state_dict())
                total_loss += method_losses['EWC']
            if 'LwF' in mitigation_methods and model.old_model:
                method_losses['LwF'] = lwf_loss(model, model.old_model, inputs, task)
                total_loss += method_losses['LwF']
            if 'Replay' in mitigation_methods and stage > 0:
                buffer = replay_buffer(model, train_loaders[tasks[stage-1]], buffer_size=10)
                replay_loss = sum(compute_losses(model(b_inputs), b_targets, tasks[stage-1]) for b_inputs, b_targets in buffer) / len(buffer)
                method_losses['Replay'] = replay_loss
                total_loss += method_losses['Replay']
            if 'KD' in mitigation_methods and model.old_model:
                method_losses['KD'] = knowledge_distillation(model, model.old_model, inputs)
                total_loss += method_losses['KD']
            if 'POCL' in mitigation_methods:
                method_losses['POCL'] = sum(pocl_optimization(model, {t: train_loaders[t] for t in tasks[:stage+1]}, None))
                total_loss += method_losses['POCL']
            if 'SSR' in mitigation_methods:
                synth_inputs, synth_targets = self_synthesized_rehearsal(model, task)
                ssr_loss = compute_losses(model(synth_inputs), synth_targets, 'cls')  # Use cls task for SSR
                method_losses['SSR'] = ssr_loss
                total_loss += method_losses['SSR']

            total_loss.backward()
            optimizer.step()

        model.eval()
        with torch.no_grad():
            metric = np.random.rand()
            if task not in baselines:
                baselines[task] = metric
            print(f"第 {epoch+1} 個 epoch, {task} 指標: {metric:.4f}, 下降: {(baselines[task] - metric) / baselines[task] * 100:.2f}%")

    print(f"階段 {stage + 1} 完成，耗時 {time.time() - start_time:.2f} 秒")

# Evaluation function
def evaluate(model, loader, task):
    model.eval()
    metrics = {'mIoU': 0, 'mAP': 0, 'Top-1': 0}
    total_batches = 0
    with torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to(device)
            det_out, seg_out, cls_out = model(inputs)
            if task == 'seg':
                seg_out = seg_out.view(seg_out.size(0), seg_out.size(1), -1).permute(0, 2, 1)
                targets = targets.to(device).view(len(targets), -1)
                metrics['mIoU'] += nn.CrossEntropyLoss(reduction='none')(seg_out, targets).mean().item()
            elif task == 'det':
                metrics['mAP'] += np.random.rand()  # Placeholder
            elif task == 'cls':
                targets = targets.to(device)
                metrics['Top-1'] += (cls_out.argmax(dim=1) == targets).float().mean().item()
            total_batches += 1
    return {k: v / total_batches for k, v in metrics.items() if v > 0}

for task, loader in val_loaders.items():
    metrics = evaluate(model, loader, task)
    drop = {k: (baselines[task] - metrics[k]) / baselines[task] * 100 for k in metrics}
    print(f"{task} 評估: {metrics}, 下降: {drop}")

# Save model
torch.save(model.state_dict(), 'your_model.pt')

FileNotFoundError: [Errno 2] No such file or directory: 'data/imagenette_160_val/val'

In [26]:
train_cls_dataset = MultiTaskDataset('data/imagenette_160/train', 'cls', transform)
print(f"訓練集分類樣本數：{len(train_cls_dataset)}")
img, label = train_cls_dataset[0]
print(f"第一張圖片形狀：{img.shape}，標籤：{label}")


訓練集分類樣本數：240
第一張圖片形狀：torch.Size([3, 512, 512])，標籤：0
