<a href="https://colab.research.google.com/github/trie0000/external/blob/main/code_20260104_1510.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# -*- coding: utf-8 -*-
"""DL_Basic_2025_Competition_NYUv2_baseline.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/17t7uAU0aST5aUt6sIJCqFneGlQrxFrJY

# Deep Learning 基礎講座　最終課題: NYUv2 セマンティックセグメンテーションcoarse_w=0」

## 概要
RGB画像から、画像内の各ピクセルがどのクラスに属するかを予測するセマンティックセグメンテーションタスク.

### データセット
- データセット: NYUv2 dataset
- 訓練データ: 795枚
- テストデータ: 654枚
- 入力: RGB画像 + 深度マップ（元画像サイズは可変）
- 出力: 13クラスのセグメンテーションマップ
- 評価指標: Mean IoU (Intersection over Union)

### データセットの詳細（[NYU Depth Dataset V2](https://cs.nyu.edu/~fergus/datasets/nyu_depth_v2.html)）
- 画像は屋内シーンを撮影したもので、家具や壁、床などの物体が含まれています.
- 各画像に対して13クラスのセグメンテーションラベルが提供されます.
- データは以下のディレクトリ構造で提供:
```
data/NYUv2/
├─train/
│  ├─image/      # RGB画像
│  │    000000.png
│  │    ...
│  │
│  ├─depth/      # 深度マップ
│  │    000000.png
│  │    ...
│  │
│  └─label/      # 13クラスセグメンテーション（教師ラベル）
│       000000.png
│       ...
└─test/
   ├─image/      # RGB画像
   │    000000.png
   │    ...
   │  ├─depth/   # 深度マップ
   │    000000.png
   │    ...
```

### タスクの詳細
- 入力のRGB画像と深度マップから、各ピクセルが13クラスのどれに属するかを予測するタスクです.
- 評価はMean IoUを使用します．
  - 各クラスごとにIoUを計算し、その平均を取ります.
  - IoUは以下の式で計算:
  $$IoU = \frac{TP}{TP + FP + FN}$$
    - TP: True Positive（正しく予測されたピクセル数）
    - FP: False Positive（誤って予測されたピクセル数）
    - FN: False Negative（見逃したピクセル数）

### 前処理
- 入力画像は512×512にリサイズされます.
- ピクセル値は0-1に正規化されます.
- セグメンテーションラベルは0-12の整数値（13クラス）です．
  - 255はignore index（評価から除外）

### 提出形式
- テスト画像（RGB + Depth）の各ピクセルに対してクラス（0~12）を予測したものをnumpy配列として保存されます.
- ファイル名: `submission.npy`
- 配列の形状: [テストデータ数, 高さ, 幅]
- 各ピクセルの値: 0-12の整数（予測クラス）

## 考えられる工夫の例
- 事前学習モデルの fine-tuning
    - ImageNetなどで事前学習されたモデルを本データセットでfine-tuningすることで性能向上が見込めます.
- 損失関数の再設計
    - クラスごとの出現頻度に応じて損失を補正するように損失関数を設計すると、クラス分布の不均衡に対してロバストな学習ができます.
- 画像の前処理
    - RandomResizedCrop / Flip / ColorJitter 等のデータ拡張を追加することで，汎化性能の向上が見込めます．

## 修了要件を満たす条件
- ベースラインでは，omnicampus 上での性能評価において， 38.2% となります．したがって，ベースラインである 38.2% を超えた提出のみ，修了要件として認めます．
- ベースラインから改善を加えることで， 50%以上に性能向上することを運営で確認しています．こちらを 1つの指標として取り組んでみてください．

## 注意点
- 最終的な予測モデルは，**配布している訓練データを用いて学習**（ファインチューニング含む）したものとしてください．
- 学習を行わず，**事前学習済みモデルの知識のみを利用した推論は禁止**します．
（例: ChatGPT 等の LLM に入力して推論を得るのみ）

### 事前学習モデルの利用
許可される事項
- **構成要素としての事前学習モデルの利用**: 自身で実装したアーキテクチャの一部（特徴抽出，埋め込みなど）として事前学習モデル（BERT，ViT など）を利用することは可能です．
- **ファインチューニング**: 上記の用途で利用している事前学習モデルのファインチューニングは可能です．

禁止される事項
- **タスク解決用の事前学習モデルの利用**: transformers などで提供されている，対象タスクを直接解くための事前学習モデルでそのまま推論のみ，またはファインチューニングのみで利用することは禁止とします．
  - 禁止事項の例: VQA タスクを直接解くための事前学習モデルを VQA タスクで利用する．

### データの準備
データをダウンロードした際に，google drive したため，利用するために google drive をマウントする必要があります．また， drive 上で展開することができないため，/content ディレクトリ下にコピーし "data.zip" を展開します．
google drive 上に "data.zip" が配置されていない場合は実行できません．google drive 上に "data.zip" (**831MB**) を配置することが可能であれば，"data_download.ipynb" を先に実行してください．難しい場合は，omnicampus 演習環境を利用してください．．
"""

# omnicampus 上では 4 セル目まで実行不要
# ドライブのマウント
from google.colab import drive
drive.mount('/content/drive')

# データダウンロード用の notebook にてgoogle drive への保存後，
# 反映に時間がかかる可能性がありますので，google drive のマウント後，
# data.zip がディレクトリ内にあることを確認してから実行してください．
# data.zip を /content 下にコピーする
!cp "/content/drive/MyDrive/data.zip" "/content"

# Commented out IPython magic to ensure Python compatibility.
# カレントディレクトリ下のファイル群を確認
# data.zip が表示されれば問題ないです
# %ls

# データを解凍する
!unzip data.zip
!mkdir data
!mv train test data/

"""omnicampus 演習環境では，data_download.ipynb のマウント，zip 化，drive へのコピーを実行しないことで，"data.zip" を解凍した形で配置されます．したがって，data ディレクトリが存在するディレクトリをカレントディレクトリとするだけで良いです．


"""

# Commented out IPython magic to ensure Python compatibility.
# omnicampus 実行用
# 以下の例では/workspace/Segmentation/split_data_scripts/omnicampus に data ディレクトリがあると想定
# %cd /workspace/Segmentation/split_data_scripts_omnicampus

# omnicampus 実行用
!pip install h5py scikit-image

"""# import library"""

Mounted at /content/drive
Archive:  data.zip
  inflating: data/train/image/000600.png  
  inflating: data/train/image/000320.png  
  inflating: data/train/image/000491.png  
  inflating: data/train/image/000502.png  
  inflating: data/train/image/000129.png  
  inflating: data/train/image/000044.png  
  inflating: data/train/image/000652.png  
  inflating: data/train/image/000919.png  
  inflating: data/train/image/000528.png  
  inflating: data/train/image/000853.png  
  inflating: data/train/image/000177.png  
  inflating: data/train/image/000584.png  
  inflating: data/train/image/001319.png  
  inflating: data/train/image/000597.png  
  inflating: data/train/image/000223.png  
  inflating: data/train/image/001350.png  
  inflating: data/train/image/000404.png  
  inflating: data/train/image/000488.png  
  inflating: data/train/image/000268.png  
  inflating: data/train/image/000481.png  
  inflating: data/train/image/000341.png  
  inflating: data/train/image/000159.png  
  inflati

'# import library'

In [None]:
# =========================
# UNUSED: Focal CE
# 使う条件: config.use_focal=True
# =========================
import torch
import torch.nn as nn

