<a href="https://colab.research.google.com/github/trie0000/external/blob/main/code%20_20260116%EF%BC%BF%EF%BC%90%EF%BC%99%EF%BC%94%EF%BC%91.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# -*- coding: utf-8 -*-
"""DL_Basic_2025_Competition_NYUv2_baseline.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/17t7uAU0aST5aUt6sIJCqFneGlQrxFrJY

# Deep Learning 基礎講座　最終課題: NYUv2 セマンティックセグメンテーションcoarse_w=0」

## 概要
RGB画像から、画像内の各ピクセルがどのクラスに属するかを予測するセマンティックセグメンテーションタスク.

### データセット
- データセット: NYUv2 dataset
- 訓練データ: 795枚
- テストデータ: 654枚
- 入力: RGB画像 + 深度マップ（元画像サイズは可変）
- 出力: 13クラスのセグメンテーションマップ
- 評価指標: Mean IoU (Intersection over Union)

### データセットの詳細（[NYU Depth Dataset V2](https://cs.nyu.edu/~fergus/datasets/nyu_depth_v2.html)）
- 画像は屋内シーンを撮影したもので、家具や壁、床などの物体が含まれています.
- 各画像に対して13クラスのセグメンテーションラベルが提供されます.
- データは以下のディレクトリ構造で提供:
```
data/NYUv2/
├─train/
│  ├─image/      # RGB画像
│  │    000000.png
│  │    ...
│  │
│  ├─depth/      # 深度マップ
│  │    000000.png
│  │    ...
│  │
│  └─label/      # 13クラスセグメンテーション（教師ラベル）
│       000000.png
│       ...
└─test/
   ├─image/      # RGB画像
   │    000000.png
   │    ...
   │  ├─depth/   # 深度マップ
   │    000000.png
   │    ...
```

### タスクの詳細
- 入力のRGB画像と深度マップから、各ピクセルが13クラスのどれに属するかを予測するタスクです.
- 評価はMean IoUを使用します．
  - 各クラスごとにIoUを計算し、その平均を取ります.
  - IoUは以下の式で計算:
  $$IoU = \frac{TP}{TP + FP + FN}$$
    - TP: True Positive（正しく予測されたピクセル数）
    - FP: False Positive（誤って予測されたピクセル数）
    - FN: False Negative（見逃したピクセル数）

### 前処理
- 入力画像は512×512にリサイズされます.
- ピクセル値は0-1に正規化されます.
- セグメンテーションラベルは0-12の整数値（13クラス）です．
  - 255はignore index（評価から除外）

### 提出形式
- テスト画像（RGB + Depth）の各ピクセルに対してクラス（0~12）を予測したものをnumpy配列として保存されます.
- ファイル名: `submission.npy`
- 配列の形状: [テストデータ数, 高さ, 幅]
- 各ピクセルの値: 0-12の整数（予測クラス）

## 考えられる工夫の例
- 事前学習モデルの fine-tuning
    - ImageNetなどで事前学習されたモデルを本データセットでfine-tuningすることで性能向上が見込めます.
- 損失関数の再設計
    - クラスごとの出現頻度に応じて損失を補正するように損失関数を設計すると、クラス分布の不均衡に対してロバストな学習ができます.
- 画像の前処理
    - RandomResizedCrop / Flip / ColorJitter 等のデータ拡張を追加することで，汎化性能の向上が見込めます．

## 修了要件を満たす条件
- ベースラインでは，omnicampus 上での性能評価において， 38.2% となります．したがって，ベースラインである 38.2% を超えた提出のみ，修了要件として認めます．
- ベースラインから改善を加えることで， 50%以上に性能向上することを運営で確認しています．こちらを 1つの指標として取り組んでみてください．

## 注意点
- 最終的な予測モデルは，**配布している訓練データを用いて学習**（ファインチューニング含む）したものとしてください．
- 学習を行わず，**事前学習済みモデルの知識のみを利用した推論は禁止**します．
（例: ChatGPT 等の LLM に入力して推論を得るのみ）

### 事前学習モデルの利用
許可される事項
- **構成要素としての事前学習モデルの利用**: 自身で実装したアーキテクチャの一部（特徴抽出，埋め込みなど）として事前学習モデル（BERT，ViT など）を利用することは可能です．
- **ファインチューニング**: 上記の用途で利用している事前学習モデルのファインチューニングは可能です．

禁止される事項
- **タスク解決用の事前学習モデルの利用**: transformers などで提供されている，対象タスクを直接解くための事前学習モデルでそのまま推論のみ，またはファインチューニングのみで利用することは禁止とします．
  - 禁止事項の例: VQA タスクを直接解くための事前学習モデルを VQA タスクで利用する．

### データの準備
データをダウンロードした際に，google drive したため，利用するために google drive をマウントする必要があります．また， drive 上で展開することができないため，/content ディレクトリ下にコピーし "data.zip" を展開します．
google drive 上に "data.zip" が配置されていない場合は実行できません．google drive 上に "data.zip" (**831MB**) を配置することが可能であれば，"data_download.ipynb" を先に実行してください．難しい場合は，omnicampus 演習環境を利用してください．．
"""

# omnicampus 上では 4 セル目まで実行不要
# ドライブのマウント
from google.colab import drive
drive.mount('/content/drive')

# データダウンロード用の notebook にてgoogle drive への保存後，
# 反映に時間がかかる可能性がありますので，google drive のマウント後，
# data.zip がディレクトリ内にあることを確認してから実行してください．
# data.zip を /content 下にコピーする
!cp "/content/drive/MyDrive/data.zip" "/content"

# Commented out IPython magic to ensure Python compatibility.
# カレントディレクトリ下のファイル群を確認
# data.zip が表示されれば問題ないです
# %ls

# データを解凍する
!unzip data.zip
!mkdir data
!mv train test data/

"""omnicampus 演習環境では，data_download.ipynb のマウント，zip 化，drive へのコピーを実行しないことで，"data.zip" を解凍した形で配置されます．したがって，data ディレクトリが存在するディレクトリをカレントディレクトリとするだけで良いです．


"""

# Commented out IPython magic to ensure Python compatibility.
# omnicampus 実行用
# 以下の例では/workspace/Segmentation/split_data_scripts/omnicampus に data ディレクトリがあると想定
# %cd /workspace/Segmentation/split_data_scripts_omnicampus

# omnicampus 実行用
!pip install h5py scikit-image

"""# import library"""

Mounted at /content/drive
Archive:  data.zip
  inflating: data/train/image/000600.png  
  inflating: data/train/image/000320.png  
  inflating: data/train/image/000491.png  
  inflating: data/train/image/000502.png  
  inflating: data/train/image/000129.png  
  inflating: data/train/image/000044.png  
  inflating: data/train/image/000652.png  
  inflating: data/train/image/000919.png  
  inflating: data/train/image/000528.png  
  inflating: data/train/image/000853.png  
  inflating: data/train/image/000177.png  
  inflating: data/train/image/000584.png  
  inflating: data/train/image/001319.png  
  inflating: data/train/image/000597.png  
  inflating: data/train/image/000223.png  
  inflating: data/train/image/001350.png  
  inflating: data/train/image/000404.png  
  inflating: data/train/image/000488.png  
  inflating: data/train/image/000268.png  
  inflating: data/train/image/000481.png  
  inflating: data/train/image/000341.png  
  inflating: data/train/image/000159.png  
  inflati

'# import library'

In [2]:
# =========================
# Cell 1: common (imports / utils / dataset / metrics / drive / zip-submission only)
# =========================
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import time
import json
import random
import shutil
import numpy as np
from zipfile import ZipFile, ZIP_DEFLATED
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, Dataset, Subset

import albumentations as A
import cv2
from torchvision import models
from torch.amp import autocast, GradScaler


# =========================
# Constants
# =========================
CLASS_NAMES = ["Bed", "Book", "Ceiling", "Chair", "Floor", "Cabinet", "Object", "Picture",
               "Sofa", "Desk", "TV", "Wall", "Window"]
NUM_CLASSES = 13
IGNORE_INDEX = 255

BOOK_ID = 1
CABINET_ID = 5
OBJECT_ID = 6

TRI_NAMES = ["Book", "Cabinet", "Object"]
TRI_NUM_CLASSES = 3
TRI_IGNORE = 255

BASE_TRI_IDS = [BOOK_ID, CABINET_ID, OBJECT_ID]


# =========================
# Utils
# =========================
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def now_ts():
    return time.strftime("%Y%m%d_%H%M%S")

def make_run_id(prefix: str):
    return f"{prefix}_{now_ts()}"

def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)
    return path

def must_exist(path: str, what: str):
    if (not path) or (not os.path.exists(path)):
        raise FileNotFoundError(f"[ERROR] {what} not found: {path}")
    return path

def write_json(path: str, obj: dict):
    with open(path, "w") as f:
        json.dump(obj, f, indent=2)

def append_jsonl(path: str, obj: dict):
    with open(path, "a") as f:
        f.write(json.dumps(obj) + "\n")

def fmt_metrics_console(tag: str, epoch: int, miou: float, class_iou, class_precision, class_recall, names):
    print(f"[{tag}] Epoch {epoch:02d} | mIoU={miou:.5f}")
    header = f"{'Class':<10} {'IoU':>8} {'Prec':>8} {'Rec':>8}"
    print(header)
    print("-" * len(header))
    for i, name in enumerate(names):
        iou = class_iou[i]
        pr  = class_precision[i]
        rc  = class_recall[i]
        print(f"{name:<10} {iou:>8.3f} {pr:>8.3f} {rc:>8.3f}")

def compute_metrics_from_cm(cm: torch.Tensor):
    cm = cm.float()
    tp = torch.diag(cm)
    fp = cm.sum(dim=0) - tp
    fn = cm.sum(dim=1) - tp
    iou = tp / (tp + fp + fn + 1e-8)
    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    return {
        "miou": iou.mean().item(),
        "class_iou": iou.cpu().tolist(),
        "class_precision": precision.cpu().tolist(),
        "class_recall": recall.cpu().tolist(),
    }

def update_cm(cm: torch.Tensor, pred: torch.Tensor, target: torch.Tensor, n_classes: int, ignore_index: int):
    pred = pred.detach().view(-1).to("cpu")
    target = target.detach().view(-1).to("cpu")
    mask = (target != ignore_index)
    cm += torch.bincount(
        target[mask] * n_classes + pred[mask],
        minlength=n_classes**2
    ).view(n_classes, n_classes)
    return cm

def estimate_height_from_depth(depth_np: np.ndarray) -> np.ndarray:
    """
    depth -> height 的な擬似特徴（元コード踏襲）
    """
    H, W = depth_np.shape
    y_grid = np.linspace(0, 1, H).reshape(H, 1).repeat(W, axis=1).astype(np.float32)
    height_map = y_grid * depth_np
    max_val = float(height_map.max())
    if max_val > 0:
        height_map /= max_val
    return height_map.astype(np.float32)


# =========================
# Drive
# =========================
def mount_drive():
    try:
        from google.colab import drive
        drive.mount("/content/drive")
    except Exception as e:
        print(f"[WARN] Drive mount skipped or failed: {e}")

