<a href="https://colab.research.google.com/github/trie0000/external/blob/main/code_20260103_2024.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# -*- coding: utf-8 -*-
"""DL_Basic_2025_Competition_NYUv2_baseline.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/17t7uAU0aST5aUt6sIJCqFneGlQrxFrJY

# Deep Learning 基礎講座　最終課題: NYUv2 セマンティックセグメンテーションcoarse_w=0」

## 概要
RGB画像から、画像内の各ピクセルがどのクラスに属するかを予測するセマンティックセグメンテーションタスク.

### データセット
- データセット: NYUv2 dataset
- 訓練データ: 795枚
- テストデータ: 654枚
- 入力: RGB画像 + 深度マップ（元画像サイズは可変）
- 出力: 13クラスのセグメンテーションマップ
- 評価指標: Mean IoU (Intersection over Union)

### データセットの詳細（[NYU Depth Dataset V2](https://cs.nyu.edu/~fergus/datasets/nyu_depth_v2.html)）
- 画像は屋内シーンを撮影したもので、家具や壁、床などの物体が含まれています.
- 各画像に対して13クラスのセグメンテーションラベルが提供されます.
- データは以下のディレクトリ構造で提供:
```
data/NYUv2/
├─train/
│  ├─image/      # RGB画像
│  │    000000.png
│  │    ...
│  │
│  ├─depth/      # 深度マップ
│  │    000000.png
│  │    ...
│  │
│  └─label/      # 13クラスセグメンテーション（教師ラベル）
│       000000.png
│       ...
└─test/
   ├─image/      # RGB画像
   │    000000.png
   │    ...
   │  ├─depth/   # 深度マップ
   │    000000.png
   │    ...
```

### タスクの詳細
- 入力のRGB画像と深度マップから、各ピクセルが13クラスのどれに属するかを予測するタスクです.
- 評価はMean IoUを使用します．
  - 各クラスごとにIoUを計算し、その平均を取ります.
  - IoUは以下の式で計算:
  $$IoU = \frac{TP}{TP + FP + FN}$$
    - TP: True Positive（正しく予測されたピクセル数）
    - FP: False Positive（誤って予測されたピクセル数）
    - FN: False Negative（見逃したピクセル数）

### 前処理
- 入力画像は512×512にリサイズされます.
- ピクセル値は0-1に正規化されます.
- セグメンテーションラベルは0-12の整数値（13クラス）です．
  - 255はignore index（評価から除外）

### 提出形式
- テスト画像（RGB + Depth）の各ピクセルに対してクラス（0~12）を予測したものをnumpy配列として保存されます.
- ファイル名: `submission.npy`
- 配列の形状: [テストデータ数, 高さ, 幅]
- 各ピクセルの値: 0-12の整数（予測クラス）

## 考えられる工夫の例
- 事前学習モデルの fine-tuning
    - ImageNetなどで事前学習されたモデルを本データセットでfine-tuningすることで性能向上が見込めます.
- 損失関数の再設計
    - クラスごとの出現頻度に応じて損失を補正するように損失関数を設計すると、クラス分布の不均衡に対してロバストな学習ができます.
- 画像の前処理
    - RandomResizedCrop / Flip / ColorJitter 等のデータ拡張を追加することで，汎化性能の向上が見込めます．

## 修了要件を満たす条件
- ベースラインでは，omnicampus 上での性能評価において， 38.2% となります．したがって，ベースラインである 38.2% を超えた提出のみ，修了要件として認めます．
- ベースラインから改善を加えることで， 50%以上に性能向上することを運営で確認しています．こちらを 1つの指標として取り組んでみてください．

## 注意点
- 最終的な予測モデルは，**配布している訓練データを用いて学習**（ファインチューニング含む）したものとしてください．
- 学習を行わず，**事前学習済みモデルの知識のみを利用した推論は禁止**します．
（例: ChatGPT 等の LLM に入力して推論を得るのみ）

### 事前学習モデルの利用
許可される事項
- **構成要素としての事前学習モデルの利用**: 自身で実装したアーキテクチャの一部（特徴抽出，埋め込みなど）として事前学習モデル（BERT，ViT など）を利用することは可能です．
- **ファインチューニング**: 上記の用途で利用している事前学習モデルのファインチューニングは可能です．

禁止される事項
- **タスク解決用の事前学習モデルの利用**: transformers などで提供されている，対象タスクを直接解くための事前学習モデルでそのまま推論のみ，またはファインチューニングのみで利用することは禁止とします．
  - 禁止事項の例: VQA タスクを直接解くための事前学習モデルを VQA タスクで利用する．

### データの準備
データをダウンロードした際に，google drive したため，利用するために google drive をマウントする必要があります．また， drive 上で展開することができないため，/content ディレクトリ下にコピーし "data.zip" を展開します．
google drive 上に "data.zip" が配置されていない場合は実行できません．google drive 上に "data.zip" (**831MB**) を配置することが可能であれば，"data_download.ipynb" を先に実行してください．難しい場合は，omnicampus 演習環境を利用してください．．
"""

# omnicampus 上では 4 セル目まで実行不要
# ドライブのマウント
from google.colab import drive
drive.mount('/content/drive')

# データダウンロード用の notebook にてgoogle drive への保存後，
# 反映に時間がかかる可能性がありますので，google drive のマウント後，
# data.zip がディレクトリ内にあることを確認してから実行してください．
# data.zip を /content 下にコピーする
!cp "/content/drive/MyDrive/data.zip" "/content"

# Commented out IPython magic to ensure Python compatibility.
# カレントディレクトリ下のファイル群を確認
# data.zip が表示されれば問題ないです
# %ls

# データを解凍する
!unzip data.zip
!mkdir data
!mv train test data/

"""omnicampus 演習環境では，data_download.ipynb のマウント，zip 化，drive へのコピーを実行しないことで，"data.zip" を解凍した形で配置されます．したがって，data ディレクトリが存在するディレクトリをカレントディレクトリとするだけで良いです．


"""

# Commented out IPython magic to ensure Python compatibility.
# omnicampus 実行用
# 以下の例では/workspace/Segmentation/split_data_scripts/omnicampus に data ディレクトリがあると想定
# %cd /workspace/Segmentation/split_data_scripts_omnicampus

# omnicampus 実行用
!pip install h5py scikit-image

"""# import library"""

Mounted at /content/drive
Archive:  data.zip
  inflating: data/train/image/000600.png  
  inflating: data/train/image/000320.png  
  inflating: data/train/image/000491.png  
  inflating: data/train/image/000502.png  
  inflating: data/train/image/000129.png  
  inflating: data/train/image/000044.png  
  inflating: data/train/image/000652.png  
  inflating: data/train/image/000919.png  
  inflating: data/train/image/000528.png  
  inflating: data/train/image/000853.png  
  inflating: data/train/image/000177.png  
  inflating: data/train/image/000584.png  
  inflating: data/train/image/001319.png  
  inflating: data/train/image/000597.png  
  inflating: data/train/image/000223.png  
  inflating: data/train/image/001350.png  
  inflating: data/train/image/000404.png  
  inflating: data/train/image/000488.png  
  inflating: data/train/image/000268.png  
  inflating: data/train/image/000481.png  
  inflating: data/train/image/000341.png  
  inflating: data/train/image/000159.png  
  inflati

'# import library'

In [42]:
# -*- coding: utf-8 -*-
"""
NYUv2 baseline (RGBD) with optional heads/losses.
要求：coarse_w == 0 のときは、coarse の forward / loss 計算を“完全に”スキップする。

ポイント：
- epochごとに coarse_w を計算
- coarse_w==0 かつ region_loss_weight==0 のときは coarse を一切使わない
  -> model.forward でも coarse_head を呼ばない（GPU計算も発生しない）
"""

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"

import time
import random
import json
from dataclasses import dataclass
from zipfile import ZipFile, ZIP_DEFLATED

import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, random_split, Subset
from torchvision.datasets import VisionDataset
from torchvision.transforms import (
    Compose, Resize, ToTensor, Normalize, InterpolationMode,
    ColorJitter
)
from torchvision import models
from torch.amp import autocast, GradScaler


# =========================
# Class ID definition (YOUR dataset)
# =========================
ID_BED     = 0
ID_BOOK    = 1
ID_CEILING = 2
ID_CHAIR   = 3
ID_FLOOR   = 4
ID_CABINET = 5
ID_OBJECT  = 6
ID_PICTURE = 7
ID_SOFA    = 8
ID_DESK    = 9
ID_TV      = 10
ID_WALL    = 11
ID_WINDOW  = 12

CLASS_NAMES_13 = [
    "bed", "book", "ceiling", "chair", "floor", "cabinet", "object",
    "picture", "sofa", "desk", "tv", "wall", "window"
]


# =========================
# Seed
# =========================
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# =========================
# Dataset utilities
# =========================
def pil_label_to_long_tensor(lbl_pil: Image.Image) -> torch.Tensor:
    arr = np.array(lbl_pil, dtype=np.int64)
    return torch.from_numpy(arr).long()

def depth_pil_to_tensor_01(depth_pil: Image.Image, size):
    # size: (H,W)
    depth_pil = depth_pil.resize((size[1], size[0]), resample=Image.BILINEAR)
    arr = np.array(depth_pil)

    # 16-bit/8-bit を安全に 0..1 に
    if arr.dtype == np.uint16 or (arr.max() > 255):
        arr = arr.astype(np.float32) / 65535.0
    else:
        arr = arr.astype(np.float32) / 255.0

    arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
    arr = np.clip(arr, 0.0, 1.0)
    return torch.from_numpy(arr).unsqueeze(0)  # [1,H,W]


# =========================
# Transform: RandomGrayscale (3ch safe)
# =========================
class RandomGrayscale3ch:
    """
    torchvisionのRandomGrayscaleが古くても確実に動く版。
    確率pでRGB画像をグレースケール化し、必ずRGB(3ch)で返す。
    """
    def __init__(self, p=0.3):
        self.p = float(p)

    def __call__(self, img: Image.Image) -> Image.Image:
        if random.random() >= self.p:
            return img
        return img.convert("L").convert("RGB")


