<a href="https://colab.research.google.com/github/trie0000/external/blob/main/code_20260104_10%3A48.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# -*- coding: utf-8 -*-
"""DL_Basic_2025_Competition_NYUv2_baseline.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/17t7uAU0aST5aUt6sIJCqFneGlQrxFrJY

# Deep Learning 基礎講座　最終課題: NYUv2 セマンティックセグメンテーションcoarse_w=0」

## 概要
RGB画像から、画像内の各ピクセルがどのクラスに属するかを予測するセマンティックセグメンテーションタスク.

### データセット
- データセット: NYUv2 dataset
- 訓練データ: 795枚
- テストデータ: 654枚
- 入力: RGB画像 + 深度マップ（元画像サイズは可変）
- 出力: 13クラスのセグメンテーションマップ
- 評価指標: Mean IoU (Intersection over Union)

### データセットの詳細（[NYU Depth Dataset V2](https://cs.nyu.edu/~fergus/datasets/nyu_depth_v2.html)）
- 画像は屋内シーンを撮影したもので、家具や壁、床などの物体が含まれています.
- 各画像に対して13クラスのセグメンテーションラベルが提供されます.
- データは以下のディレクトリ構造で提供:
```
data/NYUv2/
├─train/
│  ├─image/      # RGB画像
│  │    000000.png
│  │    ...
│  │
│  ├─depth/      # 深度マップ
│  │    000000.png
│  │    ...
│  │
│  └─label/      # 13クラスセグメンテーション（教師ラベル）
│       000000.png
│       ...
└─test/
   ├─image/      # RGB画像
   │    000000.png
   │    ...
   │  ├─depth/   # 深度マップ
   │    000000.png
   │    ...
```

### タスクの詳細
- 入力のRGB画像と深度マップから、各ピクセルが13クラスのどれに属するかを予測するタスクです.
- 評価はMean IoUを使用します．
  - 各クラスごとにIoUを計算し、その平均を取ります.
  - IoUは以下の式で計算:
  $$IoU = \frac{TP}{TP + FP + FN}$$
    - TP: True Positive（正しく予測されたピクセル数）
    - FP: False Positive（誤って予測されたピクセル数）
    - FN: False Negative（見逃したピクセル数）

### 前処理
- 入力画像は512×512にリサイズされます.
- ピクセル値は0-1に正規化されます.
- セグメンテーションラベルは0-12の整数値（13クラス）です．
  - 255はignore index（評価から除外）

### 提出形式
- テスト画像（RGB + Depth）の各ピクセルに対してクラス（0~12）を予測したものをnumpy配列として保存されます.
- ファイル名: `submission.npy`
- 配列の形状: [テストデータ数, 高さ, 幅]
- 各ピクセルの値: 0-12の整数（予測クラス）

## 考えられる工夫の例
- 事前学習モデルの fine-tuning
    - ImageNetなどで事前学習されたモデルを本データセットでfine-tuningすることで性能向上が見込めます.
- 損失関数の再設計
    - クラスごとの出現頻度に応じて損失を補正するように損失関数を設計すると、クラス分布の不均衡に対してロバストな学習ができます.
- 画像の前処理
    - RandomResizedCrop / Flip / ColorJitter 等のデータ拡張を追加することで，汎化性能の向上が見込めます．

## 修了要件を満たす条件
- ベースラインでは，omnicampus 上での性能評価において， 38.2% となります．したがって，ベースラインである 38.2% を超えた提出のみ，修了要件として認めます．
- ベースラインから改善を加えることで， 50%以上に性能向上することを運営で確認しています．こちらを 1つの指標として取り組んでみてください．

## 注意点
- 最終的な予測モデルは，**配布している訓練データを用いて学習**（ファインチューニング含む）したものとしてください．
- 学習を行わず，**事前学習済みモデルの知識のみを利用した推論は禁止**します．
（例: ChatGPT 等の LLM に入力して推論を得るのみ）

### 事前学習モデルの利用
許可される事項
- **構成要素としての事前学習モデルの利用**: 自身で実装したアーキテクチャの一部（特徴抽出，埋め込みなど）として事前学習モデル（BERT，ViT など）を利用することは可能です．
- **ファインチューニング**: 上記の用途で利用している事前学習モデルのファインチューニングは可能です．

禁止される事項
- **タスク解決用の事前学習モデルの利用**: transformers などで提供されている，対象タスクを直接解くための事前学習モデルでそのまま推論のみ，またはファインチューニングのみで利用することは禁止とします．
  - 禁止事項の例: VQA タスクを直接解くための事前学習モデルを VQA タスクで利用する．

### データの準備
データをダウンロードした際に，google drive したため，利用するために google drive をマウントする必要があります．また， drive 上で展開することができないため，/content ディレクトリ下にコピーし "data.zip" を展開します．
google drive 上に "data.zip" が配置されていない場合は実行できません．google drive 上に "data.zip" (**831MB**) を配置することが可能であれば，"data_download.ipynb" を先に実行してください．難しい場合は，omnicampus 演習環境を利用してください．．
"""

# omnicampus 上では 4 セル目まで実行不要
# ドライブのマウント
from google.colab import drive
drive.mount('/content/drive')

# データダウンロード用の notebook にてgoogle drive への保存後，
# 反映に時間がかかる可能性がありますので，google drive のマウント後，
# data.zip がディレクトリ内にあることを確認してから実行してください．
# data.zip を /content 下にコピーする
!cp "/content/drive/MyDrive/data.zip" "/content"

# Commented out IPython magic to ensure Python compatibility.
# カレントディレクトリ下のファイル群を確認
# data.zip が表示されれば問題ないです
# %ls

# データを解凍する
!unzip data.zip
!mkdir data
!mv train test data/

"""omnicampus 演習環境では，data_download.ipynb のマウント，zip 化，drive へのコピーを実行しないことで，"data.zip" を解凍した形で配置されます．したがって，data ディレクトリが存在するディレクトリをカレントディレクトリとするだけで良いです．


"""

# Commented out IPython magic to ensure Python compatibility.
# omnicampus 実行用
# 以下の例では/workspace/Segmentation/split_data_scripts/omnicampus に data ディレクトリがあると想定
# %cd /workspace/Segmentation/split_data_scripts_omnicampus

# omnicampus 実行用
!pip install h5py scikit-image

"""# import library"""

Mounted at /content/drive
Archive:  data.zip
  inflating: data/train/image/000600.png  
  inflating: data/train/image/000320.png  
  inflating: data/train/image/000491.png  
  inflating: data/train/image/000502.png  
  inflating: data/train/image/000129.png  
  inflating: data/train/image/000044.png  
  inflating: data/train/image/000652.png  
  inflating: data/train/image/000919.png  
  inflating: data/train/image/000528.png  
  inflating: data/train/image/000853.png  
  inflating: data/train/image/000177.png  
  inflating: data/train/image/000584.png  
  inflating: data/train/image/001319.png  
  inflating: data/train/image/000597.png  
  inflating: data/train/image/000223.png  
  inflating: data/train/image/001350.png  
  inflating: data/train/image/000404.png  
  inflating: data/train/image/000488.png  
  inflating: data/train/image/000268.png  
  inflating: data/train/image/000481.png  
  inflating: data/train/image/000341.png  
  inflating: data/train/image/000159.png  
  inflati

'# import library'

In [None]:
# =========================
# UNUSED: Focal CE
# 使う条件: config.use_focal=True
# =========================
import torch
import torch.nn as nn

def focal_ce_loss(logits, target, weight=None, ignore_index=255, gamma=2.0):
    ce = nn.functional.cross_entropy(
        logits, target, weight=weight, ignore_index=ignore_index, reduction="none"
    )
    pt = torch.exp(-ce)
    loss = ((1 - pt) ** gamma) * ce
    valid = (target != ignore_index)
    denom = valid.sum().clamp_min(1)
    return (loss * valid).sum() / denom

# =========================
# UNUSED: Boundary mask / dilation / loss
# 使う条件: config.boundary_w>0 かつ model.use_boundary_head=True
# =========================
import torch
import torch.nn as nn

def boundary_mask_from_label(label, ignore_index=255):
    valid = (label != ignore_index)
    b = torch.zeros_like(label, dtype=torch.bool)
    b[:, 1:, :] |= (label[:, 1:, :] != label[:, :-1, :])
    b[:, :-1, :] |= (label[:, :-1, :] != label[:, 1:, :])
    b[:, :, 1:] |= (label[:, :, 1:] != label[:, :, :-1])
    b[:, :, :-1] |= (label[:, :, :-1] != label[:, :, 1:])
    b &= valid
    return b

def dilate_mask(mask_bool, radius: int):
    if radius <= 0:
        return mask_bool
    x = mask_bool.float().unsqueeze(1)  # [B,1,H,W]
    k = 2 * radius + 1
    y = nn.functional.max_pool2d(x, kernel_size=k, stride=1, padding=radius)
    return (y.squeeze(1) > 0.5)

def boundary_focal_bce_dice_loss(
    boundary_logit, label,
    ignore_index=255,
    radius=2,
    gamma=2.0,
    alpha=0.25,
    dice_w=0.5,
    eps=1e-6
):
    with torch.no_grad():
        b0 = boundary_mask_from_label(label, ignore_index=ignore_index)
        b = dilate_mask(b0, radius=radius).float()
        valid = (label != ignore_index).float()

    logit = boundary_logit.squeeze(1)  # [B,H,W]
    p = torch.sigmoid(logit)
    pt = p * b + (1 - p) * (1 - b)
    w = (alpha * b + (1 - alpha) * (1 - b)) * ((1 - pt).clamp_min(1e-6) ** gamma)

    bce = nn.functional.binary_cross_entropy_with_logits(logit, b, reduction="none")
    loss_focal = (w * bce * valid).sum() / (valid.sum().clamp_min(1.0))

    prob = p * valid
    target01 = b * valid
    inter = (prob * target01).sum()
    union = prob.sum() + target01.sum()
    loss_dice = 1.0 - (2.0 * inter + eps) / (union + eps)

    return (1.0 - dice_w) * loss_focal + dice_w * loss_dice

