<a href="https://colab.research.google.com/github/trie0000/external/blob/main/code_20260115_2024.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
# -*- coding: utf-8 -*-
"""DL_Basic_2025_Competition_NYUv2_baseline.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/17t7uAU0aST5aUt6sIJCqFneGlQrxFrJY

# Deep Learning 基礎講座　最終課題: NYUv2 セマンティックセグメンテーションcoarse_w=0」

## 概要
RGB画像から、画像内の各ピクセルがどのクラスに属するかを予測するセマンティックセグメンテーションタスク.

### データセット
- データセット: NYUv2 dataset
- 訓練データ: 795枚
- テストデータ: 654枚
- 入力: RGB画像 + 深度マップ（元画像サイズは可変）
- 出力: 13クラスのセグメンテーションマップ
- 評価指標: Mean IoU (Intersection over Union)

### データセットの詳細（[NYU Depth Dataset V2](https://cs.nyu.edu/~fergus/datasets/nyu_depth_v2.html)）
- 画像は屋内シーンを撮影したもので、家具や壁、床などの物体が含まれています.
- 各画像に対して13クラスのセグメンテーションラベルが提供されます.
- データは以下のディレクトリ構造で提供:
```
data/NYUv2/
├─train/
│  ├─image/      # RGB画像
│  │    000000.png
│  │    ...
│  │
│  ├─depth/      # 深度マップ
│  │    000000.png
│  │    ...
│  │
│  └─label/      # 13クラスセグメンテーション（教師ラベル）
│       000000.png
│       ...
└─test/
   ├─image/      # RGB画像
   │    000000.png
   │    ...
   │  ├─depth/   # 深度マップ
   │    000000.png
   │    ...
```

### タスクの詳細
- 入力のRGB画像と深度マップから、各ピクセルが13クラスのどれに属するかを予測するタスクです.
- 評価はMean IoUを使用します．
  - 各クラスごとにIoUを計算し、その平均を取ります.
  - IoUは以下の式で計算:
  $$IoU = \frac{TP}{TP + FP + FN}$$
    - TP: True Positive（正しく予測されたピクセル数）
    - FP: False Positive（誤って予測されたピクセル数）
    - FN: False Negative（見逃したピクセル数）

### 前処理
- 入力画像は512×512にリサイズされます.
- ピクセル値は0-1に正規化されます.
- セグメンテーションラベルは0-12の整数値（13クラス）です．
  - 255はignore index（評価から除外）

### 提出形式
- テスト画像（RGB + Depth）の各ピクセルに対してクラス（0~12）を予測したものをnumpy配列として保存されます.
- ファイル名: `submission.npy`
- 配列の形状: [テストデータ数, 高さ, 幅]
- 各ピクセルの値: 0-12の整数（予測クラス）

## 考えられる工夫の例
- 事前学習モデルの fine-tuning
    - ImageNetなどで事前学習されたモデルを本データセットでfine-tuningすることで性能向上が見込めます.
- 損失関数の再設計
    - クラスごとの出現頻度に応じて損失を補正するように損失関数を設計すると、クラス分布の不均衡に対してロバストな学習ができます.
- 画像の前処理
    - RandomResizedCrop / Flip / ColorJitter 等のデータ拡張を追加することで，汎化性能の向上が見込めます．

## 修了要件を満たす条件
- ベースラインでは，omnicampus 上での性能評価において， 38.2% となります．したがって，ベースラインである 38.2% を超えた提出のみ，修了要件として認めます．
- ベースラインから改善を加えることで， 50%以上に性能向上することを運営で確認しています．こちらを 1つの指標として取り組んでみてください．

## 注意点
- 最終的な予測モデルは，**配布している訓練データを用いて学習**（ファインチューニング含む）したものとしてください．
- 学習を行わず，**事前学習済みモデルの知識のみを利用した推論は禁止**します．
（例: ChatGPT 等の LLM に入力して推論を得るのみ）

### 事前学習モデルの利用
許可される事項
- **構成要素としての事前学習モデルの利用**: 自身で実装したアーキテクチャの一部（特徴抽出，埋め込みなど）として事前学習モデル（BERT，ViT など）を利用することは可能です．
- **ファインチューニング**: 上記の用途で利用している事前学習モデルのファインチューニングは可能です．

禁止される事項
- **タスク解決用の事前学習モデルの利用**: transformers などで提供されている，対象タスクを直接解くための事前学習モデルでそのまま推論のみ，またはファインチューニングのみで利用することは禁止とします．
  - 禁止事項の例: VQA タスクを直接解くための事前学習モデルを VQA タスクで利用する．

### データの準備
データをダウンロードした際に，google drive したため，利用するために google drive をマウントする必要があります．また， drive 上で展開することができないため，/content ディレクトリ下にコピーし "data.zip" を展開します．
google drive 上に "data.zip" が配置されていない場合は実行できません．google drive 上に "data.zip" (**831MB**) を配置することが可能であれば，"data_download.ipynb" を先に実行してください．難しい場合は，omnicampus 演習環境を利用してください．．
"""

# omnicampus 上では 4 セル目まで実行不要
# ドライブのマウント
from google.colab import drive
drive.mount('/content/drive')

# データダウンロード用の notebook にてgoogle drive への保存後，
# 反映に時間がかかる可能性がありますので，google drive のマウント後，
# data.zip がディレクトリ内にあることを確認してから実行してください．
# data.zip を /content 下にコピーする
!cp "/content/drive/MyDrive/data.zip" "/content"

# Commented out IPython magic to ensure Python compatibility.
# カレントディレクトリ下のファイル群を確認
# data.zip が表示されれば問題ないです
# %ls

# データを解凍する
!unzip data.zip
!mkdir data
!mv train test data/

"""omnicampus 演習環境では，data_download.ipynb のマウント，zip 化，drive へのコピーを実行しないことで，"data.zip" を解凍した形で配置されます．したがって，data ディレクトリが存在するディレクトリをカレントディレクトリとするだけで良いです．


"""

# Commented out IPython magic to ensure Python compatibility.
# omnicampus 実行用
# 以下の例では/workspace/Segmentation/split_data_scripts/omnicampus に data ディレクトリがあると想定
# %cd /workspace/Segmentation/split_data_scripts_omnicampus

# omnicampus 実行用
!pip install h5py scikit-image

"""# import library"""

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Archive:  data.zip
replace data/train/image/000600.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: A
  inflating: data/train/image/000600.png  
  inflating: data/train/image/000320.png  
  inflating: data/train/image/000491.png  
  inflating: data/train/image/000502.png  
  inflating: data/train/image/000129.png  
  inflating: data/train/image/000044.png  
  inflating: data/train/image/000652.png  
  inflating: data/train/image/000919.png  
  inflating: data/train/image/000528.png  
  inflating: data/train/image/000853.png  
  inflating: data/train/image/000177.png  
  inflating: data/train/image/000584.png  
  inflating: data/train/image/001319.png  
  inflating: data/train/image/000597.png  
  inflating: data/train/image/000223.png  
  inflating: data/train/image/001350.png  
  inflating: data/train/image/000404.png  
  inflating: data/train/image/000488.png  
 

'# import library'

In [9]:
# =========================
# Cell 1: common (imports / utils / dataset / metrics / drive / zip-submission only)
# =========================
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import time
import json
import random
import shutil
import numpy as np
from zipfile import ZipFile, ZIP_DEFLATED
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, Dataset, Subset

import albumentations as A
import cv2
from torchvision import models
from torch.amp import autocast, GradScaler


# =========================
# Constants
# =========================
CLASS_NAMES = ["Bed", "Book", "Ceiling", "Chair", "Floor", "Cabinet", "Object", "Picture",
               "Sofa", "Desk", "TV", "Wall", "Window"]
NUM_CLASSES = 13
IGNORE_INDEX = 255

BOOK_ID = 1
CABINET_ID = 5
OBJECT_ID = 6

TRI_NAMES = ["Book", "Cabinet", "Object"]
TRI_NUM_CLASSES = 3
TRI_IGNORE = 255

BASE_TRI_IDS = [BOOK_ID, CABINET_ID, OBJECT_ID]


# =========================
# Utils
# =========================
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # 再現性寄り。ただしbenchmark=Trueは厳密再現性を崩す場合あり
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def now_ts():
    return time.strftime("%Y%m%d_%H%M%S")

def make_run_id(prefix: str):
    return f"{prefix}_{now_ts()}"

def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)
    return path

def must_exist(path: str, what: str):
    if (not path) or (not os.path.exists(path)):
        raise FileNotFoundError(f"[ERROR] {what} not found: {path}")
    return path

def write_json(path: str, obj: dict):
    with open(path, "w") as f:
        json.dump(obj, f, indent=2)

def append_jsonl(path: str, obj: dict):
    with open(path, "a") as f:
        f.write(json.dumps(obj) + "\n")

def fmt_metrics_console(tag: str, epoch: int, miou: float, class_iou, class_precision, class_recall, names):
    print(f"[{tag}] Epoch {epoch:02d} | mIoU={miou:.5f}")
    header = f"{'Class':<10} {'IoU':>8} {'Prec':>8} {'Rec':>8}"
    print(header)
    print("-" * len(header))
    for i, name in enumerate(names):
        iou = class_iou[i]
        pr  = class_precision[i]
        rc  = class_recall[i]
        print(f"{name:<10} {iou:>8.3f} {pr:>8.3f} {rc:>8.3f}")

def compute_metrics_from_cm(cm: torch.Tensor):
    cm = cm.float()
    tp = torch.diag(cm)
    fp = cm.sum(dim=0) - tp
    fn = cm.sum(dim=1) - tp
    iou = tp / (tp + fp + fn + 1e-8)
    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    return {
        "miou": iou.mean().item(),
        "class_iou": iou.cpu().tolist(),
        "class_precision": precision.cpu().tolist(),
        "class_recall": recall.cpu().tolist(),
    }

def update_cm(cm: torch.Tensor, pred: torch.Tensor, target: torch.Tensor, n_classes: int, ignore_index: int):
    pred = pred.detach().view(-1).to("cpu")
    target = target.detach().view(-1).to("cpu")
    mask = (target != ignore_index)
    cm += torch.bincount(
        target[mask] * n_classes + pred[mask],
        minlength=n_classes**2
    ).view(n_classes, n_classes)
    return cm

def estimate_height_from_depth(depth_np: np.ndarray) -> np.ndarray:
    """
    depth -> height 的な擬似特徴（元コード踏襲）
    """
    H, W = depth_np.shape
    y_grid = np.linspace(0, 1, H).reshape(H, 1).repeat(W, axis=1).astype(np.float32)
    height_map = y_grid * depth_np
    max_val = float(height_map.max())
    if max_val > 0:
        height_map /= max_val
    return height_map.astype(np.float32)


# =========================
# Drive
# =========================
def mount_drive():
    try:
        from google.colab import drive
        drive.mount("/content/drive")
    except Exception as e:
        print(f"[WARN] Drive mount skipped or failed: {e}")