class NYUv2(VisionDataset):
    def __init__(
        self,
        root: str,
        split: str = "train",
        include_depth: bool = True,
        image_transform=None,
        depth_transform=None,
        target_transform=None,
    ):
        super().__init__(root)
        assert split in ("train", "test")
        self.root = root
        self.split = split
        self.include_depth = include_depth

        images_dir = os.path.join(self.root, self.split, "image")
        img_names = sorted(os.listdir(images_dir))

        self.images = [os.path.join(images_dir, n) for n in img_names]
        self.depths = [os.path.join(self.root, self.split, "depth", n) for n in img_names]

        if self.split == "train":
            self.targets = [os.path.join(self.root, self.split, "label", n) for n in img_names]
        else:
            self.targets = None

        self.image_transform = image_transform
        self.depth_transform = depth_transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert("RGB")
        depth = Image.open(self.depths[idx])

        if self.image_transform is not None:
            image = self.image_transform(image)
        if self.depth_transform is not None:
            depth = self.depth_transform(depth)

        if self.split == "test":
            return (image, depth) if self.include_depth else image

        target = Image.open(self.targets[idx])
        if self.target_transform is not None:
            target = self.target_transform(target)

        return (image, depth, target) if self.include_depth else (image, target)


# =========================
# Boundary mask + dilation
# =========================
def boundary_mask_from_label(label, ignore_index=255):
    """
    label: [B,H,W] long
    returns: bool [B,H,W] boundary pixels (4-neighborhood changes), ignore excluded
    """
    valid = (label != ignore_index)
    b = torch.zeros_like(label, dtype=torch.bool)
    b[:, 1:, :] |= (label[:, 1:, :] != label[:, :-1, :])
    b[:, :-1, :] |= (label[:, :-1, :] != label[:, 1:, :])
    b[:, :, 1:] |= (label[:, :, 1:] != label[:, :, :-1])
    b[:, :, :-1] |= (label[:, :, :-1] != label[:, :, 1:])
    b &= valid
    return b

def dilate_mask(mask_bool, radius: int):
    """
    bool [B,H,W] -> bool [B,H,W]
    max-poolで膨張。radius=2なら5x5相当。
    """
    if radius <= 0:
        return mask_bool
    x = mask_bool.float().unsqueeze(1)  # [B,1,H,W]
    k = 2 * radius + 1
    y = nn.functional.max_pool2d(x, kernel_size=k, stride=1, padding=radius)
    return (y.squeeze(1) > 0.5)


# =========================
# Model blocks
# =========================
class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, p=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=k, padding=p, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.block(x)

class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.conv1 = ConvBNReLU(in_ch + skip_ch, out_ch)
        self.conv2 = ConvBNReLU(out_ch, out_ch)

    def forward(self, x, skip):
        x = nn.functional.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x


# =========================
# SE (Depth -> Channel gate)  ※decoder側だけに適用
# =========================
class DepthStem(nn.Module):
    """
    Depth(1ch 0..1) -> feature map (default 1/8) for conditioning.
    """
    def __init__(self, z_ch=64):
        super().__init__()
        self.net = nn.Sequential(
            ConvBNReLU(1, 32, 3, 1),
            nn.MaxPool2d(2),       # 1/2
            ConvBNReLU(32, 48, 3, 1),
            nn.MaxPool2d(2),       # 1/4
            ConvBNReLU(48, z_ch, 3, 1),
            nn.MaxPool2d(2),       # 1/8
        )
    def forward(self, d):
        return self.net(d)

class SEFromDepth(nn.Module):
    """
    残差SE版：
      F' = F * (1 + alpha * (g - 0.5))
    """
    def __init__(self, feat_ch, z_ch, reduction, alpha):
        super().__init__()
        assert alpha is not None, "alpha must be provided from config"
        self.alpha = float(alpha)
        hidden = max(32, feat_ch // reduction)
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(z_ch, hidden, 1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden, feat_ch, 1, bias=True),
            nn.Sigmoid(),
        )

    def forward(self, F, z):
        g = self.fc(z)  # [B,C,1,1] in [0,1]
        scale = 1.0 + self.alpha * (g - 0.5)
        return F * scale


# =========================
# ResNet50 UNet
# - fine(13) + (optional) coarse(5) + boundary(1ch)
# =========================
class ResNet50UNet(nn.Module):
    """
    Outputs:
      - fine_logits:   [B,13,H,W]
      - coarse_logits: [B,5,H/8,W/8] or None
      - boundary_logit:[B,1,H,W]
    """
    def __init__(
        self,
        num_classes=13,
        coarse_classes=5,
        in_channels=4,
        pretrained=True,
        use_se=False,
        se_z_ch=64,
        se_reduction=16,
        se_alpha=None,
        se_dec_stages=("up3",),
    ):
        super().__init__()
        self.use_se = use_se
        self.se_dec_stages = tuple(se_dec_stages)

        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None)

        # conv1 拡張
        old_conv = resnet.conv1
        if in_channels != old_conv.in_channels:
            new_conv = nn.Conv2d(
                in_channels, old_conv.out_channels,
                kernel_size=old_conv.kernel_size,
                stride=old_conv.stride,
                padding=old_conv.padding,
                bias=False
            )
            with torch.no_grad():
                new_conv.weight[:, :3] = old_conv.weight
                if in_channels > 3:
                    mean_w = old_conv.weight.mean(dim=1, keepdim=True)
                    for c in range(3, in_channels):
                        new_conv.weight[:, c:c+1] = mean_w
            resnet.conv1 = new_conv

        # encoder
        self.enc0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)  # 1/2, 64
        self.pool = resnet.maxpool                                        # 1/4
        self.enc1 = resnet.layer1                                         # 1/4, 256
        self.enc2 = resnet.layer2                                         # 1/8, 512
        self.enc3 = resnet.layer3                                         # 1/16, 1024
        self.enc4 = resnet.layer4                                         # 1/32, 2048

        # decoder
        self.center = ConvBNReLU(2048, 1024)
        self.up4 = UpBlock(in_ch=1024, skip_ch=1024, out_ch=512)
        self.up3 = UpBlock(in_ch=512,  skip_ch=512,  out_ch=256)  # 1/8
        self.up2 = UpBlock(in_ch=256,  skip_ch=256,  out_ch=128)
        self.up1 = UpBlock(in_ch=128,  skip_ch=64,   out_ch=64)

        # fine head
        self.head_feat = ConvBNReLU(64, 64)
        self.head_out  = nn.Conv2d(64, num_classes, kernel_size=1)

        # coarse head @1/8（※必要なときだけ呼ぶ）
        self.coarse_head = nn.Conv2d(256, coarse_classes, kernel_size=1)

        # boundary head @H,W
        self.boundary_head = nn.Sequential(
            ConvBNReLU(64, 32),
            nn.Conv2d(32, 1, kernel_size=1)
        )

        # decoder-only SE
        if self.use_se:
            self.depth_stem = DepthStem(z_ch=se_z_ch)
            stage_to_ch = {
                "center": 1024,
                "up4": 512,
                "up3": 256,
                "up2": 128,
                "up1": 64,
            }
            self.se_dec = nn.ModuleDict()
            for st in self.se_dec_stages:
                assert st in stage_to_ch, f"Unknown se_dec_stage: {st}"
                self.se_dec[st] = SEFromDepth(
                    feat_ch=stage_to_ch[st],
                    z_ch=se_z_ch,
                    reduction=se_reduction,
                    alpha=se_alpha,
                )
        else:
            self.depth_stem = None
            self.se_dec = None

    def _apply_dec_se(self, feat, z, stage_name: str):
        if (not self.use_se) or (self.se_dec is None):
            return feat
        if stage_name not in self.se_dec:
            return feat
        z_ = nn.functional.interpolate(z, size=feat.shape[-2:], mode="bilinear", align_corners=False)
        return self.se_dec[stage_name](feat, z_)

    def forward(self, x, depth=None, need_coarse: bool = True):
        """
        need_coarse=False のとき：
          - coarse_head を呼ばない（計算しない）
          - coarse_logits は None を返す
        """
        # encoder
        c1 = self.enc0(x)          # 1/2
        t  = self.pool(c1)         # 1/4
        c2 = self.enc1(t)          # 1/4
        c3 = self.enc2(c2)         # 1/8
        c4 = self.enc3(c3)         # 1/16
        c5 = self.enc4(c4)         # 1/32

        # depth stem
        if self.use_se:
            assert depth is not None, "use_se=True の場合、forward に depth を渡してください"
            z = self.depth_stem(depth)  # 1/8
        else:
            z = None

        # decoder
        x = self.center(c5)
        if z is not None:
            x = self._apply_dec_se(x, z, "center")

        x = self.up4(x, c4)
        if z is not None:
            x = self._apply_dec_se(x, z, "up4")

        x = self.up3(x, c3)
        if z is not None:
            x = self._apply_dec_se(x, z, "up3")

        # coarse (optional)
        coarse_logits = None
        if need_coarse:
            coarse_logits = self.coarse_head(x)  # [B,5,H/8,W/8]

        x = self.up2(x, c2)
        if z is not None:
            x = self._apply_dec_se(x, z, "up2")

        x = self.up1(x, c1)
        if z is not None:
            x = self._apply_dec_se(x, z, "up1")

        # final res
        x = nn.functional.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)  # -> H,W
        feat = self.head_feat(x)
        fine_logits = self.head_out(feat)          # [B,13,H,W]
        boundary_logit = self.boundary_head(x)     # [B,1,H,W]

        return fine_logits, coarse_logits, boundary_logit


# =========================
# Losses
# =========================
def focal_ce_loss(logits, target, weight=None, ignore_index=255, gamma=2.0):
    ce = nn.functional.cross_entropy(
        logits, target, weight=weight, ignore_index=ignore_index, reduction="none"
    )
    pt = torch.exp(-ce)
    loss = ((1 - pt) ** gamma) * ce
    valid = (target != ignore_index)
    denom = valid.sum().clamp_min(1)
    return (loss * valid).sum() / denom

def main_seg_loss(
    logits, label,
    class_weight_tensor=None,
    use_focal=False,
    gamma=2.0,
    ignore_index=255
):
    if use_focal:
        return focal_ce_loss(logits, label, weight=class_weight_tensor, ignore_index=ignore_index, gamma=gamma)
    return nn.functional.cross_entropy(logits, label, weight=class_weight_tensor, ignore_index=ignore_index)

def downsample_label_nearest(label, size_hw):
    x = label.unsqueeze(1).float()
    x = nn.functional.interpolate(x, size=size_hw, mode="nearest")
    return x.squeeze(1).long()