def copy_to_drive_if_needed(src_path: str, dst_dir: str):
    """
    圧縮はしない。必要なファイルを個別にコピーするだけ。
    """
    if dst_dir is None:
        return
    ensure_dir(dst_dir)
    if os.path.isdir(src_path):
        # ディレクトリコピーが必要ならここで。今回の要件では原則不要。
        raise RuntimeError("copy_to_drive_if_needed: directory copy is disabled by design.")
    shutil.copy2(src_path, os.path.join(dst_dir, os.path.basename(src_path)))

def zip_submission_only(npy_path: str, zip_path: str):
    """
    submission.npy だけ zip（他の圧縮はしない）
    """
    print(f"[LONG] Start zipping submission only: {zip_path}")
    with ZipFile(zip_path, "w", ZIP_DEFLATED) as zf:
        zf.write(npy_path, arcname=os.path.basename(npy_path))
    print(f"[DONE] submission zip: {zip_path}")
    return zip_path


# =========================
# Dataset
# =========================
class NYUv2Dataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None, return_label=True):
        self.split = split
        self.transform = transform
        self.return_label = return_label

        src_split = 'train' if split in ['train', 'val'] else 'test'
        self.images_dir = os.path.join(root_dir, src_split, 'image')
        self.depths_dir = os.path.join(root_dir, src_split, 'depth')
        self.labels_dir = os.path.join(root_dir, src_split, 'label') if src_split == 'train' else None

        self.filenames = sorted([f for f in os.listdir(self.images_dir) if f.endswith('.png')])

        self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        self.std  = np.array([0.229, 0.224, 0.225], dtype=np.float32)
        self.d_mean, self.d_std = 0.5, 0.25

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        fname = self.filenames[idx]
        rgb = np.array(Image.open(os.path.join(self.images_dir, fname)).convert('RGB'))
        depth = np.array(Image.open(os.path.join(self.depths_dir, fname)))
        if depth.ndim == 3:
            depth = depth[:, :, 0]

        depth = depth.astype(np.float32) / (65535.0 if depth.max() > 255 else 255.0)
        h_map = estimate_height_from_depth(depth)

        if self.labels_dir and self.return_label:
            label = np.array(Image.open(os.path.join(self.labels_dir, fname)))
        else:
            label = np.zeros(depth.shape, dtype=np.int32)

        if self.transform:
            augmented = self.transform(image=rgb, depth=depth, height=h_map, mask=label)
            rgb, depth, h_map, label = augmented['image'], augmented['depth'], augmented['height'], augmented['mask']

        rgb = (rgb.astype(np.float32) / 255.0 - self.mean) / self.std
        depth = (depth - self.d_mean) / self.d_std
        h_map = (h_map - self.d_mean) / self.d_std

        rgb_t = torch.from_numpy(rgb.transpose(2, 0, 1)).float()
        depth_t = torch.from_numpy(depth).unsqueeze(0).float()
        h_t = torch.from_numpy(h_map).unsqueeze(0).float()
        x = torch.cat([rgb_t, depth_t, h_t], dim=0)  # 5ch

        if self.split == 'test':
            return x, fname
        return x, torch.from_numpy(label).long()

class NYUv2TriDataset(Dataset):
    """
    TRI用: GTを {Book, Cabinet, Object} の3クラスに射影。
    それ以外は ignore。
    """
    def __init__(self, base_dataset: NYUv2Dataset):
        self.base = base_dataset

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        x, y = self.base[idx]
        tri = torch.full_like(y, TRI_IGNORE)
        tri[y == BOOK_ID] = 0
        tri[y == CABINET_ID] = 1
        tri[y == OBJECT_ID] = 2
        return x, tri


# =========================
# Logger
# =========================
class RunLogger:
    def __init__(self, run_dir: str, run_id: str):
        self.run_dir = run_dir
        self.run_id = run_id
        self.logs_dir = ensure_dir(os.path.join(run_dir, "logs"))
        self.cfg_path = os.path.join(self.logs_dir, "config.json")
        self.train_log = os.path.join(self.logs_dir, "train_log.jsonl")
        self.val_log = os.path.join(self.logs_dir, "val_log.jsonl")
        self.batch_log = os.path.join(self.logs_dir, "batch_log.jsonl")
        self.batch_buf = []

    def save_config(self, cfg: dict):
        write_json(self.cfg_path, cfg)

    def log_batch(self, tag: str, epoch: int, batch: int, losses: dict):
        self.batch_buf.append({
            "run_id": self.run_id,
            "tag": tag,
            "epoch": epoch,
            "batch": batch,
            "losses": {k: float(v) for k, v in losses.items()}
        })

    def flush_batch(self):
        if not self.batch_buf:
            return
        with open(self.batch_log, "a") as f:
            for e in self.batch_buf:
                f.write(json.dumps(e) + "\n")
        self.batch_buf = []

    def log_epoch_train(self, tag: str, epoch: int, lr: float, avg_losses: dict):
        append_jsonl(self.train_log, {
            "run_id": self.run_id,
            "tag": tag,
            "epoch": epoch,
            "lr": float(lr),
            "losses": {k: float(v) for k, v in avg_losses.items()}
        })

    def log_epoch_val(self, tag: str, epoch: int, metrics: dict, names: list, cm_np=None):
        append_jsonl(self.val_log, {
            "run_id": self.run_id,
            "tag": tag,
            "epoch": epoch,
            "miou": float(metrics["miou"]),
            "class_iou": {names[i]: float(metrics["class_iou"][i]) for i in range(len(names))},
            "class_precision": {names[i]: float(metrics["class_precision"][i]) for i in range(len(names))},
            "class_recall": {names[i]: float(metrics["class_recall"][i]) for i in range(len(names))}
        })
        if cm_np is not None:
            np.save(os.path.join(self.logs_dir, f"confusion_matrix_{tag}_epoch_{epoch}.npy"), cm_np)


In [3]:
# =========================
# Cell 2: BASE (model + losses + train/eval + inference submission)
# =========================

# ---- Losses (BASE用：元コード踏襲)
class FocalLoss(nn.Module):
    def __init__(self, weight=None, gamma=2.0, ignore_index=IGNORE_INDEX):
        super().__init__()
        self.weight = weight
        self.gamma = gamma
        self.ignore_index = ignore_index

    def forward(self, logits, target):
        ce = F.cross_entropy(logits, target, weight=self.weight,
                             ignore_index=self.ignore_index, reduction='none')
        pt = torch.exp(-ce)
        loss = ((1 - pt) ** self.gamma) * ce
        return loss.mean()

class ClassBalancedDiceLoss(nn.Module):
    def __init__(self, n_classes=NUM_CLASSES, smooth=1e-5, ignore_index=IGNORE_INDEX, class_weights=None):
        super().__init__()
        self.n_classes = n_classes
        self.smooth = smooth
        self.ignore_index = ignore_index
        self.class_weights = class_weights

    def forward(self, logits, target):
        prob = F.softmax(logits, dim=1)
        mask = (target != self.ignore_index)
        t = target.clone()
        t[~mask] = 0
        onehot = F.one_hot(t, self.n_classes).permute(0, 3, 1, 2).float()

        m = mask.unsqueeze(1).expand_as(prob)
        prob = prob * m
        onehot = onehot * m

        inter = (prob * onehot).sum(dim=(2, 3))
        union = prob.sum(dim=(2, 3)) + onehot.sum(dim=(2, 3))
        dice = (2 * inter + self.smooth) / (union + self.smooth)

        if self.class_weights is not None:
            dice_c = dice.mean(dim=0)
            w = self.class_weights.to(dice.device)
            weighted = (dice_c * w).sum() / (w.sum() + 1e-12)
            return 1 - weighted
        return 1 - dice.mean()

def make_boundary_mask(target: torch.Tensor, ignore_index=IGNORE_INDEX, dilate=2) -> torch.Tensor:
    valid = (target != ignore_index)
    t = target.clone()
    t[~valid] = -1

    up    = torch.zeros_like(t); up[:, 1:]  = t[:, :-1]
    down  = torch.zeros_like(t); down[:, :-1] = t[:, 1:]
    left  = torch.zeros_like(t); left[:, :, 1:] = t[:, :, :-1]
    right = torch.zeros_like(t); right[:, :, :-1] = t[:, :, 1:]

    edge = ((t != up) | (t != down) | (t != left) | (t != right)) & valid
    edge = edge.float().unsqueeze(1)
    for _ in range(max(0, dilate)):
        edge = F.max_pool2d(edge, kernel_size=3, stride=1, padding=1)
    return (edge > 0).float()

def dice_loss_binary(logits: torch.Tensor, target: torch.Tensor, eps=1e-6) -> torch.Tensor:
    prob = torch.sigmoid(logits)
    inter = (prob * target).sum(dim=(2,3))
    union = prob.sum(dim=(2,3)) + target.sum(dim=(2,3))
    dice = (2*inter + eps) / (union + eps)
    return 1 - dice.mean()

def ohem_cross_entropy(
    logits: torch.Tensor,
    target: torch.Tensor,
    weight: torch.Tensor,
    ignore_index=IGNORE_INDEX,
    min_kept=131072,
    thresh=0.7
) -> torch.Tensor:
    with torch.no_grad():
        prob = F.softmax(logits, dim=1)
        valid = (target != ignore_index)
        t_safe = target.clone()
        t_safe[~valid] = 0
        p_gt = prob.gather(1, t_safe.unsqueeze(1)).squeeze(1)
        hard = (p_gt < thresh) & valid

    loss = F.cross_entropy(logits, target, weight=weight, ignore_index=ignore_index, reduction='none')
    loss_valid = loss[valid]
    if loss_valid.numel() == 0:
        return loss.mean()

    loss_hard = loss[hard]
    if loss_hard.numel() >= min_kept:
        return torch.topk(loss_hard, k=min_kept, largest=True).values.mean()

    k = min(min_kept, loss_valid.numel())
    return torch.topk(loss_valid, k=k, largest=True).values.mean()


# ---- Model (BASE: ResNeXt101 + DeepLabV3+ OS=8) ※state_dict互換を崩さない
class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, p=1, d=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, k, padding=p, dilation=d, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.block(x)

class ASPP(nn.Module):
    def __init__(self, in_ch, out_ch=256, rates=(12, 24, 36)):
        super().__init__()
        r1, r2, r3 = rates
        self.b1 = ConvBNReLU(in_ch, out_ch, 1, 0, 1)
        self.b2 = ConvBNReLU(in_ch, out_ch, 3, r1, r1)
        self.b3 = ConvBNReLU(in_ch, out_ch, 3, r2, r2)
        self.b4 = ConvBNReLU(in_ch, out_ch, 3, r3, r3)
        self.b5 = nn.Sequential(nn.AdaptiveAvgPool2d(1), ConvBNReLU(in_ch, out_ch, 1, 0, 1))
        self.proj = nn.Sequential(ConvBNReLU(out_ch * 5, out_ch, 1, 0, 1), nn.Dropout(0.1))

    def forward(self, x):
        h, w = x.shape[2:]
        f1 = self.b1(x)
        f2 = self.b2(x)
        f3 = self.b3(x)
        f4 = self.b4(x)
        f5 = F.interpolate(self.b5(x), size=(h, w), mode='bilinear', align_corners=False)
        return self.proj(torch.cat([f1, f2, f3, f4, f5], dim=1))