def copy_to_drive_if_needed(src_path: str, dst_dir: str):
    """
    圧縮はしない。必要なファイルを個別にコピーするだけ。
    """
    if dst_dir is None:
        return
    ensure_dir(dst_dir)
    if os.path.isdir(src_path):
        # ディレクトリコピーが必要ならここで。今回の要件では原則不要。
        raise RuntimeError("copy_to_drive_if_needed: directory copy is disabled by design.")
    shutil.copy2(src_path, os.path.join(dst_dir, os.path.basename(src_path)))

def zip_submission_only(npy_path: str, zip_path: str):
    """
    submission.npy だけ zip（他の圧縮はしない）
    """
    print(f"[LONG] Start zipping submission only: {zip_path}")
    with ZipFile(zip_path, "w", ZIP_DEFLATED) as zf:
        zf.write(npy_path, arcname=os.path.basename(npy_path))
    print(f"[DONE] submission zip: {zip_path}")
    return zip_path


# =========================
# Dataset
# =========================
class NYUv2Dataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None, return_label=True):
        self.split = split
        self.transform = transform
        self.return_label = return_label

        src_split = 'train' if split in ['train', 'val'] else 'test'
        self.images_dir = os.path.join(root_dir, src_split, 'image')
        self.depths_dir = os.path.join(root_dir, src_split, 'depth')
        self.labels_dir = os.path.join(root_dir, src_split, 'label') if src_split == 'train' else None

        self.filenames = sorted([f for f in os.listdir(self.images_dir) if f.endswith('.png')])

        self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        self.std  = np.array([0.229, 0.224, 0.225], dtype=np.float32)
        self.d_mean, self.d_std = 0.5, 0.25

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fname = self.filenames[idx]
        rgb = np.array(Image.open(os.path.join(self.images_dir, fname)).convert('RGB'))
        depth = np.array(Image.open(os.path.join(self.depths_dir, fname)))
        if depth.ndim == 3:
            depth = depth[:, :, 0]

        depth = depth.astype(np.float32) / (65535.0 if depth.max() > 255 else 255.0)
        h_map = estimate_height_from_depth(depth)

        if self.labels_dir and self.return_label:
            label = np.array(Image.open(os.path.join(self.labels_dir, fname)))
        else:
            label = np.zeros(depth.shape, dtype=np.int32)

        if self.transform:
            augmented = self.transform(image=rgb, depth=depth, height=h_map, mask=label)
            rgb, depth, h_map, label = augmented['image'], augmented['depth'], augmented['height'], augmented['mask']

        rgb = (rgb.astype(np.float32) / 255.0 - self.mean) / self.std
        depth = (depth - self.d_mean) / self.d_std
        h_map = (h_map - self.d_mean) / self.d_std

        rgb_t = torch.from_numpy(rgb.transpose(2, 0, 1)).float()
        depth_t = torch.from_numpy(depth).unsqueeze(0).float()
        h_t = torch.from_numpy(h_map).unsqueeze(0).float()
        x = torch.cat([rgb_t, depth_t, h_t], dim=0)  # 5ch

        if self.split == 'test':
            return x, fname
        return x, torch.from_numpy(label).long()

class NYUv2TriDataset(Dataset):
    """
    TRI用: GTを {Book, Cabinet, Object} の3クラスに射影。
    それ以外は ignore。
    """
    def __init__(self, base_dataset: NYUv2Dataset):
        self.base = base_dataset

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        x, y = self.base[idx]
        tri = torch.full_like(y, TRI_IGNORE)
        tri[y == BOOK_ID] = 0
        tri[y == CABINET_ID] = 1
        tri[y == OBJECT_ID] = 2
        return x, tri


# =========================
# Logger
# =========================
class RunLogger:
    def __init__(self, run_dir: str, run_id: str):
        self.run_dir = run_dir
        self.run_id = run_id
        self.logs_dir = ensure_dir(os.path.join(run_dir, "logs"))
        self.cfg_path = os.path.join(self.logs_dir, "config.json")
        self.train_log = os.path.join(self.logs_dir, "train_log.jsonl")
        self.val_log = os.path.join(self.logs_dir, "val_log.jsonl")
        self.batch_log = os.path.join(self.logs_dir, "batch_log.jsonl")
        self.batch_buf = []

    def save_config(self, cfg: dict):
        write_json(self.cfg_path, cfg)

    def log_batch(self, tag: str, epoch: int, batch: int, losses: dict):
        self.batch_buf.append({
            "run_id": self.run_id,
            "tag": tag,
            "epoch": epoch,
            "batch": batch,
            "losses": {k: float(v) for k, v in losses.items()}
        })

    def flush_batch(self):
        if not self.batch_buf:
            return
        with open(self.batch_log, "a") as f:
            for e in self.batch_buf:
                f.write(json.dumps(e) + "\n")
        self.batch_buf = []

    def log_epoch_train(self, tag: str, epoch: int, lr: float, avg_losses: dict):
        append_jsonl(self.train_log, {
            "run_id": self.run_id,
            "tag": tag,
            "epoch": epoch,
            "lr": float(lr),
            "losses": {k: float(v) for k, v in avg_losses.items()}
        })

    def log_epoch_val(self, tag: str, epoch: int, metrics: dict, names: list, cm_np=None):
        append_jsonl(self.val_log, {
            "run_id": self.run_id,
            "tag": tag,
            "epoch": epoch,
            "miou": float(metrics["miou"]),
            "class_iou": {names[i]: float(metrics["class_iou"][i]) for i in range(len(names))},
            "class_precision": {names[i]: float(metrics["class_precision"][i]) for i in range(len(names))},
            "class_recall": {names[i]: float(metrics["class_recall"][i]) for i in range(len(names))}
        })
        if cm_np is not None:
            np.save(os.path.join(self.logs_dir, f"confusion_matrix_{tag}_epoch_{epoch}.npy"), cm_np)


In [10]:
# =========================
# Cell 2: BASE (model + losses + train/eval + inference submission)
# =========================

# ---- Losses (BASE用：元コード踏襲)
class FocalLoss(nn.Module):
    def __init__(self, weight=None, gamma=2.0, ignore_index=IGNORE_INDEX):
        super().__init__()
        self.weight = weight
        self.gamma = gamma
        self.ignore_index = ignore_index

    def forward(self, logits, target):
        ce = F.cross_entropy(logits, target, weight=self.weight,
                             ignore_index=self.ignore_index, reduction='none')
        pt = torch.exp(-ce)
        loss = ((1 - pt) ** self.gamma) * ce
        return loss.mean()

class ClassBalancedDiceLoss(nn.Module):
    def __init__(self, n_classes=NUM_CLASSES, smooth=1e-5, ignore_index=IGNORE_INDEX, class_weights=None):
        super().__init__()
        self.n_classes = n_classes
        self.smooth = smooth
        self.ignore_index = ignore_index
        self.class_weights = class_weights

    def forward(self, logits, target):
        prob = F.softmax(logits, dim=1)
        mask = (target != self.ignore_index)
        t = target.clone()
        t[~mask] = 0
        onehot = F.one_hot(t, self.n_classes).permute(0, 3, 1, 2).float()

        m = mask.unsqueeze(1).expand_as(prob)
        prob = prob * m
        onehot = onehot * m

        inter = (prob * onehot).sum(dim=(2, 3))
        union = prob.sum(dim=(2, 3)) + onehot.sum(dim=(2, 3))
        dice = (2 * inter + self.smooth) / (union + self.smooth)

        if self.class_weights is not None:
            dice_c = dice.mean(dim=0)
            w = self.class_weights.to(dice.device)
            weighted = (dice_c * w).sum() / (w.sum() + 1e-12)
            return 1 - weighted
        return 1 - dice.mean()

def make_boundary_mask(target: torch.Tensor, ignore_index=IGNORE_INDEX, dilate=2) -> torch.Tensor:
    valid = (target != ignore_index)
    t = target.clone()
    t[~valid] = -1

    up    = torch.zeros_like(t); up[:, 1:]  = t[:, :-1]
    down  = torch.zeros_like(t); down[:, :-1] = t[:, 1:]
    left  = torch.zeros_like(t); left[:, :, 1:] = t[:, :, :-1]
    right = torch.zeros_like(t); right[:, :, :-1] = t[:, :, 1:]

    edge = ((t != up) | (t != down) | (t != left) | (t != right)) & valid
    edge = edge.float().unsqueeze(1)
    for _ in range(max(0, dilate)):
        edge = F.max_pool2d(edge, kernel_size=3, stride=1, padding=1)
    return (edge > 0).float()

def dice_loss_binary(logits: torch.Tensor, target: torch.Tensor, eps=1e-6) -> torch.Tensor:
    prob = torch.sigmoid(logits)
    inter = (prob * target).sum(dim=(2,3))
    union = prob.sum(dim=(2,3)) + target.sum(dim=(2,3))
    dice = (2*inter + eps) / (union + eps)
    return 1 - dice.mean()

def ohem_cross_entropy(
    logits: torch.Tensor,
    target: torch.Tensor,
    weight: torch.Tensor,
    ignore_index=IGNORE_INDEX,
    min_kept=131072,
    thresh=0.7
) -> torch.Tensor:
    with torch.no_grad():
        prob = F.softmax(logits, dim=1)
        valid = (target != ignore_index)
        t_safe = target.clone()
        t_safe[~valid] = 0
        p_gt = prob.gather(1, t_safe.unsqueeze(1)).squeeze(1)
        hard = (p_gt < thresh) & valid

    loss = F.cross_entropy(logits, target, weight=weight, ignore_index=ignore_index, reduction='none')
    loss_valid = loss[valid]
    if loss_valid.numel() == 0:
        return loss.mean()

    loss_hard = loss[hard]
    if loss_hard.numel() >= min_kept:
        return torch.topk(loss_hard, k=min_kept, largest=True).values.mean()

    k = min(min_kept, loss_valid.numel())
    return torch.topk(loss_valid, k=k, largest=True).values.mean()


# ---- Model (BASE: ResNeXt101 + DeepLabV3+ OS=8) ※state_dict互換を崩さない
class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, p=1, d=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, k, padding=p, dilation=d, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.block(x)

class ASPP(nn.Module):
    def __init__(self, in_ch, out_ch=256, rates=(12, 24, 36)):
        super().__init__()
        r1, r2, r3 = rates
        self.b1 = ConvBNReLU(in_ch, out_ch, 1, 0, 1)
        self.b2 = ConvBNReLU(in_ch, out_ch, 3, r1, r1)
        self.b3 = ConvBNReLU(in_ch, out_ch, 3, r2, r2)
        self.b4 = ConvBNReLU(in_ch, out_ch, 3, r3, r3)
        self.b5 = nn.Sequential(nn.AdaptiveAvgPool2d(1), ConvBNReLU(in_ch, out_ch, 1, 0, 1))
        self.proj = nn.Sequential(ConvBNReLU(out_ch * 5, out_ch, 1, 0, 1), nn.Dropout(0.1))

    def forward(self, x):
        h, w = x.shape[2:]
        f1 = self.b1(x)
        f2 = self.b2(x)
        f3 = self.b3(x)
        f4 = self.b4(x)
        f5 = F.interpolate(self.b5(x), size=(h, w), mode='bilinear', align_corners=False)
        return self.proj(torch.cat([f1, f2, f3, f4, f5], dim=1))