def focal_ce_loss(logits, target, weight=None, ignore_index=255, gamma=2.0):
    ce = nn.functional.cross_entropy(
        logits, target, weight=weight, ignore_index=ignore_index, reduction="none"
    )
    pt = torch.exp(-ce)
    loss = ((1 - pt) ** gamma) * ce
    valid = (target != ignore_index)
    denom = valid.sum().clamp_min(1)
    return (loss * valid).sum() / denom

# =========================
# UNUSED: Boundary mask / dilation / loss
# 使う条件: config.boundary_w>0 かつ model.use_boundary_head=True
# =========================
import torch
import torch.nn as nn

def boundary_mask_from_label(label, ignore_index=255):
    valid = (label != ignore_index)
    b = torch.zeros_like(label, dtype=torch.bool)
    b[:, 1:, :] |= (label[:, 1:, :] != label[:, :-1, :])
    b[:, :-1, :] |= (label[:, :-1, :] != label[:, 1:, :])
    b[:, :, 1:] |= (label[:, :, 1:] != label[:, :, :-1])
    b[:, :, :-1] |= (label[:, :, :-1] != label[:, :, 1:])
    b &= valid
    return b

def dilate_mask(mask_bool, radius: int):
    if radius <= 0:
        return mask_bool
    x = mask_bool.float().unsqueeze(1)  # [B,1,H,W]
    k = 2 * radius + 1
    y = nn.functional.max_pool2d(x, kernel_size=k, stride=1, padding=radius)
    return (y.squeeze(1) > 0.5)

def boundary_focal_bce_dice_loss(
    boundary_logit, label,
    ignore_index=255,
    radius=2,
    gamma=2.0,
    alpha=0.25,
    dice_w=0.5,
    eps=1e-6
):
    with torch.no_grad():
        b0 = boundary_mask_from_label(label, ignore_index=ignore_index)
        b = dilate_mask(b0, radius=radius).float()
        valid = (label != ignore_index).float()

    logit = boundary_logit.squeeze(1)  # [B,H,W]
    p = torch.sigmoid(logit)
    pt = p * b + (1 - p) * (1 - b)
    w = (alpha * b + (1 - alpha) * (1 - b)) * ((1 - pt).clamp_min(1e-6) ** gamma)

    bce = nn.functional.binary_cross_entropy_with_logits(logit, b, reduction="none")
    loss_focal = (w * bce * valid).sum() / (valid.sum().clamp_min(1.0))

    prob = p * valid
    target01 = b * valid
    inter = (prob * target01).sum()
    union = prob.sum() + target01.sum()
    loss_dice = 1.0 - (2.0 * inter + eps) / (union + eps)

    return (1.0 - dice_w) * loss_focal + dice_w * loss_dice

# =========================
# UNUSED: Coarse supervision helpers
# 使う条件: config.coarse_w>0 かつ 13->coarse(5) のmapを実装してCEする
# =========================
import torch
import torch.nn as nn

def downsample_label_nearest(label, size_hw):
    x = label.unsqueeze(1).float()
    x = nn.functional.interpolate(x, size=size_hw, mode="nearest")
    return x.squeeze(1).long()

In [25]:
# -*- coding: utf-8 -*-
"""
NYUv2: ResNeXt101_32x8d + DeepLabV3+ (Option 1: Powerful Backbone)
Resolution: 768x768
Batch Size: 16 (Adjust based on 80GB VRAM usage)
"""

import os
# メモリ断片化対策
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"

import time
import random
import json
import shutil
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from zipfile import ZipFile, ZIP_DEFLATED

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, random_split, Subset, Dataset
from torchvision import models
from torch.amp import autocast, GradScaler

import albumentations as A
from albumentations.pytorch import ToTensorV2

# =========================
# Class Definitions
# =========================
CLASS_NAMES = [
    "Bed", "Book", "Ceiling", "Chair", "Floor", "Cabinet", "Object",
    "Picture", "Sofa", "Desk", "TV", "Wall", "Window"
]

# =========================
# Settings & Seed
# =========================
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

# =========================
# Helper Functions
# =========================
def estimate_height_from_depth(depth_np):
    H, W = depth_np.shape
    y_grid = np.linspace(0, 1, H).reshape(H, 1).repeat(W, axis=1).astype(np.float32)
    height_map = y_grid * depth_np
    max_val = height_map.max()
    if max_val > 0:
        height_map = height_map / max_val
    return height_map.astype(np.float32)

# =========================
# Transforms
# =========================
def get_transforms(split='train', height=768, width=768):
    if split == 'train':
        pixel_transform = A.Compose([
            A.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1, p=0.5),
            A.GaussNoise(p=0.2),
            A.Blur(blur_limit=3, p=0.1),
            A.ToGray(p=0.1),
        ])
        sync_transform = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, p=0.5, border_mode=0),
            A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),
            A.CoarseDropout(
                num_holes_range=(4, 12),
                hole_height_range=(32, 96),
                hole_width_range=(32, 96),
                p=0.3
            ),
            A.Resize(height=height, width=width),
        ], additional_targets={'depth': 'image', 'height': 'image', 'label': 'mask'})
        return {'pixel': pixel_transform, 'sync': sync_transform}
    else:
        sync_transform = A.Compose([
            A.Resize(height=height, width=width),
        ], additional_targets={'depth': 'image', 'height': 'image', 'label': 'mask'})
        return {'pixel': None, 'sync': sync_transform}

# =========================
# Dataset Class
# =========================
class NYUv2Dataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None, return_label=True):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.return_label = return_label

        src_split = 'train' if split in ['train', 'val'] else 'test'
        self.images_dir = os.path.join(root_dir, src_split, 'image')
        self.depths_dir = os.path.join(root_dir, src_split, 'depth')
        self.labels_dir = os.path.join(root_dir, src_split, 'label') if src_split == 'train' else None

        if not os.path.exists(self.images_dir):
            raise FileNotFoundError(f"Directory not found: {self.images_dir}")
        self.filenames = sorted([f for f in os.listdir(self.images_dir) if f.endswith('.png')])

        self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32)

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fname = self.filenames[idx]
        rgb = np.array(Image.open(os.path.join(self.images_dir, fname)).convert('RGB'))
        depth = np.array(Image.open(os.path.join(self.depths_dir, fname)))
        if len(depth.shape) == 3: depth = depth[:, :, 0]

        if depth.max() > 255:
            depth = depth.astype(np.float32) / 65535.0
        else:
            depth = depth.astype(np.float32) / 255.0

        height_map = estimate_height_from_depth(depth)

        if self.labels_dir and self.return_label:
            label = np.array(Image.open(os.path.join(self.labels_dir, fname)))
        else:
            label = np.zeros(depth.shape, dtype=np.int32)

        if self.transform:
            if isinstance(self.transform, dict):
                if self.transform.get('pixel'):
                    rgb = self.transform['pixel'](image=rgb)['image']
                if self.transform.get('sync'):
                    augmented = self.transform['sync'](image=rgb, depth=depth, height=height_map, label=label)
                    rgb, depth, height_map, label = augmented['image'], augmented['depth'], augmented['height'], augmented['label']
            else:
                augmented = self.transform(image=rgb, depth=depth, height=height_map, label=label)
                rgb, depth, height_map, label = augmented['image'], augmented['depth'], augmented['height'], augmented['label']

        rgb = (rgb.astype(np.float32) / 255.0 - self.mean) / self.std
        rgb_t = torch.from_numpy(rgb.transpose(2, 0, 1)).float()
        depth_t = torch.from_numpy(depth).float().unsqueeze(0)
        height_t = torch.from_numpy(height_map).float().unsqueeze(0)
        label_t = torch.from_numpy(label).long()

        input_tensor = torch.cat([rgb_t, depth_t, height_t], dim=0)

        if self.split == 'test':
            return input_tensor, fname
        else:
            return input_tensor, label_t