class ResNeXtDeepLabV3Plus_OS8(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, in_channels=5, aspp_rates=(12,24,36)):
        super().__init__()
        backbone = models.resnext101_32x8d(
            weights=models.ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
            replace_stride_with_dilation=[False, True, True],
        )

        old_conv = backbone.conv1
        new_conv = nn.Conv2d(in_channels, old_conv.out_channels, 7, stride=2, padding=3, bias=False)
        with torch.no_grad():
            new_conv.weight[:, :3] = old_conv.weight
            mean_w = old_conv.weight.mean(dim=1, keepdim=True)
            new_conv.weight[:, 3:] = mean_w.repeat(1, in_channels-3, 1, 1)

        self.enc0 = nn.Sequential(new_conv, backbone.bn1, backbone.relu)
        self.pool = backbone.maxpool
        self.enc1 = backbone.layer1
        self.enc2 = backbone.layer2
        self.enc3 = backbone.layer3
        self.enc4 = backbone.layer4

        self.aspp = ASPP(2048, 256, rates=aspp_rates)
        self.low_proj = nn.Sequential(nn.Conv2d(256, 48, 1, bias=False), nn.BatchNorm2d(48), nn.ReLU(inplace=True))
        self.dec1 = ConvBNReLU(256 + 48, 256, 3, 1, 1)
        self.dec2 = ConvBNReLU(256, 256, 3, 1, 1)

        self.seg_head = nn.Conv2d(256, num_classes, 1)
        self.boundary_head = nn.Conv2d(256, 1, 1)

        self.aux_head = nn.Sequential(
            ConvBNReLU(1024, 256, 3, 1, 1),
            nn.Dropout(0.1),
            nn.Conv2d(256, num_classes, 1)
        )

    def forward(self, x, return_aux=False, return_boundary=False, return_feat=False):
        H, W = x.shape[2:]
        x0 = self.enc0(x)
        x1 = self.pool(x0)

        low = self.enc1(x1)
        x2 = self.enc2(low)
        mid = self.enc3(x2)
        x4 = self.enc4(mid)

        xA = self.aspp(x4)
        xA = F.interpolate(xA, size=low.shape[2:], mode='bilinear', align_corners=False)

        feat = torch.cat([xA, self.low_proj(low)], dim=1)
        feat = self.dec2(self.dec1(feat))  # decoder feat (H/4)

        seg = self.seg_head(feat)
        seg = F.interpolate(seg, size=(H, W), mode='bilinear', align_corners=False)

        outs = [seg]

        if return_boundary:
            bd = self.boundary_head(feat)
            bd = F.interpolate(bd, size=(H, W), mode='bilinear', align_corners=False)
            outs.append(bd)

        if return_aux and self.training:
            aux = self.aux_head(mid)
            aux = F.interpolate(aux, size=(H, W), mode='bilinear', align_corners=False)
            outs.append(aux)

        if return_feat:
            outs.append(feat)

        if len(outs) == 1:
            return outs[0]
        return tuple(outs)


# ---- BASE inference (TTA: heavy at final only, light at epochs as needed)
def tta_predict_logits(model, x, out_hw, img_size, scales=(1.0,), do_flip=True):
    """
    logits集約（softmax平均ではなく logits->softmax の平均を返す用途向け）
    ここでは「確率平均」を返す（[B,C,H,W]）
    """
    B = x.shape[0]
    C = NUM_CLASSES
    acc = torch.zeros((B, C, out_hw[0], out_hw[1]), device=x.device)
    n_aug = 0

    for s in scales:
        hs, ws = int(img_size * s), int(img_size * s)
        xs = F.interpolate(x, size=(hs, ws), mode='bilinear', align_corners=False)

        out = model(xs)  # [B,C,hs,ws] -> model内部で元解像度へ戻してないので注意
        out = F.interpolate(out, size=out_hw, mode='bilinear', align_corners=False)
        acc += F.softmax(out, dim=1); n_aug += 1

        if do_flip:
            xsf = torch.flip(xs, dims=[3])
            out_f = model(xsf)
            out_f = torch.flip(out_f, dims=[3])
            out_f = F.interpolate(out_f, size=out_hw, mode='bilinear', align_corners=False)
            acc += F.softmax(out_f, dim=1); n_aug += 1

    return acc / float(n_aug)


def run_base_train_and_submit(
    dataset_root="/content/data",
    img_size=768,
    batch_size=16,
    epochs=60,
    lr=1e-4,
    seed=42,
    phase1_epochs=40,
    save_cm_every=5,
    drive_dir=None,
    base_best_out="/content/base_best_model.pt",   # 既定：このパスへも更新
    run_mode_tag="base"
):
    set_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    run_id = make_run_id("base")
    run_dir = ensure_dir(os.path.join("/content", run_id))
    logger = RunLogger(run_dir, run_id)

    # paths
    base_best = os.path.join(run_dir, "base_best_model.pt")
    base_final = os.path.join(run_dir, "base_final_model.pt")

    print("====================================================")
    print(f"[BASE] run_id: {run_id}")
    print(f"[BASE] run_dir: {run_dir}")
    print(f"[BASE] logs   : {logger.logs_dir}")
    print(f"[BASE] save best : {base_best}")
    print(f"[BASE] save final: {base_final}")
    print(f"[BASE] device: {device}")
    print("====================================================")

    # Aug
    train_aug = A.Compose([
        A.RandomResizedCrop(size=(img_size, img_size), scale=(0.70, 1.00), ratio=(0.90, 1.10),
                            interpolation=cv2.INTER_LINEAR, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.Rotate(limit=10, border_mode=cv2.BORDER_REFLECT_101, p=0.3),
        A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.15, rotate_limit=0,
                           border_mode=cv2.BORDER_REFLECT_101, p=0.3),
        A.ColorJitter(p=0.5),
    ], additional_targets={'depth': 'image', 'height': 'image'})

    val_aug = A.Compose([
        A.Resize(img_size, img_size, interpolation=cv2.INTER_LINEAR)
    ], additional_targets={'depth': 'image', 'height': 'image'})

    # Data split（毎回seed固定なら毎回同じ）
    full = NYUv2Dataset(dataset_root, split="train", transform=None)
    n_total = len(full)
    n_val = int(n_total * 0.1)
    indices = list(range(n_total))
    random.shuffle(indices)
    train_idx = indices[:-n_val]
    val_idx = indices[-n_val:]

    train_ds = NYUv2Dataset(dataset_root, split="train", transform=train_aug)
    val_ds = NYUv2Dataset(dataset_root, split="train", transform=val_aug)

    train_loader = DataLoader(Subset(train_ds, train_idx), batch_size=batch_size, shuffle=True,
                              num_workers=8, pin_memory=True, drop_last=True)
    val_loader = DataLoader(Subset(val_ds, val_idx), batch_size=8, shuffle=False,
                            num_workers=4, pin_memory=True)

    # Test
    test_ds = NYUv2Dataset(dataset_root, split="test", transform=val_aug, return_label=False)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)

    print(f"[BASE] Train: {len(train_idx)}, Val: {len(val_idx)}, Test: {len(test_ds)}")
    print("[BASE] Model: ResNeXt101_32x8d + DeepLabV3+ (OS=8) + ASPP(12,24,36)")

    # weights（元コード踏襲）
    ce_weights = torch.tensor([1.0, 12.0, 0.6, 2.2,  0.6,  1.0,  2.0,  1.5,
                              1.5,  5.0,  3.0,  0.5,  1.0], device=device)
    dice_weights = torch.tensor([1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.0,
                                 1.0, 2.0, 1.0, 1.0, 1.0], device=device)

    cfg = {
        "mode": "base",
        "dataset_root": dataset_root,
        "img_size": img_size,
        "batch_size": batch_size,
        "epochs": epochs,
        "lr": lr,
        "seed": seed,
        "phase1_epochs": phase1_epochs,
        "model": "ResNeXt101_32x8d + DeepLabV3+ OS=8",
        "aspp_rates": [12,24,36],
        "save_paths": {"best": base_best, "final": base_final},
    }
    logger.save_config(cfg)

    model = ResNeXtDeepLabV3Plus_OS8(num_classes=NUM_CLASSES, in_channels=5, aspp_rates=(12,24,36)).to(device)

    criterion_focal = FocalLoss(weight=ce_weights, gamma=2.0, ignore_index=IGNORE_INDEX)
    criterion_ce = nn.CrossEntropyLoss(weight=ce_weights, ignore_index=IGNORE_INDEX)
    criterion_dice = ClassBalancedDiceLoss(class_weights=dice_weights, ignore_index=IGNORE_INDEX)
    bce_boundary = nn.BCEWithLogitsLoss()

    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-3)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    scaler = GradScaler("cuda")

    best_miou = 0.0

    for epoch in range(1, epochs + 1):
        model.train()

        if epoch <= phase1_epochs:
            w_focal, w_ce, w_dice, w_aux, w_bd, w_ohem = 0.25, 0.20, 0.40, 0.10, 0.05, 0.00
            use_ohem = False
        else:
            w_focal, w_ce, w_dice, w_aux, w_bd, w_ohem = 0.15, 0.00, 0.35, 0.10, 0.20, 0.20
            use_ohem = True

        epoch_losses = {"focal": 0.0, "ce": 0.0, "ohem_ce": 0.0, "dice": 0.0, "aux": 0.0, "boundary": 0.0, "total": 0.0}

        pbar = tqdm(train_loader, desc=f"[BASE] Epoch {epoch}/{epochs}")
        for bi, (x, y) in enumerate(pbar):
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)

            with autocast("cuda"):
                seg, bd, aux = model(x, return_aux=True, return_boundary=True)

                loss_focal = criterion_focal(seg, y)
                loss_dice  = criterion_dice(seg, y)

                if use_ohem:
                    loss_ohem = ohem_cross_entropy(seg, y, weight=ce_weights, ignore_index=IGNORE_INDEX,
                                                   min_kept=131072, thresh=0.7)
                    loss_ce = torch.tensor(0.0, device=device)
                else:
                    loss_ce = criterion_ce(seg, y)
                    loss_ohem = torch.tensor(0.0, device=device)

                loss_aux = criterion_ce(aux, y)

                bd_gt = make_boundary_mask(y, ignore_index=IGNORE_INDEX, dilate=2)
                loss_bd = bce_boundary(bd, bd_gt) + dice_loss_binary(bd, bd_gt)

                loss = (w_focal * loss_focal +
                        w_ce    * loss_ce +
                        w_ohem  * loss_ohem +
                        w_dice  * loss_dice +
                        w_aux   * loss_aux +
                        w_bd    * loss_bd)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            batch_losses = {
                "focal": float(loss_focal.item()),
                "ce": float(loss_ce.item()),
                "ohem_ce": float(loss_ohem.item()),
                "dice": float(loss_dice.item()),
                "aux": float(loss_aux.item()),
                "boundary": float(loss_bd.item()),
                "total": float(loss.item())
            }
            logger.log_batch(run_mode_tag, epoch, bi, batch_losses)

            for k in epoch_losses:
                epoch_losses[k] += batch_losses[k]

            pbar.set_postfix({
                "loss": f"{batch_losses['total']:.4f}",
                "dice": f"{batch_losses['dice']:.4f}",
                "bd": f"{batch_losses['boundary']:.3f}",
                "ph": "1" if epoch <= phase1_epochs else "2"
            })

        logger.flush_batch()

        n_batches = len(train_loader)
        avg_losses = {k: v / n_batches for k, v in epoch_losses.items()}
        lr_now = optimizer.param_groups[0]['lr']
        logger.log_epoch_train(run_mode_tag, epoch, lr_now, avg_losses)
        scheduler.step()

        # Val
        model.eval()
        cm = torch.zeros((NUM_CLASSES, NUM_CLASSES), dtype=torch.long)
        with torch.no_grad():
            for x, y in val_loader:
                x = x.to(device, non_blocking=True)
                seg = model(x)
                pred = seg.argmax(1).to("cpu")
                cm = update_cm(cm, pred, y.to("cpu"), NUM_CLASSES, IGNORE_INDEX)

        m = compute_metrics_from_cm(cm)
        miou = m["miou"]

        save_cm = (epoch % save_cm_every == 0) or (miou > best_miou)
        logger.log_epoch_val(run_mode_tag, epoch, m, CLASS_NAMES, cm.numpy() if save_cm else None)
        fmt_metrics_console("BASE", epoch, miou, m["class_iou"], m["class_precision"], m["class_recall"], CLASS_NAMES)

        if miou > best_miou:
            best_miou = miou
            torch.save(model.state_dict(), base_best)
            print(f"  -> New best BASE mIoU: {best_miou:.5f}")

    # Save final
    torch.save(model.state_dict(), base_final)

    # /content/base_best_model.pt 更新（要件通り）
    if os.path.exists(base_best_out):
        # 既存があっても上書きする（運用通り）
        pass
    shutil.copy2(base_best, base_best_out)
    print(f"[BASE] updated {base_best_out}")

    # ---- Submission (heavy TTA)
    best_path = base_best if os.path.exists(base_best) else base_final
    print(f"[LONG] Start generating submission (BASE heavy TTA) using: {best_path}")
    model.load_state_dict(torch.load(best_path, map_location=device))
    model.eval()

    tta_scales = (0.75, 1.0, 1.25, 1.5)
    preds = []
    with torch.no_grad():
        for x, _ in tqdm(test_loader, desc="[BASE] test inference"):
            x = x.to(device, non_blocking=True)
            prob = tta_predict_logits(model, x, out_hw=(512,512), img_size=img_size, scales=tta_scales, do_flip=True)
            pred = prob.argmax(1).cpu().numpy().astype(np.uint8)  # [1,512,512]
            preds.append(pred)

    submission = np.concatenate(preds, axis=0)
    sub_path = os.path.join(run_dir, "submission.npy")
    np.save(sub_path, submission)
    print(f"[DONE] submission saved: {sub_path} shape={submission.shape} best_mIoU={best_miou:.5f}")

    # zip: submission.npy only
    zip_path = os.path.join(run_dir, "submission_only.zip")
    zip_submission_only(sub_path, zip_path)

    # Drive copy（必要なら）：submission_only.zip だけ
    if drive_dir is not None:
        print(f"[LONG] Start copying submission_only.zip to Drive: {drive_dir}")
        copy_to_drive_if_needed(zip_path, drive_dir)
        print(f"[DONE] Drive copy: {os.path.join(drive_dir, os.path.basename(zip_path))}")

    return {
        "run_id": run_id,
        "run_dir": run_dir,
        "best_miou": best_miou,
        "base_best": base_best,
        "base_final": base_final,
        "submission_npy": sub_path,
        "submission_zip": zip_path,
        "base_best_out": base_best_out,
    }