def dice_loss_from_logits(logit, target01, valid01=None, eps=1e-6):
    """
    logit: [B,H,W]
    target01: [B,H,W] float 0/1
    valid01: [B,H,W] float 0/1
    """
    prob = torch.sigmoid(logit)
    if valid01 is None:
        valid01 = torch.ones_like(target01)
    prob = prob * valid01
    target01 = target01 * valid01

    inter = (prob * target01).sum()
    union = prob.sum() + target01.sum()
    return 1.0 - (2.0 * inter + eps) / (union + eps)

def boundary_focal_bce_dice_loss(
    boundary_logit, label,
    ignore_index=255,
    radius=2,
    gamma=2.0,
    alpha=0.25,
    dice_w=0.5,
    eps=1e-6
):
    """
    boundary_logit: [B,1,H,W]
    label: [B,H,W]
    - boundary mask を radiusで太らせる
    - Focal BCE + Dice（不均衡対策）
    """
    with torch.no_grad():
        b0 = boundary_mask_from_label(label, ignore_index=ignore_index)      # 1px
        b = dilate_mask(b0, radius=radius).float()                          # thick band 0/1
        valid = (label != ignore_index).float()

    logit = boundary_logit.squeeze(1)  # [B,H,W]

    # focal BCE (logits)
    p = torch.sigmoid(logit)
    pt = p * b + (1 - p) * (1 - b)
    w = (alpha * b + (1 - alpha) * (1 - b)) * ((1 - pt).clamp_min(1e-6) ** gamma)

    bce = nn.functional.binary_cross_entropy_with_logits(logit, b, reduction="none")
    loss_focal = (w * bce * valid).sum() / (valid.sum().clamp_min(1.0))

    # dice
    loss_dice = dice_loss_from_logits(logit, b, valid01=valid, eps=eps)

    return (1.0 - dice_w) * loss_focal + dice_w * loss_dice, b0, b


# =========================
# (2) Region consistency (coarseクラス別weight対応)
# =========================
def region_consistency_loss_coarse_as_mask(
    fine_logits, coarse_logits,
    non_boundary_mask,  # bool [B,H,W] True=内部領域
    detach_coarse=True,
    class_weights=None,  # (Cc,)
    eps=1e-6
):
    """
    coarse(5)を“領域マスク”として使い、
    coarse領域の内部(non-boundary)で fine(13)分布が一貫するようにする。
    """
    assert coarse_logits is not None, "region loss requires coarse_logits (but got None)"

    B, Cf, H, W = fine_logits.shape
    _, Cc, Hc, Wc = coarse_logits.shape

    fine_prob = fine_logits.softmax(dim=1)
    fine_prob_ds = nn.functional.interpolate(
        fine_prob, size=(Hc, Wc), mode="bilinear", align_corners=False
    )  # [B,13,Hc,Wc]

    coarse_prob = coarse_logits.softmax(dim=1)
    if detach_coarse:
        coarse_prob = coarse_prob.detach()

    # non-boundary mask を coarse 解像度へ
    nb = non_boundary_mask.float().unsqueeze(1)  # [B,1,H,W]
    nb_ds = nn.functional.interpolate(nb, size=(Hc, Wc), mode="nearest")
    nb_ds = (nb_ds > 0.5).float()

    if class_weights is None:
        cw = torch.ones(Cc, device=fine_logits.device, dtype=torch.float32)
    else:
        cw = torch.tensor(class_weights, device=fine_logits.device, dtype=torch.float32).clamp_min(0.0)

    loss = 0.0
    denom_all = 0.0

    for k in range(Cc):
        wk = cw[k]
        if wk.item() == 0.0:
            continue

        mk = coarse_prob[:, k:k+1] * nb_ds
        denom = mk.sum(dim=(2, 3), keepdim=True)
        denom_safe = denom + eps

        mean_k = (fine_prob_ds * mk).sum(dim=(2, 3), keepdim=True) / denom_safe  # [B,13,1,1]
        diff2 = (fine_prob_ds - mean_k) ** 2
        loss_k = (diff2 * mk).sum()

        loss = loss + wk * loss_k
        denom_all = denom_all + (wk * denom).sum()

    return loss / (denom_all + eps)


# =========================
# (4) Floor prior: depth flatness (soft)
# =========================
def depth_grad_mag(depth_01: torch.Tensor) -> torch.Tensor:
    kx = torch.tensor([[-1, 0, 1],
                       [-2, 0, 2],
                       [-1, 0, 1]], dtype=depth_01.dtype, device=depth_01.device).view(1, 1, 3, 3)
    ky = torch.tensor([[-1, -2, -1],
                       [ 0,  0,  0],
                       [ 1,  2,  1]], dtype=depth_01.dtype, device=depth_01.device).view(1, 1, 3, 3)
    gx = nn.functional.conv2d(depth_01, kx, padding=1)
    gy = nn.functional.conv2d(depth_01, ky, padding=1)
    g = torch.sqrt(gx * gx + gy * gy + 1e-12)
    return g

def floor_flat_prior_loss(
    fine_logits: torch.Tensor,     # [B,13,H,W]
    depth_01: torch.Tensor,        # [B,1,H,W]
    label: torch.Tensor,           # [B,H,W]
    floor_id: int = ID_FLOOR,
    ignore_index: int = 255,
    tau: float = 0.03,
    k: float = 30.0,
    eps: float = 1e-6
) -> torch.Tensor:
    with torch.no_grad():
        valid = (label != ignore_index).float()  # [B,H,W]
    prob = fine_logits.softmax(dim=1)[:, floor_id]  # [B,H,W]
    g = depth_grad_mag(depth_01).squeeze(1)         # [B,H,W]

    flat = torch.sigmoid(k * (tau - g))             # [B,H,W] in (0,1)
    loss_map = flat * (1.0 - prob) * valid
    return loss_map.sum() / (valid.sum().clamp_min(1.0) + eps)


# =========================
# Metrics
# =========================
@torch.no_grad()
def update_confusion_matrix(cm, pred, label, num_classes, ignore_index=255):
    pred = pred.view(-1)
    label = label.view(-1)
    mask = (label != ignore_index)
    pred = pred[mask]
    label = label[mask]
    if pred.numel() == 0:
        return cm
    idx = label * num_classes + pred
    binc = torch.bincount(idx, minlength=num_classes*num_classes).cpu()
    cm += binc.view(num_classes, num_classes)
    return cm

def compute_iou_from_cm(cm):
    tp = torch.diag(cm).float()
    fp = cm.sum(dim=0).float() - tp
    fn = cm.sum(dim=1).float() - tp
    denom = tp + fp + fn
    iou = torch.where(denom > 0, tp / denom, torch.zeros_like(denom))
    miou = iou.mean().item()
    return iou, miou, tp, fp, fn


# =========================
# Freeze helpers
# =========================
def set_requires_grad(module, flag: bool):
    for p in module.parameters():
        p.requires_grad = flag

def encoder_modules(model: ResNet50UNet):
    return [model.enc0, model.enc1, model.enc2, model.enc3, model.enc4]

def decoder_modules(model: ResNet50UNet):
    mods = [
        model.center, model.up4, model.up3, model.up2, model.up1,
        model.head_feat, model.head_out,
        model.coarse_head, model.boundary_head
    ]
    if model.use_se:
        mods += [model.depth_stem, model.se_dec]
    return mods


# =========================
# Config
# =========================
@dataclass
class TrainingConfig:
    dataset_root: str = "/content/data"
    batch_size: int = 16
    num_workers: int = 4
    num_classes: int = 13
    coarse_classes: int = 5

    epochs: int = 50
    learning_rate: float = 1e-4
    weight_decay: float = 1e-4
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    checkpoint_dir: str = "checkpoints"
    image_size: tuple = (512, 512)

    # RGB normalize (ImageNet)
    normalize_mean: tuple = (0.485, 0.456, 0.406)
    normalize_std: tuple = (0.229, 0.224, 0.225)

    # split / ignore
    train_val_split: float = 0.9
    ignore_index: int = 255

    # reproducibility
    seed: int = 42

    # input
    input_mode: str = "rgbd"  # only rgbd

    # SE
    use_se: bool = False
    se_z_ch: int = 64
    se_reduction: int = 16
    se_alpha: float = 0.05
    se_dec_stages: tuple = ("up1",)

    # encoder warmup
    encoder_warmup_epochs: int = 0
    encoder_lr_scale: float = 0.1

    # focal seg (optional)
    use_focal: bool = False
    focal_gamma: float = 2.0

    # grayscale/jitter (train only)
    gray_p: float = 0.0
    jitter_brightness: float = 0.4
    jitter_contrast: float = 0.4
    jitter_saturation: float = 0.3
    jitter_hue: float = 0.05

    # coarse loss schedule
    coarse_w_stage1: float = 0.0
    coarse_w_stage2: float = 0.0
    coarse_w_stage3: float = 0.0
    coarse_epoch_stage2: int = 20
    coarse_epoch_stage3: int = 35

    # region consistency (coarse必須)
    region_loss_weight: float = 0.0
    detach_coarse_in_region_loss: bool = True
    non_boundary_radius: int = 2
    region_class_weights: tuple = (0.0, 1.0, 0.6, 1.0, 1.0)

    # boundary loss
    boundary_loss_weight: float = 0.25
    boundary_radius: int = 2
    boundary_focal_gamma: float = 2.0
    boundary_focal_alpha: float = 0.25
    boundary_dice_w: float = 0.5

    # manual class weights
    floor_weight_mul: float = 1.0
    picture_weight_mul: float = 1.0
    window_weight_mul: float = 1.0
    book_weight_mul: float = 1.0

    # floor prior
    floor_prior_weight: float = 0.0
    floor_prior_tau: float = 0.03
    floor_prior_k: float = 30.0

    # logging
    log_path: str = "train_log.jsonl"

    def __post_init__(self):
        os.makedirs(self.checkpoint_dir, exist_ok=True)


# =========================
# Coarse mapping (13 -> 5)  [YOUR IDs]
# 0: structure (wall,floor,ceiling)
# 1: surface   (desk)
# 2: objects   (bed, chair, sofa, cabinet, book, object, tv)
# 3: opening   (window)
# 4: deco      (picture)
# =========================
def build_coarse_map_13_to_5(device):
    m = torch.zeros(13, dtype=torch.long, device=device)

    for i in [ID_WALL, ID_FLOOR, ID_CEILING]:
        m[i] = 0
    for i in [ID_DESK]:
        m[i] = 1
    for i in [ID_BED, ID_BOOK, ID_CHAIR, ID_SOFA, ID_CABINET, ID_OBJECT, ID_TV]:
        m[i] = 2
    for i in [ID_WINDOW]:
        m[i] = 3
    for i in [ID_PICTURE]:
        m[i] = 4
    return m