# =========================
# Losses
# =========================
class DiceLoss(nn.Module):
    def __init__(self, n_classes, smooth=1e-5, ignore_index=255):
        super().__init__()
        self.n_classes = n_classes
        self.smooth = smooth
        self.ignore_index = ignore_index

    def forward(self, pred, target):
        mask = (target != self.ignore_index)
        target = target.clone()
        target[~mask] = 0
        pred = torch.softmax(pred, dim=1)
        target_one_hot = torch.nn.functional.one_hot(target, num_classes=self.n_classes).permute(0, 3, 1, 2).float()
        mask = mask.unsqueeze(1).expand_as(pred)
        pred = pred * mask
        target_one_hot = target_one_hot * mask
        intersection = (pred * target_one_hot).sum(dim=(2, 3))
        union = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3))
        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

# =========================
# Model: ResNeXt101 + DeepLabV3+
# =========================
class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, p=1, d=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=k, padding=p, dilation=d, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.block(x)

class ASPP(nn.Module):
    def __init__(self, in_ch, out_ch=256, rates=[6, 12, 18]):
        super().__init__()
        self.branch1 = ConvBNReLU(in_ch, out_ch, k=1, p=0)
        self.branch2 = ConvBNReLU(in_ch, out_ch, k=3, p=rates[0], d=rates[0])
        self.branch3 = ConvBNReLU(in_ch, out_ch, k=3, p=rates[1], d=rates[1])
        self.branch4 = ConvBNReLU(in_ch, out_ch, k=3, p=rates[2], d=rates[2])
        self.branch5_avg = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            ConvBNReLU(in_ch, out_ch, k=1, p=0)
        )
        self.project = nn.Sequential(
            ConvBNReLU(out_ch * 5, out_ch, k=1, p=0),
            nn.Dropout(0.1)
        )

    def forward(self, x):
        h, w = x.shape[2], x.shape[3]
        b1 = self.branch1(x)
        b2 = self.branch2(x)
        b3 = self.branch3(x)
        b4 = self.branch4(x)
        b5 = F.interpolate(self.branch5_avg(x), size=(h, w), mode='bilinear', align_corners=False)
        return self.project(torch.cat([b1, b2, b3, b4, b5], dim=1))

class ResNeXtDeepLabV3Plus(nn.Module):
    def __init__(self, num_classes=13, in_channels=5, pretrained=True):
        super().__init__()
        print("Initializing ResNeXt101_32x8d Backbone...")
        # ★ここを変更: ResNeXt50 -> ResNeXt101
        resnext = models.resnext101_32x8d(weights=models.ResNeXt101_32X8D_Weights.IMAGENET1K_V1 if pretrained else None)

        old_conv = resnext.conv1
        if in_channels != old_conv.in_channels:
            new_conv = nn.Conv2d(in_channels, old_conv.out_channels, kernel_size=7, stride=2, padding=3, bias=False)
            with torch.no_grad():
                new_conv.weight[:, :3] = old_conv.weight
                mean_w = old_conv.weight.mean(dim=1, keepdim=True)
                for c in range(3, in_channels):
                    new_conv.weight[:, c:c+1] = mean_w
            resnext.conv1 = new_conv

        self.enc0 = nn.Sequential(resnext.conv1, resnext.bn1, resnext.relu)
        self.pool = resnext.maxpool
        self.enc1 = resnext.layer1
        self.enc2 = resnext.layer2
        self.enc3 = resnext.layer3
        self.enc4 = resnext.layer4

        # ResNeXt101も最終層は2048chなので変更なし
        self.aspp = ASPP(in_ch=2048, out_ch=256)

        self.low_level_proj = nn.Sequential(
            nn.Conv2d(256, 48, kernel_size=1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )

        self.decoder = nn.Sequential(
            ConvBNReLU(304, 256, k=3, p=1),
            ConvBNReLU(256, 256, k=3, p=1),
            nn.Conv2d(256, num_classes, kernel_size=1)
        )

    def forward(self, x):
        h, w = x.shape[2], x.shape[3]
        x = self.enc0(x)
        x = self.pool(x)
        low_level = self.enc1(x)
        x = self.enc2(low_level)
        x = self.enc3(x)
        x = self.enc4(x)

        x = self.aspp(x)
        x = F.interpolate(x, size=low_level.shape[2:], mode='bilinear', align_corners=False)

        low_level = self.low_level_proj(low_level)
        x = torch.cat([x, low_level], dim=1)
        x = self.decoder(x)

        x = F.interpolate(x, size=(h, w), mode='bilinear', align_corners=False)
        return x

# =========================
# Metrics & Utils
# =========================
@torch.no_grad()
def update_confusion_matrix(cm, pred, label, num_classes, ignore_index=255):
    pred = pred.view(-1); label = label.view(-1)
    mask = (label != ignore_index)
    pred = pred[mask]; label = label[mask]
    if pred.numel() == 0: return cm
    idx = label * num_classes + pred
    binc = torch.bincount(idx, minlength=num_classes**2).cpu()
    cm += binc.view(num_classes, num_classes)
    return cm

def compute_metrics(cm):
    cm = cm.float()
    tp = torch.diag(cm)
    fp = cm.sum(dim=0) - tp
    fn = cm.sum(dim=1) - tp

    iou = tp / (tp + fp + fn + 1e-8)
    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)

    miou = iou.mean().item()
    return {
        "miou": miou,
        "class_iou": iou.cpu().tolist(),
        "class_precision": precision.cpu().tolist(),
        "class_recall": recall.cpu().tolist()
    }