In [4]:
# =========================
# Cell 5 (Pure): Analyze "Book" Neurons with Masking
# =========================

def find_book_neurons_pure(dataset_root, top_k=256):
    print("[ANALYSIS] Probing ImageNet weights using ONLY 'Book' pixels...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1. ImageNet学習済みモデル (Backboneのみ)
    # 構造を分解して layer4 の出力まで手動で流すのが確実です
    full_model = models.resnext101_32x8d(weights=models.ResNeXt101_32X8D_Weights.IMAGENET1K_V1).to(device)
    full_model.eval()

    # 特徴抽出用サブモデル (layer4まで)
    class BackboneFeature(nn.Module):
        def __init__(self, original_model):
            super().__init__()
            self.m = original_model
        def forward(self, x):
            x = self.m.conv1(x)
            x = self.m.bn1(x)
            x = self.m.relu(x)
            x = self.m.maxpool(x)
            x = self.m.layer1(x)
            x = self.m.layer2(x)
            x = self.m.layer3(x)
            x = self.m.layer4(x) # [B, 2048, H/32, W/32]
            return x

    feature_extractor = BackboneFeature(full_model)

    # 2. Dataset
    ds = NYUv2Dataset(dataset_root, split="train", transform=None)

    activation_sum = torch.zeros(2048).to(device)
    count_book_pixels = 0

    print(f"[ANALYSIS] Scanning entire dataset with Masking...")

    with torch.no_grad():
        indices = list(range(len(ds)))

        for idx in tqdm(indices):
            x, y = ds[idx] # x:5ch, y:label

            if BOOK_ID not in y:
                continue

            # 入力画像 (3ch)
            rgb_t = x[:3, :, :].unsqueeze(0).to(device) # [1, 3, H, W]

            # 正解マスク (Book=1, 他=0)
            mask_t = (y == BOOK_ID).unsqueeze(0).float().to(device) # [1, H, W]

            # 高速化のためリサイズ (512x512)
            input_size = 512
            rgb_t = F.interpolate(rgb_t, size=(input_size, input_size), mode='bilinear', align_corners=False)

            # 特徴マップ取得 [1, 2048, 16, 16] (512/32 = 16)
            feat = feature_extractor(rgb_t)

            # マスクも特徴マップと同じサイズ(16x16)に縮小
            # nearestで縮小しないと「0.5の本」みたいなのができてボケるので注意
            mask_small = F.interpolate(mask_t.unsqueeze(0), size=feat.shape[2:], mode='nearest') # [1, 1, 16, 16]

            # ★ここが重要：本以外の場所をゼロにする
            masked_feat = feat * mask_small

            # 本の領域だけの平均を取る (Global Avg Pool on Mask)
            # sum(特徴量) / sum(画素数)
            pixel_count = mask_small.sum()
            if pixel_count > 0:
                # この画像における「本ニューロン」の平均活性値
                img_book_act = masked_feat.sum(dim=(2, 3)) / pixel_count # [1, 2048]
                activation_sum += img_book_act.squeeze(0)
                count_book_pixels += 1 # 画像枚数カウントとして使う

    if count_book_pixels == 0:
        return []

    # 3. 全画像の平均
    mean_acts = activation_sum / count_book_pixels

    # 4. ランキング
    top_values, top_indices = torch.topk(mean_acts, k=top_k)

    print(f"[RESULT] Top {top_k} Pure Book Channels:")
    print(top_indices.tolist())

    return top_indices.tolist()

# 実行
book_channel_indices = find_book_neurons_pure("/content/data", top_k=256)

[ANALYSIS] Probing ImageNet weights using ONLY 'Book' pixels...
Downloading: "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth" to /root/.cache/torch/hub/checkpoints/resnext101_32x8d-8ba56ff5.pth


100%|██████████| 340M/340M [00:01<00:00, 237MB/s]


[ANALYSIS] Scanning entire dataset with Masking...


100%|██████████| 795/795 [00:26<00:00, 29.98it/s]


[RESULT] Top 256 Pure Book Channels:
[1644, 661, 377, 1641, 564, 638, 651, 1953, 1376, 1535, 1258, 44, 1892, 439, 1524, 1695, 1531, 1765, 128, 67, 1448, 1923, 664, 199, 1759, 1828, 1631, 1702, 474, 1559, 1658, 1129, 1853, 104, 673, 1050, 1177, 1821, 467, 1180, 1307, 1278, 2014, 1212, 1751, 105, 508, 133, 1290, 1428, 301, 1545, 1724, 1571, 588, 965, 1872, 757, 574, 620, 1227, 850, 1083, 417, 412, 963, 1160, 1125, 1165, 891, 1925, 287, 1670, 1612, 1606, 1973, 938, 1963, 106, 1643, 1669, 1573, 454, 450, 1730, 1175, 1988, 103, 448, 420, 1930, 1106, 1496, 931, 1276, 1121, 526, 1749, 788, 351, 378, 1518, 131, 37, 741, 737, 1575, 1741, 613, 1966, 1735, 1119, 698, 1857, 235, 578, 473, 1928, 1408, 271, 540, 1059, 355, 2020, 223, 1592, 1495, 1568, 113, 1753, 712, 1646, 539, 1417, 826, 1321, 1400, 1285, 1538, 55, 294, 488, 1281, 246, 459, 1918, 1596, 1663, 667, 1343, 1058, 414, 1373, 1688, 627, 121, 1888, 1164, 1158, 1087, 411, 1577, 60, 88, 1755, 1654, 1878, 400, 1585, 1984, 184, 1603, 1523, 147

In [5]:
# =========================
# Cell 3: Definitions (Model, Metrics & Train Function with Full Class Logs)
# =========================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models
import os
import numpy as np
import logging
from tqdm import tqdm

# --- 定数定義 (User Specified) ---
CLASS_NAMES = ["Bed", "Book", "Ceiling", "Chair", "Floor", "Cabinet", "Object", "Picture",
               "Sofa", "Desk", "TV", "Wall", "Window"]
NUM_CLASSES = 13
IGNORE_INDEX = 255

BOOK_ID = 1      # 13クラス定義準拠
CABINET_ID = 5
OBJECT_ID = 6
DESK_ID = 9

# --- 0. Logger Setup (修正: ファイルパスをコンソール表示) ---
def setup_logger(run_dir):
    logger = logging.getLogger(f"train_logger_{run_dir}")
    logger.setLevel(logging.INFO)

    if logger.hasHandlers():
        logger.handlers.clear()

    logger.propagate = False

    formatter = logging.Formatter('[%(asctime)s] %(message)s')

    # ログファイルパス
    file_path = os.path.join(run_dir, "log.txt")

    # ファイルハンドラ
    file_handler = logging.FileHandler(file_path)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    # ストリームハンドラ (コンソール)
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)

    # ログファイルの場所をコンソールに表示
    print(f"[OUTPUT] Log file created at: {file_path}")

    return logger

# --- 1. Metric Helper ---
class IoUMeter:
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.reset()

    def reset(self):
        self.confusion_matrix = np.zeros((self.num_classes, self.num_classes), dtype=np.int64)

    def update(self, pred, target):
        if isinstance(pred, torch.Tensor): pred = pred.cpu().numpy()
        if isinstance(target, torch.Tensor): target = target.cpu().numpy()
        pred = pred.flatten()
        target = target.flatten()
        mask = (target != IGNORE_INDEX)
        pred = pred[mask]
        target = target[mask]
        valid_indices = target * self.num_classes + pred
        counts = np.bincount(valid_indices, minlength=self.num_classes**2)
        self.confusion_matrix += counts.reshape(self.num_classes, self.num_classes)

    def get_metrics(self):
        cm = self.confusion_matrix
        tp = np.diag(cm)
        fp = cm.sum(axis=0) - tp
        fn = cm.sum(axis=1) - tp
        iou = tp / (tp + fp + fn + 1e-10)
        precision = tp / (tp + fp + 1e-10)
        recall = tp / (tp + fn + 1e-10)
        return {"iou": iou, "miou": np.nanmean(iou), "precision": precision, "recall": recall, "cm": cm}

# --- 2. Model Definition (13-Class Boost) ---
class TriDeepLabV3Plus_Boosted(nn.Module):
    def __init__(self, book_indices, in_channels=5, out_classes=13):
        super().__init__()
        self.book_indices = book_indices
        self.backbone = models.resnext101_32x8d(weights=models.ResNeXt101_32X8D_Weights.IMAGENET1K_V1)

        if in_channels != 3:
            old_conv = self.backbone.conv1
            self.backbone.conv1 = nn.Conv2d(in_channels, old_conv.out_channels,
                                            kernel_size=old_conv.kernel_size, stride=old_conv.stride,
                                            padding=old_conv.padding, bias=old_conv.bias)
            with torch.no_grad():
                self.backbone.conv1.weight[:, :3] = old_conv.weight
                self.backbone.conv1.weight[:, 3:] = old_conv.weight.mean(dim=1, keepdim=True).repeat(1, in_channels-3, 1, 1)

        self.layer0 = nn.Sequential(self.backbone.conv1, self.backbone.bn1, self.backbone.relu, self.backbone.maxpool)
        self.layer1 = self.backbone.layer1
        self.layer2 = self.backbone.layer2
        self.layer3 = self.backbone.layer3
        self.layer4 = self.backbone.layer4

        self.main_decoder = nn.Sequential(
            nn.Conv2d(2048, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256), nn.ReLU(),
            nn.Conv2d(256, out_classes, kernel_size=1)
        )

        num_book_channels = len(book_indices)
        self.book_branch = nn.Sequential(
            nn.Conv2d(num_book_channels, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        input_size = x.shape[-2:]
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        feat = self.layer4(x)

        main_out = self.main_decoder(feat)
        main_out = F.interpolate(main_out, size=input_size, mode='bilinear', align_corners=False)

        # Boost Map (Index=BOOK_ID)
        book_feat = feat[:, self.book_indices, :, :]
        book_att = self.book_branch(book_feat)
        book_att = F.interpolate(book_att, size=input_size, mode='bilinear', align_corners=False)

        boost_map = book_att.squeeze(1) * 2.0
        mask = torch.zeros_like(main_out)
        mask[:, BOOK_ID, :, :] = boost_map
        main_out = main_out + mask
        return main_out

# --- Helper: 統合関数 (Internal Validation用) ---
def fuse_conditional_boost_internal(base_logits, boost_logits, alpha=2.0):
    TARGET_IDS = [BOOK_ID, CABINET_ID, OBJECT_ID, DESK_ID]
    base_probs = F.softmax(base_logits, dim=1)
    base_pred = base_probs.argmax(dim=1)

    mask = torch.zeros_like(base_pred, dtype=torch.bool)
    for tid in TARGET_IDS:
        mask |= (base_pred == tid)

    final_logits = base_logits.clone()
    boost_book_score = boost_logits[:, BOOK_ID, :, :]
    final_logits[:, BOOK_ID, :, :][mask] += boost_book_score[mask] * alpha

    return F.softmax(final_logits, dim=1)

# --- 3. Training Function with Full Class Logs ---
def run_tri_train_boost(book_indices, dataset_root, img_size, batch_size, epochs, lr, seed, save_cm_every, fused_eval_every, fused_tta_light, BASE_MODEL_PATH):
    run_dir = "runs/tri_boost_experiment"
    os.makedirs(run_dir, exist_ok=True)
    logger = setup_logger(run_dir)

    logger.info(f"[START] Boost Training (Fused Val). ID: Book={BOOK_ID}, Cabinet={CABINET_ID}, Obj={OBJECT_ID}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    best_path = os.path.join(run_dir, "tri_best_model.pt")

    # Validation用 Base Model Load
    logger.info(f"Loading Base Model for Validation: {BASE_MODEL_PATH}")
    base_model = ResNeXtDeepLabV3Plus_OS8(num_classes=NUM_CLASSES, in_channels=5).to(device)
    base_model.load_state_dict(torch.load(BASE_MODEL_PATH, map_location=device))
    base_model.eval()
    for param in base_model.parameters():
        param.requires_grad = False

    # Dataset
    train_ds = NYUv2Dataset(dataset_root, split="train", transform=None)
    val_ds = NYUv2Dataset(dataset_root, split="val", transform=None)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4)

    # Boost Model
    model = TriDeepLabV3Plus_Boosted(book_indices=book_indices, in_channels=5, out_classes=NUM_CLASSES).to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)

    iou_meter = IoUMeter(num_classes=NUM_CLASSES)
    best_fused_miou = 0.0

    for epoch in range(epochs):
        # --- Train Phase ---
        model.train()
        epoch_loss = 0
        for x, y in tqdm(train_loader, desc=f"Ep {epoch+1}/{epochs} [Train]"):
            x, y = x.to(device), y.to(device).long()
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        avg_train_loss = epoch_loss / len(train_loader)

        # --- Fused Validation Phase ---
        if (epoch + 1) % fused_eval_every == 0 or (epoch + 1) == epochs:
            model.eval()
            iou_meter.reset()
            val_loss = 0

            with torch.no_grad():
                for x, y in tqdm(val_loader, desc=f"Ep {epoch+1}/{epochs} [Val Fused]"):
                    x, y = x.to(device), y.to(device).long()

                    # Fuse
                    base_out = base_model(x)
                    boost_out = model(x)

                    loss = criterion(boost_out, y)
                    val_loss += loss.item()

                    fused_probs = fuse_conditional_boost_internal(base_out, boost_out)
                    fused_preds = fused_probs.argmax(dim=1)
                    iou_meter.update(fused_preds, y)

            avg_val_loss = val_loss / len(val_loader)
            metrics = iou_meter.get_metrics()
            miou = metrics["miou"]

            logger.info(f"\n----- Epoch {epoch+1} Result (Fused Base+Boost) -----")
            logger.info(f"Train Loss: {avg_train_loss:.4f} | Val Loss(Boost): {avg_val_loss:.4f}")
            logger.info(f"Fused mIoU  : {miou:.4f}")

            # ★追加: 全クラスのスコアを表示
            logger.info("--- Class-wise Performance ---")
            for i, name in enumerate(CLASS_NAMES):
                # インデックス範囲外参照防止 (CLASS_NAMESとNUM_CLASSESが一致している前提)
                if i < len(metrics['iou']):
                    logger.info(f"ID {i:2d} ({name:8s}) | IoU: {metrics['iou'][i]:.4f} | Prec: {metrics['precision'][i]:.4f} | Rec: {metrics['recall'][i]:.4f}")
            logger.info("------------------------------")

            # CM保存
            if (epoch + 1) % save_cm_every == 0:
                cm_path = os.path.join(run_dir, f"cm_epoch_{epoch+1}.npy")
                np.save(cm_path, metrics["cm"])
                print(f"[OUTPUT] Confusion Matrix saved to: {cm_path}")

            # Best Model保存
            if miou > best_fused_miou:
                best_fused_miou = miou
                torch.save(model.state_dict(), best_path)
                logger.info("★ Best Model Updated (Based on Fused Score)!")
                print(f"[OUTPUT] Best Model saved to: {best_path}")

            logger.info("================================================\n")

    return {"tri_best": best_path, "run_dir": run_dir}

In [None]:
# =========================
# Cell 4 - Part 1: Imports, Fusion & Submit Logic
# =========================
import torch
import torch.nn.functional as F
import os
import numpy as np
import shutil
import logging
import cv2
import albumentations as A
from torch.utils.data import DataLoader
from tqdm import tqdm

# 設定
RUN_MODE = "base"
NUM_CLASSES = 13
# ターゲットID定義 (User Defined)
BOOK_ID = 1
CABINET_ID = 5
OBJECT_ID = 6
DESK_ID = 9

# --- 外部用 Fusion 関数 (条件付きブースト) ---
def fuse_conditional_boost(base_logits, boost_logits, alpha=2.0):
    """
    Baseモデルが Book(1), Cabinet(5), Object(6), Desk(9) と予測した領域のみ、
    Boostモデルの Book(1) スコアを加算して強化する。
    """
    TARGET_IDS = [BOOK_ID, CABINET_ID, OBJECT_ID, DESK_ID]

    # Baseの予測（確率化前でもArgmaxは同じだが、念のためSoftmax経由）
    base_probs = F.softmax(base_logits, dim=1)
    base_pred = base_probs.argmax(dim=1)

    # ターゲット領域のマスク作成
    mask = torch.zeros_like(base_pred, dtype=torch.bool)
    for tid in TARGET_IDS:
        mask |= (base_pred == tid)

    final_logits = base_logits.clone()

    # Boostモデルの「本(ID=1)」のスコアを取り出す
    boost_book_score = boost_logits[:, BOOK_ID, :, :]

    # マスクされた領域のみ加算
    final_logits[:, BOOK_ID, :, :][mask] += boost_book_score[mask] * alpha

    return F.softmax(final_logits, dim=1)

# --- 提出用データ作成関数 ---
def run_fused13_submit_for_boosted(
    dataset_root, img_size, BASE_MODEL_PATH, TRI_MODEL_PATH,
    book_indices,
    out_dir, drive_dir
):
    ensure_dir(out_dir)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\n[SUBMIT] Generating submission (13-Class Fused)...")

    # 1. Load Base Model
    # ※クラス定義は Cell 1 等にある前提
    base_model = ResNeXtDeepLabV3Plus_OS8(num_classes=NUM_CLASSES, in_channels=5).to(device)
    base_model.load_state_dict(torch.load(BASE_MODEL_PATH, map_location=device))
    base_model.eval()

    # 2. Load Boost Model
    # ※クラス定義は Cell 3 にある前提
    tri_model = TriDeepLabV3Plus_Boosted(book_indices=book_indices, in_channels=5, out_classes=NUM_CLASSES).to(device)
    tri_model.load_state_dict(torch.load(TRI_MODEL_PATH, map_location=device))
    tri_model.eval()

    # 3. Test Loader
    test_ds = NYUv2Dataset(dataset_root, split="test", transform=A.Compose([
         A.Resize(img_size, img_size, interpolation=cv2.INTER_LINEAR)
    ], additional_targets={'depth': 'image', 'height': 'image'}), return_label=False)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)

    scales = [0.5, 0.75, 1.0, 1.25, 1.5]
    preds = []

    with torch.no_grad():
        for x, _ in tqdm(test_loader, desc="[SUBMIT] Inference"):
            x = x.to(device, non_blocking=True)
            H, W = 512, 512

            # Logits累積用
            base_logits_sum = torch.zeros(1, NUM_CLASSES, H, W).to(device)
            boost_logits_sum = torch.zeros(1, NUM_CLASSES, H, W).to(device)

            # TTA Loop
            for scale in scales:
                h_s, w_s = int(H*scale), int(W*scale)
                x_s = F.interpolate(x, size=(h_s, w_s), mode='bilinear', align_corners=False)

                b_out = base_model(x_s)
                t_out = tri_model(x_s)

                # 元サイズに戻す
                b_out = F.interpolate(b_out, size=(H, W), mode='bilinear', align_corners=False)
                t_out = F.interpolate(t_out, size=(H, W), mode='bilinear', align_corners=False)

                base_logits_sum += b_out
                boost_logits_sum += t_out

                # Flip TTA
                x_f = torch.flip(x_s, [3])
                b_f = base_model(x_f)
                t_f = tri_model(x_f)
                b_f = torch.flip(F.interpolate(b_f, size=(H,W), mode='bilinear'), [3])
                t_f = torch.flip(F.interpolate(t_f, size=(H,W), mode='bilinear'), [3])
                base_logits_sum += b_f
                boost_logits_sum += t_f

            # 4. 統合 (Fusion)
            final_prob = fuse_conditional_boost(base_logits_sum, boost_logits_sum)
            pred = final_prob.argmax(dim=1).cpu().numpy().astype(np.uint8)
            preds.append(pred)

    # 5. 保存
    submission = np.concatenate(preds, axis=0)
    sub_path = os.path.join(out_dir, "submission.npy")
    np.save(sub_path, submission)
    print(f"[OUTPUT] Submission NPY saved to: {sub_path}")

    zip_path = os.path.join(out_dir, "submission_only.zip")
    zip_submission_only(sub_path, zip_path)
    print(f"[OUTPUT] Submission ZIP saved to: {zip_path}")

    # Drive Copy
    if drive_dir:
        ensure_dir(drive_dir)
        copy_to_drive_if_needed(zip_path, drive_dir)
        print(f"[OUTPUT] Copied submission to Drive: {drive_dir}")

    return zip_path

# =========================
# Cell 4 - Part 2: Main Execution
# =========================

def main():
    mount_drive()

    base_drive_root = "/content/drive/MyDrive/nyu_runs/base"
    tri_drive_root  = "/content/drive/MyDrive/nyu_runs/tri"

    # --- BASE MODE ---
    if RUN_MODE in ["base", "both"]:
        base_drive_dir = os.path.join(base_drive_root, make_run_id("base_out"))
        base_out = run_base_train_and_submit(
            dataset_root="/content/data",
            img_size=768,
            batch_size=16,
            epochs=60,
            lr=1e-4,
            seed=42,
            phase1_epochs=40,
            save_cm_every=5,
            drive_dir=base_drive_dir,
            base_best_out="/content/base_best_model.pt",
        )
        print("[BASE] summary:", base_out)

    # --- TRI MODE (Boost) ---
    if RUN_MODE in ["tri", "both"]:
        tri_drive_dir = os.path.join(tri_drive_root, make_run_id("tri_boost_out"))
        ensure_dir(tri_drive_dir)

        # 1. 分析 (13クラスのID=1を探す)
        print("\n[Step 1] Finding 'Book' Neurons (ID=1)...")
        try:
            # Cell 5 の分析関数を実行
            book_indices = find_book_neurons_pure("/content/data", top_k=256)
        except NameError:
            print("⚠ 'find_book_neurons_pure' not found. Using dummy indices.")
            book_indices = list(range(256))

        # 2. 学習 (Fused Validation付き)
        print(f"\n[Step 2] Training Boosted Model with Fused Validation...")
        # Cell 3 の学習関数を実行
        tri_out = run_tri_train_boost(
            book_indices=book_indices,
            dataset_root="/content/data",
            img_size=768,
            batch_size=16,
            epochs=25,
            lr=3e-4,
            seed=42,
            save_cm_every=5,
            fused_eval_every=1,  # ★重要: 毎エポック Fusedスコアを確認
            fused_tta_light=True,
            BASE_MODEL_PATH="/content/base_best_model.pt",
        )
        print("[TRI] summary:", tri_out)

        # 3. 提出
        submit_drive_dir = os.path.join(tri_drive_root, make_run_id("tri_submit"))
        submit_out_dir = os.path.join(tri_out["run_dir"], "submit_fused13")

        fused_submit = run_fused13_submit_for_boosted(
            dataset_root="/content/data",
            img_size=768,
            BASE_MODEL_PATH="/content/base_best_model.pt",
            TRI_MODEL_PATH=tri_out["tri_best"],
            book_indices=book_indices,
            out_dir=submit_out_dir,
            drive_dir=submit_drive_dir,
        )
        print("[FUSED13] submit:", fused_submit)

        # 4. バックアップ (ログとモデル)
        if tri_drive_dir:
            print("\n[BACKUP] Copying files to Drive...")
            try:
                # モデルのコピー
                dest_model = os.path.join(tri_drive_dir, "tri_best_model.pt")
                shutil.copy(tri_out["tri_best"], dest_model)
                print(f"[OUTPUT] Backup Model copied to: {dest_model}")

                # ログファイルのコピー
                dest_log = os.path.join(tri_drive_dir, "train_log.txt")
                shutil.copy(os.path.join(tri_out["run_dir"], "log.txt"), dest_log)
                print(f"[OUTPUT] Backup Log copied to: {dest_log}")
            except Exception as e:
                print(f"[WARN] Backup failed: {e}")

if __name__ == "__main__":
    main()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
[BASE] run_id: base_20260106_002943
[BASE] run_dir: /content/base_20260106_002943
[BASE] logs   : /content/base_20260106_002943/logs
[BASE] save best : /content/base_20260106_002943/base_best_model.pt
[BASE] save final: /content/base_20260106_002943/base_final_model.pt
[BASE] device: cuda
[BASE] Train: 716, Val: 79, Test: 654
[BASE] Model: ResNeXt101_32x8d + DeepLabV3+ (OS=8) + ASPP(12,24,36)


  original_init(self, **validated_kwargs)
[BASE] Epoch 1/60: 100%|██████████| 44/44 [01:04<00:00,  1.46s/it, loss=0.9743, dice=0.8376, bd=1.265, ph=1]


[BASE] Epoch 01 | mIoU=0.34001
Class           IoU     Prec      Rec
-------------------------------------
Bed           0.502    0.755    0.600
Book          0.101    0.110    0.554
Ceiling       0.000    0.856    0.000
Chair         0.375    0.509    0.587
Floor         0.530    0.871    0.575
Cabinet       0.246    0.501    0.326
Object        0.395    0.438    0.800
Picture       0.360    0.592    0.478
Sofa          0.230    0.530    0.288
Desk          0.189    0.210    0.651
TV            0.465    0.596    0.680
Wall          0.565    0.913    0.597
Window        0.463    0.683    0.589
  -> New best BASE mIoU: 0.34001


[BASE] Epoch 2/60: 100%|██████████| 44/44 [00:59<00:00,  1.34s/it, loss=0.9449, dice=0.8075, bd=1.183, ph=1]


[BASE] Epoch 02 | mIoU=0.44113
Class           IoU     Prec      Rec
-------------------------------------
Bed           0.450    0.813    0.502
Book          0.143    0.153    0.681
Ceiling       0.217    0.625    0.250
Chair         0.474    0.712    0.586
Floor         0.813    0.947    0.852
Cabinet       0.397    0.583    0.554
Object        0.412    0.479    0.745
Picture       0.493    0.635    0.687
Sofa          0.522    0.709    0.664
Desk          0.372    0.563    0.522
TV            0.298    0.302    0.961
Wall          0.659    0.934    0.691
Window        0.488    0.630    0.684
  -> New best BASE mIoU: 0.44113


[BASE] Epoch 3/60: 100%|██████████| 44/44 [00:59<00:00,  1.35s/it, loss=0.7844, dice=0.7686, bd=1.152, ph=1]


[BASE] Epoch 03 | mIoU=0.46296
Class           IoU     Prec      Rec
-------------------------------------
Bed           0.648    0.802    0.771
Book          0.126    0.132    0.726
Ceiling       0.277    0.769    0.302
Chair         0.299    0.822    0.320
Floor         0.787    0.984    0.797
Cabinet       0.375    0.549    0.542
Object        0.458    0.520    0.794
Picture       0.543    0.811    0.622
Sofa          0.437    0.807    0.488
Desk          0.352    0.485    0.564
TV            0.376    0.889    0.395
Wall          0.755    0.907    0.818
Window        0.585    0.777    0.703
  -> New best BASE mIoU: 0.46296


[BASE] Epoch 4/60: 100%|██████████| 44/44 [00:59<00:00,  1.36s/it, loss=0.6715, dice=0.7774, bd=1.090, ph=1]


[BASE] Epoch 04 | mIoU=0.52065
Class           IoU     Prec      Rec
-------------------------------------
Bed           0.679    0.910    0.728
Book          0.148    0.152    0.835
Ceiling       0.564    0.719    0.724
Chair         0.465    0.509    0.843
Floor         0.881    0.959    0.915
Cabinet       0.382    0.738    0.442
Object        0.474    0.601    0.691
Picture       0.514    0.875    0.555
Sofa          0.364    0.534    0.533
Desk          0.321    0.403    0.611
TV            0.640    0.774    0.788
Wall          0.755    0.952    0.785
Window        0.582    0.619    0.905
  -> New best BASE mIoU: 0.52065


[BASE] Epoch 5/60: 100%|██████████| 44/44 [00:59<00:00,  1.34s/it, loss=0.6398, dice=0.7270, bd=1.080, ph=1]


[BASE] Epoch 05 | mIoU=0.57320
Class           IoU     Prec      Rec
-------------------------------------
Bed           0.592    0.926    0.621
Book          0.152    0.156    0.848
Ceiling       0.644    0.833    0.739
Chair         0.540    0.603    0.839
Floor         0.845    0.986    0.856
Cabinet       0.467    0.712    0.576
Object        0.501    0.621    0.721
Picture       0.586    0.887    0.633
Sofa          0.561    0.718    0.720
Desk          0.363    0.499    0.571
TV            0.699    0.783    0.868
Wall          0.813    0.933    0.864
Window        0.688    0.798    0.834
  -> New best BASE mIoU: 0.57320


[BASE] Epoch 6/60: 100%|██████████| 44/44 [01:00<00:00,  1.36s/it, loss=0.5647, dice=0.7251, bd=1.007, ph=1]


[BASE] Epoch 06 | mIoU=0.56628
Class           IoU     Prec      Rec
-------------------------------------
Bed           0.729    0.895    0.797
Book          0.160    0.182    0.563
Ceiling       0.611    0.879    0.667
Chair         0.552    0.817    0.630
Floor         0.901    0.976    0.921
Cabinet       0.487    0.706    0.610
Object        0.510    0.597    0.779
Picture       0.591    0.694    0.800
Sofa          0.466    0.828    0.515
Desk          0.395    0.494    0.663
TV            0.504    0.552    0.853
Wall          0.799    0.937    0.844
Window        0.657    0.762    0.826


[BASE] Epoch 7/60: 100%|██████████| 44/44 [00:59<00:00,  1.36s/it, loss=0.5978, dice=0.7025, bd=0.994, ph=1]


[BASE] Epoch 07 | mIoU=0.58966
Class           IoU     Prec      Rec
-------------------------------------
Bed           0.622    0.877    0.681
Book          0.166    0.190    0.571
Ceiling       0.602    0.722    0.784
Chair         0.606    0.709    0.807
Floor         0.877    0.982    0.891
Cabinet       0.520    0.788    0.604
Object        0.510    0.597    0.777
Picture       0.582    0.781    0.695
Sofa          0.604    0.793    0.717
Desk          0.379    0.510    0.596
TV            0.679    0.802    0.817
Wall          0.825    0.938    0.873
Window        0.693    0.773    0.869
  -> New best BASE mIoU: 0.58966


[BASE] Epoch 8/60: 100%|██████████| 44/44 [00:59<00:00,  1.36s/it, loss=0.5468, dice=0.6822, bd=0.953, ph=1]


[BASE] Epoch 08 | mIoU=0.58815
Class           IoU     Prec      Rec
-------------------------------------
Bed           0.697    0.877    0.773
Book          0.207    0.240    0.604
Ceiling       0.610    0.932    0.639
Chair         0.597    0.725    0.772
Floor         0.917    0.977    0.938
Cabinet       0.526    0.672    0.707
Object        0.521    0.623    0.761
Picture       0.658    0.896    0.712
Sofa          0.540    0.829    0.607
Desk          0.420    0.592    0.590
TV            0.508    0.980    0.513
Wall          0.819    0.937    0.866
Window        0.626    0.791    0.750


[BASE] Epoch 9/60: 100%|██████████| 44/44 [00:58<00:00,  1.34s/it, loss=0.5293, dice=0.6951, bd=0.930, ph=1]


[BASE] Epoch 09 | mIoU=0.61170
Class           IoU     Prec      Rec
-------------------------------------
Bed           0.799    0.910    0.867
Book          0.176    0.188    0.720
Ceiling       0.591    0.806    0.689
Chair         0.611    0.812    0.712
Floor         0.919    0.963    0.953
Cabinet       0.517    0.725    0.643
Object        0.543    0.652    0.763
Picture       0.649    0.779    0.796
Sofa          0.625    0.810    0.733
Desk          0.402    0.543    0.608
TV            0.637    0.709    0.863
Wall          0.813    0.923    0.872
Window        0.671    0.867    0.748
  -> New best BASE mIoU: 0.61170


[BASE] Epoch 10/60: 100%|██████████| 44/44 [00:59<00:00,  1.35s/it, loss=0.5071, dice=0.6773, bd=0.944, ph=1]


[BASE] Epoch 10 | mIoU=0.58500
Class           IoU     Prec      Rec
-------------------------------------
Bed           0.634    0.927    0.668
Book          0.196    0.226    0.599
Ceiling       0.538    0.848    0.595
Chair         0.594    0.710    0.784
Floor         0.905    0.979    0.923
Cabinet       0.549    0.665    0.760
Object        0.544    0.651    0.767
Picture       0.595    0.818    0.686
Sofa          0.561    0.801    0.652
Desk          0.383    0.660    0.477
TV            0.613    0.679    0.864
Wall          0.829    0.937    0.878
Window        0.664    0.873    0.735


[BASE] Epoch 11/60:  50%|█████     | 22/44 [00:30<00:28,  1.28s/it, loss=0.5239, dice=0.7037, bd=0.941, ph=1]

In [None]:
# =========================
# Check Base Model Performance Alone
# =========================
def check_base_performance(dataset_root, model_path, img_size=768):
    print(f"[CHECK] Evaluating Base Model Alone: {model_path}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load Base Model
    base_model = ResNeXtDeepLabV3Plus_OS8(num_classes=13, in_channels=5).to(device)
    base_model.load_state_dict(torch.load(model_path, map_location=device))
    base_model.eval()

    # Dataset & Loader
    val_ds = NYUv2Dataset(dataset_root, split="val", transform=None)
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4)

    iou_meter = IoUMeter(num_classes=13)

    with torch.no_grad():
        for x, y in tqdm(val_loader, desc="Base Eval"):
            x, y = x.to(device), y.to(device).long()

            # Predict
            # 画像サイズに合わせてリサイズが必要ならここに挟む
            # ここでは簡易的にそのまま流す（学習時と同じ前提）
            out = base_model(x)
            pred = out.argmax(dim=1)
            iou_meter.update(pred, y)

    metrics = iou_meter.get_metrics()
    print(f"\n[Base Model Result]")
    print(f"mIoU: {metrics['miou']:.4f}")
    print(f"Book IoU: {metrics['iou'][1]:.4f}") # Book ID=1
    print("------------------------------------------------")

# 実行
check_base_performance("/content/data", "/content/base_best_model.pt")

[CHECK] Evaluating Base Model Alone: /content/base_best_model.pt


Base Eval: 100%|██████████| 795/795 [00:26<00:00, 29.89it/s]


[Base Model Result]
mIoU: 0.8627
Book IoU: 0.6601
------------------------------------------------





In [None]:
import json
import matplotlib.pyplot as plt
import glob
import os

def plot_learning_curves(log_file):
    epochs = []
    losses = []
    mious = []
    lrs = []

    # ログファイルの読み込み
    with open(log_file, 'r') as f:
        for line in f:
            data = json.loads(line)
            # ヘッダー行（info）はスキップ
            if "info" in data:
                continue

            epochs.append(data['epoch'])
            losses.append(data['loss'])
            mious.append(data['val_miou'])
            lrs.append(data['lr'])

    # プロットの作成
    fig, ax1 = plt.subplots(figsize=(12, 6))

    # Loss (左軸・赤)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Training Loss', color='tab:red')
    ax1.plot(epochs, losses, color='tab:red', label='Train Loss', marker='.')
    ax1.tick_params(axis='y', labelcolor='tab:red')
    ax1.grid(True, alpha=0.3)

    # mIoU (右軸・青)
    ax2 = ax1.twinx()
    ax2.set_ylabel('Validation mIoU', color='tab:blue')
    ax2.plot(epochs, mious, color='tab:blue', label='Val mIoU', marker='.')
    ax2.tick_params(axis='y', labelcolor='tab:blue')

    plt.title(f'Learning Curve: {os.path.basename(log_file)}')
    fig.tight_layout()
    plt.show()

    # 学習率の確認用プロット（Poly Schedulerが効いているか確認）
    plt.figure(figsize=(12, 3))
    plt.plot(epochs, lrs, color='orange')
    plt.title("Learning Rate Schedule")
    plt.xlabel("Epoch")
    plt.ylabel("LR")
    plt.grid(True, alpha=0.3)
    plt.show()

# 自動で最新のログファイルを見つけて描画
log_files = sorted(glob.glob("train_log_detailed_*.jsonl"))
if log_files:
    latest_log = log_files[-1]
    print(f"Plotting: {latest_log}")
    plot_learning_curves(latest_log)
else:
    print("No log file found.")

In [None]:
# -*- coding: utf-8 -*-
# NYUv2 val visualization (RGB | GT | Pred | RGB+Error) with FIXED palette -> ZIP
# - NO matplotlib colormap
# - GT and Pred share EXACT same palette mapping (ID -> RGB)
# - ignore_index(255) is shown as black by default

import os
import zipfile
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, Subset
from torchvision.datasets import VisionDataset
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
from torchvision import models

# =========================
# EDIT HERE
# =========================
DATASET_ROOT = "/content/data"
IMAGE_SIZE = (512, 512)

MODEL_PATH = "/content/checkpoints/model_20260103062841.pt"  # ←ここだけ直す
OUTDIR = "val_viz_fixed"
ZIP_PATH = "val_viz_fixed.zip"

MAX_SAVE = 50  # Noneで全件
SEED = 42
TRAIN_VAL_SPLIT = 0.9

NUM_WORKERS = 0
BATCH_SIZE = 1

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IGNORE_INDEX = 255

print("DEVICE:", DEVICE)
print("MODEL_PATH:", MODEL_PATH)

# =========================
# FIXED PALETTE (ID -> RGB)
# ここが「色の意味」そのもの。GTもPredも必ずこれを通す。
# =========================
PALETTE_13 = np.array([
    [  0,   0,   0],   # 0 wall
    [128,  64,  32],   # 1 floor
    [  0, 128, 255],   # 2 cabinet
    [255,   0, 128],   # 3 bed
    [  0, 255,   0],   # 4 chair
    [255, 128,   0],   # 5 sofa
    [255, 255,   0],   # 6 table
    [128, 128, 128],   # 7 door
    [  0, 255, 255],   # 8 window
    [  0,  64, 128],   # 9 bookshelf
    [255,   0,   0],   # 10 picture
    [128,   0, 255],   # 11 counter
    [ 64, 255, 128],   # 12 desk
], dtype=np.uint8)

# =========================
# Dataset helpers
# =========================
def pil_label_to_long_tensor(lbl_pil: Image.Image) -> torch.Tensor:
    arr = np.array(lbl_pil, dtype=np.int64)
    return torch.from_numpy(arr).long()

def depth_pil_to_tensor_01(depth_pil: Image.Image, size_hw):
    depth_pil = depth_pil.resize((size_hw[1], size_hw[0]), resample=Image.BILINEAR)
    arr = np.array(depth_pil)
    if arr.dtype == np.uint16 or (arr.max() > 255):
        arr = arr.astype(np.float32) / 65535.0
    else:
        arr = arr.astype(np.float32) / 255.0
    arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
    arr = np.clip(arr, 0.0, 1.0)
    return torch.from_numpy(arr).unsqueeze(0)  # [1,H,W]

class NYUv2(VisionDataset):
    def __init__(self, root, split="train", include_depth=True,
                 image_transform=None, depth_transform=None, target_transform=None):
        super().__init__(root)
        assert split in ("train", "test")
        self.root = root
        self.split = split
        self.include_depth = include_depth

        images_dir = os.path.join(self.root, self.split, "image")
        img_names = sorted(os.listdir(images_dir))

        self.images = [os.path.join(images_dir, n) for n in img_names]
        self.depths = [os.path.join(self.root, self.split, "depth", n) for n in img_names]

        if self.split == "train":
            self.targets = [os.path.join(self.root, self.split, "label", n) for n in img_names]
        else:
            self.targets = None

        self.image_transform = image_transform
        self.depth_transform = depth_transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert("RGB")
        depth = Image.open(self.depths[idx])

        if self.image_transform is not None:
            image = self.image_transform(image)
        if self.depth_transform is not None:
            depth = self.depth_transform(depth)

        if self.split == "test":
            return (image, depth) if self.include_depth else image

        target = Image.open(self.targets[idx])
        if self.target_transform is not None:
            target = self.target_transform(target)

        return (image, depth, target) if self.include_depth else (image, target)

# =========================
# Model blocks (ResNet50-UNet + optional SE)
# =========================
class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, p=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=k, padding=p, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.block(x)

class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.conv1 = ConvBNReLU(in_ch + skip_ch, out_ch)
        self.conv2 = ConvBNReLU(out_ch, out_ch)

    def forward(self, x, skip):
        x = nn.functional.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class DepthStem(nn.Module):
    def __init__(self, z_ch=64):
        super().__init__()
        self.net = nn.Sequential(
            ConvBNReLU(1, 32, 3, 1),
            nn.MaxPool2d(2),
            ConvBNReLU(32, 48, 3, 1),
            nn.MaxPool2d(2),
            ConvBNReLU(48, z_ch, 3, 1),
            nn.MaxPool2d(2),
        )
    def forward(self, d):
        return self.net(d)

class SEFromDepth(nn.Module):
    def __init__(self, feat_ch, z_ch, reduction=16, alpha=0.05):
        super().__init__()
        self.alpha = float(alpha)
        hidden = max(32, feat_ch // reduction)
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(z_ch, hidden, 1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden, feat_ch, 1, bias=True),
            nn.Sigmoid(),
        )
    def forward(self, F, z):
        g = self.fc(z)  # [B,C,1,1]
        scale = 1.0 + self.alpha * (g - 0.5)
        return F * scale

class ResNet50UNet(nn.Module):
    def __init__(self, num_classes=13, coarse_classes=5, in_channels=4,
                 pretrained=False,
                 use_se=False, se_z_ch=64, se_reduction=16, se_alpha=0.05,
                 se_dec_stages=("up1",)):
        super().__init__()
        self.use_se = use_se
        self.se_dec_stages = tuple(se_dec_stages)

        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None)

        old_conv = resnet.conv1
        if in_channels != old_conv.in_channels:
            new_conv = nn.Conv2d(
                in_channels, old_conv.out_channels,
                kernel_size=old_conv.kernel_size,
                stride=old_conv.stride,
                padding=old_conv.padding,
                bias=False
            )
            with torch.no_grad():
                new_conv.weight[:, :3] = old_conv.weight
                mean_w = old_conv.weight.mean(dim=1, keepdim=True)
                for c in range(3, in_channels):
                    new_conv.weight[:, c:c+1] = mean_w
            resnet.conv1 = new_conv

        self.enc0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)  # 1/2
        self.pool = resnet.maxpool                                        # 1/4
        self.enc1 = resnet.layer1                                         # 1/4
        self.enc2 = resnet.layer2                                         # 1/8
        self.enc3 = resnet.layer3                                         # 1/16
        self.enc4 = resnet.layer4                                         # 1/32

        self.center = ConvBNReLU(2048, 1024)
        self.up4 = UpBlock(1024, 1024, 512)
        self.up3 = UpBlock(512,  512,  256)
        self.up2 = UpBlock(256,  256,  128)
        self.up1 = UpBlock(128,  64,   64)

        self.head_feat = ConvBNReLU(64, 64)
        self.head_out  = nn.Conv2d(64, num_classes, kernel_size=1)

        self.coarse_head = nn.Conv2d(256, coarse_classes, kernel_size=1)
        self.boundary_head = nn.Sequential(ConvBNReLU(64, 32), nn.Conv2d(32, 1, 1))

        if self.use_se:
            self.depth_stem = DepthStem(z_ch=se_z_ch)
            stage_to_ch = {"center":1024, "up4":512, "up3":256, "up2":128, "up1":64}
            self.se_dec = nn.ModuleDict()
            for st in self.se_dec_stages:
                self.se_dec[st] = SEFromDepth(stage_to_ch[st], se_z_ch, se_reduction, se_alpha)
        else:
            self.depth_stem = None
            self.se_dec = None

    def _apply_dec_se(self, feat, z, stage_name: str):
        if (not self.use_se) or (self.se_dec is None) or (stage_name not in self.se_dec):
            return feat
        z_ = nn.functional.interpolate(z, size=feat.shape[-2:], mode="bilinear", align_corners=False)
        return self.se_dec[stage_name](feat, z_)

    def forward(self, x, depth=None):
        c1 = self.enc0(x)
        t  = self.pool(c1)
        c2 = self.enc1(t)
        c3 = self.enc2(c2)
        c4 = self.enc3(c3)
        c5 = self.enc4(c4)

        z = self.depth_stem(depth) if self.use_se else None

        x = self.center(c5)
        if z is not None: x = self._apply_dec_se(x, z, "center")
        x = self.up4(x, c4)
        if z is not None: x = self._apply_dec_se(x, z, "up4")
        x = self.up3(x, c3)
        if z is not None: x = self._apply_dec_se(x, z, "up3")

        coarse_logits = self.coarse_head(x)

        x = self.up2(x, c2)
        if z is not None: x = self._apply_dec_se(x, z, "up2")
        x = self.up1(x, c1)
        if z is not None: x = self._apply_dec_se(x, z, "up1")

        x = nn.functional.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
        feat = self.head_feat(x)
        fine_logits = self.head_out(feat)
        boundary_logit = self.boundary_head(x)
        return fine_logits, coarse_logits, boundary_logit