class ResNeXtDeepLabV3Plus_OS8(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, in_channels=5, aspp_rates=(12,24,36)):
        super().__init__()
        backbone = models.resnext101_32x8d(
            weights=models.ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
            replace_stride_with_dilation=[False, True, True],
        )

        old_conv = backbone.conv1
        new_conv = nn.Conv2d(in_channels, old_conv.out_channels, 7, stride=2, padding=3, bias=False)
        with torch.no_grad():
            new_conv.weight[:, :3] = old_conv.weight
            mean_w = old_conv.weight.mean(dim=1, keepdim=True)
            new_conv.weight[:, 3:] = mean_w.repeat(1, in_channels-3, 1, 1)

        self.enc0 = nn.Sequential(new_conv, backbone.bn1, backbone.relu)
        self.pool = backbone.maxpool
        self.enc1 = backbone.layer1
        self.enc2 = backbone.layer2
        self.enc3 = backbone.layer3
        self.enc4 = backbone.layer4

        self.aspp = ASPP(2048, 256, rates=aspp_rates)
        self.low_proj = nn.Sequential(nn.Conv2d(256, 48, 1, bias=False), nn.BatchNorm2d(48), nn.ReLU(inplace=True))
        self.dec1 = ConvBNReLU(256 + 48, 256, 3, 1, 1)
        self.dec2 = ConvBNReLU(256, 256, 3, 1, 1)

        self.seg_head = nn.Conv2d(256, num_classes, 1)
        self.boundary_head = nn.Conv2d(256, 1, 1)

        self.aux_head = nn.Sequential(
            ConvBNReLU(1024, 256, 3, 1, 1),
            nn.Dropout(0.1),
            nn.Conv2d(256, num_classes, 1)
        )

    def forward(self, x, return_aux=False, return_boundary=False, return_feat=False):
        H, W = x.shape[2:]
        x0 = self.enc0(x)
        x1 = self.pool(x0)

        low = self.enc1(x1)
        x2 = self.enc2(low)
        mid = self.enc3(x2)
        x4 = self.enc4(mid)

        xA = self.aspp(x4)
        xA = F.interpolate(xA, size=low.shape[2:], mode='bilinear', align_corners=False)

        feat = torch.cat([xA, self.low_proj(low)], dim=1)
        feat = self.dec2(self.dec1(feat))  # decoder feat (H/4)

        seg = self.seg_head(feat)
        seg = F.interpolate(seg, size=(H, W), mode='bilinear', align_corners=False)

        outs = [seg]

        if return_boundary:
            bd = self.boundary_head(feat)
            bd = F.interpolate(bd, size=(H, W), mode='bilinear', align_corners=False)
            outs.append(bd)

        if return_aux and self.training:
            aux = self.aux_head(mid)
            aux = F.interpolate(aux, size=(H, W), mode='bilinear', align_corners=False)
            outs.append(aux)

        if return_feat:
            outs.append(feat)

        if len(outs) == 1:
            return outs[0]
        return tuple(outs)


# ---- BASE inference (TTA: heavy at final only, light at epochs as needed)
def tta_predict_logits(model, x, out_hw, img_size, scales=(1.0,), do_flip=True):
    """
    logits集約（softmax平均ではなく logits->softmax の平均を返す用途向け）
    ここでは「確率平均」を返す（[B,C,H,W]）
    """
    B = x.shape[0]
    C = NUM_CLASSES
    acc = torch.zeros((B, C, out_hw[0], out_hw[1]), device=x.device)
    n_aug = 0

    for s in scales:
        hs, ws = int(img_size * s), int(img_size * s)
        xs = F.interpolate(x, size=(hs, ws), mode='bilinear', align_corners=False)

        out = model(xs)  # [B,C,hs,ws] -> model内部で元解像度へ戻してないので注意
        out = F.interpolate(out, size=out_hw, mode='bilinear', align_corners=False)
        acc += F.softmax(out, dim=1); n_aug += 1

        if do_flip:
            xsf = torch.flip(xs, dims=[3])
            out_f = model(xsf)
            out_f = torch.flip(out_f, dims=[3])
            out_f = F.interpolate(out_f, size=out_hw, mode='bilinear', align_corners=False)
            acc += F.softmax(out_f, dim=1); n_aug += 1

    return acc / float(n_aug)


def run_base_train_and_submit(
    dataset_root="/content/data",
    img_size=768,
    batch_size=16,
    epochs=60,
    lr=1e-4,
    seed=42,
    phase1_epochs=40,
    save_cm_every=5,
    drive_dir=None,
    base_best_out="/content/base_best_model.pt",   # 既定：このパスへも更新
    run_mode_tag="base"
):
    set_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    run_id = make_run_id("base")
    run_dir = ensure_dir(os.path.join("/content", run_id))
    logger = RunLogger(run_dir, run_id)

    # paths
    base_best = os.path.join(run_dir, "base_best_model.pt")
    base_final = os.path.join(run_dir, "base_final_model.pt")

    print("====================================================")
    print(f"[BASE] run_id: {run_id}")
    print(f"[BASE] run_dir: {run_dir}")
    print(f"[BASE] logs   : {logger.logs_dir}")
    print(f"[BASE] save best : {base_best}")
    print(f"[BASE] save final: {base_final}")
    print(f"[BASE] device: {device}")
    print("====================================================")

    # Aug
    train_aug = A.Compose([
        A.RandomResizedCrop(size=(img_size, img_size), scale=(0.70, 1.00), ratio=(0.90, 1.10),
                            interpolation=cv2.INTER_LINEAR, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=10, border_mode=cv2.BORDER_REFLECT_101, p=0.3),
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.15, rotate_limit=0,
                           border_mode=cv2.BORDER_REFLECT_101, p=0.3),
        A.ColorJitter(p=0.5),
    ], additional_targets={'depth': 'image', 'height': 'image'})

    val_aug = A.Compose([
        A.Resize(img_size, img_size, interpolation=cv2.INTER_LINEAR)
    ], additional_targets={'depth': 'image', 'height': 'image'})

    # Data split（毎回seed固定なら毎回同じ）
    full = NYUv2Dataset(dataset_root, split="train", transform=None)
    n_total = len(full)
    n_val = int(n_total * 0.1)
    indices = list(range(n_total))
    random.shuffle(indices)
    train_idx = indices[:-n_val]
    val_idx = indices[-n_val:]

    train_ds = NYUv2Dataset(dataset_root, split="train", transform=train_aug)
    val_ds = NYUv2Dataset(dataset_root, split="train", transform=val_aug)

    train_loader = DataLoader(Subset(train_ds, train_idx), batch_size=batch_size, shuffle=True,
                              num_workers=8, pin_memory=True, drop_last=True)
    val_loader = DataLoader(Subset(val_ds, val_idx), batch_size=8, shuffle=False,
                            num_workers=4, pin_memory=True)

    # Test
    test_ds = NYUv2Dataset(dataset_root, split="test", transform=val_aug, return_label=False)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)

    print(f"[BASE] Train: {len(train_idx)}, Val: {len(val_idx)}, Test: {len(test_ds)}")
    print("[BASE] Model: ResNeXt101_32x8d + DeepLabV3+ (OS=8) + ASPP(12,24,36)")

    # weights（元コード踏襲）
    ce_weights = torch.tensor([1.0, 12.0, 0.6, 2.2,  0.6,  1.0,  2.0,  1.5,
                              1.5,  5.0,  3.0,  0.5,  1.0], device=device)
    dice_weights = torch.tensor([1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.0,
                                 1.0, 2.0, 1.0, 1.0, 1.0], device=device)

    cfg = {
        "mode": "base",
        "dataset_root": dataset_root,
        "img_size": img_size,
        "batch_size": batch_size,
        "epochs": epochs,
        "lr": lr,
        "seed": seed,
        "phase1_epochs": phase1_epochs,
        "model": "ResNeXt101_32x8d + DeepLabV3+ OS=8",
        "aspp_rates": [12,24,36],
        "save_paths": {"best": base_best, "final": base_final},
    }
    logger.save_config(cfg)

    model = ResNeXtDeepLabV3Plus_OS8(num_classes=NUM_CLASSES, in_channels=5, aspp_rates=(12,24,36)).to(device)

    criterion_focal = FocalLoss(weight=ce_weights, gamma=2.0, ignore_index=IGNORE_INDEX)
    criterion_ce = nn.CrossEntropyLoss(weight=ce_weights, ignore_index=IGNORE_INDEX)
    criterion_dice = ClassBalancedDiceLoss(class_weights=dice_weights, ignore_index=IGNORE_INDEX)
    bce_boundary = nn.BCEWithLogitsLoss()

    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-3)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    scaler = GradScaler("cuda")

    best_miou = 0.0

    for epoch in range(1, epochs + 1):
        model.train()

        if epoch <= phase1_epochs:
            w_focal, w_ce, w_dice, w_aux, w_bd, w_ohem = 0.25, 0.20, 0.40, 0.10, 0.05, 0.00
            use_ohem = False
        else:
            w_focal, w_ce, w_dice, w_aux, w_bd, w_ohem = 0.15, 0.00, 0.35, 0.10, 0.20, 0.20
            use_ohem = True

        epoch_losses = {"focal": 0.0, "ce": 0.0, "ohem_ce": 0.0, "dice": 0.0, "aux": 0.0, "boundary": 0.0, "total": 0.0}

        pbar = tqdm(train_loader, desc=f"[BASE] Epoch {epoch}/{epochs}")
        for bi, (x, y) in enumerate(pbar):
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)

            with autocast("cuda"):
                seg, bd, aux = model(x, return_aux=True, return_boundary=True)

                loss_focal = criterion_focal(seg, y)
                loss_dice  = criterion_dice(seg, y)

                if use_ohem:
                    loss_ohem = ohem_cross_entropy(seg, y, weight=ce_weights, ignore_index=IGNORE_INDEX,
                                                   min_kept=131072, thresh=0.7)
                    loss_ce = torch.tensor(0.0, device=device)
                else:
                    loss_ce = criterion_ce(seg, y)
                    loss_ohem = torch.tensor(0.0, device=device)

                loss_aux = criterion_ce(aux, y)

                bd_gt = make_boundary_mask(y, ignore_index=IGNORE_INDEX, dilate=2)
                loss_bd = bce_boundary(bd, bd_gt) + dice_loss_binary(bd, bd_gt)

                loss = (w_focal * loss_focal +
                        w_ce    * loss_ce +
                        w_ohem  * loss_ohem +
                        w_dice  * loss_dice +
                        w_aux   * loss_aux +
                        w_bd    * loss_bd)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            batch_losses = {
                "focal": float(loss_focal.item()),
                "ce": float(loss_ce.item()),
                "ohem_ce": float(loss_ohem.item()),
                "dice": float(loss_dice.item()),
                "aux": float(loss_aux.item()),
                "boundary": float(loss_bd.item()),
                "total": float(loss.item())
            }
            logger.log_batch(run_mode_tag, epoch, bi, batch_losses)

            for k in epoch_losses:
                epoch_losses[k] += batch_losses[k]

            pbar.set_postfix({
                "loss": f"{batch_losses['total']:.4f}",
                "dice": f"{batch_losses['dice']:.4f}",
                "bd": f"{batch_losses['boundary']:.3f}",
                "ph": "1" if epoch <= phase1_epochs else "2"
            })

        logger.flush_batch()

        n_batches = len(train_loader)
        avg_losses = {k: v / n_batches for k, v in epoch_losses.items()}
        lr_now = optimizer.param_groups[0]['lr']
        logger.log_epoch_train(run_mode_tag, epoch, lr_now, avg_losses)
        scheduler.step()

        # Val
        model.eval()
        cm = torch.zeros((NUM_CLASSES, NUM_CLASSES), dtype=torch.long)
        with torch.no_grad():
            for x, y in val_loader:
                x = x.to(device, non_blocking=True)
                seg = model(x)
                pred = seg.argmax(1).to("cpu")
                cm = update_cm(cm, pred, y.to("cpu"), NUM_CLASSES, IGNORE_INDEX)

        m = compute_metrics_from_cm(cm)
        miou = m["miou"]

        save_cm = (epoch % save_cm_every == 0) or (miou > best_miou)
        logger.log_epoch_val(run_mode_tag, epoch, m, CLASS_NAMES, cm.numpy() if save_cm else None)
        fmt_metrics_console("BASE", epoch, miou, m["class_iou"], m["class_precision"], m["class_recall"], CLASS_NAMES)

        if miou > best_miou:
            best_miou = miou
            torch.save(model.state_dict(), base_best)
            print(f"  -> New best BASE mIoU: {best_miou:.5f}")

    # Save final
    torch.save(model.state_dict(), base_final)

    # /content/base_best_model.pt 更新（要件通り）
    if os.path.exists(base_best_out):
        # 既存があっても上書きする（運用通り）
        pass
    shutil.copy2(base_best, base_best_out)
    print(f"[BASE] updated {base_best_out}")

    # ---- Submission (heavy TTA)
    best_path = base_best if os.path.exists(base_best) else base_final
    print(f"[LONG] Start generating submission (BASE heavy TTA) using: {best_path}")
    model.load_state_dict(torch.load(best_path, map_location=device))
    model.eval()

    tta_scales = (0.75, 1.0, 1.25, 1.5)
    preds = []
    with torch.no_grad():
        for x, _ in tqdm(test_loader, desc="[BASE] test inference"):
            x = x.to(device, non_blocking=True)
            prob = tta_predict_logits(model, x, out_hw=(512,512), img_size=img_size, scales=tta_scales, do_flip=True)
            pred = prob.argmax(1).cpu().numpy().astype(np.uint8)  # [1,512,512]
            preds.append(pred)

    submission = np.concatenate(preds, axis=0)
    sub_path = os.path.join(run_dir, "submission.npy")
    np.save(sub_path, submission)
    print(f"[DONE] submission saved: {sub_path} shape={submission.shape} best_mIoU={best_miou:.5f}")

    # zip: submission.npy only
    zip_path = os.path.join(run_dir, "submission_only.zip")
    zip_submission_only(sub_path, zip_path)

    # Drive copy（必要なら）：submission_only.zip だけ
    if drive_dir is not None:
        print(f"[LONG] Start copying submission_only.zip to Drive: {drive_dir}")
        copy_to_drive_if_needed(zip_path, drive_dir)
        print(f"[DONE] Drive copy: {os.path.join(drive_dir, os.path.basename(zip_path))}")

    return {
        "run_id": run_id,
        "run_dir": run_dir,
        "best_miou": best_miou,
        "base_best": base_best,
        "base_final": base_final,
        "submission_npy": sub_path,
        "submission_zip": zip_path,
        "base_best_out": base_best_out,
    }