def save_debug_images(model, val_loader, device, epoch, output_dir):
    model.eval()
    os.makedirs(output_dir, exist_ok=True)
    try:
        x, label = next(iter(val_loader))
    except StopIteration:
        return
    x = x.to(device)
    with torch.no_grad():
        preds = model(x).argmax(dim=1).cpu().numpy()

    i = 0
    pred_vis = (preds[i] * (255 // 13)).astype(np.uint8)
    Image.fromarray(pred_vis).save(os.path.join(output_dir, f"epoch_{epoch:03d}_pred.png"))

def main():
    dataset_root = "/content/data"

    img_height = 768
    img_width = 768

    # ★ ResNeXt101は重いのでBS=24は80GBでもカツカツの可能性があります。
    # 安全のため16にしていますが、余裕があれば20-24に上げてください。
    batch_size = 16

    # ★ Epochを50に変更
    epochs = 50

    lr = 1e-4
    seed = 42
    ignore_index = 255
    num_classes = 13
    debug_dir = "debug_logs"

    set_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")
    print(f"Model: ResNeXt101_32x8d + DeepLabV3+")
    print(f"Resolution: {img_height}x{img_width}, Batch Size: {batch_size}, Epochs: {epochs}")

    if os.path.exists(debug_dir): shutil.rmtree(debug_dir)
    os.makedirs(debug_dir)

    train_transform = get_transforms('train', height=img_height, width=img_width)
    val_transform = get_transforms('val', height=img_height, width=img_width)

    train_full = NYUv2Dataset(dataset_root, split='train', transform=train_transform)
    n_total = len(train_full)
    n_train = int(n_total * 0.9)
    n_val = n_total - n_train

    train_sub, val_sub = random_split(train_full, [n_train, n_val], generator=torch.Generator().manual_seed(seed))

    train_ds = Subset(NYUv2Dataset(dataset_root, split='train', transform=train_transform), train_sub.indices)
    val_ds = Subset(NYUv2Dataset(dataset_root, split='train', transform=val_transform), val_sub.indices)
    test_ds = NYUv2Dataset(dataset_root, split='test', transform=val_transform, return_label=False)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_ds, batch_size=min(batch_size, 8), shuffle=False, num_workers=8, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

    class_weights = torch.tensor([1.0, 3.0, 0.6, 2.0, 0.6, 1.0, 1.2, 1.5, 1.5, 2.0, 2.5, 0.5, 1.0]).to(device)
    model = ResNeXtDeepLabV3Plus(num_classes=num_classes, in_channels=5).to(device)

    criterion_ce = nn.CrossEntropyLoss(weight=class_weights, ignore_index=ignore_index)
    criterion_dice = DiceLoss(n_classes=num_classes, ignore_index=ignore_index)

    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-3)

    lambda_poly = lambda epoch: (1 - epoch / epochs) ** 0.9
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_poly)

    scaler = GradScaler("cuda")

    best_miou = 0.0
    timestamp = time.strftime("%Y%m%d%H%M%S")
    log_file = f"train_log_resnext101_{timestamp}.jsonl"

    print("Start Training...")

    with open(log_file, "w") as f:
        f.write(json.dumps({"info": "ResNeXt101 + DeepLabV3+ (768px)", "classes": CLASS_NAMES}) + "\n")

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        current_lr = optimizer.param_groups[0]['lr']

        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}") as pbar:
            for x, label in pbar:
                x = x.to(device, non_blocking=True)
                label = label.to(device, non_blocking=True)

                optimizer.zero_grad(set_to_none=True)
                with autocast("cuda"):
                    logits = model(x)
                    loss_ce = criterion_ce(logits, label)
                    loss_dice = criterion_dice(logits, label)
                    loss = 0.4 * loss_ce + 0.6 * loss_dice

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                total_loss += loss.item()
                pbar.set_postfix(loss=loss.item())

        scheduler.step()

        model.eval()
        cm = torch.zeros((num_classes, num_classes), dtype=torch.long)
        save_debug_images(model, val_loader, device, epoch + 1, debug_dir)

        with torch.no_grad():
            for x, label in tqdm(val_loader, desc="Validating", leave=False):
                x = x.to(device); label = label.to(device)
                logits = model(x)
                cm = update_confusion_matrix(cm, logits.argmax(dim=1).cpu(), label.cpu(), num_classes)

        metrics = compute_metrics(cm)
        miou = metrics['miou']

        print(f"Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f} | Val mIoU: {miou:.5f}")

        if miou > best_miou:
            best_miou = miou
            torch.save(model.state_dict(), f"best_model_resnext101_{timestamp}.pt")
            print(f"  --> Best Model Saved! ({best_miou:.5f})")

        log_entry = {
            "epoch": epoch + 1,
            "loss": total_loss/len(train_loader),
            "lr": current_lr,
            "val_miou": miou,
            "class_iou": metrics['class_iou'],
            "class_precision": metrics['class_precision'],
            "class_recall": metrics['class_recall']
        }
        with open(log_file, "a") as f:
            f.write(json.dumps(log_entry) + "\n")

        if (epoch + 1) % 5 == 0:
            np.save(os.path.join(debug_dir, f"cm_epoch_{epoch+1}.npy"), cm.numpy())

    print("\nGenerating Submission...")
    if os.path.exists(f"best_model_resnext101_{timestamp}.pt"):
        model.load_state_dict(torch.load(f"best_model_resnext101_{timestamp}.pt", map_location=device))
    model.eval()
    predictions = []
    scales = [0.75, 1.0, 1.25]

    with torch.no_grad():
        for x, _ in tqdm(test_loader):
            x = x.to(device)
            B, C, H, W = x.shape
            avg_logits = torch.zeros((B, num_classes, H, W), device=device)

            for scale in scales:
                if scale != 1.0:
                    new_h, new_w = int(H * scale), int(W * scale)
                    x_s = nn.functional.interpolate(x, size=(new_h, new_w), mode='bilinear', align_corners=False)
                else:
                    x_s = x

                logits_s = model(x_s)
                x_s_flip = torch.flip(x_s, dims=[3])
                logits_s_flip = torch.flip(model(x_s_flip), dims=[3])

                if scale != 1.0:
                    logits_s = nn.functional.interpolate(logits_s, size=(H, W), mode='bilinear', align_corners=False)
                    logits_s_flip = nn.functional.interpolate(logits_s_flip, size=(H, W), mode='bilinear', align_corners=False)

                avg_logits += (logits_s + logits_s_flip)

            predictions.append(avg_logits.argmax(dim=1).cpu())

    np.save("submission.npy", torch.cat(predictions, dim=0).numpy().astype(np.uint8))
    with ZipFile("submission.zip", mode="w", compression=ZIP_DEFLATED) as zf:
        zf.write("submission.npy")
        if os.path.exists(debug_dir):
            for f in os.listdir(debug_dir):
                zf.write(os.path.join(debug_dir, f), arcname=f"debug/{f}")

    print("Done! submission.zip created.")

if __name__ == "__main__":
    main()

Device: cuda
Model: ResNeXt101_32x8d + DeepLabV3+
Resolution: 768x768, Batch Size: 16, Epochs: 50
Initializing ResNeXt101_32x8d Backbone...
Downloading: "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth" to /root/.cache/torch/hub/checkpoints/resnext101_32x8d-8ba56ff5.pth


100%|██████████| 340M/340M [00:01<00:00, 241MB/s]


Start Training...


Epoch 1/50: 100%|██████████| 44/44 [00:47<00:00,  1.09s/it, loss=1.03]


Epoch 1: Loss=1.1528 | Val mIoU: 0.36060
  --> Best Model Saved! (0.36060)


Epoch 2/50: 100%|██████████| 44/44 [00:23<00:00,  1.84it/s, loss=0.897]


Epoch 2: Loss=0.8855 | Val mIoU: 0.46161
  --> Best Model Saved! (0.46161)


Epoch 3/50: 100%|██████████| 44/44 [00:25<00:00,  1.74it/s, loss=0.906]


Epoch 3: Loss=0.7964 | Val mIoU: 0.50096
  --> Best Model Saved! (0.50096)


Epoch 4/50: 100%|██████████| 44/44 [00:24<00:00,  1.80it/s, loss=0.727]


Epoch 4: Loss=0.7294 | Val mIoU: 0.51975
  --> Best Model Saved! (0.51975)


Epoch 5/50: 100%|██████████| 44/44 [00:24<00:00,  1.78it/s, loss=0.711]


Epoch 5: Loss=0.6838 | Val mIoU: 0.56592
  --> Best Model Saved! (0.56592)


Epoch 6/50: 100%|██████████| 44/44 [00:24<00:00,  1.81it/s, loss=0.762]


Epoch 6: Loss=0.6387 | Val mIoU: 0.56870
  --> Best Model Saved! (0.56870)


Epoch 7/50: 100%|██████████| 44/44 [00:24<00:00,  1.82it/s, loss=0.691]


Epoch 7: Loss=0.6138 | Val mIoU: 0.58649
  --> Best Model Saved! (0.58649)


Epoch 8/50: 100%|██████████| 44/44 [00:23<00:00,  1.90it/s, loss=0.626]