# =========================
# UNUSED: Coarse supervision helpers
# 使う条件: config.coarse_w>0 かつ 13->coarse(5) のmapを実装してCEする
# =========================
import torch
import torch.nn as nn

def downsample_label_nearest(label, size_hw):
    x = label.unsqueeze(1).float()
    x = nn.functional.interpolate(x, size=size_hw, mode="nearest")
    return x.squeeze(1).long()

In [4]:
# -*- coding: utf-8 -*-
"""
NYUv2 Semantic Segmentation (ResNet50-UNet) + Improved ASPP (Separable + Residual)

変更点（要求反映）:
- 冗長削除:
  - autocast(False) の二重ブロック削除（xのcatは通常処理でOK）
  - loss_coarse/loss_boundaryの0tensor生成を削除（必要時のみ計算）
  - val側の分岐も簡潔化
- ログ強化:
  - class precision/recall/f1 を追加
  - confusion topK (gt->pred count) を追加
  - FN率(=1-recall)をすぐ見れるよう recall を明示
- ASPP強化:
  - SeparableConv(BN+ReLU) を採用
  - Residual (入力を1x1で揃えてprojectに加算)
  - rates を OS=32前提でデフォルト(3,6,9)
  - aspp_out_ch を設定可能（デフォルト512）→ decoderも追従
- zip作成は「そのまま」維持
"""

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"

import time
import random
import json
from dataclasses import dataclass
from zipfile import ZipFile, ZIP_DEFLATED

import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, random_split, Subset
from torchvision.datasets import VisionDataset
from torchvision.transforms import (
    Compose, Resize, ToTensor, Normalize, InterpolationMode, ColorJitter
)
from torchvision import models
from torch.amp import autocast, GradScaler


# =========================
# Class IDs (YOUR dataset)
# =========================
ID_BED     = 0
ID_BOOK    = 1
ID_CEILING = 2
ID_CHAIR   = 3
ID_FLOOR   = 4
ID_CABINET = 5
ID_OBJECT  = 6
ID_PICTURE = 7
ID_SOFA    = 8
ID_DESK    = 9
ID_TV      = 10
ID_WALL    = 11
ID_WINDOW  = 12

CLASS_NAMES_13 = [
    "bed", "book", "ceiling", "chair", "floor", "cabinet", "object",
    "picture", "sofa", "desk", "tv", "wall", "window"
]


# =========================
# Seed
# =========================
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# =========================
# Dataset utilities
# =========================
def pil_label_to_long_tensor(lbl_pil: Image.Image) -> torch.Tensor:
    arr = np.array(lbl_pil, dtype=np.int64)
    return torch.from_numpy(arr).long()

def depth_pil_to_tensor_01(depth_pil: Image.Image, size):
    depth_pil = depth_pil.resize((size[1], size[0]), resample=Image.BILINEAR)
    arr = np.array(depth_pil)

    if arr.dtype == np.uint16 or (arr.max() > 255):
        arr = arr.astype(np.float32) / 65535.0
    else:
        arr = arr.astype(np.float32) / 255.0

    arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
    arr = np.clip(arr, 0.0, 1.0)
    return torch.from_numpy(arr).unsqueeze(0)  # [1,H,W]


class RandomGrayscale3ch:
    def __init__(self, p=0.0):
        self.p = float(p)

    def __call__(self, img: Image.Image) -> Image.Image:
        if random.random() >= self.p:
            return img
        return img.convert("L").convert("RGB")


class NYUv2(VisionDataset):
    def __init__(
        self,
        root: str,
        split: str = "train",
        include_depth: bool = True,
        image_transform=None,
        depth_transform=None,
        target_transform=None,
    ):
        super().__init__(root)
        assert split in ("train", "test")
        self.root = root
        self.split = split
        self.include_depth = include_depth

        images_dir = os.path.join(self.root, self.split, "image")
        img_names = sorted(os.listdir(images_dir))

        self.images = [os.path.join(images_dir, n) for n in img_names]
        self.depths = [os.path.join(self.root, self.split, "depth", n) for n in img_names]

        if self.split == "train":
            self.targets = [os.path.join(self.root, self.split, "label", n) for n in img_names]
        else:
            self.targets = None

        self.image_transform = image_transform
        self.depth_transform = depth_transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert("RGB")
        depth = Image.open(self.depths[idx])

        if self.image_transform is not None:
            image = self.image_transform(image)
        if self.depth_transform is not None:
            depth = self.depth_transform(depth)

        if self.split == "test":
            return (image, depth) if self.include_depth else image

        target = Image.open(self.targets[idx])
        if self.target_transform is not None:
            target = self.target_transform(target)

        return (image, depth, target) if self.include_depth else (image, target)


# =========================
# (optional) Boundary utils
# =========================
def boundary_mask_from_label(label, ignore_index=255):
    valid = (label != ignore_index)
    b = torch.zeros_like(label, dtype=torch.bool)
    b[:, 1:, :] |= (label[:, 1:, :] != label[:, :-1, :])
    b[:, :-1, :] |= (label[:, :-1, :] != label[:, 1:, :])
    b[:, :, 1:] |= (label[:, :, 1:] != label[:, :, :-1])
    b[:, :, :-1] |= (label[:, :, :-1] != label[:, :, 1:])
    b &= valid
    return b

def dilate_mask(mask_bool, radius: int):
    if radius <= 0:
        return mask_bool
    x = mask_bool.float().unsqueeze(1)
    k = 2 * radius + 1
    y = nn.functional.max_pool2d(x, kernel_size=k, stride=1, padding=radius)
    return (y.squeeze(1) > 0.5)

def boundary_focal_bce_dice_loss(
    boundary_logit, label,
    ignore_index=255,
    radius=2,
    gamma=2.0,
    alpha=0.25,
    dice_w=0.5,
    eps=1e-6
):
    with torch.no_grad():
        b0 = boundary_mask_from_label(label, ignore_index=ignore_index)
        b = dilate_mask(b0, radius=radius).float()
        valid = (label != ignore_index).float()

    logit = boundary_logit.squeeze(1)
    p = torch.sigmoid(logit)
    pt = p * b + (1 - p) * (1 - b)
    w = (alpha * b + (1 - alpha) * (1 - b)) * ((1 - pt).clamp_min(1e-6) ** gamma)

    bce = nn.functional.binary_cross_entropy_with_logits(logit, b, reduction="none")
    loss_focal = (w * bce * valid).sum() / valid.sum().clamp_min(1.0)

    prob = p * valid
    target01 = b * valid
    inter = (prob * target01).sum()
    union = prob.sum() + target01.sum()
    loss_dice = 1.0 - (2.0 * inter + eps) / (union + eps)

    return (1.0 - dice_w) * loss_focal + dice_w * loss_dice


# =========================
# Model blocks
# =========================
class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, p=1, d=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=k, padding=p, dilation=d, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.block(x)

class SeparableConvBNReLU(nn.Module):
    """Depthwise separable conv: DW(3x3) + PW(1x1) + BN + ReLU"""
    def __init__(self, in_ch, out_ch, k=3, p=1, d=1):
        super().__init__()
        self.dw = nn.Conv2d(in_ch, in_ch, kernel_size=k, padding=p, dilation=d, groups=in_ch, bias=False)
        self.pw = nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.act = nn.ReLU(inplace=True)
    def forward(self, x):
        x = self.dw(x)
        x = self.pw(x)
        x = self.bn(x)
        return self.act(x)

class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.conv1 = ConvBNReLU(in_ch + skip_ch, out_ch)
        self.conv2 = ConvBNReLU(out_ch, out_ch)

    def forward(self, x, skip):
        x = nn.functional.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x


# =========================
# Improved ASPP (Separable + Residual)
# =========================
class ASPP(nn.Module):
    """
    Improved ASPP:
      - 1x1 conv
      - 3x3 separable conv (dilation=r1,r2,r3)
      - optional image pooling branch
    concat -> 1x1 projection -> (residual add) -> BN/ReLU/Dropout
    """
    def __init__(
        self,
        in_ch: int,
        out_ch: int = 512,
        atrous_rates=(3, 6, 9),         # OS=32前提で小さめをデフォルト
        use_image_pooling: bool = True,
        proj_dropout: float = 0.1,      # 0より少し入れるのが安定しやすい
        use_separable: bool = True,
        use_residual: bool = True,
    ):
        super().__init__()
        r1, r2, r3 = atrous_rates
        bch = out_ch // 4

        self.branch1 = nn.Sequential(
            nn.Conv2d(in_ch, bch, kernel_size=1, bias=False),
            nn.BatchNorm2d(bch),
            nn.ReLU(inplace=True),
        )
        Conv3 = SeparableConvBNReLU if use_separable else ConvBNReLU
        self.branch2 = Conv3(in_ch, bch, k=3, p=r1, d=r1)
        self.branch3 = Conv3(in_ch, bch, k=3, p=r2, d=r2)
        self.branch4 = Conv3(in_ch, bch, k=3, p=r3, d=r3)

        self.use_image_pooling = bool(use_image_pooling)
        if self.use_image_pooling:
            self.img_pool = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Conv2d(in_ch, bch, kernel_size=1, bias=False),
                nn.BatchNorm2d(bch),
                nn.ReLU(inplace=True),
            )
            cat_ch = bch * 5
        else:
            self.img_pool = None
            cat_ch = bch * 4

        self.project = nn.Conv2d(cat_ch, out_ch, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.act = nn.ReLU(inplace=True)
        self.drop = nn.Dropout2d(p=float(proj_dropout)) if proj_dropout and proj_dropout > 0 else nn.Identity()

        self.use_residual = bool(use_residual)
        if self.use_residual:
            self.res_proj = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_ch),
            )
        else:
            self.res_proj = None

    def forward(self, x):
        b, c, h, w = x.shape
        y1 = self.branch1(x)
        y2 = self.branch2(x)
        y3 = self.branch3(x)
        y4 = self.branch4(x)
        ys = [y1, y2, y3, y4]

        if self.use_image_pooling:
            y5 = self.img_pool(x)
            y5 = nn.functional.interpolate(y5, size=(h, w), mode="bilinear", align_corners=False)
            ys.append(y5)

        y = torch.cat(ys, dim=1)
        y = self.project(y)

        if self.use_residual:
            y = y + self.res_proj(x)

        y = self.bn(y)
        y = self.act(y)
        y = self.drop(y)
        return y


