In [None]:
!pip install fiftyone pycocotools scikit-multilearn
!pip install tqdm

Collecting fiftyone
  Downloading fiftyone-1.5.2-py3-none-any.whl.metadata (23 kB)
Collecting scikit-multilearn
  Downloading scikit_multilearn-0.2.0-py3-none-any.whl.metadata (6.0 kB)
Collecting argcomplete (from fiftyone)
  Downloading argcomplete-3.6.2-py3-none-any.whl.metadata (16 kB)
Collecting async_lru>=2 (from fiftyone)
  Downloading async_lru-2.0.5-py3-none-any.whl.metadata (4.5 kB)
Collecting boto3 (from fiftyone)
  Downloading boto3-1.38.36-py3-none-any.whl.metadata (6.6 kB)
Collecting dacite<1.8.0,>=1.6.0 (from fiftyone)
  Downloading dacite-1.7.0-py3-none-any.whl.metadata (14 kB)
Collecting Deprecated (from fiftyone)
  Downloading Deprecated-1.2.18-py2.py3-none-any.whl.metadata (5.7 kB)
Collecting ftfy (from fiftyone)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting hypercorn>=0.13.2 (from fiftyone)
  Downloading hypercorn-0.17.3-py3-none-any.whl.metadata (5.4 kB)
Collecting kaleido!=0.2.1.post1 (from fiftyone)
  Downloading kaleido-0.2.1-py2.py3-none

# Dataset dowmload

## COCO dataset

In [None]:
import fiftyone as fo
import fiftyone.zoo as foz
import json
import os
from collections import defaultdict
import random

# 選擇 10 個常見的 COCO 類別
selected_classes = [
    "person", "car", "bicycle", "motorcycle", "airplane",
    "bus", "train", "truck", "boat", "dog"
]

# 下載指定類別的 COCO 2017 數據集
dataset = foz.load_zoo_dataset(
    "coco-2017",
    splits=["train", "validation"],
    label_types=["detections"],
    classes=selected_classes,
    max_samples=350  # 下載足夠的樣本以便後續篩選
)


Downloading split 'train' to '/root/fiftyone/coco-2017/train' if necessary


INFO:fiftyone.zoo.datasets:Downloading split 'train' to '/root/fiftyone/coco-2017/train' if necessary


Downloading annotations to '/root/fiftyone/coco-2017/tmp-download/annotations_trainval2017.zip'


INFO:fiftyone.utils.coco:Downloading annotations to '/root/fiftyone/coco-2017/tmp-download/annotations_trainval2017.zip'


 100% |██████|    1.9Gb/1.9Gb [15.8s elapsed, 0s remaining, 144.4Mb/s]      


INFO:eta.core.utils: 100% |██████|    1.9Gb/1.9Gb [15.8s elapsed, 0s remaining, 144.4Mb/s]      


Extracting annotations to '/root/fiftyone/coco-2017/raw/instances_train2017.json'


INFO:fiftyone.utils.coco:Extracting annotations to '/root/fiftyone/coco-2017/raw/instances_train2017.json'


Downloading 350 images


INFO:fiftyone.utils.coco:Downloading 350 images


 100% |██████████████████| 350/350 [3.6m elapsed, 0s remaining, 1.7 images/s]      


INFO:eta.core.utils: 100% |██████████████████| 350/350 [3.6m elapsed, 0s remaining, 1.7 images/s]      


Writing annotations for 350 downloaded samples to '/root/fiftyone/coco-2017/train/labels.json'


INFO:fiftyone.utils.coco:Writing annotations for 350 downloaded samples to '/root/fiftyone/coco-2017/train/labels.json'


Downloading split 'validation' to '/root/fiftyone/coco-2017/validation' if necessary


INFO:fiftyone.zoo.datasets:Downloading split 'validation' to '/root/fiftyone/coco-2017/validation' if necessary


Found annotations at '/root/fiftyone/coco-2017/raw/instances_val2017.json'


INFO:fiftyone.utils.coco:Found annotations at '/root/fiftyone/coco-2017/raw/instances_val2017.json'


Downloading 350 images


INFO:fiftyone.utils.coco:Downloading 350 images


 100% |██████████████████| 350/350 [3.6m elapsed, 0s remaining, 1.6 images/s]      


INFO:eta.core.utils: 100% |██████████████████| 350/350 [3.6m elapsed, 0s remaining, 1.6 images/s]      


Writing annotations for 350 downloaded samples to '/root/fiftyone/coco-2017/validation/labels.json'


INFO:fiftyone.utils.coco:Writing annotations for 350 downloaded samples to '/root/fiftyone/coco-2017/validation/labels.json'


Dataset info written to '/root/fiftyone/coco-2017/info.json'


INFO:fiftyone.zoo.datasets:Dataset info written to '/root/fiftyone/coco-2017/info.json'


Loading 'coco-2017' split 'train'


INFO:fiftyone.zoo.datasets:Loading 'coco-2017' split 'train'


 100% |█████████████████| 350/350 [1.9s elapsed, 0s remaining, 185.0 samples/s]         


INFO:eta.core.utils: 100% |█████████████████| 350/350 [1.9s elapsed, 0s remaining, 185.0 samples/s]         


Loading 'coco-2017' split 'validation'


INFO:fiftyone.zoo.datasets:Loading 'coco-2017' split 'validation'


 100% |█████████████████| 350/350 [1.9s elapsed, 0s remaining, 184.5 samples/s]      


INFO:eta.core.utils: 100% |█████████████████| 350/350 [1.9s elapsed, 0s remaining, 184.5 samples/s]      


Dataset 'coco-2017-train-validation-350' created


INFO:fiftyone.zoo.datasets:Dataset 'coco-2017-train-validation-350' created


In [None]:
import numpy as np
from pycocotools.coco import COCO
import requests
from pathlib import Path
import shutil

# 創建目錄結構
os.makedirs("data/mini_coco_det/train/images", exist_ok=True)
os.makedirs("data/mini_coco_det/val/images", exist_ok=True)
os.makedirs("data/mini_coco_det/train/annotations", exist_ok=True)
os.makedirs("data/mini_coco_det/val/annotations", exist_ok=True)

# 從 FiftyOne 數據集中提取樣本
samples = list(dataset)
random.shuffle(samples)

# 分割為 train (240) 和 val (60)
train_samples = samples[:240]
val_samples = samples[240:300]

def download_and_process_samples(samples, split_name, target_count):
    """下載並處理樣本"""
    images_info = []
    annotations_info = []

    for idx, sample in enumerate(samples[:target_count]):
        # 獲取圖片信息
        image_path = sample.filepath
        image_name = os.path.basename(image_path)

        # 複製圖片到目標目錄
        target_path = f"data/mini_coco_det/{split_name}/images/{image_name}"
        shutil.copy2(image_path, target_path)

        # 獲取圖片尺寸
        from PIL import Image
        with Image.open(target_path) as img:
            width, height = img.size

        # 構建圖片信息
        image_info = {
            "id": idx + 1,
            "width": width,
            "height": height,
            "file_name": image_name
        }
        images_info.append(image_info)

        # 處理標註信息
        if sample.ground_truth and sample.ground_truth.detections:
            for det_idx, detection in enumerate(sample.ground_truth.detections):
                # 獲取邊界框座標 (相對座標轉絕對座標)
                bbox = detection.bounding_box
                x = bbox[0] * width
                y = bbox[1] * height
                w = bbox[2] * width
                h = bbox[3] * height

                # 獲取類別 ID
                class_name = detection.label
                if class_name in selected_classes:
                    class_id = selected_classes.index(class_name)

                    annotation = {
                        "id": len(annotations_info) + 1,
                        "image_id": idx + 1,
                        "category_id": class_id,
                        "bbox": [x, y, w, h],
                        "area": w * h,
                        "iscrowd": getattr(detection, 'iscrowd', 0)
                    }
                    annotations_info.append(annotation)

    return images_info, annotations_info

# 處理訓練集
print("處理訓練集...")
train_images, train_annotations = download_and_process_samples(train_samples, "train", 240)

# 處理驗證集
print("處理驗證集...")
val_images, val_annotations = download_and_process_samples(val_samples, "val", 60)


處理訓練集...
處理驗證集...


In [None]:
# 創建類別信息
categories = []
for idx, class_name in enumerate(selected_classes):
    categories.append({
        "id": idx + 1,
        "name": class_name,
        "supercategory": "object"
    })

# 創建訓練集標註文件
train_annotation = {
    "info": {
        "description": "Custom COCO Dataset - Train Split",
        "version": "1.0",
        "year": 2025
    },
    "licenses": [],
    "images": train_images,
    "annotations": train_annotations,
    "categories": categories
}

# 創建驗證集標註文件
val_annotation = {
    "info": {
        "description": "Custom COCO Dataset - Validation Split",
        "version": "1.0",
        "year": 2025
    },
    "licenses": [],
    "images": val_images,
    "annotations": val_annotations,
    "categories": categories
}

# 保存標註文件
with open("data/mini_coco_det/train/annotations/annotations.json", "w") as f:
    json.dump(train_annotation, f, indent=2)

with open("data/mini_coco_det/val/annotations/annotations.json", "w") as f:
    json.dump(val_annotation, f, indent=2)

print(f"數據集創建完成！")
print(f"訓練集: {len(train_images)} 張圖片, {len(train_annotations)} 個標註")
print(f"驗證集: {len(val_images)} 張圖片, {len(val_annotations)} 個標註")
print(f"類別數量: {len(categories)}")
print(f"選擇的類別: {selected_classes}")


數據集創建完成！
訓練集: 240 張圖片, 1040 個標註
驗證集: 60 張圖片, 368 個標註
類別數量: 10
選擇的類別: ['person', 'car', 'bicycle', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'dog']


In [None]:
# 檢查文件結構

# 統計每個分割的圖片數量
train_count = len(os.listdir("data/mini_coco_det/train/images"))
val_count = len(os.listdir("data/mini_coco_det/val/images"))

print(f"訓練集圖片數量: {train_count}")
print(f"驗證集圖片數量: {val_count}")

# 檢查標註文件
with open("data/mini_coco_det/train/annotations/annotations.json", "r") as f:
    train_data = json.load(f)

print(f"訓練集標註統計:")
print(f"  圖片: {len(train_data['images'])}")
print(f"  標註: {len(train_data['annotations'])}")
print(f"  類別: {len(train_data['categories'])}")


訓練集圖片數量: 240
驗證集圖片數量: 60
訓練集標註統計:
  圖片: 240
  標註: 1040
  類別: 10


## PASCAL VOC 2012

In [None]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 下載 PASCAL VOC 2012
print("下載資料集...")
voc_dataset = datasets.VOCDetection(
    root='./data/mini_voc_seg/',
    year='2012',
    image_set='train',
    download=True,
    transform=None
)

下載資料集...


100%|██████████| 2.00G/2.00G [01:43<00:00, 19.2MB/s]


In [None]:
import os
import shutil
import random
from PIL import Image