Epoch 8: Loss=0.5834 | Val mIoU: 0.60417
  --> Best Model Saved! (0.60417)


Epoch 9/50: 100%|██████████| 44/44 [00:24<00:00,  1.80it/s, loss=0.584]


Epoch 9: Loss=0.5613 | Val mIoU: 0.60633
  --> Best Model Saved! (0.60633)


Epoch 10/50: 100%|██████████| 44/44 [00:24<00:00,  1.81it/s, loss=0.545]


Epoch 10: Loss=0.5316 | Val mIoU: 0.60413


Epoch 11/50: 100%|██████████| 44/44 [00:24<00:00,  1.82it/s, loss=0.536]


Epoch 11: Loss=0.5154 | Val mIoU: 0.60109


Epoch 12/50: 100%|██████████| 44/44 [00:24<00:00,  1.83it/s, loss=0.56]


Epoch 12: Loss=0.5111 | Val mIoU: 0.60160


Epoch 13/50: 100%|██████████| 44/44 [00:24<00:00,  1.79it/s, loss=0.519]


Epoch 13: Loss=0.4948 | Val mIoU: 0.59648


Epoch 14/50: 100%|██████████| 44/44 [00:24<00:00,  1.83it/s, loss=0.475]


Epoch 14: Loss=0.4837 | Val mIoU: 0.61893
  --> Best Model Saved! (0.61893)


Epoch 15/50: 100%|██████████| 44/44 [00:24<00:00,  1.82it/s, loss=0.557]


Epoch 15: Loss=0.4735 | Val mIoU: 0.60459


Epoch 16/50: 100%|██████████| 44/44 [00:25<00:00,  1.75it/s, loss=0.532]


Epoch 16: Loss=0.4592 | Val mIoU: 0.61416


Epoch 17/50: 100%|██████████| 44/44 [00:23<00:00,  1.83it/s, loss=0.484]


Epoch 17: Loss=0.4485 | Val mIoU: 0.62129
  --> Best Model Saved! (0.62129)


Epoch 18/50: 100%|██████████| 44/44 [00:23<00:00,  1.84it/s, loss=0.473]


Epoch 18: Loss=0.4396 | Val mIoU: 0.62256
  --> Best Model Saved! (0.62256)


Epoch 19/50: 100%|██████████| 44/44 [00:24<00:00,  1.77it/s, loss=0.486]


Epoch 19: Loss=0.4358 | Val mIoU: 0.60654


Epoch 20/50: 100%|██████████| 44/44 [00:23<00:00,  1.84it/s, loss=0.408]


Epoch 20: Loss=0.4284 | Val mIoU: 0.61196


Epoch 21/50: 100%|██████████| 44/44 [00:23<00:00,  1.84it/s, loss=0.49]


Epoch 21: Loss=0.4258 | Val mIoU: 0.61564


Epoch 22/50: 100%|██████████| 44/44 [00:24<00:00,  1.82it/s, loss=0.482]


Epoch 22: Loss=0.4133 | Val mIoU: 0.63088
  --> Best Model Saved! (0.63088)


Epoch 23/50: 100%|██████████| 44/44 [00:23<00:00,  1.83it/s, loss=0.402]


Epoch 23: Loss=0.4109 | Val mIoU: 0.62940


Epoch 24/50: 100%|██████████| 44/44 [00:23<00:00,  1.88it/s, loss=0.377]


Epoch 24: Loss=0.4097 | Val mIoU: 0.60746


Epoch 25/50: 100%|██████████| 44/44 [00:24<00:00,  1.78it/s, loss=0.396]


Epoch 25: Loss=0.4043 | Val mIoU: 0.61823


Epoch 26/50: 100%|██████████| 44/44 [00:23<00:00,  1.84it/s, loss=0.42]


Epoch 26: Loss=0.3958 | Val mIoU: 0.62914


Epoch 27/50: 100%|██████████| 44/44 [00:24<00:00,  1.78it/s, loss=0.439]


Epoch 27: Loss=0.3929 | Val mIoU: 0.63838
  --> Best Model Saved! (0.63838)


Epoch 28/50: 100%|██████████| 44/44 [00:24<00:00,  1.79it/s, loss=0.391]


Epoch 28: Loss=0.3888 | Val mIoU: 0.63876
  --> Best Model Saved! (0.63876)


Epoch 29/50: 100%|██████████| 44/44 [00:23<00:00,  1.84it/s, loss=0.418]


Epoch 29: Loss=0.3835 | Val mIoU: 0.63093


Epoch 30/50: 100%|██████████| 44/44 [00:24<00:00,  1.77it/s, loss=0.409]


Epoch 30: Loss=0.3806 | Val mIoU: 0.63754


Epoch 31/50: 100%|██████████| 44/44 [00:24<00:00,  1.79it/s, loss=0.39]


Epoch 31: Loss=0.3740 | Val mIoU: 0.63909
  --> Best Model Saved! (0.63909)


Epoch 32/50: 100%|██████████| 44/44 [00:24<00:00,  1.79it/s, loss=0.399]


Epoch 32: Loss=0.3771 | Val mIoU: 0.64669
  --> Best Model Saved! (0.64669)


Epoch 33/50: 100%|██████████| 44/44 [00:23<00:00,  1.87it/s, loss=0.397]


Epoch 33: Loss=0.3752 | Val mIoU: 0.63263


Epoch 34/50: 100%|██████████| 44/44 [00:23<00:00,  1.87it/s, loss=0.4]


Epoch 34: Loss=0.3673 | Val mIoU: 0.63153


Epoch 35/50: 100%|██████████| 44/44 [00:23<00:00,  1.85it/s, loss=0.378]


Epoch 35: Loss=0.3684 | Val mIoU: 0.63568


Epoch 36/50: 100%|██████████| 44/44 [00:24<00:00,  1.76it/s, loss=0.384]


Epoch 36: Loss=0.3645 | Val mIoU: 0.64453


Epoch 37/50: 100%|██████████| 44/44 [00:24<00:00,  1.82it/s, loss=0.368]


Epoch 37: Loss=0.3621 | Val mIoU: 0.65126
  --> Best Model Saved! (0.65126)


Epoch 38/50: 100%|██████████| 44/44 [00:24<00:00,  1.78it/s, loss=0.394]


Epoch 38: Loss=0.3612 | Val mIoU: 0.62595


Epoch 39/50: 100%|██████████| 44/44 [00:23<00:00,  1.84it/s, loss=0.423]


Epoch 39: Loss=0.3601 | Val mIoU: 0.63707


Epoch 40/50: 100%|██████████| 44/44 [00:25<00:00,  1.76it/s, loss=0.399]


Epoch 40: Loss=0.3568 | Val mIoU: 0.64026


Epoch 41/50: 100%|██████████| 44/44 [00:24<00:00,  1.83it/s, loss=0.343]


Epoch 41: Loss=0.3553 | Val mIoU: 0.64577


Epoch 42/50: 100%|██████████| 44/44 [00:24<00:00,  1.77it/s, loss=0.41]


Epoch 42: Loss=0.3508 | Val mIoU: 0.65506
  --> Best Model Saved! (0.65506)


Epoch 43/50: 100%|██████████| 44/44 [00:24<00:00,  1.83it/s, loss=0.338]


Epoch 43: Loss=0.3513 | Val mIoU: 0.64891


Epoch 44/50: 100%|██████████| 44/44 [00:24<00:00,  1.79it/s, loss=0.373]


Epoch 44: Loss=0.3528 | Val mIoU: 0.64951