# =========================
# ResNet50 UNet + ASPP at center
# =========================
class ResNet50UNetASPP(nn.Module):
    def __init__(
        self,
        num_classes=13,
        coarse_classes=5,
        in_channels=4,
        pretrained=True,
        use_aspp=True,
        aspp_rates=(3, 6, 9),
        aspp_image_pooling=True,
        aspp_dropout=0.1,
        aspp_out_ch=512,
        aspp_use_separable=True,
        aspp_use_residual=True,
        use_boundary_head=False,
    ):
        super().__init__()
        self.use_aspp = bool(use_aspp)
        self.use_boundary_head = bool(use_boundary_head)
        self.center_ch = int(aspp_out_ch)

        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None)

        # conv1 拡張 (RGBD=4ch)
        old_conv = resnet.conv1
        if in_channels != old_conv.in_channels:
            new_conv = nn.Conv2d(
                in_channels, old_conv.out_channels,
                kernel_size=old_conv.kernel_size,
                stride=old_conv.stride,
                padding=old_conv.padding,
                bias=False
            )
            with torch.no_grad():
                new_conv.weight[:, :3] = old_conv.weight
                if in_channels > 3:
                    mean_w = old_conv.weight.mean(dim=1, keepdim=True)
                    for c in range(3, in_channels):
                        new_conv.weight[:, c:c+1] = mean_w
            resnet.conv1 = new_conv

        # encoder
        self.enc0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)  # 1/2, 64
        self.pool = resnet.maxpool                                        # 1/4
        self.enc1 = resnet.layer1                                         # 1/4, 256
        self.enc2 = resnet.layer2                                         # 1/8, 512
        self.enc3 = resnet.layer3                                         # 1/16, 1024
        self.enc4 = resnet.layer4                                         # 1/32, 2048

        # center
        if self.use_aspp:
            self.center = ASPP(
                in_ch=2048,
                out_ch=self.center_ch,
                atrous_rates=aspp_rates,
                use_image_pooling=aspp_image_pooling,
                proj_dropout=aspp_dropout,
                use_separable=aspp_use_separable,
                use_residual=aspp_use_residual,
            )
        else:
            self.center = ConvBNReLU(2048, self.center_ch)

        # decoder（center_chに追従）
        self.up4 = UpBlock(in_ch=self.center_ch, skip_ch=1024, out_ch=512)
        self.up3 = UpBlock(in_ch=512,         skip_ch=512,  out_ch=256)  # 1/8
        self.up2 = UpBlock(in_ch=256,         skip_ch=256,  out_ch=128)
        self.up1 = UpBlock(in_ch=128,         skip_ch=64,   out_ch=64)

        # fine head
        self.head_feat = ConvBNReLU(64, 64)
        self.head_out  = nn.Conv2d(64, num_classes, kernel_size=1)

        # coarse head @1/8（必要なときだけ呼ぶ）
        self.coarse_head = nn.Conv2d(256, coarse_classes, kernel_size=1)

        # optional boundary head @H,W
        if self.use_boundary_head:
            self.boundary_head = nn.Sequential(
                ConvBNReLU(64, 32),
                nn.Conv2d(32, 1, kernel_size=1)
            )
        else:
            self.boundary_head = None

    def forward(self, x, need_coarse: bool = False, need_boundary: bool = False):
        c1 = self.enc0(x)          # 1/2
        t  = self.pool(c1)         # 1/4
        c2 = self.enc1(t)          # 1/4
        c3 = self.enc2(c2)         # 1/8
        c4 = self.enc3(c3)         # 1/16
        c5 = self.enc4(c4)         # 1/32

        x = self.center(c5)        # 1/32, center_ch

        x = self.up4(x, c4)
        x = self.up3(x, c3)

        coarse_logits = self.coarse_head(x) if need_coarse else None

        x = self.up2(x, c2)
        x = self.up1(x, c1)

        x = nn.functional.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)  # -> H,W
        feat = self.head_feat(x)
        fine_logits = self.head_out(feat)

        boundary_logit = self.boundary_head(x) if (need_boundary and self.boundary_head is not None) else None
        return fine_logits, coarse_logits, boundary_logit


# =========================
# Losses
# =========================
def focal_ce_loss(logits, target, weight=None, ignore_index=255, gamma=2.0):
    ce = nn.functional.cross_entropy(
        logits, target, weight=weight, ignore_index=ignore_index, reduction="none"
    )
    pt = torch.exp(-ce)
    loss = ((1 - pt) ** gamma) * ce
    valid = (target != ignore_index)
    denom = valid.sum().clamp_min(1)
    return (loss * valid).sum() / denom

def main_seg_loss(
    logits, label,
    class_weight_tensor=None,
    use_focal=False,
    gamma=2.0,
    ignore_index=255
):
    if use_focal:
        return focal_ce_loss(logits, label, weight=class_weight_tensor, ignore_index=ignore_index, gamma=gamma)
    return nn.functional.cross_entropy(logits, label, weight=class_weight_tensor, ignore_index=ignore_index)

def downsample_label_nearest(label, size_hw):
    x = label.unsqueeze(1).float()
    x = nn.functional.interpolate(x, size=size_hw, mode="nearest")
    return x.squeeze(1).long()


# =========================
# Metrics / Logging helpers
# =========================
@torch.no_grad()
def update_confusion_matrix(cm, pred, label, num_classes, ignore_index=255):
    pred = pred.view(-1)
    label = label.view(-1)
    mask = (label != ignore_index)
    pred = pred[mask]
    label = label[mask]
    if pred.numel() == 0:
        return cm
    idx = label * num_classes + pred
    binc = torch.bincount(idx, minlength=num_classes*num_classes).cpu()
    cm += binc.view(num_classes, num_classes)
    return cm

def compute_iou_prf_from_cm(cm):
    cmf = cm.float()
    tp = torch.diag(cmf)
    fp = cmf.sum(dim=0) - tp
    fn = cmf.sum(dim=1) - tp
    denom_iou = tp + fp + fn
    iou = torch.where(denom_iou > 0, tp / denom_iou, torch.zeros_like(denom_iou))

    prec_d = (tp + fp).clamp_min(1.0)
    rec_d  = (tp + fn).clamp_min(1.0)
    precision = tp / prec_d
    recall    = tp / rec_d
    f1 = (2 * precision * recall) / (precision + recall).clamp_min(1e-12)

    miou = iou.mean().item()
    return iou, miou, tp, fp, fn, precision, recall, f1

def top_confusions_from_cm(cm, k=20, ignore_diag=True):
    # return list of (gt, pred, count) sorted desc
    cm_ = cm.clone()
    if ignore_diag:
        cm_.fill_diagonal_(0)
    flat = cm_.view(-1)
    if flat.sum().item() == 0:
        return []
    topv, topi = torch.topk(flat, k=min(k, flat.numel()))
    res = []
    n = cm.size(0)
    for v, idx in zip(topv.tolist(), topi.tolist()):
        if v <= 0:
            continue
        gt = idx // n
        pr = idx % n
        res.append((gt, pr, int(v)))
    return res


# =========================
# Config
# =========================
@dataclass
class TrainingConfig:
    dataset_root: str = "/content/data"
    image_size: tuple = (512, 512)
    batch_size: int = 16
    num_workers: int = 4
    num_classes: int = 13
    coarse_classes: int = 5
    epochs: int = 50

    learning_rate: float = 1e-4
    weight_decay: float = 1e-4
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    train_val_split: float = 0.9
    ignore_index: int = 255
    seed: int = 42

    # Normalize (ImageNet)
    normalize_mean: tuple = (0.485, 0.456, 0.406)
    normalize_std: tuple = (0.229, 0.224, 0.225)

    # Augmentation switches
    use_colorjitter: bool = False
    jitter_brightness: float = 0.4
    jitter_contrast: float = 0.4
    jitter_saturation: float = 0.3
    jitter_hue: float = 0.05
    gray_p: float = 0.0

    # ASPP (improved defaults)
    use_aspp: bool = True
    aspp_rates: tuple = (3, 6, 9)
    aspp_image_pooling: bool = True
    aspp_dropout: float = 0.1
    aspp_out_ch: int = 512
    aspp_use_separable: bool = True
    aspp_use_residual: bool = True

    # Optional heads / losses
    use_focal: bool = False
    focal_gamma: float = 2.0

    coarse_w: float = 0.0
    boundary_w: float = 0.0
    boundary_radius: int = 2
    boundary_focal_gamma: float = 2.0
    boundary_focal_alpha: float = 0.25
    boundary_dice_w: float = 0.5

    # Logging / saving
    checkpoint_dir: str = "checkpoints"
    log_path: str = "train_log.jsonl"
    save_best: bool = False
    confusion_topk: int = 20

    def __post_init__(self):
        os.makedirs(self.checkpoint_dir, exist_ok=True)