def convert_voc_to_minivoc_seg():
    """將下載的 VOC2012 轉換為作業要求的 Mini-VOC-Seg 格式"""

    # VOC2012 路徑
    voc_root = "./data/mini_voc_seg/VOCdevkit/VOC2012"  # 你的 VOC2012 目錄
    images_dir = os.path.join(voc_root, "JPEGImages")
    seg_class_dir = os.path.join(voc_root, "SegmentationClass")

    # 獲取所有有分割標註的圖片
    seg_images = []
    for img_file in os.listdir(seg_class_dir):
        if img_file.endswith('.png'):
            img_id = img_file.replace('.png', '')
            jpg_path = os.path.join(images_dir, f"{img_id}.jpg")
            png_path = os.path.join(seg_class_dir, img_file)

            if os.path.exists(jpg_path):
                seg_images.append((jpg_path, png_path, img_id))

    print(f"找到 {len(seg_images)} 張有分割標註的圖片")

    # 隨機 shuffle 並選擇 300 張
    random.seed(42)
    random.shuffle(seg_images)
    selected_images = seg_images[:300]

    # 分割為 240 train / 60 val
    train_images = selected_images[:240]
    val_images = selected_images[240:300]

    # 創建作業要求的目錄結構
    os.makedirs("data/mini_voc_seg/train/images", exist_ok=True)
    os.makedirs("data/mini_voc_seg/val/images", exist_ok=True)
    os.makedirs("data/mini_voc_seg/train/annotations", exist_ok=True)
    os.makedirs("data/mini_voc_seg/val/annotations", exist_ok=True)

    def copy_split_data(image_list, split_name):
        """複製並重命名文件到目標目錄"""
        for idx, (jpg_path, png_path, img_id) in enumerate(image_list):
            # 新的文件名格式
            new_jpg_name = f"{split_name}_{idx:06d}.jpg"
            new_png_name = f"{split_name}_{idx:06d}.png"

            # 目標路徑
            dst_jpg = os.path.join(f"data/mini_voc_seg/{split_name}/images", new_jpg_name)
            dst_png = os.path.join(f"data/mini_voc_seg/{split_name}/annotations", new_png_name)

            # 複製文件
            shutil.copy2(jpg_path, dst_jpg)
            shutil.copy2(png_path, dst_png)

            if (idx + 1) % 50 == 0:
                print(f"已處理 {split_name} 集 {idx + 1} 張圖片")

    # 處理訓練集和驗證集
    print("處理訓練集...")
    copy_split_data(train_images, "train")

    print("處理驗證集...")
    copy_split_data(val_images, "val")

    print(f"\nMini-VOC-Seg 數據集創建完成！")
    print(f"訓練集: {len(train_images)} 張")
    print(f"驗證集: {len(val_images)} 張")

    return len(train_images), len(val_images)

# 執行轉換
train_count, val_count = convert_voc_to_minivoc_seg()


找到 2913 張有分割標註的圖片
處理訓練集...
已處理 train 集 50 張圖片
已處理 train 集 100 張圖片
已處理 train 集 150 張圖片
已處理 train 集 200 張圖片
處理驗證集...
已處理 val 集 50 張圖片

Mini-VOC-Seg 數據集創建完成！
訓練集: 240 張
驗證集: 60 張


## Imagemette-160

In [None]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import os
import shutil
import random
import json
from collections import defaultdict

def download_imagenette_160():
    """使用 PyTorch API 下載 Imagenette-160"""

    # 下載訓練集
    train_dataset = datasets.Imagenette(
        root='./data/imagenette_data',
        split='train',
        size='160px',
        download=True,
        transform=None
    )

    # 下載驗證集
    val_dataset = datasets.Imagenette(
        root='./data/imagenette_160',
        split='val',
        size='160px',
        download=True,
        transform=None
    )

    print(f"Imagenette-160 訓練集: {len(train_dataset)} 張")
    print(f"Imagenette-160 驗證集: {len(val_dataset)} 張")
    print(f"類別數: {len(train_dataset.classes)}")

    return train_dataset, val_dataset

# 執行下載
train_dataset, val_dataset = download_imagenette_160()


100%|██████████| 99.0M/99.0M [00:19<00:00, 5.03MB/s]
100%|██████████| 99.0M/99.0M [00:20<00:00, 4.90MB/s]


Imagenette-160 訓練集: 9469 張
Imagenette-160 驗證集: 3925 張
類別數: 10


In [None]:
import os
import shutil
import random
from collections import defaultdict
import json

def select_imagenette_samples():
    """從現有的 Imagenette-160 結構中挑選指定數量的圖片"""

    # 設定隨機種子
    random.seed(42)

    # 正確的路徑（根據你的目錄結構）
    train_dir = "./data/imagenette_160/imagenette2-160/train"
    val_dir = "./data/imagenette_160/imagenette2-160/val"

    # 檢查路徑是否存在
    if not os.path.exists(train_dir):
        print(f"錯誤: 找不到目錄 {train_dir}")
        return
    if not os.path.exists(val_dir):
        print(f"錯誤: 找不到目錄 {val_dir}")
        return

    # 類別映射
    class_mapping = {
        'n01440764': 'tench',
        'n02102040': 'English_springer',
        'n02979186': 'cassette_player',
        'n03000684': 'chain_saw',
        'n03028079': 'church',
        'n03394916': 'French_horn',
        'n03417042': 'garbage_truck',
        'n03425413': 'gas_pump',
        'n03445777': 'golf_ball',
        'n03888257': 'parachute'
    }

    # 收集每個類別的圖片
    def collect_class_images(base_dir):
        class_images = defaultdict(list)
        for class_folder in os.listdir(base_dir):
            class_path = os.path.join(base_dir, class_folder)
            if os.path.isdir(class_path) and class_folder in class_mapping:
                for img_file in os.listdir(class_path):
                    if img_file.lower().endswith(('.jpg', '.jpeg', '.png')):
                        img_path = os.path.join(class_path, img_file)
                        class_images[class_folder].append(img_path)
        return class_images

    # 收集訓練集和驗證集圖片
    print("收集訓練集圖片...")
    train_class_images = collect_class_images(train_dir)

    print("收集驗證集圖片...")
    val_class_images = collect_class_images(val_dir)

    # 顯示每個類別的圖片數量
    print("\n原始數據統計:")
    for class_id in class_mapping.keys():
        train_count = len(train_class_images.get(class_id, []))
        val_count = len(val_class_images.get(class_id, []))
        print(f"{class_mapping[class_id]} ({class_id}): train={train_count}, val={val_count}")

    # 從每個類別隨機挑選指定數量的圖片
    train_selected = {}
    val_selected = {}

    for class_id, class_name in class_mapping.items():
        # 訓練集：每類挑選 24 張
        train_images = train_class_images.get(class_id, [])
        if len(train_images) >= 24:
            train_selected[class_id] = random.sample(train_images, 24)
        else:
            train_selected[class_id] = train_images
            print(f"警告: {class_name} 訓練集只有 {len(train_images)} 張圖片，少於 24 張")

        # 驗證集：每類挑選 6 張
        val_images = val_class_images.get(class_id, [])
        if len(val_images) >= 6:
            val_selected[class_id] = random.sample(val_images, 6)
        else:
            val_selected[class_id] = val_images
            print(f"警告: {class_name} 驗證集只有 {len(val_images)} 張圖片，少於 6 張")

    return train_selected, val_selected, class_mapping

# 執行挑選
train_selected, val_selected, class_mapping = select_imagenette_samples()


收集訓練集圖片...
收集驗證集圖片...

原始數據統計:
tench (n01440764): train=963, val=387
English_springer (n02102040): train=955, val=395
cassette_player (n02979186): train=993, val=357
chain_saw (n03000684): train=858, val=386
church (n03028079): train=941, val=409
French_horn (n03394916): train=956, val=394
garbage_truck (n03417042): train=961, val=389
gas_pump (n03425413): train=931, val=419
golf_ball (n03445777): train=951, val=399
parachute (n03888257): train=960, val=390


In [None]:
def save_mini_imagenette_160(train_selected, val_selected, class_mapping):
    """保存挑選的圖片為作業要求的格式"""

    # 創建目標目錄
    os.makedirs("data/imagenette_160/train/images", exist_ok=True)
    os.makedirs("data/imagenette_160/train/annotations", exist_ok=True)
    os.makedirs("data/imagenette_160/val/images", exist_ok=True)
    os.makedirs("data/imagenette_160/val/annotations", exist_ok=True)

    def process_split(selected_images, split_name):
        """處理單個分割的數據"""
        all_samples = []
        class_counts = defaultdict(int)

        # 將所有選中的圖片合併到一個列表中
        for class_id, image_paths in selected_images.items():
            class_name = class_mapping[class_id]
            for img_path in image_paths:
                all_samples.append((img_path, class_name, class_id))
                class_counts[class_name] += 1

        # 隨機打亂所有樣本
        random.shuffle(all_samples)

        # 保存圖片並創建標籤
        labels_info = []
        class_to_idx = {name: idx for idx, name in enumerate(sorted(class_mapping.values()))}

        for idx, (img_path, class_name, class_id) in enumerate(all_samples):
            # 新文件名
            img_filename = f"{split_name}_{idx:06d}.jpg"
            dst_path = os.path.join(f"data/imagenette_160/{split_name}/images", img_filename)

            # 複製圖片
            shutil.copy2(img_path, dst_path)

            # 記錄標籤信息
            label_info = {
                "filename": img_filename,
                "class_name": class_name,
                "class_id": class_to_idx[class_name],
                "original_class_id": class_id,
                "original_path": img_path
            }
            labels_info.append(label_info)

            if (idx + 1) % 30 == 0:
                print(f"已處理 {split_name} 集 {idx + 1} 張圖片")

        # 保存標籤文件
        labels_data = {
            "split": split_name,
            "num_images": len(all_samples),
            "classes": sorted(class_mapping.values()),
            "class_to_idx": class_to_idx,
            "class_distribution": dict(class_counts),
            "labels": labels_info
        }

        with open(f"data/imagenette_160/{split_name}/annotations/labels.json", "w") as f:
            json.dump(labels_data, f, indent=2)

        return len(all_samples), class_counts

    # 處理訓練集
    print("\n處理訓練集...")
    train_count, train_class_counts = process_split(train_selected, "train")

    # 處理驗證集
    print("處理驗證集...")
    val_count, val_class_counts = process_split(val_selected, "val")

    # 創建數據集信息
    dataset_info = {
        "name": "Imagenette_160",
        "description": "Selected subset from Imagenette-160 for multi-task learning",
        "total_images": train_count + val_count,
        "train_images": train_count,
        "val_images": val_count,
        "num_classes": len(class_mapping),
        "classes": sorted(class_mapping.values()),
        "selection_strategy": "24 per class for train, 6 per class for val",
        "train_class_distribution": dict(train_class_counts),
        "val_class_distribution": dict(val_class_counts)
    }

    with open("data/imagenette_160/dataset_info.json", "w") as f:
        json.dump(dataset_info, f, indent=2)

    print(f"\nMini-Imagenette-160 創建完成！")
    print(f"訓練集: {train_count} 張圖片")
    print(f"驗證集: {val_count} 張圖片")

    print("\n訓練集類別分佈:")
    for class_name, count in sorted(train_class_counts.items()):
        print(f"  {class_name}: {count} 張")

    print("\n驗證集類別分佈:")
    for class_name, count in sorted(val_class_counts.items()):
        print(f"  {class_name}: {count} 張")