# =========================
# FIXED colorization (ID -> RGB)
# =========================
def id_to_rgb(mask_hw: np.ndarray, ignore_index=255, ignore_rgb=(0,0,0)) -> np.ndarray:
    """
    mask_hw: [H,W] int (0..12 or 255)
    returns: [H,W,3] uint8
    """
    out = np.zeros((mask_hw.shape[0], mask_hw.shape[1], 3), dtype=np.uint8)
    ign = (mask_hw == ignore_index)
    valid = ~ign
    m = mask_hw.copy()
    m[ign] = 0
    m = np.clip(m, 0, 12)
    out[valid] = PALETTE_13[m[valid]]
    out[ign] = np.array(ignore_rgb, dtype=np.uint8)
    return out

def denorm_rgb(rgb_tensor: torch.Tensor) -> Image.Image:
    mean = torch.tensor([0.485, 0.456, 0.406], device=rgb_tensor.device).view(3,1,1)
    std  = torch.tensor([0.229, 0.224, 0.225], device=rgb_tensor.device).view(3,1,1)
    x = (rgb_tensor * std + mean).clamp(0, 1)
    x = (x * 255).byte().permute(1,2,0).cpu().numpy()
    return Image.fromarray(x, mode="RGB")

def overlay_error(rgb_pil: Image.Image, gt_hw: np.ndarray, pred_hw: np.ndarray, ignore_index=255) -> Image.Image:
    rgb = np.array(rgb_pil).astype(np.float32)
    err = (gt_hw != ignore_index) & (gt_hw != pred_hw)
    if err.any():
        overlay = rgb.copy()
        overlay[err] = 0.6 * overlay[err] + 0.4 * np.array([255, 0, 0], dtype=np.float32)
        overlay = np.clip(overlay, 0, 255).astype(np.uint8)
        return Image.fromarray(overlay, mode="RGB")
    return rgb_pil