# =========================
# Main
# =========================
def main():
    config = TrainingConfig(
        dataset_root="/content/data",
        image_size=(512, 512),
        batch_size=16,
        num_workers=4,
        epochs=50,
        learning_rate=1e-4,
        weight_decay=1e-4,

        # ASPP ON (improved)
        use_aspp=True,
        aspp_rates=(3, 6, 9),
        aspp_image_pooling=True,
        aspp_dropout=0.1,
        aspp_out_ch=512,
        aspp_use_separable=True,
        aspp_use_residual=True,

        # baseline寄せ：augmentation OFF（必要ならON）
        use_colorjitter=False,
        gray_p=0.0,

        # baseline寄せ：補助loss OFF（必要ならON）
        coarse_w=0.0,
        boundary_w=0.0,

        # focal OFF（必要ならON）
        use_focal=False,
    )

    run_ts = time.strftime("%Y%m%d%H%M%S")
    base, ext = os.path.splitext(config.log_path)
    if ext.lower() != ".jsonl":
        ext = ".jsonl"
    config.log_path = f"{base}_{run_ts}{ext}"

    set_seed(config.seed)
    device = config.device
    print(f"Using device: {device}")
    print(f"Run timestamp: {run_ts}")
    print(f"Log path     : {config.log_path}")

    # transforms
    jitter_tf = []
    if config.use_colorjitter:
        jitter_tf.append(ColorJitter(
            brightness=config.jitter_brightness,
            contrast=config.jitter_contrast,
            saturation=config.jitter_saturation,
            hue=config.jitter_hue,
        ))

    train_image_transform = Compose([
        Resize(config.image_size, interpolation=InterpolationMode.BILINEAR),
        RandomGrayscale3ch(p=config.gray_p),
        *jitter_tf,
        ToTensor(),
        Normalize(config.normalize_mean, config.normalize_std),
    ])

    eval_image_transform = Compose([
        Resize(config.image_size, interpolation=InterpolationMode.BILINEAR),
        ToTensor(),
        Normalize(config.normalize_mean, config.normalize_std),
    ])

    def depth_transform(depth_pil):
        return depth_pil_to_tensor_01(depth_pil, config.image_size)

    def target_transform(lbl_pil: Image.Image) -> torch.Tensor:
        lbl_pil = lbl_pil.resize((config.image_size[1], config.image_size[0]), resample=Image.NEAREST)
        return pil_label_to_long_tensor(lbl_pil)

    # fixed split indices（indices決定用のbase）
    split_base = NYUv2(
        root=config.dataset_root,
        split="train",
        include_depth=True,
        image_transform=eval_image_transform,
        depth_transform=depth_transform,
        target_transform=target_transform,
    )
    n_total = len(split_base)
    n_train = int(n_total * config.train_val_split)
    n_val = n_total - n_train
    g = torch.Generator().manual_seed(config.seed)
    train_subset_base, val_subset_base = random_split(split_base, [n_train, n_val], generator=g)
    train_indices = train_subset_base.indices
    val_indices = val_subset_base.indices

    # 実データセット（transform違いだけ分ける）
    train_full = NYUv2(
        root=config.dataset_root,
        split="train",
        include_depth=True,
        image_transform=train_image_transform,
        depth_transform=depth_transform,
        target_transform=target_transform,
    )
    val_full = NYUv2(
        root=config.dataset_root,
        split="train",
        include_depth=True,
        image_transform=eval_image_transform,
        depth_transform=depth_transform,
        target_transform=target_transform,
    )
    train_ds = Subset(train_full, train_indices)
    val_ds   = Subset(val_full, val_indices)

    test_ds = NYUv2(
        root=config.dataset_root,
        split="test",
        include_depth=True,
        image_transform=eval_image_transform,
        depth_transform=depth_transform,
        target_transform=None,
    )

    train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True,
                              num_workers=config.num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False,
                            num_workers=config.num_workers, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False,
                             num_workers=config.num_workers, pin_memory=True)

    # model
    model = ResNet50UNetASPP(
        num_classes=config.num_classes,
        coarse_classes=config.coarse_classes,
        in_channels=4,
        pretrained=True,
        use_aspp=config.use_aspp,
        aspp_rates=config.aspp_rates,
        aspp_image_pooling=config.aspp_image_pooling,
        aspp_dropout=config.aspp_dropout,
        aspp_out_ch=config.aspp_out_ch,
        aspp_use_separable=config.aspp_use_separable,
        aspp_use_residual=config.aspp_use_residual,
        use_boundary_head=(config.boundary_w > 0.0),
    ).to(device)

    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    scaler = GradScaler("cuda" if str(device).startswith("cuda") else "cpu")

    # log init
    with open(config.log_path, "w", encoding="utf-8") as f:
        pass

    def log_json(obj):
        with open(config.log_path, "a", encoding="utf-8") as f:
            f.write(json.dumps(obj, ensure_ascii=False) + "\n")

    log_json({
        "type": "run_header",
        "run_ts": run_ts,
        "device": str(device),
        "dataset_root": config.dataset_root,
        "image_size": list(config.image_size),
        "batch_size": config.batch_size,
        "epochs": config.epochs,
        "lr": config.learning_rate,
        "weight_decay": config.weight_decay,
        "use_aspp": bool(config.use_aspp),
        "aspp_rates": list(config.aspp_rates),
        "aspp_image_pooling": bool(config.aspp_image_pooling),
        "aspp_dropout": float(config.aspp_dropout),
        "aspp_out_ch": int(config.aspp_out_ch),
        "aspp_use_separable": bool(config.aspp_use_separable),
        "aspp_use_residual": bool(config.aspp_use_residual),
        "use_colorjitter": bool(config.use_colorjitter),
        "gray_p": float(config.gray_p),
        "use_focal": bool(config.use_focal),
        "coarse_w": float(config.coarse_w),
        "boundary_w": float(config.boundary_w),
        "class_names_13": CLASS_NAMES_13,
    })

    best_miou = -1.0
    best_model_path = os.path.join(config.checkpoint_dir, f"model_best_{run_ts}.pt")

    need_coarse = (config.coarse_w > 0.0)
    need_boundary = (config.boundary_w > 0.0)

    for epoch in range(config.epochs):
        # train
        model.train()
        total_loss = 0.0
        total_main = 0.0
        total_coarse = 0.0
        total_boundary = 0.0
        seen = 0

        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.epochs} (train)") as pbar:
            for (rgb, depth, label) in pbar:
                rgb = rgb.to(device, non_blocking=True)
                depth = depth.to(device, non_blocking=True)
                label = label.to(device, non_blocking=True)

                optimizer.zero_grad(set_to_none=True)
                x = torch.cat([rgb.float(), depth.float()], dim=1)

                with autocast(device_type="cuda" if str(device).startswith("cuda") else "cpu", enabled=True):
                    fine_logits, coarse_logits, boundary_logit = model(
                        x, need_coarse=need_coarse, need_boundary=need_boundary
                    )

                    loss_main = main_seg_loss(
                        fine_logits, label,
                        class_weight_tensor=None,
                        use_focal=config.use_focal,
                        gamma=config.focal_gamma,
                        ignore_index=config.ignore_index
                    )

                    loss_coarse = 0.0
                    if need_coarse:
                        # coarse supervision は設計次第（現状は 0 のままにしておく）
                        # 必要なら 13->5 map を入れてCEを実装
                        loss_coarse = 0.0

                    loss_boundary = 0.0
                    if need_boundary:
                        loss_boundary = boundary_focal_bce_dice_loss(
                            boundary_logit, label,
                            ignore_index=config.ignore_index,
                            radius=config.boundary_radius,
                            gamma=config.boundary_focal_gamma,
                            alpha=config.boundary_focal_alpha,
                            dice_w=config.boundary_dice_w
                        )

                    loss = loss_main
                    if need_coarse:
                        loss = loss + config.coarse_w * loss_coarse
                    if need_boundary:
                        loss = loss + config.boundary_w * loss_boundary

                if torch.isnan(loss) or torch.isinf(loss):
                    log_json({"type": "train_nan", "epoch": epoch + 1, "msg": "loss is NaN/Inf, skipped batch"})
                    continue

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

                bs = rgb.size(0)
                total_loss += float(loss.item()) * bs
                total_main += float(loss_main.item()) * bs
                total_coarse += (float(loss_coarse) if isinstance(loss_coarse, (int, float)) else float(loss_coarse.item())) * bs
                total_boundary += (float(loss_boundary) if isinstance(loss_boundary, (int, float)) else float(loss_boundary.item())) * bs
                seen += bs

                pbar.set_postfix(
                    loss=float(loss.item()),
                    main=float(loss_main.item()),
                    coarse=float(total_coarse/max(1,seen)),
                    b=float(total_boundary/max(1,seen)),
                    aspp=bool(config.use_aspp),
                )

        train_loss = total_loss / max(1, seen)
        train_main = total_main / max(1, seen)
        train_coarse = total_coarse / max(1, seen)
        train_boundary = total_boundary / max(1, seen)

        # val
        model.eval()
        cm = torch.zeros((config.num_classes, config.num_classes), dtype=torch.long)
        val_loss_sum, val_seen = 0.0, 0

        with torch.no_grad():
            for (rgb, depth, label) in tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.epochs} (val)", leave=False):
                rgb = rgb.to(device, non_blocking=True)
                depth = depth.to(device, non_blocking=True)
                label = label.to(device, non_blocking=True)

                x = torch.cat([rgb.float(), depth.float()], dim=1)

                with autocast(device_type="cuda" if str(device).startswith("cuda") else "cpu", enabled=True):
                    fine_logits, coarse_logits, boundary_logit = model(
                        x, need_coarse=need_coarse, need_boundary=need_boundary
                    )
                    loss_main = main_seg_loss(
                        fine_logits, label,
                        class_weight_tensor=None,
                        use_focal=config.use_focal,
                        gamma=config.focal_gamma,
                        ignore_index=config.ignore_index
                    )

                    vloss = loss_main
                    if need_boundary:
                        lb = boundary_focal_bce_dice_loss(
                            boundary_logit, label,
                            ignore_index=config.ignore_index,
                            radius=config.boundary_radius,
                            gamma=config.boundary_focal_gamma,
                            alpha=config.boundary_focal_alpha,
                            dice_w=config.boundary_dice_w
                        )
                        vloss = vloss + config.boundary_w * lb
                    # coarseは現状無効（必要なら同様に足す）

                pred = fine_logits.argmax(dim=1)
                cm = update_confusion_matrix(cm, pred.cpu(), label.cpu(), config.num_classes, ignore_index=config.ignore_index)

                val_loss_sum += float(vloss.item())
                val_seen += 1

        iou, miou, tp, fp, fn, precision, recall, f1 = compute_iou_prf_from_cm(cm)

        iou_list = iou.tolist()
        prec_list = precision.tolist()
        rec_list = recall.tolist()
        f1_list = f1.tolist()

        worst = sorted(range(config.num_classes), key=lambda c: iou_list[c])[:5]
        best  = sorted(range(config.num_classes), key=lambda c: iou_list[c], reverse=True)[:5]

        conf_top = top_confusions_from_cm(cm, k=config.confusion_topk, ignore_diag=True)
        conf_top_named = [
            {"gt": int(g), "pred": int(p), "count": int(cnt),
             "gt_name": CLASS_NAMES_13[g], "pred_name": CLASS_NAMES_13[p]}
            for (g, p, cnt) in conf_top
        ]

        print(
            f"Epoch {epoch+1}: "
            f"train_loss={train_loss:.4f} (main={train_main:.4f}, coarse={train_coarse:.4f}, b={train_boundary:.4f}) "
            f"val_loss={val_loss_sum/max(1,val_seen):.4f} val_mIoU={miou:.5f} "
            f"(ASPP={config.use_aspp}, out={config.aspp_out_ch}, rates={config.aspp_rates}, jitter={config.use_colorjitter})"
        )

        log_json({
            "type": "epoch",
            "run_ts": run_ts,
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "train_main": train_main,
            "train_coarse": train_coarse,
            "train_boundary": train_boundary,
            "val_loss": val_loss_sum / max(1, val_seen),
            "val_miou": miou,

            "class_iou": iou_list,
            "class_precision": prec_list,
            "class_recall": rec_list,
            "class_f1": f1_list,

            "tp": tp.tolist(),
            "fp": fp.tolist(),
            "fn": fn.tolist(),

            "pred_hist": cm.sum(dim=0).tolist(),
            "gt_hist": cm.sum(dim=1).tolist(),

            "best_classes": best,
            "worst_classes": worst,
            "best_names": [CLASS_NAMES_13[i] for i in best],
            "worst_names": [CLASS_NAMES_13[i] for i in worst],

            "confusion_topk": conf_top_named,

            "use_aspp": bool(config.use_aspp),
            "aspp_out_ch": int(config.aspp_out_ch),
            "aspp_rates": list(config.aspp_rates),
            "aspp_use_separable": bool(config.aspp_use_separable),
            "aspp_use_residual": bool(config.aspp_use_residual),
        })

        if config.save_best and miou > best_miou:
            best_miou = miou
            torch.save(model.state_dict(), best_model_path)
            log_json({"type": "best", "epoch": epoch + 1, "val_miou": best_miou, "path": best_model_path})
            print(f"  [BEST] saved: epoch={epoch+1} miou={best_miou:.5f} -> {best_model_path}")

    # Save final
    current_time = time.strftime("%Y%m%d%H%M%S")
    model_path = os.path.join(config.checkpoint_dir, f"model_final_{current_time}.pt")
    torch.save(model.state_dict(), model_path)
    print(f"Final model saved to {model_path}")
    print(f"Log saved to {config.log_path}")

    # Inference -> submission.npy
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    predictions = []
    with torch.no_grad():
        for (rgb, depth) in tqdm(test_loader, desc="Generating predictions"):
            rgb = rgb.to(device, non_blocking=True)
            depth = depth.to(device, non_blocking=True)
            x = torch.cat([rgb.float(), depth.float()], dim=1)

            with autocast(device_type="cuda" if str(device).startswith("cuda") else "cpu", enabled=True):
                fine_logits, _, _ = model(x, need_coarse=False, need_boundary=False)
                pred = fine_logits.argmax(dim=1)

            predictions.append(pred.cpu())

    predictions = torch.cat(predictions, dim=0).numpy()
    np.save("submission.npy", predictions)
    print("Predictions saved to submission.npy")

    # Zip for submission（ここは要求通り“そのまま”）
    notebook_path = "/content/drive/MyDrive/Colab Notebooks/DL_Basic_2025_Competition_NYUv2_baseline.ipynb"
    with ZipFile("submission.zip", mode="w", compression=ZIP_DEFLATED, compresslevel=9) as zf:
        zf.write("submission.npy")
        zf.write(model_path, arcname=os.path.basename(model_path))
        zf.write(config.log_path, arcname=os.path.basename(config.log_path))
        if os.path.exists(notebook_path):
            zf.write(notebook_path, arcname="DL_Basic_2025_Competition_NYUv2_baseline.ipynb")
    print("Created submission.zip")