def map_to_coarse_label(label, coarse_map, ignore_index=255):
    out = label.clone()
    valid = (out != ignore_index)
    out[valid] = coarse_map[out[valid]]
    return out


# =========================
# (1) Class weights (manual)
# =========================
def build_manual_class_weight(
    num_classes,
    floor_id, floor_mul=1.0,
    picture_id=None, picture_mul=1.0,
    window_id=None, window_mul=1.0,
    book_id=None, book_mul=1.0,
    device="cpu"
):
    w = torch.ones(num_classes, dtype=torch.float32, device=device)
    if floor_mul is not None and float(floor_mul) != 1.0:
        w[floor_id] = float(floor_mul)
    if picture_id is not None and picture_mul is not None and float(picture_mul) != 1.0:
        w[picture_id] = float(picture_mul)
    if window_id is not None and window_mul is not None and float(window_mul) != 1.0:
        w[window_id] = float(window_mul)
    if book_id is not None and book_mul is not None and float(book_mul) != 1.0:
        w[book_id] = float(book_mul)
    return w


# =========================
# Build input tensor
# =========================
def build_input(rgb, depth, mode: str):
    rgb_f = rgb.float()
    depth_f = depth.float()
    if mode == "rgbd":
        x = torch.cat([rgb_f, depth_f], dim=1)  # [B,4,H,W]
        return x, depth_f
    raise ValueError(f"Unknown input_mode: {mode}")


# =========================
# Coarse loss schedule helper
# =========================
def get_coarse_weight(config: TrainingConfig, epoch_idx0: int) -> float:
    e = epoch_idx0 + 1
    if e <= config.coarse_epoch_stage2:
        return float(config.coarse_w_stage1)
    if e <= config.coarse_epoch_stage3:
        return float(config.coarse_w_stage2)
    return float(config.coarse_w_stage3)


# =========================
# Main
# =========================
config = TrainingConfig(
    dataset_root="/content/data",
    batch_size=16,
    num_workers=4,
    learning_rate=1e-4,
    epochs=50,
    image_size=(512, 512),
    input_mode="rgbd",

    use_se=False,
    gray_p=0.0,

    # coarse/region/floorprior を完全OFFにする想定
    coarse_w_stage1=0.0,
    coarse_w_stage2=0.0,
    coarse_w_stage3=0.0,

    region_loss_weight=0.0,
    floor_prior_weight=0.0,

    # boundaryだけON
    boundary_loss_weight=0,

    # class weightもOFF
    floor_weight_mul=1.0,
    picture_weight_mul=1.0,
    window_weight_mul=1.0,
    book_weight_mul=1.0,
)

run_ts = time.strftime("%Y%m%d%H%M%S")
base, ext = os.path.splitext(config.log_path)
if ext.lower() != ".jsonl":
    ext = ".jsonl"
config.log_path = f"{base}_{run_ts}{ext}"

set_seed(config.seed)

# transforms
train_image_transform = Compose([
    Resize(config.image_size, interpolation=InterpolationMode.BILINEAR),
    RandomGrayscale3ch(p=config.gray_p),
    ColorJitter(
        brightness=config.jitter_brightness,
        contrast=config.jitter_contrast,
        saturation=config.jitter_saturation,
        hue=config.jitter_hue,
    ),
    ToTensor(),
    Normalize(config.normalize_mean, config.normalize_std),
])

eval_image_transform = Compose([
    Resize(config.image_size, interpolation=InterpolationMode.BILINEAR),
    ToTensor(),
    Normalize(config.normalize_mean, config.normalize_std),
])

def depth_transform(depth_pil):
    return depth_pil_to_tensor_01(depth_pil, config.image_size)

def target_transform(lbl_pil: Image.Image) -> torch.Tensor:
    lbl_pil = lbl_pil.resize((config.image_size[1], config.image_size[0]), resample=Image.NEAREST)
    return pil_label_to_long_tensor(lbl_pil)

# split with fixed indices
split_base = NYUv2(
    root=config.dataset_root,
    split="train",
    include_depth=True,
    image_transform=eval_image_transform,
    depth_transform=depth_transform,
    target_transform=target_transform,
)
n_total = len(split_base)
n_train = int(n_total * config.train_val_split)
n_val = n_total - n_train
g = torch.Generator().manual_seed(config.seed)
train_subset_base, val_subset_base = random_split(split_base, [n_train, n_val], generator=g)
train_indices = train_subset_base.indices
val_indices = val_subset_base.indices

train_full = NYUv2(
    root=config.dataset_root,
    split="train",
    include_depth=True,
    image_transform=train_image_transform,
    depth_transform=depth_transform,
    target_transform=target_transform,
)
val_full = NYUv2(
    root=config.dataset_root,
    split="train",
    include_depth=True,
    image_transform=eval_image_transform,
    depth_transform=depth_transform,
    target_transform=target_transform,
)
train_ds = Subset(train_full, train_indices)
val_ds   = Subset(val_full, val_indices)

test_ds = NYUv2(
    root=config.dataset_root,
    split="test",
    include_depth=True,
    image_transform=eval_image_transform,
    depth_transform=depth_transform,
    target_transform=None,
)

train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True,
                          num_workers=config.num_workers, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False,
                        num_workers=config.num_workers, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False,
                         num_workers=config.num_workers, pin_memory=True)

device = config.device
print(f"Using device: {device}")
print(f"Run timestamp: {run_ts}")
print(f"Log path     : {config.log_path}")

coarse_map = build_coarse_map_13_to_5(device)

model = ResNet50UNet(
    num_classes=config.num_classes,
    coarse_classes=config.coarse_classes,
    in_channels=4,
    pretrained=True,
    use_se=config.use_se,
    se_z_ch=config.se_z_ch,
    se_reduction=config.se_reduction,
    se_alpha=config.se_alpha,
    se_dec_stages=config.se_dec_stages,
).to(device)

class_weight_tensor = build_manual_class_weight(
    num_classes=config.num_classes,
    floor_id=ID_FLOOR,
    floor_mul=config.floor_weight_mul,
    picture_id=ID_PICTURE,
    picture_mul=config.picture_weight_mul,
    window_id=ID_WINDOW,
    window_mul=config.window_weight_mul,
    book_id=ID_BOOK,
    book_mul=config.book_weight_mul,
    device=device
)

# optimizer: encoder/decoder LR差
enc_params = []
for m in encoder_modules(model):
    enc_params += list(m.parameters())

dec_params = []
for m in decoder_modules(model):
    dec_params += list(m.parameters())

optimizer = optim.Adam(
    [
        {"params": enc_params, "lr": config.learning_rate * config.encoder_lr_scale},
        {"params": dec_params, "lr": config.learning_rate},
    ],
    weight_decay=config.weight_decay
)

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.epochs)
scaler = GradScaler("cuda" if str(device).startswith("cuda") else "cpu")

# log init
with open(config.log_path, "w", encoding="utf-8") as f:
    pass

def log_json(obj):
    with open(config.log_path, "a", encoding="utf-8") as f:
        f.write(json.dumps(obj, ensure_ascii=False) + "\n")

log_json({
    "type": "run_header",
    "run_ts": run_ts,
    "device": str(device),
    "dataset_root": config.dataset_root,
    "image_size": list(config.image_size),
    "batch_size": config.batch_size,
    "epochs": config.epochs,
    "lr": config.learning_rate,
    "weight_decay": config.weight_decay,
    "encoder_lr_scale": config.encoder_lr_scale,
    "input_mode": config.input_mode,
    "use_se": config.use_se,
    "gray_p": float(config.gray_p),

    "class_names_13": CLASS_NAMES_13,
    "id_map": {
        "bed": ID_BED, "book": ID_BOOK, "ceiling": ID_CEILING, "chair": ID_CHAIR,
        "floor": ID_FLOOR, "cabinet": ID_CABINET, "object": ID_OBJECT, "picture": ID_PICTURE,
        "sofa": ID_SOFA, "desk": ID_DESK, "tv": ID_TV, "wall": ID_WALL, "window": ID_WINDOW,
    },

    "coarse_w_stage1": float(config.coarse_w_stage1),
    "coarse_w_stage2": float(config.coarse_w_stage2),
    "coarse_w_stage3": float(config.coarse_w_stage3),
    "region_loss_weight": float(config.region_loss_weight),
    "boundary_loss_weight": float(config.boundary_loss_weight),
    "floor_prior_weight": float(config.floor_prior_weight),
})