def stack_h(images):
    widths = [im.width for im in images]
    heights = [im.height for im in images]
    out = Image.new("RGB", (sum(widths), max(heights)))
    x = 0
    for im in images:
        out.paste(im, (x, 0))
        x += im.width
    return out

# =========================
# Transforms (same as eval)
# =========================
eval_image_transform = Compose([
    Resize(IMAGE_SIZE, interpolation=InterpolationMode.BILINEAR),
    ToTensor(),
    Normalize((0.485,0.456,0.406), (0.229,0.224,0.225)),
])

def depth_transform(pil):
    return depth_pil_to_tensor_01(pil, IMAGE_SIZE)

def target_transform(pil):
    pil = pil.resize((IMAGE_SIZE[1], IMAGE_SIZE[0]), resample=Image.NEAREST)
    return pil_label_to_long_tensor(pil)

# =========================
# Build val split (same indices rule)
# =========================
split_base = NYUv2(
    root=DATASET_ROOT,
    split="train",
    include_depth=True,
    image_transform=eval_image_transform,
    depth_transform=depth_transform,
    target_transform=target_transform,
)

n_total = len(split_base)
n_train = int(n_total * TRAIN_VAL_SPLIT)
n_val = n_total - n_train
g = torch.Generator().manual_seed(SEED)
_, val_subset = random_split(split_base, [n_train, n_val], generator=g)
val_ds = Subset(split_base, val_subset.indices)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"Val samples: {len(val_ds)}")