if __name__ == "__main__":
    main()


Using device: cuda
Run timestamp: 20260104013547
Log path     : train_log_20260104013547.jsonl


Epoch 1/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.24it/s, aspp=1, b=0, coarse=0, loss=1.54, main=1.54]


Epoch 1: train_loss=1.8812 (main=1.8812, coarse=0.0000, b=0.0000) val_loss=1.5892 val_mIoU=0.21545 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 2/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.24it/s, aspp=1, b=0, coarse=0, loss=1.27, main=1.27]


Epoch 2: train_loss=1.3729 (main=1.3729, coarse=0.0000, b=0.0000) val_loss=1.3713 val_mIoU=0.26400 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 3/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.27it/s, aspp=1, b=0, coarse=0, loss=1.12, main=1.12]


Epoch 3: train_loss=1.1108 (main=1.1108, coarse=0.0000, b=0.0000) val_loss=1.2540 val_mIoU=0.32762 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 4/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.23it/s, aspp=1, b=0, coarse=0, loss=0.806, main=0.806]


Epoch 4: train_loss=0.9154 (main=0.9154, coarse=0.0000, b=0.0000) val_loss=1.2138 val_mIoU=0.36588 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 5/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.29it/s, aspp=1, b=0, coarse=0, loss=0.75, main=0.75]


Epoch 5: train_loss=0.7706 (main=0.7706, coarse=0.0000, b=0.0000) val_loss=1.1414 val_mIoU=0.41282 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 6/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.24it/s, aspp=1, b=0, coarse=0, loss=0.617, main=0.617]


Epoch 6: train_loss=0.6524 (main=0.6524, coarse=0.0000, b=0.0000) val_loss=1.1103 val_mIoU=0.42587 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 7/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.28it/s, aspp=1, b=0, coarse=0, loss=0.524, main=0.524]


Epoch 7: train_loss=0.5679 (main=0.5679, coarse=0.0000, b=0.0000) val_loss=1.0647 val_mIoU=0.43757 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 8/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.26it/s, aspp=1, b=0, coarse=0, loss=0.422, main=0.422]


Epoch 8: train_loss=0.4946 (main=0.4946, coarse=0.0000, b=0.0000) val_loss=1.0764 val_mIoU=0.44708 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 9/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.27it/s, aspp=1, b=0, coarse=0, loss=0.481, main=0.481]


Epoch 9: train_loss=0.4296 (main=0.4296, coarse=0.0000, b=0.0000) val_loss=1.0391 val_mIoU=0.46217 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 10/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.32it/s, aspp=1, b=0, coarse=0, loss=0.365, main=0.365]


Epoch 10: train_loss=0.3885 (main=0.3885, coarse=0.0000, b=0.0000) val_loss=1.0131 val_mIoU=0.47143 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 11/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.27it/s, aspp=1, b=0, coarse=0, loss=0.351, main=0.351]


Epoch 11: train_loss=0.3565 (main=0.3565, coarse=0.0000, b=0.0000) val_loss=1.0557 val_mIoU=0.46643 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 12/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.25it/s, aspp=1, b=0, coarse=0, loss=0.412, main=0.412]


Epoch 12: train_loss=0.3184 (main=0.3184, coarse=0.0000, b=0.0000) val_loss=0.9735 val_mIoU=0.50161 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 13/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.27it/s, aspp=1, b=0, coarse=0, loss=0.288, main=0.288]


Epoch 13: train_loss=0.2859 (main=0.2859, coarse=0.0000, b=0.0000) val_loss=0.9533 val_mIoU=0.52873 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 14/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.26it/s, aspp=1, b=0, coarse=0, loss=0.246, main=0.246]


Epoch 14: train_loss=0.2565 (main=0.2565, coarse=0.0000, b=0.0000) val_loss=0.9608 val_mIoU=0.51311 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 15/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.30it/s, aspp=1, b=0, coarse=0, loss=0.228, main=0.228]