for epoch in range(config.epochs):
    # optional freeze
    encoder_frozen = (epoch < config.encoder_warmup_epochs)
    for m in encoder_modules(model):
        set_requires_grad(m, not encoder_frozen)

    if encoder_frozen:
        optimizer.param_groups[0]["lr"] = 0.0
    else:
        optimizer.param_groups[0]["lr"] = config.learning_rate * config.encoder_lr_scale
    optimizer.param_groups[1]["lr"] = config.learning_rate

    # scheduled coarse weight
    coarse_w = get_coarse_weight(config, epoch)

    # ★ここが肝：coarse を “計算ごと” 切る条件
    # coarse loss を使う or region loss が使う（regionはcoarse_logits必須）場合だけ有効化
    coarse_needed = (coarse_w > 0.0) or (config.region_loss_weight > 0.0)

    model.train()
    total_loss = 0.0
    total_main = 0.0
    total_coarse = 0.0
    total_region = 0.0
    total_boundary = 0.0
    total_floor_prior = 0.0
    seen = 0

    with tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.epochs} (train)") as pbar:
        for (rgb, depth, label) in pbar:
            rgb = rgb.to(device, non_blocking=True)
            depth = depth.to(device, non_blocking=True)
            label = label.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)

            with autocast(device_type="cuda" if str(device).startswith("cuda") else "cpu", enabled=False):
                x, d_for_se = build_input(rgb, depth, config.input_mode)

            with autocast(device_type="cuda" if str(device).startswith("cuda") else "cpu", enabled=True):
                fine_logits, coarse_logits, boundary_logit = model(
                    x,
                    depth=d_for_se if config.use_se else None,
                    need_coarse=coarse_needed,   # ★coarse_w=0ならここでforward計算も止まる
                )

                # 1) main seg
                loss_main = main_seg_loss(
                    fine_logits, label,
                    class_weight_tensor=class_weight_tensor,
                    use_focal=config.use_focal,
                    gamma=config.focal_gamma,
                    ignore_index=config.ignore_index
                )

                # 2) coarse CE（必要なときだけ）
                if coarse_w > 0.0:
                    # coarse_logits は coarse_needed=True のときだけ non-None
                    coarse_label_full = map_to_coarse_label(label, coarse_map, ignore_index=config.ignore_index)
                    coarse_label = downsample_label_nearest(coarse_label_full, coarse_logits.shape[-2:])
                    loss_coarse = nn.functional.cross_entropy(coarse_logits, coarse_label, ignore_index=config.ignore_index)
                else:
                    loss_coarse = torch.tensor(0.0, device=device)

                # boundary mask（region用 non-boundary）
                with torch.no_grad():
                    b1 = boundary_mask_from_label(label, ignore_index=config.ignore_index)
                    b_thick_for_region = dilate_mask(b1, radius=config.non_boundary_radius)
                    valid = (label != config.ignore_index)
                    non_boundary = valid & (~b_thick_for_region)

                # 3) region consistency（必要なときだけ）
                if config.region_loss_weight > 0.0:
                    loss_region = region_consistency_loss_coarse_as_mask(
                        fine_logits, coarse_logits,
                        non_boundary_mask=non_boundary,
                        detach_coarse=config.detach_coarse_in_region_loss,
                        class_weights=config.region_class_weights
                    )
                else:
                    loss_region = torch.tensor(0.0, device=device)

                # 4) boundary loss（weight>0なら）
                if config.boundary_loss_weight > 0.0:
                    loss_boundary, _b0, _bthick = boundary_focal_bce_dice_loss(
                        boundary_logit, label,
                        ignore_index=config.ignore_index,
                        radius=config.boundary_radius,
                        gamma=config.boundary_focal_gamma,
                        alpha=config.boundary_focal_alpha,
                        dice_w=config.boundary_dice_w
                    )
                else:
                    loss_boundary = torch.tensor(0.0, device=device)

                # 5) floor prior（weight>0なら）
                if config.floor_prior_weight > 0.0:
                    loss_floor_prior = floor_flat_prior_loss(
                        fine_logits=fine_logits,
                        depth_01=d_for_se,
                        label=label,
                        floor_id=ID_FLOOR,
                        ignore_index=config.ignore_index,
                        tau=config.floor_prior_tau,
                        k=config.floor_prior_k
                    )
                else:
                    loss_floor_prior = torch.tensor(0.0, device=device)

                loss = (
                    loss_main
                    + coarse_w * loss_coarse
                    + config.region_loss_weight * loss_region
                    + config.boundary_loss_weight * loss_boundary
                    + config.floor_prior_weight * loss_floor_prior
                )

            if torch.isnan(loss) or torch.isinf(loss):
                log_json({"type": "train_nan", "epoch": epoch + 1, "msg": "loss is NaN/Inf, skipped batch"})
                continue

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            bs = rgb.size(0)
            total_loss += float(loss.item()) * bs
            total_main += float(loss_main.item()) * bs
            total_coarse += float(loss_coarse.item()) * bs
            total_region += float(loss_region.item()) * bs
            total_boundary += float(loss_boundary.item()) * bs
            total_floor_prior += float(loss_floor_prior.item()) * bs
            seen += bs

            pbar.set_postfix(
                loss=float(loss.item()),
                main=float(loss_main.item()),
                coarse=float(loss_coarse.item()),
                coarse_w=float(coarse_w),
                region=float(loss_region.item()),
                b=float(loss_boundary.item()),
                floor_p=float(loss_floor_prior.item()),
                coarse_need=bool(coarse_needed),
            )

    train_loss = total_loss / max(1, seen)
    train_main = total_main / max(1, seen)
    train_coarse = total_coarse / max(1, seen)
    train_region = total_region / max(1, seen)
    train_boundary = total_boundary / max(1, seen)
    train_floor_prior = total_floor_prior / max(1, seen)
    scheduler.step()

    # validation
    model.eval()
    cm = torch.zeros((config.num_classes, config.num_classes), dtype=torch.long)
    val_loss_sum = 0.0
    val_main_sum = 0.0
    val_coarse_sum = 0.0
    val_region_sum = 0.0
    val_boundary_sum = 0.0
    val_floor_prior_sum = 0.0
    val_seen = 0

    with torch.no_grad():
        for (rgb, depth, label) in tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.epochs} (val)", leave=False):
            rgb = rgb.to(device, non_blocking=True)
            depth = depth.to(device, non_blocking=True)
            label = label.to(device, non_blocking=True)

            with autocast(device_type="cuda" if str(device).startswith("cuda") else "cpu", enabled=False):
                x, d_for_se = build_input(rgb, depth, config.input_mode)

            with autocast(device_type="cuda" if str(device).startswith("cuda") else "cpu", enabled=True):
                fine_logits, coarse_logits, boundary_logit = model(
                    x,
                    depth=d_for_se if config.use_se else None,
                    need_coarse=coarse_needed,
                )

                loss_main = main_seg_loss(
                    fine_logits, label,
                    class_weight_tensor=class_weight_tensor,
                    use_focal=config.use_focal,
                    gamma=config.focal_gamma,
                    ignore_index=config.ignore_index
                )

                if coarse_w > 0.0:
                    coarse_label_full = map_to_coarse_label(label, coarse_map, ignore_index=config.ignore_index)
                    coarse_label = downsample_label_nearest(coarse_label_full, coarse_logits.shape[-2:])
                    loss_coarse = nn.functional.cross_entropy(coarse_logits, coarse_label, ignore_index=config.ignore_index)
                else:
                    loss_coarse = torch.tensor(0.0, device=device)

                with torch.no_grad():
                    b1 = boundary_mask_from_label(label, ignore_index=config.ignore_index)
                    b_thick_for_region = dilate_mask(b1, radius=config.non_boundary_radius)
                    valid = (label != config.ignore_index)
                    non_boundary = valid & (~b_thick_for_region)

                if config.region_loss_weight > 0.0:
                    loss_region = region_consistency_loss_coarse_as_mask(
                        fine_logits, coarse_logits,
                        non_boundary_mask=non_boundary,
                        detach_coarse=config.detach_coarse_in_region_loss,
                        class_weights=config.region_class_weights
                    )
                else:
                    loss_region = torch.tensor(0.0, device=device)

                if config.boundary_loss_weight > 0.0:
                    loss_boundary, _b0, _bthick = boundary_focal_bce_dice_loss(
                        boundary_logit, label,
                        ignore_index=config.ignore_index,
                        radius=config.boundary_radius,
                        gamma=config.boundary_focal_gamma,
                        alpha=config.boundary_focal_alpha,
                        dice_w=config.boundary_dice_w
                    )
                else:
                    loss_boundary = torch.tensor(0.0, device=device)

                if config.floor_prior_weight > 0.0:
                    loss_floor_prior = floor_flat_prior_loss(
                        fine_logits=fine_logits,
                        depth_01=d_for_se,
                        label=label,
                        floor_id=ID_FLOOR,
                        ignore_index=config.ignore_index,
                        tau=config.floor_prior_tau,
                        k=config.floor_prior_k
                    )
                else:
                    loss_floor_prior = torch.tensor(0.0, device=device)

                vloss = (
                    loss_main
                    + coarse_w * loss_coarse
                    + config.region_loss_weight * loss_region
                    + config.boundary_loss_weight * loss_boundary
                    + config.floor_prior_weight * loss_floor_prior
                )

            pred = fine_logits.argmax(dim=1)
            cm = update_confusion_matrix(cm, pred.cpu(), label.cpu(), config.num_classes, ignore_index=config.ignore_index)

            val_loss_sum += float(vloss.item())
            val_main_sum += float(loss_main.item())
            val_coarse_sum += float(loss_coarse.item())
            val_region_sum += float(loss_region.item())
            val_boundary_sum += float(loss_boundary.item())
            val_floor_prior_sum += float(loss_floor_prior.item())
            val_seen += 1

    iou, miou, tp, fp, fn = compute_iou_from_cm(cm)
    iou_list = iou.tolist()
    worst = sorted(range(config.num_classes), key=lambda c: iou_list[c])[:5]
    best  = sorted(range(config.num_classes), key=lambda c: iou_list[c], reverse=True)[:5]
    pred_hist = cm.sum(dim=0).tolist()
    gt_hist   = cm.sum(dim=1).tolist()

    best_names  = [CLASS_NAMES_13[i] for i in best]
    worst_names = [CLASS_NAMES_13[i] for i in worst]

    lr_enc = optimizer.param_groups[0]["lr"]
    lr_dec = optimizer.param_groups[1]["lr"]

    print(
        f"Epoch {epoch+1}: "
        f"train_loss={train_loss:.4f} (main={train_main:.4f}, coarse={train_coarse:.4f}, region={train_region:.4f}, b={train_boundary:.4f}, floor_p={train_floor_prior:.4f}) "
        f"val_loss={val_loss_sum/max(1,val_seen):.4f} (main={val_main_sum/max(1,val_seen):.4f}, coarse={val_coarse_sum/max(1,val_seen):.4f}, region={val_region_sum/max(1,val_seen):.4f}, b={val_boundary_sum/max(1,val_seen):.4f}, floor_p={val_floor_prior_sum/max(1,val_seen):.4f}) "
        f"val_mIoU={miou:.5f} "
        f"(freeze={encoder_frozen}, se={config.use_se}, coarse_w={coarse_w:.3f}, coarse_needed={coarse_needed})"
    )

    log_json({
        "type": "epoch",
        "run_ts": run_ts,
        "epoch": epoch + 1,

        "train_loss": train_loss,
        "train_main": train_main,
        "train_coarse": train_coarse,
        "train_region": train_region,
        "train_boundary": train_boundary,
        "train_floor_prior": train_floor_prior,

        "val_loss": val_loss_sum / max(1, val_seen),
        "val_main": val_main_sum / max(1, val_seen),
        "val_coarse": val_coarse_sum / max(1, val_seen),
        "val_region": val_region_sum / max(1, val_seen),
        "val_boundary": val_boundary_sum / max(1, val_seen),
        "val_floor_prior": val_floor_prior_sum / max(1, val_seen),

        "val_miou": miou,
        "class_iou": iou_list,
        "tp": tp.tolist(),
        "fp": fp.tolist(),
        "fn": fn.tolist(),
        "pred_hist": pred_hist,
        "gt_hist": gt_hist,
        "best_classes": best,
        "worst_classes": worst,
        "best_names": best_names,
        "worst_names": worst_names,

        "lr_enc": lr_enc,
        "lr_dec": lr_dec,
        "encoder_frozen": encoder_frozen,

        "use_se": config.use_se,
        "gray_p": float(config.gray_p),

        "coarse_classes": config.coarse_classes,
        "coarse_w": float(coarse_w),
        "coarse_needed": bool(coarse_needed),
        "region_loss_weight": float(config.region_loss_weight),

        "boundary_loss_weight": float(config.boundary_loss_weight),
        "boundary_radius": int(config.boundary_radius),

        "floor_prior_weight": float(config.floor_prior_weight),
        "floor_id": int(ID_FLOOR),
    })