In [13]:
# ===== Common Helpers (paste in both scripts or run once) =====
import os, time, json, shutil
from zipfile import ZipFile, ZIP_DEFLATED

def now_run_id(prefix: str):
    return f"{prefix}_{time.strftime('%Y%m%d_%H%M%S')}"

def ensure_dir(p: str):
    os.makedirs(p, exist_ok=True)
    return p

def mount_drive():
    from google.colab import drive
    drive.mount("/content/drive", force_remount=False)

def zip_dir(src_dir: str, zip_path: str):
    with ZipFile(zip_path, "w", ZIP_DEFLATED) as zf:
        for root, _, files in os.walk(src_dir):
            for fn in files:
                full = os.path.join(root, fn)
                rel = os.path.relpath(full, src_dir)
                zf.write(full, arcname=rel)

class StableLogger:
    """
    jsonlを固定フォーマットで出す（後から解析しやすい）
    - run_meta: 実行のメタ（使用モデル名など）を必ず残す
    - epoch_train / epoch_val: 既存のスタイルを崩さない
    """
    def __init__(self, log_dir: str):
        self.log_dir = ensure_dir(log_dir)
        self.meta_path = os.path.join(self.log_dir, "run_meta.json")
        self.train_path = os.path.join(self.log_dir, "train_log.jsonl")
        self.val_path   = os.path.join(self.log_dir, "val_log.jsonl")

    def write_meta(self, meta: dict):
        with open(self.meta_path, "w") as f:
            json.dump(meta, f, indent=2)

    def log_train_epoch(self, entry: dict):
        with open(self.train_path, "a") as f:
            f.write(json.dumps(entry) + "\n")

    def log_val_epoch(self, entry: dict):
        with open(self.val_path, "a") as f:
            f.write(json.dumps(entry) + "\n")

def copy_artifacts_to_drive(local_run_dir: str, drive_run_dir: str, extra_files=None):
    ensure_dir(drive_run_dir)
    # logs/
    src_logs = os.path.join(local_run_dir, "logs")
    dst_logs = os.path.join(drive_run_dir, "logs")
    if os.path.exists(dst_logs):
        shutil.rmtree(dst_logs)
    shutil.copytree(src_logs, dst_logs)

    # extra files (models, zip)
    if extra_files:
        for fp in extra_files:
            if fp and os.path.exists(fp):
                shutil.copy2(fp, os.path.join(drive_run_dir, os.path.basename(fp)))


In [18]:
# =========================
# Cell 3 (Final Weapon): TRI with Boundary, Dice, and Classification Head
# =========================

# ---- TRI model: Classification Headを追加
class TriDeepLabV3Plus_OS8(nn.Module):
    def __init__(self, in_channels=5, out_classes=TRI_NUM_CLASSES, aspp_rates=(12,24,36)):
        super().__init__()
        self.net = ResNeXtDeepLabV3Plus_OS8(
            num_classes=out_classes,
            in_channels=in_channels,
            aspp_rates=aspp_rates
        )

        # ★ Classification Head (ImageNetの記憶を呼び覚ますスイッチ)
        # ResNeXtの最終層の特徴マップ(2048ch) or Decoder中間を使う手もあるが、
        # ここではASPP直後のEncoder出力相当、あるいはBottleneckを利用。
        # ResNeXt101のlayer4出力は 2048ch
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # backboneの最終層(layer4)のチャンネル数は2048
        self.cls_head = nn.Sequential(
            nn.Linear(2048, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, out_classes) # [Book, Cabinet, Object] の有無を判定
        )

    def forward(self, x, return_boundary=False, return_cls=False):
        # Backboneの特徴を取り出すために、net内部を少しハックするか、
        # ResNeXtDeepLabV3Plus_OS8 に return_feat を実装しているのでそれを使う

        # feat: Decoderの手前の特徴、あるいはEncoderの出力
        # Cell 2の定義では `return_feat=True` で decoder feat (256ch) が返る
        # しかし分類には Encoderの深い特徴(2048ch)の方が「概念」を含んでいるため良い

        # 簡易実装: net.forward を分解して呼ぶのが綺麗だが、
        # ここでは Cell 2 の ResNeXtDeepLabV3Plus_OS8 を修正せずに済むよう、
        # 既存の forward を使いつつ、別途 encoder の特徴を取りたい。
        # -> Cell 2 の `forward` に `return_feat` があるので、それ(Decoder出力)を使う。
        # Decoder出力(256ch)には空間情報と意味情報が混ざっているのでこれでもOK。

        seg_out = self.net(x, return_boundary=return_boundary, return_feat=True)

        # netの戻り値の形式が変わるためアンパック
        # return_boundary=True, return_feat=True の場合: (seg, bd, feat)
        # return_boundary=False, return_feat=True の場合: (seg, feat)

        if return_boundary:
            seg, bd, feat = seg_out
        else:
            seg, feat = seg_out
            bd = None

        outs = [seg]
        if return_boundary:
            outs.append(bd)

        if return_cls:
            # featは [B, 256, H/4, W/4]
            # これをGlobal Poolingして分類
            pool = self.avg_pool(feat).flatten(1) # [B, 256]
            # cls_headの入力次元を合わせる必要がある (定義では2048にしてたがfeatは256)
            # -> __init__で256に修正済みとして扱うか、ここで動的対応
            # Cell 2 の decoder feat は 256ch です。
            cls_logits = self.cls_head_256(pool) # 下記initで再定義
            outs.append(cls_logits)

        if len(outs) == 1: return outs[0]
        return tuple(outs)

    def _init_cls_head(self):
        # feat (Decoder output) is 256 channels
        self.cls_head_256 = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(128, TRI_NUM_CLASSES)
        )


# クラス定義を微修正（__init__内で完結させる）
class TriDeepLabV3Plus_OS8_Cls(nn.Module):
    def __init__(self, in_channels=5, out_classes=TRI_NUM_CLASSES, aspp_rates=(12,24,36)):
        super().__init__()
        self.net = ResNeXtDeepLabV3Plus_OS8(
            num_classes=out_classes,
            in_channels=in_channels,
            aspp_rates=aspp_rates
        )
        # Decoder出力(256ch)を使って分類を行う
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.cls_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(128, out_classes)
        )

    def forward(self, x, return_boundary=False, return_cls=False):
        # Cell 2のモデルは (seg, bd, aux, feat) の順で可変長タプルを返す仕様
        # 確実に feat を取るため return_feat=True を指定

        out_tuple = self.net(x, return_boundary=return_boundary, return_feat=True)

        # out_tupleの要素数をチェックして分解
        # ケース1: seg only -> (seg, feat)
        # ケース2: seg + boundary -> (seg, bd, feat)
        # ※ auxは training=Trueかつreturn_aux=Trueの時のみだが、TRIではaux切るか？
        # Cell 3の呼び出し元では aux 使ってないので、ここではシンプルに。

        if return_boundary:
            seg, bd, feat = out_tuple
        else:
            seg, feat = out_tuple
            bd = None

        outs = [seg]
        if return_boundary:
            outs.append(bd)

        if return_cls:
            # Global Average Pooling
            x_cls = self.avg_pool(feat).flatten(1) # [B, 256]
            cls_logits = self.cls_head(x_cls)      # [B, 3]
            outs.append(cls_logits)

        if len(outs) == 1: return outs[0]
        return tuple(outs)