Epoch 15: train_loss=0.2325 (main=0.2325, coarse=0.0000, b=0.0000) val_loss=0.9871 val_mIoU=0.50119 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 16/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.25it/s, aspp=1, b=0, coarse=0, loss=0.228, main=0.228]


Epoch 16: train_loss=0.2232 (main=0.2232, coarse=0.0000, b=0.0000) val_loss=0.9830 val_mIoU=0.51938 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 17/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.27it/s, aspp=1, b=0, coarse=0, loss=0.208, main=0.208]


Epoch 17: train_loss=0.2031 (main=0.2031, coarse=0.0000, b=0.0000) val_loss=0.9899 val_mIoU=0.52847 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 18/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.23it/s, aspp=1, b=0, coarse=0, loss=0.154, main=0.154]


Epoch 18: train_loss=0.1810 (main=0.1810, coarse=0.0000, b=0.0000) val_loss=0.9745 val_mIoU=0.51646 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 19/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.25it/s, aspp=1, b=0, coarse=0, loss=0.194, main=0.194]


Epoch 19: train_loss=0.1701 (main=0.1701, coarse=0.0000, b=0.0000) val_loss=0.9601 val_mIoU=0.53496 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 20/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.23it/s, aspp=1, b=0, coarse=0, loss=0.163, main=0.163]


Epoch 20: train_loss=0.1647 (main=0.1647, coarse=0.0000, b=0.0000) val_loss=1.0428 val_mIoU=0.48826 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 21/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.24it/s, aspp=1, b=0, coarse=0, loss=0.151, main=0.151]


Epoch 21: train_loss=0.1628 (main=0.1628, coarse=0.0000, b=0.0000) val_loss=1.0831 val_mIoU=0.51215 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 22/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.28it/s, aspp=1, b=0, coarse=0, loss=0.14, main=0.14]


Epoch 22: train_loss=0.1583 (main=0.1583, coarse=0.0000, b=0.0000) val_loss=1.0273 val_mIoU=0.50889 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 23/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.28it/s, aspp=1, b=0, coarse=0, loss=0.136, main=0.136]


Epoch 23: train_loss=0.1558 (main=0.1558, coarse=0.0000, b=0.0000) val_loss=1.0007 val_mIoU=0.52735 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 24/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.28it/s, aspp=1, b=0, coarse=0, loss=0.138, main=0.138]


Epoch 24: train_loss=0.1362 (main=0.1362, coarse=0.0000, b=0.0000) val_loss=0.9767 val_mIoU=0.53759 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 25/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.29it/s, aspp=1, b=0, coarse=0, loss=0.128, main=0.128]


Epoch 25: train_loss=0.1226 (main=0.1226, coarse=0.0000, b=0.0000) val_loss=1.0093 val_mIoU=0.53199 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 26/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.27it/s, aspp=1, b=0, coarse=0, loss=0.137, main=0.137]


Epoch 26: train_loss=0.1136 (main=0.1136, coarse=0.0000, b=0.0000) val_loss=1.0182 val_mIoU=0.51021 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 27/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.30it/s, aspp=1, b=0, coarse=0, loss=0.0945, main=0.0945]


Epoch 27: train_loss=0.1108 (main=0.1108, coarse=0.0000, b=0.0000) val_loss=1.0513 val_mIoU=0.52126 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 28/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.31it/s, aspp=1, b=0, coarse=0, loss=0.119, main=0.119]


Epoch 28: train_loss=0.1053 (main=0.1053, coarse=0.0000, b=0.0000) val_loss=0.9625 val_mIoU=0.54423 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 29/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.25it/s, aspp=1, b=0, coarse=0, loss=0.0868, main=0.0868]


Epoch 29: train_loss=0.0962 (main=0.0962, coarse=0.0000, b=0.0000) val_loss=0.9763 val_mIoU=0.52869 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 30/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.28it/s, aspp=1, b=0, coarse=0, loss=0.0764, main=0.0764]


Epoch 30: train_loss=0.0894 (main=0.0894, coarse=0.0000, b=0.0000) val_loss=0.9809 val_mIoU=0.54225 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 31/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.24it/s, aspp=1, b=0, coarse=0, loss=0.109, main=0.109]


Epoch 31: train_loss=0.0840 (main=0.0840, coarse=0.0000, b=0.0000) val_loss=0.9819 val_mIoU=0.53843 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 32/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.30it/s, aspp=1, b=0, coarse=0, loss=0.0953, main=0.0953]


Epoch 32: train_loss=0.0866 (main=0.0866, coarse=0.0000, b=0.0000) val_loss=1.0126 val_mIoU=0.53875 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 33/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.25it/s, aspp=1, b=0, coarse=0, loss=0.0889, main=0.0889]


Epoch 33: train_loss=0.0850 (main=0.0850, coarse=0.0000, b=0.0000) val_loss=1.0339 val_mIoU=0.53014 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 34/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.29it/s, aspp=1, b=0, coarse=0, loss=0.0696, main=0.0696]


Epoch 34: train_loss=0.0741 (main=0.0741, coarse=0.0000, b=0.0000) val_loss=0.9988 val_mIoU=0.54436 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 35/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.30it/s, aspp=1, b=0, coarse=0, loss=0.081, main=0.081]


Epoch 35: train_loss=0.0708 (main=0.0708, coarse=0.0000, b=0.0000) val_loss=1.0027 val_mIoU=0.54853 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 36/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.22it/s, aspp=1, b=0, coarse=0, loss=0.0803, main=0.0803]


Epoch 36: train_loss=0.0663 (main=0.0663, coarse=0.0000, b=0.0000) val_loss=1.0120 val_mIoU=0.54834 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 37/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.30it/s, aspp=1, b=0, coarse=0, loss=0.0643, main=0.0643]


Epoch 37: train_loss=0.0639 (main=0.0639, coarse=0.0000, b=0.0000) val_loss=1.0496 val_mIoU=0.54726 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 38/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.29it/s, aspp=1, b=0, coarse=0, loss=0.112, main=0.112]


Epoch 38: train_loss=0.0604 (main=0.0604, coarse=0.0000, b=0.0000) val_loss=1.0206 val_mIoU=0.55724 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 39/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.28it/s, aspp=1, b=0, coarse=0, loss=0.07, main=0.07]


Epoch 39: train_loss=0.0604 (main=0.0604, coarse=0.0000, b=0.0000) val_loss=1.0443 val_mIoU=0.53895 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 40/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.24it/s, aspp=1, b=0, coarse=0, loss=0.0531, main=0.0531]


Epoch 40: train_loss=0.0562 (main=0.0562, coarse=0.0000, b=0.0000) val_loss=1.0287 val_mIoU=0.55012 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 41/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.23it/s, aspp=1, b=0, coarse=0, loss=0.0508, main=0.0508]


Epoch 41: train_loss=0.0583 (main=0.0583, coarse=0.0000, b=0.0000) val_loss=1.0804 val_mIoU=0.52372 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 42/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.26it/s, aspp=1, b=0, coarse=0, loss=0.0941, main=0.0941]


Epoch 42: train_loss=0.0663 (main=0.0663, coarse=0.0000, b=0.0000) val_loss=1.0934 val_mIoU=0.52124 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 43/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.29it/s, aspp=1, b=0, coarse=0, loss=0.0852, main=0.0852]


Epoch 43: train_loss=0.0725 (main=0.0725, coarse=0.0000, b=0.0000) val_loss=1.1049 val_mIoU=0.54184 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 44/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.26it/s, aspp=1, b=0, coarse=0, loss=0.127, main=0.127]


Epoch 44: train_loss=0.1086 (main=0.1086, coarse=0.0000, b=0.0000) val_loss=1.1822 val_mIoU=0.51553 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 45/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.28it/s, aspp=1, b=0, coarse=0, loss=0.126, main=0.126]


Epoch 45: train_loss=0.1134 (main=0.1134, coarse=0.0000, b=0.0000) val_loss=1.2020 val_mIoU=0.50189 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 46/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.27it/s, aspp=1, b=0, coarse=0, loss=0.101, main=0.101]


Epoch 46: train_loss=0.1171 (main=0.1171, coarse=0.0000, b=0.0000) val_loss=1.0871 val_mIoU=0.52561 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 47/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.28it/s, aspp=1, b=0, coarse=0, loss=0.0632, main=0.0632]


Epoch 47: train_loss=0.0846 (main=0.0846, coarse=0.0000, b=0.0000) val_loss=1.0568 val_mIoU=0.55011 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 48/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.27it/s, aspp=1, b=0, coarse=0, loss=0.0979, main=0.0979]


Epoch 48: train_loss=0.0757 (main=0.0757, coarse=0.0000, b=0.0000) val_loss=1.0674 val_mIoU=0.54550 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 49/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.25it/s, aspp=1, b=0, coarse=0, loss=0.0593, main=0.0593]


Epoch 49: train_loss=0.0746 (main=0.0746, coarse=0.0000, b=0.0000) val_loss=1.0431 val_mIoU=0.55381 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)


Epoch 50/50 (train): 100%|██████████| 45/45 [00:08<00:00,  5.28it/s, aspp=1, b=0, coarse=0, loss=0.0638, main=0.0638]


Epoch 50: train_loss=0.0593 (main=0.0593, coarse=0.0000, b=0.0000) val_loss=1.0448 val_mIoU=0.55255 (ASPP=True, out=512, rates=(3, 6, 9), jitter=False)
Final model saved to checkpoints/model_final_20260104014425.pt
Log saved to train_log_20260104013547.jsonl


Generating predictions: 100%|██████████| 654/654 [00:10<00:00, 61.38it/s]