Epoch 45/50: 100%|██████████| 44/44 [00:24<00:00,  1.76it/s, loss=0.355]


Epoch 45: Loss=0.3493 | Val mIoU: 0.64606


Epoch 46/50: 100%|██████████| 44/44 [00:24<00:00,  1.79it/s, loss=0.364]


Epoch 46: Loss=0.3482 | Val mIoU: 0.64754


Epoch 47/50: 100%|██████████| 44/44 [00:24<00:00,  1.80it/s, loss=0.37]


Epoch 47: Loss=0.3471 | Val mIoU: 0.64252


Epoch 48/50: 100%|██████████| 44/44 [00:23<00:00,  1.84it/s, loss=0.342]


Epoch 48: Loss=0.3443 | Val mIoU: 0.64631


Epoch 49/50: 100%|██████████| 44/44 [00:24<00:00,  1.82it/s, loss=0.335]


Epoch 49: Loss=0.3425 | Val mIoU: 0.65069


Epoch 50/50: 100%|██████████| 44/44 [00:23<00:00,  1.85it/s, loss=0.334]


Epoch 50: Loss=0.3440 | Val mIoU: 0.65046

Generating Submission...


100%|██████████| 654/654 [01:34<00:00,  6.89it/s]


Done! submission.zip created.


In [None]:
import json
import matplotlib.pyplot as plt
import glob
import os

def plot_learning_curves(log_file):
    epochs = []
    losses = []
    mious = []
    lrs = []

    # ログファイルの読み込み
    with open(log_file, 'r') as f:
        for line in f:
            data = json.loads(line)
            # ヘッダー行（info）はスキップ
            if "info" in data:
                continue

            epochs.append(data['epoch'])
            losses.append(data['loss'])
            mious.append(data['val_miou'])
            lrs.append(data['lr'])

    # プロットの作成
    fig, ax1 = plt.subplots(figsize=(12, 6))

    # Loss (左軸・赤)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Training Loss', color='tab:red')
    ax1.plot(epochs, losses, color='tab:red', label='Train Loss', marker='.')
    ax1.tick_params(axis='y', labelcolor='tab:red')
    ax1.grid(True, alpha=0.3)

    # mIoU (右軸・青)
    ax2 = ax1.twinx()
    ax2.set_ylabel('Validation mIoU', color='tab:blue')
    ax2.plot(epochs, mious, color='tab:blue', label='Val mIoU', marker='.')
    ax2.tick_params(axis='y', labelcolor='tab:blue')

    plt.title(f'Learning Curve: {os.path.basename(log_file)}')
    fig.tight_layout()
    plt.show()

    # 学習率の確認用プロット（Poly Schedulerが効いているか確認）
    plt.figure(figsize=(12, 3))
    plt.plot(epochs, lrs, color='orange')
    plt.title("Learning Rate Schedule")
    plt.xlabel("Epoch")
    plt.ylabel("LR")
    plt.grid(True, alpha=0.3)
    plt.show()

# 自動で最新のログファイルを見つけて描画
log_files = sorted(glob.glob("train_log_detailed_*.jsonl"))
if log_files:
    latest_log = log_files[-1]
    print(f"Plotting: {latest_log}")
    plot_learning_curves(latest_log)
else:
    print("No log file found.")

In [37]:
# -*- coding: utf-8 -*-
# NYUv2 val visualization (RGB | GT | Pred | RGB+Error) with FIXED palette -> ZIP
# - NO matplotlib colormap
# - GT and Pred share EXACT same palette mapping (ID -> RGB)
# - ignore_index(255) is shown as black by default

import os
import zipfile
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, Subset
from torchvision.datasets import VisionDataset
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
from torchvision import models

# =========================
# EDIT HERE
# =========================
DATASET_ROOT = "/content/data"
IMAGE_SIZE = (512, 512)

MODEL_PATH = "/content/checkpoints/model_20260103062841.pt"  # ←ここだけ直す
OUTDIR = "val_viz_fixed"
ZIP_PATH = "val_viz_fixed.zip"

MAX_SAVE = 50  # Noneで全件
SEED = 42
TRAIN_VAL_SPLIT = 0.9

NUM_WORKERS = 0
BATCH_SIZE = 1

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IGNORE_INDEX = 255

print("DEVICE:", DEVICE)
print("MODEL_PATH:", MODEL_PATH)

# =========================
# FIXED PALETTE (ID -> RGB)
# ここが「色の意味」そのもの。GTもPredも必ずこれを通す。
# =========================
PALETTE_13 = np.array([
    [  0,   0,   0],   # 0 wall
    [128,  64,  32],   # 1 floor
    [  0, 128, 255],   # 2 cabinet
    [255,   0, 128],   # 3 bed
    [  0, 255,   0],   # 4 chair
    [255, 128,   0],   # 5 sofa
    [255, 255,   0],   # 6 table
    [128, 128, 128],   # 7 door
    [  0, 255, 255],   # 8 window
    [  0,  64, 128],   # 9 bookshelf
    [255,   0,   0],   # 10 picture
    [128,   0, 255],   # 11 counter
    [ 64, 255, 128],   # 12 desk
], dtype=np.uint8)

# =========================
# Dataset helpers
# =========================
def pil_label_to_long_tensor(lbl_pil: Image.Image) -> torch.Tensor:
    arr = np.array(lbl_pil, dtype=np.int64)
    return torch.from_numpy(arr).long()

def depth_pil_to_tensor_01(depth_pil: Image.Image, size_hw):
    depth_pil = depth_pil.resize((size_hw[1], size_hw[0]), resample=Image.BILINEAR)
    arr = np.array(depth_pil)
    if arr.dtype == np.uint16 or (arr.max() > 255):
        arr = arr.astype(np.float32) / 65535.0
    else:
        arr = arr.astype(np.float32) / 255.0
    arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
    arr = np.clip(arr, 0.0, 1.0)
    return torch.from_numpy(arr).unsqueeze(0)  # [1,H,W]

class NYUv2(VisionDataset):
    def __init__(self, root, split="train", include_depth=True,
                 image_transform=None, depth_transform=None, target_transform=None):
        super().__init__(root)
        assert split in ("train", "test")
        self.root = root
        self.split = split
        self.include_depth = include_depth

        images_dir = os.path.join(self.root, self.split, "image")
        img_names = sorted(os.listdir(images_dir))

        self.images = [os.path.join(images_dir, n) for n in img_names]
        self.depths = [os.path.join(self.root, self.split, "depth", n) for n in img_names]

        if self.split == "train":
            self.targets = [os.path.join(self.root, self.split, "label", n) for n in img_names]
        else:
            self.targets = None

        self.image_transform = image_transform
        self.depth_transform = depth_transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert("RGB")
        depth = Image.open(self.depths[idx])

        if self.image_transform is not None:
            image = self.image_transform(image)
        if self.depth_transform is not None:
            depth = self.depth_transform(depth)

        if self.split == "test":
            return (image, depth) if self.include_depth else image

        target = Image.open(self.targets[idx])
        if self.target_transform is not None:
            target = self.target_transform(target)

        return (image, depth, target) if self.include_depth else (image, target)

# =========================
# Model blocks (ResNet50-UNet + optional SE)
# =========================
class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, p=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=k, padding=p, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.block(x)