# ---- Helpers (変更なし、そのまま維持)
def map_y13_to_tri_ignore(y13):
    tri = torch.full_like(y13, TRI_IGNORE)
    tri[y13 == BOOK_ID] = 0; tri[y13 == CABINET_ID] = 1; tri[y13 == OBJECT_ID] = 2
    return tri

def fuse_base10_tri3_by_base_pred(base_logits_13, tri_logits_3):
    base_pred = base_logits_13.argmax(1)
    tri_pred  = tri_logits_3.argmax(1)
    tri_to_base = torch.empty_like(tri_pred)
    tri_to_base[tri_pred == 0] = BOOK_ID; tri_to_base[tri_pred == 1] = CABINET_ID; tri_to_base[tri_pred == 2] = OBJECT_ID
    mask = (base_pred == BOOK_ID) | (base_pred == CABINET_ID) | (base_pred == OBJECT_ID)
    fused = base_pred.clone(); fused[mask] = tri_to_base[mask]
    return fused

# 評価関数などは以前と同じものを使用可能（モデルのforward互換性はあるため）
# ただし run_tri_train 内で cls_loss を計算するロジックが必要

def run_tri_train(
    dataset_root="/content/data",
    img_size=768,
    batch_size=16,
    epochs=25,
    lr=3e-4, # 分類ヘッド学習のため少し下げない方がいい、維持
    seed=42,
    save_cm_every=5,
    fused_eval_every=5,
    fused_tta_light=True,
    drive_dir=None,
    BASE_MODEL_PATH="/content/base_best_model.pt",
):
    set_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    must_exist(BASE_MODEL_PATH, "BASE best model")

    run_id = make_run_id("tri_cls_boost") # ID変更: Classification Boost
    run_dir = ensure_dir(os.path.join("/content", run_id))
    logger = RunLogger(run_dir, run_id)
    tri_best, tri_final = os.path.join(run_dir, "tri_best.pt"), os.path.join(run_dir, "tri_final.pt")

    # Aug / Dataset / Loader (共通)
    tri_train_aug = A.Compose([A.Resize(img_size, img_size)], additional_targets={'depth': 'image', 'height': 'image'})
    tri_val_aug = A.Compose([A.Resize(img_size, img_size)], additional_targets={'depth': 'image', 'height': 'image'})

    full = NYUv2Dataset(dataset_root, split="train", transform=None)
    n_total = len(full); n_val = int(n_total * 0.1)
    indices = list(range(n_total)); random.shuffle(indices)
    train_idx, val_idx = indices[:-n_val], indices[-n_val:]

    base_train = NYUv2Dataset(dataset_root, split="train", transform=tri_train_aug)
    base_val   = NYUv2Dataset(dataset_root, split="train", transform=tri_val_aug)

    tri_train_loader = DataLoader(Subset(base_train, train_idx), batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, drop_last=True)
    base_val_loader_13 = DataLoader(Subset(base_val, val_idx), batch_size=1, shuffle=False, num_workers=2, pin_memory=True)

    print(f"[TRI Cls+] Training with Classification Head (Distilling 'Book' features)")

    base_model = ResNeXtDeepLabV3Plus_OS8(num_classes=NUM_CLASSES, in_channels=5).to(device)
    base_model.load_state_dict(torch.load(BASE_MODEL_PATH, map_location=device), strict=True)
    base_model.eval()
    for p in base_model.parameters(): p.requires_grad = False

    # ★ New Model Class
    tri_model = TriDeepLabV3Plus_OS8_Cls(in_channels=5, out_classes=TRI_NUM_CLASSES).to(device)

    # Losses
    tri_weights = torch.tensor([3.0, 1.5, 1.0], device=device)
    criterion_ce = nn.CrossEntropyLoss(weight=tri_weights, ignore_index=TRI_IGNORE)
    criterion_dice = ClassBalancedDiceLoss(n_classes=TRI_NUM_CLASSES, ignore_index=TRI_IGNORE, class_weights=tri_weights)
    bce_boundary = nn.BCEWithLogitsLoss()
    # ★ Classification Loss (Multi-label BCE)
    bce_cls = nn.BCEWithLogitsLoss()

    optimizer = optim.AdamW(tri_model.parameters(), lr=lr, weight_decay=1e-3)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    scaler = GradScaler("cuda")

    logger.save_config({"mode": "tri_cls_boost", "epochs": epochs, "lr": lr})
    best_miou = 0.0

    for epoch in range(1, epochs + 1):
        tri_model.train()
        epoch_losses = {"ce": 0.0, "dice": 0.0, "bd": 0.0, "cls": 0.0, "total": 0.0}

        pbar = tqdm(tri_train_loader, desc=f"[TRI Cls+] Epoch {epoch}/{epochs}")
        for bi, (x, y13) in enumerate(pbar):
            x, y13 = x.to(device, non_blocking=True), y13.to(device, non_blocking=True)

            with torch.no_grad():
                base_logits = base_model(x); base_pred = base_logits.argmax(1)
                mask = (base_pred == BOOK_ID) | (base_pred == CABINET_ID) | (base_pred == OBJECT_ID)

            y_tri = map_y13_to_tri_ignore(y13)
            y_tri_g = y_tri.clone(); y_tri_g[~mask] = TRI_IGNORE

            # ★ Classification Targetの作成
            # バッチ内の各画像について、GTの中にBook/Cabinet/Objectが含まれているか？ (0 or 1)
            # y_triは [B, H, W]
            B_sz = x.shape[0]
            y_cls_target = torch.zeros((B_sz, TRI_NUM_CLASSES), device=device)
            for b in range(B_sz):
                labels_in_img = torch.unique(y_tri[b]) # 0,1,2,255
                if 0 in labels_in_img: y_cls_target[b, 0] = 1.0 # Book exists
                if 1 in labels_in_img: y_cls_target[b, 1] = 1.0 # Cabinet exists
                if 2 in labels_in_img: y_cls_target[b, 2] = 1.0 # Object exists

            optimizer.zero_grad(set_to_none=True)
            with autocast("cuda"):
                # Forward with Cls
                seg, bd_pred, cls_logits = tri_model(x, return_boundary=True, return_cls=True)

                loss_ce = criterion_ce(seg, y_tri_g)
                loss_dice = criterion_dice(seg, y_tri_g)

                bd_gt = make_boundary_mask(y_tri_g, ignore_index=TRI_IGNORE, dilate=2)
                loss_bd = bce_boundary(bd_pred, bd_gt) + dice_loss_binary(bd_pred, bd_gt)

                # ★ Classification Loss
                # 本があるのにマスクが出ない -> 特徴抽出がサボっている -> Cls Lossで叩く
                loss_cls = bce_cls(cls_logits, y_cls_target)

                # 複合Loss (Clsは補助なので0.2くらい。ただし初期学習には効く)
                loss = 1.0*loss_ce + 0.5*loss_dice + 0.5*loss_bd + 0.3*loss_cls

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(tri_model.parameters(), max_norm=1.0)
            scaler.step(optimizer); scaler.update()

            batch_losses = {"ce": float(loss_ce), "dice": float(loss_dice), "bd": float(loss_bd), "cls": float(loss_cls), "total": float(loss)}
            logger.log_batch("tri_cls", epoch, bi, batch_losses)
            for k in epoch_losses: epoch_losses[k] += batch_losses[k]
            pbar.set_postfix({"L": f"{loss:.3f}", "Cls": f"{loss_cls:.3f}", "Dice": f"{loss_dice:.3f}"})

        logger.flush_batch()
        # (Logging omitted for brevity, logic same as before)
        scheduler.step()

        # Val (Use same eval function)
        # ※ eval関数は return_cls=False (default) で呼ぶので互換性あり
        m_tri, cm_tri = eval_tri_only_gated_by_base_pred(base_model, tri_model, base_val_loader_13, device, img_size)
        fmt_metrics_console("TRI Cls+", epoch, m_tri["miou"], m_tri["class_iou"], m_tri["class_precision"], m_tri["class_recall"], TRI_NAMES)

        if m_tri["miou"] > best_miou:
            best_miou = m_tri["miou"]
            torch.save(tri_model.state_dict(), tri_best)
            print(f" -> New best: {best_miou:.5f}")

        if (epoch % fused_eval_every == 0) or (epoch == epochs):
            tta = "light" if fused_tta_light and epoch!=epochs else "heavy"
            m_f, cm_f = eval_fused13(base_model, tri_model, base_val_loader_13, device, img_size, tta_mode=tta)
            fmt_metrics_console("FUSED", epoch, m_f["miou"], m_f["class_iou"], m_f["class_precision"], m_f["class_recall"], CLASS_NAMES)

    torch.save(tri_model.state_dict(), tri_final)
    return {"run_id": run_id, "run_dir": run_dir, "tri_best": tri_best}

# 必要に応じて eval_fused13 等のヘルパー関数は Cell 3 のものを使用（このコードブロックには再掲していませんが、メモリ上に残っていれば動作します）
# 安全のため、Cell 3 全体をこのコードで置き換えて実行することを推奨します。

In [19]:
# =========================
# Cell 4: main (RUN_MODE switch)  ※全モードで「最後にtest推論&submission.zip作成」を保証
# =========================
RUN_MODE = "tri"  # "base" or "tri" or "both"

def main():
    mount_drive()

    if RUN_MODE not in ["base", "tri", "both"]:
        raise ValueError("RUN_MODE must be 'base' or 'tri' or 'both'.")

    base_out = None
    tri_out  = None

    # 共通のDrive保存先
    base_drive_root = "/content/drive/MyDrive/nyu_runs/base"
    tri_drive_root  = "/content/drive/MyDrive/nyu_runs/tri"

    if RUN_MODE in ["base", "both"]:
        # BASE: test推論→submission_only.zip はこの関数内で必ず生成される（現状OK）
        base_drive_dir = os.path.join(base_drive_root, make_run_id("base_out"))
        base_out = run_base_train_and_submit(
            dataset_root="/content/data",
            img_size=768,
            batch_size=16,
            epochs=60,
            lr=1e-4,
            seed=42,
            phase1_epochs=40,
            save_cm_every=5,
            drive_dir=base_drive_dir,
            base_best_out="/content/base_best_model.pt",
        )
        print("[BASE] summary:", base_out)

    if RUN_MODE in ["tri", "both"]:
        # TRI: まず学習＆val（tri/fused13）まで
        tri_drive_dir = os.path.join(tri_drive_root, make_run_id("tri_out"))
        tri_out = run_tri_train(
            dataset_root="/content/data",
            img_size=768,
            batch_size=16,
            epochs=25,
            lr=3e-4,
            seed=42,
            save_cm_every=5,
            fused_eval_every=5,
            fused_tta_light=True,
            drive_dir=tri_drive_dir,
            BASE_MODEL_PATH="/content/base_best_model.pt",
        )
        print("[TRI] summary:", tri_out)

        # ★ 追加：全ケースで「最後にtest推論→submission_only.zip」を保証
        # TRI単体でも、BASE best は必須入力なので FUSED13 で提出を作れる
        submit_drive_dir = os.path.join(tri_drive_root, make_run_id("tri_submit"))
        submit_out_dir = os.path.join(tri_out["run_dir"], "submit_fused13")

        fused_submit = run_fused13_test_and_submit(
            dataset_root="/content/data",
            img_size=768,
            BASE_MODEL_PATH="/content/base_best_model.pt",
            TRI_MODEL_PATH=tri_out["tri_best"],   # bestで提出を作る（最終方針）
            out_dir=submit_out_dir,
            tta_mode="heavy",                     # 要件：testはheavy
            drive_dir=submit_drive_dir,           # submission_only.zipだけDriveへ
        )
        print("[FUSED13] submit:", fused_submit)

    # RUN_MODE="base" の場合：
    # BASE関数内で submission_only.zip が作られるのでここで追加処理は不要