# =========================
# Load checkpoint + auto-detect SE
# =========================
ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
if isinstance(ckpt, dict) and "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
    sd = ckpt["state_dict"]
else:
    sd = ckpt

has_se = any(k.startswith("depth_stem.") or k.startswith("se_dec.") for k in sd.keys())
print("Detected SE in checkpoint:", has_se)

model = ResNet50UNet(
    num_classes=13,
    coarse_classes=5,
    in_channels=4,
    pretrained=False,
    use_se=has_se,
    se_z_ch=64,
    se_reduction=16,
    se_alpha=0.05,
    se_dec_stages=("up1",),
).to(DEVICE)

model.load_state_dict(sd, strict=True)
model.eval()

# =========================
# Generate panels with FIXED palette
# =========================
os.makedirs(OUTDIR, exist_ok=True)

saved = 0
with torch.no_grad():
    for i, (rgb, depth, label) in enumerate(tqdm(val_loader, desc="Saving fixed-color panels")):
        rgb = rgb.to(DEVICE, non_blocking=True)
        depth = depth.to(DEVICE, non_blocking=True)
        label = label.to(DEVICE, non_blocking=True)

        x = torch.cat([rgb.float(), depth.float()], dim=1)
        fine_logits, _, _ = model(x, depth=depth.float() if has_se else None)
        pred = fine_logits.argmax(dim=1)

        rgb_pil = denorm_rgb(rgb[0])

        gt_hw = label[0].cpu().numpy().astype(np.int32)
        pred_hw = pred[0].cpu().numpy().astype(np.int32)

        gt_rgb = id_to_rgb(gt_hw, ignore_index=IGNORE_INDEX, ignore_rgb=(0,0,0))
        pr_rgb = id_to_rgb(pred_hw, ignore_index=IGNORE_INDEX, ignore_rgb=(0,0,0))

        gt_pil = Image.fromarray(gt_rgb, mode="RGB")
        pr_pil = Image.fromarray(pr_rgb, mode="RGB")
        err_pil = overlay_error(rgb_pil, gt_hw, pred_hw, ignore_index=IGNORE_INDEX)

        panel = stack_h([rgb_pil, gt_pil, pr_pil, err_pil])
        panel.save(os.path.join(OUTDIR, f"{i:06d}.png"), optimize=True)

        saved += 1
        if MAX_SAVE is not None and saved >= MAX_SAVE:
            break