class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.conv1 = ConvBNReLU(in_ch + skip_ch, out_ch)
        self.conv2 = ConvBNReLU(out_ch, out_ch)

    def forward(self, x, skip):
        x = nn.functional.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class DepthStem(nn.Module):
    def __init__(self, z_ch=64):
        super().__init__()
        self.net = nn.Sequential(
            ConvBNReLU(1, 32, 3, 1),
            nn.MaxPool2d(2),
            ConvBNReLU(32, 48, 3, 1),
            nn.MaxPool2d(2),
            ConvBNReLU(48, z_ch, 3, 1),
            nn.MaxPool2d(2),
        )
    def forward(self, d):
        return self.net(d)

class SEFromDepth(nn.Module):
    def __init__(self, feat_ch, z_ch, reduction=16, alpha=0.05):
        super().__init__()
        self.alpha = float(alpha)
        hidden = max(32, feat_ch // reduction)
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(z_ch, hidden, 1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden, feat_ch, 1, bias=True),
            nn.Sigmoid(),
        )
    def forward(self, F, z):
        g = self.fc(z)  # [B,C,1,1]
        scale = 1.0 + self.alpha * (g - 0.5)
        return F * scale

class ResNet50UNet(nn.Module):
    def __init__(self, num_classes=13, coarse_classes=5, in_channels=4,
                 pretrained=False,
                 use_se=False, se_z_ch=64, se_reduction=16, se_alpha=0.05,
                 se_dec_stages=("up1",)):
        super().__init__()
        self.use_se = use_se
        self.se_dec_stages = tuple(se_dec_stages)

        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None)

        old_conv = resnet.conv1
        if in_channels != old_conv.in_channels:
            new_conv = nn.Conv2d(
                in_channels, old_conv.out_channels,
                kernel_size=old_conv.kernel_size,
                stride=old_conv.stride,
                padding=old_conv.padding,
                bias=False
            )
            with torch.no_grad():
                new_conv.weight[:, :3] = old_conv.weight
                mean_w = old_conv.weight.mean(dim=1, keepdim=True)
                for c in range(3, in_channels):
                    new_conv.weight[:, c:c+1] = mean_w
            resnet.conv1 = new_conv

        self.enc0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)  # 1/2
        self.pool = resnet.maxpool                                        # 1/4
        self.enc1 = resnet.layer1                                         # 1/4
        self.enc2 = resnet.layer2                                         # 1/8
        self.enc3 = resnet.layer3                                         # 1/16
        self.enc4 = resnet.layer4                                         # 1/32

        self.center = ConvBNReLU(2048, 1024)
        self.up4 = UpBlock(1024, 1024, 512)
        self.up3 = UpBlock(512,  512,  256)
        self.up2 = UpBlock(256,  256,  128)
        self.up1 = UpBlock(128,  64,   64)

        self.head_feat = ConvBNReLU(64, 64)
        self.head_out  = nn.Conv2d(64, num_classes, kernel_size=1)

        self.coarse_head = nn.Conv2d(256, coarse_classes, kernel_size=1)
        self.boundary_head = nn.Sequential(ConvBNReLU(64, 32), nn.Conv2d(32, 1, 1))

        if self.use_se:
            self.depth_stem = DepthStem(z_ch=se_z_ch)
            stage_to_ch = {"center":1024, "up4":512, "up3":256, "up2":128, "up1":64}
            self.se_dec = nn.ModuleDict()
            for st in self.se_dec_stages:
                self.se_dec[st] = SEFromDepth(stage_to_ch[st], se_z_ch, se_reduction, se_alpha)
        else:
            self.depth_stem = None
            self.se_dec = None

    def _apply_dec_se(self, feat, z, stage_name: str):
        if (not self.use_se) or (self.se_dec is None) or (stage_name not in self.se_dec):
            return feat
        z_ = nn.functional.interpolate(z, size=feat.shape[-2:], mode="bilinear", align_corners=False)
        return self.se_dec[stage_name](feat, z_)

    def forward(self, x, depth=None):
        c1 = self.enc0(x)
        t  = self.pool(c1)
        c2 = self.enc1(t)
        c3 = self.enc2(c2)
        c4 = self.enc3(c3)
        c5 = self.enc4(c4)

        z = self.depth_stem(depth) if self.use_se else None

        x = self.center(c5)
        if z is not None: x = self._apply_dec_se(x, z, "center")
        x = self.up4(x, c4)
        if z is not None: x = self._apply_dec_se(x, z, "up4")
        x = self.up3(x, c3)
        if z is not None: x = self._apply_dec_se(x, z, "up3")

        coarse_logits = self.coarse_head(x)

        x = self.up2(x, c2)
        if z is not None: x = self._apply_dec_se(x, z, "up2")
        x = self.up1(x, c1)
        if z is not None: x = self._apply_dec_se(x, z, "up1")

        x = nn.functional.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
        feat = self.head_feat(x)
        fine_logits = self.head_out(feat)
        boundary_logit = self.boundary_head(x)
        return fine_logits, coarse_logits, boundary_logit

# =========================
# FIXED colorization (ID -> RGB)
# =========================
def id_to_rgb(mask_hw: np.ndarray, ignore_index=255, ignore_rgb=(0,0,0)) -> np.ndarray:
    """
    mask_hw: [H,W] int (0..12 or 255)
    returns: [H,W,3] uint8
    """
    out = np.zeros((mask_hw.shape[0], mask_hw.shape[1], 3), dtype=np.uint8)
    ign = (mask_hw == ignore_index)
    valid = ~ign
    m = mask_hw.copy()
    m[ign] = 0
    m = np.clip(m, 0, 12)
    out[valid] = PALETTE_13[m[valid]]
    out[ign] = np.array(ignore_rgb, dtype=np.uint8)
    return out

def denorm_rgb(rgb_tensor: torch.Tensor) -> Image.Image:
    mean = torch.tensor([0.485, 0.456, 0.406], device=rgb_tensor.device).view(3,1,1)
    std  = torch.tensor([0.229, 0.224, 0.225], device=rgb_tensor.device).view(3,1,1)
    x = (rgb_tensor * std + mean).clamp(0, 1)
    x = (x * 255).byte().permute(1,2,0).cpu().numpy()
    return Image.fromarray(x, mode="RGB")

def overlay_error(rgb_pil: Image.Image, gt_hw: np.ndarray, pred_hw: np.ndarray, ignore_index=255) -> Image.Image:
    rgb = np.array(rgb_pil).astype(np.float32)
    err = (gt_hw != ignore_index) & (gt_hw != pred_hw)
    if err.any():
        overlay = rgb.copy()
        overlay[err] = 0.6 * overlay[err] + 0.4 * np.array([255, 0, 0], dtype=np.float32)
        overlay = np.clip(overlay, 0, 255).astype(np.uint8)
        return Image.fromarray(overlay, mode="RGB")
    return rgb_pil

def stack_h(images):
    widths = [im.width for im in images]
    heights = [im.height for im in images]
    out = Image.new("RGB", (sum(widths), max(heights)))
    x = 0
    for im in images:
        out.paste(im, (x, 0))
        x += im.width
    return out

# =========================
# Transforms (same as eval)
# =========================
eval_image_transform = Compose([
    Resize(IMAGE_SIZE, interpolation=InterpolationMode.BILINEAR),
    ToTensor(),
    Normalize((0.485,0.456,0.406), (0.229,0.224,0.225)),
])

def depth_transform(pil):
    return depth_pil_to_tensor_01(pil, IMAGE_SIZE)

def target_transform(pil):
    pil = pil.resize((IMAGE_SIZE[1], IMAGE_SIZE[0]), resample=Image.NEAREST)
    return pil_label_to_long_tensor(pil)