Predictions saved to submission.npy
Created submission.zip


In [37]:
# -*- coding: utf-8 -*-
# NYUv2 val visualization (RGB | GT | Pred | RGB+Error) with FIXED palette -> ZIP
# - NO matplotlib colormap
# - GT and Pred share EXACT same palette mapping (ID -> RGB)
# - ignore_index(255) is shown as black by default

import os
import zipfile
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, Subset
from torchvision.datasets import VisionDataset
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
from torchvision import models

# =========================
# EDIT HERE
# =========================
DATASET_ROOT = "/content/data"
IMAGE_SIZE = (512, 512)

MODEL_PATH = "/content/checkpoints/model_20260103062841.pt"  # ←ここだけ直す
OUTDIR = "val_viz_fixed"
ZIP_PATH = "val_viz_fixed.zip"

MAX_SAVE = 50  # Noneで全件
SEED = 42
TRAIN_VAL_SPLIT = 0.9

NUM_WORKERS = 0
BATCH_SIZE = 1

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IGNORE_INDEX = 255

print("DEVICE:", DEVICE)
print("MODEL_PATH:", MODEL_PATH)

# =========================
# FIXED PALETTE (ID -> RGB)
# ここが「色の意味」そのもの。GTもPredも必ずこれを通す。
# =========================
PALETTE_13 = np.array([
    [  0,   0,   0],   # 0 wall
    [128,  64,  32],   # 1 floor
    [  0, 128, 255],   # 2 cabinet
    [255,   0, 128],   # 3 bed
    [  0, 255,   0],   # 4 chair
    [255, 128,   0],   # 5 sofa
    [255, 255,   0],   # 6 table
    [128, 128, 128],   # 7 door
    [  0, 255, 255],   # 8 window
    [  0,  64, 128],   # 9 bookshelf
    [255,   0,   0],   # 10 picture
    [128,   0, 255],   # 11 counter
    [ 64, 255, 128],   # 12 desk
], dtype=np.uint8)

# =========================
# Dataset helpers
# =========================
def pil_label_to_long_tensor(lbl_pil: Image.Image) -> torch.Tensor:
    arr = np.array(lbl_pil, dtype=np.int64)
    return torch.from_numpy(arr).long()

def depth_pil_to_tensor_01(depth_pil: Image.Image, size_hw):
    depth_pil = depth_pil.resize((size_hw[1], size_hw[0]), resample=Image.BILINEAR)
    arr = np.array(depth_pil)
    if arr.dtype == np.uint16 or (arr.max() > 255):
        arr = arr.astype(np.float32) / 65535.0
    else:
        arr = arr.astype(np.float32) / 255.0
    arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
    arr = np.clip(arr, 0.0, 1.0)
    return torch.from_numpy(arr).unsqueeze(0)  # [1,H,W]

class NYUv2(VisionDataset):
    def __init__(self, root, split="train", include_depth=True,
                 image_transform=None, depth_transform=None, target_transform=None):
        super().__init__(root)
        assert split in ("train", "test")
        self.root = root
        self.split = split
        self.include_depth = include_depth

        images_dir = os.path.join(self.root, self.split, "image")
        img_names = sorted(os.listdir(images_dir))

        self.images = [os.path.join(images_dir, n) for n in img_names]
        self.depths = [os.path.join(self.root, self.split, "depth", n) for n in img_names]

        if self.split == "train":
            self.targets = [os.path.join(self.root, self.split, "label", n) for n in img_names]
        else:
            self.targets = None

        self.image_transform = image_transform
        self.depth_transform = depth_transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert("RGB")
        depth = Image.open(self.depths[idx])

        if self.image_transform is not None:
            image = self.image_transform(image)
        if self.depth_transform is not None:
            depth = self.depth_transform(depth)

        if self.split == "test":
            return (image, depth) if self.include_depth else image

        target = Image.open(self.targets[idx])
        if self.target_transform is not None:
            target = self.target_transform(target)

        return (image, depth, target) if self.include_depth else (image, target)

# =========================
# Model blocks (ResNet50-UNet + optional SE)
# =========================
class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, p=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=k, padding=p, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.block(x)

class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.conv1 = ConvBNReLU(in_ch + skip_ch, out_ch)
        self.conv2 = ConvBNReLU(out_ch, out_ch)

    def forward(self, x, skip):
        x = nn.functional.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class DepthStem(nn.Module):
    def __init__(self, z_ch=64):
        super().__init__()
        self.net = nn.Sequential(
            ConvBNReLU(1, 32, 3, 1),
            nn.MaxPool2d(2),
            ConvBNReLU(32, 48, 3, 1),
            nn.MaxPool2d(2),
            ConvBNReLU(48, z_ch, 3, 1),
            nn.MaxPool2d(2),
        )
    def forward(self, d):
        return self.net(d)

class SEFromDepth(nn.Module):
    def __init__(self, feat_ch, z_ch, reduction=16, alpha=0.05):
        super().__init__()
        self.alpha = float(alpha)
        hidden = max(32, feat_ch // reduction)
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(z_ch, hidden, 1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden, feat_ch, 1, bias=True),
            nn.Sigmoid(),
        )
    def forward(self, F, z):
        g = self.fc(z)  # [B,C,1,1]
        scale = 1.0 + self.alpha * (g - 0.5)
        return F * scale

class ResNet50UNet(nn.Module):
    def __init__(self, num_classes=13, coarse_classes=5, in_channels=4,
                 pretrained=False,
                 use_se=False, se_z_ch=64, se_reduction=16, se_alpha=0.05,
                 se_dec_stages=("up1",)):
        super().__init__()
        self.use_se = use_se
        self.se_dec_stages = tuple(se_dec_stages)

        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None)

        old_conv = resnet.conv1
        if in_channels != old_conv.in_channels:
            new_conv = nn.Conv2d(
                in_channels, old_conv.out_channels,
                kernel_size=old_conv.kernel_size,
                stride=old_conv.stride,
                padding=old_conv.padding,
                bias=False
            )
            with torch.no_grad():
                new_conv.weight[:, :3] = old_conv.weight
                mean_w = old_conv.weight.mean(dim=1, keepdim=True)
                for c in range(3, in_channels):
                    new_conv.weight[:, c:c+1] = mean_w
            resnet.conv1 = new_conv

        self.enc0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)  # 1/2
        self.pool = resnet.maxpool                                        # 1/4
        self.enc1 = resnet.layer1                                         # 1/4
        self.enc2 = resnet.layer2                                         # 1/8
        self.enc3 = resnet.layer3                                         # 1/16
        self.enc4 = resnet.layer4                                         # 1/32

        self.center = ConvBNReLU(2048, 1024)
        self.up4 = UpBlock(1024, 1024, 512)
        self.up3 = UpBlock(512,  512,  256)
        self.up2 = UpBlock(256,  256,  128)
        self.up1 = UpBlock(128,  64,   64)

        self.head_feat = ConvBNReLU(64, 64)
        self.head_out  = nn.Conv2d(64, num_classes, kernel_size=1)

        self.coarse_head = nn.Conv2d(256, coarse_classes, kernel_size=1)
        self.boundary_head = nn.Sequential(ConvBNReLU(64, 32), nn.Conv2d(32, 1, 1))

        if self.use_se:
            self.depth_stem = DepthStem(z_ch=se_z_ch)
            stage_to_ch = {"center":1024, "up4":512, "up3":256, "up2":128, "up1":64}
            self.se_dec = nn.ModuleDict()
            for st in self.se_dec_stages:
                self.se_dec[st] = SEFromDepth(stage_to_ch[st], se_z_ch, se_reduction, se_alpha)
        else:
            self.depth_stem = None
            self.se_dec = None

    def _apply_dec_se(self, feat, z, stage_name: str):
        if (not self.use_se) or (self.se_dec is None) or (stage_name not in self.se_dec):
            return feat
        z_ = nn.functional.interpolate(z, size=feat.shape[-2:], mode="bilinear", align_corners=False)
        return self.se_dec[stage_name](feat, z_)

    def forward(self, x, depth=None):
        c1 = self.enc0(x)
        t  = self.pool(c1)
        c2 = self.enc1(t)
        c3 = self.enc2(c2)
        c4 = self.enc3(c3)
        c5 = self.enc4(c4)

        z = self.depth_stem(depth) if self.use_se else None

        x = self.center(c5)
        if z is not None: x = self._apply_dec_se(x, z, "center")
        x = self.up4(x, c4)
        if z is not None: x = self._apply_dec_se(x, z, "up4")
        x = self.up3(x, c3)
        if z is not None: x = self._apply_dec_se(x, z, "up3")

        coarse_logits = self.coarse_head(x)

        x = self.up2(x, c2)
        if z is not None: x = self._apply_dec_se(x, z, "up2")
        x = self.up1(x, c1)
        if z is not None: x = self._apply_dec_se(x, z, "up1")

        x = nn.functional.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
        feat = self.head_feat(x)
        fine_logits = self.head_out(feat)
        boundary_logit = self.boundary_head(x)
        return fine_logits, coarse_logits, boundary_logit

# =========================
# FIXED colorization (ID -> RGB)
# =========================
def id_to_rgb(mask_hw: np.ndarray, ignore_index=255, ignore_rgb=(0,0,0)) -> np.ndarray:
    """
    mask_hw: [H,W] int (0..12 or 255)
    returns: [H,W,3] uint8
    """
    out = np.zeros((mask_hw.shape[0], mask_hw.shape[1], 3), dtype=np.uint8)
    ign = (mask_hw == ignore_index)
    valid = ~ign
    m = mask_hw.copy()
    m[ign] = 0
    m = np.clip(m, 0, 12)
    out[valid] = PALETTE_13[m[valid]]
    out[ign] = np.array(ignore_rgb, dtype=np.uint8)
    return out