print(f"Saved: {saved} images -> {OUTDIR}")

# =========================
# Zip
# =========================
with zipfile.ZipFile(ZIP_PATH, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as zf:
    for fn in sorted(os.listdir(OUTDIR)):
        if fn.lower().endswith(".png"):
            zf.write(os.path.join(OUTDIR, fn), arcname=fn)

print(f"ZIP created: {ZIP_PATH}")


In [None]:
%ls -l checkpoints

In [None]:
# ------------------
#    Evaluation (DeepLabV3-ResNet50 / ResNet backbone)
# ------------------

ckpt = torch.load(final_path, map_location=device)

# もし「学習で checkpoint dict（model_state入り）」で保存している場合にも対応
if isinstance(ckpt, dict) and "model_state" in ckpt:
    model.load_state_dict(ckpt["model_state"])
else:
    model.load_state_dict(ckpt)

model.eval()

predictions = []

with torch.no_grad():
    print("Generating predictions...")
    for image, depth in tqdm(test_data):
        image = image.to(device, non_blocking=True)
        depth = depth.to(device, non_blocking=True)

        # 念のため：depthが3ch事故のときは1chに落とす
        if depth.dim() == 4 and depth.size(1) == 3:
            depth = depth[:, :1, :, :]

        x = torch.cat((image, depth), dim=1)   # [B,4,H,W]

        # DeepLabV3 は dict を返すので "out" を使う
        out = model(x)["out"]                  # [B,num_classes,H,W]
        pred = out.argmax(dim=1)               # [B,H,W]

        predictions.append(pred.cpu())

predictions = torch.cat(predictions, dim=0).numpy()
np.save("submission.npy", predictions)
print("Predictions saved to submission.npy")


"""## 提出方法

以下の3点をzip化し，Omnicampusの「最終課題 (セグメンテーション)」から提出してください．

- `submission.npy`
- `model.pt`や`model_best.pt`など，テストに使用した重み（拡張子は`.pt`のみ）
- 本Colab Notebook
"""

from zipfile import ZipFile, ZIP_DEFLATED

notebook_path = "/content/drive/MyDrive/Colab Notebooks/DL_Basic_2025_Competition_NYUv2_baseline.ipynb"

from zipfile import ZipFile, ZIP_DEFLATED

notebook_path = "/content/drive/MyDrive/code.ipynb"

with ZipFile("submission.zip",
             mode="w",
             compression=ZIP_DEFLATED,
             compresslevel=9) as zf:
    zf.write("submission.npy")
    zf.write(model_path)
#    zf.write(notebook_path,
#             arcname="code.ipynb")