if __name__ == "__main__":
    main()


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[TRI Cls+] Training with Classification Head (Distilling 'Book' features)


Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:836.)
  batch_losses = {"ce": float(loss_ce), "dice": float(loss_dice), "bd": float(loss_bd), "cls": float(loss_cls), "total": float(loss)}
[TRI Cls+] Epoch 1/25: 100%|██████████| 44/44 [01:23<00:00,  1.90s/it, L=2.039, Cls=0.559, Dice=0.758]


[TRI Cls+] Epoch 01 | mIoU=0.36821
Class           IoU     Prec      Rec
-------------------------------------
Book          0.020    0.080    0.025
Cabinet       0.602    0.672    0.853
Object        0.483    0.769    0.565
 -> New best: 0.36821


[TRI Cls+] Epoch 2/25: 100%|██████████| 44/44 [01:22<00:00,  1.88s/it, L=1.438, Cls=0.376, Dice=0.716]


[TRI Cls+] Epoch 02 | mIoU=0.40503
Class           IoU     Prec      Rec
-------------------------------------
Book          0.137    0.214    0.276
Cabinet       0.599    0.688    0.822
Object        0.479    0.752    0.570
 -> New best: 0.40503


[TRI Cls+] Epoch 3/25: 100%|██████████| 44/44 [01:23<00:00,  1.90s/it, L=1.579, Cls=0.425, Dice=0.710]


[TRI Cls+] Epoch 03 | mIoU=0.35602
Class           IoU     Prec      Rec
-------------------------------------
Book          0.099    0.153    0.221
Cabinet       0.589    0.631    0.900
Object        0.380    0.847    0.408


[TRI Cls+] Epoch 4/25: 100%|██████████| 44/44 [01:23<00:00,  1.90s/it, L=1.214, Cls=0.334, Dice=0.612]


[TRI Cls+] Epoch 04 | mIoU=0.39001
Class           IoU     Prec      Rec
-------------------------------------
Book          0.044    0.188    0.055
Cabinet       0.530    0.867    0.577
Object        0.596    0.634    0.908


[TRI Cls+] Epoch 5/25: 100%|██████████| 44/44 [01:22<00:00,  1.89s/it, L=1.186, Cls=0.293, Dice=0.597]


[TRI Cls+] Epoch 05 | mIoU=0.44019
Class           IoU     Prec      Rec
-------------------------------------
Book          0.071    0.106    0.180
Cabinet       0.632    0.773    0.776
Object        0.617    0.784    0.744
 -> New best: 0.44019
[FUSED] Epoch 05 | mIoU=0.64524
Class           IoU     Prec      Rec
-------------------------------------
Bed           0.774    0.845    0.902
Book          0.040    0.062    0.105
Ceiling       0.750    0.889    0.828
Chair         0.702    0.819    0.831
Floor         0.938    0.972    0.964
Cabinet       0.505    0.656    0.688
Object        0.485    0.655    0.652
Picture       0.686    0.866    0.768
Sofa          0.722    0.874    0.805
Desk          0.466    0.689    0.591
TV            0.680    0.953    0.704
Wall          0.867    0.932    0.926
Window        0.771    0.865    0.876


[TRI Cls+] Epoch 6/25: 100%|██████████| 44/44 [01:23<00:00,  1.89s/it, L=1.140, Cls=0.279, Dice=0.597]


[TRI Cls+] Epoch 06 | mIoU=0.52814
Class           IoU     Prec      Rec
-------------------------------------
Book          0.174    0.239    0.392
Cabinet       0.694    0.879    0.768
Object        0.716    0.796    0.877
 -> New best: 0.52814


[TRI Cls+] Epoch 7/25: 100%|██████████| 44/44 [01:22<00:00,  1.89s/it, L=1.057, Cls=0.243, Dice=0.585]


[TRI Cls+] Epoch 07 | mIoU=0.51383
Class           IoU     Prec      Rec
-------------------------------------
Book          0.189    0.256    0.419
Cabinet       0.698    0.800    0.847
Object        0.654    0.839    0.747


[TRI Cls+] Epoch 8/25: 100%|██████████| 44/44 [01:23<00:00,  1.89s/it, L=0.977, Cls=0.208, Dice=0.578]


[TRI Cls+] Epoch 08 | mIoU=0.51166
Class           IoU     Prec      Rec
-------------------------------------
Book          0.130    0.240    0.222
Cabinet       0.706    0.873    0.786
Object        0.699    0.779    0.872


[TRI Cls+] Epoch 9/25: 100%|██████████| 44/44 [01:23<00:00,  1.89s/it, L=0.810, Cls=0.136, Dice=0.475]


[TRI Cls+] Epoch 09 | mIoU=0.49877
Class           IoU     Prec      Rec
-------------------------------------
Book          0.121    0.208    0.225
Cabinet       0.712    0.786    0.883
Object        0.663    0.866    0.739


[TRI Cls+] Epoch 10/25: 100%|██████████| 44/44 [01:23<00:00,  1.89s/it, L=0.847, Cls=0.090, Dice=0.600]


[TRI Cls+] Epoch 10 | mIoU=0.50473
Class           IoU     Prec      Rec
-------------------------------------
Book          0.069    0.132    0.127
Cabinet       0.723    0.877    0.804
Object        0.722    0.801    0.879
[FUSED] Epoch 10 | mIoU=0.65796
Class           IoU     Prec      Rec
-------------------------------------
Bed           0.774    0.845    0.902
Book          0.053    0.111    0.091
Ceiling       0.750    0.889    0.828
Chair         0.702    0.819    0.831
Floor         0.938    0.972    0.964
Cabinet       0.595    0.779    0.716
Object        0.549    0.651    0.776
Picture       0.686    0.866    0.768
Sofa          0.722    0.874    0.805
Desk          0.466    0.689    0.591
TV            0.680    0.953    0.704
Wall          0.867    0.932    0.926
Window        0.771    0.865    0.876


[TRI Cls+] Epoch 11/25: 100%|██████████| 44/44 [01:22<00:00,  1.88s/it, L=0.757, Cls=0.119, Dice=0.485]


[TRI Cls+] Epoch 11 | mIoU=0.54212
Class           IoU     Prec      Rec
-------------------------------------
Book          0.147    0.198    0.363
Cabinet       0.732    0.888    0.807
Object        0.747    0.835    0.876
 -> New best: 0.54212


[TRI Cls+] Epoch 12/25: 100%|██████████| 44/44 [01:23<00:00,  1.89s/it, L=0.732, Cls=0.055, Dice=0.521]


[TRI Cls+] Epoch 12 | mIoU=0.55311
Class           IoU     Prec      Rec
-------------------------------------
Book          0.154    0.302    0.239
Cabinet       0.769    0.851    0.888
Object        0.737    0.865    0.833
 -> New best: 0.55311


[TRI Cls+] Epoch 13/25: 100%|██████████| 44/44 [01:22<00:00,  1.88s/it, L=0.716, Cls=0.103, Dice=0.504]


[TRI Cls+] Epoch 13 | mIoU=0.55749
Class           IoU     Prec      Rec
-------------------------------------
Book          0.185    0.249    0.419
Cabinet       0.749    0.847    0.866
Object        0.738    0.882    0.819
 -> New best: 0.55749


[TRI Cls+] Epoch 14/25: 100%|██████████| 44/44 [01:23<00:00,  1.89s/it, L=0.684, Cls=0.161, Dice=0.531]


[TRI Cls+] Epoch 14 | mIoU=0.57138
Class           IoU     Prec      Rec
-------------------------------------
Book          0.196    0.304    0.356
Cabinet       0.761    0.884    0.845
Object        0.757    0.846    0.879
 -> New best: 0.57138


[TRI Cls+] Epoch 15/25: 100%|██████████| 44/44 [01:22<00:00,  1.88s/it, L=0.648, Cls=0.066, Dice=0.575]


[TRI Cls+] Epoch 15 | mIoU=0.56213
Class           IoU     Prec      Rec
-------------------------------------
Book          0.187    0.285    0.350
Cabinet       0.751    0.873    0.843
Object        0.749    0.847    0.867
[FUSED] Epoch 15 | mIoU=0.66762
Class           IoU     Prec      Rec
-------------------------------------
Bed           0.774    0.845    0.902
Book          0.141    0.237    0.259
Ceiling       0.750    0.889    0.828
Chair         0.702    0.819    0.831
Floor         0.938    0.972    0.964
Cabinet       0.608    0.758    0.756
Object        0.572    0.698    0.760
Picture       0.686    0.866    0.768
Sofa          0.722    0.874    0.805
Desk          0.466    0.689    0.591
TV            0.680    0.953    0.704
Wall          0.867    0.932    0.926
Window        0.771    0.865    0.876


[TRI Cls+] Epoch 16/25: 100%|██████████| 44/44 [01:23<00:00,  1.90s/it, L=0.619, Cls=0.069, Dice=0.476]


[TRI Cls+] Epoch 16 | mIoU=0.57371
Class           IoU     Prec      Rec
-------------------------------------
Book          0.214    0.296    0.435
Cabinet       0.752    0.879    0.839
Object        0.755    0.851    0.870
 -> New best: 0.57371


[TRI Cls+] Epoch 17/25: 100%|██████████| 44/44 [01:22<00:00,  1.88s/it, L=0.614, Cls=0.114, Dice=0.516]


[TRI Cls+] Epoch 17 | mIoU=0.57299
Class           IoU     Prec      Rec
-------------------------------------
Book          0.206    0.289    0.418
Cabinet       0.752    0.880    0.838
Object        0.760    0.853    0.875


[TRI Cls+] Epoch 18/25: 100%|██████████| 44/44 [01:22<00:00,  1.88s/it, L=0.582, Cls=0.049, Dice=0.566]


[TRI Cls+] Epoch 18 | mIoU=0.57066
Class           IoU     Prec      Rec
-------------------------------------
Book          0.203    0.302    0.383
Cabinet       0.754    0.870    0.850
Object        0.755    0.857    0.864