# 執行保存
save_mini_imagenette_160(train_selected, val_selected, class_mapping)



處理訓練集...
已處理 train 集 30 張圖片
已處理 train 集 60 張圖片
已處理 train 集 90 張圖片
已處理 train 集 120 張圖片
已處理 train 集 150 張圖片
已處理 train 集 180 張圖片
已處理 train 集 210 張圖片
已處理 train 集 240 張圖片
處理驗證集...
已處理 val 集 30 張圖片
已處理 val 集 60 張圖片

Mini-Imagenette-160 創建完成！
訓練集: 240 張圖片
驗證集: 60 張圖片

訓練集類別分佈:
  English_springer: 24 張
  French_horn: 24 張
  cassette_player: 24 張
  chain_saw: 24 張
  church: 24 張
  garbage_truck: 24 張
  gas_pump: 24 張
  golf_ball: 24 張
  parachute: 24 張
  tench: 24 張

驗證集類別分佈:
  English_springer: 6 張
  French_horn: 6 張
  cassette_player: 6 張
  chain_saw: 6 張
  church: 6 張
  garbage_truck: 6 張
  gas_pump: 6 張
  golf_ball: 6 張
  parachute: 6 張
  tench: 6 張


In [None]:
def verify_final_dataset():
    """驗證最終的數據集"""

    print("驗證 Mini-Imagenette-160 數據集:")

    total_size = 0
    for split in ['train', 'val']:
        split_dir = f"data/imagenette-160/{split}"

        if os.path.exists(split_dir):
            # 統計圖片文件
            image_files = [f for f in os.listdir(split_dir) if f.endswith('.jpg')]

            # 計算大小
            split_size = 0
            for file in os.listdir(split_dir):
                file_path = os.path.join(split_dir, file)
                split_size += os.path.getsize(file_path)

            total_size += split_size

            print(f"\n{split} 集:")
            print(f"  圖片數量: {len(image_files)}")
            print(f"  大小: {split_size / (1024*1024):.1f} MB")

            # 檢查標籤文件
            labels_file = os.path.join(split_dir, "labels.json")
            if os.path.exists(labels_file):
                with open(labels_file, 'r') as f:
                    labels_data = json.load(f)
                print(f"  類別數: {len(labels_data['classes'])}")
                print(f"  每類分佈: {labels_data['class_distribution']}")

    print(f"\n總大小: {total_size / (1024*1024):.1f} MB")
    print(f"符合作業要求: ≈25 MB, 240 train + 60 val")

verify_final_dataset()

驗證 Mini-Imagenette-160 數據集:

總大小: 0.0 MB
符合作業要求: ≈25 MB, 240 train + 60 val


# Training

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import mobilenet_v3_small
import torchvision.transforms as transforms