# =========================
# Build val split (same indices rule)
# =========================
split_base = NYUv2(
    root=DATASET_ROOT,
    split="train",
    include_depth=True,
    image_transform=eval_image_transform,
    depth_transform=depth_transform,
    target_transform=target_transform,
)

n_total = len(split_base)
n_train = int(n_total * TRAIN_VAL_SPLIT)
n_val = n_total - n_train
g = torch.Generator().manual_seed(SEED)
_, val_subset = random_split(split_base, [n_train, n_val], generator=g)
val_ds = Subset(split_base, val_subset.indices)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"Val samples: {len(val_ds)}")

# =========================
# Load checkpoint + auto-detect SE
# =========================
ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
if isinstance(ckpt, dict) and "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
    sd = ckpt["state_dict"]
else:
    sd = ckpt

has_se = any(k.startswith("depth_stem.") or k.startswith("se_dec.") for k in sd.keys())
print("Detected SE in checkpoint:", has_se)

model = ResNet50UNet(
    num_classes=13,
    coarse_classes=5,
    in_channels=4,
    pretrained=False,
    use_se=has_se,
    se_z_ch=64,
    se_reduction=16,
    se_alpha=0.05,
    se_dec_stages=("up1",),
).to(DEVICE)

model.load_state_dict(sd, strict=True)
model.eval()

# =========================
# Generate panels with FIXED palette
# =========================
os.makedirs(OUTDIR, exist_ok=True)

saved = 0
with torch.no_grad():
    for i, (rgb, depth, label) in enumerate(tqdm(val_loader, desc="Saving fixed-color panels")):
        rgb = rgb.to(DEVICE, non_blocking=True)
        depth = depth.to(DEVICE, non_blocking=True)
        label = label.to(DEVICE, non_blocking=True)

        x = torch.cat([rgb.float(), depth.float()], dim=1)
        fine_logits, _, _ = model(x, depth=depth.float() if has_se else None)
        pred = fine_logits.argmax(dim=1)

        rgb_pil = denorm_rgb(rgb[0])

        gt_hw = label[0].cpu().numpy().astype(np.int32)
        pred_hw = pred[0].cpu().numpy().astype(np.int32)

        gt_rgb = id_to_rgb(gt_hw, ignore_index=IGNORE_INDEX, ignore_rgb=(0,0,0))
        pr_rgb = id_to_rgb(pred_hw, ignore_index=IGNORE_INDEX, ignore_rgb=(0,0,0))

        gt_pil = Image.fromarray(gt_rgb, mode="RGB")
        pr_pil = Image.fromarray(pr_rgb, mode="RGB")
        err_pil = overlay_error(rgb_pil, gt_hw, pred_hw, ignore_index=IGNORE_INDEX)

        panel = stack_h([rgb_pil, gt_pil, pr_pil, err_pil])
        panel.save(os.path.join(OUTDIR, f"{i:06d}.png"), optimize=True)

        saved += 1
        if MAX_SAVE is not None and saved >= MAX_SAVE:
            break

print(f"Saved: {saved} images -> {OUTDIR}")

# =========================
# Zip
# =========================
with zipfile.ZipFile(ZIP_PATH, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as zf:
    for fn in sorted(os.listdir(OUTDIR)):
        if fn.lower().endswith(".png"):
            zf.write(os.path.join(OUTDIR, fn), arcname=fn)

print(f"ZIP created: {ZIP_PATH}")


DEVICE: cuda
MODEL_PATH: /content/checkpoints/model_20260103062841.pt
Val samples: 80
Detected SE in checkpoint: False


  return Image.fromarray(x, mode="RGB")
  gt_pil = Image.fromarray(gt_rgb, mode="RGB")
  pr_pil = Image.fromarray(pr_rgb, mode="RGB")
  return Image.fromarray(overlay, mode="RGB")
Saving fixed-color panels:  61%|██████▏   | 49/80 [00:27<00:17,  1.78it/s]


Saved: 50 images -> val_viz_fixed
ZIP created: val_viz_fixed.zip


In [33]:
%ls -l checkpoints

total 4093968
-rw-r--r-- 1 root root 233561111 Jan  3 01:19 model_20260103011918.pt
-rw-r--r-- 1 root root 233561111 Jan  3 01:32 model_20260103013257.pt
-rw-r--r-- 1 root root 232859155 Jan  3 01:53 model_20260103015327.pt
-rw-r--r-- 1 root root 232859155 Jan  3 02:13 model_20260103021307.pt
-rw-r--r-- 1 root root 232859155 Jan  3 02:40 model_20260103024015.pt
-rw-r--r-- 1 root root 232800147 Jan  3 03:13 model_20260103031301.pt
-rw-r--r-- 1 root root 232800147 Jan  3 03:31 model_20260103033149.pt
-rw-r--r-- 1 root root 232800147 Jan  3 03:48 model_20260103034848.pt
-rw-r--r-- 1 root root 232800147 Jan  3 04:14 model_20260103041457.pt
-rw-r--r-- 1 root root 232814185 Jan  3 04:38 model_20260103043835.pt
-rw-r--r-- 1 root root 232814185 Jan  3 04:51 model_20260103045108.pt
-rw-r--r-- 1 root root 232882241 Jan  3 05:16 model_20260103051603.pt
-rw-r--r-- 1 root root 232883265 Jan  3 05:41 model_20260103054113.pt
-rw-r--r-- 1 root root 232688655 Jan  3 06:28 model_20260103062841.pt
-rw-r-

In [17]:
# ------------------
#    Evaluation (DeepLabV3-ResNet50 / ResNet backbone)
# ------------------

ckpt = torch.load(final_path, map_location=device)

# もし「学習で checkpoint dict（model_state入り）」で保存している場合にも対応
if isinstance(ckpt, dict) and "model_state" in ckpt:
    model.load_state_dict(ckpt["model_state"])
else:
    model.load_state_dict(ckpt)

model.eval()

predictions = []

with torch.no_grad():
    print("Generating predictions...")
    for image, depth in tqdm(test_data):
        image = image.to(device, non_blocking=True)
        depth = depth.to(device, non_blocking=True)

        # 念のため：depthが3ch事故のときは1chに落とす
        if depth.dim() == 4 and depth.size(1) == 3:
            depth = depth[:, :1, :, :]

        x = torch.cat((image, depth), dim=1)   # [B,4,H,W]

        # DeepLabV3 は dict を返すので "out" を使う
        out = model(x)["out"]                  # [B,num_classes,H,W]
        pred = out.argmax(dim=1)               # [B,H,W]

        predictions.append(pred.cpu())

predictions = torch.cat(predictions, dim=0).numpy()
np.save("submission.npy", predictions)
print("Predictions saved to submission.npy")


"""## 提出方法

以下の3点をzip化し，Omnicampusの「最終課題 (セグメンテーション)」から提出してください．

- `submission.npy`
- `model.pt`や`model_best.pt`など，テストに使用した重み（拡張子は`.pt`のみ）
- 本Colab Notebook
"""

from zipfile import ZipFile, ZIP_DEFLATED

notebook_path = "/content/drive/MyDrive/Colab Notebooks/DL_Basic_2025_Competition_NYUv2_baseline.ipynb"

from zipfile import ZipFile, ZIP_DEFLATED

notebook_path = "/content/drive/MyDrive/code.ipynb"

with ZipFile("submission.zip",
             mode="w",
             compression=ZIP_DEFLATED,
             compresslevel=9) as zf:
    zf.write("submission.npy")
    zf.write(model_path)
#    zf.write(notebook_path,
#             arcname="code.ipynb")


NameError: name 'final_path' is not defined