# =========================
# Save
# =========================
current_time = time.strftime("%Y%m%d%H%M%S")
model_path = os.path.join(config.checkpoint_dir, f"model_{current_time}.pt")
torch.save(model.state_dict(), model_path)
print(f"Model saved to {model_path}")
print(f"Log saved to {config.log_path}")


# =========================
# Inference -> submission.npy（fineだけ使う）
# ※coarseは常に不要なので need_coarse=False で固定
# =========================
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

predictions = []
with torch.no_grad():
    for (rgb, depth) in tqdm(test_loader, desc="Generating predictions"):
        rgb = rgb.to(device, non_blocking=True)
        depth = depth.to(device, non_blocking=True)

        with autocast(device_type="cuda" if str(device).startswith("cuda") else "cpu", enabled=False):
            x = torch.cat([rgb.float(), depth.float()], dim=1)
            d_for_se = depth.float()

        with autocast(device_type="cuda" if str(device).startswith("cuda") else "cpu", enabled=True):
            fine_logits, _coarse, _b = model(
                x,
                depth=d_for_se if config.use_se else None,
                need_coarse=False,   # ★推論でもcoarse計算を完全にしない
            )
            pred = fine_logits.argmax(dim=1)

        predictions.append(pred.cpu())

predictions = torch.cat(predictions, dim=0).numpy()
np.save("submission.npy", predictions)
print("Predictions saved to submission.npy")


# =========================
# Zip for submission
# =========================
notebook_path = "/content/drive/MyDrive/Colab Notebooks/DL_Basic_2025_Competition_NYUv2_baseline.ipynb"

with ZipFile("submission.zip", mode="w", compression=ZIP_DEFLATED, compresslevel=9) as zf:
    zf.write("submission.npy")
    zf.write(model_path, arcname=os.path.basename(model_path))
    zf.write(config.log_path, arcname=os.path.basename(config.log_path))
    if os.path.exists(notebook_path):
        zf.write(notebook_path, arcname="DL_Basic_2025_Competition_NYUv2_baseline.ipynb")

print("Created submission.zip")


Using device: cuda
Run timestamp: 20260103105723
Log path     : train_log_20260103105723.jsonl


Epoch 1/50 (train):  80%|████████  | 36/45 [00:10<00:01,  4.57it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=1.76, main=1.76, region=0]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7bb0644e8360>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7bb0644e8360>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dat