[TRI Cls+] Epoch 19/25:  48%|████▊     | 21/44 [00:41<00:45,  1.96s/it, L=0.518, Cls=0.037, Dice=0.458]


KeyboardInterrupt: 

In [17]:
# =========================
# Cell 5: Recovery (Skip training, run submission only)
# =========================
import glob

def run_recovery_submission():
    # 1. 直近の学習フォルダ（tri_refined_...）を自動検索
    search_pattern = "/content/tri_refined_*"
    dirs = sorted(glob.glob(search_pattern), key=os.path.getmtime, reverse=True)

    if not dirs:
        print("[ERROR] 学習結果フォルダ(tri_refined_...)が見つかりません。")
        return

    latest_run_dir = dirs[0]
    run_id = os.path.basename(latest_run_dir)
    tri_best_path = os.path.join(latest_run_dir, "tri_best_model.pt")

    print(f"[RECOVERY] Found latest run: {latest_run_dir}")
    print(f"[RECOVERY] Using model: {tri_best_path}")

    if not os.path.exists(tri_best_path):
        print(f"[ERROR] Best model not found at {tri_best_path}")
        return

    # 2. 保存先設定
    # Driveへの保存先（元のCell 4の設定に合わせる）
    tri_drive_root = "/content/drive/MyDrive/nyu_runs/tri"
    submit_drive_dir = os.path.join(tri_drive_root, f"{run_id}_submit")

    # ローカルの保存先
    submit_out_dir = os.path.join(latest_run_dir, "submit_fused13")

    # 3. 提出ファイル生成を実行
    # (Cell 3 で定義した run_fused13_test_and_submit を使用)
    fused_submit = run_fused13_test_and_submit(
        dataset_root="/content/data",
        img_size=768,
        BASE_MODEL_PATH="/content/base_best_model.pt",
        TRI_MODEL_PATH=tri_best_path,
        out_dir=submit_out_dir,
        tta_mode="heavy",           # 提出用はHeavy TTA
        drive_dir=submit_drive_dir  # Driveへ保存
    )

    print("====================================================")
    print(f"[RECOVERY] SUCCESS! Submission generated.")
    print(f" - Local: {fused_submit}")
    print(f" - Drive: {submit_drive_dir}")
    print("====================================================")

if __name__ == "__main__":
    run_recovery_submission()

[RECOVERY] Found latest run: /content/tri_refined_20260105_100426
[RECOVERY] Using model: /content/tri_refined_20260105_100426/tri_best_model.pt
[LONG] Start FUSED13 submission: base=base_best_model.pt tri=tri_best_model.pt


[FUSED] inference:  93%|█████████▎| 606/654 [08:47<00:41,  1.15it/s]


KeyboardInterrupt: 

In [None]:
import json
import matplotlib.pyplot as plt
import glob
import os

def plot_learning_curves(log_file):
    epochs = []
    losses = []
    mious = []
    lrs = []

    # ログファイルの読み込み
    with open(log_file, 'r') as f:
        for line in f:
            data = json.loads(line)
            # ヘッダー行（info）はスキップ
            if "info" in data:
                continue

            epochs.append(data['epoch'])
            losses.append(data['loss'])
            mious.append(data['val_miou'])
            lrs.append(data['lr'])

    # プロットの作成
    fig, ax1 = plt.subplots(figsize=(12, 6))

    # Loss (左軸・赤)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Training Loss', color='tab:red')
    ax1.plot(epochs, losses, color='tab:red', label='Train Loss', marker='.')
    ax1.tick_params(axis='y', labelcolor='tab:red')
    ax1.grid(True, alpha=0.3)

    # mIoU (右軸・青)
    ax2 = ax1.twinx()
    ax2.set_ylabel('Validation mIoU', color='tab:blue')
    ax2.plot(epochs, mious, color='tab:blue', label='Val mIoU', marker='.')
    ax2.tick_params(axis='y', labelcolor='tab:blue')

    plt.title(f'Learning Curve: {os.path.basename(log_file)}')
    fig.tight_layout()
    plt.show()

    # 学習率の確認用プロット（Poly Schedulerが効いているか確認）
    plt.figure(figsize=(12, 3))
    plt.plot(epochs, lrs, color='orange')
    plt.title("Learning Rate Schedule")
    plt.xlabel("Epoch")
    plt.ylabel("LR")
    plt.grid(True, alpha=0.3)
    plt.show()

# 自動で最新のログファイルを見つけて描画
log_files = sorted(glob.glob("train_log_detailed_*.jsonl"))
if log_files:
    latest_log = log_files[-1]
    print(f"Plotting: {latest_log}")
    plot_learning_curves(latest_log)
else:
    print("No log file found.")

No log file found.


In [None]:
# -*- coding: utf-8 -*-
# NYUv2 val visualization (RGB | GT | Pred | RGB+Error) with FIXED palette -> ZIP
# - NO matplotlib colormap
# - GT and Pred share EXACT same palette mapping (ID -> RGB)
# - ignore_index(255) is shown as black by default

import os
import zipfile
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, Subset
from torchvision.datasets import VisionDataset
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
from torchvision import models

# =========================
# EDIT HERE
# =========================
DATASET_ROOT = "/content/data"
IMAGE_SIZE = (512, 512)

MODEL_PATH = "/content/checkpoints/model_20260103062841.pt"  # ←ここだけ直す
OUTDIR = "val_viz_fixed"
ZIP_PATH = "val_viz_fixed.zip"

MAX_SAVE = 50  # Noneで全件
SEED = 42
TRAIN_VAL_SPLIT = 0.9

NUM_WORKERS = 0
BATCH_SIZE = 1

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IGNORE_INDEX = 255

print("DEVICE:", DEVICE)
print("MODEL_PATH:", MODEL_PATH)

# =========================
# FIXED PALETTE (ID -> RGB)
# ここが「色の意味」そのもの。GTもPredも必ずこれを通す。
# =========================
PALETTE_13 = np.array([
    [  0,   0,   0],   # 0 wall
    [128,  64,  32],   # 1 floor
    [  0, 128, 255],   # 2 cabinet
    [255,   0, 128],   # 3 bed
    [  0, 255,   0],   # 4 chair
    [255, 128,   0],   # 5 sofa
    [255, 255,   0],   # 6 table
    [128, 128, 128],   # 7 door
    [  0, 255, 255],   # 8 window
    [  0,  64, 128],   # 9 bookshelf
    [255,   0,   0],   # 10 picture
    [128,   0, 255],   # 11 counter
    [ 64, 255, 128],   # 12 desk
], dtype=np.uint8)

# =========================
# Dataset helpers
# =========================
def pil_label_to_long_tensor(lbl_pil: Image.Image) -> torch.Tensor:
    arr = np.array(lbl_pil, dtype=np.int64)
    return torch.from_numpy(arr).long()

def depth_pil_to_tensor_01(depth_pil: Image.Image, size_hw):
    depth_pil = depth_pil.resize((size_hw[1], size_hw[0]), resample=Image.BILINEAR)
    arr = np.array(depth_pil)
    if arr.dtype == np.uint16 or (arr.max() > 255):
        arr = arr.astype(np.float32) / 65535.0
    else:
        arr = arr.astype(np.float32) / 255.0
    arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
    arr = np.clip(arr, 0.0, 1.0)
    return torch.from_numpy(arr).unsqueeze(0)  # [1,H,W]

class NYUv2(VisionDataset):
    def __init__(self, root, split="train", include_depth=True,
                 image_transform=None, depth_transform=None, target_transform=None):
        super().__init__(root)
        assert split in ("train", "test")
        self.root = root
        self.split = split
        self.include_depth = include_depth

        images_dir = os.path.join(self.root, self.split, "image")
        img_names = sorted(os.listdir(images_dir))

        self.images = [os.path.join(images_dir, n) for n in img_names]
        self.depths = [os.path.join(self.root, self.split, "depth", n) for n in img_names]

        if self.split == "train":
            self.targets = [os.path.join(self.root, self.split, "label", n) for n in img_names]
        else:
            self.targets = None

        self.image_transform = image_transform
        self.depth_transform = depth_transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert("RGB")
        depth = Image.open(self.depths[idx])

        if self.image_transform is not None:
            image = self.image_transform(image)
        if self.depth_transform is not None:
            depth = self.depth_transform(depth)

        if self.split == "test":
            return (image, depth) if self.include_depth else image

        target = Image.open(self.targets[idx])
        if self.target_transform is not None:
            target = self.target_transform(target)

        return (image, depth, target) if self.include_depth else (image, target)

# =========================
# Model blocks (ResNet50-UNet + optional SE)
# =========================
class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, p=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=k, padding=p, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.block(x)

class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.conv1 = ConvBNReLU(in_ch + skip_ch, out_ch)
        self.conv2 = ConvBNReLU(out_ch, out_ch)

    def forward(self, x, skip):
        x = nn.functional.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class DepthStem(nn.Module):
    def __init__(self, z_ch=64):
        super().__init__()
        self.net = nn.Sequential(
            ConvBNReLU(1, 32, 3, 1),
            nn.MaxPool2d(2),
            ConvBNReLU(32, 48, 3, 1),
            nn.MaxPool2d(2),
            ConvBNReLU(48, z_ch, 3, 1),
            nn.MaxPool2d(2),
        )
    def forward(self, d):
        return self.net(d)

class SEFromDepth(nn.Module):
    def __init__(self, feat_ch, z_ch, reduction=16, alpha=0.05):
        super().__init__()
        self.alpha = float(alpha)
        hidden = max(32, feat_ch // reduction)
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(z_ch, hidden, 1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden, feat_ch, 1, bias=True),
            nn.Sigmoid(),
        )
    def forward(self, F, z):
        g = self.fc(z)  # [B,C,1,1]
        scale = 1.0 + self.alpha * (g - 0.5)
        return F * scale