def denorm_rgb(rgb_tensor: torch.Tensor) -> Image.Image:
    mean = torch.tensor([0.485, 0.456, 0.406], device=rgb_tensor.device).view(3,1,1)
    std  = torch.tensor([0.229, 0.224, 0.225], device=rgb_tensor.device).view(3,1,1)
    x = (rgb_tensor * std + mean).clamp(0, 1)
    x = (x * 255).byte().permute(1,2,0).cpu().numpy()
    return Image.fromarray(x, mode="RGB")

def overlay_error(rgb_pil: Image.Image, gt_hw: np.ndarray, pred_hw: np.ndarray, ignore_index=255) -> Image.Image:
    rgb = np.array(rgb_pil).astype(np.float32)
    err = (gt_hw != ignore_index) & (gt_hw != pred_hw)
    if err.any():
        overlay = rgb.copy()
        overlay[err] = 0.6 * overlay[err] + 0.4 * np.array([255, 0, 0], dtype=np.float32)
        overlay = np.clip(overlay, 0, 255).astype(np.uint8)
        return Image.fromarray(overlay, mode="RGB")
    return rgb_pil

def stack_h(images):
    widths = [im.width for im in images]
    heights = [im.height for im in images]
    out = Image.new("RGB", (sum(widths), max(heights)))
    x = 0
    for im in images:
        out.paste(im, (x, 0))
        x += im.width
    return out

# =========================
# Transforms (same as eval)
# =========================
eval_image_transform = Compose([
    Resize(IMAGE_SIZE, interpolation=InterpolationMode.BILINEAR),
    ToTensor(),
    Normalize((0.485,0.456,0.406), (0.229,0.224,0.225)),
])

def depth_transform(pil):
    return depth_pil_to_tensor_01(pil, IMAGE_SIZE)

def target_transform(pil):
    pil = pil.resize((IMAGE_SIZE[1], IMAGE_SIZE[0]), resample=Image.NEAREST)
    return pil_label_to_long_tensor(pil)

# =========================
# Build val split (same indices rule)
# =========================
split_base = NYUv2(
    root=DATASET_ROOT,
    split="train",
    include_depth=True,
    image_transform=eval_image_transform,
    depth_transform=depth_transform,
    target_transform=target_transform,
)

n_total = len(split_base)
n_train = int(n_total * TRAIN_VAL_SPLIT)
n_val = n_total - n_train
g = torch.Generator().manual_seed(SEED)
_, val_subset = random_split(split_base, [n_train, n_val], generator=g)
val_ds = Subset(split_base, val_subset.indices)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"Val samples: {len(val_ds)}")

# =========================
# Load checkpoint + auto-detect SE
# =========================
ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
if isinstance(ckpt, dict) and "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
    sd = ckpt["state_dict"]
else:
    sd = ckpt

has_se = any(k.startswith("depth_stem.") or k.startswith("se_dec.") for k in sd.keys())
print("Detected SE in checkpoint:", has_se)

model = ResNet50UNet(
    num_classes=13,
    coarse_classes=5,
    in_channels=4,
    pretrained=False,
    use_se=has_se,
    se_z_ch=64,
    se_reduction=16,
    se_alpha=0.05,
    se_dec_stages=("up1",),
).to(DEVICE)

model.load_state_dict(sd, strict=True)
model.eval()

# =========================
# Generate panels with FIXED palette
# =========================
os.makedirs(OUTDIR, exist_ok=True)

saved = 0
with torch.no_grad():
    for i, (rgb, depth, label) in enumerate(tqdm(val_loader, desc="Saving fixed-color panels")):
        rgb = rgb.to(DEVICE, non_blocking=True)
        depth = depth.to(DEVICE, non_blocking=True)
        label = label.to(DEVICE, non_blocking=True)

        x = torch.cat([rgb.float(), depth.float()], dim=1)
        fine_logits, _, _ = model(x, depth=depth.float() if has_se else None)
        pred = fine_logits.argmax(dim=1)

        rgb_pil = denorm_rgb(rgb[0])

        gt_hw = label[0].cpu().numpy().astype(np.int32)
        pred_hw = pred[0].cpu().numpy().astype(np.int32)

        gt_rgb = id_to_rgb(gt_hw, ignore_index=IGNORE_INDEX, ignore_rgb=(0,0,0))
        pr_rgb = id_to_rgb(pred_hw, ignore_index=IGNORE_INDEX, ignore_rgb=(0,0,0))

        gt_pil = Image.fromarray(gt_rgb, mode="RGB")
        pr_pil = Image.fromarray(pr_rgb, mode="RGB")
        err_pil = overlay_error(rgb_pil, gt_hw, pred_hw, ignore_index=IGNORE_INDEX)

        panel = stack_h([rgb_pil, gt_pil, pr_pil, err_pil])
        panel.save(os.path.join(OUTDIR, f"{i:06d}.png"), optimize=True)

        saved += 1
        if MAX_SAVE is not None and saved >= MAX_SAVE:
            break

print(f"Saved: {saved} images -> {OUTDIR}")

# =========================
# Zip
# =========================
with zipfile.ZipFile(ZIP_PATH, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as zf:
    for fn in sorted(os.listdir(OUTDIR)):
        if fn.lower().endswith(".png"):
            zf.write(os.path.join(OUTDIR, fn), arcname=fn)

print(f"ZIP created: {ZIP_PATH}")


DEVICE: cuda
MODEL_PATH: /content/checkpoints/model_20260103062841.pt
Val samples: 80
Detected SE in checkpoint: False


  return Image.fromarray(x, mode="RGB")
  gt_pil = Image.fromarray(gt_rgb, mode="RGB")
  pr_pil = Image.fromarray(pr_rgb, mode="RGB")
  return Image.fromarray(overlay, mode="RGB")
Saving fixed-color panels:  61%|██████▏   | 49/80 [00:27<00:17,  1.78it/s]


Saved: 50 images -> val_viz_fixed
ZIP created: val_viz_fixed.zip


In [33]:
%ls -l checkpoints

total 4093968
-rw-r--r-- 1 root root 233561111 Jan  3 01:19 model_20260103011918.pt
-rw-r--r-- 1 root root 233561111 Jan  3 01:32 model_20260103013257.pt
-rw-r--r-- 1 root root 232859155 Jan  3 01:53 model_20260103015327.pt
-rw-r--r-- 1 root root 232859155 Jan  3 02:13 model_20260103021307.pt
-rw-r--r-- 1 root root 232859155 Jan  3 02:40 model_20260103024015.pt
-rw-r--r-- 1 root root 232800147 Jan  3 03:13 model_20260103031301.pt
-rw-r--r-- 1 root root 232800147 Jan  3 03:31 model_20260103033149.pt
-rw-r--r-- 1 root root 232800147 Jan  3 03:48 model_20260103034848.pt
-rw-r--r-- 1 root root 232800147 Jan  3 04:14 model_20260103041457.pt
-rw-r--r-- 1 root root 232814185 Jan  3 04:38 model_20260103043835.pt
-rw-r--r-- 1 root root 232814185 Jan  3 04:51 model_20260103045108.pt
-rw-r--r-- 1 root root 232882241 Jan  3 05:16 model_20260103051603.pt
-rw-r--r-- 1 root root 232883265 Jan  3 05:41 model_20260103054113.pt
-rw-r--r-- 1 root root 232688655 Jan  3 06:28 model_20260103062841.pt
-rw-r-

In [17]:
# ------------------
#    Evaluation (DeepLabV3-ResNet50 / ResNet backbone)
# ------------------

ckpt = torch.load(final_path, map_location=device)

# もし「学習で checkpoint dict（model_state入り）」で保存している場合にも対応
if isinstance(ckpt, dict) and "model_state" in ckpt:
    model.load_state_dict(ckpt["model_state"])
else:
    model.load_state_dict(ckpt)

model.eval()

predictions = []

with torch.no_grad():
    print("Generating predictions...")
    for image, depth in tqdm(test_data):
        image = image.to(device, non_blocking=True)
        depth = depth.to(device, non_blocking=True)

        # 念のため：depthが3ch事故のときは1chに落とす
        if depth.dim() == 4 and depth.size(1) == 3:
            depth = depth[:, :1, :, :]

        x = torch.cat((image, depth), dim=1)   # [B,4,H,W]

        # DeepLabV3 は dict を返すので "out" を使う
        out = model(x)["out"]                  # [B,num_classes,H,W]
        pred = out.argmax(dim=1)               # [B,H,W]

        predictions.append(pred.cpu())

predictions = torch.cat(predictions, dim=0).numpy()
np.save("submission.npy", predictions)
print("Predictions saved to submission.npy")


"""## 提出方法

以下の3点をzip化し，Omnicampusの「最終課題 (セグメンテーション)」から提出してください．

- `submission.npy`
- `model.pt`や`model_best.pt`など，テストに使用した重み（拡張子は`.pt`のみ）
- 本Colab Notebook
"""

from zipfile import ZipFile, ZIP_DEFLATED

notebook_path = "/content/drive/MyDrive/Colab Notebooks/DL_Basic_2025_Competition_NYUv2_baseline.ipynb"

from zipfile import ZipFile, ZIP_DEFLATED

notebook_path = "/content/drive/MyDrive/code.ipynb"

with ZipFile("submission.zip",
             mode="w",
             compression=ZIP_DEFLATED,
             compresslevel=9) as zf:
    zf.write("submission.npy")
    zf.write(model_path)
#    zf.write(notebook_path,
#             arcname="code.ipynb")


NameError: name 'final_path' is not defined