Epoch 1: train_loss=2.0156 (main=2.0156, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.6512 (main=1.6512, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.17960 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 2/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.49it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=1.41, main=1.41, region=0]


Epoch 2: train_loss=1.5536 (main=1.5536, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.4378 (main=1.4378, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.27632 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 3/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.48it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=1.37, main=1.37, region=0]


Epoch 3: train_loss=1.3512 (main=1.3512, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.3552 (main=1.3552, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.30799 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 4/50 (train): 100%|██████████| 45/45 [00:13<00:00,  3.44it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=1.26, main=1.26, region=0]


Epoch 4: train_loss=1.1918 (main=1.1918, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.2484 (main=1.2484, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.33549 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 5/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.48it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=1.09, main=1.09, region=0]


Epoch 5: train_loss=1.0608 (main=1.0608, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.2705 (main=1.2705, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.35417 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 6/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.49it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=1.04, main=1.04, region=0]


Epoch 6: train_loss=0.9455 (main=0.9455, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.1465 (main=1.1465, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.41117 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 7/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.51it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.745, main=0.745, region=0]


Epoch 7: train_loss=0.8441 (main=0.8441, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.1284 (main=1.1284, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.41682 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 8/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.49it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.729, main=0.729, region=0]


Epoch 8: train_loss=0.7374 (main=0.7374, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.1106 (main=1.1106, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.41521 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 9/50 (train): 100%|██████████| 45/45 [00:13<00:00,  3.46it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.731, main=0.731, region=0]


Epoch 9: train_loss=0.6731 (main=0.6731, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.1376 (main=1.1376, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.42399 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 10/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.53it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.766, main=0.766, region=0]


Epoch 10: train_loss=0.6145 (main=0.6145, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0761 (main=1.0761, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.41707 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 11/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.51it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.756, main=0.756, region=0]


Epoch 11: train_loss=0.5511 (main=0.5511, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0322 (main=1.0322, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.42934 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 12/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.51it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.703, main=0.703, region=0]


Epoch 12: train_loss=0.5032 (main=0.5032, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0316 (main=1.0316, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.43821 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 13/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.51it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.376, main=0.376, region=0]


Epoch 13: train_loss=0.4520 (main=0.4520, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0034 (main=1.0034, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.44466 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 14/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.53it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.447, main=0.447, region=0]


Epoch 14: train_loss=0.4118 (main=0.4118, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0210 (main=1.0210, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.45076 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 15/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.54it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.413, main=0.413, region=0]


Epoch 15: train_loss=0.3801 (main=0.3801, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0135 (main=1.0135, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.47452 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 16/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.47it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.386, main=0.386, region=0]


Epoch 16: train_loss=0.3562 (main=0.3562, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=0.9934 (main=0.9934, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.47061 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 17/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.48it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.329, main=0.329, region=0]


Epoch 17: train_loss=0.3255 (main=0.3255, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.1004 (main=1.1004, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.48279 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 18/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.50it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.358, main=0.358, region=0]


Epoch 18: train_loss=0.3074 (main=0.3074, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=0.9816 (main=0.9816, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.47428 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 19/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.52it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.278, main=0.278, region=0]


Epoch 19: train_loss=0.2816 (main=0.2816, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=0.9975 (main=0.9975, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.49246 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 20/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.51it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.303, main=0.303, region=0]


Epoch 20: train_loss=0.2667 (main=0.2667, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0128 (main=1.0128, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.48775 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 21/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.49it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.257, main=0.257, region=0]


Epoch 21: train_loss=0.2436 (main=0.2436, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=0.9655 (main=0.9655, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.49746 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 22/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.49it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.287, main=0.287, region=0]


Epoch 22: train_loss=0.2257 (main=0.2257, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0017 (main=1.0017, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.49024 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 23/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.47it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.229, main=0.229, region=0]


Epoch 23: train_loss=0.2141 (main=0.2141, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0164 (main=1.0164, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.48563 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 24/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.48it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.249, main=0.249, region=0]


Epoch 24: train_loss=0.2071 (main=0.2071, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0068 (main=1.0068, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.48485 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 25/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.50it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.192, main=0.192, region=0]


Epoch 25: train_loss=0.1921 (main=0.1921, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=0.9793 (main=0.9793, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.50354 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 26/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.51it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.238, main=0.238, region=0]


Epoch 26: train_loss=0.1778 (main=0.1778, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0063 (main=1.0063, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.49364 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 27/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.51it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.249, main=0.249, region=0]


Epoch 27: train_loss=0.1791 (main=0.1791, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=0.9796 (main=0.9796, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.50287 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 28/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.53it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.195, main=0.195, region=0]


Epoch 28: train_loss=0.1616 (main=0.1616, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0146 (main=1.0146, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.49577 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 29/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.50it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.139, main=0.139, region=0]


Epoch 29: train_loss=0.1537 (main=0.1537, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=0.9968 (main=0.9968, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.49900 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 30/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.50it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.191, main=0.191, region=0]


Epoch 30: train_loss=0.1457 (main=0.1457, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0074 (main=1.0074, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.50072 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 31/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.49it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.178, main=0.178, region=0]


Epoch 31: train_loss=0.1389 (main=0.1389, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0115 (main=1.0115, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.49855 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 32/50 (train): 100%|██████████| 45/45 [00:13<00:00,  3.45it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.119, main=0.119, region=0]


Epoch 32: train_loss=0.1288 (main=0.1288, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0061 (main=1.0061, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.49638 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 33/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.51it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.181, main=0.181, region=0]


Epoch 33: train_loss=0.1222 (main=0.1222, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0315 (main=1.0315, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.48845 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 34/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.49it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.17, main=0.17, region=0]


Epoch 34: train_loss=0.1176 (main=0.1176, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0264 (main=1.0264, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.50337 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 35/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.51it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.121, main=0.121, region=0]


Epoch 35: train_loss=0.1189 (main=0.1189, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0025 (main=1.0025, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.50257 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 36/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.49it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.14, main=0.14, region=0]


Epoch 36: train_loss=0.1190 (main=0.1190, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0353 (main=1.0353, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.50542 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 37/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.49it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.11, main=0.11, region=0]


Epoch 37: train_loss=0.1151 (main=0.1151, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0270 (main=1.0270, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.50106 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 38/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.51it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.102, main=0.102, region=0]


Epoch 38: train_loss=0.1025 (main=0.1025, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0334 (main=1.0334, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.50848 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 39/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.46it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.0937, main=0.0937, region=0]


Epoch 39: train_loss=0.0989 (main=0.0989, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0575 (main=1.0575, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.50221 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 40/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.51it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.0834, main=0.0834, region=0]


Epoch 40: train_loss=0.0946 (main=0.0946, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0585 (main=1.0585, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.50405 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 41/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.53it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.124, main=0.124, region=0]


Epoch 41: train_loss=0.0984 (main=0.0984, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0595 (main=1.0595, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.50993 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 42/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.49it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.0961, main=0.0961, region=0]


Epoch 42: train_loss=0.0932 (main=0.0932, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.0673 (main=1.0673, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.51026 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 43/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.49it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.112, main=0.112, region=0]


Epoch 43: train_loss=0.0992 (main=0.0992, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.1280 (main=1.1280, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.50312 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 44/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.51it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.137, main=0.137, region=0]


Epoch 44: train_loss=0.1066 (main=0.1066, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.1122 (main=1.1122, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.49626 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 45/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.53it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.0959, main=0.0959, region=0]


Epoch 45: train_loss=0.0924 (main=0.0924, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.1016 (main=1.1016, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.49454 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 46/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.51it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.0847, main=0.0847, region=0]


Epoch 46: train_loss=0.0901 (main=0.0901, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.1085 (main=1.1085, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.49437 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 47/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.49it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.0906, main=0.0906, region=0]


Epoch 47: train_loss=0.0865 (main=0.0865, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.1290 (main=1.1290, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.50130 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 48/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.49it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.103, main=0.103, region=0]


Epoch 48: train_loss=0.0898 (main=0.0898, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.1467 (main=1.1467, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.48151 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 49/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.50it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.0665, main=0.0665, region=0]


Epoch 49: train_loss=0.0852 (main=0.0852, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.1287 (main=1.1287, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.49960 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)


Epoch 50/50 (train): 100%|██████████| 45/45 [00:12<00:00,  3.53it/s, b=0, coarse=0, coarse_need=0, coarse_w=0, floor_p=0, loss=0.0964, main=0.0964, region=0]


Epoch 50: train_loss=0.0801 (main=0.0801, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_loss=1.1037 (main=1.1037, coarse=0.0000, region=0.0000, b=0.0000, floor_p=0.0000) val_mIoU=0.49976 (freeze=False, se=False, coarse_w=0.000, coarse_needed=False)
Model saved to checkpoints/model_20260103110955.pt
Log saved to train_log_20260103105723.jsonl


Generating predictions: 100%|██████████| 654/654 [00:11<00:00, 58.16it/s]


Predictions saved to submission.npy
Created submission.zip


In [37]:
# -*- coding: utf-8 -*-
# NYUv2 val visualization (RGB | GT | Pred | RGB+Error) with FIXED palette -> ZIP
# - NO matplotlib colormap
# - GT and Pred share EXACT same palette mapping (ID -> RGB)
# - ignore_index(255) is shown as black by default

import os
import zipfile
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, Subset
from torchvision.datasets import VisionDataset
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
from torchvision import models

# =========================
# EDIT HERE
# =========================
DATASET_ROOT = "/content/data"
IMAGE_SIZE = (512, 512)

MODEL_PATH = "/content/checkpoints/model_20260103062841.pt"  # ←ここだけ直す
OUTDIR = "val_viz_fixed"
ZIP_PATH = "val_viz_fixed.zip"

MAX_SAVE = 50  # Noneで全件
SEED = 42
TRAIN_VAL_SPLIT = 0.9

NUM_WORKERS = 0
BATCH_SIZE = 1

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IGNORE_INDEX = 255

print("DEVICE:", DEVICE)
print("MODEL_PATH:", MODEL_PATH)

# =========================
# FIXED PALETTE (ID -> RGB)
# ここが「色の意味」そのもの。GTもPredも必ずこれを通す。
# =========================
PALETTE_13 = np.array([
    [  0,   0,   0],   # 0 wall
    [128,  64,  32],   # 1 floor
    [  0, 128, 255],   # 2 cabinet
    [255,   0, 128],   # 3 bed
    [  0, 255,   0],   # 4 chair
    [255, 128,   0],   # 5 sofa
    [255, 255,   0],   # 6 table
    [128, 128, 128],   # 7 door
    [  0, 255, 255],   # 8 window
    [  0,  64, 128],   # 9 bookshelf
    [255,   0,   0],   # 10 picture
    [128,   0, 255],   # 11 counter
    [ 64, 255, 128],   # 12 desk
], dtype=np.uint8)

# =========================
# Dataset helpers
# =========================
def pil_label_to_long_tensor(lbl_pil: Image.Image) -> torch.Tensor:
    arr = np.array(lbl_pil, dtype=np.int64)
    return torch.from_numpy(arr).long()

def depth_pil_to_tensor_01(depth_pil: Image.Image, size_hw):
    depth_pil = depth_pil.resize((size_hw[1], size_hw[0]), resample=Image.BILINEAR)
    arr = np.array(depth_pil)
    if arr.dtype == np.uint16 or (arr.max() > 255):
        arr = arr.astype(np.float32) / 65535.0
    else:
        arr = arr.astype(np.float32) / 255.0
    arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
    arr = np.clip(arr, 0.0, 1.0)
    return torch.from_numpy(arr).unsqueeze(0)  # [1,H,W]

class NYUv2(VisionDataset):
    def __init__(self, root, split="train", include_depth=True,
                 image_transform=None, depth_transform=None, target_transform=None):
        super().__init__(root)
        assert split in ("train", "test")
        self.root = root
        self.split = split
        self.include_depth = include_depth

        images_dir = os.path.join(self.root, self.split, "image")
        img_names = sorted(os.listdir(images_dir))

        self.images = [os.path.join(images_dir, n) for n in img_names]
        self.depths = [os.path.join(self.root, self.split, "depth", n) for n in img_names]

        if self.split == "train":
            self.targets = [os.path.join(self.root, self.split, "label", n) for n in img_names]
        else:
            self.targets = None

        self.image_transform = image_transform
        self.depth_transform = depth_transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert("RGB")
        depth = Image.open(self.depths[idx])

        if self.image_transform is not None:
            image = self.image_transform(image)
        if self.depth_transform is not None:
            depth = self.depth_transform(depth)

        if self.split == "test":
            return (image, depth) if self.include_depth else image

        target = Image.open(self.targets[idx])
        if self.target_transform is not None:
            target = self.target_transform(target)

        return (image, depth, target) if self.include_depth else (image, target)

# =========================
# Model blocks (ResNet50-UNet + optional SE)
# =========================
class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, p=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=k, padding=p, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.block(x)

class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.conv1 = ConvBNReLU(in_ch + skip_ch, out_ch)
        self.conv2 = ConvBNReLU(out_ch, out_ch)

    def forward(self, x, skip):
        x = nn.functional.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class DepthStem(nn.Module):
    def __init__(self, z_ch=64):
        super().__init__()
        self.net = nn.Sequential(
            ConvBNReLU(1, 32, 3, 1),
            nn.MaxPool2d(2),
            ConvBNReLU(32, 48, 3, 1),
            nn.MaxPool2d(2),
            ConvBNReLU(48, z_ch, 3, 1),
            nn.MaxPool2d(2),
        )
    def forward(self, d):
        return self.net(d)

class SEFromDepth(nn.Module):
    def __init__(self, feat_ch, z_ch, reduction=16, alpha=0.05):
        super().__init__()
        self.alpha = float(alpha)
        hidden = max(32, feat_ch // reduction)
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(z_ch, hidden, 1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden, feat_ch, 1, bias=True),
            nn.Sigmoid(),
        )
    def forward(self, F, z):
        g = self.fc(z)  # [B,C,1,1]
        scale = 1.0 + self.alpha * (g - 0.5)
        return F * scale

class ResNet50UNet(nn.Module):
    def __init__(self, num_classes=13, coarse_classes=5, in_channels=4,
                 pretrained=False,
                 use_se=False, se_z_ch=64, se_reduction=16, se_alpha=0.05,
                 se_dec_stages=("up1",)):
        super().__init__()
        self.use_se = use_se
        self.se_dec_stages = tuple(se_dec_stages)

        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None)

        old_conv = resnet.conv1
        if in_channels != old_conv.in_channels:
            new_conv = nn.Conv2d(
                in_channels, old_conv.out_channels,
                kernel_size=old_conv.kernel_size,
                stride=old_conv.stride,
                padding=old_conv.padding,
                bias=False
            )
            with torch.no_grad():
                new_conv.weight[:, :3] = old_conv.weight
                mean_w = old_conv.weight.mean(dim=1, keepdim=True)
                for c in range(3, in_channels):
                    new_conv.weight[:, c:c+1] = mean_w
            resnet.conv1 = new_conv

        self.enc0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)  # 1/2
        self.pool = resnet.maxpool                                        # 1/4
        self.enc1 = resnet.layer1                                         # 1/4
        self.enc2 = resnet.layer2                                         # 1/8
        self.enc3 = resnet.layer3                                         # 1/16
        self.enc4 = resnet.layer4                                         # 1/32

        self.center = ConvBNReLU(2048, 1024)
        self.up4 = UpBlock(1024, 1024, 512)
        self.up3 = UpBlock(512,  512,  256)
        self.up2 = UpBlock(256,  256,  128)
        self.up1 = UpBlock(128,  64,   64)

        self.head_feat = ConvBNReLU(64, 64)
        self.head_out  = nn.Conv2d(64, num_classes, kernel_size=1)

        self.coarse_head = nn.Conv2d(256, coarse_classes, kernel_size=1)
        self.boundary_head = nn.Sequential(ConvBNReLU(64, 32), nn.Conv2d(32, 1, 1))

        if self.use_se:
            self.depth_stem = DepthStem(z_ch=se_z_ch)
            stage_to_ch = {"center":1024, "up4":512, "up3":256, "up2":128, "up1":64}
            self.se_dec = nn.ModuleDict()
            for st in self.se_dec_stages:
                self.se_dec[st] = SEFromDepth(stage_to_ch[st], se_z_ch, se_reduction, se_alpha)
        else:
            self.depth_stem = None
            self.se_dec = None

    def _apply_dec_se(self, feat, z, stage_name: str):
        if (not self.use_se) or (self.se_dec is None) or (stage_name not in self.se_dec):
            return feat
        z_ = nn.functional.interpolate(z, size=feat.shape[-2:], mode="bilinear", align_corners=False)
        return self.se_dec[stage_name](feat, z_)

    def forward(self, x, depth=None):
        c1 = self.enc0(x)
        t  = self.pool(c1)
        c2 = self.enc1(t)
        c3 = self.enc2(c2)
        c4 = self.enc3(c3)
        c5 = self.enc4(c4)

        z = self.depth_stem(depth) if self.use_se else None

        x = self.center(c5)
        if z is not None: x = self._apply_dec_se(x, z, "center")
        x = self.up4(x, c4)
        if z is not None: x = self._apply_dec_se(x, z, "up4")
        x = self.up3(x, c3)
        if z is not None: x = self._apply_dec_se(x, z, "up3")

        coarse_logits = self.coarse_head(x)

        x = self.up2(x, c2)
        if z is not None: x = self._apply_dec_se(x, z, "up2")
        x = self.up1(x, c1)
        if z is not None: x = self._apply_dec_se(x, z, "up1")

        x = nn.functional.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
        feat = self.head_feat(x)
        fine_logits = self.head_out(feat)
        boundary_logit = self.boundary_head(x)
        return fine_logits, coarse_logits, boundary_logit

# =========================
# FIXED colorization (ID -> RGB)
# =========================
def id_to_rgb(mask_hw: np.ndarray, ignore_index=255, ignore_rgb=(0,0,0)) -> np.ndarray:
    """
    mask_hw: [H,W] int (0..12 or 255)
    returns: [H,W,3] uint8
    """
    out = np.zeros((mask_hw.shape[0], mask_hw.shape[1], 3), dtype=np.uint8)
    ign = (mask_hw == ignore_index)
    valid = ~ign
    m = mask_hw.copy()
    m[ign] = 0
    m = np.clip(m, 0, 12)
    out[valid] = PALETTE_13[m[valid]]
    out[ign] = np.array(ignore_rgb, dtype=np.uint8)
    return out

def denorm_rgb(rgb_tensor: torch.Tensor) -> Image.Image:
    mean = torch.tensor([0.485, 0.456, 0.406], device=rgb_tensor.device).view(3,1,1)
    std  = torch.tensor([0.229, 0.224, 0.225], device=rgb_tensor.device).view(3,1,1)
    x = (rgb_tensor * std + mean).clamp(0, 1)
    x = (x * 255).byte().permute(1,2,0).cpu().numpy()
    return Image.fromarray(x, mode="RGB")

def overlay_error(rgb_pil: Image.Image, gt_hw: np.ndarray, pred_hw: np.ndarray, ignore_index=255) -> Image.Image:
    rgb = np.array(rgb_pil).astype(np.float32)
    err = (gt_hw != ignore_index) & (gt_hw != pred_hw)
    if err.any():
        overlay = rgb.copy()
        overlay[err] = 0.6 * overlay[err] + 0.4 * np.array([255, 0, 0], dtype=np.float32)
        overlay = np.clip(overlay, 0, 255).astype(np.uint8)
        return Image.fromarray(overlay, mode="RGB")
    return rgb_pil

def stack_h(images):
    widths = [im.width for im in images]
    heights = [im.height for im in images]
    out = Image.new("RGB", (sum(widths), max(heights)))
    x = 0
    for im in images:
        out.paste(im, (x, 0))
        x += im.width
    return out

# =========================
# Transforms (same as eval)
# =========================
eval_image_transform = Compose([
    Resize(IMAGE_SIZE, interpolation=InterpolationMode.BILINEAR),
    ToTensor(),
    Normalize((0.485,0.456,0.406), (0.229,0.224,0.225)),
])

def depth_transform(pil):
    return depth_pil_to_tensor_01(pil, IMAGE_SIZE)

def target_transform(pil):
    pil = pil.resize((IMAGE_SIZE[1], IMAGE_SIZE[0]), resample=Image.NEAREST)
    return pil_label_to_long_tensor(pil)

# =========================
# Build val split (same indices rule)
# =========================
split_base = NYUv2(
    root=DATASET_ROOT,
    split="train",
    include_depth=True,
    image_transform=eval_image_transform,
    depth_transform=depth_transform,
    target_transform=target_transform,
)

n_total = len(split_base)
n_train = int(n_total * TRAIN_VAL_SPLIT)
n_val = n_total - n_train
g = torch.Generator().manual_seed(SEED)
_, val_subset = random_split(split_base, [n_train, n_val], generator=g)
val_ds = Subset(split_base, val_subset.indices)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"Val samples: {len(val_ds)}")

# =========================
# Load checkpoint + auto-detect SE
# =========================
ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
if isinstance(ckpt, dict) and "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
    sd = ckpt["state_dict"]
else:
    sd = ckpt

has_se = any(k.startswith("depth_stem.") or k.startswith("se_dec.") for k in sd.keys())
print("Detected SE in checkpoint:", has_se)

model = ResNet50UNet(
    num_classes=13,
    coarse_classes=5,
    in_channels=4,
    pretrained=False,
    use_se=has_se,
    se_z_ch=64,
    se_reduction=16,
    se_alpha=0.05,
    se_dec_stages=("up1",),
).to(DEVICE)

model.load_state_dict(sd, strict=True)
model.eval()

# =========================
# Generate panels with FIXED palette
# =========================
os.makedirs(OUTDIR, exist_ok=True)

saved = 0
with torch.no_grad():
    for i, (rgb, depth, label) in enumerate(tqdm(val_loader, desc="Saving fixed-color panels")):
        rgb = rgb.to(DEVICE, non_blocking=True)
        depth = depth.to(DEVICE, non_blocking=True)
        label = label.to(DEVICE, non_blocking=True)

        x = torch.cat([rgb.float(), depth.float()], dim=1)
        fine_logits, _, _ = model(x, depth=depth.float() if has_se else None)
        pred = fine_logits.argmax(dim=1)

        rgb_pil = denorm_rgb(rgb[0])

        gt_hw = label[0].cpu().numpy().astype(np.int32)
        pred_hw = pred[0].cpu().numpy().astype(np.int32)

        gt_rgb = id_to_rgb(gt_hw, ignore_index=IGNORE_INDEX, ignore_rgb=(0,0,0))
        pr_rgb = id_to_rgb(pred_hw, ignore_index=IGNORE_INDEX, ignore_rgb=(0,0,0))

        gt_pil = Image.fromarray(gt_rgb, mode="RGB")
        pr_pil = Image.fromarray(pr_rgb, mode="RGB")
        err_pil = overlay_error(rgb_pil, gt_hw, pred_hw, ignore_index=IGNORE_INDEX)

        panel = stack_h([rgb_pil, gt_pil, pr_pil, err_pil])
        panel.save(os.path.join(OUTDIR, f"{i:06d}.png"), optimize=True)

        saved += 1
        if MAX_SAVE is not None and saved >= MAX_SAVE:
            break

print(f"Saved: {saved} images -> {OUTDIR}")

# =========================
# Zip
# =========================
with zipfile.ZipFile(ZIP_PATH, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as zf:
    for fn in sorted(os.listdir(OUTDIR)):
        if fn.lower().endswith(".png"):
            zf.write(os.path.join(OUTDIR, fn), arcname=fn)

print(f"ZIP created: {ZIP_PATH}")


DEVICE: cuda
MODEL_PATH: /content/checkpoints/model_20260103062841.pt
Val samples: 80
Detected SE in checkpoint: False


  return Image.fromarray(x, mode="RGB")
  gt_pil = Image.fromarray(gt_rgb, mode="RGB")
  pr_pil = Image.fromarray(pr_rgb, mode="RGB")
  return Image.fromarray(overlay, mode="RGB")
Saving fixed-color panels:  61%|██████▏   | 49/80 [00:27<00:17,  1.78it/s]


Saved: 50 images -> val_viz_fixed
ZIP created: val_viz_fixed.zip


In [33]:
%ls -l checkpoints

total 4093968
-rw-r--r-- 1 root root 233561111 Jan  3 01:19 model_20260103011918.pt
-rw-r--r-- 1 root root 233561111 Jan  3 01:32 model_20260103013257.pt
-rw-r--r-- 1 root root 232859155 Jan  3 01:53 model_20260103015327.pt
-rw-r--r-- 1 root root 232859155 Jan  3 02:13 model_20260103021307.pt
-rw-r--r-- 1 root root 232859155 Jan  3 02:40 model_20260103024015.pt
-rw-r--r-- 1 root root 232800147 Jan  3 03:13 model_20260103031301.pt
-rw-r--r-- 1 root root 232800147 Jan  3 03:31 model_20260103033149.pt
-rw-r--r-- 1 root root 232800147 Jan  3 03:48 model_20260103034848.pt
-rw-r--r-- 1 root root 232800147 Jan  3 04:14 model_20260103041457.pt
-rw-r--r-- 1 root root 232814185 Jan  3 04:38 model_20260103043835.pt
-rw-r--r-- 1 root root 232814185 Jan  3 04:51 model_20260103045108.pt
-rw-r--r-- 1 root root 232882241 Jan  3 05:16 model_20260103051603.pt
-rw-r--r-- 1 root root 232883265 Jan  3 05:41 model_20260103054113.pt
-rw-r--r-- 1 root root 232688655 Jan  3 06:28 model_20260103062841.pt
-rw-r-

In [17]:
# ------------------
#    Evaluation (DeepLabV3-ResNet50 / ResNet backbone)
# ------------------

ckpt = torch.load(final_path, map_location=device)

# もし「学習で checkpoint dict（model_state入り）」で保存している場合にも対応
if isinstance(ckpt, dict) and "model_state" in ckpt:
    model.load_state_dict(ckpt["model_state"])
else:
    model.load_state_dict(ckpt)

model.eval()

predictions = []

with torch.no_grad():
    print("Generating predictions...")
    for image, depth in tqdm(test_data):
        image = image.to(device, non_blocking=True)
        depth = depth.to(device, non_blocking=True)

        # 念のため：depthが3ch事故のときは1chに落とす
        if depth.dim() == 4 and depth.size(1) == 3:
            depth = depth[:, :1, :, :]

        x = torch.cat((image, depth), dim=1)   # [B,4,H,W]

        # DeepLabV3 は dict を返すので "out" を使う
        out = model(x)["out"]                  # [B,num_classes,H,W]
        pred = out.argmax(dim=1)               # [B,H,W]

        predictions.append(pred.cpu())

predictions = torch.cat(predictions, dim=0).numpy()
np.save("submission.npy", predictions)
print("Predictions saved to submission.npy")


"""## 提出方法

以下の3点をzip化し，Omnicampusの「最終課題 (セグメンテーション)」から提出してください．

- `submission.npy`
- `model.pt`や`model_best.pt`など，テストに使用した重み（拡張子は`.pt`のみ）
- 本Colab Notebook
"""

from zipfile import ZipFile, ZIP_DEFLATED

notebook_path = "/content/drive/MyDrive/Colab Notebooks/DL_Basic_2025_Competition_NYUv2_baseline.ipynb"

from zipfile import ZipFile, ZIP_DEFLATED

notebook_path = "/content/drive/MyDrive/code.ipynb"

with ZipFile("submission.zip",
             mode="w",
             compression=ZIP_DEFLATED,
             compresslevel=9) as zf:
    zf.write("submission.npy")
    zf.write(model_path)
#    zf.write(notebook_path,
#             arcname="code.ipynb")


NameError: name 'final_path' is not defined