class MultiTaskHead(nn.Module):
    def __init__(self, in_channels, num_det_classes=10, num_seg_classes=21, num_cls_classes=10):
        super().__init__()

        # 共享特徵提取層 (符合參數限制)
        self.shared_conv = nn.Sequential(
            nn.Conv2d(in_channels, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        # 分割專用上採樣路徑 (新增部分)
        self.seg_upsample = nn.Sequential(
            nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True),  # 關鍵修正
            nn.Conv2d(128, num_seg_classes, 1)  # 輸出通道數=類別數
        )

        # 檢測與分類輸出層
        self.det_cls_conv = nn.Conv2d(128, 6 + num_det_classes + num_cls_classes, 1)

        self.num_det_classes = num_det_classes
        self.num_cls_classes = num_cls_classes

    def forward(self, x):
        shared_feat = self.shared_conv(x)

        # 分割輸出 (512x512)
        seg_output = self.seg_upsample(shared_feat)  # [B,21,512,512]

        # 檢測與分類輸出
        det_cls_output = self.det_cls_conv(shared_feat)
        det_output = det_cls_output[:, :6+self.num_det_classes]
        cls_output = F.adaptive_avg_pool2d(
            det_cls_output[:, 6+self.num_det_classes:], (1, 1)
        ).squeeze(-1).squeeze(-1)

        return det_output, seg_output, cls_output


class MultiTaskModel(nn.Module):
    def __init__(self):
        super(MultiTaskModel, self).__init__()

        # 使用 MobileNetV3-Small 作為骨幹網路
        self.backbone = mobilenet_v3_small(pretrained=True)
        self.backbone.classifier = nn.Identity()  # 移除分類頭

        # 簡單的頸部網路
        self.neck = nn.Sequential(
            nn.Conv2d(576, 256, 3, padding=1),  # MobileNetV3-Small 輸出通道數
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        # 多任務頭部
        self.head = MultiTaskHead(256)

    def forward(self, x):
        # 骨幹網路特徵提取
        features = self.backbone.features(x)

        # 頸部處理
        neck_feat = self.neck(features)

        # 多任務輸出
        det_out, seg_out, cls_out = self.head(neck_feat)

        return det_out, seg_out, cls_out


In [None]:
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torchvision.ops import box_iou  # 添加這行導入

class EWC:
    def __init__(self, model, task_loaders, importance=1000):
        """
        多任務版EWC正確實作
        :param task_loaders: 字典格式 {
            'seg': seg_loader,
            'det': det_loader,
            'cls': cls_loader
        }
        """
        self.model = model
        self.importance = importance
        self.params = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad}
        self.fisher = self._compute_multitask_fisher(task_loaders)

    def _compute_multitask_fisher(self, task_loaders):
        fisher = {n: torch.zeros_like(p) for n, p in self.model.named_parameters() if p.requires_grad}

        self.model.eval()

        # 驗證輸入格式
        if not isinstance(task_loaders, dict):
            raise TypeError("task_loaders必須是字典格式，例如: {'seg': seg_loader, ...}")

        # 分任務計算
        for task_name, loader in task_loaders.items():
            print(f"正在計算 {task_name} 任務的Fisher信息...")
            for data, targets in tqdm(loader, desc=task_name):
                self.model.zero_grad()
                data = data.to(self.model.device)
                if task_name == 'det':
                    # 手动移动每个目标的键值到设备
                    device_targets = []
                    for t in targets:
                        device_targets.append({
                            'boxes': t['boxes'].to(self.model.device),
                            'labels': t['labels'].to(self.model.device),
                            'image_id': t['image_id'].to(self.model.device)
                        })
                else:
                    # 其他任务直接移动整个张量
                    targets = targets.to(self.model.device)

                # 根據任務類型計算損失
                if task_name == 'seg':
                    _, seg_out, _ = self.model(data)
                    loss = F.cross_entropy(seg_out, targets, ignore_index=255)
                elif task_name == 'det':
                    det_out, _, _ = self.model(data)
                    loss = compute_detection_loss(det_out, targets, 10)
                elif task_name == 'cls':
                    _, _, cls_out = self.model(data)
                    loss = F.cross_entropy(cls_out, targets)

                loss.backward()

                # 累積Fisher信息
                for n, p in self.model.named_parameters():
                   if p.requires_grad:
                        if p.grad is not None:
                            fisher[n] += p.grad.data ** 2
                        else:
                            # 可選擇初始化或跳過
                            fisher[n] += torch.zeros_like(p)

        # 正規化
        total_batches = sum(len(loader) for loader in task_loaders.values())
        for n in fisher:
            fisher[n] /= total_batches

        return fisher

    def penalty(self, model):
        loss = 0
        for n, p in model.named_parameters():
            if n in self.fisher:
                loss += (self.fisher[n] * (p - self.params[n]) ** 2).sum()
        return self.importance * loss

class MultiTaskTrainer:
    def __init__(self, model, device):
        self.model = model
        self.model.device = device  # 新增device屬性
        self.device = device
        self.ewc = None
        self.task_fishers = {}  # 儲存各任務的EWC參數

    def train_stage(self, dataloader, task_type, num_epochs=60, lr=1e-4):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.model.train()

        for epoch in tqdm(range(num_epochs), desc=f"Training {task_type}"):
            total_loss = 0
            for batch_idx, (data, target) in enumerate(dataloader):
                data = data.to(self.device)
                if task_type == 'detection':
                    # 手动移动每个目标的键值到设备
                    device_targets = []
                    for t in target:
                        device_targets.append({
                            'boxes': t['boxes'].to(self.device),
                            'labels': t['labels'].to(self.device),
                            'image_id': t['image_id'].to(self.device)
                        })
                else:
                    # 其他任务直接移动整个张量
                    target = target.to(self.device)


                optimizer.zero_grad()
                det_out, seg_out, cls_out = self.model(data)

                # 根據任務類型計算損失
                if task_type == 'segmentation':
                    # print(torch.unique(target))
                    loss = F.cross_entropy(seg_out, target, ignore_index=255)
                elif task_type == 'detection':
                    loss = compute_detection_loss(det_out, target, 10)
                elif task_type == 'classification':
                    loss = F.cross_entropy(cls_out, target)

                # 加入所有已學任務的EWC正則化
                if self.ewc is not None:
                    loss += self.ewc.penalty(self.model)

                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            print(f'{task_type} Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataloader):.4f}')

    def set_ewc(self, task_loaders):
        """更新EWC參數，需傳入字典格式的task_loaders"""
        self.ewc = EWC(self.model, task_loaders)

def compute_detection_loss(predictions, targets, num_classes):
    """
    改進版檢測損失函數，解決設備不一致問題
    """
    total_loss = 0
    batch_size = predictions.shape[0]
    device = predictions.device  # 獲取預測張量的設備

    for i in range(batch_size):
        pred = predictions[i]  # [6+num_classes, H, W]
        target = targets[i]

        if len(target['boxes']) == 0:
            continue

        # 解碼預測框時確保設備一致
        H, W = pred.shape[1], pred.shape[2]

        # 生成網格座標時指定設備
        y_grid, x_grid = torch.meshgrid(
            torch.arange(H, device=device),  # 關鍵修正：添加device參數
            torch.arange(W, device=device),
            indexing='ij'
        )
        grid_centers = torch.stack([x_grid, y_grid], dim=-1).float()

        # 座標轉換（保持設備一致）
        pred_boxes = pred[:4].permute(1, 2, 0)  # [H, W, 4]
        pred_boxes_abs = torch.zeros_like(pred_boxes, device=device)  # 明確指定設備

        # 確保所有操作在相同設備
        pred_boxes_abs[..., 0] = (grid_centers[..., 0] + pred_boxes[..., 0]) / W
        pred_boxes_abs[..., 1] = (grid_centers[..., 1] + pred_boxes[..., 1]) / H
        pred_boxes_abs[..., 2] = pred_boxes[..., 2]
        pred_boxes_abs[..., 3] = pred_boxes[..., 3]

        # 後續計算需確保目標數據也在相同設備
        target_boxes = target['boxes'].to(device)
        target_labels = target['labels'].to(device)

        # 計算IOU矩陣（保持設備一致）
        flat_boxes = pred_boxes_abs.view(-1, 4)
        ious = box_iou(flat_boxes, target_boxes)

        # 匈牙利匹配（需處理設備轉換）
        match_indices = hungarian_matching(ious)

        # 損失計算（全部在相同設備）
        if len(match_indices) > 0:
            # 座標損失
            matched_pred = flat_boxes[match_indices[:, 0]]
            matched_target = target_boxes[match_indices[:, 1]]
            box_loss = F.mse_loss(matched_pred, matched_target)

            # 置信度損失
            pred_conf = pred[4].sigmoid().view(-1)
            conf_target = torch.zeros_like(pred_conf, device=device)
            conf_target[match_indices[:, 0]] = 1
            conf_loss = F.binary_cross_entropy(pred_conf, conf_target)

            # 分類損失
            pred_cls = pred[6:].permute(1, 2, 0).view(-1, num_classes)
            cls_loss = F.cross_entropy(
                pred_cls[match_indices[:, 0]],
                target_labels[match_indices[:, 1]]
            )
        else:
            box_loss = torch.tensor(0.0, device=device)
            conf_loss = F.binary_cross_entropy(
                pred[4].sigmoid().view(-1),
                torch.zeros_like(pred[4].view(-1), device=device)
            )
            cls_loss = torch.tensor(0.0, device=device)

        total_loss += box_loss + conf_loss + cls_loss

    return total_loss / batch_size if batch_size > 0 else torch.tensor(0.0)

def hungarian_matching(cost_matrix):
    """改進的匈牙利匹配函數，處理設備轉換"""
    from scipy.optimize import linear_sum_assignment
    cost_matrix_np = (1 - cost_matrix).detach().cpu().numpy()
    row_ind, col_ind = linear_sum_assignment(cost_matrix_np)
    return torch.stack([
        torch.tensor(row_ind, dtype=torch.long).to(cost_matrix.device),
        torch.tensor(col_ind, dtype=torch.long).to(cost_matrix.device)
    ], dim=1)


In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import json
import torchvision.transforms as transforms
from pycocotools.coco import COCO
import numpy as np

class MultiTaskDataset(Dataset):
    def __init__(self, data_root, task_type, split='train', transform=None):
        self.data_root = data_root
        self.task_type = task_type
        self.split = split
        self.transform = transform
        self.samples = []

        # 根據任務類型載入對應的數據
        if task_type == 'segmentation':
            self.data_dir = os.path.join(data_root, 'mini_voc_seg', split)
            self._load_segmentation_samples()
        elif task_type == 'detection':
            self.data_dir = os.path.join(data_root, 'mini_coco_det', split)
            self._load_detection_samples()
        elif task_type == 'classification':
            self.data_dir = os.path.join(data_root, 'imagenette_160', split)
            self._load_classification_samples()
        else:
            raise ValueError(f"不支援的任務類型: {task_type}")

    def _load_segmentation_samples(self):
        """載入分割任務的樣本 (VOC格式)"""
        if not os.path.exists(self.data_dir):
            raise FileNotFoundError(f"找不到目錄: {self.data_dir}")

        # 獲取所有 .jpg 圖片文件
        for filename in os.listdir(os.path.join(self.data_dir, "images")):
            if filename.endswith('.jpg'):
                img_path = os.path.join(self.data_dir, "images", filename)
                # 對應的 mask 文件 (.png)
                mask_filename = filename.replace('.jpg', '.png')
                mask_path = os.path.join(self.data_dir, "annotations", mask_filename)
                # print(mask_path)

                if os.path.exists(mask_path):
                    self.samples.append({
                        'image_path': img_path,
                        'mask_path': mask_path,
                        'filename': filename
                    })

        print(f"載入分割樣本: {len(self.samples)} 個")

    def _load_detection_samples(self):
        """載入檢測任務的樣本 (COCO格式)"""
        if not os.path.exists(self.data_dir):
            raise FileNotFoundError(f"找不到目錄: {self.data_dir}")

        # 載入 COCO 格式的標註文件
        annotation_file = os.path.join(self.data_dir, "annotations", 'annotations.json')
        if not os.path.exists(annotation_file):
            raise FileNotFoundError(f"找不到標註文件: {annotation_file}")

        self.coco = COCO(annotation_file)
        self.image_ids = list(self.coco.imgs.keys())

        for img_id in self.image_ids:
            img_info = self.coco.imgs[img_id]
            img_path = os.path.join(self.data_dir, "images", img_info['file_name'])

            # 獲取該圖片的所有標註
            ann_ids = self.coco.getAnnIds(imgIds=img_id)
            annotations = self.coco.loadAnns(ann_ids)

            self.samples.append({
                'image_path': img_path,
                'image_id': img_id,
                'annotations': annotations,
                'image_info': img_info
            })

        print(f"載入檢測樣本: {len(self.samples)} 個")

    def _load_classification_samples(self):
        """載入分類任務的樣本 (Imagenette格式)"""
        if not os.path.exists(self.data_dir):
            raise FileNotFoundError(f"找不到目錄: {self.data_dir}")

        # 載入標籤文件
        labels_file = os.path.join(self.data_dir, "annotations", 'labels.json')
        if not os.path.exists(labels_file):
            raise FileNotFoundError(f"找不到標籤文件: {labels_file}")

        with open(labels_file, 'r') as f:
            labels_data = json.load(f)

        self.classes = labels_data['classes']
        self.class_to_idx = labels_data['class_to_idx']

        for label_info in labels_data['labels']:
            img_path = os.path.join(self.data_dir, "images", label_info['filename'])

            self.samples.append({
                'image_path': img_path,
                'class_name': label_info['class_name'],
                'class_id': label_info['class_id'],
                'filename': label_info['filename']
            })

        print(f"載入分類樣本: {len(self.samples)} 個")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        # 載入圖片
        image = Image.open(sample['image_path']).convert('RGB')

        if self.task_type == 'segmentation':
            # 分割任務：返回圖片和分割 mask
            mask = Image.open(sample['mask_path'])

            if self.transform:
                image = self.transform(image)
                # mask 需要特殊處理，不能使用 normalize
                mask = mask.resize((512, 512), Image.NEAREST)
                mask = torch.from_numpy(np.array(mask)).long()  # int64 (long)
                # print("get item",torch.unique(mask))
            else:
                mask = np.array(mask)
            return image, mask

        elif self.task_type == 'detection':
            # 檢測任務：返回圖片和邊界框標註
            annotations = sample['annotations']

            # 處理邊界框和標籤
            boxes = []
            labels = []

            for ann in annotations:
                # COCO格式: [x, y, width, height]
                bbox = ann['bbox']
                # 轉換為 [x1, y1, x2, y2] 格式
                x1, y1, w, h = bbox
                x2, y2 = x1 + w, y1 + h
                boxes.append([x1, y1, x2, y2])
                labels.append(ann['category_id'])

            # 套用轉換 (需同步處理影像和邊界框)
            if self.transform:
                # 獲取原始影像尺寸
                orig_w, orig_h = image.size

                # 套用影像轉換
                image = self.transform(image)

                # 計算縮放比例 (假設transform包含Resize到固定尺寸)
                new_h, new_w = image.shape[1], image.shape[2]  # C,H,W格式

                # 調整邊界框座標
                scale_x = new_w / orig_w
                scale_y = new_h / orig_h
                boxes = torch.tensor(boxes, dtype=torch.float32)
                if len(boxes) > 0:
                    boxes[:, [0, 2]] *= scale_x
                    boxes[:, [1, 3]] *= scale_y

            # 處理空標註情況
            if len(boxes) == 0:
                boxes = torch.zeros((0, 4), dtype=torch.float32)
                labels = torch.zeros((0,), dtype=torch.int64)
            else:
                boxes = torch.tensor(boxes, dtype=torch.float32)
                labels = torch.tensor(labels, dtype=torch.int64)

            target = {
                'boxes': boxes,
                'labels': labels,
                'image_id': torch.tensor([sample['image_id']], dtype=torch.int64)
            }

            return image, target

        elif self.task_type == 'classification':
            # 分類任務：返回圖片和類別標籤
            label = sample['class_id']

            if self.transform:
                image = self.transform(image)

            return image, torch.tensor(label, dtype=torch.long)

# 不同任務的資料轉換
def get_transforms(task_type, split='train'):
    """根據任務類型獲取對應的資料轉換"""

    if task_type == 'classification':
        if split == 'train':
            return transforms.Compose([
                transforms.Resize((512, 512)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else:
            return transforms.Compose([
                transforms.Resize((512, 512)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])

    elif task_type == 'segmentation':
        return transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    elif task_type == 'detection':
        return transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])


In [None]:
def evaluate_segmentation(model, dataloader, device):
    model.eval()
    total_iou = 0
    num_samples = 0

    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            _, seg_out, _ = model(data)

            pred = torch.argmax(seg_out, dim=1)
            iou = compute_iou(pred, target)
            total_iou += iou
            num_samples += 1

    return total_iou / num_samples



def compute_iou(pred, target):
    # 計算IoU的實作
    intersection = (pred & target).float().sum()
    union = (pred | target).float().sum()
    return intersection / union if union > 0 else 0.0


In [None]:
import torch
import numpy as np
from collections import defaultdict

def evaluate_detection(model, dataloader, device, iou_threshold=0.5):
    """
    評估檢測任務的mAP指標
    """
    model.eval()
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for data, targets in dataloader:
            data = data.to(device)

            # 模型輸出
            det_out, _, _ = model(data)

            # 處理檢測輸出 (假設輸出格式為 N x (6 + num_classes))
            batch_size = det_out.shape[0]

            for i in range(batch_size):
                # 解析檢測輸出
                pred_boxes, pred_scores, pred_labels = parse_detection_output(det_out[i])

                all_predictions.append({
                    'boxes': pred_boxes,
                    'scores': pred_scores,
                    'labels': pred_labels
                })

                # 處理ground truth
                target = targets[i] if isinstance(targets, list) else targets
                all_targets.append(target)

    # 計算mAP
    mAP = compute_map(all_predictions, all_targets, iou_threshold)
    return mAP

def parse_detection_output(det_output, conf_threshold=0.5, num_classes=10):
    """
    解析检测输出（假设输出形状为 [16, H, W]）
    det_output: 模型输出的检测头结果，形状 [16, H, W]
    """
    # 分离参数
    cx = det_output[0]   # 中心x坐标 [H, W]
    cy = det_output[1]   # 中心y坐标 [H, W]
    w = det_output[2]    # 宽度 [H, W]
    h = det_output[3]    # 高度 [H, W]
    conf = det_output[4].sigmoid()  # 置信度 [H, W]
    cls_probs = det_output[5:5+num_classes].softmax(dim=0)  # 类别概率 [num_classes, H, W]

    # 生成网格坐标
    grid_h, grid_w = cx.shape
    y_grid, x_grid = torch.meshgrid(
        torch.arange(grid_h, device=det_output.device),
        torch.arange(grid_w, device=det_output.device),
        indexing='ij'
    )

    # 转换为绝对坐标（假设输入图像尺寸为 512x512）
    scale = 512 / grid_h  # 特征图到原图的缩放比例
    x1 = (x_grid + cx - w/2) * scale
    y1 = (y_grid + cy - h/2) * scale
    x2 = (x_grid + cx + w/2) * scale
    y2 = (y_grid + cy + h/2) * scale

    # 展平所有预测
    boxes = torch.stack([x1, y1, x2, y2], dim=-1).reshape(-1, 4)  # [H*W, 4]
    confidences = conf.reshape(-1)                                # [H*W]
    class_ids = cls_probs.permute(1,2,0).reshape(-1, num_classes).argmax(dim=1)  # [H*W]

    # 过滤低置信度预测
    mask = confidences > conf_threshold
    return boxes[mask], confidences[mask], class_ids[mask]

def compute_map(predictions, targets, iou_threshold=0.5):
    """
    計算mAP指標
    """
    # 收集所有類別的AP
    all_aps = []

    # 獲取所有類別
    all_classes = set()
    for target in targets:
        if 'labels' in target:
            all_classes.update(target['labels'].tolist())

    for class_id in all_classes:
        # 收集該類別的所有預測和真值
        class_predictions = []
        class_targets = []

        for i, (pred, target) in enumerate(zip(predictions, targets)):
            # 預測中的該類別
            if len(pred['labels']) > 0:
                class_mask = pred['labels'] == class_id
                if class_mask.any():
                    class_predictions.append({
                        'image_id': i,
                        'boxes': pred['boxes'][class_mask],
                        'scores': pred['scores'][class_mask]
                    })

            # 真值中的該類別
            if 'labels' in target and len(target['labels']) > 0:
                gt_mask = target['labels'] == class_id
                if gt_mask.any():
                    class_targets.append({
                        'image_id': i,
                        'boxes': target['boxes'][gt_mask]
                    })

        # 計算該類別的AP
        ap = compute_ap(class_predictions, class_targets, iou_threshold)
        all_aps.append(ap)

    return np.mean(all_aps) if all_aps else 0.0

def compute_ap(predictions, targets, iou_threshold):
    """
    計算單一類別的Average Precision
    """
    if not predictions or not targets:
        return 0.0

    # 按分數排序預測
    all_pred_boxes = []
    all_pred_scores = []
    all_pred_image_ids = []

    for pred in predictions:
        for box, score in zip(pred['boxes'], pred['scores']):
            all_pred_boxes.append(box)
            all_pred_image_ids.append(pred['image_id'])

    all_pred_scores = [pred['scores'].cpu().numpy() for pred in predictions]
    all_pred_scores = np.concatenate(all_pred_scores)


    if all_pred_scores.size == 0:  # 使用.size替代直接布林判斷
        return 0.0

    # 排序
    sorted_indices = np.argsort(all_pred_scores)[::-1]

    # 建立ground truth字典
    gt_dict = defaultdict(list)
    for target in targets:
        gt_dict[target['image_id']].extend(target['boxes'])

    # 計算precision和recall
    tp = np.zeros(len(sorted_indices))
    fp = np.zeros(len(sorted_indices))

    for i, idx in enumerate(sorted_indices):
        pred_box = all_pred_boxes[idx]
        image_id = all_pred_image_ids[idx]

        if image_id in gt_dict:
            gt_boxes = gt_dict[image_id]
            ious = [compute_iou_boxes(pred_box, gt_box) for gt_box in gt_boxes]
            max_iou = max(ious) if ious else 0

            if max_iou >= iou_threshold:
                tp[i] = 1
            else:
                fp[i] = 1
        else:
            fp[i] = 1

    # 累積和
    tp_cumsum = np.cumsum(tp)
    fp_cumsum = np.cumsum(fp)

    # 計算precision和recall
    total_gt = sum(len(target['boxes']) for target in targets)
    recalls = tp_cumsum / max(total_gt, 1)
    precisions = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-8)

    # 計算AP (使用11點插值)
    ap = 0
    for t in np.arange(0, 1.1, 0.1):
        if np.sum(recalls >= t) == 0:
            p = 0
        else:
            p = np.max(precisions[recalls >= t])
        ap += p / 11

    return ap

def compute_iou_boxes(box1, box2):
    """
    計算兩個bounding box的IoU
    """
    x1_1, y1_1, x2_1, y2_1 = box1
    x1_2, y1_2, x2_2, y2_2 = box2

    # 計算交集
    x1_inter = max(x1_1, x1_2)
    y1_inter = max(y1_1, y1_2)
    x2_inter = min(x2_1, x2_2)
    y2_inter = min(y2_1, y2_2)

    if x2_inter <= x1_inter or y2_inter <= y1_inter:
        return 0.0

    inter_area = (x2_inter - x1_inter) * (y2_inter - y1_inter)

    # 計算聯集
    area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
    area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
    union_area = area1 + area2 - inter_area

    return inter_area / union_area if union_area > 0 else 0.0


In [None]:
def evaluate_classification(model, dataloader, device):
    """
    評估分類任務的Top-1準確率
    """
    model.eval()
    correct = 0
    total = 0
    top5_correct = 0

    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)

            # 模型輸出
            _, _, cls_out = model(data)

            # Top-1準確率
            _, predicted = torch.max(cls_out, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

            # Top-5準確率 (如果類別數>=5)
            if cls_out.size(1) >= 5:
                _, top5_pred = torch.topk(cls_out, 5, dim=1)
                top5_correct += sum([target[i] in top5_pred[i] for i in range(target.size(0))])

    top1_accuracy = 100. * correct / total
    top5_accuracy = 100. * top5_correct / total if cls_out.size(1) >= 5 else top1_accuracy

    return top1_accuracy, top5_accuracy

def evaluate_classification_detailed(model, dataloader, device, num_classes=10):
    """
    詳細的分類評估，包含每類別的準確率
    """
    model.eval()
    class_correct = list(0. for i in range(num_classes))
    class_total = list(0. for i in range(num_classes))

    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)

            _, _, cls_out = model(data)
            _, predicted = torch.max(cls_out, 1)

            c = (predicted == target).squeeze()
            for i in range(target.size(0)):
                label = target[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1

    # 計算每類別準確率
    class_accuracies = []
    for i in range(num_classes):
        if class_total[i] > 0:
            accuracy = 100 * class_correct[i] / class_total[i]
            class_accuracies.append(accuracy)
            print(f'Class {i}: {accuracy:.2f}%')
        else:
            class_accuracies.append(0.0)

    overall_accuracy = 100 * sum(class_correct) / sum(class_total)
    return overall_accuracy, class_accuracies


In [None]:
def detection_collate(batch):
    images = []
    targets = []

    for image, target in batch:
        images.append(image)
        targets.append(target)

    images = torch.stack(images, dim=0)
    return images, targets

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"  # 顯示具體錯誤位置
os.environ['TORCH_USE_CUDA_DSA'] = "1"    # 啟用設備端斷言

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = MultiTaskModel().to(device)
    trainer = MultiTaskTrainer(model, device)

    data_root = "data"  # 你的數據根目錄

    # 創建不同任務的數據集
    train_seg_dataset = MultiTaskDataset(
        data_root=data_root,
        task_type='segmentation',
        split='train',
        transform=get_transforms('segmentation', 'train')
    )

    train_det_dataset = MultiTaskDataset(
        data_root=data_root,
        task_type='detection',
        split='train',
        transform=get_transforms('detection', 'train')
    )

    train_cls_dataset = MultiTaskDataset(
        data_root=data_root,
        task_type='classification',
        split='train',
        transform=get_transforms('classification', 'train')
    )

    # 創建 DataLoader
    train_seg_loader = DataLoader(train_seg_dataset, batch_size=8, shuffle=True)
    train_det_loader = DataLoader(train_det_dataset, batch_size=8, shuffle=True,  collate_fn=detection_collate)
    train_cls_loader = DataLoader(train_cls_dataset, batch_size=32, shuffle=True)

    print(f"分割數據集: {len(train_seg_dataset)} 樣本")
    print(f"檢測數據集: {len(train_det_dataset)} 樣本")
    print(f"分類數據集: {len(train_cls_dataset)} 樣本")

    # 測試載入一個批次
    seg_batch = next(iter(train_seg_loader))
    det_batch = next(iter(train_det_loader))
    cls_batch = next(iter(train_cls_loader))

    print(f"分割批次 - 圖片: {seg_batch[0].shape}, 標籤: {seg_batch[1].shape}")
    print(f"檢測批次 - 圖片: {det_batch[0].shape}")
    print(f"分類批次 - 圖片: {cls_batch[0].shape}, 標籤: {cls_batch[1].shape}")


    # 檢查模型參數數量
    total_params = sum(p.numel() for p in model.parameters())
    print(f'Total parameters: {total_params/1e6:.2f}M')

    # Stage 1: 訓練分割任務
    print("Stage 1: Training Segmentation")
    trainer.train_stage(train_seg_loader, 'segmentation')

    # 記錄基準性能
    mIoU_base = evaluate_segmentation(model, train_seg_loader, device)
    print(f"Segmentation baseline mIoU: {mIoU_base:.4f}")

    # 設定EWC
    trainer.set_ewc({'seg': train_seg_loader})

    # Stage 2: 訓練檢測任務
    print("Stage 2: Training Detection")
    trainer.train_stage(train_det_loader, 'detection', num_epochs=100)

    mAP_base = evaluate_detection(model, train_det_loader, device)
    print(f"Detection baseline mAP: {mAP_base:.4f}")

    trainer.set_ewc({'seg': train_seg_loader, 'det': train_det_loader})

    # Stage 3: 訓練分類任務
    print("Stage 3: Training Classification")
    trainer.train_stage(train_cls_loader, 'classification')

    # 最終評估
    final_mIoU = evaluate_segmentation(model, train_seg_loader, device)
    final_mAP = evaluate_detection(model, train_det_loader, device)
    final_acc = evaluate_classification(model, train_cls_loader, device)

    print(f"Final Results:")
    print(f"mIoU: {final_mIoU:.4f} (drop: {mIoU_base - final_mIoU:.4f})")
    print(f"mAP: {final_mAP:.4f}")
    print(f"Top-1 Accuracy: {final_acc[0]:.4f}")

    # 儲存模型
    torch.save(model.state_dict(), 'multitask_model.pt')
    print("Model saved.")

if __name__ == "__main__":
    main()




載入分割樣本: 240 個
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
載入檢測樣本: 240 個
載入分類樣本: 240 個
分割數據集: 240 樣本
檢測數據集: 240 樣本
分類數據集: 240 樣本


  boxes = torch.tensor(boxes, dtype=torch.float32)


分割批次 - 圖片: torch.Size([8, 3, 512, 512]), 標籤: torch.Size([8, 512, 512])
檢測批次 - 圖片: torch.Size([8, 3, 512, 512])
分類批次 - 圖片: torch.Size([32, 3, 512, 512]), 標籤: torch.Size([32])
Total parameters: 3.15M
Stage 1: Training Segmentation


Training segmentation:   2%|▏         | 1/60 [00:30<29:54, 30.42s/it]

segmentation Epoch 1/60, Loss: 2.3282


Training segmentation:   3%|▎         | 2/60 [01:00<29:19, 30.34s/it]

segmentation Epoch 2/60, Loss: 1.7287


Training segmentation:   5%|▌         | 3/60 [01:30<28:45, 30.27s/it]

segmentation Epoch 3/60, Loss: 1.5312


Training segmentation:   7%|▋         | 4/60 [02:01<28:16, 30.29s/it]

segmentation Epoch 4/60, Loss: 1.3623


Training segmentation:   8%|▊         | 5/60 [02:31<27:43, 30.24s/it]

segmentation Epoch 5/60, Loss: 1.2260


Training segmentation:  10%|█         | 6/60 [03:01<27:15, 30.29s/it]

segmentation Epoch 6/60, Loss: 1.1106


Training segmentation:  12%|█▏        | 7/60 [03:32<26:44, 30.28s/it]

segmentation Epoch 7/60, Loss: 0.9542


Training segmentation:  13%|█▎        | 8/60 [04:02<26:14, 30.28s/it]

segmentation Epoch 8/60, Loss: 0.8482


Training segmentation:  15%|█▌        | 9/60 [04:32<25:43, 30.27s/it]

segmentation Epoch 9/60, Loss: 0.7671


Training segmentation:  17%|█▋        | 10/60 [05:02<25:11, 30.24s/it]

segmentation Epoch 10/60, Loss: 0.7078


Training segmentation:  18%|█▊        | 11/60 [05:33<24:42, 30.26s/it]

segmentation Epoch 11/60, Loss: 0.6311


Training segmentation:  20%|██        | 12/60 [06:03<24:11, 30.25s/it]

segmentation Epoch 12/60, Loss: 0.5716


Training segmentation:  22%|██▏       | 13/60 [06:33<23:42, 30.27s/it]

segmentation Epoch 13/60, Loss: 0.5344


Training segmentation:  23%|██▎       | 14/60 [07:03<23:10, 30.23s/it]

segmentation Epoch 14/60, Loss: 0.5084


Training segmentation:  25%|██▌       | 15/60 [07:33<22:40, 30.24s/it]

segmentation Epoch 15/60, Loss: 0.4500


Training segmentation:  27%|██▋       | 16/60 [08:04<22:10, 30.23s/it]

segmentation Epoch 16/60, Loss: 0.4104


Training segmentation:  28%|██▊       | 17/60 [08:34<21:41, 30.27s/it]

segmentation Epoch 17/60, Loss: 0.4140


Training segmentation:  30%|███       | 18/60 [09:04<21:09, 30.23s/it]

segmentation Epoch 18/60, Loss: 0.3587


Training segmentation:  32%|███▏      | 19/60 [09:34<20:39, 30.24s/it]

segmentation Epoch 19/60, Loss: 0.3454


Training segmentation:  33%|███▎      | 20/60 [10:05<20:08, 30.22s/it]

segmentation Epoch 20/60, Loss: 0.3111


Training segmentation:  35%|███▌      | 21/60 [10:35<19:40, 30.28s/it]

segmentation Epoch 21/60, Loss: 0.2970


Training segmentation:  37%|███▋      | 22/60 [11:05<19:10, 30.27s/it]

segmentation Epoch 22/60, Loss: 0.2922


Training segmentation:  38%|███▊      | 23/60 [11:35<18:39, 30.25s/it]

segmentation Epoch 23/60, Loss: 0.2654


Training segmentation:  40%|████      | 24/60 [12:06<18:07, 30.22s/it]

segmentation Epoch 24/60, Loss: 0.2570


Training segmentation:  42%|████▏     | 25/60 [12:36<17:37, 30.23s/it]

segmentation Epoch 25/60, Loss: 0.2386


Training segmentation:  43%|████▎     | 26/60 [13:06<17:07, 30.21s/it]

segmentation Epoch 26/60, Loss: 0.2262


Training segmentation:  45%|████▌     | 27/60 [13:36<16:36, 30.21s/it]

segmentation Epoch 27/60, Loss: 0.2258


Training segmentation:  47%|████▋     | 28/60 [14:06<16:06, 30.20s/it]

segmentation Epoch 28/60, Loss: 0.2107


Training segmentation:  48%|████▊     | 29/60 [14:37<15:35, 30.19s/it]

segmentation Epoch 29/60, Loss: 0.1943


Training segmentation:  50%|█████     | 30/60 [15:07<15:05, 30.18s/it]

segmentation Epoch 30/60, Loss: 0.1934


Training segmentation:  52%|█████▏    | 31/60 [15:37<14:36, 30.23s/it]

segmentation Epoch 31/60, Loss: 0.1910


Training segmentation:  53%|█████▎    | 32/60 [16:07<14:06, 30.23s/it]

segmentation Epoch 32/60, Loss: 0.1820


Training segmentation:  55%|█████▌    | 33/60 [16:38<13:36, 30.22s/it]

segmentation Epoch 33/60, Loss: 0.1843


Training segmentation:  57%|█████▋    | 34/60 [17:08<13:05, 30.23s/it]

segmentation Epoch 34/60, Loss: 0.1656


Training segmentation:  58%|█████▊    | 35/60 [17:38<12:34, 30.20s/it]

segmentation Epoch 35/60, Loss: 0.1699


Training segmentation:  60%|██████    | 36/60 [18:08<12:05, 30.23s/it]

segmentation Epoch 36/60, Loss: 0.1589


Training segmentation:  62%|██████▏   | 37/60 [18:38<11:34, 30.21s/it]

segmentation Epoch 37/60, Loss: 0.1429


Training segmentation:  63%|██████▎   | 38/60 [19:09<11:04, 30.23s/it]

segmentation Epoch 38/60, Loss: 0.1484


Training segmentation:  65%|██████▌   | 39/60 [19:39<10:34, 30.20s/it]

segmentation Epoch 39/60, Loss: 0.1390


Training segmentation:  67%|██████▋   | 40/60 [20:09<10:04, 30.22s/it]

segmentation Epoch 40/60, Loss: 0.1470


Training segmentation:  68%|██████▊   | 41/60 [20:39<09:33, 30.19s/it]

segmentation Epoch 41/60, Loss: 0.1434


Training segmentation:  70%|███████   | 42/60 [21:09<09:04, 30.22s/it]

segmentation Epoch 42/60, Loss: 0.1440


Training segmentation:  72%|███████▏  | 43/60 [21:40<08:33, 30.23s/it]

segmentation Epoch 43/60, Loss: 0.1295


Training segmentation:  73%|███████▎  | 44/60 [22:10<08:04, 30.28s/it]

segmentation Epoch 44/60, Loss: 0.1192


Training segmentation:  75%|███████▌  | 45/60 [22:40<07:33, 30.24s/it]

segmentation Epoch 45/60, Loss: 0.1292


Training segmentation:  77%|███████▋  | 46/60 [23:11<07:03, 30.26s/it]

segmentation Epoch 46/60, Loss: 0.1160


Training segmentation:  78%|███████▊  | 47/60 [23:41<06:33, 30.24s/it]

segmentation Epoch 47/60, Loss: 0.1132


Training segmentation:  80%|████████  | 48/60 [24:11<06:02, 30.25s/it]

segmentation Epoch 48/60, Loss: 0.1154


Training segmentation:  82%|████████▏ | 49/60 [24:41<05:32, 30.23s/it]

segmentation Epoch 49/60, Loss: 0.1145


Training segmentation:  83%|████████▎ | 50/60 [25:12<05:02, 30.26s/it]

segmentation Epoch 50/60, Loss: 0.1131


Training segmentation:  85%|████████▌ | 51/60 [25:42<04:32, 30.24s/it]

segmentation Epoch 51/60, Loss: 0.1040


Training segmentation:  87%|████████▋ | 52/60 [26:12<04:01, 30.23s/it]

segmentation Epoch 52/60, Loss: 0.1004


Training segmentation:  88%|████████▊ | 53/60 [26:42<03:31, 30.23s/it]

segmentation Epoch 53/60, Loss: 0.0953


Training segmentation:  90%|█████████ | 54/60 [27:12<03:01, 30.25s/it]

segmentation Epoch 54/60, Loss: 0.0949


Training segmentation:  92%|█████████▏| 55/60 [27:43<02:31, 30.26s/it]

segmentation Epoch 55/60, Loss: 0.0923


Training segmentation:  93%|█████████▎| 56/60 [28:13<02:00, 30.25s/it]

segmentation Epoch 56/60, Loss: 0.0913


Training segmentation:  95%|█████████▌| 57/60 [28:43<01:30, 30.26s/it]

segmentation Epoch 57/60, Loss: 0.0947


Training segmentation:  97%|█████████▋| 58/60 [29:13<01:00, 30.21s/it]

segmentation Epoch 58/60, Loss: 0.0914


Training segmentation:  98%|█████████▊| 59/60 [29:44<00:30, 30.21s/it]

segmentation Epoch 59/60, Loss: 0.0918


Training segmentation: 100%|██████████| 60/60 [30:14<00:00, 30.24s/it]

segmentation Epoch 60/60, Loss: 0.0880





Segmentation baseline mIoU: 0.2154
正在計算 seg 任務的Fisher信息...


seg: 100%|██████████| 30/30 [00:30<00:00,  1.01s/it]


Stage 2: Training Detection


Training detection:   1%|          | 1/100 [00:08<13:27,  8.16s/it]

detection Epoch 1/100, Loss: 88789.5727


Training detection:   2%|▏         | 2/100 [00:16<13:30,  8.27s/it]

detection Epoch 2/100, Loss: 83979.6906


Training detection:   3%|▎         | 3/100 [00:24<13:14,  8.19s/it]

detection Epoch 3/100, Loss: 81314.8398


Training detection:   4%|▍         | 4/100 [00:33<13:14,  8.28s/it]

detection Epoch 4/100, Loss: 79668.5777


Training detection:   5%|▌         | 5/100 [00:41<13:17,  8.39s/it]

detection Epoch 5/100, Loss: 78203.3870


Training detection:   6%|▌         | 6/100 [00:49<12:53,  8.22s/it]

detection Epoch 6/100, Loss: 76852.6362


Training detection:   7%|▋         | 7/100 [00:58<12:53,  8.31s/it]

detection Epoch 7/100, Loss: 75410.5913


Training detection:   8%|▊         | 8/100 [01:06<12:51,  8.39s/it]

detection Epoch 8/100, Loss: 74295.8402


Training detection:   9%|▉         | 9/100 [01:14<12:27,  8.22s/it]

detection Epoch 9/100, Loss: 73085.6979


Training detection:  10%|█         | 10/100 [01:22<12:27,  8.30s/it]

detection Epoch 10/100, Loss: 71582.7035


Training detection:  11%|█         | 11/100 [01:31<12:25,  8.37s/it]

detection Epoch 11/100, Loss: 70804.7913


Training detection:  12%|█▏        | 12/100 [01:39<12:03,  8.22s/it]

detection Epoch 12/100, Loss: 69773.1457


Training detection:  13%|█▎        | 13/100 [01:47<11:59,  8.27s/it]

detection Epoch 13/100, Loss: 68810.7745


Training detection:  14%|█▍        | 14/100 [01:56<11:54,  8.31s/it]

detection Epoch 14/100, Loss: 67757.2260


Training detection:  15%|█▌        | 15/100 [02:03<11:33,  8.15s/it]

detection Epoch 15/100, Loss: 66756.3233


Training detection:  16%|█▌        | 16/100 [02:12<11:29,  8.21s/it]

detection Epoch 16/100, Loss: 65863.0388


Training detection:  17%|█▋        | 17/100 [02:20<11:24,  8.24s/it]

detection Epoch 17/100, Loss: 65103.1055


Training detection:  18%|█▊        | 18/100 [02:28<11:07,  8.14s/it]

detection Epoch 18/100, Loss: 64215.3896


Training detection:  19%|█▉        | 19/100 [02:37<11:12,  8.31s/it]

detection Epoch 19/100, Loss: 63115.0441


Training detection:  20%|██        | 20/100 [02:45<11:10,  8.38s/it]

detection Epoch 20/100, Loss: 62582.9613


Training detection:  21%|██        | 21/100 [02:53<10:50,  8.23s/it]

detection Epoch 21/100, Loss: 61682.8177


Training detection:  22%|██▏       | 22/100 [03:02<10:47,  8.31s/it]

detection Epoch 22/100, Loss: 60902.5546


Training detection:  23%|██▎       | 23/100 [03:10<10:42,  8.34s/it]

detection Epoch 23/100, Loss: 60111.0460


Training detection:  24%|██▍       | 24/100 [03:18<10:24,  8.22s/it]

detection Epoch 24/100, Loss: 59134.4471


Training detection:  25%|██▌       | 25/100 [03:26<10:20,  8.28s/it]

detection Epoch 25/100, Loss: 58545.3872


Training detection:  26%|██▌       | 26/100 [03:35<10:16,  8.34s/it]

detection Epoch 26/100, Loss: 57726.4642


Training detection:  27%|██▋       | 27/100 [03:43<10:00,  8.22s/it]

detection Epoch 27/100, Loss: 57080.8973


Training detection:  28%|██▊       | 28/100 [03:51<09:53,  8.25s/it]

detection Epoch 28/100, Loss: 56192.7548


Training detection:  29%|██▉       | 29/100 [03:59<09:49,  8.30s/it]

detection Epoch 29/100, Loss: 55518.8898


Training detection:  30%|███       | 30/100 [04:08<09:39,  8.27s/it]

detection Epoch 30/100, Loss: 54782.9815


Training detection:  31%|███       | 31/100 [04:16<09:29,  8.25s/it]

detection Epoch 31/100, Loss: 54517.3759


Training detection:  32%|███▏      | 32/100 [04:24<09:24,  8.30s/it]

detection Epoch 32/100, Loss: 53500.5818


Training detection:  33%|███▎      | 33/100 [04:32<09:12,  8.24s/it]

detection Epoch 33/100, Loss: 52699.5732


Training detection:  34%|███▍      | 34/100 [04:40<08:59,  8.17s/it]

detection Epoch 34/100, Loss: 51943.0987


Training detection:  35%|███▌      | 35/100 [04:49<08:55,  8.25s/it]

detection Epoch 35/100, Loss: 51184.3846


Training detection:  36%|███▌      | 36/100 [04:57<08:49,  8.27s/it]

detection Epoch 36/100, Loss: 50609.2373


Training detection:  37%|███▋      | 37/100 [05:05<08:36,  8.19s/it]

detection Epoch 37/100, Loss: 49733.4247


Training detection:  38%|███▊      | 38/100 [05:14<08:32,  8.26s/it]

detection Epoch 38/100, Loss: 49090.2128


Training detection:  39%|███▉      | 39/100 [05:22<08:26,  8.30s/it]

detection Epoch 39/100, Loss: 48636.8113


Training detection:  40%|████      | 40/100 [05:30<08:08,  8.14s/it]

detection Epoch 40/100, Loss: 47882.7477


Training detection:  41%|████      | 41/100 [05:38<08:06,  8.25s/it]

detection Epoch 41/100, Loss: 47442.5134


Training detection:  42%|████▏     | 42/100 [05:47<08:00,  8.29s/it]

detection Epoch 42/100, Loss: 46641.9528


Training detection:  43%|████▎     | 43/100 [05:54<07:43,  8.13s/it]

detection Epoch 43/100, Loss: 46209.1201


Training detection:  44%|████▍     | 44/100 [06:03<07:41,  8.24s/it]

detection Epoch 44/100, Loss: 45430.6288


Training detection:  45%|████▌     | 45/100 [06:11<07:35,  8.29s/it]

detection Epoch 45/100, Loss: 44943.2676


Training detection:  46%|████▌     | 46/100 [06:19<07:18,  8.13s/it]

detection Epoch 46/100, Loss: 44623.1608


Training detection:  47%|████▋     | 47/100 [06:27<07:13,  8.18s/it]

detection Epoch 47/100, Loss: 44015.2195


Training detection:  48%|████▊     | 48/100 [06:36<07:09,  8.25s/it]

detection Epoch 48/100, Loss: 43493.7066


Training detection:  49%|████▉     | 49/100 [06:44<06:53,  8.11s/it]

detection Epoch 49/100, Loss: 42690.3494


Training detection:  50%|█████     | 50/100 [06:52<06:51,  8.22s/it]

detection Epoch 50/100, Loss: 42308.7622


Training detection:  51%|█████     | 51/100 [07:01<06:47,  8.31s/it]

detection Epoch 51/100, Loss: 41653.0047


Training detection:  52%|█████▏    | 52/100 [07:08<06:32,  8.17s/it]

detection Epoch 52/100, Loss: 41412.4721


Training detection:  53%|█████▎    | 53/100 [07:17<06:27,  8.25s/it]

detection Epoch 53/100, Loss: 40692.9938


Training detection:  54%|█████▍    | 54/100 [07:25<06:21,  8.29s/it]

detection Epoch 54/100, Loss: 40261.1289


Training detection:  55%|█████▌    | 55/100 [07:33<06:07,  8.16s/it]

detection Epoch 55/100, Loss: 39784.8995


Training detection:  56%|█████▌    | 56/100 [07:41<06:01,  8.23s/it]

detection Epoch 56/100, Loss: 39212.1604


Training detection:  57%|█████▋    | 57/100 [07:50<05:55,  8.26s/it]

detection Epoch 57/100, Loss: 38808.6773


Training detection:  58%|█████▊    | 58/100 [07:58<05:41,  8.13s/it]

detection Epoch 58/100, Loss: 38410.0374


Training detection:  59%|█████▉    | 59/100 [08:06<05:38,  8.24s/it]

detection Epoch 59/100, Loss: 38035.9291


Training detection:  60%|██████    | 60/100 [08:15<05:33,  8.33s/it]

detection Epoch 60/100, Loss: 37558.6655


Training detection:  61%|██████    | 61/100 [08:23<05:19,  8.20s/it]

detection Epoch 61/100, Loss: 37046.7529


Training detection:  62%|██████▏   | 62/100 [08:31<05:16,  8.32s/it]

detection Epoch 62/100, Loss: 36755.6509


Training detection:  63%|██████▎   | 63/100 [08:40<05:09,  8.37s/it]

detection Epoch 63/100, Loss: 36422.5391


Training detection:  64%|██████▍   | 64/100 [08:48<04:59,  8.31s/it]

detection Epoch 64/100, Loss: 36129.5602


Training detection:  65%|██████▌   | 65/100 [08:56<04:51,  8.34s/it]

detection Epoch 65/100, Loss: 35630.7193


Training detection:  66%|██████▌   | 66/100 [09:05<04:46,  8.43s/it]

detection Epoch 66/100, Loss: 35391.9792


Training detection:  67%|██████▋   | 67/100 [09:13<04:37,  8.41s/it]

detection Epoch 67/100, Loss: 34919.1454


Training detection:  68%|██████▊   | 68/100 [09:21<04:25,  8.30s/it]

detection Epoch 68/100, Loss: 34670.6452


Training detection:  69%|██████▉   | 69/100 [09:30<04:19,  8.39s/it]

detection Epoch 69/100, Loss: 34166.3217


Training detection:  70%|███████   | 70/100 [09:38<04:13,  8.44s/it]

detection Epoch 70/100, Loss: 33786.6117


Training detection:  71%|███████   | 71/100 [09:46<03:59,  8.26s/it]

detection Epoch 71/100, Loss: 33665.1252


Training detection:  72%|███████▏  | 72/100 [09:55<03:53,  8.34s/it]

detection Epoch 72/100, Loss: 33398.7689


Training detection:  73%|███████▎  | 73/100 [10:03<03:45,  8.35s/it]

detection Epoch 73/100, Loss: 33007.8867


Training detection:  74%|███████▍  | 74/100 [10:11<03:33,  8.20s/it]

detection Epoch 74/100, Loss: 32616.1794


Training detection:  75%|███████▌  | 75/100 [10:19<03:26,  8.27s/it]

detection Epoch 75/100, Loss: 32202.0765


Training detection:  76%|███████▌  | 76/100 [10:28<03:20,  8.36s/it]

detection Epoch 76/100, Loss: 31972.6919


Training detection:  77%|███████▋  | 77/100 [10:36<03:08,  8.20s/it]

detection Epoch 77/100, Loss: 31799.8475


Training detection:  78%|███████▊  | 78/100 [10:44<03:02,  8.28s/it]

detection Epoch 78/100, Loss: 31522.1243


Training detection:  79%|███████▉  | 79/100 [10:53<02:55,  8.35s/it]

detection Epoch 79/100, Loss: 31275.1944


Training detection:  80%|████████  | 80/100 [11:01<02:44,  8.22s/it]

detection Epoch 80/100, Loss: 31029.2341


Training detection:  81%|████████  | 81/100 [11:09<02:37,  8.30s/it]

detection Epoch 81/100, Loss: 30635.0790


Training detection:  82%|████████▏ | 82/100 [11:18<02:29,  8.33s/it]

detection Epoch 82/100, Loss: 30428.9001


Training detection:  83%|████████▎ | 83/100 [11:25<02:18,  8.17s/it]

detection Epoch 83/100, Loss: 30227.1929


Training detection:  84%|████████▍ | 84/100 [11:34<02:12,  8.27s/it]

detection Epoch 84/100, Loss: 29942.8290


Training detection:  85%|████████▌ | 85/100 [11:42<02:04,  8.31s/it]

detection Epoch 85/100, Loss: 29549.9154


Training detection:  86%|████████▌ | 86/100 [11:50<01:54,  8.18s/it]

detection Epoch 86/100, Loss: 29527.6523


Training detection:  87%|████████▋ | 87/100 [11:59<01:47,  8.27s/it]

detection Epoch 87/100, Loss: 29186.8536


Training detection:  88%|████████▊ | 88/100 [12:07<01:40,  8.37s/it]

detection Epoch 88/100, Loss: 28956.4856


Training detection:  89%|████████▉ | 89/100 [12:15<01:31,  8.31s/it]

detection Epoch 89/100, Loss: 28821.4696


Training detection:  90%|█████████ | 90/100 [12:24<01:22,  8.24s/it]

detection Epoch 90/100, Loss: 28697.9924


Training detection:  91%|█████████ | 91/100 [12:32<01:14,  8.31s/it]

detection Epoch 91/100, Loss: 28522.6496


Training detection:  92%|█████████▏| 92/100 [12:40<01:06,  8.35s/it]

detection Epoch 92/100, Loss: 28249.5379


Training detection:  93%|█████████▎| 93/100 [12:49<00:57,  8.28s/it]

detection Epoch 93/100, Loss: 27955.7989


Training detection:  94%|█████████▍| 94/100 [12:57<00:50,  8.34s/it]

detection Epoch 94/100, Loss: 27642.7477


Training detection:  95%|█████████▌| 95/100 [13:05<00:41,  8.36s/it]

detection Epoch 95/100, Loss: 27611.2278


Training detection:  96%|█████████▌| 96/100 [13:13<00:32,  8.23s/it]

detection Epoch 96/100, Loss: 27489.8229


Training detection:  97%|█████████▋| 97/100 [13:22<00:24,  8.33s/it]

detection Epoch 97/100, Loss: 27261.2989


Training detection:  98%|█████████▊| 98/100 [13:30<00:16,  8.36s/it]

detection Epoch 98/100, Loss: 27112.3400


Training detection:  99%|█████████▉| 99/100 [13:38<00:08,  8.22s/it]

detection Epoch 99/100, Loss: 27034.0083


Training detection: 100%|██████████| 100/100 [13:47<00:00,  8.27s/it]

detection Epoch 100/100, Loss: 27019.0214





Detection baseline mAP: 0.0000
正在計算 seg 任務的Fisher信息...


seg: 100%|██████████| 30/30 [00:30<00:00,  1.02s/it]


正在計算 det 任務的Fisher信息...


det: 100%|██████████| 30/30 [00:06<00:00,  4.79it/s]


Stage 3: Training Classification


Training classification:   2%|▏         | 1/60 [00:04<04:22,  4.45s/it]

classification Epoch 1/60, Loss: 45078.6706


Training classification:   3%|▎         | 2/60 [00:08<04:01,  4.17s/it]

classification Epoch 2/60, Loss: 19695.3963


Training classification:   5%|▌         | 3/60 [00:12<03:52,  4.07s/it]

classification Epoch 3/60, Loss: 8468.5839


Training classification:   7%|▋         | 4/60 [00:16<03:56,  4.22s/it]

classification Epoch 4/60, Loss: 2686.2763


Training classification:   8%|▊         | 5/60 [00:20<03:46,  4.12s/it]

classification Epoch 5/60, Loss: 1936.8615


Training classification:  10%|█         | 6/60 [00:24<03:39,  4.06s/it]

classification Epoch 6/60, Loss: 556.5414


Training classification:  12%|█▏        | 7/60 [00:29<03:41,  4.18s/it]

classification Epoch 7/60, Loss: 389.4873


Training classification:  13%|█▎        | 8/60 [00:33<03:33,  4.11s/it]

classification Epoch 8/60, Loss: 119.0512


Training classification:  15%|█▌        | 9/60 [00:37<03:28,  4.08s/it]

classification Epoch 9/60, Loss: 68.4999


Training classification:  17%|█▋        | 10/60 [00:41<03:29,  4.19s/it]

classification Epoch 10/60, Loss: 34.6069


Training classification:  18%|█▊        | 11/60 [00:45<03:21,  4.12s/it]

classification Epoch 11/60, Loss: 13.7847


Training classification:  20%|██        | 12/60 [00:49<03:17,  4.11s/it]

classification Epoch 12/60, Loss: 7.1919


Training classification:  22%|██▏       | 13/60 [00:53<03:16,  4.19s/it]

classification Epoch 13/60, Loss: 4.6108


Training classification:  23%|██▎       | 14/60 [00:57<03:08,  4.10s/it]

classification Epoch 14/60, Loss: 3.3444


Training classification:  25%|██▌       | 15/60 [01:01<03:04,  4.11s/it]

classification Epoch 15/60, Loss: 2.6534


Training classification:  27%|██▋       | 16/60 [01:06<03:03,  4.17s/it]

classification Epoch 16/60, Loss: 2.3895


Training classification:  28%|██▊       | 17/60 [01:10<02:57,  4.12s/it]

classification Epoch 17/60, Loss: 2.2685


Training classification:  30%|███       | 18/60 [01:14<02:52,  4.11s/it]

classification Epoch 18/60, Loss: 2.2125


Training classification:  32%|███▏      | 19/60 [01:18<02:50,  4.16s/it]

classification Epoch 19/60, Loss: 2.1880


Training classification:  33%|███▎      | 20/60 [01:22<02:43,  4.10s/it]

classification Epoch 20/60, Loss: 2.1828


Training classification:  35%|███▌      | 21/60 [01:26<02:40,  4.12s/it]

classification Epoch 21/60, Loss: 2.1703


Training classification:  37%|███▋      | 22/60 [01:31<02:37,  4.16s/it]

classification Epoch 22/60, Loss: 2.1605


Training classification:  38%|███▊      | 23/60 [01:34<02:31,  4.10s/it]

classification Epoch 23/60, Loss: 2.1515


Training classification:  40%|████      | 24/60 [01:39<02:28,  4.12s/it]

classification Epoch 24/60, Loss: 2.1548


Training classification:  42%|████▏     | 25/60 [01:43<02:25,  4.15s/it]

classification Epoch 25/60, Loss: 2.1251


Training classification:  43%|████▎     | 26/60 [01:47<02:19,  4.09s/it]

classification Epoch 26/60, Loss: 2.0911


Training classification:  45%|████▌     | 27/60 [01:51<02:16,  4.13s/it]

classification Epoch 27/60, Loss: 2.0767


Training classification:  47%|████▋     | 28/60 [01:55<02:13,  4.16s/it]

classification Epoch 28/60, Loss: 2.0651


Training classification:  48%|████▊     | 29/60 [01:59<02:07,  4.10s/it]

classification Epoch 29/60, Loss: 2.0691


Training classification:  50%|█████     | 30/60 [02:04<02:04,  4.16s/it]

classification Epoch 30/60, Loss: 2.0629


Training classification:  52%|█████▏    | 31/60 [02:08<02:00,  4.15s/it]

classification Epoch 31/60, Loss: 2.0489


Training classification:  53%|█████▎    | 32/60 [02:12<01:55,  4.12s/it]

classification Epoch 32/60, Loss: 2.0332


Training classification:  55%|█████▌    | 33/60 [02:16<01:53,  4.19s/it]

classification Epoch 33/60, Loss: 2.0291


Training classification:  57%|█████▋    | 34/60 [02:20<01:48,  4.16s/it]

classification Epoch 34/60, Loss: 2.0251


Training classification:  58%|█████▊    | 35/60 [02:24<01:42,  4.10s/it]

classification Epoch 35/60, Loss: 2.0266


Training classification:  60%|██████    | 36/60 [02:29<01:40,  4.19s/it]

classification Epoch 36/60, Loss: 2.0072


Training classification:  62%|██████▏   | 37/60 [02:33<01:35,  4.14s/it]

classification Epoch 37/60, Loss: 1.9984


Training classification:  63%|██████▎   | 38/60 [02:37<01:30,  4.10s/it]

classification Epoch 38/60, Loss: 1.9854


Training classification:  65%|██████▌   | 39/60 [02:41<01:28,  4.22s/it]

classification Epoch 39/60, Loss: 1.9821


Training classification:  67%|██████▋   | 40/60 [02:45<01:22,  4.15s/it]

classification Epoch 40/60, Loss: 1.9925


Training classification:  68%|██████▊   | 41/60 [02:49<01:18,  4.11s/it]

classification Epoch 41/60, Loss: 1.9791


Training classification:  70%|███████   | 42/60 [02:54<01:15,  4.21s/it]

classification Epoch 42/60, Loss: 1.9736


Training classification:  72%|███████▏  | 43/60 [02:57<01:10,  4.14s/it]

classification Epoch 43/60, Loss: 1.9621


Training classification:  73%|███████▎  | 44/60 [03:01<01:05,  4.10s/it]

classification Epoch 44/60, Loss: 1.9530


Training classification:  75%|███████▌  | 45/60 [03:06<01:03,  4.21s/it]

classification Epoch 45/60, Loss: 1.9381


Training classification:  77%|███████▋  | 46/60 [03:10<00:58,  4.15s/it]

classification Epoch 46/60, Loss: 1.9282


Training classification:  78%|███████▊  | 47/60 [03:14<00:53,  4.11s/it]

classification Epoch 47/60, Loss: 1.9140


Training classification:  80%|████████  | 48/60 [03:18<00:50,  4.21s/it]

classification Epoch 48/60, Loss: 1.9186


Training classification:  82%|████████▏ | 49/60 [03:22<00:45,  4.14s/it]

classification Epoch 49/60, Loss: 1.9093


Training classification:  83%|████████▎ | 50/60 [03:26<00:40,  4.10s/it]

classification Epoch 50/60, Loss: 1.8968


Training classification:  85%|████████▌ | 51/60 [03:31<00:37,  4.20s/it]

classification Epoch 51/60, Loss: 1.8785


Training classification:  87%|████████▋ | 52/60 [03:35<00:33,  4.14s/it]

classification Epoch 52/60, Loss: 1.8671


Training classification:  88%|████████▊ | 53/60 [03:39<00:28,  4.10s/it]

classification Epoch 53/60, Loss: 1.8758


Training classification:  90%|█████████ | 54/60 [03:43<00:25,  4.21s/it]

classification Epoch 54/60, Loss: 1.8571


Training classification:  92%|█████████▏| 55/60 [03:47<00:20,  4.14s/it]

classification Epoch 55/60, Loss: 1.8565


Training classification:  93%|█████████▎| 56/60 [03:51<00:16,  4.10s/it]

classification Epoch 56/60, Loss: 1.8307


Training classification:  95%|█████████▌| 57/60 [03:56<00:12,  4.20s/it]

classification Epoch 57/60, Loss: 1.8178


Training classification:  97%|█████████▋| 58/60 [04:00<00:08,  4.14s/it]

classification Epoch 58/60, Loss: 1.8150


Training classification:  98%|█████████▊| 59/60 [04:04<00:04,  4.10s/it]

classification Epoch 59/60, Loss: 1.8282


Training classification: 100%|██████████| 60/60 [04:08<00:00,  4.15s/it]

classification Epoch 60/60, Loss: 1.7844





Final Results:
mIoU: 0.0000 (drop: 0.2154)
mAP: 0.0000
Top-1 Accuracy: 55.8333
Model saved.