class ResNet50UNet(nn.Module):
    def __init__(self, num_classes=13, coarse_classes=5, in_channels=4,
                 pretrained=False,
                 use_se=False, se_z_ch=64, se_reduction=16, se_alpha=0.05,
                 se_dec_stages=("up1",)):
        super().__init__()
        self.use_se = use_se
        self.se_dec_stages = tuple(se_dec_stages)

        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None)

        old_conv = resnet.conv1
        if in_channels != old_conv.in_channels:
            new_conv = nn.Conv2d(
                in_channels, old_conv.out_channels,
                kernel_size=old_conv.kernel_size,
                stride=old_conv.stride,
                padding=old_conv.padding,
                bias=False
            )
            with torch.no_grad():
                new_conv.weight[:, :3] = old_conv.weight
                mean_w = old_conv.weight.mean(dim=1, keepdim=True)
                for c in range(3, in_channels):
                    new_conv.weight[:, c:c+1] = mean_w
            resnet.conv1 = new_conv

        self.enc0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)  # 1/2
        self.pool = resnet.maxpool                                        # 1/4
        self.enc1 = resnet.layer1                                         # 1/4
        self.enc2 = resnet.layer2                                         # 1/8
        self.enc3 = resnet.layer3                                         # 1/16
        self.enc4 = resnet.layer4                                         # 1/32

        self.center = ConvBNReLU(2048, 1024)
        self.up4 = UpBlock(1024, 1024, 512)
        self.up3 = UpBlock(512,  512,  256)
        self.up2 = UpBlock(256,  256,  128)
        self.up1 = UpBlock(128,  64,   64)

        self.head_feat = ConvBNReLU(64, 64)
        self.head_out  = nn.Conv2d(64, num_classes, kernel_size=1)

        self.coarse_head = nn.Conv2d(256, coarse_classes, kernel_size=1)
        self.boundary_head = nn.Sequential(ConvBNReLU(64, 32), nn.Conv2d(32, 1, 1))

        if self.use_se:
            self.depth_stem = DepthStem(z_ch=se_z_ch)
            stage_to_ch = {"center":1024, "up4":512, "up3":256, "up2":128, "up1":64}
            self.se_dec = nn.ModuleDict()
            for st in self.se_dec_stages:
                self.se_dec[st] = SEFromDepth(stage_to_ch[st], se_z_ch, se_reduction, se_alpha)
        else:
            self.depth_stem = None
            self.se_dec = None

    def _apply_dec_se(self, feat, z, stage_name: str):
        if (not self.use_se) or (self.se_dec is None) or (stage_name not in self.se_dec):
            return feat
        z_ = nn.functional.interpolate(z, size=feat.shape[-2:], mode="bilinear", align_corners=False)
        return self.se_dec[stage_name](feat, z_)

    def forward(self, x, depth=None):
        c1 = self.enc0(x)
        t  = self.pool(c1)
        c2 = self.enc1(t)
        c3 = self.enc2(c2)
        c4 = self.enc3(c3)
        c5 = self.enc4(c4)

        z = self.depth_stem(depth) if self.use_se else None

        x = self.center(c5)
        if z is not None: x = self._apply_dec_se(x, z, "center")
        x = self.up4(x, c4)
        if z is not None: x = self._apply_dec_se(x, z, "up4")
        x = self.up3(x, c3)
        if z is not None: x = self._apply_dec_se(x, z, "up3")

        coarse_logits = self.coarse_head(x)

        x = self.up2(x, c2)
        if z is not None: x = self._apply_dec_se(x, z, "up2")
        x = self.up1(x, c1)
        if z is not None: x = self._apply_dec_se(x, z, "up1")

        x = nn.functional.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
        feat = self.head_feat(x)
        fine_logits = self.head_out(feat)
        boundary_logit = self.boundary_head(x)
        return fine_logits, coarse_logits, boundary_logit

# =========================
# FIXED colorization (ID -> RGB)
# =========================
def id_to_rgb(mask_hw: np.ndarray, ignore_index=255, ignore_rgb=(0,0,0)) -> np.ndarray:
    """
    mask_hw: [H,W] int (0..12 or 255)
    returns: [H,W,3] uint8
    """
    out = np.zeros((mask_hw.shape[0], mask_hw.shape[1], 3), dtype=np.uint8)
    ign = (mask_hw == ignore_index)
    valid = ~ign
    m = mask_hw.copy()
    m[ign] = 0
    m = np.clip(m, 0, 12)
    out[valid] = PALETTE_13[m[valid]]
    out[ign] = np.array(ignore_rgb, dtype=np.uint8)
    return out

def denorm_rgb(rgb_tensor: torch.Tensor) -> Image.Image:
    mean = torch.tensor([0.485, 0.456, 0.406], device=rgb_tensor.device).view(3,1,1)
    std  = torch.tensor([0.229, 0.224, 0.225], device=rgb_tensor.device).view(3,1,1)
    x = (rgb_tensor * std + mean).clamp(0, 1)
    x = (x * 255).byte().permute(1,2,0).cpu().numpy()
    return Image.fromarray(x, mode="RGB")

def overlay_error(rgb_pil: Image.Image, gt_hw: np.ndarray, pred_hw: np.ndarray, ignore_index=255) -> Image.Image:
    rgb = np.array(rgb_pil).astype(np.float32)
    err = (gt_hw != ignore_index) & (gt_hw != pred_hw)
    if err.any():
        overlay = rgb.copy()
        overlay[err] = 0.6 * overlay[err] + 0.4 * np.array([255, 0, 0], dtype=np.float32)
        overlay = np.clip(overlay, 0, 255).astype(np.uint8)
        return Image.fromarray(overlay, mode="RGB")
    return rgb_pil

def stack_h(images):
    widths = [im.width for im in images]
    heights = [im.height for im in images]
    out = Image.new("RGB", (sum(widths), max(heights)))
    x = 0
    for im in images:
        out.paste(im, (x, 0))
        x += im.width
    return out

# =========================
# Transforms (same as eval)
# =========================
eval_image_transform = Compose([
    Resize(IMAGE_SIZE, interpolation=InterpolationMode.BILINEAR),
    ToTensor(),
    Normalize((0.485,0.456,0.406), (0.229,0.224,0.225)),
])

def depth_transform(pil):
    return depth_pil_to_tensor_01(pil, IMAGE_SIZE)

def target_transform(pil):
    pil = pil.resize((IMAGE_SIZE[1], IMAGE_SIZE[0]), resample=Image.NEAREST)
    return pil_label_to_long_tensor(pil)

# =========================
# Build val split (same indices rule)
# =========================
split_base = NYUv2(
    root=DATASET_ROOT,
    split="train",
    include_depth=True,
    image_transform=eval_image_transform,
    depth_transform=depth_transform,
    target_transform=target_transform,
)

n_total = len(split_base)
n_train = int(n_total * TRAIN_VAL_SPLIT)
n_val = n_total - n_train
g = torch.Generator().manual_seed(SEED)
_, val_subset = random_split(split_base, [n_train, n_val], generator=g)
val_ds = Subset(split_base, val_subset.indices)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"Val samples: {len(val_ds)}")

# =========================
# Load checkpoint + auto-detect SE
# =========================
ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
if isinstance(ckpt, dict) and "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
    sd = ckpt["state_dict"]
else:
    sd = ckpt

has_se = any(k.startswith("depth_stem.") or k.startswith("se_dec.") for k in sd.keys())
print("Detected SE in checkpoint:", has_se)

model = ResNet50UNet(
    num_classes=13,
    coarse_classes=5,
    in_channels=4,
    pretrained=False,
    use_se=has_se,
    se_z_ch=64,
    se_reduction=16,
    se_alpha=0.05,
    se_dec_stages=("up1",),
).to(DEVICE)

model.load_state_dict(sd, strict=True)
model.eval()

# =========================
# Generate panels with FIXED palette
# =========================
os.makedirs(OUTDIR, exist_ok=True)

saved = 0
with torch.no_grad():
    for i, (rgb, depth, label) in enumerate(tqdm(val_loader, desc="Saving fixed-color panels")):
        rgb = rgb.to(DEVICE, non_blocking=True)
        depth = depth.to(DEVICE, non_blocking=True)
        label = label.to(DEVICE, non_blocking=True)

        x = torch.cat([rgb.float(), depth.float()], dim=1)
        fine_logits, _, _ = model(x, depth=depth.float() if has_se else None)
        pred = fine_logits.argmax(dim=1)

        rgb_pil = denorm_rgb(rgb[0])

        gt_hw = label[0].cpu().numpy().astype(np.int32)
        pred_hw = pred[0].cpu().numpy().astype(np.int32)

        gt_rgb = id_to_rgb(gt_hw, ignore_index=IGNORE_INDEX, ignore_rgb=(0,0,0))
        pr_rgb = id_to_rgb(pred_hw, ignore_index=IGNORE_INDEX, ignore_rgb=(0,0,0))

        gt_pil = Image.fromarray(gt_rgb, mode="RGB")
        pr_pil = Image.fromarray(pr_rgb, mode="RGB")
        err_pil = overlay_error(rgb_pil, gt_hw, pred_hw, ignore_index=IGNORE_INDEX)

        panel = stack_h([rgb_pil, gt_pil, pr_pil, err_pil])
        panel.save(os.path.join(OUTDIR, f"{i:06d}.png"), optimize=True)

        saved += 1
        if MAX_SAVE is not None and saved >= MAX_SAVE:
            break

print(f"Saved: {saved} images -> {OUTDIR}")

# =========================
# Zip
# =========================
with zipfile.ZipFile(ZIP_PATH, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as zf:
    for fn in sorted(os.listdir(OUTDIR)):
        if fn.lower().endswith(".png"):
            zf.write(os.path.join(OUTDIR, fn), arcname=fn)

print(f"ZIP created: {ZIP_PATH}")


DEVICE: cuda
MODEL_PATH: /content/checkpoints/model_20260103062841.pt
Val samples: 80


FileNotFoundError: [Errno 2] No such file or directory: '/content/checkpoints/model_20260103062841.pt'

In [None]:
%ls -l checkpoints

In [None]:
# ------------------
#    Evaluation (DeepLabV3-ResNet50 / ResNet backbone)
# ------------------

ckpt = torch.load(final_path, map_location=device)

# もし「学習で checkpoint dict（model_state入り）」で保存している場合にも対応
if isinstance(ckpt, dict) and "model_state" in ckpt:
    model.load_state_dict(ckpt["model_state"])
else:
    model.load_state_dict(ckpt)

model.eval()

predictions = []

with torch.no_grad():
    print("Generating predictions...")
    for image, depth in tqdm(test_data):
        image = image.to(device, non_blocking=True)
        depth = depth.to(device, non_blocking=True)

        # 念のため：depthが3ch事故のときは1chに落とす
        if depth.dim() == 4 and depth.size(1) == 3:
            depth = depth[:, :1, :, :]

        x = torch.cat((image, depth), dim=1)   # [B,4,H,W]

        # DeepLabV3 は dict を返すので "out" を使う
        out = model(x)["out"]                  # [B,num_classes,H,W]
        pred = out.argmax(dim=1)               # [B,H,W]

        predictions.append(pred.cpu())

predictions = torch.cat(predictions, dim=0).numpy()
np.save("submission.npy", predictions)
print("Predictions saved to submission.npy")


"""## 提出方法

以下の3点をzip化し，Omnicampusの「最終課題 (セグメンテーション)」から提出してください．

- `submission.npy`
- `model.pt`や`model_best.pt`など，テストに使用した重み（拡張子は`.pt`のみ）
- 本Colab Notebook
"""

from zipfile import ZipFile, ZIP_DEFLATED

notebook_path = "/content/drive/MyDrive/Colab Notebooks/DL_Basic_2025_Competition_NYUv2_baseline.ipynb"

from zipfile import ZipFile, ZIP_DEFLATED

notebook_path = "/content/drive/MyDrive/code.ipynb"

with ZipFile("submission.zip",
             mode="w",
             compression=ZIP_DEFLATED,
             compresslevel=9) as zf:
    zf.write("submission.npy")
    zf.write(model_path)
#    zf.write(notebook_path,
#             arcname="code.ipynb")
