<a href="https://colab.research.google.com/github/trie0000/external/blob/main/code_20260108_1520.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# -*- coding: utf-8 -*-
"""DL_Basic_2025_Competition_NYUv2_baseline.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/17t7uAU0aST5aUt6sIJCqFneGlQrxFrJY

# Deep Learning 基礎講座　最終課題: NYUv2 セマンティックセグメンテーションcoarse_w=0」

## 概要
RGB画像から、画像内の各ピクセルがどのクラスに属するかを予測するセマンティックセグメンテーションタスク.

### データセット
- データセット: NYUv2 dataset
- 訓練データ: 795枚
- テストデータ: 654枚
- 入力: RGB画像 + 深度マップ（元画像サイズは可変）
- 出力: 13クラスのセグメンテーションマップ
- 評価指標: Mean IoU (Intersection over Union)

### データセットの詳細（[NYU Depth Dataset V2](https://cs.nyu.edu/~fergus/datasets/nyu_depth_v2.html)）
- 画像は屋内シーンを撮影したもので、家具や壁、床などの物体が含まれています.
- 各画像に対して13クラスのセグメンテーションラベルが提供されます.
- データは以下のディレクトリ構造で提供:
```
data/NYUv2/
├─train/
│  ├─image/      # RGB画像
│  │    000000.png
│  │    ...
│  │
│  ├─depth/      # 深度マップ
│  │    000000.png
│  │    ...
│  │
│  └─label/      # 13クラスセグメンテーション（教師ラベル）
│       000000.png
│       ...
└─test/
   ├─image/      # RGB画像
   │    000000.png
   │    ...
   │  ├─depth/   # 深度マップ
   │    000000.png
   │    ...
```

### タスクの詳細
- 入力のRGB画像と深度マップから、各ピクセルが13クラスのどれに属するかを予測するタスクです.
- 評価はMean IoUを使用します．
  - 各クラスごとにIoUを計算し、その平均を取ります.
  - IoUは以下の式で計算:
  $$IoU = \frac{TP}{TP + FP + FN}$$
    - TP: True Positive（正しく予測されたピクセル数）
    - FP: False Positive（誤って予測されたピクセル数）
    - FN: False Negative（見逃したピクセル数）

### 前処理
- 入力画像は512×512にリサイズされます.
- ピクセル値は0-1に正規化されます.
- セグメンテーションラベルは0-12の整数値（13クラス）です．
  - 255はignore index（評価から除外）

### 提出形式
- テスト画像（RGB + Depth）の各ピクセルに対してクラス（0~12）を予測したものをnumpy配列として保存されます.
- ファイル名: `submission.npy`
- 配列の形状: [テストデータ数, 高さ, 幅]
- 各ピクセルの値: 0-12の整数（予測クラス）

## 考えられる工夫の例
- 事前学習モデルの fine-tuning
    - ImageNetなどで事前学習されたモデルを本データセットでfine-tuningすることで性能向上が見込めます.
- 損失関数の再設計
    - クラスごとの出現頻度に応じて損失を補正するように損失関数を設計すると、クラス分布の不均衡に対してロバストな学習ができます.
- 画像の前処理
    - RandomResizedCrop / Flip / ColorJitter 等のデータ拡張を追加することで，汎化性能の向上が見込めます．

## 修了要件を満たす条件
- ベースラインでは，omnicampus 上での性能評価において， 38.2% となります．したがって，ベースラインである 38.2% を超えた提出のみ，修了要件として認めます．
- ベースラインから改善を加えることで， 50%以上に性能向上することを運営で確認しています．こちらを 1つの指標として取り組んでみてください．

## 注意点
- 最終的な予測モデルは，**配布している訓練データを用いて学習**（ファインチューニング含む）したものとしてください．
- 学習を行わず，**事前学習済みモデルの知識のみを利用した推論は禁止**します．
（例: ChatGPT 等の LLM に入力して推論を得るのみ）

### 事前学習モデルの利用
許可される事項
- **構成要素としての事前学習モデルの利用**: 自身で実装したアーキテクチャの一部（特徴抽出，埋め込みなど）として事前学習モデル（BERT，ViT など）を利用することは可能です．
- **ファインチューニング**: 上記の用途で利用している事前学習モデルのファインチューニングは可能です．

禁止される事項
- **タスク解決用の事前学習モデルの利用**: transformers などで提供されている，対象タスクを直接解くための事前学習モデルでそのまま推論のみ，またはファインチューニングのみで利用することは禁止とします．
  - 禁止事項の例: VQA タスクを直接解くための事前学習モデルを VQA タスクで利用する．

### データの準備
データをダウンロードした際に，google drive したため，利用するために google drive をマウントする必要があります．また， drive 上で展開することができないため，/content ディレクトリ下にコピーし "data.zip" を展開します．
google drive 上に "data.zip" が配置されていない場合は実行できません．google drive 上に "data.zip" (**831MB**) を配置することが可能であれば，"data_download.ipynb" を先に実行してください．難しい場合は，omnicampus 演習環境を利用してください．．
"""

# omnicampus 上では 4 セル目まで実行不要
# ドライブのマウント
from google.colab import drive
drive.mount('/content/drive')

# データダウンロード用の notebook にてgoogle drive への保存後，
# 反映に時間がかかる可能性がありますので，google drive のマウント後，
# data.zip がディレクトリ内にあることを確認してから実行してください．
# data.zip を /content 下にコピーする
!cp "/content/drive/MyDrive/data.zip" "/content"

# Commented out IPython magic to ensure Python compatibility.
# カレントディレクトリ下のファイル群を確認
# data.zip が表示されれば問題ないです
# %ls

# データを解凍する
!unzip data.zip
!mkdir data
!mv train test data/

"""omnicampus 演習環境では，data_download.ipynb のマウント，zip 化，drive へのコピーを実行しないことで，"data.zip" を解凍した形で配置されます．したがって，data ディレクトリが存在するディレクトリをカレントディレクトリとするだけで良いです．


"""

# Commented out IPython magic to ensure Python compatibility.
# omnicampus 実行用
# 以下の例では/workspace/Segmentation/split_data_scripts/omnicampus に data ディレクトリがあると想定
# %cd /workspace/Segmentation/split_data_scripts_omnicampus

# omnicampus 実行用
!pip install h5py scikit-image

"""# import library"""

Mounted at /content/drive
Archive:  data.zip
  inflating: data/train/image/000600.png  
  inflating: data/train/image/000320.png  
  inflating: data/train/image/000491.png  
  inflating: data/train/image/000502.png  
  inflating: data/train/image/000129.png  
  inflating: data/train/image/000044.png  
  inflating: data/train/image/000652.png  
  inflating: data/train/image/000919.png  
  inflating: data/train/image/000528.png  
  inflating: data/train/image/000853.png  
  inflating: data/train/image/000177.png  
  inflating: data/train/image/000584.png  
  inflating: data/train/image/001319.png  
  inflating: data/train/image/000597.png  
  inflating: data/train/image/000223.png  
  inflating: data/train/image/001350.png  
  inflating: data/train/image/000404.png  
  inflating: data/train/image/000488.png  
  inflating: data/train/image/000268.png  
  inflating: data/train/image/000481.png  
  inflating: data/train/image/000341.png  
  inflating: data/train/image/000159.png  
  inflati

'# import library'

In [None]:
# =========================
# Cell 1: Common Setup (Logger Enhanced)
# =========================
import os
import time
import json
import random
import shutil
import logging
import numpy as np
from zipfile import ZipFile, ZIP_DEFLATED
from PIL import Image
from tqdm import tqdm
import cv2
import albumentations as A

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import models
from torch.amp import autocast, GradScaler

# 環境設定
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# --- Constants ---
CLASS_NAMES = ["Bed", "Book", "Ceiling", "Chair", "Floor", "Cabinet", "Object", "Picture",
               "Sofa", "Desk", "TV", "Wall", "Window"]
NUM_CLASSES = 13
IGNORE_INDEX = 255

# Target IDs
BOOK_ID = 1
CABINET_ID = 5
OBJECT_ID = 6
DESK_ID = 9

# --- Utils ---
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def now_ts():
    return time.strftime("%Y%m%d_%H%M%S")

def make_run_id(prefix: str):
    return f"{prefix}_{now_ts()}"

def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)
    return path

def write_json(path: str, obj: dict):
    with open(path, "w") as f:
        json.dump(obj, f, indent=2)

def append_jsonl(path: str, obj: dict):
    with open(path, "a") as f:
        f.write(json.dumps(obj) + "\n")

# --- Drive Utils ---
def mount_drive():
    try:
        from google.colab import drive
        drive.mount("/content/drive")
    except Exception as e:
        print(f"[WARN] Drive mount skipped or failed: {e}")

def copy_to_drive_if_needed(src_path: str, dst_dir: str):
    if dst_dir is None: return
    ensure_dir(dst_dir)
    if os.path.isdir(src_path):
        raise RuntimeError("Directory copy is disabled.")
    shutil.copy2(src_path, os.path.join(dst_dir, os.path.basename(src_path)))

def zip_submission_only(npy_path: str, zip_path: str):
    print(f"[LONG] Start zipping submission only: {zip_path}")
    with ZipFile(zip_path, "w", ZIP_DEFLATED) as zf:
        zf.write(npy_path, arcname=os.path.basename(npy_path))
    print(f"[DONE] submission zip: {zip_path}")
    return zip_path

# --- Logger & Metrics Formatting ---
class RunLogger:
    def __init__(self, run_dir: str, run_id: str):
        self.run_dir = run_dir
        self.run_id = run_id
        self.logs_dir = ensure_dir(os.path.join(run_dir, "logs"))
        self.cfg_path = os.path.join(self.logs_dir, "config.json")
        self.train_log = os.path.join(self.logs_dir, "train_log.jsonl")
        self.val_log = os.path.join(self.logs_dir, "val_log.jsonl")
        self.batch_log = os.path.join(self.logs_dir, "batch_log.jsonl")
        self.batch_buf = []

    def save_config(self, cfg: dict):
        write_json(self.cfg_path, cfg)

    def log_batch(self, tag: str, epoch: int, batch: int, losses: dict):
        self.batch_buf.append({
            "run_id": self.run_id, "tag": tag, "epoch": epoch, "batch": batch,
            "losses": {k: float(v) for k, v in losses.items()}
        })

    def flush_batch(self):
        if not self.batch_buf: return
        with open(self.batch_log, "a") as f:
            for e in self.batch_buf: f.write(json.dumps(e) + "\n")
        self.batch_buf = []

    def log_epoch_train(self, tag: str, epoch: int, lr: float, avg_losses: dict):
        append_jsonl(self.train_log, {
            "run_id": self.run_id, "tag": tag, "epoch": epoch, "lr": float(lr),
            "losses": {k: float(v) for k, v in avg_losses.items()}
        })

    def log_epoch_val(self, tag: str, epoch: int, metrics: dict, names: list, cm_np=None):
        append_jsonl(self.val_log, {
            "run_id": self.run_id, "tag": tag, "epoch": epoch,
            "miou": float(metrics["miou"]),
            "class_iou": {names[i]: float(metrics["class_iou"][i]) for i in range(len(names))},
            "class_precision": {names[i]: float(metrics["class_precision"][i]) for i in range(len(names))},
            "class_recall": {names[i]: float(metrics["class_recall"][i]) for i in range(len(names))}
        })
        # ★ここが重要: 混合行列(Numpy配列)を保存する処理
        if cm_np is not None:
            np.save(os.path.join(self.logs_dir, f"confusion_matrix_{tag}_epoch_{epoch}.npy"), cm_np)

def setup_console_logger(run_dir):
    """Boost学習用の簡易ロガー"""
    logger = logging.getLogger(f"train_logger_{run_dir}")
    logger.setLevel(logging.INFO)
    if logger.hasHandlers(): logger.handlers.clear()
    logger.propagate = False
    formatter = logging.Formatter('[%(asctime)s] %(message)s')
    file_path = os.path.join(run_dir, "log.txt")
    file_handler = logging.FileHandler(file_path); file_handler.setFormatter(formatter)
    stream_handler = logging.StreamHandler(); stream_handler.setFormatter(formatter)
    logger.addHandler(file_handler); logger.addHandler(stream_handler)
    print(f"[OUTPUT] Log file created at: {file_path}")
    return logger

# --- Metrics ---
def compute_metrics_from_cm(cm: torch.Tensor):
    cm = cm.float()
    tp = torch.diag(cm)
    fp = cm.sum(dim=0) - tp
    fn = cm.sum(dim=1) - tp
    iou = tp / (tp + fp + fn + 1e-8)
    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    return {
        "miou": iou.mean().item(), "class_iou": iou.cpu().tolist(),
        "class_precision": precision.cpu().tolist(), "class_recall": recall.cpu().tolist(),
    }

def update_cm(cm: torch.Tensor, pred: torch.Tensor, target: torch.Tensor, n_classes: int, ignore_index: int):
    pred = pred.detach().view(-1).to("cpu")
    target = target.detach().view(-1).to("cpu")
    mask = (target != ignore_index)
    cm += torch.bincount(target[mask] * n_classes + pred[mask], minlength=n_classes**2).view(n_classes, n_classes)
    return cm

# ★追加: コンソールに見やすく表示するための関数
def fmt_metrics_console(tag: str, epoch: int, miou: float, class_iou, class_precision, class_recall, names):
    print(f"[{tag}] Epoch {epoch:02d} | mIoU={miou:.5f}")
    # ヘッダー
    header = f"{'Class':<10} {'IoU':>8} {'Prec':>8} {'Rec':>8}"
    print(header)
    print("-" * len(header))
    # 各クラスのスコア
    for i, name in enumerate(names):
        print(f"{name:<10} {class_iou[i]:>8.3f} {class_precision[i]:>8.3f} {class_recall[i]:>8.3f}")

class IoUMeter:
    """Boost学習用 (Numpy Base)"""
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.reset()
    def reset(self):
        self.confusion_matrix = np.zeros((self.num_classes, self.num_classes), dtype=np.int64)
    def update(self, pred, target):
        if isinstance(pred, torch.Tensor): pred = pred.cpu().numpy()
        if isinstance(target, torch.Tensor): target = target.cpu().numpy()
        pred = pred.flatten(); target = target.flatten()
        mask = (target != IGNORE_INDEX)
        valid_indices = target[mask] * self.num_classes + pred[mask]
        self.confusion_matrix += np.bincount(valid_indices, minlength=self.num_classes**2).reshape(self.num_classes, self.num_classes)
    def get_metrics(self):
        cm = self.confusion_matrix
        tp = np.diag(cm)
        fp = cm.sum(axis=0) - tp; fn = cm.sum(axis=1) - tp
        iou = tp / (tp + fp + fn + 1e-10)
        precision = tp / (tp + fp + 1e-10)
        recall = tp / (tp + fn + 1e-10)
        return {"iou": iou, "miou": np.nanmean(iou), "precision": precision, "recall": recall, "cm": cm}

# --- Fusion Logic (Unified) ---
def fuse_conditional_boost(base_logits, boost_logits, alpha=2.0):
    TARGET_IDS = [BOOK_ID, CABINET_ID, OBJECT_ID, DESK_ID]
    base_probs = F.softmax(base_logits, dim=1)
    base_pred = base_probs.argmax(dim=1)
    mask = torch.zeros_like(base_pred, dtype=torch.bool)
    for tid in TARGET_IDS:
        mask |= (base_pred == tid)
    final_logits = base_logits.clone()
    boost_book_score = boost_logits[:, BOOK_ID, :, :]
    final_logits[:, BOOK_ID, :, :][mask] += boost_book_score[mask] * alpha
    return F.softmax(final_logits, dim=1)

# --- Dataset ---
def estimate_height_from_depth(depth_np: np.ndarray) -> np.ndarray:
    H, W = depth_np.shape
    y_grid = np.linspace(0, 1, H).reshape(H, 1).repeat(W, axis=1).astype(np.float32)
    height_map = y_grid * depth_np
    max_val = float(height_map.max())
    if max_val > 0: height_map /= max_val
    return height_map.astype(np.float32)

class NYUv2Dataset(Dataset):
    def __init__(self, root_dir, split='train', transform=None, return_label=True):
        self.split = split
        self.transform = transform
        self.return_label = return_label
        src_split = 'train' if split in ['train', 'val'] else 'test'
        self.images_dir = os.path.join(root_dir, src_split, 'image')
        self.depths_dir = os.path.join(root_dir, src_split, 'depth')
        self.labels_dir = os.path.join(root_dir, src_split, 'label') if src_split == 'train' else None
        self.filenames = sorted([f for f in os.listdir(self.images_dir) if f.endswith('.png')])
        self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        self.std  = np.array([0.229, 0.224, 0.225], dtype=np.float32)
        self.d_mean, self.d_std = 0.5, 0.25

    def __len__(self): return len(self.filenames)

    def __getitem__(self, idx):
        fname = self.filenames[idx]
        rgb = np.array(Image.open(os.path.join(self.images_dir, fname)).convert('RGB'))
        depth = np.array(Image.open(os.path.join(self.depths_dir, fname)))
        if depth.ndim == 3: depth = depth[:, :, 0]
        depth = depth.astype(np.float32) / (65535.0 if depth.max() > 255 else 255.0)
        h_map = estimate_height_from_depth(depth)

        if self.labels_dir and self.return_label:
            label = np.array(Image.open(os.path.join(self.labels_dir, fname)))
        else:
            label = np.zeros(depth.shape, dtype=np.int32)

        if self.transform:
            augmented = self.transform(image=rgb, depth=depth, height=h_map, mask=label)
            rgb, depth, h_map, label = augmented['image'], augmented['depth'], augmented['height'], augmented['mask']

        rgb = (rgb.astype(np.float32) / 255.0 - self.mean) / self.std
        depth = (depth - self.d_mean) / self.d_std
        h_map = (h_map - self.d_mean) / self.d_std

        x = torch.cat([
            torch.from_numpy(rgb.transpose(2, 0, 1)).float(),
            torch.from_numpy(depth).unsqueeze(0).float(),
            torch.from_numpy(h_map).unsqueeze(0).float()
        ], dim=0) # 5ch

        if self.split == 'test': return x, fname
        return x, torch.from_numpy(label).long()

In [None]:
# =========================
# Cell 2: Base Model (Full Definitions & Modified for Unified Directory)
# =========================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, Subset
from torchvision import models
from torch.amp import autocast, GradScaler
import albumentations as A
from tqdm import tqdm
import os
import shutil
import numpy as np

# --- Architecture Components ---
class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False):
        super().__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
    def forward(self, x): return self.pointwise(self.depthwise(x))

class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, p=1, d=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, k, padding=p, dilation=d, bias=False),
            nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.block(x)

class ASPP(nn.Module):
    def __init__(self, in_ch, out_ch=256, rates=(12, 24, 36)):
        super().__init__()
        r1, r2, r3 = rates
        self.b1 = ConvBNReLU(in_ch, out_ch, 1, 0, 1)
        self.b2 = ConvBNReLU(in_ch, out_ch, 3, r1, r1)
        self.b3 = ConvBNReLU(in_ch, out_ch, 3, r2, r2)
        self.b4 = ConvBNReLU(in_ch, out_ch, 3, r3, r3)
        self.b5 = nn.Sequential(nn.AdaptiveAvgPool2d(1), ConvBNReLU(in_ch, out_ch, 1, 0, 1))
        self.proj = nn.Sequential(ConvBNReLU(out_ch * 5, out_ch, 1, 0, 1), nn.Dropout(0.1))
    def forward(self, x):
        h, w = x.shape[2:]
        return self.proj(torch.cat([
            self.b1(x), self.b2(x), self.b3(x), self.b4(x),
            F.interpolate(self.b5(x), size=(h, w), mode='bilinear', align_corners=False)
        ], dim=1))

class ResNeXtDeepLabV3Plus_OS8(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES, in_channels=5, aspp_rates=(12,24,36)):
        super().__init__()
        backbone = models.resnext101_32x8d(
            weights=models.ResNeXt101_32X8D_Weights.IMAGENET1K_V1,
            replace_stride_with_dilation=[False, True, True],
        )
        old_conv = backbone.conv1
        new_conv = nn.Conv2d(in_channels, old_conv.out_channels, 7, stride=2, padding=3, bias=False)
        with torch.no_grad():
            new_conv.weight[:, :3] = old_conv.weight
            new_conv.weight[:, 3:] = old_conv.weight.mean(dim=1, keepdim=True).repeat(1, in_channels-3, 1, 1)

        self.enc0 = nn.Sequential(new_conv, backbone.bn1, backbone.relu)
        self.pool = backbone.maxpool
        self.enc1 = backbone.layer1
        self.enc2 = backbone.layer2
        self.enc3 = backbone.layer3
        self.enc4 = backbone.layer4

        self.aspp = ASPP(2048, 256, rates=aspp_rates)
        self.low_proj = nn.Sequential(nn.Conv2d(256, 48, 1, bias=False), nn.BatchNorm2d(48), nn.ReLU(inplace=True))
        self.dec1 = ConvBNReLU(256 + 48, 256, 3, 1, 1)
        self.dec2 = ConvBNReLU(256, 256, 3, 1, 1)
        self.seg_head = nn.Conv2d(256, num_classes, 1)
        self.boundary_head = nn.Conv2d(256, 1, 1)
        self.aux_head = nn.Sequential(ConvBNReLU(1024, 256, 3, 1, 1), nn.Dropout(0.1), nn.Conv2d(256, num_classes, 1))

    def forward(self, x, return_aux=False, return_boundary=False):
        H, W = x.shape[2:]
        x0 = self.enc0(x); x1 = self.pool(x0)
        low = self.enc1(x1); x2 = self.enc2(low)
        mid = self.enc3(x2); x4 = self.enc4(mid)

        xA = self.aspp(x4)
        xA = F.interpolate(xA, size=low.shape[2:], mode='bilinear', align_corners=False)
        feat = self.dec2(self.dec1(torch.cat([xA, self.low_proj(low)], dim=1)))

        outs = [F.interpolate(self.seg_head(feat), size=(H, W), mode='bilinear', align_corners=False)]
        if return_boundary:
            outs.append(F.interpolate(self.boundary_head(feat), size=(H, W), mode='bilinear', align_corners=False))
        if return_aux and self.training:
            outs.append(F.interpolate(self.aux_head(mid), size=(H, W), mode='bilinear', align_corners=False))

        return outs[0] if len(outs) == 1 else tuple(outs)

# --- Losses ---
class FocalLoss(nn.Module):
    def __init__(self, weight=None, gamma=2.0, ignore_index=IGNORE_INDEX):
        super().__init__()
        self.weight = weight; self.gamma = gamma; self.ignore_index = ignore_index
    def forward(self, logits, target):
        ce = F.cross_entropy(logits, target, weight=self.weight, ignore_index=self.ignore_index, reduction='none')
        return (((1 - torch.exp(-ce)) ** self.gamma) * ce).mean()

class ClassBalancedDiceLoss(nn.Module):
    def __init__(self, n_classes=NUM_CLASSES, smooth=1e-5, ignore_index=IGNORE_INDEX, class_weights=None):
        super().__init__()
        self.n_classes = n_classes; self.smooth = smooth; self.ignore_index = ignore_index; self.class_weights = class_weights
    def forward(self, logits, target):
        prob = F.softmax(logits, dim=1)
        mask = (target != self.ignore_index)
        t = target.clone(); t[~mask] = 0
        onehot = F.one_hot(t, self.n_classes).permute(0, 3, 1, 2).float() * mask.unsqueeze(1)
        prob = prob * mask.unsqueeze(1)
        inter = (prob * onehot).sum(dim=(2, 3))
        dice = (2 * inter + self.smooth) / (prob.sum(dim=(2, 3)) + onehot.sum(dim=(2, 3)) + self.smooth)
        if self.class_weights is not None:
            w = self.class_weights.to(dice.device)
            return 1 - (dice.mean(dim=0) * w).sum() / (w.sum() + 1e-12)
        return 1 - dice.mean()

def make_boundary_mask(target, ignore_index=IGNORE_INDEX, dilate=2):
    valid = (target != ignore_index)
    t = target.clone(); t[~valid] = -1
    edge = torch.zeros_like(t, dtype=torch.bool)
    edge[:, 1:] |= (t[:, 1:] != t[:, :-1]) & valid[:, 1:]
    edge[:, :-1] |= (t[:, :-1] != t[:, 1:]) & valid[:, :-1]
    edge[:, :, 1:] |= (t[:, :, 1:] != t[:, :, :-1]) & valid[:, :, 1:]
    edge[:, :, :-1] |= (t[:, :, :-1] != t[:, :, 1:]) & valid[:, :, :-1]
    edge = edge.float().unsqueeze(1)
    for _ in range(max(0, dilate)): edge = F.max_pool2d(edge, kernel_size=3, stride=1, padding=1)
    return (edge > 0).float()

def ohem_cross_entropy(logits, target, weight, ignore_index=IGNORE_INDEX, min_kept=131072, thresh=0.7):
    with torch.no_grad():
        prob = F.softmax(logits, dim=1)
        valid = (target != ignore_index)
        t_safe = target.clone(); t_safe[~valid] = 0
        p_gt = prob.gather(1, t_safe.unsqueeze(1)).squeeze(1)
        hard = (p_gt < thresh) & valid
    loss = F.cross_entropy(logits, target, weight=weight, ignore_index=ignore_index, reduction='none')
    if loss[hard].numel() >= min_kept: return torch.topk(loss[hard], k=min_kept).values.mean()
    return torch.topk(loss[valid], k=min(min_kept, loss[valid].numel())).values.mean()

def tta_predict_logits(model, x, out_hw, img_size, scales=(1.0,), do_flip=True):
    B, C = x.shape[0], NUM_CLASSES
    acc = torch.zeros((B, C, out_hw[0], out_hw[1]), device=x.device)
    n_aug = 0
    for s in scales:
        hs, ws = int(img_size * s), int(img_size * s)
        xs = F.interpolate(x, size=(hs, ws), mode='bilinear', align_corners=False)
        out = F.interpolate(model(xs), size=out_hw, mode='bilinear', align_corners=False)
        acc += F.softmax(out, dim=1); n_aug += 1
        if do_flip:
            xsf = torch.flip(xs, dims=[3])
            out_f = torch.flip(F.interpolate(model(xsf), size=out_hw, mode='bilinear', align_corners=False), dims=[3])
            acc += F.softmax(out_f, dim=1); n_aug += 1
    return acc / n_aug

# --- Base Training Function (Modified for Unified Directory) ---
def run_base_train_and_submit(save_dir, dataset_root, train_idx, val_idx, img_size=768, batch_size=16, epochs=60, lr=1e-4, seed=42, phase1_epochs=40, save_cm_every=5):
    set_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ディレクトリ作成 (Main側で管理するため、ここでは作成済みのパスを受け取る想定だが念のため)
    ensure_dir(save_dir)

    # Loggerの初期化 (Cell 1のRunLoggerを使用)
    # run_id はディレクトリ名などを利用
    run_id = os.path.basename(os.path.normpath(save_dir))
    logger = RunLogger(save_dir, run_id)

    base_best_path = os.path.join(save_dir, "base_best_model.pt")

    print(f"[BASE] Start Training. Output: {save_dir}")

    # Data Augmentation (共通)
    train_aug = A.Compose([
        A.RandomResizedCrop(size=(img_size, img_size), scale=(0.7, 1.0), ratio=(0.9, 1.1), p=1.0),
        A.HorizontalFlip(p=0.5), A.Rotate(limit=10, p=0.3), A.ColorJitter(p=0.5)
    ], additional_targets={'depth': 'image', 'height': 'image'})
    val_aug = A.Compose([A.Resize(img_size, img_size)], additional_targets={'depth': 'image', 'height': 'image'})

    train_loader = DataLoader(
        Subset(NYUv2Dataset(dataset_root, "train", train_aug), train_idx),
        batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, drop_last=True
    )
    val_loader = DataLoader(
        Subset(NYUv2Dataset(dataset_root, "train", val_aug), val_idx),
        batch_size=8, shuffle=False, num_workers=4, pin_memory=True
    )

    model = ResNeXtDeepLabV3Plus_OS8().to(device)

    # Loss & Optimizer
    ce_weights = torch.tensor([1.0, 12.0, 0.6, 2.2, 0.6, 1.0, 2.0, 1.5, 1.5, 5.0, 3.0, 0.5, 1.0], device=device)
    criterion_focal = FocalLoss(weight=ce_weights)
    criterion_dice = ClassBalancedDiceLoss(class_weights=torch.tensor([1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.2, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0], device=device))
    bce_boundary = nn.BCEWithLogitsLoss()

    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-3)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    scaler = GradScaler("cuda")

    best_miou = 0.0

    for epoch in range(1, epochs + 1):
        model.train()
        use_ohem = epoch > phase1_epochs
        w = [0.25, 0.20, 0.00, 0.40, 0.10, 0.05] if not use_ohem else [0.15, 0.00, 0.20, 0.35, 0.10, 0.20]

        epoch_losses = {"total": 0.0}
        pbar = tqdm(train_loader, desc=f"[BASE] Ep {epoch}")

        for x, y in pbar:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            optimizer.zero_grad()
            with autocast("cuda"):
                seg, bd, aux = model(x, return_aux=True, return_boundary=True)
                loss = (w[0] * criterion_focal(seg, y) + w[1] * (0 if use_ohem else F.cross_entropy(seg, y, weight=ce_weights, ignore_index=IGNORE_INDEX)) +
                        w[2] * (ohem_cross_entropy(seg, y, weight=ce_weights) if use_ohem else 0) + w[3] * criterion_dice(seg, y) +
                        w[4] * F.cross_entropy(aux, y, ignore_index=IGNORE_INDEX) + w[5] * bce_boundary(bd, make_boundary_mask(y)))

            scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
            epoch_losses["total"] += loss.item()
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        scheduler.step()
        logger.log_epoch_train("base", epoch, scheduler.get_last_lr()[0], {k: v/len(train_loader) for k, v in epoch_losses.items()})

        # Validation
        model.eval()
        cm = torch.zeros((NUM_CLASSES, NUM_CLASSES), dtype=torch.long)
        with torch.no_grad():
            for x, y in val_loader:
                cm = update_cm(cm, model(x.to(device)).argmax(1).cpu(), y.cpu(), NUM_CLASSES, IGNORE_INDEX)
        m = compute_metrics_from_cm(cm)

        # ログ保存 (BASEとBOOSTで形式統一)
        save_cm = (epoch % save_cm_every == 0) or (m["miou"] > best_miou)
        logger.log_epoch_val("base", epoch, m, CLASS_NAMES, cm_np=cm.numpy() if save_cm else None)
        fmt_metrics_console("BASE", epoch, m["miou"], m["class_iou"], m["class_precision"], m["class_recall"], CLASS_NAMES)

        if m["miou"] > best_miou:
            best_miou = m["miou"]
            torch.save(model.state_dict(), base_best_path)
            print(f"  -> New Best mIoU: {best_miou:.5f}")

    return base_best_path

In [None]:
# =========================
# Cell 3: Analysis (Dilated Context Input)
# =========================
def find_book_neurons_pure(dataset_root, train_idx, top_k=256):
    print("[ANALYSIS] Probing 'Book' neurons (Including Surrounding Pixels)...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Feature Extractor (RGB Input)
    full_model = models.resnext101_32x8d(weights=models.ResNeXt101_32X8D_Weights.IMAGENET1K_V1).to(device).eval()
    class BackboneFeature(nn.Module):
        def __init__(self, m): super().__init__(); self.m = m
        def forward(self, x):
            return self.m.layer4(self.m.layer3(self.m.layer2(self.m.layer1(self.m.maxpool(self.m.relu(self.m.bn1(self.m.conv1(x))))))))
    extractor = BackboneFeature(full_model)

    ds = Subset(NYUv2Dataset(dataset_root, split="train", transform=None), train_idx)

    # Book(1) vs Cabinet(5)/Object(6)
    TARGET_POS_ID = 1
    TARGET_NEG_IDS = [5, 6]

    pos_act_sum = torch.zeros(2048).to(device)
    neg_act_sum = torch.zeros(2048).to(device)
    pos_count = 0
    neg_count = 0

    with torch.no_grad():
        for idx in tqdm(range(len(ds)), desc="Analyzing Neurons"):
            x, y = ds[idx]

            # 対象クラスが含まれない画像はスキップ
            has_pos = (y == TARGET_POS_ID).any()
            has_neg = False
            for tid in TARGET_NEG_IDS:
                if (y == tid).any(): has_neg = True; break
            if not has_pos and not has_neg: continue

            rgb = F.interpolate(x[:3].unsqueeze(0).to(device), size=(512, 512))
            feat = extractor(rgb) # [1, 2048, 16, 16]

            # --- マスク作成と拡張 ---
            y_gpu = y.unsqueeze(0).unsqueeze(0).float().to(device)
            y_small = F.interpolate(y_gpu, size=(16, 16), mode='nearest')

            # 1. Positive (Book + 周辺)
            mask_pos = (y_small == TARGET_POS_ID).float()

            # ★ここが重要: 周辺ピクセルを含めるためにマスクを広げる (Dilation)
            # MaxPool2d(kernel=3, stride=1, padding=1) は Dilation と同じ効果
            if mask_pos.sum() > 0:
                mask_pos_dilated = F.max_pool2d(mask_pos, kernel_size=3, stride=1, padding=1)

                # 広げたマスク(周辺込み)で特徴量を集計
                pos_act_sum += (feat.squeeze(0) * mask_pos_dilated.squeeze(0)).sum(dim=(1, 2)) / (mask_pos_dilated.sum() + 1e-6)
                pos_count += 1

            # 2. Negative (Cabinet/Object + 周辺)
            mask_neg = torch.zeros_like(y_small)
            for tid in TARGET_NEG_IDS:
                mask_neg = torch.logical_or(mask_neg, (y_small == tid))
            mask_neg = mask_neg.float()

            # Negative側も条件を合わせて広げる
            if mask_neg.sum() > 0:
                mask_neg_dilated = F.max_pool2d(mask_neg, kernel_size=3, stride=1, padding=1)

                neg_act_sum += (feat.squeeze(0) * mask_neg_dilated.squeeze(0)).sum(dim=(1, 2)) / (mask_neg_dilated.sum() + 1e-6)
                neg_count += 1

    # スコア計算: (Book周辺込み) - (Cabinet周辺込み)
    pos_mean = pos_act_sum / (pos_count + 1e-6)
    neg_mean = neg_act_sum / (neg_count + 1e-6)

    discriminative_score = pos_mean - neg_mean
    values, indices = torch.topk(discriminative_score, k=top_k)
    selected_indices = indices.tolist()

    print(f"  -> Top 5 Scores (Dilated Input): {values[:5].cpu().numpy()}")
    return selected_indices

In [None]:
# =========================
# Cell 4: Boost Model (Fixed: Manhattan Edge Constraint)
# =========================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import models
import albumentations as A
import numpy as np
import os
import cv2
from tqdm import tqdm

# --- 1. Attention Modules (CBAM) - Unchanged ---
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        return self.sigmoid(avg_out + max_out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        return self.sigmoid(self.conv1(x))

class CBAM(nn.Module):
    def __init__(self, in_planes, ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes, ratio)
        self.sa = SpatialAttention(kernel_size)
    def forward(self, x):
        x = x * self.ca(x)
        x = x * self.sa(x)
        return x

class AttentionDecoderBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.attention = CBAM(skip_ch)
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch + skip_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch), nn.ReLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch), nn.ReLU()
        )
    def forward(self, x, skip):
        x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
        skip = self.attention(skip)
        return self.conv(torch.cat([x, skip], dim=1))

# --- 2. Boost Model Architecture ---
class TriDeepLabV3Plus_Boosted_Attention(nn.Module):
    def __init__(self):
        super().__init__()
        print("[BOOST] Initializing Backbone (UNFROZEN, RGB ONLY, Full Features)...")
        self.backbone = models.resnext101_32x8d(weights=models.ResNeXt101_32X8D_Weights.IMAGENET1K_V1)

        self.stem = nn.Sequential(self.backbone.conv1, self.backbone.bn1, self.backbone.relu) # 1/2
        self.layer1 = nn.Sequential(self.backbone.maxpool, self.backbone.layer1) # 1/4
        self.layer2 = self.backbone.layer2 # 1/8
        self.layer3 = self.backbone.layer3 # 1/16
        self.layer4 = self.backbone.layer4 # 1/32

        self.center = nn.Sequential(
            nn.Conv2d(2048, 512, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(),
            CBAM(512)
        )

        self.dec3 = AttentionDecoderBlock(in_ch=512, skip_ch=1024, out_ch=256)
        self.dec2 = AttentionDecoderBlock(in_ch=256, skip_ch=512, out_ch=128)
        self.dec1 = AttentionDecoderBlock(in_ch=128, skip_ch=256, out_ch=64)
        self.dec0 = AttentionDecoderBlock(in_ch=64, skip_ch=64, out_ch=32)

        self.final = nn.Conv2d(32, 1, 1)

    def forward(self, x):
        input_size = x.shape[-2:]
        x = x[:, :3, :, :] # RGB Only

        x0 = self.stem(x)
        x1 = self.layer1(x0)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)

        c = self.center(x4)
        d3 = self.dec3(c, x3)
        d2 = self.dec2(d3, x2)
        d1 = self.dec1(d2, x1)
        d0 = self.dec0(d1, x0)

        out = self.final(d0)
        return F.interpolate(out, size=input_size, mode='bilinear', align_corners=False)

# --- 3. Edge Consistency Loss (New!) ---
def compute_edge_loss(pred_logits, rgb_image):
    """
    RGB画像の勾配（エッジ）が存在しない場所で、
    予測マスクの勾配（変化）が起きることを罰するロス関数。
    """
    # 1. RGB画像からエッジ強度を計算 (Sobel-like)
    # グレースケール化
    gray = rgb_image.mean(dim=1, keepdim=True)

    # 縦・横の勾配
    dy_img = torch.abs(gray[:, :, 1:, :] - gray[:, :, :-1, :])
    dx_img = torch.abs(gray[:, :, :, 1:] - gray[:, :, :, :-1])

    # 画像のエッジ強度 (0~1に近づけるためexpで減衰させる係数を作る)
    # エッジがある場所(=1)は罰則なし、エッジがない平坦な場所(=0)は罰則最大
    edge_weight_y = torch.exp(-5.0 * dy_img)
    edge_weight_x = torch.exp(-5.0 * dx_img)

    # 2. 予測マスクの勾配を計算 (Sigmoid後の確率値で見る)
    prob = torch.sigmoid(pred_logits)
    dy_pred = torch.abs(prob[:, :, 1:, :] - prob[:, :, :-1, :])
    dx_pred = torch.abs(prob[:, :, :, 1:] - prob[:, :, :, :-1])

    # 3. ロス計算: (画像にエッジがない場所の変化) の総和
    loss_y = (dy_pred * edge_weight_y).mean()
    loss_x = (dx_pred * edge_weight_x).mean()

    return loss_y + loss_x

# --- 4. Training Function ---
def run_boost_train_fixed(save_dir, dataset_root, train_idx, val_idx, img_size, batch_size, epochs, lr, seed, BASE_MODEL_PATH):
    set_seed(seed)
    ensure_dir(save_dir)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    run_id = os.path.basename(os.path.normpath(save_dir))
    logger = RunLogger(save_dir, run_id)

    print(f"[BOOST] Start Training (Edge Consistency Constraint). Output: {save_dir}")

    # 1. Base Model (Fixed)
    base_model = ResNeXtDeepLabV3Plus_OS8().to(device)
    base_model.load_state_dict(torch.load(BASE_MODEL_PATH, map_location=device))
    base_model.eval()

    # 2. Boost Model
    model = TriDeepLabV3Plus_Boosted_Attention().to(device)

    optimizer = optim.AdamW([
        {'params': model.center.parameters(), 'lr': lr},
        {'params': model.dec3.parameters(), 'lr': lr},
        {'params': model.dec2.parameters(), 'lr': lr},
        {'params': model.dec1.parameters(), 'lr': lr},
        {'params': model.dec0.parameters(), 'lr': lr},
        {'params': model.final.parameters(), 'lr': lr},
        {'params': model.backbone.parameters(), 'lr': lr * 0.1}
    ], weight_decay=1e-3)

    # --- Oversampling ---
    full_ds = NYUv2Dataset(dataset_root, "train", transform=None)
    book_indices_in_ds = []
    for idx in tqdm(train_idx, desc="Scanning for books"):
        _, label = full_ds[idx]
        if (label == BOOK_ID).any():
            book_indices_in_ds.append(idx)

    # Oversample
    augmented_train_idx = list(train_idx) + (book_indices_in_ds * 2)
    print(f"  -> Oversampled size: {len(augmented_train_idx)}")

    class DilatedBinaryWithHardNeg(Dataset):
        def __init__(self, ds, target_id, kernel_size=5):
            self.ds = ds; self.tid = target_id
            self.kernel = np.ones((kernel_size, kernel_size), np.uint8)
        def __len__(self): return len(self.ds)
        def __getitem__(self, i):
            x, y = self.ds[i]
            mask = (y == self.tid).numpy().astype(np.uint8)
            if mask.sum() > 0:
                mask = cv2.dilate(mask, self.kernel, iterations=1)
            bin_y = torch.from_numpy(mask).float().unsqueeze(0)
            full_y = y.clone()
            return x, bin_y, full_y

    train_aug = A.Compose([
        A.RandomResizedCrop((img_size, img_size), scale=(0.5, 1.0)),
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
        A.GaussianBlur(p=0.3),
        A.CoarseDropout(num_holes_range=(1, 8), hole_height_range=(4, 32), hole_width_range=(4, 32), p=0.3)
    ], additional_targets={'depth':'image','height':'image'})

    val_aug = A.Compose([A.Resize(img_size, img_size)], additional_targets={'depth':'image','height':'image'})

    train_ds = Subset(DilatedBinaryWithHardNeg(NYUv2Dataset(dataset_root, "train", train_aug), BOOK_ID, kernel_size=7), augmented_train_idx)
    val_ds_full = Subset(NYUv2Dataset(dataset_root, "train", val_aug), val_idx)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader_full = DataLoader(val_ds_full, batch_size=batch_size, shuffle=False, num_workers=4)

    best_miou = 0.0
    best_path = os.path.join(save_dir, "boost_best_edge.pt")

    FUSION_ALPHA = 1.2
    TARGET_IDS = [BOOK_ID, 5, 6]
    fusion_dilate_kernel = torch.ones((1, 1, 7, 7)).to(device)
    HN_WEIGHTS = {5: 4.0, 6: 2.5} # Hard Negative Weights

    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = 0
        pbar = tqdm(train_loader, desc=f"[BOOST] Ep {epoch}")

        for x, bin_y, full_y in pbar:
            x, bin_y, full_y = x.to(device), bin_y.to(device), full_y.to(device)
            optimizer.zero_grad()
            out = model(x) # [B, 1, H, W]

            # 1. Main Loss (Hard Negative Weighted BCE)
            bce_loss = F.binary_cross_entropy_with_logits(out, bin_y, reduction='none')
            weight_map = torch.ones_like(bce_loss)
            for cls_id, w in HN_WEIGHTS.items():
                weight_map[full_y.unsqueeze(1) == cls_id] = w
            weight_map[bin_y == 1] = 2.0
            loss_main = (bce_loss * weight_map).mean()

            # 2. Edge Consistency Loss (常識の制約)
            # RGB画像(x)は正規化されているため、0-1に戻して使うとより正確だが、
            # 変化を見るだけなのでそのまま入力しても機能する
            loss_edge = compute_edge_loss(out, x[:, :3, :, :])

            # Total Loss (Edge Lossの重みは実験的に強めに 0.5~1.0)
            loss = loss_main + 0.5 * loss_edge

            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            pbar.set_postfix({"loss": f"{loss.item():.4f} (E:{loss_edge.item():.4f})"})

        # Validation
        model.eval()
        cm = torch.zeros((NUM_CLASSES, NUM_CLASSES), dtype=torch.long)

        with torch.no_grad():
            for x, y in val_loader_full:
                x = x.to(device)
                base_logits = base_model(x)
                boost_logits = model(x)

                base_probs = F.softmax(base_logits, dim=1)
                boost_prob = torch.sigmoid(boost_logits).squeeze(1)

                # Fusion
                base_pred = base_probs.argmax(dim=1)
                target_mask = torch.zeros_like(base_pred, dtype=torch.float32)
                for tid in TARGET_IDS:
                    target_mask = torch.logical_or(target_mask.bool(), (base_pred == tid)).float()

                if target_mask.sum() > 0:
                    dilated_mask = torch.clamp(F.conv2d(target_mask.unsqueeze(1), fusion_dilate_kernel, padding=3), 0, 1).squeeze(1).bool()
                else:
                    dilated_mask = target_mask.bool()

                final_probs = base_probs.clone()
                if dilated_mask.any():
                    final_probs[:, BOOK_ID, :, :][dilated_mask] += (boost_prob[dilated_mask] * FUSION_ALPHA)

                temp_pred = final_probs.argmax(1)
                orig_target_mask = target_mask.bool()
                final_pred = base_pred.clone()
                final_pred[orig_target_mask] = temp_pred[orig_target_mask]

                cm = update_cm(cm, final_pred.cpu(), y.cpu(), NUM_CLASSES, IGNORE_INDEX)

        m = compute_metrics_from_cm(cm)
        logger.log_epoch_val("boost_edge", epoch, m, CLASS_NAMES, cm_np=None)
        fmt_metrics_console("BOOST(EDGE)", epoch, m["miou"], m["class_iou"], m["class_precision"], m["class_recall"], CLASS_NAMES)

        if m["miou"] > best_miou:
            best_miou = m["miou"]
            torch.save(model.state_dict(), best_path)
            print(f"  -> New Best mIoU: {best_miou:.5f}")

    return best_path

In [None]:
# =========================
# Cell 5: Main Execution (Corrected for Boost V2)
# =========================
import os
import shutil
import datetime
import random
import torch
import numpy as np
from torch.utils.data import DataLoader, Subset

# --- Helper Functions ---
def get_jst_timestamp():
    tz_jst = datetime.timezone(datetime.timedelta(hours=9))
    now = datetime.datetime.now(tz_jst)
    return now.strftime("%Y%m%d_%H%M%S_JST")

def backup_run_to_drive(local_run_dir, drive_root, folder_prefix):
    if local_run_dir is None or not os.path.exists(local_run_dir):
        return
    timestamp = get_jst_timestamp()
    folder_name = f"{folder_prefix}_{timestamp}"
    drive_dest_path = os.path.join(drive_root, folder_name)
    print(f"\n[BACKUP] Copying to Drive: {drive_dest_path}")
    try:
        shutil.copytree(local_run_dir, drive_dest_path)
        print(f"[BACKUP] Success!")
    except FileExistsError:
        print(f"[WARN] Backup skipped (Destination exists).")
    except Exception as e:
        print(f"[ERROR] Backup failed: {e}")

# --- Main Logic ---
def main():
    mount_drive()
    DRIVE_ROOT = "/content/drive/MyDrive/nyu_runs"
    ensure_dir(DRIVE_ROOT)

    # ★実行モード設定
    # "base": ベースモデルのみ学習
    # "tri":  Boostモデルのみ学習 (Baseモデルが必要)
    # "both": 両方実行
    RUN_MODE = "tri"

    timestamp = get_jst_timestamp()
    run_root = os.path.join("/content", f"run_{timestamp}")
    ensure_dir(run_root)

    dir_base = os.path.join(run_root, "base")
    dir_boost = os.path.join(run_root, "boost")

    print(f"==========================================")
    print(f" START PIPELINE at {timestamp}")
    print(f" MODE: {RUN_MODE}")
    print(f" DEVICE: {torch.cuda.get_device_name(0)}")
    print(f"==========================================\n")

    # データ分割 (Seed 42固定)
    full_dataset = NYUv2Dataset("/content/data", split="train", transform=None)
    indices = list(range(len(full_dataset)))
    random.seed(42)
    random.shuffle(indices)
    n_val = int(len(full_dataset) * 0.1)
    train_idx = indices[:-n_val]
    val_idx = indices[-n_val:]

    # ★ベースモデルのパス
    # 既存の学習済みモデルがある場合は、そのパスを指定してください
    # (例: "/content/drive/MyDrive/nyu_runs/.../base_best_model.pt")
    base_model_path = "/content/base_best_model.pt"

    # --- BASE MODEL ---
    if RUN_MODE in ["base", "both"]:
        base_model_path = run_base_train_and_submit(
            save_dir=dir_base,
            dataset_root="/content/data",
            train_idx=train_idx, val_idx=val_idx,
            img_size=768, batch_size=16, epochs=1, lr=1e-4
        )
        print(f"\n[INFO] Base model saved at: {base_model_path}")

    # --- BOOST MODEL ---
    if RUN_MODE in ["tri", "both"]:
        # パスチェック
        if not os.path.exists(base_model_path):
            print(f"\n[ERROR] Base model not found at: {base_model_path}")
            print("Please upload 'base_best_model.pt' to /content/ or change RUN_MODE to 'both'.")
            return

        print(f"\n[BOOST] Using Base Model from: {base_model_path}")
        print(">>> Strategy: Full Features (No Neuron Selection) + Attention + Oversampling")

        # Boost実行 (引数をCell 4に合わせて修正済み)
        boost_model_path = run_boost_train_fixed(
            save_dir=dir_boost,
            # book_indices=indices_book,  <-- 削除 (Cell 4定義と不整合のため)
            dataset_root="/content/data",
            train_idx=train_idx, val_idx=val_idx,
            img_size=768, batch_size=16, epochs=25, lr=1e-3, seed=42,
            BASE_MODEL_PATH=base_model_path
        )

    # Driveへバックアップ
    backup_run_to_drive(run_root, DRIVE_ROOT, "RUN_BoostV2")
    print(f"\n==========================================")
    print(f" ALL DONE.")

if __name__ == "__main__":
    main()

In [None]:
# =========================
# Cell 12 (Final Fix): Blob Correction with mIoU Check
# =========================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Subset
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report
import os
import random
import albumentations as A
from tqdm import tqdm
from skimage.measure import label, regionprops

# --- Settings ---
OUTPUT_DIR = "/content/blob_correction_final"
if not os.path.exists(OUTPUT_DIR): os.makedirs(OUTPUT_DIR)

ID_BOOK = 1
ID_CABINET = 5
ID_OBJECT = 6
TARGET_IDS = [ID_BOOK, ID_CABINET, ID_OBJECT]
NUM_CLASSES = 13

# --- Reuse Metrics Utils from Cell 1 ---
# (定義済みと仮定しますが、念のため簡易版を再定義)
def fast_hist(a, b, n):
    k = (a >= 0) & (a < n)
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n**2).reshape(n, n)

def compute_miou_simple(cm):
    tp = np.diag(cm)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp
    iou = tp / (tp + fp + fn + 1e-10)
    return np.nanmean(iou)

# --- MLP & Dataset (Unchanged) ---
class BlobCorrectionNet(nn.Module):
    def __init__(self, input_dim, num_classes=4):
        super(BlobCorrectionNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(64, num_classes)
        )
    def forward(self, x): return self.net(x)

class BlobDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32).unsqueeze(1)
    def __len__(self): return len(self.X)
    def __getitem__(self, idx): return self.X[idx], self.y[idx]

def get_majority_label(gt_crop, ignore_index=255):
    valid_gt = gt_crop[gt_crop != ignore_index]
    if len(valid_gt) == 0: return -1
    counts = np.bincount(valid_gt)
    return np.argmax(counts)

def mine_blobs_multiclass(model, dataset, device, desc="Mining"):
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
    blob_data = []
    model.eval()

    with torch.no_grad():
        for i, (x, y) in enumerate(tqdm(loader, desc=desc)):
            x = x.to(device)
            y_np = y.numpy()[0]

            logits = model(x)
            probs = F.softmax(logits, dim=1)
            pred_map = probs.argmax(dim=1).cpu().numpy()[0]

            prob_book = probs[0, ID_BOOK].cpu().numpy()
            prob_cab  = probs[0, ID_CABINET].cpu().numpy()
            prob_obj  = probs[0, ID_OBJECT].cpu().numpy()

            for pred_cls in TARGET_IDS:
                mask = (pred_map == pred_cls)
                if not mask.any(): continue
                lbl_img = label(mask)
                regions = regionprops(lbl_img)
                for props in regions:
                    if props.area < 50: continue
                    y0, x0, y1, x1 = props.bbox
                    h, w = y1 - y0, x1 - x0
                    r_mask = (lbl_img == props.label)
                    gt_crop = y_np[r_mask]
                    true_label = get_majority_label(gt_crop)
                    if true_label == -1: continue

                    if true_label == ID_BOOK: mlp_target = 1
                    elif true_label == ID_CABINET: mlp_target = 2
                    elif true_label == ID_OBJECT: mlp_target = 3
                    else: mlp_target = 0

                    blob_data.append({
                        "original_pred": pred_cls,
                        "area": props.area,
                        "aspect_ratio": h / (w + 1e-6),
                        "centroid_y": props.centroid[0] / 768.0,
                        "solidity": props.solidity, "extent": props.extent, "eccentricity": props.eccentricity,
                        "prob_book_mean": np.mean(prob_book[r_mask]), "prob_book_max": np.max(prob_book[r_mask]), "prob_book_std": np.std(prob_book[r_mask]),
                        "prob_cab_mean": np.mean(prob_cab[r_mask]), "prob_cab_max": np.max(prob_cab[r_mask]), "prob_cab_std": np.std(prob_cab[r_mask]),
                        "prob_obj_mean": np.mean(prob_obj[r_mask]), "prob_obj_max": np.max(prob_obj[r_mask]), "prob_obj_std": np.std(prob_obj[r_mask]),
                        "mlp_target": mlp_target
                    })
    return pd.DataFrame(blob_data)

# --- Main Execution ---
def run_blob_correction_robust_miou():
    print(f"==========================================")
    print(f" START BLOB CORRECTION (mIoU Check)")
    print(f"==========================================")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1. Load Base Model
    base_model_path = "/content/base_best_model.pt"
    if not os.path.exists(base_model_path):
        print(f"[ERROR] Base model not found at {base_model_path}")
        return
    print(f"[INFO] Loading Base Model from {base_model_path}...")
    model = ResNeXtDeepLabV3Plus_OS8().to(device)
    model.load_state_dict(torch.load(base_model_path, map_location=device))
    model.eval()

    # 2. Re-generate Split (Cell 5 Logic)
    print("[INFO] Re-generating indices (Seed 42)...")
    ds_full = NYUv2Dataset("/content/data", split="train", transform=None)
    indices = list(range(len(ds_full)))
    random.seed(42)
    random.shuffle(indices)

    n_val = int(len(ds_full) * 0.1)
    train_idx = indices[:-n_val]
    val_idx = indices[-n_val:]
    print(f"[INFO] Train: {len(train_idx)}, Val: {len(val_idx)}")

    # 3. Validation Dataset
    val_aug = A.Compose([A.Resize(768, 768)], additional_targets={'depth': 'image', 'height': 'image'})
    full_val_ds = NYUv2Dataset("/content/data", split="train", transform=val_aug)
    val_ds = Subset(full_val_ds, val_idx)

    # 4. Leak Check using mIoU (Better than Pixel Acc)
    print("[CHECK] Verifying split integrity with mIoU...")
    check_loader = DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=4)
    hist = np.zeros((NUM_CLASSES, NUM_CLASSES))

    # 全件チェック (A100ならすぐ終わる)
    with torch.no_grad():
        for x, y in tqdm(check_loader, desc="Calculating mIoU"):
            x = x.to(device)
            pred = model(x).argmax(dim=1).cpu().numpy()
            y = y.numpy()
            hist += fast_hist(y.flatten(), pred.flatten(), NUM_CLASSES)

    miou = compute_miou_simple(hist)
    print(f"\n  -> Val Subset mIoU: {miou:.5f}")

    if miou > 0.85:
        print("\n[CRITICAL WARNING] mIoU > 0.85 suggests LEAK. Stopping.")
        return
    elif miou < 0.30:
        print("\n[WARNING] mIoU is surprisingly low. Check model weights.")
    else:
        print(f"  -> mIoU is realistic (Target ~0.67). Proceeding.")

    # 5. Mining & Train
    print(f"\n[STEP 1] Mining blobs from Train...")
    train_mining_ds = Subset(full_val_ds, train_idx) # Use resized ds
    df_train = mine_blobs_multiclass(model, train_mining_ds, device, "Mining Train")

    print(f"[STEP 2] Mining blobs from Val...")
    df_val = mine_blobs_multiclass(model, val_ds, device, "Mining Val")

    # Save & Prep
    df_train.to_csv(os.path.join(OUTPUT_DIR, "blobs_train.csv"), index=False)
    df_val.to_csv(os.path.join(OUTPUT_DIR, "blobs_val.csv"), index=False)

    df_train = pd.get_dummies(df_train, columns=["original_pred"], prefix="base")
    df_val = pd.get_dummies(df_val, columns=["original_pred"], prefix="base")

    expected_cols = [
        "area", "aspect_ratio", "centroid_y", "solidity", "extent", "eccentricity",
        "prob_book_mean", "prob_book_max", "prob_book_std",
        "prob_cab_mean", "prob_cab_max", "prob_cab_std",
        "prob_obj_mean", "prob_obj_max", "prob_obj_std",
        f"base_{ID_BOOK}", f"base_{ID_CABINET}", f"base_{ID_OBJECT}"
    ]
    for col in expected_cols:
        if col not in df_train.columns: df_train[col] = 0
        if col not in df_val.columns: df_val[col] = 0

    X_train = df_train[expected_cols].fillna(0).values
    y_train = df_train["mlp_target"].values
    X_val = df_val[expected_cols].fillna(0).values
    y_val = df_val["mlp_target"].values

    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_val = scaler.transform(X_val)

    print(f"\n[STEP 3] Training MLP...")
    class_counts = np.bincount(y_train.astype(int)); class_counts = np.maximum(class_counts, 1)
    weights = 1. / class_counts; sample_weights = weights[y_train.astype(int)]
    sampler = WeightedRandomSampler(sample_weights, len(sample_weights))

    train_loader = DataLoader(BlobDataset(X_train, y_train), batch_size=64, sampler=sampler)
    mlp = BlobCorrectionNet(len(expected_cols), num_classes=4).to(device)
    optimizer = optim.Adam(mlp.parameters(), lr=0.001); criterion = nn.CrossEntropyLoss()

    for epoch in range(1, 51):
        mlp.train()
        for bx, by in train_loader:
            bx, by = bx.to(device), by.to(device).long().squeeze(1)
            optimizer.zero_grad(); out = mlp(bx); loss = criterion(out, by)
            loss.backward(); optimizer.step()

    # 6. Evaluate
    mlp.eval()
    val_preds = mlp(torch.tensor(X_val, dtype=torch.float32).to(device)).argmax(dim=1).cpu().numpy()

    print("\n--- Blob-Level Classification Report (MLP on Val) ---")
    print(classification_report(y_val, val_preds, target_names=["Other", "Book", "Cab", "Obj"]))

    print("\n--- Baseline (Base Model) Blob Accuracy ---")
    base_preds_val = df_val[[f"base_{ID_BOOK}", f"base_{ID_CABINET}", f"base_{ID_OBJECT}"]].idxmax(axis=1)
    def map_col(col_name):
        return {f"base_{ID_BOOK}":1, f"base_{ID_CABINET}":2, f"base_{ID_OBJECT}":3}.get(col_name, 0)
    base_preds_val = base_preds_val.apply(map_col).values
    print(classification_report(y_val, base_preds_val, target_names=["Other", "Book", "Cab", "Obj"]))

    torch.save(mlp.state_dict(), os.path.join(OUTPUT_DIR, "blob_correction_mlp.pt"))

if __name__ == "__main__":
    run_blob_correction_robust_miou()

In [None]:
# =========================
# Cell 13 (Strategy C): Train on Noisy Preds, Label with GT
# =========================
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, Dataset
import pandas as pd
import numpy as np
import lightgbm as lgb
from sklearn.metrics import classification_report
import os
import random
import albumentations as A
from tqdm import tqdm
from PIL import Image
from skimage.measure import label, regionprops
from scipy.ndimage import binary_dilation, generate_binary_structure

# --- Settings ---
OUTPUT_DIR = "/content/gbdt_correction_result_C"
if not os.path.exists(OUTPUT_DIR): os.makedirs(OUTPUT_DIR)

ID_BOOK = 1
ID_CABINET = 5
ID_OBJECT = 6
# Baseモデルがこれらを予測したときに、GBDTによる「検問」を入れる
TARGET_IDS = [ID_BOOK, ID_CABINET, ID_OBJECT]
NUM_CLASSES = 13
CLASS_NAMES = ["Bed", "Book", "Ceiling", "Chair", "Floor", "Cabinet", "Object", "Picture", "Sofa", "Desk", "TV", "Wall", "Window"]

# --- 1. Dataset Class (Fixed) ---
class NYUv2DatasetFixed(Dataset):
    def __init__(self, root_dir, split='train', transform=None):
        self.split = split
        self.transform = transform
        src_split = 'train' if split in ['train', 'val'] else 'test'

        self.images_dir = os.path.join(root_dir, src_split, 'image')
        self.depths_dir = os.path.join(root_dir, src_split, 'depth')
        self.labels_dir = os.path.join(root_dir, src_split, 'label')

        if not os.path.exists(self.images_dir):
            self.filenames = []
        else:
            self.filenames = sorted([f for f in os.listdir(self.images_dir) if f.endswith('.png')])

        self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        self.std  = np.array([0.229, 0.224, 0.225], dtype=np.float32)
        self.d_mean, self.d_std = 0.5, 0.25

    def __len__(self): return len(self.filenames)

    def __getitem__(self, idx):
        fname = self.filenames[idx]
        rgb = np.array(Image.open(os.path.join(self.images_dir, fname)).convert('RGB'))
        depth = np.array(Image.open(os.path.join(self.depths_dir, fname)))
        label = np.array(Image.open(os.path.join(self.labels_dir, fname)))

        if depth.ndim == 3: depth = depth[:, :, 0]
        depth = depth.astype(np.float32) / (65535.0 if depth.max() > 255 else 255.0)

        H, W = depth.shape
        y_grid = np.linspace(0, 1, H).reshape(H, 1).repeat(W, axis=1).astype(np.float32)
        h_map = y_grid * depth

        if self.transform:
            augmented = self.transform(image=rgb, depth=depth, height=h_map, mask=label)
            rgb, depth, h_map, label = augmented['image'], augmented['depth'], augmented['height'], augmented['mask']

        rgb = (rgb.astype(np.float32) / 255.0 - self.mean) / self.std
        depth = (depth - self.d_mean) / self.d_std
        h_map = (h_map - self.d_mean) / self.d_std

        x = torch.cat([
            torch.from_numpy(rgb.transpose(2, 0, 1)).float(),
            torch.from_numpy(depth).unsqueeze(0).float(),
            torch.from_numpy(h_map).unsqueeze(0).float()
        ], dim=0)

        return x, torch.from_numpy(label).long()

# --- 2. Utils ---
def fast_hist(a, b, n):
    k = (a >= 0) & (a < n)
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n**2).reshape(n, n)

def compute_miou_simple(cm):
    tp = np.diag(cm)
    fp = cm.sum(axis=0) - tp
    fn = cm.sum(axis=1) - tp
    iou = tp / (tp + fp + fn + 1e-10)
    return np.nanmean(iou), iou

def denormalize_depth(norm_depth, mean=0.5, std=0.25):
    return norm_depth * std + mean

# --- 3. Feature Engineering (Strategy C) ---
def get_neighbors_distribution(blob_mask, map_for_context, num_classes=13):
    """隣接クラス分布を計算"""
    struct = generate_binary_structure(2, 1)
    dilated = binary_dilation(blob_mask, structure=struct, iterations=2)
    neighbor_mask = dilated & ~blob_mask
    neighbor_labels = map_for_context[neighbor_mask]

    # Ignore(255)などを除外
    valid_labels = neighbor_labels[neighbor_labels < num_classes]

    if len(valid_labels) == 0:
        return np.zeros(num_classes, dtype=np.float32)
    counts = np.bincount(valid_labels, minlength=num_classes)
    return counts / counts.sum()

def extract_blob_features_gbdt(props, blob_mask, depth_map, map_for_context, pred_cls_id):
    """
    特徴量抽出: Baseモデルの確信度は含めない
    """
    features = {}

    # 1. Geometry
    y0, x0, y1, x1 = props.bbox
    h_blob, w_blob = y1 - y0, x1 - x0
    features['area'] = props.area
    features['aspect_ratio'] = h_blob / (w_blob + 1e-6)
    features['rectangularity'] = props.area / (h_blob * w_blob + 1e-6)
    features['centroid_norm_y'] = props.centroid[0] / depth_map.shape[0]

    # 2. Depth (物理特性: 本はギザギザしている)
    d_vals = depth_map[blob_mask]
    features['depth_mean'] = np.mean(d_vals)
    features['depth_std'] = np.std(d_vals)

    # 3. Context (文脈: 棚の中にあるか)
    # ここでは「予測マップ」の隣接状況を見る（本番環境と同じ条件）
    neighbor_probs = get_neighbors_distribution(blob_mask, map_for_context, NUM_CLASSES)
    features['neighbor_prob_cabinet'] = neighbor_probs[ID_CABINET]
    features['neighbor_prob_wall'] = neighbor_probs[11]
    features['neighbor_prob_obj'] = neighbor_probs[ID_OBJECT]

    # 4. 元の予測クラス (これ自体も重要な情報)
    features['pred_class_id'] = pred_cls_id

    return features

# --- 4. Data Collection (Mining from Predictions) ---
def collect_blobs_from_pred(base_model, loader, device, desc="Mining"):
    """
    Baseモデルの予測を実行し、Blobを切り出し、GTと照合してデータを収集する
    Train時もVal時もこれを使う（Train時はGTを教師データにする）
    """
    base_model.eval()
    data_rows = []

    with torch.no_grad():
        for x, y in tqdm(loader, desc=desc):
            x = x.to(device)
            y_np = y.numpy()

            # Base Model Inference
            logits = base_model(x)
            probs = F.softmax(logits, dim=1).cpu().numpy()
            preds = probs.argmax(axis=1)
            depths = x[:, 3, :, :].cpu().numpy()

            for b in range(x.shape[0]):
                pred_map = preds[b]
                gt_map = y_np[b]
                depth_map = denormalize_depth(depths[b])

                # ターゲットクラス(Book, Cabinet, Object)と予測されたBlobを全て検査
                for target_id in TARGET_IDS:
                    mask_cls = (pred_map == target_id)
                    if not mask_cls.any(): continue

                    labeled = label(mask_cls)
                    regions = regionprops(labeled)

                    for prop in regions:
                        if prop.area < 50: continue # 小さすぎるノイズは無視

                        blob_mask = np.zeros_like(pred_map, dtype=bool)
                        blob_mask[prop.coords[:,0], prop.coords[:,1]] = True

                        # 特徴量抽出
                        # map_for_context=pred_map (文脈は予測結果から判断せざるを得ないため)
                        feats = extract_blob_features_gbdt(
                            prop, blob_mask, depth_map,
                            map_for_context=pred_map,
                            pred_cls_id=target_id
                        )

                        # 【重要】正解ラベルの付与
                        # Blob領域内のGTの最頻値を取得
                        gt_vals = gt_map[blob_mask]
                        gt_vals = gt_vals[gt_vals != 255] # Ignore除外

                        if len(gt_vals) == 0:
                            # 全てIgnoreならスキップ
                            continue

                        majority_label = np.argmax(np.bincount(gt_vals, minlength=NUM_CLASSES))
                        feats['target_label'] = majority_label

                        data_rows.append(feats)

    return pd.DataFrame(data_rows)

# --- 5. Main Pipeline ---
def run_gbdt_correction_pipeline_C():
    print(f"==========================================")
    print(f" START GBDT PIPELINE (Strategy C: Learn Error Patterns)")
    print(f"==========================================")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    base_model_path = "/content/base_best_model.pt"
    if not os.path.exists(base_model_path):
        print(f"[ERROR] Base model not found.")
        return

    print(f"[INFO] Loading Base Model...")
    try:
        model = ResNeXtDeepLabV3Plus_OS8().to(device)
        model.load_state_dict(torch.load(base_model_path, map_location=device))
        model.eval()
    except NameError:
        print("[ERROR] Model class not defined.")
        return

    # Data Split
    dataset_root = "/content/data"
    ds_full = NYUv2DatasetFixed(dataset_root, split="train", transform=None)
    if len(ds_full) == 0: return

    indices = list(range(len(ds_full)))
    random.seed(42)
    random.shuffle(indices)

    n_val = int(len(ds_full) * 0.1)
    train_idx = indices[:-n_val]
    val_idx = indices[-n_val:]
    print(f"[INFO] Train: {len(train_idx)}, Val: {len(val_idx)}")

    val_aug = A.Compose([A.Resize(768, 768)], additional_targets={'depth': 'image', 'height': 'image'})
    full_ds_resized = NYUv2DatasetFixed(dataset_root, split="train", transform=val_aug)

    ds_train_mining = Subset(full_ds_resized, train_idx)
    ds_val = Subset(full_ds_resized, val_idx)

    # Leak Check
    check_loader = DataLoader(ds_val, batch_size=8, shuffle=False, num_workers=4)
    hist = np.zeros((NUM_CLASSES, NUM_CLASSES))
    with torch.no_grad():
        for x, y in tqdm(check_loader, desc="Base mIoU Check"):
            x = x.to(device)
            pred = model(x).argmax(dim=1).cpu().numpy()
            y = y.numpy()
            hist += fast_hist(y.flatten(), pred.flatten(), NUM_CLASSES)
    base_miou, base_iou_per_class = compute_miou_simple(hist)
    print(f"\n[BASELINE] Val mIoU: {base_miou:.5f}")

    # Collect Data
    # ★ Strategy Cのキモ: Trainデータに対してもBaseモデルで推論させ、「間違え方」を集める
    print(f"\n[STEP 1] Mining Error Patterns from Train Set...")
    loader_train = DataLoader(ds_train_mining, batch_size=1, shuffle=False, num_workers=4)
    df_train = collect_blobs_from_pred(model, loader_train, device, "Mining Train Preds")

    print(f"[STEP 2] Mining Blobs from Val Set...")
    loader_val = DataLoader(ds_val, batch_size=1, shuffle=False, num_workers=4)
    df_val = collect_blobs_from_pred(model, loader_val, device, "Mining Val Preds")

    if len(df_train) == 0:
        print("[ERROR] No blobs collected.")
        return

    df_train.to_csv(os.path.join(OUTPUT_DIR, "gbdt_train_noisy.csv"), index=False)

    # Train GBDT
    print(f"\n[STEP 3] Training LightGBM on Noisy Blobs...")

    # 学習データのフィルタリング:
    # ターゲットクラス(1,5,6)以外の正解ラベル（誤検知）は "99" (Other) にまとめる
    def adjust_label(x):
        if x in TARGET_IDS: return x
        return 99

    X_train = df_train.drop(columns=['target_label'])
    y_train = df_train['target_label'].apply(adjust_label)

    X_val = df_val.drop(columns=['target_label'])
    y_val = df_val['target_label'].apply(adjust_label) # Eval用

    cat_features = ['pred_class_id']
    for col in cat_features:
        X_train[col] = X_train[col].astype('category')
        X_val[col] = X_val[col].astype('category')

    # Class Weights: Book(1)の学習を最優先にする
    # Other(99)は頻度が高いのでウェイトを下げる
    custom_weights = {
        ID_BOOK: 5.0,     # Book (最重要)
        ID_CABINET: 1.0,  # Cabinet
        ID_OBJECT: 1.0,   # Object
        99: 0.5           # Other (Wall, Floorなど)
    }

    lgb_model = lgb.LGBMClassifier(
        objective='multiclass', num_class=100, n_estimators=1000, learning_rate=0.03,
        num_leaves=31, class_weight=custom_weights, random_state=42, n_jobs=-1,
        verbose=-1
    )

    lgb_model.fit(
        X_train, y_train,
        eval_set=[(X_val, y_val)],
        callbacks=[lgb.early_stopping(stopping_rounds=100), lgb.log_evaluation(period=0)]
    )

    # Feature Importance
    print("\n[ANALYSIS] GBDT Feature Importance:")
    imp_df = pd.DataFrame({'feat': X_train.columns, 'gain': lgb_model.feature_importances_}).sort_values('gain', ascending=False)
    print(imp_df.head(10))

    # Apply Correction
    print(f"\n[STEP 4] Applying Correction...")
    corrected_hist = hist.copy()

    gbdt_probs = lgb_model.predict_proba(X_val)
    classes = lgb_model.classes_

    df_val_reset = df_val.reset_index(drop=True)
    changes_count = 0
    success_change = 0

    for i, row in df_val_reset.iterrows():
        base_cls = int(row['pred_class_id'])
        true_cls = int(row['target_label'])

        top_idx = np.argmax(gbdt_probs[i])
        gbdt_cls = classes[top_idx]
        prob = gbdt_probs[i][top_idx]

        # 補正ロジック:
        # 1. GBDTが「Other(99)」以外と予測し、
        # 2. それが元の予測と異なり、
        # 3. 確信度が閾値(0.6: 学習データがノイジーなので少し甘めに設定)を超えている場合
        if gbdt_cls != 99 and gbdt_cls != base_cls and prob > 0.6:
            area = int(row['area'])
            if true_cls < NUM_CLASSES:
                corrected_hist[true_cls, base_cls] -= area
                corrected_hist[true_cls, gbdt_cls] += area
            changes_count += 1
            if gbdt_cls == true_cls: success_change += 1

    new_miou, new_iou_per_class = compute_miou_simple(corrected_hist)

    print(f"\n[RESULT] Correction Complete.")
    print(f"  -> Blobs Corrected: {changes_count} / {len(df_val)}")
    if changes_count > 0:
        print(f"  -> Correction Accuracy: {success_change / changes_count:.2%}")

    print(f"\n[COMPARISON]")
    print(f"  mIoU:    {base_miou:.5f} -> {new_miou:.5f} ({new_miou - base_miou:+.5f})")
    print(f"  Book:    {base_iou_per_class[ID_BOOK]:.5f} -> {new_iou_per_class[ID_BOOK]:.5f} ({new_iou_per_class[ID_BOOK] - base_iou_per_class[ID_BOOK]:+.5f})")

    lgb_model.booster_.save_model(os.path.join(OUTPUT_DIR, "gbdt_correction_model_C.txt"))
    print(f"\n[DONE] GBDT Model saved.")

if __name__ == "__main__":
    run_gbdt_correction_pipeline_C()

In [None]:
import os
import glob
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import datetime
import pytz

def get_jst_time():
    return datetime.datetime.now(pytz.timezone('Asia/Tokyo')).strftime('%Y-%m-%d %H:%M:%S JST')

def analyze_dataset(data_root):
    """
    NYUv2データセットの各クラスの特徴を把握するための分析画像を生成する
    """
    print(f"[{get_jst_time()}] 分析を開始します: {data_root}")

    # パスの設定 (ユーザー提示の構造に基づく)
    image_dir = os.path.join(data_root, 'train', 'image')
    label_dir = os.path.join(data_root, 'train', 'label')

    # 出力ディレクトリの設定
    output_dir = '/content/class_analysis_results'
    os.makedirs(output_dir, exist_ok=True)
    print(f"[{get_jst_time()}] 出力ディレクトリを作成しました: {os.path.abspath(output_dir)}")

    # ファイルリストの取得
    label_files = sorted(glob.glob(os.path.join(label_dir, '*.png')))
    if not label_files:
        print(f"[{get_jst_time()}] エラー: ラベルファイルが見つかりません。パスを確認してください: {label_dir}")
        return

    print(f"[{get_jst_time()}] 対象ファイル数: {len(label_files)}")

    num_classes = 13
    # 各クラスについて、最もピクセル数が多い画像のファイル名とスコアを保持
    best_images = {i: {"file": None, "score": 0} for i in range(num_classes)}
    # クラスごとの総ピクセル数（分布確認用）
    class_counts = np.zeros(num_classes, dtype=np.int64)

    print(f"[{get_jst_time()}] 全データをスキャンしてクラス統計を取得中...")

    for lbl_path in tqdm(label_files):
        # ラベル読み込み
        lbl = np.array(Image.open(lbl_path))

        # クラスIDごとのピクセル数をカウント
        # 注意: 255はignoreなので除外、0-12を対象
        ids, counts = np.unique(lbl, return_counts=True)

        filename = os.path.basename(lbl_path)

        for id_val, count in zip(ids, counts):
            if 0 <= id_val < num_classes:
                class_counts[id_val] += count

                # そのクラスが最も大きく写っている画像を更新
                if count > best_images[id_val]["score"]:
                    best_images[id_val]["file"] = filename
                    best_images[id_val]["score"] = count

    # 結果の出力と可視化生成
    print(f"[{get_jst_time()}] クラスごとの代表画像を生成します...")

    # カラーマップの定義 (可視化用)
    cmap = plt.get_cmap('tab20')

    fig_summary, ax_summary = plt.subplots(figsize=(12, 6))
    ax_summary.bar(range(num_classes), class_counts)
    ax_summary.set_title("Class Distribution (Pixel Counts)")
    ax_summary.set_xlabel("Class ID")
    ax_summary.set_ylabel("Total Pixels")
    ax_summary.set_xticks(range(num_classes))
    dist_path = os.path.join(output_dir, '00_class_distribution.png')
    plt.savefig(dist_path)
    plt.close()
    print(f"[{get_jst_time()}] 分布グラフを保存しました: {os.path.abspath(dist_path)}")

    for cls_id in range(num_classes):
        info = best_images[cls_id]
        if info["file"] is None:
            print(f"[{get_jst_time()}] Class {cls_id}: 該当するデータが見つかりませんでした。")
            continue

        fname = info["file"]
        img_path = os.path.join(image_dir, fname)
        lbl_path = os.path.join(label_dir, fname)

        # 画像読み込み
        img = Image.open(img_path).convert('RGB')
        lbl = np.array(Image.open(lbl_path))

        # プロット作成
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))

        # 1. 元画像
        axes[0].imshow(img)
        axes[0].set_title(f"Original Image: {fname}")
        axes[0].axis('off')

        # 2. そのクラスの領域のみ抽出 (マスク)
        mask = (lbl == cls_id)
        img_np = np.array(img)
        overlay = img_np.copy()
        # 赤色で強調表示
        overlay[mask] = [255, 0, 0]
        # 元画像とブレンド
        blended = (img_np * 0.5 + overlay * 0.5).astype(np.uint8)

        axes[1].imshow(blended)
        axes[1].set_title(f"Class {cls_id} Mask (Red Area)")
        axes[1].axis('off')

        # 3. 全体のセグメンテーションマップ (参考)
        axes[2].imshow(lbl, cmap='tab20', vmin=0, vmax=19) # tab20は20色あるので0-12をカバー
        axes[2].set_title("Full Segmentation Label")
        axes[2].axis('off')

        save_path = os.path.join(output_dir, f'class_{cls_id:02d}_analysis.png')
        plt.savefig(save_path)
        plt.close()
        print(f"[{get_jst_time()}] Class {cls_id} の分析画像を保存しました: {os.path.abspath(save_path)}")

    print(f"[{get_jst_time()}] すべての処理が完了しました。")

# 実行
# 注意: データセットのルートディレクトリを指定してください。
# ユーザープロンプトに基づき '/content/data/NYUv2' と仮定しています。
if __name__ == "__main__":
    analyze_dataset('/content/data')

In [None]:
import os
import glob
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import datetime
import pytz
import shutil
import scipy.ndimage as ndimage
import pandas as pd

def get_jst_time():
    return datetime.datetime.now(pytz.timezone('Asia/Tokyo')).strftime('%Y-%m-%d %H:%M:%S JST')

def analyze_class1_with_depth(data_root):
    """
    Class 1 (Books) に特化した詳細分析を行う（Depth情報付き）
    """
    start_time = get_jst_time()
    print(f"[{start_time}] Class 1 詳細分析（Depth対応）を開始します: {data_root}")

    # パス設定
    train_label_dir = os.path.join(data_root, 'train', 'label')
    train_image_dir = os.path.join(data_root, 'train', 'image')
    train_depth_dir = os.path.join(data_root, 'train', 'depth') # Depthディレクトリ追加

    output_base = '/content/class1_deep_analysis_v2'
    crop_dir = os.path.join(output_base, 'crops')
    os.makedirs(output_base, exist_ok=True)
    os.makedirs(crop_dir, exist_ok=True)
    print(f"[{get_jst_time()}] 出力ディレクトリ作成: {os.path.abspath(output_base)}")

    label_files = sorted(glob.glob(os.path.join(train_label_dir, '*.png')))

    # 統計用データ格納
    adj_class_counts = np.zeros(13, dtype=np.int64) # 隣接クラスのカウント
    pixel_colors = [] # RGB値 (サンプリング)
    pixel_depths = [] # Depth値 (サンプリング)
    depth_diffs = []  # 本とその周囲の深度差 (Context Contrast)
    blob_sizes = []   # 連結成分のサイズ (ピクセル数)

    sample_count = 0
    max_samples = 100000

    print(f"[{get_jst_time()}] データスキャン開始 ({len(label_files)} files)...")

    for lbl_file in tqdm(label_files):
        fname = os.path.basename(lbl_file)
        lbl_path = lbl_file
        img_path = os.path.join(train_image_dir, fname)
        depth_path = os.path.join(train_depth_dir, fname) # Depthパス

        # ラベル読み込み
        lbl = np.array(Image.open(lbl_path))

        # Class 1 が含まれていない画像はスキップ
        if 1 not in lbl:
            continue

        img = np.array(Image.open(img_path).convert('RGB'))

        # Depth読み込み (NYUv2は通常PNGなどで保存されていると想定)
        # 画像として読み込み、numpy配列化
        depth = np.array(Image.open(depth_path))

        # 1. Class 1 のマスク作成
        mask_c1 = (lbl == 1)

        # 2. サイズ分布 & 深度差分析 (連結成分ごと)
        labeled_array, num_features = ndimage.label(mask_c1)

        for feat_id in range(1, num_features + 1):
            blob_mask = (labeled_array == feat_id)
            size = np.sum(blob_mask)
            blob_sizes.append(size)

            # --- 深度差分析 ---
            # このBlobの「直近の周囲」を取得
            dilated = ndimage.binary_dilation(blob_mask, iterations=2)
            boundary = dilated & (~blob_mask) & (lbl != 255) # ignore領域は除く

            if np.sum(boundary) > 0:
                # 本の平均深度
                blob_depth_mean = np.mean(depth[blob_mask])
                # 周囲の平均深度
                boundary_depth_mean = np.mean(depth[boundary])

                # 差分: (本 - 周囲)
                # 正なら本の方が奥、負なら本の方が手前
                diff = blob_depth_mean - boundary_depth_mean
                depth_diffs.append(diff)

        # 3. 隣接クラス分析
        dilation_struct = ndimage.generate_binary_structure(2, 2)
        dilated_mask = ndimage.binary_dilation(mask_c1, structure=dilation_struct)
        boundary_mask = dilated_mask & (~mask_c1)
        boundary_labels = lbl[boundary_mask]
        boundary_labels = boundary_labels[boundary_labels != 255]

        for b_lbl in boundary_labels:
            if 0 <= b_lbl < 13:
                adj_class_counts[b_lbl] += 1

        # 4. 色・深度分布 (ピクセルサンプリング)
        if len(pixel_colors) < max_samples:
            c1_pixels_rgb = img[mask_c1]
            c1_pixels_depth = depth[mask_c1]

            if len(c1_pixels_rgb) > 500:
                indices = np.random.choice(len(c1_pixels_rgb), 500, replace=False)
                c1_pixels_rgb = c1_pixels_rgb[indices]
                c1_pixels_depth = c1_pixels_depth[indices]

            pixel_colors.extend(c1_pixels_rgb)
            pixel_depths.extend(c1_pixels_depth)

        # 5. クロップ画像の保存 (RGB + Depth + Label)
        if sample_count < 20 and np.sum(mask_c1) > 100:
            rows = np.any(mask_c1, axis=1)
            cols = np.any(mask_c1, axis=0)
            rmin, rmax = np.where(rows)[0][[0, -1]]
            cmin, cmax = np.where(cols)[0][[0, -1]]

            margin = 30
            h, w = lbl.shape
            rmin = max(0, rmin - margin)
            rmax = min(h, rmax + margin)
            cmin = max(0, cmin - margin)
            cmax = min(w, cmax + margin)

            crop_img = img[rmin:rmax, cmin:cmax]
            crop_depth = depth[rmin:rmax, cmin:cmax]
            crop_lbl = lbl[rmin:rmax, cmin:cmax]

            # 可視化
            fig, ax = plt.subplots(1, 3, figsize=(12, 4))

            # RGB
            ax[0].imshow(crop_img)
            ax[0].set_title("RGB Crop")
            ax[0].axis('off')

            # Depth (カラーマップ適用)
            # 見やすくするために正規化して表示
            d_min, d_max = crop_depth.min(), crop_depth.max()
            ax[1].imshow(crop_depth, cmap='inferno')
            ax[1].set_title("Depth Crop")
            ax[1].axis('off')

            # Label
            vis_lbl = np.zeros((crop_lbl.shape[0], crop_lbl.shape[1], 3), dtype=np.uint8)
            vis_lbl[crop_lbl == 1] = [255, 0, 0] # Class 1 = Red
            vis_lbl[crop_lbl != 1] = [50, 50, 50]
            ax[2].imshow(vis_lbl)
            ax[2].set_title("Label (Red=Books)")
            ax[2].axis('off')

            plt.savefig(os.path.join(crop_dir, f'crop_{fname}'))
            plt.close()
            sample_count += 1

    print(f"[{get_jst_time()}] 集計完了。グラフ生成中...")

    # --- 結果のプロット ---

    # 1. 隣接クラス分布
    plt.figure(figsize=(10, 6))
    total_adj = np.sum(adj_class_counts)
    if total_adj > 0:
        percentages = adj_class_counts / total_adj * 100
    else:
        percentages = adj_class_counts
    plt.bar(range(13), percentages)
    plt.title("Adjacency Distribution (Neighbors of Class 1)")
    plt.xlabel("Neighbor Class ID")
    plt.ylabel("Percentage (%)")
    plt.xticks(range(13))
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    plt.savefig(os.path.join(output_base, 'adjacency_distribution.png'))
    plt.close()

    # 2. RGBヒストグラム
    pixel_colors = np.array(pixel_colors)
    if len(pixel_colors) > 0:
        plt.figure(figsize=(10, 6))
        colors = ['red', 'green', 'blue']
        for i, color in enumerate(colors):
            plt.hist(pixel_colors[:, i], bins=50, color=color, alpha=0.5, label=color.upper())
        plt.title("RGB Histogram of Class 1")
        plt.legend()
        plt.savefig(os.path.join(output_base, 'rgb_histogram.png'))
        plt.close()

    # 3. Depthヒストグラム (絶対値)
    if len(pixel_depths) > 0:
        plt.figure(figsize=(10, 6))
        plt.hist(pixel_depths, bins=50, color='purple', alpha=0.7)
        plt.title("Depth Value Distribution of Class 1 (Absolute Distance)")
        plt.xlabel("Depth Value")
        plt.ylabel("Frequency")
        plt.savefig(os.path.join(output_base, 'depth_histogram.png'))
        plt.close()

    # 4. Depth差分ヒストグラム (相対値: 本 - 周囲)
    if len(depth_diffs) > 0:
        plt.figure(figsize=(10, 6))
        plt.hist(depth_diffs, bins=50, color='orange', alpha=0.7)
        plt.title("Depth Contrast: (Books Mean Depth) - (Surroundings Mean Depth)")
        plt.xlabel("Depth Difference (<0: Closer than surroundings, >0: Further)")
        plt.ylabel("Count (Blobs)")
        plt.axvline(x=0, color='k', linestyle='--')
        plt.savefig(os.path.join(output_base, 'depth_contrast_histogram.png'))
        plt.close()

    # 5. サイズ分布
    if len(blob_sizes) > 0:
        plt.figure(figsize=(10, 6))
        plt.hist(blob_sizes, bins=50, log=True)
        plt.title("Size Distribution of Class 1 Blobs")
        plt.savefig(os.path.join(output_base, 'size_distribution.png'))
        plt.close()

    # CSV/TXT出力
    adj_df = pd.DataFrame({'neighbor_class_id': range(13), 'count': adj_class_counts, 'percentage': percentages})
    adj_df.to_csv(os.path.join(output_base, 'adjacency_stats.csv'), index=False)

    with open(os.path.join(output_base, 'stats_summary.txt'), 'w') as f:
        f.write("=== Class 1 Statistics ===\n")
        f.write(f"Total Blobs Analyzed: {len(blob_sizes)}\n")
        if len(blob_sizes) > 0:
            f.write(f"Mean Blob Size: {np.mean(blob_sizes):.2f}\n")
        if len(depth_diffs) > 0:
            f.write(f"Mean Depth Diff (Books - Surround): {np.mean(depth_diffs):.2f}\n")
            f.write(f"  (Negative means books are closer/pop-out, Positive means they are recessed)\n")

    # ZIP圧縮
    shutil.make_archive('/content/class1_analysis_v2', 'zip', output_base)
    print(f"[{get_jst_time()}] 分析完了。結果を圧縮しました: /content/class1_analysis_v2.zip")

if __name__ == "__main__":
    analyze_class1_with_depth('/content/data')

In [None]:
import os
import glob
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import datetime
import pytz
import scipy.ndimage as ndimage
from sklearn.metrics import roc_auc_score, accuracy_score
import shutil

def get_jst_time():
    return datetime.datetime.now(pytz.timezone('Asia/Tokyo')).strftime('%Y-%m-%d %H:%M:%S JST')

def analyze_separation_probability(data_root):
    """
    Class 1, 5, 6 のエッジ特徴量分布を取得し、
    エッジ情報だけでどの程度の確率（AUC/Accuracy）で分離可能かを定量評価する。
    """
    start_time = get_jst_time()
    print(f"[{start_time}] 分離性能（識別確率）の分析を開始します: {data_root}")

    # パス設定
    train_label_dir = os.path.join(data_root, 'train', 'label')
    train_image_dir = os.path.join(data_root, 'train', 'image')

    output_dir = '/content/separation_analysis'
    os.makedirs(output_dir, exist_ok=True)
    print(f"[{get_jst_time()}] 出力ディレクトリ作成: {os.path.abspath(output_dir)}")

    label_files = sorted(glob.glob(os.path.join(train_label_dir, '*.png')))

    # 対象クラス
    target_classes = [1, 5, 6] # Books, Furniture, Objects
    class_names = {1: "Books", 5: "Furniture", 6: "Objects"}

    # データ格納用（メモリ節約のためサンプリング）
    # 各クラス最大 100,000 ピクセル分の値を保持
    max_samples = 100000
    collected_samples = {k: [] for k in target_classes}

    print(f"[{get_jst_time()}] データ収集中...")

    # ランダムにシャッフルしてスキャン（偏りを防ぐ）
    indices = np.random.permutation(len(label_files))

    for idx in tqdm(indices):
        # 全クラス十分集まったら終了
        if all(len(v) >= max_samples for v in collected_samples.values()):
            break

        lbl_path = label_files[idx]
        fname = os.path.basename(lbl_path)
        img_path = os.path.join(train_image_dir, fname)

        lbl = np.array(Image.open(lbl_path))

        # 対象クラスが含まれているか確認
        unique_lbls = np.unique(lbl)
        if not any(c in unique_lbls for c in target_classes):
            continue

        # 画像読み込み & エッジ計算
        img = Image.open(img_path).convert('L')
        img_arr = np.array(img, dtype=np.float32)

        gx = ndimage.sobel(img_arr, axis=1)
        gy = ndimage.sobel(img_arr, axis=0)
        magnitude = np.hypot(gx, gy)

        # 各クラスのピクセル値をサンプリング
        for cls_id in target_classes:
            if len(collected_samples[cls_id]) >= max_samples:
                continue

            mask = (lbl == cls_id)
            if np.sum(mask) > 0:
                vals = magnitude[mask]
                # 画像ごとに最大1000ピクセル抽出（多様性を確保）
                if len(vals) > 1000:
                    vals = np.random.choice(vals, 1000, replace=False)

                collected_samples[cls_id].extend(vals)

    # 配列化
    for k in collected_samples:
        collected_samples[k] = np.array(collected_samples[k])
        print(f"Class {k} samples: {len(collected_samples[k])}")

    print(f"[{get_jst_time()}] 識別確率（AUC/Accuracy）を計算中...")

    report_path = os.path.join(output_dir, 'separation_report.txt')

    with open(report_path, 'w') as f:
        f.write("=== Separation Analysis Report (Edge Magnitude) ===\n")

        # 比較ペア: 1 vs 5, 1 vs 6
        pairs = [(1, 5), (1, 6)]

        for c_pos, c_neg in pairs:
            # データ準備
            pos_data = collected_samples[c_pos] # Class 1 (Books)
            neg_data = collected_samples[c_neg] # Class 5 or 6

            if len(pos_data) == 0 or len(neg_data) == 0:
                f.write(f"\nInsufficient data for comparison: {class_names[c_pos]} vs {class_names[c_neg]}\n")
                continue

            # ラベル作成 (Pos=1, Neg=0)
            y_true = np.concatenate([np.ones(len(pos_data)), np.zeros(len(neg_data))])
            y_scores = np.concatenate([pos_data, neg_data])

            # 1. AUC計算
            auc = roc_auc_score(y_true, y_scores)

            # 2. 最適閾値でのAccuracy計算
            # 簡易探索: 0〜255の間でベストな閾値を探す
            best_acc = 0
            best_thresh = 0
            thresholds = np.linspace(0, 255, 256)

            # 高速化のためサンプリングして探索する場合もあるが、ここは単純ループ
            for th in thresholds:
                y_pred = (y_scores >= th).astype(int)
                acc = accuracy_score(y_true, y_pred)
                if acc > best_acc:
                    best_acc = acc
                    best_thresh = th

            msg = f"\nComparison: {class_names[c_pos]} vs {class_names[c_neg]}\n"
            msg += f"  - AUC Score: {auc:.4f} (Probability of correct ranking)\n"
            msg += f"  - Max Accuracy: {best_acc:.2%} (at threshold={best_thresh:.1f})\n"
            msg += f"  - Interpretation: Edge magnitude alone can distinguish them with {best_acc:.1%} accuracy.\n"

            print(msg)
            f.write(msg)

    print(f"[{get_jst_time()}] レポート保存完了: {os.path.abspath(report_path)}")

    # ヒストグラムのプロット
    plt.figure(figsize=(10, 6))
    colors = {1: 'red', 5: 'blue', 6: 'green'}
    for cls_id in target_classes:
        data = collected_samples[cls_id]
        if len(data) > 0:
            # 密度分布としてプロット
            plt.hist(data, bins=50, alpha=0.5, label=f"{class_names[cls_id]} (Mean:{np.mean(data):.1f})",
                     density=True, color=colors[cls_id])

    plt.title("Edge Magnitude Distribution (Class 1 vs 5 vs 6)")
    plt.xlabel("Edge Magnitude")
    plt.ylabel("Density")
    plt.legend()
    plt.grid(linestyle='--', alpha=0.5)

    plot_path = os.path.join(output_dir, 'separation_histogram.png')
    plt.savefig(plot_path)
    plt.close()
    print(f"[{get_jst_time()}] 分布図保存完了: {os.path.abspath(plot_path)}")

    # ZIP圧縮
    shutil.make_archive('/content/separation_analysis', 'zip', output_dir)
    print(f"[{get_jst_time()}] 結果圧縮完了: /content/separation_analysis.zip")

if __name__ == "__main__":
    analyze_separation_probability('/content/data')

In [None]:
import os
import random
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from tqdm import tqdm
from torch.cuda.amp import GradScaler # For AMP

# ==========================================
# 0. A100 Optimization Flags
# ==========================================
# TF32有効化 (A100必須設定)
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.benchmark = True # 入力サイズ固定なら高速化

# ==========================================
# 1. Configuration
# ==========================================
class Config:
    DATA_ROOT = '/content/data'
    ENCODER = 'resnet50'
    ENCODER_WEIGHTS = 'imagenet'
    CLASSES = 13
    INPUT_CHANNELS = 5

    # A100 80GB用に大幅増強
    IMG_SIZE = 512
    BATCH_SIZE = 64      # 8 -> 64 (VRAMに余裕があればもっといけます)
    NUM_WORKERS = 16     # CPU並列処理数
    EPOCHS = 50
    LR = 2e-4            # BS増に合わせて少し上げる
    WEIGHT_DECAY = 1e-4
    SEED = 42
    device = torch.device('cuda')

    # Masking
    MASK_PROB = 0.8
    MASK_RATIO = 0.4
    BLOCK_SIZE = 20

    CLASS_WEIGHTS = [1.0, 5.0, 1.0, 1.5, 1.0, 1.2, 2.0, 1.5, 1.2, 1.5, 5.0, 0.8, 1.2]

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

# ==========================================
# 2. Targeted Masking (Optimized)
# ==========================================
class TargetedBookMasking:
    def __init__(self, prob=0.5, mask_ratio=0.3, block_size=15):
        self.prob = prob
        self.mask_ratio = mask_ratio
        self.block_size = block_size

    def __call__(self, image, label):
        if random.random() > self.prob:
            return image

        # Numpy演算の高速化 (whereを使わずBoolean Indexing)
        ys, xs = np.where(label == 1)
        num_book_pixels = len(ys)

        if num_book_pixels == 0:
            return image

        pixels_to_mask = int(num_book_pixels * self.mask_ratio)
        # ブロック数を制限してループ回数を減らす
        num_blocks = max(1, min(pixels_to_mask // (self.block_size ** 2), 50))

        masked_image = image.copy()
        H, W = label.shape

        # ランダムサンプリングを一度に行う
        indices = np.random.choice(num_book_pixels, num_blocks)
        cys, cxs = ys[indices], xs[indices]

        bs_half = self.block_size // 2

        for cy, cx in zip(cys, cxs):
            y1 = max(0, cy - bs_half)
            y2 = min(H, cy + bs_half)
            x1 = max(0, cx - bs_half)
            x2 = min(W, cx + bs_half)
            masked_image[y1:y2, x1:x2, :] = 0.0

        return masked_image

# ==========================================
# 3. Dataset (Fast Loading)
# ==========================================
class NYUv2Dataset(Dataset):
    def __init__(self, root_dir, split='train', img_size=512, transform=None, book_masker=None):
        self.root_dir = root_dir
        self.img_size = img_size
        self.transform = transform
        self.book_masker = book_masker

        src_split = 'train' if split in ['train', 'val'] else 'test'
        self.img_dir = os.path.join(root_dir, src_split, 'image')
        self.depth_dir = os.path.join(root_dir, src_split, 'depth')
        self.label_dir = os.path.join(root_dir, src_split, 'label')

        if not os.path.exists(self.img_dir):
            self.ids = []
            print(f"[WARN] No images found in {self.img_dir}")
        else:
            self.ids = sorted([os.path.splitext(f)[0] for f in os.listdir(self.img_dir) if f.endswith('.png')])

    def __len__(self):
        return len(self.ids)

    def compute_edge_map(self, img_array):
        # 処理が軽いのでCPU(OpenCV)のままでOK
        gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
        sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
        sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
        magnitude = np.sqrt(sobelx**2 + sobely**2)
        magnitude = cv2.normalize(magnitude, None, 0, 1, cv2.NORM_MINMAX)
        return magnitude.astype(np.float32)

    def __getitem__(self, idx):
        id_ = self.ids[idx]

        # I/O bottleneckを防ぐため、OpenCVの読み込みは最低限に
        img_path = os.path.join(self.img_dir, f"{id_}.png")
        depth_path = os.path.join(self.depth_dir, f"{id_}.png")
        label_path = os.path.join(self.label_dir, f"{id_}.png")

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        depth = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH)
        if depth is None: depth = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32)
        if depth.max() > 0: depth = depth.astype(np.float32) / depth.max()
        else: depth = depth.astype(np.float32)

        label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)

        # Label Sanitize
        mask_valid = (label >= 0) & (label < Config.CLASSES)
        label[~mask_valid] = 255

        edge = self.compute_edge_map(image)

        depth = depth * 255.0
        edge = edge * 255.0

        combined_input = np.dstack([image, depth, edge]).astype(np.float32)

        if self.book_masker:
            combined_input = self.book_masker(combined_input, label)

        if self.transform:
            transformed = self.transform(image=combined_input, mask=label)
            combined_input = transformed['image']
            label = transformed['mask']
        else:
            resize = A.Resize(height=self.img_size, width=self.img_size)
            t = resize(image=combined_input, mask=label)
            img_np = t['image'].transpose(2, 0, 1).astype(np.float32)
            img_np /= 255.0
            combined_input = torch.from_numpy(img_np)
            label = torch.from_numpy(t['mask']).long()

        return combined_input, label.long()

# ==========================================
# 4. Transforms & Model
# ==========================================
def get_transforms(phase='train'):
    # A100でもCPU前処理がボトルネックになりうるのでシンプルな構成に
    if phase == 'train':
        return A.Compose([
            A.Resize(height=Config.IMG_SIZE, width=Config.IMG_SIZE),
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, p=0.5),
            A.Normalize(
                mean=[0.485, 0.456, 0.406, 0.5, 0.5],
                std=[0.229, 0.224, 0.225, 0.5, 0.5],
                max_pixel_value=255.0
            ),
            ToTensorV2()
        ])
    else:
        return A.Compose([
            A.Resize(height=Config.IMG_SIZE, width=Config.IMG_SIZE),
            A.Normalize(
                mean=[0.485, 0.456, 0.406, 0.5, 0.5],
                std=[0.229, 0.224, 0.225, 0.5, 0.5],
                max_pixel_value=255.0
            ),
            ToTensorV2()
        ])

def get_model():
    model = smp.DeepLabV3Plus(
        encoder_name=Config.ENCODER,
        encoder_weights=Config.ENCODER_WEIGHTS,
        in_channels=Config.INPUT_CHANNELS,
        classes=Config.CLASSES,
        activation=None
    )
    return model

class WeightedLoss(nn.Module):
    def __init__(self, class_weights, device, ignore_index=255):
        super().__init__()
        weight_tensor = torch.tensor(class_weights).float().to(device)
        self.ce_loss = nn.CrossEntropyLoss(weight=weight_tensor, ignore_index=ignore_index)
        self.dice_loss = smp.losses.DiceLoss(
            mode='multiclass', classes=[1, 10], log_loss=True, from_logits=True, ignore_index=ignore_index
        )

    def forward(self, logits, targets):
        return self.ce_loss(logits, targets) + 0.5 * self.dice_loss(logits, targets)

def compute_miou(pred_mask, label_mask, num_classes):
    # GPU上で計算してCPU転送を減らす
    ious = torch.zeros(num_classes, device=pred_mask.device)
    pred_mask = pred_mask.view(-1)
    label_mask = label_mask.view(-1)

    valid_mask = label_mask != 255
    pred_mask = pred_mask[valid_mask]
    label_mask = label_mask[valid_mask]

    for cls in range(num_classes):
        pred_inds = pred_mask == cls
        target_inds = label_mask == cls
        intersection = (pred_inds & target_inds).sum()
        union = pred_inds.sum() + target_inds.sum() - intersection
        if union == 0:
            ious[cls] = float('nan')
        else:
            ious[cls] = intersection / union

    return ious.cpu().numpy() # 最後に1回だけ転送

# ==========================================
# 5. Optimized Training Loop (AMP + Compile)
# ==========================================
def train_one_epoch(model, loader, criterion, optimizer, device, scaler):
    model.train()
    epoch_loss = 0

    # prefetch_factorでデータロード待ちを防ぐ
    for images, masks in tqdm(loader, desc="Training", leave=False):
        # non_blocking=Trueで転送待ちを隠蔽
        images = images.to(device, non_blocking=True)
        masks = masks.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True) # メモリ効率化

        # AMP (Automatic Mixed Precision) Context
        # A100はbfloat16が得意なので、dtype指定が有効な場合もあるが
        # デフォルト(float16)でも十分速い。ここでは汎用AMPを使用。
        with torch.amp.autocast('cuda', dtype=torch.bfloat16): # A100ならbfloat16推奨
            outputs = model(images)
            loss = criterion(outputs, masks)

        # Scaler update
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()

    return epoch_loss / len(loader)

def validate(model, loader, criterion, device):
    model.eval()
    epoch_loss = 0
    iou_scores = []

    with torch.no_grad():
        with torch.amp.autocast('cuda', dtype=torch.bfloat16):
            for images, masks in tqdm(loader, desc="Validation", leave=False):
                images = images.to(device, non_blocking=True)
                masks = masks.to(device, non_blocking=True)

                outputs = model(images)
                loss = criterion(outputs, masks)
                epoch_loss += loss.item()

                preds = torch.argmax(outputs, dim=1)
                ious = compute_miou(preds, masks, Config.CLASSES)
                iou_scores.append(ious)

    mean_ious = np.nanmean(np.array(iou_scores), axis=0)
    total_miou = np.nanmean(mean_ious)
    return epoch_loss / len(loader), total_miou, mean_ious

# ==========================================
# 6. Main
# ==========================================
def main():
    set_seed(Config.SEED)
    print(f"🚀 Start Hyper-Fast Training on A100")
    print(f"   Batch Size: {Config.BATCH_SIZE}, AMP: Enabled (BF16)")

    # Datasets
    masker = TargetedBookMasking(prob=Config.MASK_PROB, mask_ratio=Config.MASK_RATIO, block_size=Config.BLOCK_SIZE)

    full_dataset_train = NYUv2Dataset(Config.DATA_ROOT, split='train',
                                      transform=get_transforms('train'),
                                      book_masker=masker)
    full_dataset_val = NYUv2Dataset(Config.DATA_ROOT, split='train',
                                    transform=get_transforms('val'),
                                    book_masker=None)

    n_samples = len(full_dataset_train)
    if n_samples == 0: return

    indices = list(range(n_samples))
    random.shuffle(indices)

    n_val = int(n_samples * 0.1)
    train_idx = indices[:-n_val]
    val_idx = indices[-n_val:]

    train_dataset = Subset(full_dataset_train, train_idx)
    val_dataset = Subset(full_dataset_val, val_idx)

    # DataLoader Optimization
    # pin_memory=True, num_workers=増やす, persistent_workers=True
    train_loader = DataLoader(
        train_dataset,
        batch_size=Config.BATCH_SIZE,
        shuffle=True,
        num_workers=Config.NUM_WORKERS,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=Config.BATCH_SIZE,
        shuffle=False,
        num_workers=Config.NUM_WORKERS,
        pin_memory=True,
        persistent_workers=True
    )

    # Model Setup
    model = get_model().to(Config.device)

    # Tensor Core Optimization (Channels Last)
    # メモリレイアウトを変更して高速化
    model = model.to(memory_format=torch.channels_last)

    # torch.compile (PyTorch 2.x feature) - A100で劇的に効く
    print("Compiling model with torch.compile...")
    try:
        model = torch.compile(model)
    except Exception as e:
        print(f"Warning: torch.compile failed ({e}). Running without compilation.")

    # AMP Scaler
    scaler = torch.cuda.amp.GradScaler() # BF16なら本来不要だが、安定性のため入れてもOK

    criterion = WeightedLoss(Config.CLASS_WEIGHTS, Config.device, ignore_index=255)
    optimizer = optim.AdamW(model.parameters(), lr=Config.LR, weight_decay=Config.WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.EPOCHS, eta_min=1e-6)

    best_miou = 0.0

    for epoch in range(Config.EPOCHS):
        print(f"\nEpoch {epoch+1}/{Config.EPOCHS}")

        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, Config.device, scaler)
        val_loss, val_miou, class_ious = validate(model, val_loader, criterion, Config.device)

        scheduler.step()

        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        print(f"Val mIoU: {val_miou:.4f}")
        print(f"Books IoU: {class_ious[1]:.4f} | Furniture: {class_ious[5]:.4f}")

        if val_miou > best_miou:
            best_miou = val_miou
            torch.save(model.state_dict(), 'best_model_nyuv2_fast.pth')
            print(">>> Best Model Saved!")

if __name__ == '__main__':
    if os.path.exists(Config.DATA_ROOT):
        # CUDAエラー防止のためキャッシュクリア
        torch.cuda.empty_cache()
        main()
    else:
        print(f"Data root '{Config.DATA_ROOT}' not found.")

  _C._set_float32_matmul_precision(precision)


🚀 Start Hyper-Fast Training on A100
   Batch Size: 64, AMP: Enabled (BF16)
Compiling model with torch.compile...


  scaler = torch.cuda.amp.GradScaler() # BF16なら本来不要だが、安定性のため入れてもOK



Epoch 1/50




Train Loss: 3.9861 | Val Loss: 4.0485
Val mIoU: 0.0929
Books IoU: 0.1319 | Furniture: 0.0033
>>> Best Model Saved!

Epoch 2/50




Train Loss: 2.9792 | Val Loss: 3.7166
Val mIoU: 0.1349
Books IoU: 0.0613 | Furniture: 0.2168
>>> Best Model Saved!

Epoch 3/50




Train Loss: 2.4807 | Val Loss: 3.0491
Val mIoU: 0.1922
Books IoU: 0.0002 | Furniture: 0.3563
>>> Best Model Saved!

Epoch 4/50




Train Loss: 2.0017 | Val Loss: 2.3279
Val mIoU: 0.2853
Books IoU: 0.0693 | Furniture: 0.3908
>>> Best Model Saved!

Epoch 5/50




Train Loss: 1.7623 | Val Loss: 2.5738
Val mIoU: 0.2654
Books IoU: 0.0723 | Furniture: 0.2697

Epoch 6/50




Train Loss: 1.5375 | Val Loss: 2.1816
Val mIoU: 0.3127
Books IoU: 0.0415 | Furniture: 0.4323
>>> Best Model Saved!

Epoch 7/50




Train Loss: 1.3885 | Val Loss: 1.9228
Val mIoU: 0.4009
Books IoU: 0.1000 | Furniture: 0.4371
>>> Best Model Saved!

Epoch 8/50




Train Loss: 1.2002 | Val Loss: 2.0875
Val mIoU: 0.3801
Books IoU: 0.0698 | Furniture: 0.4255

Epoch 9/50




Train Loss: 1.0920 | Val Loss: 2.0239
Val mIoU: 0.4325
Books IoU: 0.0754 | Furniture: 0.4625
>>> Best Model Saved!

Epoch 10/50




Train Loss: 0.9718 | Val Loss: 2.0784
Val mIoU: 0.4706
Books IoU: 0.0492 | Furniture: 0.4808
>>> Best Model Saved!

Epoch 11/50




Train Loss: 0.9278 | Val Loss: 1.8812
Val mIoU: 0.4859
Books IoU: 0.0364 | Furniture: 0.4771
>>> Best Model Saved!

Epoch 12/50




Train Loss: 0.9128 | Val Loss: 2.3672
Val mIoU: 0.4570
Books IoU: 0.0089 | Furniture: 0.4194

Epoch 13/50




Train Loss: 0.8342 | Val Loss: 2.2808
Val mIoU: 0.4853
Books IoU: 0.0181 | Furniture: 0.4854

Epoch 14/50




Train Loss: 0.7739 | Val Loss: 2.3104
Val mIoU: 0.4951
Books IoU: 0.0097 | Furniture: 0.4803
>>> Best Model Saved!

Epoch 15/50




Train Loss: 0.7702 | Val Loss: 2.3348
Val mIoU: 0.4731
Books IoU: 0.0836 | Furniture: 0.4693

Epoch 16/50




Train Loss: 0.7051 | Val Loss: 2.1563
Val mIoU: 0.4990
Books IoU: 0.0616 | Furniture: 0.5014
>>> Best Model Saved!

Epoch 17/50




Train Loss: 0.6742 | Val Loss: 2.0765
Val mIoU: 0.5236
Books IoU: 0.0719 | Furniture: 0.4937
>>> Best Model Saved!

Epoch 18/50




Train Loss: 0.6071 | Val Loss: 1.9217
Val mIoU: 0.5032
Books IoU: 0.0783 | Furniture: 0.5007

Epoch 19/50




Train Loss: 0.5767 | Val Loss: 2.2234
Val mIoU: 0.5067
Books IoU: 0.0257 | Furniture: 0.5090

Epoch 20/50




Train Loss: 0.7595 | Val Loss: 1.8363
Val mIoU: 0.5161
Books IoU: 0.0852 | Furniture: 0.4755

Epoch 21/50




Train Loss: 0.6961 | Val Loss: 2.6382
Val mIoU: 0.4878
Books IoU: 0.0000 | Furniture: 0.4726

Epoch 22/50




Train Loss: 0.7677 | Val Loss: 2.3368
Val mIoU: 0.5036
Books IoU: 0.0167 | Furniture: 0.4959

Epoch 23/50




Train Loss: 0.6058 | Val Loss: 1.9183
Val mIoU: 0.5092
Books IoU: 0.0953 | Furniture: 0.4743

Epoch 24/50




Train Loss: 0.5831 | Val Loss: 1.8939
Val mIoU: 0.5340
Books IoU: 0.0806 | Furniture: 0.5200
>>> Best Model Saved!

Epoch 25/50




Train Loss: 0.5364 | Val Loss: 1.7601
Val mIoU: 0.5253
Books IoU: 0.1096 | Furniture: 0.5054

Epoch 26/50




Train Loss: 0.5065 | Val Loss: 1.7241
Val mIoU: 0.5436
Books IoU: 0.1092 | Furniture: 0.5297
>>> Best Model Saved!

Epoch 27/50




Train Loss: 0.4721 | Val Loss: 1.6256
Val mIoU: 0.5455
Books IoU: 0.1233 | Furniture: 0.5280
>>> Best Model Saved!

Epoch 28/50




Train Loss: 0.4529 | Val Loss: 1.4999
Val mIoU: 0.5363
Books IoU: 0.1370 | Furniture: 0.5240

Epoch 29/50




Train Loss: 0.4410 | Val Loss: 1.6134
Val mIoU: 0.5482
Books IoU: 0.1223 | Furniture: 0.5369
>>> Best Model Saved!

Epoch 30/50




Train Loss: 0.4661 | Val Loss: 1.5179
Val mIoU: 0.5445
Books IoU: 0.1482 | Furniture: 0.5152

Epoch 31/50




Train Loss: 0.4441 | Val Loss: 1.4206
Val mIoU: 0.5555
Books IoU: 0.1569 | Furniture: 0.5336
>>> Best Model Saved!

Epoch 32/50




Train Loss: 0.4673 | Val Loss: 1.4837
Val mIoU: 0.5523
Books IoU: 0.1499 | Furniture: 0.5231

Epoch 33/50




Train Loss: 0.4273 | Val Loss: 1.3548
Val mIoU: 0.5577
Books IoU: 0.1627 | Furniture: 0.5282
>>> Best Model Saved!

Epoch 34/50




Train Loss: 0.4012 | Val Loss: 1.4201
Val mIoU: 0.5559
Books IoU: 0.1537 | Furniture: 0.5292

Epoch 35/50




Train Loss: 0.5134 | Val Loss: 1.4888
Val mIoU: 0.5502
Books IoU: 0.1336 | Furniture: 0.5238

Epoch 36/50




Train Loss: 0.4053 | Val Loss: 1.5337
Val mIoU: 0.5386
Books IoU: 0.1342 | Furniture: 0.4970

Epoch 37/50




Train Loss: 0.4162 | Val Loss: 1.5166
Val mIoU: 0.5503
Books IoU: 0.1290 | Furniture: 0.5113

Epoch 38/50




Train Loss: 0.4109 | Val Loss: 1.4890
Val mIoU: 0.5563
Books IoU: 0.1362 | Furniture: 0.5246

Epoch 39/50




Train Loss: 0.5543 | Val Loss: 1.5052
Val mIoU: 0.5560
Books IoU: 0.1349 | Furniture: 0.5213

Epoch 40/50




Train Loss: 0.3981 | Val Loss: 1.4710
Val mIoU: 0.5501
Books IoU: 0.1466 | Furniture: 0.5196

Epoch 41/50




Train Loss: 0.3968 | Val Loss: 1.4410
Val mIoU: 0.5576
Books IoU: 0.1537 | Furniture: 0.5215

Epoch 42/50




Train Loss: 0.3935 | Val Loss: 1.4360
Val mIoU: 0.5578
Books IoU: 0.1542 | Furniture: 0.5192
>>> Best Model Saved!

Epoch 43/50




Train Loss: 0.4022 | Val Loss: 1.4519
Val mIoU: 0.5585
Books IoU: 0.1495 | Furniture: 0.5235
>>> Best Model Saved!

Epoch 44/50




Train Loss: 0.3817 | Val Loss: 1.4401
Val mIoU: 0.5593
Books IoU: 0.1531 | Furniture: 0.5262
>>> Best Model Saved!

Epoch 45/50




Train Loss: 0.4229 | Val Loss: 1.4344
Val mIoU: 0.5608
Books IoU: 0.1528 | Furniture: 0.5289
>>> Best Model Saved!

Epoch 46/50




Train Loss: 0.4390 | Val Loss: 1.4799
Val mIoU: 0.5582
Books IoU: 0.1381 | Furniture: 0.5283

Epoch 47/50




Train Loss: 0.3936 | Val Loss: 1.4528
Val mIoU: 0.5567
Books IoU: 0.1454 | Furniture: 0.5243

Epoch 48/50




Train Loss: 0.3779 | Val Loss: 1.4441
Val mIoU: 0.5582
Books IoU: 0.1483 | Furniture: 0.5261

Epoch 49/50




Train Loss: 0.3855 | Val Loss: 1.4552
Val mIoU: 0.5567
Books IoU: 0.1492 | Furniture: 0.5247

Epoch 50/50




Train Loss: 0.3917 | Val Loss: 1.4597
Val mIoU: 0.5591
Books IoU: 0.1471 | Furniture: 0.5270


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, Dataset
import segmentation_models_pytorch as smp
import numpy as np
import random
import os
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import torch.nn.functional as F

# ==========================================
# 1. Config (Focal Loss用に重みを調整)
# ==========================================
class Config:
    DATA_ROOT = '/content/data'
    ENCODER = 'resnet50'
    ENCODER_WEIGHTS = 'imagenet'
    CLASSES = 13
    INPUT_CHANNELS = 5
    IMG_SIZE = 512
    BATCH_SIZE = 64
    EPOCHS = 30     # 少し長く
    LR = 5e-4       # Focal Lossは勾配が小さいので少し下げるか維持
    SEED = 42
    device = torch.device('cuda')

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

# ==========================================
# 2. Edge-Injecting Adapter (New Architecture)
# ==========================================
class EdgeAwareAdapter(nn.Module):
    def __init__(self, in_channels=256, out_channels=13):
        super().__init__()

        # エッジ情報(1ch)を取り込むためのパス
        # 1ch(Edge) -> 64ch に拡張
        self.edge_conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )

        # Main Feature (256ch) + Edge Feature (64ch) = 320ch
        self.fusion = nn.Sequential(
            nn.Conv2d(in_channels + 64, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout(0.3), # 過学習防止
            nn.Conv2d(128, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv2d(64, out_channels, kernel_size=1) # Logit出力
        )

    def forward(self, features, edge_map):
        # features: [B, 256, H/4, W/4]
        # edge_map: [B, 1, H, W] -> Downsample -> [B, 1, H/4, W/4]

        # 1. エッジ画像を特徴マップのサイズに縮小
        edge_small = F.interpolate(edge_map, size=features.shape[2:], mode='bilinear', align_corners=False)

        # 2. エッジ特徴抽出
        edge_feat = self.edge_conv(edge_small)

        # 3. 特徴結合 (Concat)
        concat = torch.cat([features, edge_feat], dim=1)

        # 4. 融合と予測
        out = self.fusion(concat)

        return out

class BookRescueModel(nn.Module):
    def __init__(self, base_model_path):
        super().__init__()

        # Base Model
        self.base_model = smp.DeepLabV3Plus(
            encoder_name=Config.ENCODER,
            encoder_weights=None,
            in_channels=Config.INPUT_CHANNELS,
            classes=Config.CLASSES,
            activation=None
        )

        # Weights Load
        state_dict = torch.load(base_model_path, map_location=Config.device)
        new_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith('_orig_mod.'):
                new_state_dict[k[10:]] = v
            else:
                new_state_dict[k] = v
        self.base_model.load_state_dict(new_state_dict)

        # Freeze Base
        for param in self.base_model.parameters():
            param.requires_grad = False
        self.base_model.eval()

        # New Adapter
        self.adapter = EdgeAwareAdapter(in_channels=256, out_channels=Config.CLASSES)

    def forward(self, x):
        # x: [B, 5, H, W] (RGB, Depth, Edge)

        # 1. Base Model Forward
        with torch.no_grad():
            features = self.base_model.encoder(x)
            decoder_output = self.base_model.decoder(features) # [B, 256, 128, 128]
            base_logits = self.base_model.segmentation_head(decoder_output)

        # 2. Extract Edge Channel (Channel 4)
        # x[:, 4:5, :, :] keeps dimension as [B, 1, H, W]
        edge_map = x[:, 4:5, :, :]

        # 3. Adapter Forward with Edge Injection
        correction = self.adapter(decoder_output, edge_map)

        # 4. Upsample Correction
        correction = F.interpolate(
            correction,
            size=base_logits.shape[2:],
            mode='bilinear',
            align_corners=False
        )

        return base_logits + correction

# ==========================================
# 3. Dataset & Utils (Reuse previous definitions)
# ==========================================
# (Dataset, Transforms, compute_miou definitions are assumed to be same as before)
# Shortened for brevity. Ensure NYUv2Dataset and get_transforms are defined.

class NYUv2Dataset(Dataset):
    def __init__(self, root_dir, split='train', img_size=512, transform=None):
        self.root_dir = root_dir; self.img_size = img_size; self.transform = transform
        src_split = 'train' if split in ['train', 'val'] else 'test'
        self.img_dir = os.path.join(root_dir, src_split, 'image')
        self.depth_dir = os.path.join(root_dir, src_split, 'depth')
        self.label_dir = os.path.join(root_dir, src_split, 'label')
        self.ids = sorted([os.path.splitext(f)[0] for f in os.listdir(self.img_dir) if f.endswith('.png')])
    def __len__(self): return len(self.ids)
    def compute_edge_map(self, img_array):
        gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
        sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3); sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
        mag = np.sqrt(sobelx**2 + sobely**2)
        return cv2.normalize(mag, None, 0, 1, cv2.NORM_MINMAX).astype(np.float32)
    def __getitem__(self, idx):
        id_ = self.ids[idx]
        img_path = os.path.join(self.img_dir, f"{id_}.png"); depth_path = os.path.join(self.depth_dir, f"{id_}.png"); label_path = os.path.join(self.label_dir, f"{id_}.png")
        image = cv2.imread(img_path); image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        depth = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH)
        if depth is None: depth = np.zeros(image.shape[:2], np.float32)
        depth = (depth.astype(np.float32) / depth.max()) if depth.max() > 0 else depth.astype(np.float32)
        label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
        mask_valid = (label >= 0) & (label < Config.CLASSES); label[~mask_valid] = 255
        edge = self.compute_edge_map(image)
        depth = depth * 255.0; edge = edge * 255.0
        combined = np.dstack([image, depth, edge])
        if self.transform: t = self.transform(image=combined, mask=label); return t['image'], t['mask'].long()
        else:
            t = A.Resize(Config.IMG_SIZE, Config.IMG_SIZE)(image=combined, mask=label)
            return torch.from_numpy(t['image'].transpose(2,0,1).astype(np.float32)/255.0), torch.from_numpy(t['mask']).long()

def get_transforms(phase='train'):
    if phase == 'train':
        return A.Compose([A.Resize(Config.IMG_SIZE, Config.IMG_SIZE), A.HorizontalFlip(p=0.5), A.Affine(scale=(0.8, 1.2), translate_percent=(0.1, 0.1), rotate=(-15, 15), p=0.5), A.Normalize(mean=[0.485, 0.456, 0.406, 0.5, 0.5], std=[0.229, 0.224, 0.225, 0.5, 0.5], max_pixel_value=255.0), ToTensorV2()])
    else: return A.Compose([A.Resize(Config.IMG_SIZE, Config.IMG_SIZE), A.Normalize(mean=[0.485, 0.456, 0.406, 0.5, 0.5], std=[0.229, 0.224, 0.225, 0.5, 0.5], max_pixel_value=255.0), ToTensorV2()])
def compute_miou(pred_mask, label_mask, num_classes):
    ious = []; pred_mask = pred_mask.view(-1); label_mask = label_mask.view(-1); valid_mask = label_mask != 255; pred_mask = pred_mask[valid_mask]; label_mask = label_mask[valid_mask]
    for cls in range(num_classes):
        pred_inds = pred_mask == cls; target_inds = label_mask == cls; intersection = (pred_inds & target_inds).sum().item(); union = pred_inds.sum().item() + target_inds.sum().item() - intersection
        if union == 0: ious.append(float('nan'))
        else: ious.append(float(intersection) / float(union))
    return np.array(ious)

# ==========================================
# 4. Main with Focal Loss
# ==========================================
def main():
    set_seed(Config.SEED)
    print("Start Training Edge-Aware Adapter (Base Frozen)")
    print("Strategy: Direct Edge Injection + Focal Loss")

    full_ds_train = NYUv2Dataset(Config.DATA_ROOT, split='train', transform=get_transforms('train'))
    full_ds_val = NYUv2Dataset(Config.DATA_ROOT, split='train', transform=get_transforms('val'))
    if len(full_ds_train) == 0: return
    n = len(full_ds_train); indices = list(range(n)); random.shuffle(indices); n_val = int(n * 0.1)
    train_ds = Subset(full_ds_train, indices[:-n_val]); val_ds = Subset(full_ds_val, indices[-n_val:])
    train_loader = DataLoader(train_ds, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=True)

    base_weight_path = 'best_model_nyuv2_fast.pth'
    if not os.path.exists(base_weight_path):
        print(f"Base model weights not found at {base_weight_path}!")
        return

    model = BookRescueModel(base_weight_path).to(Config.device)
    optimizer = optim.AdamW(model.adapter.parameters(), lr=Config.LR, weight_decay=1e-3) # Weight Decay強化

    # ★ Focal Loss 導入 (過学習防止 & 難易度高いサンプル重視)
    # ignore_index=255に対応したFocalLoss
    criterion = smp.losses.FocalLoss(
        mode='multiclass',
        ignore_index=255,
        gamma=2.0 # 難易度の高いサンプルを重視する係数
    )

    best_miou = 0.0

    for epoch in range(Config.EPOCHS):
        print(f"\nEpoch {epoch+1}/{Config.EPOCHS}")
        model.train(); model.base_model.eval()
        train_loss = 0
        for imgs, masks in tqdm(train_loader, desc="Training Adapter"):
            imgs, masks = imgs.to(Config.device), masks.to(Config.device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        model.eval()
        val_loss = 0; iou_scores = []
        with torch.no_grad():
            for imgs, masks in tqdm(val_loader, desc="Validation"):
                imgs, masks = imgs.to(Config.device), masks.to(Config.device)
                outputs = model(imgs)
                loss = criterion(outputs, masks)
                val_loss += loss.item()
                preds = outputs.argmax(dim=1)
                ious = compute_miou(preds, masks, Config.CLASSES)
                iou_scores.append(ious)

        mean_ious = np.nanmean(np.array(iou_scores), axis=0)
        val_miou = np.nanmean(mean_ious)
        print(f"Train Loss: {train_loss/len(train_loader):.4f} | Val Loss: {val_loss/len(val_loader):.4f}")
        print(f"Val mIoU: {val_miou:.4f}")
        print(f"Books IoU: {mean_ious[1]:.4f} | Furniture: {mean_ious[5]:.4f}")

        if val_miou > best_miou:
            best_miou = val_miou
            torch.save(model.state_dict(), 'best_model_nyuv2_edge_adapter.pth')
            print(">>> Best Model Saved!")

if __name__ == '__main__':
    if os.path.exists(Config.DATA_ROOT):
        torch.cuda.empty_cache()
        main()

Start Training Edge-Aware Adapter (Base Frozen)
Strategy: Direct Edge Injection + Focal Loss

Epoch 1/30


Training Adapter: 100%|██████████| 12/12 [00:13<00:00,  1.15s/it]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.35s/it]


Train Loss: 2.2840 | Val Loss: 2.3703
Val mIoU: 0.5535
Books IoU: 0.1507 | Furniture: 0.5186
>>> Best Model Saved!

Epoch 2/30


Training Adapter: 100%|██████████| 12/12 [00:11<00:00,  1.03it/s]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.21s/it]


Train Loss: 1.6065 | Val Loss: 1.8247
Val mIoU: 0.5515
Books IoU: 0.1506 | Furniture: 0.5169

Epoch 3/30


Training Adapter: 100%|██████████| 12/12 [00:11<00:00,  1.06it/s]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.19s/it]


Train Loss: 1.2560 | Val Loss: 1.6601
Val mIoU: 0.5516
Books IoU: 0.1519 | Furniture: 0.5190

Epoch 4/30


Training Adapter: 100%|██████████| 12/12 [00:11<00:00,  1.02it/s]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.18s/it]


Train Loss: 0.9956 | Val Loss: 1.4888
Val mIoU: 0.5512
Books IoU: 0.1541 | Furniture: 0.5195

Epoch 5/30


Training Adapter: 100%|██████████| 12/12 [00:12<00:00,  1.01s/it]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.18s/it]


Train Loss: 0.8123 | Val Loss: 1.2950
Val mIoU: 0.5516
Books IoU: 0.1548 | Furniture: 0.5194

Epoch 6/30


Training Adapter: 100%|██████████| 12/12 [00:12<00:00,  1.01s/it]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.19s/it]


Train Loss: 0.6540 | Val Loss: 1.1158
Val mIoU: 0.5526
Books IoU: 0.1566 | Furniture: 0.5198

Epoch 7/30


Training Adapter: 100%|██████████| 12/12 [00:11<00:00,  1.09it/s]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.20s/it]


Train Loss: 0.5392 | Val Loss: 0.9437
Val mIoU: 0.5536
Books IoU: 0.1562 | Furniture: 0.5205
>>> Best Model Saved!

Epoch 8/30


Training Adapter: 100%|██████████| 12/12 [00:11<00:00,  1.08it/s]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.18s/it]


Train Loss: 0.4636 | Val Loss: 0.8669
Val mIoU: 0.5533
Books IoU: 0.1551 | Furniture: 0.5224

Epoch 9/30


Training Adapter: 100%|██████████| 12/12 [00:11<00:00,  1.02it/s]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.19s/it]


Train Loss: 0.3826 | Val Loss: 0.7762
Val mIoU: 0.5552
Books IoU: 0.1566 | Furniture: 0.5236
>>> Best Model Saved!

Epoch 10/30


Training Adapter: 100%|██████████| 12/12 [00:12<00:00,  1.02s/it]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.19s/it]


Train Loss: 0.3300 | Val Loss: 0.7177
Val mIoU: 0.5546
Books IoU: 0.1532 | Furniture: 0.5241

Epoch 11/30


Training Adapter: 100%|██████████| 12/12 [00:12<00:00,  1.04s/it]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.20s/it]


Train Loss: 0.2913 | Val Loss: 0.6619
Val mIoU: 0.5563
Books IoU: 0.1588 | Furniture: 0.5245
>>> Best Model Saved!

Epoch 12/30


Training Adapter: 100%|██████████| 12/12 [00:11<00:00,  1.06it/s]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.19s/it]


Train Loss: 0.2608 | Val Loss: 0.6221
Val mIoU: 0.5566
Books IoU: 0.1585 | Furniture: 0.5238
>>> Best Model Saved!

Epoch 13/30


Training Adapter: 100%|██████████| 12/12 [00:11<00:00,  1.00it/s]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.18s/it]


Train Loss: 0.2303 | Val Loss: 0.5888
Val mIoU: 0.5562
Books IoU: 0.1598 | Furniture: 0.5247

Epoch 14/30


Training Adapter: 100%|██████████| 12/12 [00:12<00:00,  1.02s/it]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.19s/it]


Train Loss: 0.2177 | Val Loss: 0.5673
Val mIoU: 0.5574
Books IoU: 0.1663 | Furniture: 0.5243
>>> Best Model Saved!

Epoch 15/30


Training Adapter: 100%|██████████| 12/12 [00:11<00:00,  1.03it/s]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.22s/it]


Train Loss: 0.2049 | Val Loss: 0.5482
Val mIoU: 0.5571
Books IoU: 0.1666 | Furniture: 0.5235

Epoch 16/30


Training Adapter: 100%|██████████| 12/12 [00:11<00:00,  1.04it/s]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.19s/it]


Train Loss: 0.1889 | Val Loss: 0.5239
Val mIoU: 0.5589
Books IoU: 0.1715 | Furniture: 0.5250
>>> Best Model Saved!

Epoch 17/30


Training Adapter: 100%|██████████| 12/12 [00:12<00:00,  1.00s/it]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.17s/it]


Train Loss: 0.1829 | Val Loss: 0.5020
Val mIoU: 0.5589
Books IoU: 0.1763 | Furniture: 0.5237

Epoch 18/30


Training Adapter: 100%|██████████| 12/12 [00:12<00:00,  1.05s/it]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.20s/it]


Train Loss: 0.1699 | Val Loss: 0.4928
Val mIoU: 0.5583
Books IoU: 0.1752 | Furniture: 0.5222

Epoch 19/30


Training Adapter: 100%|██████████| 12/12 [00:12<00:00,  1.04s/it]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.18s/it]


Train Loss: 0.1679 | Val Loss: 0.4788
Val mIoU: 0.5586
Books IoU: 0.1766 | Furniture: 0.5219

Epoch 20/30


Training Adapter: 100%|██████████| 12/12 [00:11<00:00,  1.01it/s]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.19s/it]


Train Loss: 0.1594 | Val Loss: 0.4707
Val mIoU: 0.5590
Books IoU: 0.1782 | Furniture: 0.5233
>>> Best Model Saved!

Epoch 21/30


Training Adapter: 100%|██████████| 12/12 [00:12<00:00,  1.04s/it]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.18s/it]


Train Loss: 0.1526 | Val Loss: 0.4708
Val mIoU: 0.5581
Books IoU: 0.1799 | Furniture: 0.5218

Epoch 22/30


Training Adapter: 100%|██████████| 12/12 [00:12<00:00,  1.03s/it]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.22s/it]


Train Loss: 0.1501 | Val Loss: 0.4637
Val mIoU: 0.5582
Books IoU: 0.1716 | Furniture: 0.5237

Epoch 23/30


Training Adapter: 100%|██████████| 12/12 [00:11<00:00,  1.05it/s]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.18s/it]


Train Loss: 0.1430 | Val Loss: 0.4508
Val mIoU: 0.5591
Books IoU: 0.1837 | Furniture: 0.5214
>>> Best Model Saved!

Epoch 24/30


Training Adapter: 100%|██████████| 12/12 [00:12<00:00,  1.02s/it]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.18s/it]


Train Loss: 0.1412 | Val Loss: 0.4444
Val mIoU: 0.5580
Books IoU: 0.1739 | Furniture: 0.5226

Epoch 25/30


Training Adapter: 100%|██████████| 12/12 [00:11<00:00,  1.02it/s]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.18s/it]


Train Loss: 0.1389 | Val Loss: 0.4397
Val mIoU: 0.5596
Books IoU: 0.1848 | Furniture: 0.5225
>>> Best Model Saved!

Epoch 26/30


Training Adapter: 100%|██████████| 12/12 [00:11<00:00,  1.03it/s]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.19s/it]


Train Loss: 0.1372 | Val Loss: 0.4360
Val mIoU: 0.5577
Books IoU: 0.1650 | Furniture: 0.5221

Epoch 27/30


Training Adapter: 100%|██████████| 12/12 [00:11<00:00,  1.06it/s]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.19s/it]


Train Loss: 0.1341 | Val Loss: 0.4329
Val mIoU: 0.5593
Books IoU: 0.1867 | Furniture: 0.5221

Epoch 28/30


Training Adapter: 100%|██████████| 12/12 [00:12<00:00,  1.01s/it]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.18s/it]


Train Loss: 0.1303 | Val Loss: 0.4354
Val mIoU: 0.5593
Books IoU: 0.1897 | Furniture: 0.5211

Epoch 29/30


Training Adapter: 100%|██████████| 12/12 [00:11<00:00,  1.04it/s]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.24s/it]


Train Loss: 0.1289 | Val Loss: 0.4291
Val mIoU: 0.5588
Books IoU: 0.1723 | Furniture: 0.5244

Epoch 30/30


Training Adapter: 100%|██████████| 12/12 [00:11<00:00,  1.01it/s]
Validation: 100%|██████████| 2/2 [00:04<00:00,  2.20s/it]


Train Loss: 0.1283 | Val Loss: 0.4159
Val mIoU: 0.5603
Books IoU: 0.1913 | Furniture: 0.5226
>>> Best Model Saved!


In [None]:
pip install timm segmentation-models-pytorch albumentations

Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl.metadata (17 kB)
Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl (154 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: segmentation-models-pytorch
Successfully installed segmentation-models-pytorch-0.5.0


In [None]:
"""
コードの目的:
    以前35%を出した「Zooming戦略」のコードから、データリーク（Validationデータの混入）を完全に排除した修正版。
    これにより、Zooming戦略の「真の実力」を測定する。

修正点:
    - データセット分割（Train/Val）をAugmentorの初期化より前に行う。
    - Augmentorに「学習に使ってよいインデックス」のみを渡し、Validation画像からの収集を禁止する。
"""

import os
import random
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from tqdm import tqdm
from torch.cuda.amp import GradScaler

# ==========================================
# 0. Optimization Flags
# ==========================================
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.benchmark = True

# ==========================================
# 1. Configuration
# ==========================================
class Config:
    DATA_ROOT = '/content/data'

    ENCODER = 'mit_b4'
    ENCODER_WEIGHTS = 'imagenet'

    CLASSES = 13
    INPUT_CHANNELS = 5

    CROP_SIZE = 320
    IMG_SIZE = 512

    BATCH_SIZE = 16
    EPOCHS = 50

    LR = 6e-5
    WEIGHT_DECAY = 0.01
    SEED = 42
    device = torch.device('cuda')

    PASTE_PROB = 0.7

    # Class Weights (Book重視)
    CLASS_WEIGHTS = [1.0, 10.0, 1.0, 1.2, 1.0, 1.2, 2.0, 1.2, 1.0, 1.2, 3.0, 0.8, 1.0]

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

# ==========================================
# 2. Augmentor (Leakage Fixed)
# ==========================================
class BookDepthAwareAugmentor:
    def __init__(self, dataset, valid_indices=None, prob=0.5):
        """
        valid_indices: 学習に使用してよい画像のインデックスリスト。
        """
        self.book_bank = []
        self.prob = prob
        self.collect_books(dataset, valid_indices)

    def collect_books(self, dataset, valid_indices):
        print("Collecting 'Book' blobs (Internal - Strict Train Only)...")

        # 指定されたインデックス（Train）のみを対象にする
        if valid_indices is not None:
            target_indices = valid_indices
            print(f"   Scanning restricted to {len(target_indices)} training images.")
        else:
            # 指定がない場合は全データ（危険だが互換性のため残す）
            target_indices = range(len(dataset.ids))
            print(f"   Scanning ALL {len(target_indices)} images (WARNING: Potential Leakage).")

        for idx in tqdm(target_indices, desc="Scanning"):
            id_ = dataset.ids[idx]
            label_path = os.path.join(dataset.label_dir, f"{id_}.png")
            label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
            if label is None: continue
            mask_valid = (label >= 0) & (label < Config.CLASSES); label[~mask_valid] = 255
            book_mask = (label == 1).astype(np.uint8)

            if np.sum(book_mask) == 0: continue

            num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(book_mask, connectivity=8)
            if num_labels <= 1: continue

            img_path = os.path.join(dataset.img_dir, f"{id_}.png")
            depth_path = os.path.join(dataset.depth_dir, f"{id_}.png")
            image = cv2.imread(img_path); image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            depth = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH)
            if depth is None: depth = np.zeros_like(label, dtype=np.float32)
            if depth.max() > 0: depth = depth.astype(np.float32) / depth.max()
            edge = dataset.compute_edge_map(image)

            for j in range(1, num_labels):
                area = stats[j, cv2.CC_STAT_AREA]
                if area < 300: continue
                x = stats[j, cv2.CC_STAT_LEFT]; y = stats[j, cv2.CC_STAT_TOP]
                w = stats[j, cv2.CC_STAT_WIDTH]; h = stats[j, cv2.CC_STAT_HEIGHT]
                blob_mask = (labels[y:y+h, x:x+w] == j).astype(np.uint8)
                blob_rgb = image[y:y+h, x:x+w]
                blob_depth = depth[y:y+h, x:x+w]
                blob_edge = edge[y:y+h, x:x+w]
                self.book_bank.append({'rgb': blob_rgb, 'depth': blob_depth, 'edge': blob_edge, 'mask': blob_mask, 'mean_depth': np.mean(blob_depth)})
        print(f"Collected {len(self.book_bank)} book blobs from training set.")

    def apply(self, image, label):
        if random.random() > self.prob or len(self.book_bank) == 0: return image, label
        H, W = label.shape
        mask_table = (label == 9).astype(np.uint8); mask_furn = (label == 5).astype(np.uint8); mask_floor = (label == 4).astype(np.uint8)
        target_mask = None
        if np.sum(mask_table) > 1000: target_mask = mask_table
        elif np.sum(mask_furn) > 1000: target_mask = mask_furn
        elif np.sum(mask_floor) > 5000: target_mask = mask_floor
        else: return image, label
        blob = random.choice(self.book_bank)
        ys, xs = np.where(target_mask > 0)
        idx = random.randint(0, len(ys) - 1); y_target, x_target = ys[idx], xs[idx]
        target_d_val = image[y_target, x_target, 3] / 255.0; source_d_val = blob['mean_depth']
        scale = 1.0
        if target_d_val > 0.01 and source_d_val > 0.01: scale = np.clip(source_d_val / target_d_val, 0.3, 1.8)
        blob_h, blob_w = blob['mask'].shape; new_w, new_h = int(blob_w * scale), int(blob_h * scale)
        if new_w <= 0 or new_h <= 0: return image, label
        blob_rgb = cv2.resize(blob['rgb'], (new_w, new_h)); blob_depth = cv2.resize(blob['depth'], (new_w, new_h)); blob_edge = cv2.resize(blob['edge'], (new_w, new_h)); blob_mask = cv2.resize(blob['mask'], (new_w, new_h), interpolation=cv2.INTER_NEAREST)
        if random.random() > 0.3:
            k = random.choice([1, 3])
            blob_rgb = np.rot90(blob_rgb, k); blob_depth = np.rot90(blob_depth, k); blob_edge = np.rot90(blob_edge, k); blob_mask = np.rot90(blob_mask, k)
            new_h, new_w = blob_rgb.shape[:2]
        if new_h >= H or new_w >= W: return image, label
        y_target = min(y_target, H - new_h); x_target = min(x_target, W - new_w)
        mask_3ch = np.stack([blob_mask]*3, axis=2)
        image[y_target:y_target+new_h, x_target:x_target+new_w, :3] = np.where(mask_3ch==1, blob_rgb.astype(np.float32), image[y_target:y_target+new_h, x_target:x_target+new_w, :3])
        target_d = image[y_target, x_target, 3]; blob_d_mean = np.mean(blob_depth) * 255.0; new_d = np.clip((blob_depth * 255.0) + (target_d - blob_d_mean), 0, 255)
        image[y_target:y_target+new_h, x_target:x_target+new_w, 3] = np.where(blob_mask==1, new_d, image[y_target:y_target+new_h, x_target:x_target+new_w, 3])
        image[y_target:y_target+new_h, x_target:x_target+new_w, 4] = np.where(blob_mask==1, blob_edge*255.0, image[y_target:y_target+new_h, x_target:x_target+new_w, 4])
        label[y_target:y_target+new_h, x_target:x_target+new_w] = np.where(blob_mask==1, 1, label[y_target:y_target+new_h, x_target:x_target+new_w])
        return image, label

# ==========================================
# 3. Dataset (Book-Centric Zoom)
# ==========================================
class NYUv2BookZoomDataset(Dataset):
    def __init__(self, root_dir, split='train', img_size=512, transform=None, augmentor=None):
        self.root_dir = root_dir; self.img_size = img_size; self.transform = transform
        self.augmentor = augmentor
        src_split = 'train' if split in ['train', 'val'] else 'test'
        self.img_dir = os.path.join(root_dir, src_split, 'image')
        self.depth_dir = os.path.join(root_dir, src_split, 'depth')
        self.label_dir = os.path.join(root_dir, src_split, 'label')
        if not os.path.exists(self.img_dir): self.ids = []
        else: self.ids = sorted([os.path.splitext(f)[0] for f in os.listdir(self.img_dir) if f.endswith('.png')])

    def __len__(self): return len(self.ids)

    def compute_edge_map(self, img_array):
        gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
        sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3); sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
        mag = np.sqrt(sobelx**2 + sobely**2)
        return cv2.normalize(mag, None, 0, 1, cv2.NORM_MINMAX).astype(np.float32)

    def get_book_centric_crop(self, image, label):
        ys, xs = np.where(label == 1)
        if len(ys) > 0:
            idx = random.randint(0, len(ys) - 1)
            cy, cx = ys[idx], xs[idx]
            h, w = image.shape[:2]
            crop_size = Config.CROP_SIZE
            y1 = max(0, cy - crop_size // 2 + random.randint(-50, 50))
            x1 = max(0, cx - crop_size // 2 + random.randint(-50, 50))
            if y1 + crop_size > h: y1 = h - crop_size
            if x1 + crop_size > w: x1 = w - crop_size
            y1 = max(0, y1); x1 = max(0, x1)
            image_crop = image[y1:y1+crop_size, x1:x1+crop_size]
            label_crop = label[y1:y1+crop_size, x1:x1+crop_size]
            return image_crop, label_crop
        else:
            h, w = image.shape[:2]
            crop_size = Config.CROP_SIZE
            y1 = random.randint(0, h - crop_size)
            x1 = random.randint(0, w - crop_size)
            return image[y1:y1+crop_size, x1:x1+crop_size], label[y1:y1+crop_size, x1:x1+crop_size]

    def __getitem__(self, idx):
        id_ = self.ids[idx]
        img_path = os.path.join(self.img_dir, f"{id_}.png")
        depth_path = os.path.join(self.depth_dir, f"{id_}.png")
        label_path = os.path.join(self.label_dir, f"{id_}.png")
        image = cv2.imread(img_path); image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        depth = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH)
        if depth is None: depth = np.zeros(image.shape[:2], np.float32)
        if depth.max() > 0: depth = depth.astype(np.float32) / depth.max()
        label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
        mask_valid = (label >= 0) & (label < Config.CLASSES); label[~mask_valid] = 255
        edge = self.compute_edge_map(image)
        depth = depth * 255.0; edge = edge * 255.0
        combined = np.dstack([image, depth, edge]).astype(np.float32)

        if self.augmentor:
            combined, label = self.augmentor.apply(combined, label)

        if self.transform:
            combined_crop, label_crop = self.get_book_centric_crop(combined, label)
            t = self.transform(image=combined_crop, mask=label_crop)
            return t['image'], t['mask'].long()
        else:
            t = A.Resize(Config.IMG_SIZE, Config.IMG_SIZE)(image=combined, mask=label)
            return torch.from_numpy(t['image'].transpose(2,0,1).astype(np.float32)/255.0), torch.from_numpy(t['mask']).long()

def get_transforms(phase='train'):
    if phase == 'train':
        return A.Compose([
            A.Resize(Config.IMG_SIZE, Config.IMG_SIZE),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.2),
            A.Normalize(mean=[0.485, 0.456, 0.406, 0.5, 0.5], std=[0.229, 0.224, 0.225, 0.5, 0.5], max_pixel_value=255.0),
            ToTensorV2()
        ])
    else:
        return A.Compose([
            A.Resize(Config.IMG_SIZE, Config.IMG_SIZE),
            A.Normalize(mean=[0.485, 0.456, 0.406, 0.5, 0.5], std=[0.229, 0.224, 0.225, 0.5, 0.5], max_pixel_value=255.0),
            ToTensorV2()
        ])

# ==========================================
# 4. Model & Loss
# ==========================================
def get_model():
    print(f"Loading {Config.ENCODER} with 5ch Hack...")
    model = smp.DeepLabV3Plus(encoder_name=Config.ENCODER, encoder_weights=Config.ENCODER_WEIGHTS, in_channels=3, classes=Config.CLASSES, activation=None)
    if hasattr(model.encoder, 'patch_embed1') and hasattr(model.encoder.patch_embed1, 'proj'):
        old = model.encoder.patch_embed1.proj
        new_l = nn.Conv2d(5, old.out_channels, old.kernel_size, old.stride, old.padding, bias=(old.bias is not None))
        with torch.no_grad():
            new_l.weight[:, :3, :, :] = old.weight
            new_l.weight[:, 3:, :, :] = torch.mean(old.weight, dim=1, keepdim=True).repeat(1, 2, 1, 1)
            if old.bias is not None: new_l.bias = old.bias
        model.encoder.patch_embed1.proj = new_l
    return model

class LovaszWeightedLoss(nn.Module):
    def __init__(self, class_weights, device, ignore_index=255):
        super().__init__()
        self.lovasz = smp.losses.LovaszLoss(mode='multiclass', ignore_index=ignore_index)
        weights = torch.tensor(class_weights).float().to(device)
        self.ce = nn.CrossEntropyLoss(weight=weights, ignore_index=ignore_index)

    def forward(self, logits, targets):
        return 0.7 * self.lovasz(logits, targets) + 0.3 * self.ce(logits, targets)

def compute_miou(pred_mask, label_mask, num_classes):
    ious = []; pred_mask = pred_mask.view(-1); label_mask = label_mask.view(-1); valid_mask = label_mask != 255; pred_mask = pred_mask[valid_mask]; label_mask = label_mask[valid_mask]
    for cls in range(num_classes):
        pred_inds = pred_mask == cls; target_inds = label_mask == cls; intersection = (pred_inds & target_inds).sum().item(); union = pred_inds.sum().item() + target_inds.sum().item() - intersection
        if union == 0: ious.append(float('nan'))
        else: ious.append(float(intersection) / float(union))
    return np.array(ious)

# ==========================================
# 5. Training Loop
# ==========================================
def train_one_epoch(model, loader, criterion, optimizer, device, scaler):
    model.train(); epoch_loss = 0
    for images, masks in tqdm(loader, desc="Training (Zoomed)", leave=False):
        images = images.to(device, non_blocking=True); masks = masks.to(device, non_blocking=True)
        optimizer.zero_grad()
        with torch.amp.autocast('cuda'):
            outputs = model(images); loss = criterion(outputs, masks)
        scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        epoch_loss += loss.item()
    return epoch_loss / len(loader)

def validate(model, loader, criterion, device):
    model.eval(); epoch_loss = 0; iou_scores = []
    with torch.no_grad():
        with torch.amp.autocast('cuda'):
            for images, masks in tqdm(loader, desc="Validation", leave=False):
                images = images.to(device, non_blocking=True); masks = masks.to(device, non_blocking=True)
                outputs = model(images); loss = criterion(outputs, masks); epoch_loss += loss.item()
                preds = outputs.argmax(dim=1); ious = compute_miou(preds, masks, Config.CLASSES); iou_scores.append(ious)
    mean_ious = np.nanmean(np.array(iou_scores), axis=0)
    return epoch_loss / len(loader), np.nanmean(mean_ious), mean_ious

# ==========================================
# 6. Main Execution (Leakage Free)
# ==========================================
def main():
    set_seed(Config.SEED)
    print("🚀 Start Training: Book-Centric Zoom Strategy (Strict Leakage Fix)")

    # 1. まずデータセット全体をロード
    base_ds = NYUv2BookZoomDataset(Config.DATA_ROOT, split='train')
    if len(base_ds) == 0: return

    # 2. ★修正: Augmentor初期化の「前」にTrain/Valを分割する
    n = len(base_ds)
    indices = list(range(n))
    random.shuffle(indices)
    n_val = int(n * 0.1)

    train_indices = indices[:-n_val] # 学習用ID
    val_indices = indices[-n_val:]   # 検証用ID (触れてはいけない)

    # 3. ★修正: Augmentorには「Train用ID」だけを渡す
    # これにより、Validation画像からのCopy-Paste素材収集を物理的に阻止する
    train_augmentor = BookDepthAwareAugmentor(base_ds, valid_indices=train_indices, prob=Config.PASTE_PROB)

    # 4. Dataset作成 (TrainにはAugmentor適用)
    # Note: augmentor引数は、getitem内で呼ばれる
    full_ds_train = NYUv2BookZoomDataset(Config.DATA_ROOT, split='train', transform=get_transforms('train'), augmentor=train_augmentor)
    full_ds_val = NYUv2BookZoomDataset(Config.DATA_ROOT, split='train', transform=get_transforms('val')) # ValはAugmentationなし

    # 5. Subset作成 (さっき分割したインデックスを使用)
    train_ds = Subset(full_ds_train, train_indices)
    val_ds = Subset(full_ds_val, val_indices)

    print(f"   Train Size: {len(train_ds)}")
    print(f"   Val Size: {len(val_ds)}")

    train_loader = DataLoader(train_ds, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=True)

    model = get_model().to(Config.device)
    model = model.to(memory_format=torch.channels_last)
    try: model = torch.compile(model); print(">>> Model Compiled.")
    except: pass

    optimizer = optim.AdamW(model.parameters(), lr=Config.LR, weight_decay=Config.WEIGHT_DECAY)
    criterion = LovaszWeightedLoss(Config.CLASS_WEIGHTS, Config.device, ignore_index=255)
    scaler = torch.amp.GradScaler('cuda')
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.EPOCHS, eta_min=1e-7)

    best_miou = 0.0
    for epoch in range(Config.EPOCHS):
        print(f"\nEpoch {epoch+1}/{Config.EPOCHS}")
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, Config.device, scaler)

        torch.cuda.empty_cache()
        val_loss, val_miou, class_ious = validate(model, val_loader, criterion, Config.device)
        scheduler.step()

        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
        print(f"Val mIoU: {val_miou:.4f}")
        print(f"Books IoU: {class_ious[1]:.4f} | Table: {class_ious[9]:.4f}")

        if val_miou > best_miou:
            best_miou = val_miou
            torch.save(model.state_dict(), 'best_model_nyuv2_zoom_strict.pth')
            print(">>> Best Model Saved!")

if __name__ == '__main__':
    if os.path.exists(Config.DATA_ROOT):
        torch.cuda.empty_cache()
        main()

🚀 Start Training: Book-Centric Zoom Strategy (Strict Leakage Fix)
Collecting 'Book' blobs (Internal - Strict Train Only)...
   Scanning restricted to 716 training images.


Scanning: 100%|██████████| 716/716 [00:04<00:00, 163.66it/s]


Collected 306 book blobs from training set.
   Train Size: 716
   Val Size: 79
Loading mit_b4 with 5ch Hack...
>>> Model Compiled.

Epoch 1/50




Train Loss: 1.1994 | Val Loss: 1.1782
Val mIoU: 0.1929
Books IoU: 0.2206 | Table: 0.0239
>>> Best Model Saved!

Epoch 2/50




Train Loss: 0.9575 | Val Loss: 0.9159
Val mIoU: 0.3190
Books IoU: 0.1826 | Table: 0.1391
>>> Best Model Saved!

Epoch 3/50




Train Loss: 0.8411 | Val Loss: 0.8138
Val mIoU: 0.3683
Books IoU: 0.1980 | Table: 0.3214
>>> Best Model Saved!

Epoch 4/50




Train Loss: 0.7694 | Val Loss: 0.7483
Val mIoU: 0.3933
Books IoU: 0.2363 | Table: 0.2543
>>> Best Model Saved!

Epoch 5/50




Train Loss: 0.6982 | Val Loss: 0.7088
Val mIoU: 0.4428
Books IoU: 0.2329 | Table: 0.2586
>>> Best Model Saved!

Epoch 6/50




Train Loss: 0.6538 | Val Loss: 0.6483
Val mIoU: 0.4965
Books IoU: 0.2687 | Table: 0.2915
>>> Best Model Saved!

Epoch 7/50




Train Loss: 0.6263 | Val Loss: 0.6160
Val mIoU: 0.5065
Books IoU: 0.2429 | Table: 0.2259
>>> Best Model Saved!

Epoch 8/50




Train Loss: 0.6027 | Val Loss: 0.6071
Val mIoU: 0.5357
Books IoU: 0.2238 | Table: 0.3230
>>> Best Model Saved!

Epoch 9/50




Train Loss: 0.5768 | Val Loss: 0.5798
Val mIoU: 0.5313
Books IoU: 0.2607 | Table: 0.3035

Epoch 10/50




Train Loss: 0.5721 | Val Loss: 0.5668
Val mIoU: 0.5941
Books IoU: 0.2364 | Table: 0.3097
>>> Best Model Saved!

Epoch 11/50




Train Loss: 0.5435 | Val Loss: 0.6088
Val mIoU: 0.5324
Books IoU: 0.3071 | Table: 0.2842

Epoch 12/50




Train Loss: 0.5280 | Val Loss: 0.5978
Val mIoU: 0.5535
Books IoU: 0.2365 | Table: 0.3267

Epoch 13/50




Train Loss: 0.5257 | Val Loss: 0.5339
Val mIoU: 0.5735
Books IoU: 0.3279 | Table: 0.2660

Epoch 14/50




Train Loss: 0.5088 | Val Loss: 0.5305
Val mIoU: 0.6109
Books IoU: 0.2952 | Table: 0.3512
>>> Best Model Saved!

Epoch 15/50




Train Loss: 0.4972 | Val Loss: 0.5550
Val mIoU: 0.5691
Books IoU: 0.2781 | Table: 0.2493

Epoch 16/50




Train Loss: 0.4863 | Val Loss: 0.5303
Val mIoU: 0.6134
Books IoU: 0.2884 | Table: 0.2729
>>> Best Model Saved!

Epoch 17/50




Train Loss: 0.4789 | Val Loss: 0.5305
Val mIoU: 0.6246
Books IoU: 0.2792 | Table: 0.4148
>>> Best Model Saved!

Epoch 18/50




Train Loss: 0.4688 | Val Loss: 0.5669
Val mIoU: 0.5730
Books IoU: 0.2480 | Table: 0.4043

Epoch 19/50




Train Loss: 0.4609 | Val Loss: 0.5413
Val mIoU: 0.6096
Books IoU: 0.2871 | Table: 0.3160

Epoch 20/50




Train Loss: 0.4571 | Val Loss: 0.5269
Val mIoU: 0.6110
Books IoU: 0.3598 | Table: 0.3480

Epoch 21/50




Train Loss: 0.4372 | Val Loss: 0.5809
Val mIoU: 0.5492
Books IoU: 0.2914 | Table: 0.3429

Epoch 22/50




Train Loss: 0.4377 | Val Loss: 0.5313
Val mIoU: 0.6019
Books IoU: 0.2780 | Table: 0.3219

Epoch 23/50




Train Loss: 0.4289 | Val Loss: 0.5586
Val mIoU: 0.6019
Books IoU: 0.2400 | Table: 0.2695

Epoch 24/50




Train Loss: 0.4399 | Val Loss: 0.5632
Val mIoU: 0.6008
Books IoU: 0.2896 | Table: 0.3370

Epoch 25/50




Train Loss: 0.4304 | Val Loss: 0.5481
Val mIoU: 0.6167
Books IoU: 0.3049 | Table: 0.3437

Epoch 26/50




Train Loss: 0.4248 | Val Loss: 0.5443
Val mIoU: 0.5864
Books IoU: 0.3038 | Table: 0.3391

Epoch 27/50




Train Loss: 0.4209 | Val Loss: 0.5410
Val mIoU: 0.5844
Books IoU: 0.2822 | Table: 0.3020

Epoch 28/50




Train Loss: 0.4243 | Val Loss: 0.5424
Val mIoU: 0.5981
Books IoU: 0.2875 | Table: 0.3830

Epoch 29/50




Train Loss: 0.4187 | Val Loss: 0.5449
Val mIoU: 0.5885
Books IoU: 0.2930 | Table: 0.3689

Epoch 30/50




Train Loss: 0.4156 | Val Loss: 0.4559
Val mIoU: 0.6385
Books IoU: 0.3465 | Table: 0.3049
>>> Best Model Saved!

Epoch 31/50




Train Loss: 0.4086 | Val Loss: 0.5432
Val mIoU: 0.6009
Books IoU: 0.2631 | Table: 0.3279

Epoch 32/50




Train Loss: 0.3996 | Val Loss: 0.5382
Val mIoU: 0.5907
Books IoU: 0.2788 | Table: 0.3981

Epoch 33/50




Train Loss: 0.3982 | Val Loss: 0.5186
Val mIoU: 0.6179
Books IoU: 0.2864 | Table: 0.2898

Epoch 34/50




Train Loss: 0.3927 | Val Loss: 0.5441
Val mIoU: 0.6113
Books IoU: 0.2651 | Table: 0.3377

Epoch 35/50




Train Loss: 0.3973 | Val Loss: 0.5260
Val mIoU: 0.6022
Books IoU: 0.2703 | Table: 0.3179

Epoch 36/50




Train Loss: 0.3896 | Val Loss: 0.5725
Val mIoU: 0.5851
Books IoU: 0.2728 | Table: 0.3317

Epoch 37/50




Train Loss: 0.4019 | Val Loss: 0.5670
Val mIoU: 0.5776
Books IoU: 0.2589 | Table: 0.3141

Epoch 38/50




Train Loss: 0.3923 | Val Loss: 0.5148
Val mIoU: 0.6534
Books IoU: 0.2840 | Table: 0.3462
>>> Best Model Saved!

Epoch 39/50




Train Loss: 0.3925 | Val Loss: 0.5397
Val mIoU: 0.6200
Books IoU: 0.3204 | Table: 0.2640

Epoch 40/50




Train Loss: 0.4019 | Val Loss: 0.5264
Val mIoU: 0.6081
Books IoU: 0.2672 | Table: 0.3262

Epoch 41/50




Train Loss: 0.4049 | Val Loss: 0.5108
Val mIoU: 0.6388
Books IoU: 0.3288 | Table: 0.3366

Epoch 42/50




Train Loss: 0.4038 | Val Loss: 0.5257
Val mIoU: 0.6376
Books IoU: 0.2986 | Table: 0.2804

Epoch 43/50




Train Loss: 0.3935 | Val Loss: 0.5232
Val mIoU: 0.6381
Books IoU: 0.2897 | Table: 0.3140

Epoch 44/50




Train Loss: 0.3818 | Val Loss: 0.5226
Val mIoU: 0.6134
Books IoU: 0.2489 | Table: 0.3450

Epoch 45/50




Train Loss: 0.3894 | Val Loss: 0.5231
Val mIoU: 0.6417
Books IoU: 0.2745 | Table: 0.3301

Epoch 46/50




Train Loss: 0.3828 | Val Loss: 0.5177
Val mIoU: 0.6312
Books IoU: 0.2945 | Table: 0.3616

Epoch 47/50




Train Loss: 0.3976 | Val Loss: 0.5060
Val mIoU: 0.6488
Books IoU: 0.2984 | Table: 0.3571

Epoch 48/50




Train Loss: 0.3881 | Val Loss: 0.5225
Val mIoU: 0.6352
Books IoU: 0.3267 | Table: 0.3120

Epoch 49/50




Train Loss: 0.3962 | Val Loss: 0.5477
Val mIoU: 0.6169
Books IoU: 0.3047 | Table: 0.2888

Epoch 50/50


                                                         

Train Loss: 0.3821 | Val Loss: 0.5264
Val mIoU: 0.6303
Books IoU: 0.2624 | Table: 0.3555




In [None]:
import os
import random
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from tqdm import tqdm
from torch.cuda.amp import GradScaler

# ==========================================
# 0. Optimization Flags
# ==========================================
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.benchmark = True

# ==========================================
# 1. Configuration
# ==========================================
class Config:
    DATA_ROOT = '/content/data'
    ENCODER = 'mit_b4'
    ENCODER_WEIGHTS = 'imagenet'
    CLASSES = 13
    INPUT_CHANNELS = 5
    CROP_SIZE = 320
    IMG_SIZE = 512
    BATCH_SIZE = 16
    EPOCHS = 50
    LR = 6e-5
    WEIGHT_DECAY = 0.01
    SEED = 42
    device = torch.device('cuda')
    PASTE_PROB = 0.7

    GOLDEN_SAMPLES = [
        {'id': '000066', 'bbox': [407, 205, 89, 142]},
        {'id': '000072', 'bbox': [366, 412, 79, 63]},
        {'id': '000097', 'bbox': [271, 363, 51, 64]},
        {'id': '000105', 'bbox': [159, 216, 69, 20]},
        {'id': '000107', 'bbox': [244, 280, 142, 42]},
        {'id': '000109', 'bbox': [41, 285, 104, 62]},
        {'id': '000177', 'bbox': [77, 284, 69, 19]},
        {'id': '000353', 'bbox': [257, 218, 76, 31]},
    ]

    CLASS_WEIGHTS = [1.0, 10.0, 1.0, 1.2, 1.0, 1.2, 2.0, 1.2, 1.0, 1.2, 3.0, 0.8, 1.0]

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

# ==========================================
# 2. Augmentor (修正済み)
# ==========================================
class BookDepthAwareAugmentor:
    def __init__(self, dataset, prob=0.5):
        self.book_bank = []
        self.prob = prob
        self.load_golden_samples(dataset)
        self.collect_books(dataset)

    def load_golden_samples(self, dataset):
        print("💎 Loading Golden Samples...")
        count = 0
        for item in Config.GOLDEN_SAMPLES:
            id_ = item['id']
            x, y, w, h = item['bbox']

            img_path = os.path.join(dataset.img_dir, f"{id_}.png")
            depth_path = os.path.join(dataset.depth_dir, f"{id_}.png")
            label_path = os.path.join(dataset.label_dir, f"{id_}.png")

            if not os.path.exists(img_path): continue

            image = cv2.imread(img_path); image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            depth = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH)
            if depth is None: depth = np.zeros(image.shape[:2], np.float32)
            if depth.max() > 0: depth = depth.astype(np.float32) / depth.max()
            edge = dataset.compute_edge_map(image)
            label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)

            roi_rgb = image[y:y+h, x:x+w]
            roi_depth = depth[y:y+h, x:x+w]
            roi_edge = edge[y:y+h, x:x+w]
            roi_label = label[y:y+h, x:x+w]
            blob_mask = (roi_label == 1).astype(np.uint8)

            if np.sum(blob_mask) == 0: continue
            self.book_bank.append({'rgb': roi_rgb, 'depth': roi_depth, 'edge': roi_edge, 'mask': blob_mask, 'mean_depth': np.mean(roi_depth)})
            count += 1

    def collect_books(self, dataset):
        print("Collecting 'Auto' book blobs...")
        unique_scan_ids = sorted(list(set(dataset.ids)))
        count = 0
        for id_ in tqdm(unique_scan_ids, desc="Scanning"):
            label_path = os.path.join(dataset.label_dir, f"{id_}.png")
            label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
            if label is None or not np.any(label == 1): continue

            book_mask = (label == 1).astype(np.uint8)
            num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(book_mask, connectivity=8)
            if num_labels <= 1: continue

            img_path = os.path.join(dataset.img_dir, f"{id_}.png")
            depth_path = os.path.join(dataset.depth_dir, f"{id_}.png")
            image = cv2.imread(img_path); image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            depth = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH)
            if depth is None: depth = np.zeros_like(label, dtype=np.float32)
            if depth.max() > 0: depth = depth.astype(np.float32) / depth.max()
            edge = dataset.compute_edge_map(image)

            for j in range(1, num_labels):
                area = stats[j, cv2.CC_STAT_AREA]
                if area < 300: continue
                x = stats[j, cv2.CC_STAT_LEFT]; y = stats[j, cv2.CC_STAT_TOP]
                w = stats[j, cv2.CC_STAT_WIDTH]; h = stats[j, cv2.CC_STAT_HEIGHT]

                aspect_ratio = float(w) / h
                if aspect_ratio > 10.0 or aspect_ratio < 0.1: continue
                rect_area = w * h
                extent = float(area) / rect_area
                if extent < 0.5: continue

                blob_mask = (labels[y:y+h, x:x+w] == j).astype(np.uint8)
                self.book_bank.append({
                    'rgb': image[y:y+h, x:x+w], 'depth': depth[y:y+h, x:x+w], 'edge': edge[y:y+h, x:x+w],
                    'mask': blob_mask, 'mean_depth': np.mean(depth[y:y+h, x:x+w])
                })
                count += 1
        print(f"✅ Collected {count} Auto Samples. Total Bank: {len(self.book_bank)}")

    def apply(self, image, label):
        if random.random() > self.prob or len(self.book_bank) == 0: return image, label
        H, W = label.shape
        mask_target = ((label == 9) | (label == 5) | (label == 4)).astype(np.uint8)
        ys, xs = np.where(mask_target > 0)
        if len(ys) == 0: return image, label

        idx = random.randint(0, len(ys) - 1)
        y_t, x_t = ys[idx], xs[idx]
        blob = random.choice(self.book_bank)

        target_d = image[y_t, x_t, 3] / 255.0
        source_d = blob['mean_depth']

        scale = 1.0
        if target_d > 0.01 and source_d > 0.01: scale = np.clip(source_d / target_d, 0.3, 1.8)

        blob_h, blob_w = blob['mask'].shape
        new_w, new_h = int(blob_w * scale), int(blob_h * scale)

        # 仮チェック (この時点では回転前)
        if new_w <= 0 or new_h <= 0: return image, label

        # Resize
        blob_rgb = cv2.resize(blob['rgb'], (new_w, new_h))
        blob_depth = cv2.resize(blob['depth'], (new_w, new_h))
        blob_edge = cv2.resize(blob['edge'], (new_w, new_h))
        blob_mask = cv2.resize(blob['mask'], (new_w, new_h), interpolation=cv2.INTER_NEAREST)

        # Rotation
        k = random.choice([0, 1, 2, 3])
        if k > 0:
            blob_rgb = np.rot90(blob_rgb, k); blob_depth = np.rot90(blob_depth, k)
            blob_edge = np.rot90(blob_edge, k); blob_mask = np.rot90(blob_mask, k)

        # ★重要: 回転後の最終サイズを取得して再チェック
        new_h, new_w = blob_rgb.shape[:2]
        if new_h >= H or new_w >= W: return image, label # 画像からはみ出すなら中止

        y_t = min(y_t, H - new_h); x_t = min(x_t, W - new_w)

        # Occlusion
        roi_depth = image[y_t:y_t+new_h, x_t:x_t+new_w, 3]

        # 安全策: もしスライスでサイズが合わなかったら中止 (端数処理対策)
        if roi_depth.shape[:2] != (new_h, new_w): return image, label

        new_d_map = (blob_depth * 255.0) + (image[y_t, x_t, 3] - (np.mean(blob_depth) * 255.0))
        is_in_front = new_d_map < (roi_depth + 5.0)
        final_mask = (blob_mask == 1) & is_in_front

        # Paste
        mask_3 = np.stack([final_mask]*3, axis=2)
        image[y_t:y_t+new_h, x_t:x_t+new_w, :3] = np.where(mask_3, blob_rgb, image[y_t:y_t+new_h, x_t:x_t+new_w, :3])
        image[y_t:y_t+new_h, x_t:x_t+new_w, 3] = np.where(final_mask, np.clip(new_d_map,0,255), image[y_t:y_t+new_h, x_t:x_t+new_w, 3])
        image[y_t:y_t+new_h, x_t:x_t+new_w, 4] = np.where(final_mask, blob_edge*255.0, image[y_t:y_t+new_h, x_t:x_t+new_w, 4])
        label[y_t:y_t+new_h, x_t:x_t+new_w] = np.where(final_mask, 1, label[y_t:y_t+new_h, x_t:x_t+new_w])

        return image, label

# ==========================================
# 3. Dataset (Modified for Safe Split)
# ==========================================
def get_transforms(phase='train'):
    if phase == 'train':
        color_aug = A.Compose([
            A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
        ])
        geo_aug = A.Compose([
            A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.3),
            A.Perspective(scale=(0.05, 0.1), p=0.3),
            A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=20, p=0.5),
            A.HorizontalFlip(p=0.5),
            A.Resize(Config.IMG_SIZE, Config.IMG_SIZE),
            A.Normalize(mean=[0.485, 0.456, 0.406, 0.5, 0.5], std=[0.229, 0.224, 0.225, 0.5, 0.5], max_pixel_value=255.0),
            ToTensorV2()
        ])
        return {'color': color_aug, 'geo': geo_aug}
    else:
        return A.Compose([
            A.Resize(Config.IMG_SIZE, Config.IMG_SIZE),
            A.Normalize(mean=[0.485, 0.456, 0.406, 0.5, 0.5], std=[0.229, 0.224, 0.225, 0.5, 0.5], max_pixel_value=255.0),
            ToTensorV2()
        ])

class NYUv2BookZoomDataset(Dataset):
    def __init__(self, root_dir, whitelist_ids=None, is_train=False, img_size=512, transform=None, augmentor=None):
        self.root_dir = root_dir; self.img_size = img_size; self.transform = transform
        self.augmentor = augmentor
        # Train/Val/Test構造に対応（今回はsplit='train'フォルダ内をさらに分割するため固定）
        self.img_dir = os.path.join(root_dir, 'train', 'image')
        self.depth_dir = os.path.join(root_dir, 'train', 'depth')
        self.label_dir = os.path.join(root_dir, 'train', 'label')

        # ★ここが変更点: 外部から渡されたIDリスト(whitelist_ids)だけを使う
        if whitelist_ids is not None:
            self.ids = whitelist_ids
        else:
            self.ids = sorted([os.path.splitext(f)[0] for f in os.listdir(self.img_dir) if f.endswith('.png')])

        # ★オーバーサンプリングは「学習モード」かつ「本がある」場合のみ
        if is_train:
            print(f"⚖️ Applying Oversampling to {len(self.ids)} images...")
            expanded_ids = []
            for id_ in self.ids:
                label_path = os.path.join(self.label_dir, f"{id_}.png")
                label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
                if label is not None and np.any(label == 1):
                    expanded_ids.extend([id_] * 3) # 3倍
                else:
                    expanded_ids.append(id_)
            self.ids = expanded_ids
            print(f"   -> Expanded to {len(self.ids)} images.")

    def __len__(self): return len(self.ids)

    def compute_edge_map(self, img_array):
        gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
        sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3); sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
        return cv2.normalize(np.sqrt(sobelx**2 + sobely**2), None, 0, 1, cv2.NORM_MINMAX).astype(np.float32)

    def get_book_centric_crop(self, image, label):
        ys, xs = np.where(label == 1)
        if len(ys) > 0:
            idx = random.randint(0, len(ys) - 1); cy, cx = ys[idx], xs[idx]
            h, w = image.shape[:2]; cs = Config.CROP_SIZE
            y1 = max(0, min(h-cs, cy - cs // 2 + random.randint(-50, 50)))
            x1 = max(0, min(w-cs, cx - cs // 2 + random.randint(-50, 50)))
            return image[y1:y1+cs, x1:x1+cs], label[y1:y1+cs, x1:x1+cs]
        else:
            h, w = image.shape[:2]; cs = Config.CROP_SIZE
            y1 = random.randint(0, h - cs); x1 = random.randint(0, w - cs)
            return image[y1:y1+cs, x1:x1+cs], label[y1:y1+cs, x1:x1+cs]

    def __getitem__(self, idx):
        id_ = self.ids[idx]
        img_path = os.path.join(self.img_dir, f"{id_}.png")
        depth_path = os.path.join(self.depth_dir, f"{id_}.png")
        label_path = os.path.join(self.label_dir, f"{id_}.png")
        image = cv2.imread(img_path); image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        depth = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH)
        if depth is None: depth = np.zeros(image.shape[:2], np.float32)
        if depth.max() > 0: depth = depth.astype(np.float32) / depth.max()
        label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
        mask_valid = (label >= 0) & (label < Config.CLASSES); label[~mask_valid] = 255
        edge = self.compute_edge_map(image)
        depth = depth * 255.0; edge = edge * 255.0
        combined = np.dstack([image, depth, edge]).astype(np.float32)

        if self.augmentor: combined, label = self.augmentor.apply(combined, label)
        combined_crop, label_crop = self.get_book_centric_crop(combined, label)

        if self.transform:
            if isinstance(self.transform, dict): # Train
                rgb_crop = combined_crop[:, :, :3].astype(np.uint8)
                extra_crop = combined_crop[:, :, 3:]
                rgb_aug = self.transform['color'](image=rgb_crop)['image']
                combined_aug = np.dstack([rgb_aug, extra_crop])
                t_geo = self.transform['geo'](image=combined_aug, mask=label_crop)
                return t_geo['image'], t_geo['mask'].long()
            else: # Val
                t = self.transform(image=combined_crop, mask=label_crop)
                return t['image'], t['mask'].long()
        return torch.from_numpy(combined_crop.transpose(2,0,1)).float(), torch.from_numpy(label_crop).long()

# ==========================================
# 4. Model & Metrics
# ==========================================
def get_model():
    try: model = smp.UPerNet(encoder_name=Config.ENCODER, encoder_weights=Config.ENCODER_WEIGHTS, in_channels=3, classes=Config.CLASSES, activation=None)
    except: model = smp.FPN(encoder_name=Config.ENCODER, encoder_weights=Config.ENCODER_WEIGHTS, in_channels=3, classes=Config.CLASSES, activation=None)
    if hasattr(model.encoder, 'patch_embed1') and hasattr(model.encoder.patch_embed1, 'proj'):
        old = model.encoder.patch_embed1.proj
        new_l = nn.Conv2d(5, old.out_channels, old.kernel_size, old.stride, old.padding, bias=(old.bias is not None))
        with torch.no_grad():
            new_l.weight[:, :3, :, :] = old.weight
            new_l.weight[:, 3:, :, :] = torch.mean(old.weight, dim=1, keepdim=True).repeat(1, 2, 1, 1)
            if old.bias is not None: new_l.bias = old.bias
        model.encoder.patch_embed1.proj = new_l
    return model

class LovaszWeightedLoss(nn.Module):
    def __init__(self, class_weights, device, ignore_index=255):
        super().__init__()
        self.lovasz = smp.losses.LovaszLoss(mode='multiclass', ignore_index=ignore_index)
        weights = torch.tensor(class_weights).float().to(device)
        self.ce = nn.CrossEntropyLoss(weight=weights, ignore_index=ignore_index)
    def forward(self, logits, targets): return 0.7 * self.lovasz(logits, targets) + 0.3 * self.ce(logits, targets)

def compute_metrics_batch(pred_mask, label_mask, num_classes):
    pred = pred_mask.view(-1); label = label_mask.view(-1); valid = label != 255
    pred = pred[valid]; label = label[valid]
    ious, precs, recs = [], [], []
    for cls in range(num_classes):
        p_inds = (pred == cls); t_inds = (label == cls)
        TP = (p_inds & t_inds).sum().item()
        FP = (p_inds & ~t_inds).sum().item()
        FN = (~p_inds & t_inds).sum().item()
        union = TP + FP + FN
        ious.append(TP/union if union > 0 else float('nan'))
        precs.append(TP/(TP+FP) if (TP+FP) > 0 else float('nan'))
        recs.append(TP/(TP+FN) if (TP+FN) > 0 else float('nan'))
    return np.array(ious), np.array(precs), np.array(recs)

# ==========================================
# 5. Main Execution
# ==========================================
def main():
    set_seed(Config.SEED)
    print("🚀 Start Training: Strict Split -> Train(Oversampled) / Val(Raw)")

    # 1. まず全ファイルのIDを取得
    img_dir_train = os.path.join(Config.DATA_ROOT, 'train', 'image')
    all_ids = sorted([os.path.splitext(f)[0] for f in os.listdir(img_dir_train) if f.endswith('.png')])

    if len(all_ids) == 0: print("Error: No data found."); return

    # 2. ★ここでIDレベルで完全に分割する (Leakage防止)
    # これで「A1はTrain, A2はVal」のような事故は100%起きない
    random.shuffle(all_ids)
    n_val = int(len(all_ids) * 0.1)
    val_ids = all_ids[:n_val]
    train_ids = all_ids[n_val:]

    print(f"Total Images: {len(all_ids)}")
    print(f" -> Train IDs: {len(train_ids)} (Before Oversampling)")
    print(f" -> Val IDs:   {len(val_ids)} (Raw, No Oversampling)")

    # 3. Augmentor (Train IDのみを渡す)
    # Datasetのインスタンスを一時的に作ってファイルパス解決させる
    temp_ds = NYUv2BookZoomDataset(Config.DATA_ROOT, whitelist_ids=train_ids, is_train=False)
    train_augmentor = BookDepthAwareAugmentor(temp_ds, prob=Config.PASTE_PROB)

    # 4. Dataset作成 (whitelistを渡す)
    # Train: is_train=True にしてオーバーサンプリング発動
    train_ds = NYUv2BookZoomDataset(Config.DATA_ROOT, whitelist_ids=train_ids, is_train=True, transform=get_transforms('train'), augmentor=train_augmentor)
    # Val: is_train=False にして生データのまま
    val_ds = NYUv2BookZoomDataset(Config.DATA_ROOT, whitelist_ids=val_ids, is_train=False, transform=get_transforms('val'))

    print(f"Final Train Size: {len(train_ds)} (Oversampled)")
    print(f"Final Val Size:   {len(val_ds)}")

    train_loader = DataLoader(train_ds, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=True)

    model = get_model().to(Config.device)
    model = model.to(memory_format=torch.channels_last)
    try: model = torch.compile(model); print(">>> Model Compiled.")
    except: pass

    optimizer = optim.AdamW(model.parameters(), lr=Config.LR, weight_decay=Config.WEIGHT_DECAY)
    criterion = LovaszWeightedLoss(Config.CLASS_WEIGHTS, Config.device, ignore_index=255)
    scaler = torch.amp.GradScaler('cuda')
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.EPOCHS, eta_min=1e-7)

    best_miou = 0.0
    for epoch in range(Config.EPOCHS):
        print(f"\nEpoch {epoch+1}/{Config.EPOCHS}")

        # Train
        model.train(); train_loss = 0
        for images, masks in tqdm(train_loader, desc="Training", leave=False):
            images = images.to(Config.device, non_blocking=True); masks = masks.to(Config.device, non_blocking=True)
            optimizer.zero_grad()
            with torch.amp.autocast('cuda'):
                outputs = model(images); loss = criterion(outputs, masks)
            scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
            train_loss += loss.item()
        train_loss /= len(train_loader)

        # Val
        model.eval(); val_loss = 0; iou_scores, prec_scores, rec_scores = [], [], []
        with torch.no_grad():
            with torch.amp.autocast('cuda'):
                for images, masks in tqdm(val_loader, desc="Validation", leave=False):
                    images = images.to(Config.device, non_blocking=True); masks = masks.to(Config.device, non_blocking=True)
                    outputs = model(images); loss = criterion(outputs, masks); val_loss += loss.item()
                    preds = outputs.argmax(dim=1)
                    bi, bp, br = compute_metrics_batch(preds, masks, Config.CLASSES)
                    iou_scores.append(bi); prec_scores.append(bp); rec_scores.append(br)

        val_miou = np.nanmean(np.nanmean(np.array(iou_scores), axis=0))
        mean_precs = np.nanmean(np.array(prec_scores), axis=0)
        mean_recs = np.nanmean(np.array(rec_scores), axis=0)
        cls_ious = np.nanmean(np.array(iou_scores), axis=0)
        scheduler.step()

        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss/len(val_loader):.4f}")
        print(f"Val mIoU: {val_miou:.4f}")
        print(f"Books [1] -> IoU: {cls_ious[1]:.4f} | Prec: {mean_precs[1]:.4f} | Rec: {mean_recs[1]:.4f}")
        print(f"Table [9] -> IoU: {cls_ious[9]:.4f}")

        if val_miou > best_miou:
            best_miou = val_miou
            torch.save(model.state_dict(), 'best_model_nyuv2_ultimate_fixed.pth')
            print(">>> Best Model Saved!")

if __name__ == '__main__':
    if os.path.exists(Config.DATA_ROOT):
        torch.cuda.empty_cache()
        main()

🚀 Start Training: Strict Split -> Train(Oversampled) / Val(Raw)
Total Images: 795
 -> Train IDs: 716 (Before Oversampling)
 -> Val IDs:   79 (Raw, No Oversampling)
💎 Loading Golden Samples...
Collecting 'Auto' book blobs...


Scanning: 100%|██████████| 716/716 [00:04<00:00, 174.42it/s]


✅ Collected 317 Auto Samples. Total Bank: 325
⚖️ Applying Oversampling to 716 images...
   -> Expanded to 972 images.
Final Train Size: 972 (Oversampled)
Final Val Size:   79
>>> Model Compiled.

Epoch 1/50




Train Loss: 0.8541 | Val Loss: 0.6930
Val mIoU: 0.4774
Books [1] -> IoU: 0.2182 | Prec: 0.2529 | Rec: 0.4516
Table [9] -> IoU: 0.4501
>>> Best Model Saved!

Epoch 2/50




Train Loss: 0.5929 | Val Loss: 0.6214
Val mIoU: 0.5156
Books [1] -> IoU: 0.2117 | Prec: 0.2535 | Rec: 0.3786
Table [9] -> IoU: 0.4510
>>> Best Model Saved!

Epoch 3/50




Train Loss: 0.5335 | Val Loss: 0.6155
Val mIoU: 0.5164
Books [1] -> IoU: 0.2523 | Prec: 0.2703 | Rec: 0.5601
Table [9] -> IoU: 0.5371
>>> Best Model Saved!

Epoch 4/50




Train Loss: 0.4973 | Val Loss: 0.5682
Val mIoU: 0.5386
Books [1] -> IoU: 0.2431 | Prec: 0.2529 | Rec: 0.6343
Table [9] -> IoU: 0.4613
>>> Best Model Saved!

Epoch 5/50




Train Loss: 0.4528 | Val Loss: 0.5359
Val mIoU: 0.5434
Books [1] -> IoU: 0.2106 | Prec: 0.2153 | Rec: 0.6625
Table [9] -> IoU: 0.5080
>>> Best Model Saved!

Epoch 6/50




Train Loss: 0.4292 | Val Loss: 0.5850
Val mIoU: 0.5205
Books [1] -> IoU: 0.2199 | Prec: 0.2294 | Rec: 0.6221
Table [9] -> IoU: 0.5071

Epoch 7/50




Train Loss: 0.4102 | Val Loss: 0.5561
Val mIoU: 0.5604
Books [1] -> IoU: 0.3506 | Prec: 0.3831 | Rec: 0.6609
Table [9] -> IoU: 0.5015
>>> Best Model Saved!

Epoch 8/50




Train Loss: 0.3867 | Val Loss: 0.5246
Val mIoU: 0.5900
Books [1] -> IoU: 0.3018 | Prec: 0.3293 | Rec: 0.7035
Table [9] -> IoU: 0.4556
>>> Best Model Saved!

Epoch 9/50




Train Loss: 0.3752 | Val Loss: 0.5324
Val mIoU: 0.5738
Books [1] -> IoU: 0.2987 | Prec: 0.3364 | Rec: 0.6064
Table [9] -> IoU: 0.6103

Epoch 10/50




Train Loss: 0.3532 | Val Loss: 0.5169
Val mIoU: 0.6000
Books [1] -> IoU: 0.3069 | Prec: 0.3344 | Rec: 0.5456
Table [9] -> IoU: 0.5193
>>> Best Model Saved!

Epoch 11/50




Train Loss: 0.3413 | Val Loss: 0.5193
Val mIoU: 0.5788
Books [1] -> IoU: 0.3333 | Prec: 0.3753 | Rec: 0.7131
Table [9] -> IoU: 0.5689

Epoch 12/50




Train Loss: 0.3314 | Val Loss: 0.5022
Val mIoU: 0.5958
Books [1] -> IoU: 0.3529 | Prec: 0.3828 | Rec: 0.7801
Table [9] -> IoU: 0.6110

Epoch 13/50




Train Loss: 0.3207 | Val Loss: 0.4703
Val mIoU: 0.6232
Books [1] -> IoU: 0.3317 | Prec: 0.3534 | Rec: 0.6366
Table [9] -> IoU: 0.6136
>>> Best Model Saved!

Epoch 14/50




Train Loss: 0.3032 | Val Loss: 0.4781
Val mIoU: 0.6200
Books [1] -> IoU: 0.3457 | Prec: 0.3831 | Rec: 0.6551
Table [9] -> IoU: 0.5482

Epoch 15/50




Train Loss: 0.3185 | Val Loss: 0.5392
Val mIoU: 0.5822
Books [1] -> IoU: 0.3828 | Prec: 0.4162 | Rec: 0.6797
Table [9] -> IoU: 0.5351

Epoch 16/50




Train Loss: 0.3008 | Val Loss: 0.4863
Val mIoU: 0.5920
Books [1] -> IoU: 0.2639 | Prec: 0.2726 | Rec: 0.5079
Table [9] -> IoU: 0.5927

Epoch 17/50




Train Loss: 0.2914 | Val Loss: 0.5092
Val mIoU: 0.6012
Books [1] -> IoU: 0.3708 | Prec: 0.4015 | Rec: 0.8089
Table [9] -> IoU: 0.6249

Epoch 18/50




Train Loss: 0.2766 | Val Loss: 0.5163
Val mIoU: 0.5758
Books [1] -> IoU: 0.3124 | Prec: 0.3408 | Rec: 0.7604
Table [9] -> IoU: 0.5191

Epoch 19/50




Train Loss: 0.2808 | Val Loss: 0.4622
Val mIoU: 0.6374
Books [1] -> IoU: 0.3345 | Prec: 0.3570 | Rec: 0.6880
Table [9] -> IoU: 0.6206
>>> Best Model Saved!

Epoch 20/50




Train Loss: 0.2533 | Val Loss: 0.4455
Val mIoU: 0.6405
Books [1] -> IoU: 0.4107 | Prec: 0.4638 | Rec: 0.7287
Table [9] -> IoU: 0.6010
>>> Best Model Saved!

Epoch 21/50




Train Loss: 0.2574 | Val Loss: 0.4402
Val mIoU: 0.6476
Books [1] -> IoU: 0.4142 | Prec: 0.4643 | Rec: 0.8183
Table [9] -> IoU: 0.6710
>>> Best Model Saved!

Epoch 22/50




Train Loss: 0.2560 | Val Loss: 0.4615
Val mIoU: 0.6371
Books [1] -> IoU: 0.3790 | Prec: 0.4117 | Rec: 0.8685
Table [9] -> IoU: 0.5833

Epoch 23/50




Train Loss: 0.2403 | Val Loss: 0.4910
Val mIoU: 0.6147
Books [1] -> IoU: 0.4289 | Prec: 0.4657 | Rec: 0.8471
Table [9] -> IoU: 0.6098

Epoch 24/50




Train Loss: 0.2353 | Val Loss: 0.4720
Val mIoU: 0.6230
Books [1] -> IoU: 0.3941 | Prec: 0.4336 | Rec: 0.8050
Table [9] -> IoU: 0.6305

Epoch 25/50




Train Loss: 0.2364 | Val Loss: 0.4977
Val mIoU: 0.5979
Books [1] -> IoU: 0.3841 | Prec: 0.4261 | Rec: 0.7559
Table [9] -> IoU: 0.6102

Epoch 26/50




Train Loss: 0.2309 | Val Loss: 0.4793
Val mIoU: 0.6197
Books [1] -> IoU: 0.3952 | Prec: 0.4333 | Rec: 0.8392
Table [9] -> IoU: 0.6037

Epoch 27/50




Train Loss: 0.2242 | Val Loss: 0.5161
Val mIoU: 0.5970
Books [1] -> IoU: 0.4274 | Prec: 0.4678 | Rec: 0.8561
Table [9] -> IoU: 0.6400

Epoch 28/50




Train Loss: 0.2188 | Val Loss: 0.4936
Val mIoU: 0.6436
Books [1] -> IoU: 0.4337 | Prec: 0.4801 | Rec: 0.7639
Table [9] -> IoU: 0.6117

Epoch 29/50




Train Loss: 0.2160 | Val Loss: 0.4584
Val mIoU: 0.6243
Books [1] -> IoU: 0.2681 | Prec: 0.2854 | Rec: 0.6967
Table [9] -> IoU: 0.5730

Epoch 30/50




Train Loss: 0.2110 | Val Loss: 0.4756
Val mIoU: 0.6262
Books [1] -> IoU: 0.3629 | Prec: 0.3875 | Rec: 0.8218
Table [9] -> IoU: 0.6157

Epoch 31/50




Train Loss: 0.2120 | Val Loss: 0.4650
Val mIoU: 0.6557
Books [1] -> IoU: 0.4303 | Prec: 0.4702 | Rec: 0.8740
Table [9] -> IoU: 0.6026
>>> Best Model Saved!

Epoch 32/50




Train Loss: 0.2103 | Val Loss: 0.4758
Val mIoU: 0.6264
Books [1] -> IoU: 0.4383 | Prec: 0.4895 | Rec: 0.7357
Table [9] -> IoU: 0.6303

Epoch 33/50




Train Loss: 0.2089 | Val Loss: 0.4644
Val mIoU: 0.6636
Books [1] -> IoU: 0.4256 | Prec: 0.4722 | Rec: 0.6797
Table [9] -> IoU: 0.6127
>>> Best Model Saved!

Epoch 34/50




Train Loss: 0.1998 | Val Loss: 0.4430
Val mIoU: 0.6327
Books [1] -> IoU: 0.4070 | Prec: 0.4544 | Rec: 0.7803
Table [9] -> IoU: 0.6405

Epoch 35/50




Train Loss: 0.2003 | Val Loss: 0.4429
Val mIoU: 0.6640
Books [1] -> IoU: 0.4395 | Prec: 0.4863 | Rec: 0.7760
Table [9] -> IoU: 0.7056
>>> Best Model Saved!

Epoch 36/50




Train Loss: 0.2018 | Val Loss: 0.4502
Val mIoU: 0.6567
Books [1] -> IoU: 0.4286 | Prec: 0.4710 | Rec: 0.8178
Table [9] -> IoU: 0.6415

Epoch 37/50




Train Loss: 0.1945 | Val Loss: 0.4604
Val mIoU: 0.6534
Books [1] -> IoU: 0.3381 | Prec: 0.4152 | Rec: 0.7364
Table [9] -> IoU: 0.5579

Epoch 38/50




Train Loss: 0.1921 | Val Loss: 0.4297
Val mIoU: 0.6588
Books [1] -> IoU: 0.4230 | Prec: 0.4598 | Rec: 0.8063
Table [9] -> IoU: 0.6562

Epoch 39/50




Train Loss: 0.1916 | Val Loss: 0.4086
Val mIoU: 0.6667
Books [1] -> IoU: 0.4079 | Prec: 0.4548 | Rec: 0.7976
Table [9] -> IoU: 0.6042
>>> Best Model Saved!

Epoch 40/50




Train Loss: 0.1927 | Val Loss: 0.4473
Val mIoU: 0.6386
Books [1] -> IoU: 0.4407 | Prec: 0.4841 | Rec: 0.7861
Table [9] -> IoU: 0.6285

Epoch 41/50




Train Loss: 0.1930 | Val Loss: 0.4079
Val mIoU: 0.6639
Books [1] -> IoU: 0.4149 | Prec: 0.4520 | Rec: 0.8159
Table [9] -> IoU: 0.6337

Epoch 42/50




Train Loss: 0.1881 | Val Loss: 0.4222
Val mIoU: 0.6636
Books [1] -> IoU: 0.4027 | Prec: 0.4419 | Rec: 0.8573
Table [9] -> IoU: 0.6730

Epoch 43/50




Train Loss: 0.1846 | Val Loss: 0.4387
Val mIoU: 0.6604
Books [1] -> IoU: 0.5136 | Prec: 0.5588 | Rec: 0.8676
Table [9] -> IoU: 0.5910

Epoch 44/50




Train Loss: 0.1876 | Val Loss: 0.4391
Val mIoU: 0.6499
Books [1] -> IoU: 0.3926 | Prec: 0.4261 | Rec: 0.8167
Table [9] -> IoU: 0.6254

Epoch 45/50




Train Loss: 0.1918 | Val Loss: 0.4622
Val mIoU: 0.6357
Books [1] -> IoU: 0.3836 | Prec: 0.4211 | Rec: 0.7398
Table [9] -> IoU: 0.5919

Epoch 46/50




Train Loss: 0.1891 | Val Loss: 0.4313
Val mIoU: 0.6862
Books [1] -> IoU: 0.4122 | Prec: 0.4569 | Rec: 0.7643
Table [9] -> IoU: 0.5137
>>> Best Model Saved!

Epoch 47/50




Train Loss: 0.1828 | Val Loss: 0.4795
Val mIoU: 0.6383
Books [1] -> IoU: 0.4070 | Prec: 0.4478 | Rec: 0.8550
Table [9] -> IoU: 0.6222

Epoch 48/50




Train Loss: 0.1798 | Val Loss: 0.4550
Val mIoU: 0.6247
Books [1] -> IoU: 0.3967 | Prec: 0.4377 | Rec: 0.7462
Table [9] -> IoU: 0.6211

Epoch 49/50




Train Loss: 0.1841 | Val Loss: 0.4606
Val mIoU: 0.6549
Books [1] -> IoU: 0.3996 | Prec: 0.4362 | Rec: 0.7838
Table [9] -> IoU: 0.5960

Epoch 50/50


                                                         

Train Loss: 0.1802 | Val Loss: 0.4406
Val mIoU: 0.6669
Books [1] -> IoU: 0.4339 | Prec: 0.4804 | Rec: 0.7531
Table [9] -> IoU: 0.6567




In [None]:
"""
コードの目的:
    Deep Learning 基礎講座 最終課題 (NYUv2 Semantic Segmentation)
    「Bookクラスの誤検出・見逃し」が多い画像を自動抽出し、弱点を徹底分析するための学習スクリプト

処理内容・特徴:
    1. モデル構成: mit_b4 (SegFormer) + UPerNet (5ch入力)
    2. データ戦略: 5倍オーバーサンプリング + Hybrid Paste + Heavy Augmentation
    3. 弱点分析 (New!):
       - Worst Sample Mining: 検証データの中で「Bookクラスの誤り(FP+FN)」が最も多いワースト画像を自動選出して保存。
       - Confusion Matrix: クラス間の混同行列を計算し、ヒートマップとして保存。「本を何と間違えやすいか」を可視化。
       - Full IoU Report: 13クラス全てのIoU、Precision、Recallを表示。
       - Interval Logging: 10エポックごとに詳細分析を実行。
    4. 出力管理:
       - JSTタイムスタンプ付きのフォルダにログを集約。
       - 学習終了時に自動でZip圧縮。
"""

import os
import random
import time
import datetime
import pytz
import shutil
import numpy as np
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from tqdm import tqdm
from torch.cuda.amp import GradScaler
from sklearn.metrics import confusion_matrix

# ==========================================
# 0. Optimization Flags
# ==========================================
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.benchmark = True

# ==========================================
# 1. Configuration
# ==========================================
def get_jst_time_str():
    return datetime.datetime.now(pytz.timezone('Asia/Tokyo')).strftime('%Y%m%d_%H%M%S')

class Config:
    DATA_ROOT = '/content/data'

    # Model
    ENCODER = 'mit_b4'
    ENCODER_WEIGHTS = 'imagenet'
    CLASSES = 13
    INPUT_CHANNELS = 5

    # Training
    CROP_SIZE = 320
    IMG_SIZE = 512
    BATCH_SIZE = 16
    EPOCHS = 50
    LR = 6e-5
    WEIGHT_DECAY = 0.01
    SEED = 42
    device = torch.device('cuda')

    # Augmentation
    PASTE_PROB = 0.7
    OVERSAMPLE_FACTOR = 5

    # Analysis Settings (New!)
    LOG_TIMESTAMP = get_jst_time_str()
    DEBUG_DIR = f'/content/logs_{LOG_TIMESTAMP}' # JST付きフォルダ
    SAVE_INTERVAL = 10     # 10エポックごとに詳細ログ保存
    NUM_WORST_SAMPLES = 5  # 誤りがひどい画像を何枚保存するか

    GOLDEN_SAMPLES = [
        {'id': '000066', 'bbox': [407, 205, 89, 142]},
        {'id': '000072', 'bbox': [366, 412, 79, 63]},
        {'id': '000097', 'bbox': [271, 363, 51, 64]},
        {'id': '000105', 'bbox': [159, 216, 69, 20]},
        {'id': '000107', 'bbox': [244, 280, 142, 42]},
        {'id': '000109', 'bbox': [41, 285, 104, 62]},
        {'id': '000177', 'bbox': [77, 284, 69, 19]},
        {'id': '000353', 'bbox': [257, 218, 76, 31]},
    ]

    CLASS_NAMES = [
        "Bed", "Books", "Ceiling", "Chair", "Floor", "Furniture",
        "Objects", "Picture", "Sofa", "Table", "TV", "Wall", "Window"
    ]
    # Class 1 (Books) is heavily weighted
    CLASS_WEIGHTS = [1.0, 10.0, 1.0, 1.2, 1.0, 1.2, 2.0, 1.2, 1.0, 1.2, 3.0, 0.8, 1.0]

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

# ==========================================
# 2. Augmentor (Safe Rotation)
# ==========================================
class BookDepthAwareAugmentor:
    def __init__(self, dataset, prob=0.5):
        self.book_bank = []
        self.prob = prob
        self.load_golden_samples(dataset)
        self.collect_books(dataset)

    def load_golden_samples(self, dataset):
        print("💎 Loading Golden Samples...")
        for item in Config.GOLDEN_SAMPLES:
            id_ = item['id']; x, y, w, h = item['bbox']
            img_path = os.path.join(dataset.img_dir, f"{id_}.png"); depth_path = os.path.join(dataset.depth_dir, f"{id_}.png"); label_path = os.path.join(dataset.label_dir, f"{id_}.png")
            if not os.path.exists(img_path): continue
            image = cv2.imread(img_path); image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            depth = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH)
            if depth is None: depth = np.zeros(image.shape[:2], np.float32)
            if depth.max() > 0: depth = depth.astype(np.float32) / depth.max()
            edge = dataset.compute_edge_map(image)
            label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
            roi_rgb = image[y:y+h, x:x+w]; roi_depth = depth[y:y+h, x:x+w]; roi_edge = edge[y:y+h, x:x+w]; roi_label = label[y:y+h, x:x+w]
            blob_mask = (roi_label == 1).astype(np.uint8)
            if np.sum(blob_mask) == 0: continue
            self.book_bank.append({'rgb': roi_rgb, 'depth': roi_depth, 'edge': roi_edge, 'mask': blob_mask, 'mean_depth': np.mean(roi_depth)})

    def collect_books(self, dataset):
        print("Collecting 'Auto' book blobs...")
        unique_scan_ids = sorted(list(set(dataset.ids)))
        for id_ in tqdm(unique_scan_ids, desc="Scanning"):
            label_path = os.path.join(dataset.label_dir, f"{id_}.png")
            label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
            if label is None or not np.any(label == 1): continue
            book_mask = (label == 1).astype(np.uint8)
            num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(book_mask, connectivity=8)
            if num_labels <= 1: continue
            img_path = os.path.join(dataset.img_dir, f"{id_}.png"); depth_path = os.path.join(dataset.depth_dir, f"{id_}.png")
            image = cv2.imread(img_path); image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            depth = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH)
            if depth is None: depth = np.zeros_like(label, dtype=np.float32)
            if depth.max() > 0: depth = depth.astype(np.float32) / depth.max()
            edge = dataset.compute_edge_map(image)
            for j in range(1, num_labels):
                area = stats[j, cv2.CC_STAT_AREA]
                if area < 300: continue
                x = stats[j, cv2.CC_STAT_LEFT]; y = stats[j, cv2.CC_STAT_TOP]; w = stats[j, cv2.CC_STAT_WIDTH]; h = stats[j, cv2.CC_STAT_HEIGHT]
                if w/h > 10.0 or w/h < 0.1 or area/(w*h) < 0.5: continue
                blob_mask = (labels[y:y+h, x:x+w] == j).astype(np.uint8)
                self.book_bank.append({'rgb': image[y:y+h, x:x+w], 'depth': depth[y:y+h, x:x+w], 'edge': edge[y:y+h, x:x+w], 'mask': blob_mask, 'mean_depth': np.mean(depth[y:y+h, x:x+w])})

    def apply(self, image, label):
        if random.random() > self.prob or len(self.book_bank) == 0: return image, label
        H, W = label.shape
        mask_target = ((label == 9) | (label == 5) | (label == 4)).astype(np.uint8)
        ys, xs = np.where(mask_target > 0)
        if len(ys) == 0: return image, label

        idx = random.randint(0, len(ys) - 1); y_t, x_t = ys[idx], xs[idx]
        blob = random.choice(self.book_bank)
        target_d = image[y_t, x_t, 3] / 255.0; source_d = blob['mean_depth']

        scale = np.clip(source_d / target_d, 0.3, 1.8) if target_d > 0.01 and source_d > 0.01 else 1.0
        blob_h, blob_w = blob['mask'].shape; new_w, new_h = int(blob_w * scale), int(blob_h * scale)
        if new_w <= 0 or new_h <= 0: return image, label

        blob_rgb = cv2.resize(blob['rgb'], (new_w, new_h)); blob_depth = cv2.resize(blob['depth'], (new_w, new_h))
        blob_edge = cv2.resize(blob['edge'], (new_w, new_h)); blob_mask = cv2.resize(blob['mask'], (new_w, new_h), interpolation=cv2.INTER_NEAREST)

        k = random.choice([0, 1, 2, 3])
        if k > 0:
            blob_rgb = np.rot90(blob_rgb, k); blob_depth = np.rot90(blob_depth, k)
            blob_edge = np.rot90(blob_edge, k); blob_mask = np.rot90(blob_mask, k)

        new_h, new_w = blob_rgb.shape[:2]
        if new_h >= H or new_w >= W: return image, label
        y_t = min(y_t, H - new_h); x_t = min(x_t, W - new_w)

        roi_depth = image[y_t:y_t+new_h, x_t:x_t+new_w, 3]
        if roi_depth.shape[:2] != (new_h, new_w): return image, label

        new_d_map = (blob_depth * 255.0) + (image[y_t, x_t, 3] - (np.mean(blob_depth) * 255.0))
        is_in_front = new_d_map < (roi_depth + 5.0)
        final_mask = (blob_mask == 1) & is_in_front

        mask_3 = np.stack([final_mask]*3, axis=2)
        image[y_t:y_t+new_h, x_t:x_t+new_w, :3] = np.where(mask_3, blob_rgb, image[y_t:y_t+new_h, x_t:x_t+new_w, :3])
        image[y_t:y_t+new_h, x_t:x_t+new_w, 3] = np.where(final_mask, np.clip(new_d_map,0,255), image[y_t:y_t+new_h, x_t:x_t+new_w, 3])
        image[y_t:y_t+new_h, x_t:x_t+new_w, 4] = np.where(final_mask, blob_edge*255.0, image[y_t:y_t+new_h, x_t:x_t+new_w, 4])
        label[y_t:y_t+new_h, x_t:x_t+new_w] = np.where(final_mask, 1, label[y_t:y_t+new_h, x_t:x_t+new_w])

        return image, label

# ==========================================
# 3. Dataset
# ==========================================
def get_transforms(phase='train'):
    if phase == 'train':
        return {'color': A.Compose([A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5), A.RandomBrightnessContrast(p=0.5)]),
                'geo': A.Compose([A.GridDistortion(p=0.3), A.Perspective(p=0.3), A.ShiftScaleRotate(p=0.5), A.HorizontalFlip(p=0.5), A.Resize(Config.IMG_SIZE, Config.IMG_SIZE), A.Normalize(mean=[0.485, 0.456, 0.406, 0.5, 0.5], std=[0.229, 0.224, 0.225, 0.5, 0.5], max_pixel_value=255.0), ToTensorV2()])}
    else:
        return A.Compose([A.Resize(Config.IMG_SIZE, Config.IMG_SIZE), A.Normalize(mean=[0.485, 0.456, 0.406, 0.5, 0.5], std=[0.229, 0.224, 0.225, 0.5, 0.5], max_pixel_value=255.0), ToTensorV2()])

class NYUv2BookZoomDataset(Dataset):
    def __init__(self, root_dir, whitelist_ids=None, is_train=False, img_size=512, transform=None, augmentor=None):
        self.root_dir = root_dir; self.img_size = img_size; self.transform = transform; self.augmentor = augmentor
        self.img_dir = os.path.join(root_dir, 'train', 'image'); self.depth_dir = os.path.join(root_dir, 'train', 'depth'); self.label_dir = os.path.join(root_dir, 'train', 'label')
        self.ids = whitelist_ids if whitelist_ids is not None else sorted([os.path.splitext(f)[0] for f in os.listdir(self.img_dir) if f.endswith('.png')])

        if is_train:
            print(f"⚖️ Applying Oversampling ({Config.OVERSAMPLE_FACTOR}x) to {len(self.ids)} images...")
            expanded = []
            for id_ in self.ids:
                label = cv2.imread(os.path.join(self.label_dir, f"{id_}.png"), cv2.IMREAD_GRAYSCALE)
                if label is not None and np.any(label == 1):
                    expanded.extend([id_] * Config.OVERSAMPLE_FACTOR)
                else:
                    expanded.extend([id_])
            self.ids = expanded
            print(f"   -> Expanded to {len(self.ids)} images.")

    def __len__(self): return len(self.ids)
    def compute_edge_map(self, img_array):
        gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY); sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3); sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
        return cv2.normalize(np.sqrt(sobelx**2 + sobely**2), None, 0, 1, cv2.NORM_MINMAX).astype(np.float32)
    def get_book_centric_crop(self, image, label):
        ys, xs = np.where(label == 1); cs = Config.CROP_SIZE; h, w = image.shape[:2]
        if len(ys) > 0: idx = random.randint(0, len(ys) - 1); cy, cx = ys[idx], xs[idx]; y1 = max(0, min(h-cs, cy - cs//2 + random.randint(-50, 50))); x1 = max(0, min(w-cs, cx - cs//2 + random.randint(-50, 50)))
        else: y1 = random.randint(0, h - cs); x1 = random.randint(0, w - cs)
        return image[y1:y1+cs, x1:x1+cs], label[y1:y1+cs, x1:x1+cs]
    def __getitem__(self, idx):
        id_ = self.ids[idx]; img_path = os.path.join(self.img_dir, f"{id_}.png"); depth_path = os.path.join(self.depth_dir, f"{id_}.png"); label_path = os.path.join(self.label_dir, f"{id_}.png")
        image = cv2.imread(img_path); image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB); depth = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH); label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
        if depth is None: depth = np.zeros(image.shape[:2], np.float32)
        if depth.max() > 0: depth = depth.astype(np.float32) / depth.max()
        label[~((label >= 0) & (label < Config.CLASSES))] = 255; edge = self.compute_edge_map(image)
        combined = np.dstack([image, depth*255.0, edge*255.0]).astype(np.float32)
        if self.augmentor: combined, label = self.augmentor.apply(combined, label)
        combined_crop, label_crop = self.get_book_centric_crop(combined, label)
        if self.transform:
            if isinstance(self.transform, dict):
                rgb_aug = self.transform['color'](image=combined_crop[:,:,:3].astype(np.uint8))['image']
                t_geo = self.transform['geo'](image=np.dstack([rgb_aug, combined_crop[:,:,3:]]), mask=label_crop)
                return t_geo['image'], t_geo['mask'].long()
            t = self.transform(image=combined_crop, mask=label_crop)
            return t['image'], t['mask'].long()
        return torch.from_numpy(combined_crop.transpose(2,0,1)).float(), torch.from_numpy(label_crop).long()

# ==========================================
# 4. Analysis Utils (Updated)
# ==========================================
def save_worst_samples(worst_samples, epoch, save_dir):
    """
    Worst Samples (Bookクラスの誤りが多い画像) を保存
    worst_samples: list of tuples (error_score, image_tensor, mask_tensor, output_tensor)
    """
    os.makedirs(save_dir, exist_ok=True)
    mean = np.array([0.485, 0.456, 0.406]); std = np.array([0.229, 0.224, 0.225])

    for i, (score, img_t, mask_t, out_t) in enumerate(worst_samples):
        # Restore Image
        img_vis = np.clip((img_t[:3, :, :].cpu().numpy().transpose(1, 2, 0) * std + mean) * 255.0, 0, 255).astype(np.uint8)

        # Predictions
        probs = F.softmax(out_t.unsqueeze(0), dim=1) # [1, C, H, W]
        pred_t = torch.argmax(probs, dim=1).squeeze(0) # [H, W]
        book_prob = probs[0, 1, :, :].cpu().numpy()

        gt_book = (mask_t.cpu().numpy() == 1)
        pred_book = (pred_t.cpu().numpy() == 1)

        # Visualization
        fig, axes = plt.subplots(1, 4, figsize=(20, 5))
        axes[0].imshow(img_vis); axes[0].set_title(f"Rank{i+1}: Error Score={int(score)}")
        axes[1].imshow(gt_book, cmap='gray'); axes[1].set_title("GT Book")

        im = axes[2].imshow(book_prob, cmap='jet', vmin=0, vmax=1.0); axes[2].set_title("Book Prob")
        plt.colorbar(im, ax=axes[2], fraction=0.046, pad=0.04)

        # Error Map
        error_vis = (img_vis * 0.7 + 255 * 0.3).astype(np.uint8)
        error_vis[pred_book & (~gt_book)] = [255, 0, 0] # FP
        error_vis[(~pred_book) & gt_book] = [0, 0, 255] # FN
        axes[3].imshow(error_vis); axes[3].set_title("Red=FP, Blue=FN")

        for ax in axes: ax.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f"epoch_{epoch:03d}_worst_{i+1:02d}.png"))
        plt.close()

def save_confusion_matrix(cm, epoch, save_dir):
    """
    Confusion Matrixをヒートマップとして保存
    """
    os.makedirs(save_dir, exist_ok=True)
    # Normalize rows
    cm_norm = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-6)

    plt.figure(figsize=(12, 10))
    sns.heatmap(cm_norm, annot=False, fmt='.2f', cmap='Blues',
                xticklabels=Config.CLASS_NAMES, yticklabels=Config.CLASS_NAMES)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.title(f'Confusion Matrix (Normalized) - Epoch {epoch}')
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f"epoch_{epoch:03d}_confusion_matrix.png"))
    plt.close()

# ==========================================
# 5. Main Execution
# ==========================================
def main():
    set_seed(Config.SEED)
    os.makedirs(Config.DEBUG_DIR, exist_ok=True)
    print(f"[{get_jst_time_str()}] 🚀 Start Training with Error Mining")
    print(f"   Logs & Images -> {Config.DEBUG_DIR}")

    # Load Model & Components
    model = smp.UPerNet(encoder_name=Config.ENCODER, encoder_weights=Config.ENCODER_WEIGHTS, in_channels=3, classes=Config.CLASSES, activation=None)
    if hasattr(model.encoder, 'patch_embed1') and hasattr(model.encoder.patch_embed1, 'proj'):
        old = model.encoder.patch_embed1.proj; new_l = nn.Conv2d(5, old.out_channels, old.kernel_size, old.stride, old.padding, bias=(old.bias is not None))
        with torch.no_grad():
            new_l.weight[:, :3, :, :] = old.weight; new_l.weight[:, 3:, :, :] = torch.mean(old.weight, dim=1, keepdim=True).repeat(1, 2, 1, 1)
            if old.bias is not None: new_l.bias = old.bias
        model.encoder.patch_embed1.proj = new_l
    model = model.to(Config.device)

    optimizer = optim.AdamW(model.parameters(), lr=Config.LR, weight_decay=Config.WEIGHT_DECAY)
    criterion = LovaszWeightedLoss(Config.CLASS_WEIGHTS, Config.device, ignore_index=255)
    scaler = torch.amp.GradScaler('cuda')
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.EPOCHS, eta_min=1e-7)

    # Dataset Setup
    img_dir_train = os.path.join(Config.DATA_ROOT, 'train', 'image')
    all_ids = sorted([os.path.splitext(f)[0] for f in os.listdir(img_dir_train) if f.endswith('.png')])
    random.shuffle(all_ids)
    n_val = int(len(all_ids) * 0.1)
    val_ids = all_ids[:n_val]; train_ids = all_ids[n_val:]

    temp_ds = NYUv2BookZoomDataset(Config.DATA_ROOT, whitelist_ids=train_ids, is_train=False)
    train_augmentor = BookDepthAwareAugmentor(temp_ds, prob=Config.PASTE_PROB)
    train_ds = NYUv2BookZoomDataset(Config.DATA_ROOT, whitelist_ids=train_ids, is_train=True, transform=get_transforms('train'), augmentor=train_augmentor)
    val_ds = NYUv2BookZoomDataset(Config.DATA_ROOT, whitelist_ids=val_ids, is_train=False, transform=get_transforms('val'))
    train_loader = DataLoader(train_ds, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=True)

    best_miou = 0.0

    for epoch in range(Config.EPOCHS):
        print(f"\n[{get_jst_time_str()}] Epoch {epoch+1}/{Config.EPOCHS}")

        # Train
        model.train(); train_loss = 0
        for images, masks in tqdm(train_loader, desc="Training", leave=False):
            images = images.to(Config.device, non_blocking=True); masks = masks.to(Config.device, non_blocking=True)
            optimizer.zero_grad()
            with torch.amp.autocast('cuda'): outputs = model(images); loss = criterion(outputs, masks)
            scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
            train_loss += loss.item()

        # Val
        model.eval(); val_loss = 0
        iou_scores, prec_scores, rec_scores = [], [], []

        # Analysis Containers
        worst_samples = [] # (error_score, img, mask, output)
        conf_mat = np.zeros((Config.CLASSES, Config.CLASSES), dtype=np.int64)

        do_analysis = ((epoch + 1) % Config.SAVE_INTERVAL == 0) or ((epoch + 1) == Config.EPOCHS)

        with torch.no_grad():
            with torch.amp.autocast('cuda'):
                for batch_idx, (images, masks) in enumerate(tqdm(val_loader, desc="Validation", leave=False)):
                    images = images.to(Config.device, non_blocking=True); masks = masks.to(Config.device, non_blocking=True)
                    outputs = model(images); loss = criterion(outputs, masks); val_loss += loss.item()

                    preds = outputs.argmax(dim=1)

                    # Update Metrics
                    def compute_batch_metrics(p, m):
                        p = p.view(-1); m = m.view(-1); valid = m != 255; p = p[valid]; m = m[valid]
                        ious, precs, recs = [], [], []
                        for c in range(Config.CLASSES):
                            tp = ((p==c)&(m==c)).sum().item()
                            fp = ((p==c)&(m!=c)).sum().item()
                            fn = ((p!=c)&(m==c)).sum().item()
                            ious.append(tp/(tp+fp+fn) if (tp+fp+fn)>0 else np.nan)
                            precs.append(tp/(tp+fp) if (tp+fp)>0 else np.nan)
                            recs.append(tp/(tp+fn) if (tp+fn)>0 else np.nan)
                        return np.array(ious), np.array(precs), np.array(recs)

                    bi, bp, br = compute_batch_metrics(preds, masks)
                    iou_scores.append(bi); prec_scores.append(bp); rec_scores.append(br)

                    if do_analysis:
                        # 1. Update Confusion Matrix
                        m_np = masks.cpu().numpy().flatten(); p_np = preds.cpu().numpy().flatten()
                        valid = m_np != 255
                        conf_mat += confusion_matrix(m_np[valid], p_np[valid], labels=range(Config.CLASSES))

                        # 2. Find Worst Samples for Book (Class 1)
                        for i in range(len(images)):
                            gt_book = (masks[i] == 1)
                            pred_book = (preds[i] == 1)
                            fp_count = (pred_book & (~gt_book)).sum().item()
                            fn_count = ((~pred_book) & gt_book).sum().item()
                            error_score = fp_count + fn_count

                            if error_score > 0:
                                # Save only tensors to save memory, process later
                                worst_samples.append((error_score, images[i].cpu(), masks[i].cpu(), outputs[i].cpu()))

        # Aggregate Metrics
        val_miou = np.nanmean(np.nanmean(np.array(iou_scores), axis=0))
        cls_ious = np.nanmean(np.array(iou_scores), axis=0)

        # Logging
        print(f"Train Loss: {train_loss/len(train_loader):.4f} | Val Loss: {val_loss/len(val_loader):.4f}")
        print(f"Val mIoU: {val_miou:.4f}")
        print("Class-wise IoU:")
        for i, name in enumerate(Config.CLASS_NAMES):
            print(f"  {name:10s}: {cls_ious[i]:.4f}")

        if val_miou > best_miou:
            best_miou = val_miou
            torch.save(model.state_dict(), 'best_model_nyuv2_error_mining.pth')
            print(">>> Best Model Saved!")

        # Perform Analysis Saving
        if do_analysis:
            print(f"📊 Saving Analysis for Epoch {epoch+1}...")
            # Save Confusion Matrix
            save_confusion_matrix(conf_mat, epoch+1, Config.DEBUG_DIR)

            # Save Top Worst Samples
            worst_samples.sort(key=lambda x: x[0], reverse=True)
            top_worst = worst_samples[:Config.NUM_WORST_SAMPLES]
            save_worst_samples(top_worst, epoch+1, Config.DEBUG_DIR)
            print(f"   -> Saved {len(top_worst)} worst error samples and Confusion Matrix.")

    # Finish
    print("📦 Zipping logs...")
    shutil.make_archive(Config.DEBUG_DIR, 'zip', Config.DEBUG_DIR)
    print(f"✅ Created: {Config.DEBUG_DIR}.zip")

if __name__ == '__main__':
    if os.path.exists(Config.DATA_ROOT):
        torch.cuda.empty_cache()
        main()

[20260108_135015] 🚀 Start Training with Error Mining
   Logs & Images -> /content/logs_20260108_135015
💎 Loading Golden Samples...
Collecting 'Auto' book blobs...


Scanning: 100%|██████████| 716/716 [00:04<00:00, 169.08it/s]
  original_init(self, **validated_kwargs)


⚖️ Applying Oversampling (5x) to 716 images...
   -> Expanded to 1228 images.

[20260108_135022] Epoch 1/50




Train Loss: 0.8303 | Val Loss: 0.6552
Val mIoU: 0.4924
Class-wise IoU:
  Bed       : 0.6120
  Books     : 0.1742
  Ceiling   : 0.3754
  Chair     : 0.4591
  Floor     : 0.8157
  Furniture : 0.5487
  Objects   : 0.4191
  Picture   : 0.3594
  Sofa      : 0.5529
  Table     : 0.4698
  TV        : 0.3040
  Wall      : 0.6721
  Window    : 0.6391
>>> Best Model Saved!

[20260108_135109] Epoch 2/50




Train Loss: 0.5729 | Val Loss: 0.6244
Val mIoU: 0.5279
Class-wise IoU:
  Bed       : 0.6091
  Books     : 0.2437
  Ceiling   : 0.5400
  Chair     : 0.5764
  Floor     : 0.8480
  Furniture : 0.5406
  Objects   : 0.4354
  Picture   : 0.4101
  Sofa      : 0.6048
  Table     : 0.4163
  TV        : 0.3824
  Wall      : 0.7444
  Window    : 0.5117
>>> Best Model Saved!

[20260108_135140] Epoch 3/50




Train Loss: 0.5059 | Val Loss: 0.5784
Val mIoU: 0.5384
Class-wise IoU:
  Bed       : 0.5698
  Books     : 0.3029
  Ceiling   : 0.2763
  Chair     : 0.5687
  Floor     : 0.8519
  Furniture : 0.5935
  Objects   : 0.5073
  Picture   : 0.5219
  Sofa      : 0.5460
  Table     : 0.5069
  TV        : 0.4972
  Wall      : 0.7154
  Window    : 0.5418
>>> Best Model Saved!

[20260108_135211] Epoch 4/50




Train Loss: 0.4512 | Val Loss: 0.5773
Val mIoU: 0.5437
Class-wise IoU:
  Bed       : 0.6320
  Books     : 0.2843
  Ceiling   : 0.3285
  Chair     : 0.6164
  Floor     : 0.8560
  Furniture : 0.5998
  Objects   : 0.4809
  Picture   : 0.4537
  Sofa      : 0.5445
  Table     : 0.4892
  TV        : 0.5417
  Wall      : 0.7271
  Window    : 0.5143
>>> Best Model Saved!

[20260108_135242] Epoch 5/50




Train Loss: 0.4169 | Val Loss: 0.5312
Val mIoU: 0.5621
Class-wise IoU:
  Bed       : 0.6209
  Books     : 0.2330
  Ceiling   : 0.4561
  Chair     : 0.5626
  Floor     : 0.8714
  Furniture : 0.6460
  Objects   : 0.5047
  Picture   : 0.5227
  Sofa      : 0.4842
  Table     : 0.4886
  TV        : 0.5838
  Wall      : 0.7761
  Window    : 0.5577
>>> Best Model Saved!

[20260108_135312] Epoch 6/50




Train Loss: 0.3867 | Val Loss: 0.5581
Val mIoU: 0.5702
Class-wise IoU:
  Bed       : 0.7269
  Books     : 0.2745
  Ceiling   : 0.2770
  Chair     : 0.4876
  Floor     : 0.8687
  Furniture : 0.5611
  Objects   : 0.4946
  Picture   : 0.4785
  Sofa      : 0.6735
  Table     : 0.4724
  TV        : 0.7157
  Wall      : 0.7558
  Window    : 0.6262
>>> Best Model Saved!

[20260108_135343] Epoch 7/50




Train Loss: 0.3630 | Val Loss: 0.5641
Val mIoU: 0.5558
Class-wise IoU:
  Bed       : 0.6746
  Books     : 0.3065
  Ceiling   : 0.3396
  Chair     : 0.5693
  Floor     : 0.8876
  Furniture : 0.5951
  Objects   : 0.4998
  Picture   : 0.4770
  Sofa      : 0.6479
  Table     : 0.5098
  TV        : 0.4492
  Wall      : 0.7622
  Window    : 0.5071

[20260108_135413] Epoch 8/50




Train Loss: 0.3519 | Val Loss: 0.4961
Val mIoU: 0.6066
Class-wise IoU:
  Bed       : 0.7774
  Books     : 0.3337
  Ceiling   : 0.6137
  Chair     : 0.4856
  Floor     : 0.8958
  Furniture : 0.6661
  Objects   : 0.5677
  Picture   : 0.5460
  Sofa      : 0.5973
  Table     : 0.4266
  TV        : 0.5528
  Wall      : 0.8086
  Window    : 0.6148
>>> Best Model Saved!

[20260108_135444] Epoch 9/50




Train Loss: 0.3254 | Val Loss: 0.5395
Val mIoU: 0.5750
Class-wise IoU:
  Bed       : 0.6751
  Books     : 0.3867
  Ceiling   : 0.4010
  Chair     : 0.4689
  Floor     : 0.8893
  Furniture : 0.6504
  Objects   : 0.5114
  Picture   : 0.4571
  Sofa      : 0.6494
  Table     : 0.5453
  TV        : 0.4809
  Wall      : 0.7788
  Window    : 0.5799

[20260108_135514] Epoch 10/50




Train Loss: 0.3218 | Val Loss: 0.5139
Val mIoU: 0.5855
Class-wise IoU:
  Bed       : 0.7473
  Books     : 0.3422
  Ceiling   : 0.3328
  Chair     : 0.6017
  Floor     : 0.8931
  Furniture : 0.6820
  Objects   : 0.5358
  Picture   : 0.4237
  Sofa      : 0.6681
  Table     : 0.5105
  TV        : 0.5349
  Wall      : 0.7812
  Window    : 0.5583
📊 Saving Analysis for Epoch 10...
   -> Saved 5 worst error samples and Confusion Matrix.

[20260108_135550] Epoch 11/50




Train Loss: 0.3106 | Val Loss: 0.5327
Val mIoU: 0.5843
Class-wise IoU:
  Bed       : 0.6787
  Books     : 0.3413
  Ceiling   : 0.3003
  Chair     : 0.6094
  Floor     : 0.8878
  Furniture : 0.6440
  Objects   : 0.5239
  Picture   : 0.5716
  Sofa      : 0.5065
  Table     : 0.5491
  TV        : 0.6091
  Wall      : 0.7692
  Window    : 0.6056

[20260108_135621] Epoch 12/50




Train Loss: 0.2945 | Val Loss: 0.5035
Val mIoU: 0.5935
Class-wise IoU:
  Bed       : 0.6554
  Books     : 0.3417
  Ceiling   : 0.3963
  Chair     : 0.6108
  Floor     : 0.8876
  Furniture : 0.6622
  Objects   : 0.5354
  Picture   : 0.4792
  Sofa      : 0.5756
  Table     : 0.5305
  TV        : 0.6198
  Wall      : 0.7719
  Window    : 0.6493

[20260108_135651] Epoch 13/50




Train Loss: 0.2843 | Val Loss: 0.4907
Val mIoU: 0.6078
Class-wise IoU:
  Bed       : 0.7481
  Books     : 0.3567
  Ceiling   : 0.4032
  Chair     : 0.6415
  Floor     : 0.9070
  Furniture : 0.6785
  Objects   : 0.5758
  Picture   : 0.4003
  Sofa      : 0.5487
  Table     : 0.5465
  TV        : 0.7129
  Wall      : 0.7683
  Window    : 0.6139
>>> Best Model Saved!

[20260108_135722] Epoch 14/50




Train Loss: 0.2846 | Val Loss: 0.4925
Val mIoU: 0.6288
Class-wise IoU:
  Bed       : 0.7346
  Books     : 0.3865
  Ceiling   : 0.6200
  Chair     : 0.5749
  Floor     : 0.9192
  Furniture : 0.6843
  Objects   : 0.5484
  Picture   : 0.4989
  Sofa      : 0.6314
  Table     : 0.5313
  TV        : 0.6808
  Wall      : 0.7460
  Window    : 0.6178
>>> Best Model Saved!

[20260108_135753] Epoch 15/50




Train Loss: 0.2811 | Val Loss: 0.5329
Val mIoU: 0.6089
Class-wise IoU:
  Bed       : 0.6898
  Books     : 0.3346
  Ceiling   : 0.4981
  Chair     : 0.6126
  Floor     : 0.8837
  Furniture : 0.6477
  Objects   : 0.5856
  Picture   : 0.5431
  Sofa      : 0.5606
  Table     : 0.5505
  TV        : 0.6722
  Wall      : 0.7477
  Window    : 0.5889

[20260108_135823] Epoch 16/50




Train Loss: 0.2529 | Val Loss: 0.5356
Val mIoU: 0.6051
Class-wise IoU:
  Bed       : 0.6250
  Books     : 0.2711
  Ceiling   : 0.1093
  Chair     : 0.6735
  Floor     : 0.9193
  Furniture : 0.6740
  Objects   : 0.5374
  Picture   : 0.5130
  Sofa      : 0.7239
  Table     : 0.5861
  TV        : 0.9088
  Wall      : 0.7934
  Window    : 0.5310

[20260108_135854] Epoch 17/50




Train Loss: 0.2602 | Val Loss: 0.4961
Val mIoU: 0.6019
Class-wise IoU:
  Bed       : 0.6724
  Books     : 0.3500
  Ceiling   : 0.3592
  Chair     : 0.5773
  Floor     : 0.9226
  Furniture : 0.6722
  Objects   : 0.5564
  Picture   : 0.4836
  Sofa      : 0.5937
  Table     : 0.6529
  TV        : 0.5279
  Wall      : 0.7988
  Window    : 0.6583

[20260108_135924] Epoch 18/50




Train Loss: 0.2395 | Val Loss: 0.4773
Val mIoU: 0.6051
Class-wise IoU:
  Bed       : 0.7318
  Books     : 0.3002
  Ceiling   : 0.1729
  Chair     : 0.6789
  Floor     : 0.8978
  Furniture : 0.7041
  Objects   : 0.5834
  Picture   : 0.6163
  Sofa      : 0.6741
  Table     : 0.5866
  TV        : 0.5073
  Wall      : 0.8012
  Window    : 0.6123

[20260108_135955] Epoch 19/50




Train Loss: 0.2371 | Val Loss: 0.4632
Val mIoU: 0.6463
Class-wise IoU:
  Bed       : 0.7171
  Books     : 0.3404
  Ceiling   : 0.5046
  Chair     : 0.7034
  Floor     : 0.9119
  Furniture : 0.6813
  Objects   : 0.5994
  Picture   : 0.5671
  Sofa      : 0.7346
  Table     : 0.5932
  TV        : 0.6277
  Wall      : 0.7846
  Window    : 0.6367
>>> Best Model Saved!

[20260108_140026] Epoch 20/50




Train Loss: 0.2223 | Val Loss: 0.4709
Val mIoU: 0.6170
Class-wise IoU:
  Bed       : 0.7247
  Books     : 0.3691
  Ceiling   : 0.2812
  Chair     : 0.6650
  Floor     : 0.9064
  Furniture : 0.6742
  Objects   : 0.5929
  Picture   : 0.5183
  Sofa      : 0.6474
  Table     : 0.5561
  TV        : 0.6455
  Wall      : 0.8017
  Window    : 0.6388
📊 Saving Analysis for Epoch 20...
   -> Saved 5 worst error samples and Confusion Matrix.

[20260108_140102] Epoch 21/50




Train Loss: 0.2364 | Val Loss: 0.4410
Val mIoU: 0.6570
Class-wise IoU:
  Bed       : 0.7527
  Books     : 0.3382
  Ceiling   : 0.4847
  Chair     : 0.6416
  Floor     : 0.9049
  Furniture : 0.6783
  Objects   : 0.5797
  Picture   : 0.5963
  Sofa      : 0.6542
  Table     : 0.5938
  TV        : 0.8752
  Wall      : 0.8176
  Window    : 0.6237
>>> Best Model Saved!

[20260108_140133] Epoch 22/50




Train Loss: 0.2212 | Val Loss: 0.4650
Val mIoU: 0.6467
Class-wise IoU:
  Bed       : 0.7086
  Books     : 0.4226
  Ceiling   : 0.5418
  Chair     : 0.6490
  Floor     : 0.8961
  Furniture : 0.7061
  Objects   : 0.6189
  Picture   : 0.6142
  Sofa      : 0.6208
  Table     : 0.5665
  TV        : 0.6375
  Wall      : 0.8200
  Window    : 0.6047

[20260108_140204] Epoch 23/50




Train Loss: 0.2191 | Val Loss: 0.4887
Val mIoU: 0.6236
Class-wise IoU:
  Bed       : 0.6287
  Books     : 0.4424
  Ceiling   : 0.4420
  Chair     : 0.6745
  Floor     : 0.9253
  Furniture : 0.6799
  Objects   : 0.5859
  Picture   : 0.5088
  Sofa      : 0.5325
  Table     : 0.6368
  TV        : 0.5410
  Wall      : 0.8207
  Window    : 0.6879

[20260108_140234] Epoch 24/50




Train Loss: 0.2023 | Val Loss: 0.4718
Val mIoU: 0.6360
Class-wise IoU:
  Bed       : 0.7295
  Books     : 0.3875
  Ceiling   : 0.4832
  Chair     : 0.6269
  Floor     : 0.9150
  Furniture : 0.6619
  Objects   : 0.6110
  Picture   : 0.4651
  Sofa      : 0.6112
  Table     : 0.5957
  TV        : 0.6990
  Wall      : 0.8132
  Window    : 0.6687

[20260108_140305] Epoch 25/50




Train Loss: 0.2097 | Val Loss: 0.5132
Val mIoU: 0.6000
Class-wise IoU:
  Bed       : 0.6080
  Books     : 0.3821
  Ceiling   : 0.2946
  Chair     : 0.5888
  Floor     : 0.9024
  Furniture : 0.6724
  Objects   : 0.5848
  Picture   : 0.6045
  Sofa      : 0.5954
  Table     : 0.5945
  TV        : 0.5123
  Wall      : 0.8184
  Window    : 0.6417

[20260108_140335] Epoch 26/50




Train Loss: 0.2071 | Val Loss: 0.4991
Val mIoU: 0.6142
Class-wise IoU:
  Bed       : 0.6380
  Books     : 0.3935
  Ceiling   : 0.4421
  Chair     : 0.6340
  Floor     : 0.8792
  Furniture : 0.6678
  Objects   : 0.5984
  Picture   : 0.6180
  Sofa      : 0.5316
  Table     : 0.5680
  TV        : 0.5471
  Wall      : 0.8359
  Window    : 0.6308

[20260108_140406] Epoch 27/50




Train Loss: 0.2018 | Val Loss: 0.5093
Val mIoU: 0.6110
Class-wise IoU:
  Bed       : 0.6923
  Books     : 0.4299
  Ceiling   : 0.2392
  Chair     : 0.6278
  Floor     : 0.9032
  Furniture : 0.7043
  Objects   : 0.5516
  Picture   : 0.5503
  Sofa      : 0.5722
  Table     : 0.6067
  TV        : 0.7028
  Wall      : 0.7457
  Window    : 0.6170

[20260108_140437] Epoch 28/50




Train Loss: 0.2059 | Val Loss: 0.4950
Val mIoU: 0.6159
Class-wise IoU:
  Bed       : 0.7373
  Books     : 0.3894
  Ceiling   : 0.4811
  Chair     : 0.6736
  Floor     : 0.8539
  Furniture : 0.6742
  Objects   : 0.5732
  Picture   : 0.4716
  Sofa      : 0.6168
  Table     : 0.6415
  TV        : 0.3865
  Wall      : 0.7889
  Window    : 0.7186

[20260108_140515] Epoch 29/50




Train Loss: 0.1953 | Val Loss: 0.4740
Val mIoU: 0.6302
Class-wise IoU:
  Bed       : 0.7539
  Books     : 0.2851
  Ceiling   : 0.4166
  Chair     : 0.6010
  Floor     : 0.9008
  Furniture : 0.6870
  Objects   : 0.5729
  Picture   : 0.5503
  Sofa      : 0.7671
  Table     : 0.6340
  TV        : 0.5176
  Wall      : 0.7874
  Window    : 0.7190

[20260108_140547] Epoch 30/50




Train Loss: 0.1889 | Val Loss: 0.4720
Val mIoU: 0.6331
Class-wise IoU:
  Bed       : 0.7269
  Books     : 0.3698
  Ceiling   : 0.3898
  Chair     : 0.6957
  Floor     : 0.9132
  Furniture : 0.6724
  Objects   : 0.6201
  Picture   : 0.5003
  Sofa      : 0.7098
  Table     : 0.6424
  TV        : 0.5203
  Wall      : 0.8236
  Window    : 0.6466
📊 Saving Analysis for Epoch 30...
   -> Saved 5 worst error samples and Confusion Matrix.

[20260108_140624] Epoch 31/50




Train Loss: 0.1844 | Val Loss: 0.4836
Val mIoU: 0.6382
Class-wise IoU:
  Bed       : 0.6931
  Books     : 0.4301
  Ceiling   : 0.4261
  Chair     : 0.5970
  Floor     : 0.9158
  Furniture : 0.7274
  Objects   : 0.5713
  Picture   : 0.5587
  Sofa      : 0.6436
  Table     : 0.6361
  TV        : 0.5717
  Wall      : 0.8146
  Window    : 0.7114

[20260108_140656] Epoch 32/50




Train Loss: 0.1822 | Val Loss: 0.5204
Val mIoU: 0.6159
Class-wise IoU:
  Bed       : 0.7084
  Books     : 0.4498
  Ceiling   : 0.2822
  Chair     : 0.6613
  Floor     : 0.9175
  Furniture : 0.6670
  Objects   : 0.5945
  Picture   : 0.4824
  Sofa      : 0.7208
  Table     : 0.5442
  TV        : 0.6512
  Wall      : 0.7777
  Window    : 0.5496

[20260108_140727] Epoch 33/50




Train Loss: 0.1825 | Val Loss: 0.4679
Val mIoU: 0.6580
Class-wise IoU:
  Bed       : 0.6442
  Books     : 0.4313
  Ceiling   : 0.5438
  Chair     : 0.6932
  Floor     : 0.9072
  Furniture : 0.6968
  Objects   : 0.6241
  Picture   : 0.5811
  Sofa      : 0.6359
  Table     : 0.6647
  TV        : 0.5247
  Wall      : 0.8357
  Window    : 0.7709
>>> Best Model Saved!

[20260108_140759] Epoch 34/50




Train Loss: 0.1767 | Val Loss: 0.4648
Val mIoU: 0.6348
Class-wise IoU:
  Bed       : 0.7425
  Books     : 0.4122
  Ceiling   : 0.4700
  Chair     : 0.6729
  Floor     : 0.9044
  Furniture : 0.7163
  Objects   : 0.5685
  Picture   : 0.5466
  Sofa      : 0.6171
  Table     : 0.6660
  TV        : 0.5509
  Wall      : 0.7763
  Window    : 0.6089

[20260108_140831] Epoch 35/50




Train Loss: 0.1719 | Val Loss: 0.4721
Val mIoU: 0.6538
Class-wise IoU:
  Bed       : 0.6987
  Books     : 0.4309
  Ceiling   : 0.6729
  Chair     : 0.6152
  Floor     : 0.9198
  Furniture : 0.6536
  Objects   : 0.6271
  Picture   : 0.6053
  Sofa      : 0.5171
  Table     : 0.5755
  TV        : 0.6427
  Wall      : 0.8079
  Window    : 0.7330

[20260108_140903] Epoch 36/50




Train Loss: 0.1784 | Val Loss: 0.4765
Val mIoU: 0.6310
Class-wise IoU:
  Bed       : 0.7134
  Books     : 0.4337
  Ceiling   : 0.4056
  Chair     : 0.6395
  Floor     : 0.9132
  Furniture : 0.6781
  Objects   : 0.6065
  Picture   : 0.6679
  Sofa      : 0.4472
  Table     : 0.5953
  TV        : 0.5567
  Wall      : 0.8237
  Window    : 0.7219

[20260108_140934] Epoch 37/50




Train Loss: 0.1747 | Val Loss: 0.4640
Val mIoU: 0.6691
Class-wise IoU:
  Bed       : 0.7283
  Books     : 0.3038
  Ceiling   : 0.6765
  Chair     : 0.7107
  Floor     : 0.9235
  Furniture : 0.7032
  Objects   : 0.6191
  Picture   : 0.6476
  Sofa      : 0.6218
  Table     : 0.6188
  TV        : 0.6828
  Wall      : 0.8071
  Window    : 0.6546
>>> Best Model Saved!

[20260108_141006] Epoch 38/50




Train Loss: 0.1709 | Val Loss: 0.4240
Val mIoU: 0.6551
Class-wise IoU:
  Bed       : 0.6797
  Books     : 0.4198
  Ceiling   : 0.4802
  Chair     : 0.6477
  Floor     : 0.9241
  Furniture : 0.7325
  Objects   : 0.6206
  Picture   : 0.6422
  Sofa      : 0.6034
  Table     : 0.6485
  TV        : 0.5285
  Wall      : 0.8137
  Window    : 0.7757

[20260108_141037] Epoch 39/50




Train Loss: 0.1662 | Val Loss: 0.4166
Val mIoU: 0.7018
Class-wise IoU:
  Bed       : 0.7222
  Books     : 0.4162
  Ceiling   : 0.5195
  Chair     : 0.6745
  Floor     : 0.9152
  Furniture : 0.7084
  Objects   : 0.6198
  Picture   : 0.7399
  Sofa      : 0.7330
  Table     : 0.6638
  TV        : 0.8307
  Wall      : 0.8298
  Window    : 0.7506
>>> Best Model Saved!

[20260108_141109] Epoch 40/50




Train Loss: 0.1671 | Val Loss: 0.4705
Val mIoU: 0.6322
Class-wise IoU:
  Bed       : 0.6209
  Books     : 0.3775
  Ceiling   : 0.3032
  Chair     : 0.6465
  Floor     : 0.9306
  Furniture : 0.7163
  Objects   : 0.6204
  Picture   : 0.6974
  Sofa      : 0.6044
  Table     : 0.6594
  TV        : 0.5243
  Wall      : 0.8164
  Window    : 0.7021
📊 Saving Analysis for Epoch 40...
   -> Saved 5 worst error samples and Confusion Matrix.

[20260108_141146] Epoch 41/50




Train Loss: 0.1734 | Val Loss: 0.4610
Val mIoU: 0.6345
Class-wise IoU:
  Bed       : 0.7098
  Books     : 0.4151
  Ceiling   : 0.2558
  Chair     : 0.6837
  Floor     : 0.9190
  Furniture : 0.6742
  Objects   : 0.6330
  Picture   : 0.6750
  Sofa      : 0.6857
  Table     : 0.5773
  TV        : 0.4924
  Wall      : 0.8033
  Window    : 0.7238

[20260108_141218] Epoch 42/50




Train Loss: 0.1603 | Val Loss: 0.4606
Val mIoU: 0.6361
Class-wise IoU:
  Bed       : 0.6659
  Books     : 0.4189
  Ceiling   : 0.2697
  Chair     : 0.7317
  Floor     : 0.9413
  Furniture : 0.7028
  Objects   : 0.6082
  Picture   : 0.6553
  Sofa      : 0.5977
  Table     : 0.6332
  TV        : 0.4711
  Wall      : 0.8122
  Window    : 0.7609

[20260108_141250] Epoch 43/50




Train Loss: 0.1598 | Val Loss: 0.4583
Val mIoU: 0.6439
Class-wise IoU:
  Bed       : 0.7423
  Books     : 0.4386
  Ceiling   : 0.2511
  Chair     : 0.7355
  Floor     : 0.9063
  Furniture : 0.7157
  Objects   : 0.6399
  Picture   : 0.6749
  Sofa      : 0.5897
  Table     : 0.5924
  TV        : 0.5263
  Wall      : 0.8278
  Window    : 0.7307

[20260108_141321] Epoch 44/50




Train Loss: 0.1597 | Val Loss: 0.4583
Val mIoU: 0.6483
Class-wise IoU:
  Bed       : 0.7767
  Books     : 0.4141
  Ceiling   : 0.2482
  Chair     : 0.6627
  Floor     : 0.9333
  Furniture : 0.7181
  Objects   : 0.6291
  Picture   : 0.5612
  Sofa      : 0.7279
  Table     : 0.6439
  TV        : 0.6448
  Wall      : 0.7979
  Window    : 0.6701

[20260108_141352] Epoch 45/50




Train Loss: 0.1569 | Val Loss: 0.4489
Val mIoU: 0.6504
Class-wise IoU:
  Bed       : 0.7019
  Books     : 0.4185
  Ceiling   : 0.4162
  Chair     : 0.6208
  Floor     : 0.9377
  Furniture : 0.7251
  Objects   : 0.6292
  Picture   : 0.5160
  Sofa      : 0.7703
  Table     : 0.5879
  TV        : 0.7075
  Wall      : 0.7978
  Window    : 0.6259

[20260108_141424] Epoch 46/50




Train Loss: 0.1527 | Val Loss: 0.4370
Val mIoU: 0.6985
Class-wise IoU:
  Bed       : 0.8005
  Books     : 0.4332
  Ceiling   : 0.5466
  Chair     : 0.6157
  Floor     : 0.9235
  Furniture : 0.7274
  Objects   : 0.6233
  Picture   : 0.6320
  Sofa      : 0.7258
  Table     : 0.5745
  TV        : 0.8964
  Wall      : 0.8252
  Window    : 0.7569

[20260108_141456] Epoch 47/50




Train Loss: 0.1498 | Val Loss: 0.4794
Val mIoU: 0.6559
Class-wise IoU:
  Bed       : 0.6955
  Books     : 0.4012
  Ceiling   : 0.5528
  Chair     : 0.6756
  Floor     : 0.9172
  Furniture : 0.7245
  Objects   : 0.6154
  Picture   : 0.6117
  Sofa      : 0.5582
  Table     : 0.6819
  TV        : 0.5334
  Wall      : 0.8152
  Window    : 0.7440

[20260108_141527] Epoch 48/50




Train Loss: 0.1512 | Val Loss: 0.5015
Val mIoU: 0.6202
Class-wise IoU:
  Bed       : 0.6831
  Books     : 0.3966
  Ceiling   : 0.2897
  Chair     : 0.6059
  Floor     : 0.9002
  Furniture : 0.7352
  Objects   : 0.5915
  Picture   : 0.6258
  Sofa      : 0.6668
  Table     : 0.6528
  TV        : 0.4613
  Wall      : 0.8078
  Window    : 0.6456

[20260108_141559] Epoch 49/50




Train Loss: 0.1471 | Val Loss: 0.4740
Val mIoU: 0.6746
Class-wise IoU:
  Bed       : 0.7608
  Books     : 0.4215
  Ceiling   : 0.5195
  Chair     : 0.6767
  Floor     : 0.9238
  Furniture : 0.6955
  Objects   : 0.6085
  Picture   : 0.6108
  Sofa      : 0.7176
  Table     : 0.6509
  TV        : 0.6857
  Wall      : 0.8486
  Window    : 0.6498

[20260108_141630] Epoch 50/50




Train Loss: 0.1431 | Val Loss: 0.4377
Val mIoU: 0.6812
Class-wise IoU:
  Bed       : 0.6898
  Books     : 0.4158
  Ceiling   : 0.4500
  Chair     : 0.7261
  Floor     : 0.9129
  Furniture : 0.7872
  Objects   : 0.6312
  Picture   : 0.6639
  Sofa      : 0.7688
  Table     : 0.7058
  TV        : 0.5322
  Wall      : 0.8432
  Window    : 0.7282
📊 Saving Analysis for Epoch 50...
   -> Saved 5 worst error samples and Confusion Matrix.
📦 Zipping logs...
✅ Created: /content/logs_20260108_135015.zip


In [None]:
"""
コードの目的:
    Deep Learning 基礎講座 最終課題 (NYUv2 Semantic Segmentation)
    ハイブリッド学習戦略（本5倍増強＋オリジナル全データ）を用いて、
    Bookクラスの再現率と全体精度の両立を目指す。
    コンソールログには詳細なメトリクスを表示し、ファイルにも数値データを記録する。

処理内容:
    1. モデル: mit_b4 + UPerNet (5ch入力)
    2. データ:
       - Train: (Book画像 x 5) + (オリジナル全画像 x 1) の混合セット
       - Val: 生データ (重複なし)
    3. ログ出力:
       - Console: mIoU, Book詳細(IoU/Prec/Rec), FP/FN確度, 全クラスIoU
       - File: JSONL(時系列数値), NPY(混合行列), PNG(ワースト画像)
       - Archive: 学習終了時にZip圧縮
"""

import os
import json
import random
import time
import datetime
import pytz
import shutil
import numpy as np
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from tqdm import tqdm
from torch.cuda.amp import GradScaler
from sklearn.metrics import confusion_matrix

# ==========================================
# 0. Optimization Flags
# ==========================================
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.benchmark = True

# ==========================================
# 1. Configuration
# ==========================================
def get_jst_time_str():
    return datetime.datetime.now(pytz.timezone('Asia/Tokyo')).strftime('%Y%m%d_%H%M%S')

class Config:
    DATA_ROOT = '/content/data'

    # Model
    ENCODER = 'mit_b4'
    ENCODER_WEIGHTS = 'imagenet'
    CLASSES = 13
    INPUT_CHANNELS = 5

    # Training
    CROP_SIZE = 320
    IMG_SIZE = 512
    BATCH_SIZE = 16
    EPOCHS = 50
    LR = 6e-5
    WEIGHT_DECAY = 0.01
    SEED = 42
    device = torch.device('cuda')

    # Strategy
    PASTE_PROB = 0.7
    OVERSAMPLE_FACTOR = 5
    MIX_ORIGINAL = True

    # Logging
    LOG_TIMESTAMP = get_jst_time_str()
    DEBUG_DIR = f'/content/debug_hybrid_{LOG_TIMESTAMP}'
    SAVE_INTERVAL = 10
    NUM_WORST_SAMPLES = 5

    GOLDEN_SAMPLES = [
        {'id': '000066', 'bbox': [407, 205, 89, 142]},
        {'id': '000072', 'bbox': [366, 412, 79, 63]},
        {'id': '000097', 'bbox': [271, 363, 51, 64]},
        {'id': '000105', 'bbox': [159, 216, 69, 20]},
        {'id': '000107', 'bbox': [244, 280, 142, 42]},
        {'id': '000109', 'bbox': [41, 285, 104, 62]},
        {'id': '000177', 'bbox': [77, 284, 69, 19]},
        {'id': '000353', 'bbox': [257, 218, 76, 31]},
    ]

    CLASS_NAMES = [
        "Bed", "Books", "Ceiling", "Chair", "Floor", "Furniture",
        "Objects", "Picture", "Sofa", "Table", "TV", "Wall", "Window"
    ]
    CLASS_WEIGHTS = [1.0, 10.0, 1.0, 1.2, 1.0, 1.2, 2.0, 1.2, 1.0, 1.2, 3.0, 0.8, 1.0]

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

# ==========================================
# 2. Augmentor
# ==========================================
class BookDepthAwareAugmentor:
    def __init__(self, dataset, prob=0.5):
        self.book_bank = []
        self.prob = prob
        self.load_golden_samples(dataset)
        self.collect_books(dataset)

    def load_golden_samples(self, dataset):
        print("💎 Loading Golden Samples...")
        for item in Config.GOLDEN_SAMPLES:
            id_ = item['id']; x, y, w, h = item['bbox']
            img_path = os.path.join(dataset.img_dir, f"{id_}.png"); depth_path = os.path.join(dataset.depth_dir, f"{id_}.png"); label_path = os.path.join(dataset.label_dir, f"{id_}.png")
            if not os.path.exists(img_path): continue
            image = cv2.imread(img_path); image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            depth = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH)
            if depth is None: depth = np.zeros(image.shape[:2], np.float32)
            if depth.max() > 0: depth = depth.astype(np.float32) / depth.max()
            edge = dataset.compute_edge_map(image)
            label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
            roi_rgb = image[y:y+h, x:x+w]; roi_depth = depth[y:y+h, x:x+w]; roi_edge = edge[y:y+h, x:x+w]; roi_label = label[y:y+h, x:x+w]
            blob_mask = (roi_label == 1).astype(np.uint8)
            if np.sum(blob_mask) == 0: continue
            self.book_bank.append({'rgb': roi_rgb, 'depth': roi_depth, 'edge': roi_edge, 'mask': blob_mask, 'mean_depth': np.mean(roi_depth)})

    def collect_books(self, dataset):
        print("Collecting 'Auto' book blobs...")
        unique_scan_ids = sorted(list(set(dataset.ids)))
        for id_ in tqdm(unique_scan_ids, desc="Scanning"):
            label_path = os.path.join(dataset.label_dir, f"{id_}.png")
            label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
            if label is None or not np.any(label == 1): continue
            book_mask = (label == 1).astype(np.uint8)
            num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(book_mask, connectivity=8)
            if num_labels <= 1: continue
            img_path = os.path.join(dataset.img_dir, f"{id_}.png"); depth_path = os.path.join(dataset.depth_dir, f"{id_}.png")
            image = cv2.imread(img_path); image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            depth = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH)
            if depth is None: depth = np.zeros_like(label, dtype=np.float32)
            if depth.max() > 0: depth = depth.astype(np.float32) / depth.max()
            edge = dataset.compute_edge_map(image)
            for j in range(1, num_labels):
                area = stats[j, cv2.CC_STAT_AREA]
                if area < 300: continue
                x = stats[j, cv2.CC_STAT_LEFT]; y = stats[j, cv2.CC_STAT_TOP]; w = stats[j, cv2.CC_STAT_WIDTH]; h = stats[j, cv2.CC_STAT_HEIGHT]
                if w/h > 10.0 or w/h < 0.1 or area/(w*h) < 0.5: continue
                blob_mask = (labels[y:y+h, x:x+w] == j).astype(np.uint8)
                self.book_bank.append({'rgb': image[y:y+h, x:x+w], 'depth': depth[y:y+h, x:x+w], 'edge': edge[y:y+h, x:x+w], 'mask': blob_mask, 'mean_depth': np.mean(depth[y:y+h, x:x+w])})

    def apply(self, image, label):
        if random.random() > self.prob or len(self.book_bank) == 0: return image, label
        H, W = label.shape
        mask_target = ((label == 9) | (label == 5) | (label == 4)).astype(np.uint8)
        ys, xs = np.where(mask_target > 0)
        if len(ys) == 0: return image, label
        idx = random.randint(0, len(ys) - 1); y_t, x_t = ys[idx], xs[idx]
        blob = random.choice(self.book_bank)
        target_d = image[y_t, x_t, 3] / 255.0; source_d = blob['mean_depth']
        scale = np.clip(source_d / target_d, 0.3, 1.8) if target_d > 0.01 and source_d > 0.01 else 1.0
        blob_h, blob_w = blob['mask'].shape; new_w, new_h = int(blob_w * scale), int(blob_h * scale)
        if new_w <= 0 or new_h <= 0: return image, label
        blob_rgb = cv2.resize(blob['rgb'], (new_w, new_h)); blob_depth = cv2.resize(blob['depth'], (new_w, new_h))
        blob_edge = cv2.resize(blob['edge'], (new_w, new_h)); blob_mask = cv2.resize(blob['mask'], (new_w, new_h), interpolation=cv2.INTER_NEAREST)
        k = random.choice([0, 1, 2, 3])
        if k > 0: blob_rgb = np.rot90(blob_rgb, k); blob_depth = np.rot90(blob_depth, k); blob_edge = np.rot90(blob_edge, k); blob_mask = np.rot90(blob_mask, k)
        new_h, new_w = blob_rgb.shape[:2]
        if new_h >= H or new_w >= W: return image, label
        y_t = min(y_t, H - new_h); x_t = min(x_t, W - new_w)
        roi_depth = image[y_t:y_t+new_h, x_t:x_t+new_w, 3]
        if roi_depth.shape[:2] != (new_h, new_w): return image, label
        new_d_map = (blob_depth * 255.0) + (image[y_t, x_t, 3] - (np.mean(blob_depth) * 255.0))
        is_in_front = new_d_map < (roi_depth + 5.0)
        final_mask = (blob_mask == 1) & is_in_front
        mask_3 = np.stack([final_mask]*3, axis=2)
        image[y_t:y_t+new_h, x_t:x_t+new_w, :3] = np.where(mask_3, blob_rgb, image[y_t:y_t+new_h, x_t:x_t+new_w, :3])
        image[y_t:y_t+new_h, x_t:x_t+new_w, 3] = np.where(final_mask, np.clip(new_d_map,0,255), image[y_t:y_t+new_h, x_t:x_t+new_w, 3])
        image[y_t:y_t+new_h, x_t:x_t+new_w, 4] = np.where(final_mask, blob_edge*255.0, image[y_t:y_t+new_h, x_t:x_t+new_w, 4])
        label[y_t:y_t+new_h, x_t:x_t+new_w] = np.where(final_mask, 1, label[y_t:y_t+new_h, x_t:x_t+new_w])
        return image, label

# ==========================================
# 3. Dataset
# ==========================================
def get_transforms(phase='train'):
    if phase == 'train':
        return {'color': A.Compose([A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5), A.RandomBrightnessContrast(p=0.5)]),
                'geo': A.Compose([A.GridDistortion(p=0.3), A.Perspective(p=0.3), A.ShiftScaleRotate(p=0.5), A.HorizontalFlip(p=0.5), A.Resize(Config.IMG_SIZE, Config.IMG_SIZE), A.Normalize(mean=[0.485, 0.456, 0.406, 0.5, 0.5], std=[0.229, 0.224, 0.225, 0.5, 0.5], max_pixel_value=255.0), ToTensorV2()])}
    else:
        return A.Compose([A.Resize(Config.IMG_SIZE, Config.IMG_SIZE), A.Normalize(mean=[0.485, 0.456, 0.406, 0.5, 0.5], std=[0.229, 0.224, 0.225, 0.5, 0.5], max_pixel_value=255.0), ToTensorV2()])

class NYUv2BookZoomDataset(Dataset):
    def __init__(self, root_dir, whitelist_ids=None, is_train=False, img_size=512, transform=None, augmentor=None):
        self.root_dir = root_dir; self.img_size = img_size; self.transform = transform; self.augmentor = augmentor
        self.img_dir = os.path.join(root_dir, 'train', 'image'); self.depth_dir = os.path.join(root_dir, 'train', 'depth'); self.label_dir = os.path.join(root_dir, 'train', 'label')

        base_ids = whitelist_ids if whitelist_ids is not None else sorted([os.path.splitext(f)[0] for f in os.listdir(self.img_dir) if f.endswith('.png')])

        if is_train:
            print(f"⚖️ Dataset Strategy: {Config.OVERSAMPLE_FACTOR}x Oversample + Original Anchor")
            boosted_ids = []
            for id_ in base_ids:
                label = cv2.imread(os.path.join(self.label_dir, f"{id_}.png"), cv2.IMREAD_GRAYSCALE)
                if label is not None and np.any(label == 1):
                    boosted_ids.extend([id_] * Config.OVERSAMPLE_FACTOR)
                else:
                    boosted_ids.extend([id_])

            original_ids = list(base_ids) if Config.MIX_ORIGINAL else []
            self.ids = boosted_ids + original_ids
            print(f"   - Boosted Set: {len(boosted_ids)}")
            print(f"   - Original Set: {len(original_ids)}")
            print(f"   -> Final Train Size: {len(self.ids)}")
        else:
            self.ids = base_ids

    def __len__(self): return len(self.ids)
    def compute_edge_map(self, img_array):
        gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY); sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3); sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
        return cv2.normalize(np.sqrt(sobelx**2 + sobely**2), None, 0, 1, cv2.NORM_MINMAX).astype(np.float32)
    def get_book_centric_crop(self, image, label):
        ys, xs = np.where(label == 1); cs = Config.CROP_SIZE; h, w = image.shape[:2]
        if len(ys) > 0: idx = random.randint(0, len(ys) - 1); cy, cx = ys[idx], xs[idx]; y1 = max(0, min(h-cs, cy - cs//2 + random.randint(-50, 50))); x1 = max(0, min(w-cs, cx - cs//2 + random.randint(-50, 50)))
        else: y1 = random.randint(0, h - cs); x1 = random.randint(0, w - cs)
        return image[y1:y1+cs, x1:x1+cs], label[y1:y1+cs, x1:x1+cs]
    def __getitem__(self, idx):
        id_ = self.ids[idx]; img_path = os.path.join(self.img_dir, f"{id_}.png"); depth_path = os.path.join(self.depth_dir, f"{id_}.png"); label_path = os.path.join(self.label_dir, f"{id_}.png")
        image = cv2.imread(img_path); image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB); depth = cv2.imread(depth_path, cv2.IMREAD_ANYDEPTH); label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
        if depth is None: depth = np.zeros(image.shape[:2], np.float32)
        if depth.max() > 0: depth = depth.astype(np.float32) / depth.max()
        label[~((label >= 0) & (label < Config.CLASSES))] = 255; edge = self.compute_edge_map(image)
        combined = np.dstack([image, depth*255.0, edge*255.0]).astype(np.float32)
        if self.augmentor: combined, label = self.augmentor.apply(combined, label)
        combined_crop, label_crop = self.get_book_centric_crop(combined, label)
        if self.transform:
            if isinstance(self.transform, dict):
                rgb_aug = self.transform['color'](image=combined_crop[:,:,:3].astype(np.uint8))['image']
                t_geo = self.transform['geo'](image=np.dstack([rgb_aug, combined_crop[:,:,3:]]), mask=label_crop)
                return t_geo['image'], t_geo['mask'].long()
            t = self.transform(image=combined_crop, mask=label_crop)
            return t['image'], t['mask'].long()
        return torch.from_numpy(combined_crop.transpose(2,0,1)).float(), torch.from_numpy(label_crop).long()

# ==========================================
# 4. Analysis Utils
# ==========================================
def save_worst_samples(worst_samples, epoch, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    mean = np.array([0.485, 0.456, 0.406]); std = np.array([0.229, 0.224, 0.225])
    for i, (score, img_t, mask_t, out_t) in enumerate(worst_samples):
        img_vis = np.clip((img_t[:3, :, :].cpu().numpy().transpose(1, 2, 0) * std + mean) * 255.0, 0, 255).astype(np.uint8)
        probs = F.softmax(out_t.unsqueeze(0), dim=1); pred_t = torch.argmax(probs, dim=1).squeeze(0)
        book_prob = probs[0, 1, :, :].cpu().numpy()
        gt_book = (mask_t.cpu().numpy() == 1); pred_book = (pred_t.cpu().numpy() == 1)

        fig, axes = plt.subplots(1, 4, figsize=(20, 5))
        axes[0].imshow(img_vis); axes[0].set_title(f"Rank{i+1}: Err={int(score)}")
        axes[1].imshow(gt_book, cmap='gray'); axes[1].set_title("GT Book")
        im = axes[2].imshow(book_prob, cmap='jet', vmin=0, vmax=1.0); axes[2].set_title("Book Prob")
        plt.colorbar(im, ax=axes[2], fraction=0.046, pad=0.04)
        error_vis = (img_vis * 0.7 + 255 * 0.3).astype(np.uint8)
        error_vis[pred_book & (~gt_book)] = [255, 0, 0]; error_vis[(~pred_book) & gt_book] = [0, 0, 255]
        axes[3].imshow(error_vis); axes[3].set_title("Red=FP, Blue=FN")
        for ax in axes: ax.axis('off')
        plt.tight_layout(); plt.savefig(os.path.join(save_dir, f"epoch_{epoch:03d}_worst_{i+1:02d}.png")); plt.close()

def save_confusion_matrix(cm, epoch, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    cm_norm = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-6)
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm_norm, annot=False, fmt='.2f', cmap='Blues', xticklabels=Config.CLASS_NAMES, yticklabels=Config.CLASS_NAMES)
    plt.ylabel('True'); plt.xlabel('Predicted'); plt.title(f'Confusion Matrix (Norm) - Epoch {epoch}')
    plt.tight_layout(); plt.savefig(os.path.join(save_dir, f"epoch_{epoch:03d}_conf_mat.png")); plt.close()
    np.save(os.path.join(save_dir, f"epoch_{epoch:03d}_conf_mat.npy"), cm)

# ==========================================
# 5. Main Execution
# ==========================================
def main():
    set_seed(Config.SEED)
    os.makedirs(Config.DEBUG_DIR, exist_ok=True)
    print(f"[{get_jst_time_str()}] 🚀 Start Training: Hybrid Anchor Strategy")

    jsonl_path = os.path.join(Config.DEBUG_DIR, 'training_logs.jsonl')
    with open(jsonl_path, 'w') as f: pass

    model = smp.UPerNet(encoder_name=Config.ENCODER, encoder_weights=Config.ENCODER_WEIGHTS, in_channels=3, classes=Config.CLASSES, activation=None)
    if hasattr(model.encoder, 'patch_embed1') and hasattr(model.encoder.patch_embed1, 'proj'):
        old = model.encoder.patch_embed1.proj; new_l = nn.Conv2d(5, old.out_channels, old.kernel_size, old.stride, old.padding, bias=(old.bias is not None))
        with torch.no_grad():
            new_l.weight[:, :3, :, :] = old.weight; new_l.weight[:, 3:, :, :] = torch.mean(old.weight, dim=1, keepdim=True).repeat(1, 2, 1, 1)
            if old.bias is not None: new_l.bias = old.bias
        model.encoder.patch_embed1.proj = new_l
    model = model.to(Config.device)

    optimizer = optim.AdamW(model.parameters(), lr=Config.LR, weight_decay=Config.WEIGHT_DECAY)
    criterion = LovaszWeightedLoss(Config.CLASS_WEIGHTS, Config.device, ignore_index=255)
    scaler = torch.amp.GradScaler('cuda')
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.EPOCHS, eta_min=1e-7)

    img_dir_train = os.path.join(Config.DATA_ROOT, 'train', 'image')
    all_ids = sorted([os.path.splitext(f)[0] for f in os.listdir(img_dir_train) if f.endswith('.png')])
    random.shuffle(all_ids)
    n_val = int(len(all_ids) * 0.1)
    val_ids = all_ids[:n_val]; train_ids = all_ids[n_val:]

    temp_ds = NYUv2BookZoomDataset(Config.DATA_ROOT, whitelist_ids=train_ids, is_train=False)
    train_augmentor = BookDepthAwareAugmentor(temp_ds, prob=Config.PASTE_PROB)
    train_ds = NYUv2BookZoomDataset(Config.DATA_ROOT, whitelist_ids=train_ids, is_train=True, transform=get_transforms('train'), augmentor=train_augmentor)
    val_ds = NYUv2BookZoomDataset(Config.DATA_ROOT, whitelist_ids=val_ids, is_train=False, transform=get_transforms('val'))
    train_loader = DataLoader(train_ds, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=Config.BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=True)

    best_miou = 0.0

    for epoch in range(Config.EPOCHS):
        print(f"\n[{get_jst_time_str()}] Epoch {epoch+1}/{Config.EPOCHS}")

        # Train
        model.train(); train_loss = 0
        for images, masks in tqdm(train_loader, desc="Training", leave=False):
            images = images.to(Config.device, non_blocking=True); masks = masks.to(Config.device, non_blocking=True)
            optimizer.zero_grad()
            with torch.amp.autocast('cuda'): outputs = model(images); loss = criterion(outputs, masks)
            scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
            train_loss += loss.item()

        # Val
        model.eval(); val_loss = 0
        iou_scores, prec_scores, rec_scores = [], [], []
        fp_conf_accum, fn_conf_accum = [], []
        worst_samples = []
        conf_mat = np.zeros((Config.CLASSES, Config.CLASSES), dtype=np.int64)

        do_analysis = ((epoch + 1) % Config.SAVE_INTERVAL == 0) or ((epoch + 1) == Config.EPOCHS)

        with torch.no_grad():
            with torch.amp.autocast('cuda'):
                for batch_idx, (images, masks) in enumerate(tqdm(val_loader, desc="Validation", leave=False)):
                    images = images.to(Config.device, non_blocking=True); masks = masks.to(Config.device, non_blocking=True)
                    outputs = model(images); loss = criterion(outputs, masks); val_loss += loss.item()
                    preds = outputs.argmax(dim=1)

                    # Compute Batch Metrics
                    def compute_batch_metrics(p, m):
                        p = p.view(-1); m = m.view(-1); valid = m != 255; p = p[valid]; m = m[valid]
                        ious, precs, recs = [], [], []
                        for c in range(Config.CLASSES):
                            tp = ((p==c)&(m==c)).sum().item(); fp = ((p==c)&(m!=c)).sum().item(); fn = ((p!=c)&(m==c)).sum().item()
                            ious.append(tp/(tp+fp+fn) if (tp+fp+fn)>0 else np.nan)
                            precs.append(tp/(tp+fp) if (tp+fp)>0 else np.nan)
                            recs.append(tp/(tp+fn) if (tp+fn)>0 else np.nan)
                        return np.array(ious), np.array(precs), np.array(recs)

                    bi, bp, br = compute_batch_metrics(preds, masks)
                    iou_scores.append(bi); prec_scores.append(bp); rec_scores.append(br)

                    # Analysis
                    probs = F.softmax(outputs, dim=1); book_probs = probs[:, 1, :, :]
                    gt_book = (masks == 1); pred_book = (preds == 1)
                    fp_mask = pred_book & (~gt_book); fn_mask = (~pred_book) & gt_book
                    if fp_mask.any(): fp_conf_accum.append(book_probs[fp_mask].mean().item())
                    if fn_mask.any(): fn_conf_accum.append(book_probs[fn_mask].mean().item())

                    if do_analysis:
                        m_np = masks.cpu().numpy().flatten(); p_np = preds.cpu().numpy().flatten(); valid = m_np != 255
                        conf_mat += confusion_matrix(m_np[valid], p_np[valid], labels=range(Config.CLASSES))
                        for i in range(len(images)):
                            error_score = (pred_book[i] & (~gt_book[i])).sum().item() + ((~pred_book[i]) & gt_book[i]).sum().item()
                            if error_score > 0: worst_samples.append((error_score, images[i].cpu(), masks[i].cpu(), outputs[i].cpu()))

        val_miou = np.nanmean(np.nanmean(np.array(iou_scores), axis=0))
        cls_ious = np.nanmean(np.array(iou_scores), axis=0)
        mean_precs = np.nanmean(np.array(prec_scores), axis=0)
        mean_recs = np.nanmean(np.array(rec_scores), axis=0)
        mean_fp_conf = np.mean(fp_conf_accum) if fp_conf_accum else 0.0
        mean_fn_conf = np.mean(fn_conf_accum) if fn_conf_accum else 0.0

        # --- Console Output (Restore) ---
        print(f"Train Loss: {train_loss/len(train_loader):.4f} | Val Loss: {val_loss/len(val_loader):.4f}")
        print(f"Val mIoU: {val_miou:.4f}")
        print(f"Books [1] -> IoU: {cls_ious[1]:.4f} | Prec: {mean_precs[1]:.4f} | Rec: {mean_recs[1]:.4f}")
        print(f"Analysis  -> FP Conf: {mean_fp_conf:.4f} | FN Conf: {mean_fn_conf:.4f}")

        print("Class-wise IoU:")
        for i, name in enumerate(Config.CLASS_NAMES):
            print(f"  {name:10s}: {cls_ious[i]:.4f}")

        # --- File Output (JSONL) ---
        log_entry = {
            "epoch": epoch + 1, "timestamp": get_jst_time_str(),
            "train_loss": train_loss/len(train_loader), "val_loss": val_loss/len(val_loader),
            "val_miou": val_miou, "class_ious": cls_ious.tolist(),
            "book_prec": mean_precs[1], "book_rec": mean_recs[1],
            "fp_conf": mean_fp_conf, "fn_conf": mean_fn_conf
        }
        with open(jsonl_path, 'a') as f: f.write(json.dumps(log_entry) + "\n")

        if val_miou > best_miou:
            best_miou = val_miou
            torch.save(model.state_dict(), 'best_model_nyuv2_hybrid.pth')
            print(">>> Best Model Saved!")

        if do_analysis:
            save_confusion_matrix(conf_mat, epoch+1, Config.DEBUG_DIR)
            worst_samples.sort(key=lambda x: x[0], reverse=True)
            save_worst_samples(worst_samples[:Config.NUM_WORST_SAMPLES], epoch+1, Config.DEBUG_DIR)

        scheduler.step()

    print("📦 Zipping all logs...")
    shutil.make_archive(Config.DEBUG_DIR, 'zip', Config.DEBUG_DIR)
    print(f"✅ Download: {Config.DEBUG_DIR}.zip")

if __name__ == '__main__':
    if os.path.exists(Config.DATA_ROOT):
        torch.cuda.empty_cache()
        main()

[20260108_143648] 🚀 Start Training: Hybrid Anchor Strategy
💎 Loading Golden Samples...
Collecting 'Auto' book blobs...


Scanning: 100%|██████████| 716/716 [00:04<00:00, 164.00it/s]


⚖️ Dataset Strategy: 5x Oversample + Original Anchor
   - Boosted Set: 1228
   - Original Set: 716
   -> Final Train Size: 1944

[20260108_143656] Epoch 1/50




Train Loss: 0.7541 | Val Loss: 0.6307
Val mIoU: 0.5015
Books [1] -> IoU: 0.1558 | Prec: 0.1593 | Rec: 0.8698
Analysis  -> FP Conf: 0.6425 | FN Conf: 0.2075
Class-wise IoU:
  Bed       : 0.7149
  Books     : 0.1558
  Ceiling   : 0.3043
  Chair     : 0.4105
  Floor     : 0.8380
  Furniture : 0.5351
  Objects   : 0.4515
  Picture   : 0.3880
  Sofa      : 0.5486
  Table     : 0.4888
  TV        : 0.4043
  Wall      : 0.6808
  Window    : 0.5994
>>> Best Model Saved!

[20260108_143744] Epoch 2/50




Train Loss: 0.5240 | Val Loss: 0.5769
Val mIoU: 0.5319
Books [1] -> IoU: 0.2502 | Prec: 0.2578 | Rec: 0.8099
Analysis  -> FP Conf: 0.6984 | FN Conf: 0.2119
Class-wise IoU:
  Bed       : 0.6714
  Books     : 0.2502
  Ceiling   : 0.3807
  Chair     : 0.5539
  Floor     : 0.8700
  Furniture : 0.5390
  Objects   : 0.4928
  Picture   : 0.4162
  Sofa      : 0.5594
  Table     : 0.4719
  TV        : 0.3965
  Wall      : 0.7407
  Window    : 0.5716
>>> Best Model Saved!

[20260108_143831] Epoch 3/50




Train Loss: 0.4579 | Val Loss: 0.5505
Val mIoU: 0.5396
Books [1] -> IoU: 0.3279 | Prec: 0.3477 | Rec: 0.7749
Analysis  -> FP Conf: 0.6881 | FN Conf: 0.2357
Class-wise IoU:
  Bed       : 0.4368
  Books     : 0.3279
  Ceiling   : 0.2161
  Chair     : 0.6103
  Floor     : 0.8681
  Furniture : 0.6300
  Objects   : 0.5046
  Picture   : 0.4976
  Sofa      : 0.5202
  Table     : 0.5895
  TV        : 0.5246
  Wall      : 0.7364
  Window    : 0.5521
>>> Best Model Saved!

[20260108_143918] Epoch 4/50




Train Loss: 0.4161 | Val Loss: 0.5454
Val mIoU: 0.5705
Books [1] -> IoU: 0.2813 | Prec: 0.2944 | Rec: 0.7264
Analysis  -> FP Conf: 0.6646 | FN Conf: 0.2338
Class-wise IoU:
  Bed       : 0.6696
  Books     : 0.2813
  Ceiling   : 0.3569
  Chair     : 0.6569
  Floor     : 0.8695
  Furniture : 0.6054
  Objects   : 0.4906
  Picture   : 0.5156
  Sofa      : 0.6832
  Table     : 0.4489
  TV        : 0.6302
  Wall      : 0.7276
  Window    : 0.4808
>>> Best Model Saved!

[20260108_144005] Epoch 5/50




Train Loss: 0.3934 | Val Loss: 0.5064
Val mIoU: 0.5681
Books [1] -> IoU: 0.2295 | Prec: 0.2394 | Rec: 0.5308
Analysis  -> FP Conf: 0.6871 | FN Conf: 0.1933
Class-wise IoU:
  Bed       : 0.6941
  Books     : 0.2295
  Ceiling   : 0.3258
  Chair     : 0.5690
  Floor     : 0.8861
  Furniture : 0.6407
  Objects   : 0.5557
  Picture   : 0.5420
  Sofa      : 0.5620
  Table     : 0.5367
  TV        : 0.5175
  Wall      : 0.7717
  Window    : 0.5552

[20260108_144052] Epoch 6/50




Train Loss: 0.3541 | Val Loss: 0.5237
Val mIoU: 0.5979
Books [1] -> IoU: 0.2847 | Prec: 0.2943 | Rec: 0.8415
Analysis  -> FP Conf: 0.7174 | FN Conf: 0.2398
Class-wise IoU:
  Bed       : 0.6867
  Books     : 0.2847
  Ceiling   : 0.3497
  Chair     : 0.5132
  Floor     : 0.8878
  Furniture : 0.6138
  Objects   : 0.5109
  Picture   : 0.5169
  Sofa      : 0.5577
  Table     : 0.5891
  TV        : 0.8250
  Wall      : 0.7645
  Window    : 0.6724
>>> Best Model Saved!

[20260108_144140] Epoch 7/50




Train Loss: 0.3405 | Val Loss: 0.5362
Val mIoU: 0.5944
Books [1] -> IoU: 0.3303 | Prec: 0.3425 | Rec: 0.8483
Analysis  -> FP Conf: 0.7230 | FN Conf: 0.1942
Class-wise IoU:
  Bed       : 0.6561
  Books     : 0.3303
  Ceiling   : 0.3148
  Chair     : 0.5528
  Floor     : 0.9070
  Furniture : 0.6608
  Objects   : 0.5364
  Picture   : 0.4639
  Sofa      : 0.7739
  Table     : 0.5398
  TV        : 0.6793
  Wall      : 0.7804
  Window    : 0.5316

[20260108_144226] Epoch 8/50




Train Loss: 0.3068 | Val Loss: 0.4980
Val mIoU: 0.6002
Books [1] -> IoU: 0.3649 | Prec: 0.3995 | Rec: 0.7867
Analysis  -> FP Conf: 0.6932 | FN Conf: 0.2204
Class-wise IoU:
  Bed       : 0.7390
  Books     : 0.3649
  Ceiling   : 0.6440
  Chair     : 0.5068
  Floor     : 0.9131
  Furniture : 0.6694
  Objects   : 0.5549
  Picture   : 0.5967
  Sofa      : 0.6337
  Table     : 0.4079
  TV        : 0.4274
  Wall      : 0.8016
  Window    : 0.5436
>>> Best Model Saved!

[20260108_144314] Epoch 9/50




Train Loss: 0.2916 | Val Loss: 0.5020
Val mIoU: 0.6110
Books [1] -> IoU: 0.3856 | Prec: 0.4066 | Rec: 0.6948
Analysis  -> FP Conf: 0.7391 | FN Conf: 0.1840
Class-wise IoU:
  Bed       : 0.7370
  Books     : 0.3856
  Ceiling   : 0.3794
  Chair     : 0.4918
  Floor     : 0.8957
  Furniture : 0.6535
  Objects   : 0.5588
  Picture   : 0.5236
  Sofa      : 0.7689
  Table     : 0.5825
  TV        : 0.6504
  Wall      : 0.7854
  Window    : 0.5300
>>> Best Model Saved!

[20260108_144401] Epoch 10/50




Train Loss: 0.2897 | Val Loss: 0.4659
Val mIoU: 0.6252
Books [1] -> IoU: 0.3625 | Prec: 0.3803 | Rec: 0.6972
Analysis  -> FP Conf: 0.7330 | FN Conf: 0.1880
Class-wise IoU:
  Bed       : 0.7295
  Books     : 0.3625
  Ceiling   : 0.4411
  Chair     : 0.6593
  Floor     : 0.9074
  Furniture : 0.7208
  Objects   : 0.5854
  Picture   : 0.5938
  Sofa      : 0.6803
  Table     : 0.6408
  TV        : 0.4055
  Wall      : 0.8056
  Window    : 0.5956
>>> Best Model Saved!

[20260108_144454] Epoch 11/50




Train Loss: 0.2739 | Val Loss: 0.4705
Val mIoU: 0.6323
Books [1] -> IoU: 0.4058 | Prec: 0.4295 | Rec: 0.6910
Analysis  -> FP Conf: 0.7422 | FN Conf: 0.1551
Class-wise IoU:
  Bed       : 0.6902
  Books     : 0.4058
  Ceiling   : 0.4294
  Chair     : 0.5979
  Floor     : 0.9256
  Furniture : 0.6723
  Objects   : 0.5734
  Picture   : 0.5366
  Sofa      : 0.5610
  Table     : 0.6292
  TV        : 0.6684
  Wall      : 0.7920
  Window    : 0.7382
>>> Best Model Saved!

[20260108_144540] Epoch 12/50




Train Loss: 0.2635 | Val Loss: 0.4760
Val mIoU: 0.6136
Books [1] -> IoU: 0.3781 | Prec: 0.4023 | Rec: 0.7291
Analysis  -> FP Conf: 0.7950 | FN Conf: 0.1652
Class-wise IoU:
  Bed       : 0.7243
  Books     : 0.3781
  Ceiling   : 0.4840
  Chair     : 0.5891
  Floor     : 0.9024
  Furniture : 0.6920
  Objects   : 0.5690
  Picture   : 0.4928
  Sofa      : 0.6338
  Table     : 0.6513
  TV        : 0.5469
  Wall      : 0.7868
  Window    : 0.5265

[20260108_144627] Epoch 13/50




Train Loss: 0.2476 | Val Loss: 0.4530
Val mIoU: 0.6157
Books [1] -> IoU: 0.3720 | Prec: 0.4148 | Rec: 0.6186
Analysis  -> FP Conf: 0.7545 | FN Conf: 0.1019
Class-wise IoU:
  Bed       : 0.7137
  Books     : 0.3720
  Ceiling   : 0.4491
  Chair     : 0.6751
  Floor     : 0.9171
  Furniture : 0.6978
  Objects   : 0.6238
  Picture   : 0.4671
  Sofa      : 0.5584
  Table     : 0.5216
  TV        : 0.5060
  Wall      : 0.8073
  Window    : 0.6948

[20260108_144713] Epoch 14/50




Train Loss: 0.2374 | Val Loss: 0.4535
Val mIoU: 0.6502
Books [1] -> IoU: 0.4214 | Prec: 0.4532 | Rec: 0.7860
Analysis  -> FP Conf: 0.7493 | FN Conf: 0.1711
Class-wise IoU:
  Bed       : 0.7982
  Books     : 0.4214
  Ceiling   : 0.6428
  Chair     : 0.6104
  Floor     : 0.8858
  Furniture : 0.7105
  Objects   : 0.5892
  Picture   : 0.5472
  Sofa      : 0.6802
  Table     : 0.6230
  TV        : 0.4840
  Wall      : 0.7898
  Window    : 0.6702
>>> Best Model Saved!

[20260108_144801] Epoch 15/50




Train Loss: 0.2228 | Val Loss: 0.5022
Val mIoU: 0.6242
Books [1] -> IoU: 0.3930 | Prec: 0.4171 | Rec: 0.8962
Analysis  -> FP Conf: 0.7926 | FN Conf: 0.1648
Class-wise IoU:
  Bed       : 0.6531
  Books     : 0.3930
  Ceiling   : 0.4263
  Chair     : 0.6255
  Floor     : 0.9130
  Furniture : 0.6989
  Objects   : 0.6007
  Picture   : 0.5558
  Sofa      : 0.5333
  Table     : 0.6353
  TV        : 0.5427
  Wall      : 0.8069
  Window    : 0.7295

[20260108_144847] Epoch 16/50




Train Loss: 0.2280 | Val Loss: 0.5039
Val mIoU: 0.5978
Books [1] -> IoU: 0.2881 | Prec: 0.3006 | Rec: 0.7218
Analysis  -> FP Conf: 0.7868 | FN Conf: 0.2184
Class-wise IoU:
  Bed       : 0.6345
  Books     : 0.2881
  Ceiling   : 0.1712
  Chair     : 0.6615
  Floor     : 0.9288
  Furniture : 0.6759
  Objects   : 0.5997
  Picture   : 0.4933
  Sofa      : 0.7374
  Table     : 0.6425
  TV        : 0.5358
  Wall      : 0.7978
  Window    : 0.6048

[20260108_144933] Epoch 17/50




Train Loss: 0.2204 | Val Loss: 0.4498
Val mIoU: 0.6483
Books [1] -> IoU: 0.4294 | Prec: 0.4786 | Rec: 0.7970
Analysis  -> FP Conf: 0.8115 | FN Conf: 0.1253
Class-wise IoU:
  Bed       : 0.7239
  Books     : 0.4294
  Ceiling   : 0.4644
  Chair     : 0.5619
  Floor     : 0.8987
  Furniture : 0.6932
  Objects   : 0.6253
  Picture   : 0.5423
  Sofa      : 0.7405
  Table     : 0.7134
  TV        : 0.5052
  Wall      : 0.8201
  Window    : 0.7091

[20260108_145023] Epoch 18/50




Train Loss: 0.2087 | Val Loss: 0.4957
Val mIoU: 0.6137
Books [1] -> IoU: 0.3689 | Prec: 0.4064 | Rec: 0.7575
Analysis  -> FP Conf: 0.8019 | FN Conf: 0.1373
Class-wise IoU:
  Bed       : 0.7202
  Books     : 0.3689
  Ceiling   : 0.2424
  Chair     : 0.6440
  Floor     : 0.8967
  Furniture : 0.6962
  Objects   : 0.5846
  Picture   : 0.6586
  Sofa      : 0.6684
  Table     : 0.5365
  TV        : 0.5230
  Wall      : 0.8046
  Window    : 0.6336

[20260108_145110] Epoch 19/50




Train Loss: 0.2037 | Val Loss: 0.4479
Val mIoU: 0.6565
Books [1] -> IoU: 0.3493 | Prec: 0.3682 | Rec: 0.7188
Analysis  -> FP Conf: 0.8191 | FN Conf: 0.1452
Class-wise IoU:
  Bed       : 0.6740
  Books     : 0.3493
  Ceiling   : 0.5075
  Chair     : 0.7364
  Floor     : 0.9125
  Furniture : 0.6897
  Objects   : 0.6304
  Picture   : 0.5875
  Sofa      : 0.7026
  Table     : 0.6576
  TV        : 0.5934
  Wall      : 0.8388
  Window    : 0.6547
>>> Best Model Saved!

[20260108_145157] Epoch 20/50




Train Loss: 0.2058 | Val Loss: 0.4267
Val mIoU: 0.6554
Books [1] -> IoU: 0.4399 | Prec: 0.4846 | Rec: 0.6739
Analysis  -> FP Conf: 0.7334 | FN Conf: 0.1092
Class-wise IoU:
  Bed       : 0.7418
  Books     : 0.4399
  Ceiling   : 0.3761
  Chair     : 0.7240
  Floor     : 0.9190
  Furniture : 0.6870
  Objects   : 0.6361
  Picture   : 0.6463
  Sofa      : 0.6782
  Table     : 0.5753
  TV        : 0.5451
  Wall      : 0.8399
  Window    : 0.7120

[20260108_145250] Epoch 21/50




Train Loss: 0.1948 | Val Loss: 0.3869
Val mIoU: 0.6832
Books [1] -> IoU: 0.4162 | Prec: 0.4529 | Rec: 0.8162
Analysis  -> FP Conf: 0.8032 | FN Conf: 0.1079
Class-wise IoU:
  Bed       : 0.8003
  Books     : 0.4162
  Ceiling   : 0.5457
  Chair     : 0.6314
  Floor     : 0.9131
  Furniture : 0.7173
  Objects   : 0.6496
  Picture   : 0.6377
  Sofa      : 0.7324
  Table     : 0.6812
  TV        : 0.6262
  Wall      : 0.8367
  Window    : 0.6943
>>> Best Model Saved!

[20260108_145338] Epoch 22/50




Train Loss: 0.1843 | Val Loss: 0.3972
Val mIoU: 0.7021
Books [1] -> IoU: 0.4317 | Prec: 0.4757 | Rec: 0.6730
Analysis  -> FP Conf: 0.8260 | FN Conf: 0.0877
Class-wise IoU:
  Bed       : 0.7647
  Books     : 0.4317
  Ceiling   : 0.4601
  Chair     : 0.6823
  Floor     : 0.9176
  Furniture : 0.7332
  Objects   : 0.6444
  Picture   : 0.6442
  Sofa      : 0.7647
  Table     : 0.6389
  TV        : 0.9276
  Wall      : 0.8328
  Window    : 0.6856
>>> Best Model Saved!

[20260108_145425] Epoch 23/50




Train Loss: 0.1843 | Val Loss: 0.4568
Val mIoU: 0.6502
Books [1] -> IoU: 0.4444 | Prec: 0.4870 | Rec: 0.8055
Analysis  -> FP Conf: 0.8290 | FN Conf: 0.1055
Class-wise IoU:
  Bed       : 0.6586
  Books     : 0.4444
  Ceiling   : 0.5730
  Chair     : 0.6166
  Floor     : 0.9224
  Furniture : 0.6812
  Objects   : 0.6056
  Picture   : 0.5598
  Sofa      : 0.6423
  Table     : 0.6230
  TV        : 0.5518
  Wall      : 0.8336
  Window    : 0.7408

[20260108_145512] Epoch 24/50




Train Loss: 0.1738 | Val Loss: 0.4352
Val mIoU: 0.6663
Books [1] -> IoU: 0.3756 | Prec: 0.4062 | Rec: 0.8351
Analysis  -> FP Conf: 0.8092 | FN Conf: 0.1367
Class-wise IoU:
  Bed       : 0.7765
  Books     : 0.3756
  Ceiling   : 0.4054
  Chair     : 0.6530
  Floor     : 0.9199
  Furniture : 0.6648
  Objects   : 0.6604
  Picture   : 0.5884
  Sofa      : 0.7383
  Table     : 0.6674
  TV        : 0.7045
  Wall      : 0.8214
  Window    : 0.6868

[20260108_145600] Epoch 25/50




Train Loss: 0.1757 | Val Loss: 0.4513
Val mIoU: 0.6543
Books [1] -> IoU: 0.3566 | Prec: 0.3909 | Rec: 0.7402
Analysis  -> FP Conf: 0.7879 | FN Conf: 0.1213
Class-wise IoU:
  Bed       : 0.6890
  Books     : 0.3566
  Ceiling   : 0.5027
  Chair     : 0.6480
  Floor     : 0.9019
  Furniture : 0.7018
  Objects   : 0.6281
  Picture   : 0.6890
  Sofa      : 0.7119
  Table     : 0.6016
  TV        : 0.5540
  Wall      : 0.8325
  Window    : 0.6890

[20260108_145646] Epoch 26/50




Train Loss: 0.1682 | Val Loss: 0.4672
Val mIoU: 0.6513
Books [1] -> IoU: 0.4098 | Prec: 0.4542 | Rec: 0.7187
Analysis  -> FP Conf: 0.7886 | FN Conf: 0.0969
Class-wise IoU:
  Bed       : 0.7301
  Books     : 0.4098
  Ceiling   : 0.4720
  Chair     : 0.6506
  Floor     : 0.8952
  Furniture : 0.6871
  Objects   : 0.6210
  Picture   : 0.6290
  Sofa      : 0.6277
  Table     : 0.6579
  TV        : 0.5451
  Wall      : 0.8373
  Window    : 0.7043

[20260108_145733] Epoch 27/50




Train Loss: 0.1659 | Val Loss: 0.5079
Val mIoU: 0.6242
Books [1] -> IoU: 0.4432 | Prec: 0.4898 | Rec: 0.7093
Analysis  -> FP Conf: 0.8222 | FN Conf: 0.1089
Class-wise IoU:
  Bed       : 0.7137
  Books     : 0.4432
  Ceiling   : 0.2625
  Chair     : 0.6797
  Floor     : 0.9229
  Furniture : 0.7163
  Objects   : 0.5612
  Picture   : 0.5330
  Sofa      : 0.5480
  Table     : 0.6865
  TV        : 0.5494
  Wall      : 0.7946
  Window    : 0.7042

[20260108_145820] Epoch 28/50




Train Loss: 0.1619 | Val Loss: 0.4733
Val mIoU: 0.6729
Books [1] -> IoU: 0.3959 | Prec: 0.4355 | Rec: 0.6807
Analysis  -> FP Conf: 0.7763 | FN Conf: 0.0823
Class-wise IoU:
  Bed       : 0.7466
  Books     : 0.3959
  Ceiling   : 0.4827
  Chair     : 0.7267
  Floor     : 0.8831
  Furniture : 0.6899
  Objects   : 0.5861
  Picture   : 0.5526
  Sofa      : 0.7266
  Table     : 0.6752
  TV        : 0.7120
  Wall      : 0.8147
  Window    : 0.7558

[20260108_145907] Epoch 29/50




Train Loss: 0.1566 | Val Loss: 0.4292
Val mIoU: 0.6560
Books [1] -> IoU: 0.2865 | Prec: 0.3188 | Rec: 0.5414
Analysis  -> FP Conf: 0.7855 | FN Conf: 0.1294
Class-wise IoU:
  Bed       : 0.7047
  Books     : 0.2865
  Ceiling   : 0.4425
  Chair     : 0.6818
  Floor     : 0.9124
  Furniture : 0.7082
  Objects   : 0.6187
  Picture   : 0.6465
  Sofa      : 0.8294
  Table     : 0.6090
  TV        : 0.5523
  Wall      : 0.8115
  Window    : 0.7247

[20260108_145953] Epoch 30/50




Train Loss: 0.1567 | Val Loss: 0.4528
Val mIoU: 0.6563
Books [1] -> IoU: 0.3988 | Prec: 0.4342 | Rec: 0.7757
Analysis  -> FP Conf: 0.8185 | FN Conf: 0.1002
Class-wise IoU:
  Bed       : 0.7520
  Books     : 0.3988
  Ceiling   : 0.4007
  Chair     : 0.6774
  Floor     : 0.9158
  Furniture : 0.6854
  Objects   : 0.6550
  Picture   : 0.6084
  Sofa      : 0.7438
  Table     : 0.6670
  TV        : 0.5238
  Wall      : 0.8363
  Window    : 0.6681

[20260108_150046] Epoch 31/50




Train Loss: 0.1575 | Val Loss: 0.4439
Val mIoU: 0.6866
Books [1] -> IoU: 0.4186 | Prec: 0.4564 | Rec: 0.8211
Analysis  -> FP Conf: 0.8385 | FN Conf: 0.1087
Class-wise IoU:
  Bed       : 0.7139
  Books     : 0.4186
  Ceiling   : 0.5764
  Chair     : 0.5580
  Floor     : 0.9177
  Furniture : 0.7434
  Objects   : 0.6132
  Picture   : 0.6551
  Sofa      : 0.6277
  Table     : 0.7070
  TV        : 0.8779
  Wall      : 0.8241
  Window    : 0.6922

[20260108_150132] Epoch 32/50




Train Loss: 0.1533 | Val Loss: 0.4444
Val mIoU: 0.6630
Books [1] -> IoU: 0.4465 | Prec: 0.4933 | Rec: 0.6974
Analysis  -> FP Conf: 0.8012 | FN Conf: 0.1001
Class-wise IoU:
  Bed       : 0.7473
  Books     : 0.4465
  Ceiling   : 0.3170
  Chair     : 0.6858
  Floor     : 0.9303
  Furniture : 0.7235
  Objects   : 0.6198
  Picture   : 0.5544
  Sofa      : 0.8311
  Table     : 0.6675
  TV        : 0.6939
  Wall      : 0.8031
  Window    : 0.5988

[20260108_150218] Epoch 33/50




Train Loss: 0.1512 | Val Loss: 0.4389
Val mIoU: 0.6887
Books [1] -> IoU: 0.4321 | Prec: 0.4761 | Rec: 0.7695
Analysis  -> FP Conf: 0.7810 | FN Conf: 0.0922
Class-wise IoU:
  Bed       : 0.6533
  Books     : 0.4321
  Ceiling   : 0.5368
  Chair     : 0.7262
  Floor     : 0.9201
  Furniture : 0.6887
  Objects   : 0.6267
  Picture   : 0.6609
  Sofa      : 0.6668
  Table     : 0.6896
  TV        : 0.7030
  Wall      : 0.8469
  Window    : 0.8020

[20260108_150305] Epoch 34/50




Train Loss: 0.1456 | Val Loss: 0.4184
Val mIoU: 0.6712
Books [1] -> IoU: 0.4215 | Prec: 0.4678 | Rec: 0.6718
Analysis  -> FP Conf: 0.7821 | FN Conf: 0.1141
Class-wise IoU:
  Bed       : 0.7611
  Books     : 0.4215
  Ceiling   : 0.4907
  Chair     : 0.6963
  Floor     : 0.9132
  Furniture : 0.7326
  Objects   : 0.6200
  Picture   : 0.6799
  Sofa      : 0.6859
  Table     : 0.6941
  TV        : 0.5606
  Wall      : 0.7929
  Window    : 0.6770

[20260108_150352] Epoch 35/50




Train Loss: 0.1431 | Val Loss: 0.4156
Val mIoU: 0.6994
Books [1] -> IoU: 0.3905 | Prec: 0.4200 | Rec: 0.7912
Analysis  -> FP Conf: 0.8160 | FN Conf: 0.1043
Class-wise IoU:
  Bed       : 0.7621
  Books     : 0.3905
  Ceiling   : 0.7186
  Chair     : 0.6287
  Floor     : 0.9279
  Furniture : 0.7167
  Objects   : 0.6581
  Picture   : 0.7096
  Sofa      : 0.5638
  Table     : 0.7294
  TV        : 0.6913
  Wall      : 0.8401
  Window    : 0.7558

[20260108_150438] Epoch 36/50




Train Loss: 0.1419 | Val Loss: 0.4389
Val mIoU: 0.6813
Books [1] -> IoU: 0.4277 | Prec: 0.4734 | Rec: 0.7062
Analysis  -> FP Conf: 0.8097 | FN Conf: 0.1022
Class-wise IoU:
  Bed       : 0.7976
  Books     : 0.4277
  Ceiling   : 0.4671
  Chair     : 0.6005
  Floor     : 0.9151
  Furniture : 0.7333
  Objects   : 0.6288
  Picture   : 0.7149
  Sofa      : 0.6452
  Table     : 0.6680
  TV        : 0.6866
  Wall      : 0.8323
  Window    : 0.7401

[20260108_150525] Epoch 37/50




Train Loss: 0.1437 | Val Loss: 0.4273
Val mIoU: 0.6972
Books [1] -> IoU: 0.3409 | Prec: 0.4695 | Rec: 0.6806
Analysis  -> FP Conf: 0.7397 | FN Conf: 0.1029
Class-wise IoU:
  Bed       : 0.7666
  Books     : 0.3409
  Ceiling   : 0.6629
  Chair     : 0.7117
  Floor     : 0.9211
  Furniture : 0.7046
  Objects   : 0.6702
  Picture   : 0.7063
  Sofa      : 0.6903
  Table     : 0.6493
  TV        : 0.7071
  Wall      : 0.8437
  Window    : 0.6887

[20260108_150611] Epoch 38/50




Train Loss: 0.1402 | Val Loss: 0.4114
Val mIoU: 0.6871
Books [1] -> IoU: 0.4369 | Prec: 0.4857 | Rec: 0.6816
Analysis  -> FP Conf: 0.8400 | FN Conf: 0.0891
Class-wise IoU:
  Bed       : 0.7556
  Books     : 0.4369
  Ceiling   : 0.4944
  Chair     : 0.6677
  Floor     : 0.9290
  Furniture : 0.7379
  Objects   : 0.6640
  Picture   : 0.6417
  Sofa      : 0.6701
  Table     : 0.6893
  TV        : 0.6920
  Wall      : 0.7969
  Window    : 0.7569

[20260108_150658] Epoch 39/50




Train Loss: 0.1380 | Val Loss: 0.4195
Val mIoU: 0.6958
Books [1] -> IoU: 0.3994 | Prec: 0.4412 | Rec: 0.7833
Analysis  -> FP Conf: 0.8306 | FN Conf: 0.0997
Class-wise IoU:
  Bed       : 0.7484
  Books     : 0.3994
  Ceiling   : 0.5402
  Chair     : 0.6895
  Floor     : 0.9263
  Furniture : 0.7312
  Objects   : 0.6080
  Picture   : 0.7362
  Sofa      : 0.7329
  Table     : 0.6935
  TV        : 0.6839
  Wall      : 0.8420
  Window    : 0.7143

[20260108_150747] Epoch 40/50




Train Loss: 0.1347 | Val Loss: 0.4354
Val mIoU: 0.6680
Books [1] -> IoU: 0.4309 | Prec: 0.4789 | Rec: 0.6756
Analysis  -> FP Conf: 0.7978 | FN Conf: 0.1022
Class-wise IoU:
  Bed       : 0.6986
  Books     : 0.4309
  Ceiling   : 0.2780
  Chair     : 0.7240
  Floor     : 0.9184
  Furniture : 0.7502
  Objects   : 0.6418
  Picture   : 0.7014
  Sofa      : 0.7196
  Table     : 0.6363
  TV        : 0.6943
  Wall      : 0.8254
  Window    : 0.6655

[20260108_150839] Epoch 41/50




Train Loss: 0.1354 | Val Loss: 0.3976
Val mIoU: 0.6776
Books [1] -> IoU: 0.5447 | Prec: 0.6106 | Rec: 0.6499
Analysis  -> FP Conf: 0.7623 | FN Conf: 0.0897
Class-wise IoU:
  Bed       : 0.7633
  Books     : 0.5447
  Ceiling   : 0.3404
  Chair     : 0.6785
  Floor     : 0.9321
  Furniture : 0.7204
  Objects   : 0.6562
  Picture   : 0.6880
  Sofa      : 0.7241
  Table     : 0.6704
  TV        : 0.5403
  Wall      : 0.8179
  Window    : 0.7319

[20260108_150926] Epoch 42/50




Train Loss: 0.1333 | Val Loss: 0.4106
Val mIoU: 0.6858
Books [1] -> IoU: 0.4270 | Prec: 0.4740 | Rec: 0.7307
Analysis  -> FP Conf: 0.7691 | FN Conf: 0.0832
Class-wise IoU:
  Bed       : 0.6843
  Books     : 0.4270
  Ceiling   : 0.3007
  Chair     : 0.7044
  Floor     : 0.9469
  Furniture : 0.7454
  Objects   : 0.6433
  Picture   : 0.6945
  Sofa      : 0.6898
  Table     : 0.6812
  TV        : 0.8459
  Wall      : 0.8311
  Window    : 0.7201

[20260108_151013] Epoch 43/50




Train Loss: 0.1348 | Val Loss: 0.4280
Val mIoU: 0.6799
Books [1] -> IoU: 0.5206 | Prec: 0.5706 | Rec: 0.7862
Analysis  -> FP Conf: 0.8535 | FN Conf: 0.0846
Class-wise IoU:
  Bed       : 0.7721
  Books     : 0.5206
  Ceiling   : 0.2730
  Chair     : 0.7673
  Floor     : 0.9166
  Furniture : 0.7055
  Objects   : 0.6474
  Picture   : 0.6311
  Sofa      : 0.6473
  Table     : 0.6185
  TV        : 0.7032
  Wall      : 0.8445
  Window    : 0.7916

[20260108_151059] Epoch 44/50




Train Loss: 0.1337 | Val Loss: 0.4587
Val mIoU: 0.6657
Books [1] -> IoU: 0.4351 | Prec: 0.4817 | Rec: 0.7906
Analysis  -> FP Conf: 0.8418 | FN Conf: 0.0886
Class-wise IoU:
  Bed       : 0.7985
  Books     : 0.4351
  Ceiling   : 0.2596
  Chair     : 0.6699
  Floor     : 0.9355
  Furniture : 0.7388
  Objects   : 0.6019
  Picture   : 0.5983
  Sofa      : 0.7191
  Table     : 0.6812
  TV        : 0.6894
  Wall      : 0.8249
  Window    : 0.7021

[20260108_151145] Epoch 45/50




Train Loss: 0.1352 | Val Loss: 0.4206
Val mIoU: 0.6778
Books [1] -> IoU: 0.3892 | Prec: 0.4318 | Rec: 0.6376
Analysis  -> FP Conf: 0.8178 | FN Conf: 0.0755
Class-wise IoU:
  Bed       : 0.8023
  Books     : 0.3892
  Ceiling   : 0.4083
  Chair     : 0.7115
  Floor     : 0.9448
  Furniture : 0.7304
  Objects   : 0.6513
  Picture   : 0.5937
  Sofa      : 0.7865
  Table     : 0.6742
  TV        : 0.7060
  Wall      : 0.7987
  Window    : 0.6142

[20260108_151232] Epoch 46/50




Train Loss: 0.1322 | Val Loss: 0.4332
Val mIoU: 0.7031
Books [1] -> IoU: 0.4265 | Prec: 0.4764 | Rec: 0.6444
Analysis  -> FP Conf: 0.8209 | FN Conf: 0.0816
Class-wise IoU:
  Bed       : 0.8237
  Books     : 0.4265
  Ceiling   : 0.5312
  Chair     : 0.6513
  Floor     : 0.9213
  Furniture : 0.7334
  Objects   : 0.6080
  Picture   : 0.6363
  Sofa      : 0.7505
  Table     : 0.5822
  TV        : 0.9429
  Wall      : 0.8165
  Window    : 0.7170
>>> Best Model Saved!

[20260108_151319] Epoch 47/50




Train Loss: 0.1348 | Val Loss: 0.4541
Val mIoU: 0.6801
Books [1] -> IoU: 0.4229 | Prec: 0.4656 | Rec: 0.7173
Analysis  -> FP Conf: 0.8197 | FN Conf: 0.0997
Class-wise IoU:
  Bed       : 0.7920
  Books     : 0.4229
  Ceiling   : 0.5669
  Chair     : 0.6726
  Floor     : 0.9199
  Furniture : 0.7319
  Objects   : 0.5919
  Picture   : 0.6435
  Sofa      : 0.7056
  Table     : 0.7144
  TV        : 0.5412
  Wall      : 0.8047
  Window    : 0.7338

[20260108_151405] Epoch 48/50




Train Loss: 0.1312 | Val Loss: 0.4376
Val mIoU: 0.6518
Books [1] -> IoU: 0.4134 | Prec: 0.4768 | Rec: 0.5974
Analysis  -> FP Conf: 0.8354 | FN Conf: 0.0899
Class-wise IoU:
  Bed       : 0.7185
  Books     : 0.4134
  Ceiling   : 0.3047
  Chair     : 0.6104
  Floor     : 0.9065
  Furniture : 0.7302
  Objects   : 0.6358
  Picture   : 0.6798
  Sofa      : 0.6880
  Table     : 0.6678
  TV        : 0.5658
  Wall      : 0.8141
  Window    : 0.7378

[20260108_151452] Epoch 49/50




Train Loss: 0.1309 | Val Loss: 0.4459
Val mIoU: 0.6739
Books [1] -> IoU: 0.4374 | Prec: 0.4829 | Rec: 0.7225
Analysis  -> FP Conf: 0.7789 | FN Conf: 0.0941
Class-wise IoU:
  Bed       : 0.7618
  Books     : 0.4374
  Ceiling   : 0.5428
  Chair     : 0.6300
  Floor     : 0.9233
  Furniture : 0.7371
  Objects   : 0.6245
  Picture   : 0.6024
  Sofa      : 0.7962
  Table     : 0.6365
  TV        : 0.5617
  Wall      : 0.8630
  Window    : 0.6434

[20260108_151538] Epoch 50/50




Train Loss: 0.1345 | Val Loss: 0.4187
Val mIoU: 0.6973
Books [1] -> IoU: 0.4431 | Prec: 0.4936 | Rec: 0.6760
Analysis  -> FP Conf: 0.8196 | FN Conf: 0.0888
Class-wise IoU:
  Bed       : 0.7123
  Books     : 0.4431
  Ceiling   : 0.4659
  Chair     : 0.7318
  Floor     : 0.9088
  Furniture : 0.7859
  Objects   : 0.6426
  Picture   : 0.7078
  Sofa      : 0.7291
  Table     : 0.6999
  TV        : 0.6999
  Wall      : 0.8374
  Window    : 0.6998
📦 Zipping all logs...
✅ Download: /content/debug_hybrid_20260108_143648.zip


In [None]:
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
import cv2
import base64
import os
import glob
import numpy as np
from tqdm.notebook import tqdm
import time

class UltimateGoldenAnnotator:
    def __init__(self, root_dir='/content/data/train'):
        self.root_dir = root_dir
        self.img_dir = os.path.join(root_dir, 'image')
        self.label_dir = os.path.join(root_dir, 'label')

        # 本が含まれる画像のみを事前スキャン
        self.files = self.scan_for_books()
        self.index = 0

        if not self.files:
            print("エラー: 本が含まれる画像が見つかりませんでした。")
            return

        # UI
        self.btn_prev = widgets.Button(description="◀ 前へ", layout=widgets.Layout(width='80px'))
        self.btn_next = widgets.Button(description="次へ ▶", layout=widgets.Layout(width='80px'), button_style='info')
        self.btn_skip = widgets.Button(description="10枚飛ばす", layout=widgets.Layout(width='100px'))
        self.lbl_info = widgets.Label(value="")

        self.btn_prev.on_click(self.on_prev)
        self.btn_next.on_click(self.on_next)
        self.btn_skip.on_click(self.on_skip)

        self.out_image = widgets.Output()
        self.controls = widgets.HBox([self.btn_prev, self.btn_next, self.btn_skip, self.lbl_info])

        display(self.controls)
        display(self.out_image)

        self.render()

    def scan_for_books(self):
        print("🔍 本(class=1)が含まれる画像を抽出しています...少々お待ちください")
        valid_files = []
        all_files = sorted(glob.glob(os.path.join(self.img_dir, "*.png")))

        for img_path in tqdm(all_files):
            basename = os.path.basename(img_path)
            label_path = os.path.join(self.label_dir, basename)
            if os.path.exists(label_path):
                label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
                if label is not None and np.any(label == 1):
                    valid_files.append(img_path)
        print(f"✅ 完了: {len(valid_files)} 枚の画像が見つかりました。")
        return valid_files

    def on_next(self, b):
        if self.index < len(self.files) - 1:
            self.index += 1
            self.render()

    def on_prev(self, b):
        if self.index > 0:
            self.index -= 1
            self.render()

    def on_skip(self, b):
        if self.index < len(self.files) - 10:
            self.index += 10
            self.render()
        else:
            self.index = len(self.files) - 1
            self.render()

    def render(self):
        filepath = self.files[self.index]
        filename = os.path.basename(filepath)
        file_id = os.path.splitext(filename)[0]
        label_path = os.path.join(self.label_dir, f"{file_id}.png")

        self.lbl_info.value = f" [{self.index+1}/{len(self.files)}] ID: {file_id}"

        # 画像読み込み
        img = cv2.imread(filepath)
        label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)

        if img is None: return

        # --- オーバーレイ処理 (ラベルを赤く表示) ---
        if label is not None:
            # マスク作成 (class=1)
            mask = (label == 1).astype(np.uint8)
            if np.sum(mask) > 0:
                # 赤いレイヤー
                color_layer = np.zeros_like(img)
                color_layer[:, :, 2] = 255 # BGRなので2が赤 (OpenCVはBGR)

                # マスク部分だけ赤くする
                # img * 0.6 + red * 0.4
                img_display = img.copy()
                roi = img_display[mask == 1]
                img_display[mask == 1] = cv2.addWeighted(roi, 0.6, color_layer[mask == 1], 0.4, 0)
            else:
                img_display = img
        else:
            img_display = img

        # Base64変換
        img_h, img_w = img_display.shape[:2]
        _, buffer = cv2.imencode('.png', img_display)
        img_str = base64.b64encode(buffer).decode('utf-8')

        unique_id = f"{file_id}_{int(time.time()*1000)}"
        canvas_id = f"canvas_{unique_id}"
        result_id = f"result_{unique_id}"

        html_code = f"""
        <div style="margin-top: 10px;">
            <div style="margin-bottom: 5px; color: #333;">
                <span style="color: red; font-weight: bold;">■ 赤色: データセットのラベル</span> /
                <span style="color: #00E676; font-weight: bold;">■ 緑枠: あなたが選ぶ理想の本</span>
            </div>

            <div style="position: relative; display: inline-block; border: 2px solid #555;">
                <canvas id="{canvas_id}" width="{img_w}" height="{img_h}" style="cursor: crosshair; display: block;"></canvas>
            </div>

            <div style="background-color: #f1f8e9; padding: 10px; border-radius: 5px; margin-top: 8px; border: 1px solid #c5e1a5;">
                <div style="font-size: 0.8em; color: #33691e;">コピー用コード:</div>
                <textarea id="{result_id}" rows="1" style="width: 100%; font-family: monospace; font-size: 1.1em; color: #d63384; border: 1px solid #ccc;" readonly onclick="this.select();"></textarea>
            </div>
        </div>

        <script>
            (function() {{
                var canvas = document.getElementById('{canvas_id}');
                var resultArea = document.getElementById('{result_id}');
                if (!canvas) return;

                var ctx = canvas.getContext('2d');
                var img = new Image();
                img.src = "data:image/png;base64,{img_str}";
                var startX, startY, isDrawing = false;

                img.onload = function() {{ ctx.drawImage(img, 0, 0); }};

                function getPos(evt) {{
                    var rect = canvas.getBoundingClientRect();
                    return {{
                        x: (evt.clientX - rect.left) * (canvas.width / rect.width),
                        y: (evt.clientY - rect.top) * (canvas.height / rect.height)
                    }};
                }}

                canvas.addEventListener('mousedown', function(e) {{
                    var pos = getPos(e); startX = pos.x; startY = pos.y; isDrawing = true;
                }});

                canvas.addEventListener('mousemove', function(e) {{
                    if (!isDrawing) return;
                    var pos = getPos(e);
                    ctx.clearRect(0, 0, canvas.width, canvas.height);
                    ctx.drawImage(img, 0, 0);

                    var w = pos.x - startX;
                    var h = pos.y - startY;
                    ctx.beginPath(); ctx.lineWidth = 2; ctx.strokeStyle = "#00E676";
                    ctx.rect(startX, startY, w, h); ctx.stroke();
                }});

                canvas.addEventListener('mouseup', function(e) {{
                    isDrawing = false;
                    var pos = getPos(e);
                    var x = Math.round(Math.min(startX, pos.x));
                    var y = Math.round(Math.min(startY, pos.y));
                    var w = Math.round(Math.abs(pos.x - startX));
                    var h = Math.round(Math.abs(pos.y - startY));

                    if (w > 5 && h > 5) {{
                        ctx.beginPath(); ctx.lineWidth = 3; ctx.strokeStyle = "#00E676";
                        ctx.rect(x, y, w, h); ctx.stroke();
                        resultArea.value = "{{'id': '{file_id}', 'bbox': [" + x + ", " + y + ", " + w + ", " + h + "]}}, # {filename}";
                    }}
                }});
            }})();
        </script>
        """

        with self.out_image:
            clear_output(wait=True)
            display(HTML(html_code))

# 実行
annotator = UltimateGoldenAnnotator()

🔍 本(class=1)が含まれる画像を抽出しています...少々お待ちください


  0%|          | 0/795 [00:00<?, ?it/s]

✅ 完了: 136 枚の画像が見つかりました。


HBox(children=(Button(description='◀ 前へ', layout=Layout(width='80px'), style=ButtonStyle()), Button(button_sty…

Output()

In [None]:
# ==========================================
# 1. ライブラリのインストール & インポート
# ==========================================
# 便利なセグメンテーションモデルライブラリをインストール
!pip install segmentation-models-pytorch -q

import os
import random
import time
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import Resize, ToTensor, Compose, InterpolationMode, Lambda
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
from torch.cuda.amp import autocast, GradScaler
import segmentation_models_pytorch as smp # ★最強のライブラリ
import warnings

# 警告抑制
warnings.filterwarnings("ignore")

# ==========================================
# 2. 設定 (Configuration)
# ==========================================
class TrainingConfig:
    dataset_root = "/content/data" # データセットのパス
    batch_size = 16
    num_workers = 2
    # ★ここが重要: RGB(3) + Depth(1) = 4チャネル入力
    in_channels = 4
    num_classes = 13
    image_size = (320, 240) # 高さ, 幅
    epochs = 15            # テスト用に短め。精度が欲しければ 30~50 推奨
    learning_rate = 1e-4
    device = "cuda" if torch.cuda.is_available() else "cpu"

config = TrainingConfig()

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

# ==========================================
# 3. データセット定義 (NYUv2)
# ==========================================
class NYUv2(torch.utils.data.Dataset):
    def __init__(self, root, split='train', transform=None, target_transform=None):
        super().__init__()
        self.root = root
        self.split = split

        # パス解決
        base_dir = os.path.join(self.root, 'NYUv2') if os.path.exists(os.path.join(self.root, 'NYUv2')) else self.root
        img_dir = os.path.join(base_dir, self.split, 'image')

        if not os.path.exists(img_dir):
             raise FileNotFoundError(f"Directory not found: {img_dir}")

        img_names = sorted(os.listdir(img_dir))
        self.images = [os.path.join(img_dir, n) for n in img_names]
        self.depths = [os.path.join(base_dir, self.split, 'depth', n) for n in img_names]

        if self.split == 'train':
            self.targets = [os.path.join(base_dir, self.split, 'label', n) for n in img_names]

        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, idx):
        image = Image.open(self.images[idx])
        depth = Image.open(self.depths[idx])

        if self.transform:
            image = self.transform(image)
            depth = self.transform(depth)

        # ★重要: ここで RGB と Depth を結合して 4チャネルの Tensor にする
        # image: [3, H, W], depth: [1, H, W] -> input: [4, H, W]
        input_tensor = torch.cat([image, depth], dim=0)

        if self.split == 'train':
            target = Image.open(self.targets[idx])
            if self.target_transform:
                target = self.target_transform(target)
            return input_tensor, target

        return input_tensor

    def __len__(self):
        return len(self.images)

# ==========================================
# 4. データ準備 & モデル構築 (smp使用)
# ==========================================
print("Setting up Data & Model...")

# Transforms
transform = Compose([
    Resize(config.image_size, interpolation=InterpolationMode.BILINEAR),
    ToTensor()
])
target_transform = Compose([
    Resize(config.image_size, interpolation=InterpolationMode.NEAREST),
    Lambda(lambda lbl: torch.from_numpy(np.array(lbl)).long())
])

# Dataset
full_dataset = NYUv2(root=config.dataset_root, split='train', transform=transform, target_transform=target_transform)

# Split (80% Train, 20% Val)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=config.num_workers)

print(f"Train: {len(train_dataset)} images, Val: {len(val_dataset)} images")

# ★★★ Model Definition (The Key Change) ★★★
# smp.Unet を使用。EncoderにResNet34(ImageNet学習済み)を指定。
# in_channels=4 とすることで、RGB+Depth入力に自動対応（重みの拡張も自動）
print("Building Pre-trained Model (ResNet34 + ImageNet)...")
model = smp.Unet(
    encoder_name="resnet34",        # バックボーン
    encoder_weights="imagenet",     # 事前学習済み重み
    in_channels=config.in_channels, # 4チャネル入力
    classes=config.num_classes,     # 13クラス出力
    activation=None                 # 生のlogitsを出力
).to(config.device)

optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index=255)
scaler = GradScaler()

# ==========================================
# 5. 学習ループ (Training)
# ==========================================
print(f"\nStarting training for {config.epochs} epochs...")

for epoch in range(config.epochs):
    model.train()
    total_loss = 0

    with tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.epochs}") as pbar:
        for x, label in pbar:
            x, label = x.to(config.device), label.to(config.device)

            optimizer.zero_grad()
            with autocast():
                pred = model(x)
                loss = criterion(pred, label)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            pbar.set_postfix({'loss': loss.item()})

    # シンプルにするため、エポックごとのVal評価は省略し、最後にまとめて評価します

print("Training finished.")

# ==========================================
# 6. 最終評価 (Evaluation)
# ==========================================
def evaluate_iou(model, loader, device, num_classes):
    model.eval()
    confusion = np.zeros((num_classes, num_classes), dtype=np.int64)

    print("\nEvaluating on Validation Set...")
    with torch.no_grad():
        for x, target in tqdm(loader):
            x, target = x.to(device), target.to(device)
            output = model(x)
            pred = output.argmax(dim=1)

            pred_flat = pred.cpu().numpy().flatten()
            target_flat = target.cpu().numpy().flatten()

            mask = (target_flat != 255)
            confusion += confusion_matrix(
                target_flat[mask],
                pred_flat[mask],
                labels=range(num_classes)
            )

    ious = []
    for i in range(num_classes):
        tp = confusion[i, i]
        fp = confusion[:, i].sum() - tp
        fn = confusion[i, :].sum() - tp
        denom = tp + fp + fn
        ious.append(tp / denom if denom > 0 else float('nan'))
    return ious

# クラス名定義
class_names = [
    "Bed", "Book", "Ceiling", "Chair", "Floor",
    "Cabinet", "Object", "Picture", "Sofa", "Table",
    "TV", "Wall", "Window"
]

# IoU計算
ious = evaluate_iou(model, val_loader, config.device, config.num_classes)

# 結果表示
print("\n" + "="*45)
print(f"{'Class Name':<15} | {'IoU (Pre-trained)':<15}")
print("-" * 45)
valid_ious = []
for idx, iou in enumerate(ious):
    name = class_names[idx]
    print(f"{name:<15} | {iou:.4f}")
    if not np.isnan(iou):
        valid_ious.append(iou)

print("=" * 45)
targets = {"Book": 1, "Cabinet": 5, "Object": 6}
print("Target Classes Check:")
for name, idx in targets.items():
    print(f"  {name:<10}: {ious[idx]:.4f}")

if valid_ious:
    print("-" * 45)
    print(f"Mean IoU       : {sum(valid_ious)/len(valid_ious):.4f}")
print("=" * 45)

In [None]:
# ==========================================
# 5. Main Pipeline: Train & Eval (Corrected)
# ==========================================
def main_pipeline():
    print(f"=== Starting Integrated Pipeline (No Book Rescue) ===")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 強制再マイニングフラグ（特徴量定義を変えたのでTrue必須）
    FORCE_REMINE = True

    # 1. Load Base Model
    print(f"Loading Base Model from {BASE_MODEL_PATH}...")
    base_model = ResNeXtDeepLabV3Plus_OS8(num_classes=13).to(device)
    base_model.load_state_dict(torch.load(BASE_MODEL_PATH, map_location=device))
    base_model.eval()

    # 2. Prepare Dataset
    val_aug = A.Compose([A.Resize(768, 768)], additional_targets={'depth': 'image', 'height': 'image'})
    full_ds = NYUv2Dataset("/content/data", split="train", transform=val_aug)

    indices = list(range(len(full_ds)))
    random.seed(42); random.shuffle(indices)
    n_val = int(len(full_ds) * 0.1)
    train_idx = indices[:-n_val]
    val_idx = indices[-n_val:]

    train_ds = Subset(full_ds, train_idx)
    val_ds = Subset(full_ds, val_idx)

    # 3. Mine Blobs & Train MLP
    train_csv_path = os.path.join(OUTPUT_DIR, "blobs_train.csv")
    val_csv_path = os.path.join(OUTPUT_DIR, "blobs_val.csv")

    # キャッシュがあっても強制的に作り直す、またはキャッシュがない場合
    if FORCE_REMINE or not os.path.exists(train_csv_path):
        print("Mining blobs from training set (Force Refresh)...")
        df_train = mine_blobs_context(base_model, train_ds, device, "Mining Train")
        df_train.to_csv(train_csv_path, index=False)
    else:
        print("Loading cached train blobs...")
        df_train = pd.read_csv(train_csv_path)

    if FORCE_REMINE or not os.path.exists(val_csv_path):
        print("Mining blobs from validation set (Force Refresh)...")
        df_val = mine_blobs_context(base_model, val_ds, device, "Mining Val")
        df_val.to_csv(val_csv_path, index=False)
    else:
        print("Loading cached val blobs...")
        df_val = pd.read_csv(val_csv_path)

    # Prepare Features
    feature_cols = [c for c in df_train.columns if c.startswith("f_")]
    print(f"Features: {len(feature_cols)} dims")

    # エラーハンドリング：もし特徴量が抽出できていなければ止める
    if len(feature_cols) == 0:
        raise ValueError("Feature extraction failed: No columns starting with 'f_' found. Check mining logic.")

    X_train = df_train[feature_cols].fillna(0).values
    y_train = df_train["mlp_target"].values
    X_val = df_val[feature_cols].fillna(0).values # 評価用
    y_val = df_val["mlp_target"].values

    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)

    # Train MLP
    print("Training MLP...")
    class_counts = np.bincount(y_train.astype(int))
    # クラスカウントが0の場合のエラー回避
    class_counts = np.maximum(class_counts, 1)

    weights = 1. / class_counts
    sample_weights = weights[y_train.astype(int)]
    sampler = WeightedRandomSampler(sample_weights, len(sample_weights))

    mlp_loader = DataLoader(BlobDataset(X_train, y_train), batch_size=64, sampler=sampler)
    mlp = BlobCorrectionNet(len(feature_cols), num_classes=4).to(device)
    optimizer = torch.optim.Adam(mlp.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(30):
        mlp.train()
        for bx, by in mlp_loader:
            bx, by = bx.to(device), by.to(device)
            optimizer.zero_grad()
            out = mlp(bx)
            loss = criterion(out, by)
            loss.backward()
            optimizer.step()

    print("MLP Training Finished.")

    # 4. Final Integrated Evaluation
    print("\n=== Running Final Evaluation (Base + MLP Correction) ===")

    intersection_base = np.zeros(13)
    union_base = np.zeros(13)
    intersection_corr = np.zeros(13)
    union_corr = np.zeros(13)

    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False)

    base_model.eval()
    mlp.eval()

    for x, y in tqdm(val_loader, desc="Evaluating"):
        x = x.to(device)
        y_np = y.numpy()[0]

        with torch.no_grad():
            logits = base_model(x)
            pred_base = logits.argmax(dim=1).cpu().numpy()[0]

        # Corrected Prediction
        pred_corr = predict_with_correction(base_model, mlp, scaler, x, device)

        mask = (y_np != 255)

        # Base
        for cls_id in range(13):
            p = (pred_base == cls_id) & mask
            t = (y_np == cls_id) & mask
            intersection_base[cls_id] += np.sum(p & t)
            union_base[cls_id] += np.sum(p | t)

        # Corrected
        for cls_id in range(13):
            p = (pred_corr == cls_id) & mask
            t = (y_np == cls_id) & mask
            intersection_corr[cls_id] += np.sum(p & t)
            union_corr[cls_id] += np.sum(p | t)

    # 5. Report
    iou_base = intersection_base / (union_base + 1e-6)
    iou_corr = intersection_corr / (union_corr + 1e-6)

    print("\n"+"-"*50)
    print(f"{'Class':<15} | {'Base IoU':<10} | {'Corrected':<10} | {'Diff'}")
    print("-" * 50)
    class_names = ["Bg", "Book", "Floor", "Win", "Furn", "Cab", "Obj", "Wall", "Struc", "Pict", "Desk", "Ceil", "Bed"]

    for i in range(13):
        name = class_names[i] if i < len(class_names) else str(i)
        diff = iou_corr[i] - iou_base[i]
        mark = "++" if diff > 0.01 else ("+" if diff > 0 else "")
        print(f"{name:<15} | {iou_base[i]:.4f}     | {iou_corr[i]:.4f}     | {diff:+.4f} {mark}")

    print("-" * 50)
    print(f"Mean IoU (Base): {np.mean(iou_base):.4f}")
    print(f"Mean IoU (Corr): {np.mean(iou_corr):.4f}")
    print("-" * 50)

    # コンソールログ用
    print(f"[LOG] Results saved to console. Force Remine: {FORCE_REMINE}")

if __name__ == "__main__":
    main_pipeline()

In [None]:
# =========================
# Cell 18 (Pure Vision): Blob Context R-CNN (No Prob Maps)
# =========================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Subset
import pandas as pd
import numpy as np
import os
import random
import albumentations as A
from tqdm import tqdm
from skimage.measure import label, regionprops
from torchvision.ops import roi_align

# --- Settings ---
OUTPUT_DIR = "/content/blob_correction_final"
BASE_MODEL_PATH = "/content/base_best_model.pt"

ID_BOOK = 1
ID_CABINET = 5
ID_OBJECT = 6
TARGET_IDS = [ID_BOOK, ID_CABINET, ID_OBJECT]
# 0:Other, 1:Book, 2:Cab, 3:Obj
NUM_CORRECT_CLASSES = 4

# --- 1. Blob R-CNN (Pure Vision) ---
class BlobRCNN(nn.Module):
    def __init__(self, in_channels=6, num_classes=4): # 9 -> 6 ch
        super(BlobRCNN, self).__init__()

        # Encoder
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64), nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(128), nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(256), nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Attention & Classifier
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc_att = nn.Sequential(
            nn.Linear(256, 32), nn.ReLU(),
            nn.Linear(32, 256), nn.Sigmoid()
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 4 * 4, 256),
            nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.4),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        feat = self.conv3(x)
        att = self.gap(feat).view(feat.size(0), -1)
        att = self.fc_att(att).view(feat.size(0), 256, 1, 1)
        feat = feat * att
        return self.classifier(feat)

class BlobImageDataset(Dataset):
    def __init__(self, crop_data_list):
        self.data = crop_data_list
    def __len__(self): return len(self.data)
    def __getitem__(self, idx): return self.data[idx]['tensor'], self.data[idx]['target']

# --- 2. Mining (Without Prob Maps) ---
def extract_roi_crops_pure(model, dataset, device, desc="Mining"):
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
    crop_dataset = []
    model.eval()

    ROI_SIZE = (32, 32)

    with torch.no_grad():
        for i, (x, y) in enumerate(tqdm(loader, desc=desc)):
            x = x.to(device)
            y_np = y.numpy()[0]

            # Baseモデルは「Blobの位置特定」だけに使用し、特徴量としては使わない
            logits = model(x)
            probs = F.softmax(logits, dim=1)
            pred_map = probs.argmax(dim=1).cpu().numpy()[0]

            # 入力特徴量: RGB(3) + Depth(1) + Height(1) = 5ch
            # Mask(1) は後で追加
            base_feats = x[0].clone() # [5, H, W]

            for pred_cls in TARGET_IDS:
                mask = (pred_map == pred_cls)
                if not mask.any(): continue
                lbl_img = label(mask)
                regions = regionprops(lbl_img)

                for props in regions:
                    if props.area < 50: continue

                    y0, x0, y1, x1 = props.bbox
                    r_mask_np = (lbl_img[y0:y1, x0:x1] == props.label)
                    gt_crop = y_np[y0:y1, x0:x1][r_mask_np]
                    valid_gt = gt_crop[gt_crop != 255]
                    if len(valid_gt) == 0: continue
                    true_label = np.argmax(np.bincount(valid_gt))

                    if true_label == ID_BOOK: target = 1
                    elif true_label == ID_CABINET: target = 2
                    elif true_label == ID_OBJECT: target = 3
                    else: target = 0

                    # RoI Cut
                    margin = 16
                    H, W = pred_map.shape
                    ry0 = max(0, y0 - margin); rx0 = max(0, x0 - margin)
                    ry1 = min(H, y1 + margin); rx1 = min(W, x1 + margin)

                    boxes = torch.tensor([[0, rx0, ry0, rx1, ry1]], dtype=torch.float).to(device)

                    # 1. Image Features (5ch)
                    feat_input = base_feats.unsqueeze(0)
                    roi_5ch = roi_align(feat_input, boxes, output_size=ROI_SIZE)[0] # [5, 32, 32]

                    # 2. Blob Mask (1ch) - "どれが注目対象か"
                    blob_mask_full = torch.from_numpy(lbl_img == props.label).float().to(device).unsqueeze(0).unsqueeze(0)
                    roi_mask = roi_align(blob_mask_full, boxes, output_size=ROI_SIZE)[0] # [1, 32, 32]

                    # 合体 -> 6ch
                    final_crop = torch.cat([roi_5ch, roi_mask], dim=0)

                    crop_dataset.append({
                        'tensor': final_crop.cpu(),
                        'target': target
                    })

    return crop_dataset

# --- 3. Training Loop ---
def run_pure_vision_rcnn():
    print(f"==========================================")
    print(f" TRAINING BLOB R-CNN (Pure Vision / No Bias)")
    print(f"==========================================")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup
    base_model = ResNeXtDeepLabV3Plus_OS8().to(device)
    base_model.load_state_dict(torch.load(BASE_MODEL_PATH, map_location=device))
    base_model.eval()

    val_aug = A.Compose([A.Resize(768, 768)], additional_targets={'depth': 'image', 'height': 'image'})
    full_ds = NYUv2Dataset("/content/data", split="train", transform=val_aug)
    indices = list(range(len(full_ds)))
    random.seed(42); random.shuffle(indices)
    n_val = int(len(full_ds) * 0.1)
    train_idx = indices[:-n_val]
    val_idx = indices[-n_val:]

    # Mining
    print(f"[STEP 1] Mining Pure Visual Crops...")
    train_crops = extract_roi_crops_pure(base_model, Subset(full_ds, train_idx), device, "Mining Train")
    val_crops = extract_roi_crops_pure(base_model, Subset(full_ds, val_idx), device, "Mining Val")
    print(f"  -> Train: {len(train_crops)}, Val: {len(val_crops)}")

    # Dataset
    train_ds_obj = BlobImageDataset(train_crops)
    val_ds_obj = BlobImageDataset(val_crops)

    targets = [x['target'] for x in train_crops]
    counts = np.bincount(targets); counts = np.maximum(counts, 1)
    weights = 1. / counts; sample_weights = [weights[t] for t in targets]
    sampler = WeightedRandomSampler(sample_weights, len(sample_weights))

    train_loader = DataLoader(train_ds_obj, batch_size=32, sampler=sampler, num_workers=0)

    # Model (6ch Input)
    model = BlobRCNN(in_channels=6, num_classes=4).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()

    print(f"\n[STEP 3] Training Pure Vision R-CNN...")
    for epoch in range(1, 31):
        model.train()
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad(); out = model(x); loss = criterion(out, y)
            loss.backward(); optimizer.step()
            total_loss += loss.item()
        if epoch % 5 == 0: print(f"  Ep {epoch}: Loss {total_loss/len(train_loader):.4f}")

    # Eval
    model.eval()
    val_loader = DataLoader(val_ds_obj, batch_size=32, shuffle=False)
    all_preds, all_targets = [], []
    with torch.no_grad():
        for x, y in val_loader:
            x = x.to(device)
            all_preds.extend(model(x).argmax(dim=1).cpu().numpy())
            all_targets.extend(y.numpy())

    print("\n--- Pure Vision R-CNN Report ---")
    print(classification_report(all_targets, all_preds, target_names=["Other", "Book", "Cab", "Obj"]))

    torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, "blob_rcnn_pure.pt"))
    print("[DONE] Saved.")

if __name__ == "__main__":
    run_pure_vision_rcnn()

In [None]:
# =========================
# Cell 13: Recovery & Execution
# (Re-generates exact same indices using Seed 42 and runs MLP)
# =========================
import random
import torch
import os
from torch.utils.data import Subset

# 1. Re-create Dataset & Indices (Exactly as in Cell 5)
print("[INFO] Re-generating indices using fixed seed (Seed=42)...")

# Dataset読み込み (Cell 1, 5と同一設定)
full_dataset = NYUv2Dataset("/content/data", split="train", transform=None)
indices = list(range(len(full_dataset)))

# ★ Cell 5と全く同じロジックでシャッフル
random.seed(42)
random.shuffle(indices)

n_val = int(len(full_dataset) * 0.1)
train_idx = indices[:-n_val]
val_idx = indices[-n_val:]

print(f"[INFO] Indices restored. Train: {len(train_idx)}, Val: {len(val_idx)}")
print(f"       (First 5 Train IDs: {train_idx[:5]})")
print(f"       (First 5 Val IDs:   {val_idx[:5]})")

# 2. Load Base Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model_path = "/content/base_best_model.pt"

if not os.path.exists(base_model_path):
    print(f"[ERROR] Base model not found at {base_model_path}")
else:
    print(f"[INFO] Loading Base Model from {base_model_path}...")
    base_model = ResNeXtDeepLabV3Plus_OS8().to(device)
    base_model.load_state_dict(torch.load(base_model_path, map_location=device))

    # 3. Run Blob Correction Pipeline (Defined in Cell 12)
    # これで「Base学習時と同じValデータ」を使ってMLPを評価できます
    run_blob_correction_pipeline(base_model, full_dataset, train_idx, val_idx)

In [None]:
# =========================
# Cell 9: Meta-Classification Feasibility Study (The "Blob" Analysis)
# =========================
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
from skimage.measure import label, regionprops
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns
import albumentations as A

# --- Settings ---
TARGET_MODEL_PATH = "/content/base_best_model.pt"
OUTPUT_DIR = "/content/blob_analysis"
ensure_dir(OUTPUT_DIR)

# IDs
ID_BOOK = 1
ID_CABINET = 5
ID_OBJECT = 6

def analyze_blobs():
    print(f"==========================================")
    print(f" START BLOB ANALYSIS (Feature Importance)")
    print(f" Goal: Find rules from Shape, Pos, and Stats")
    print(f"==========================================")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1. Load Model
    model = ResNeXtDeepLabV3Plus_OS8(num_classes=13, in_channels=5).to(device)
    try:
        model.load_state_dict(torch.load(TARGET_MODEL_PATH, map_location=device))
        model.eval()
    except:
        print("[ERROR] Load failed."); return

    # 2. Validation Data (Use ALL validation images to get enough blobs)
    val_aug = A.Compose([A.Resize(768, 768)], additional_targets={'depth':'image','height':'image'})
    full_ds = NYUv2Dataset("/content/data", split="train", transform=None)
    indices = list(range(len(full_ds)))
    import random
    random.seed(42)
    random.shuffle(indices)
    val_idx = indices[-int(len(full_ds)*0.1):] # Last 10%
    val_ds = Subset(NYUv2Dataset("/content/data", "train", val_aug), val_idx)
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False)

    blob_data = []

    print("[INFO] Extracting blobs features...")
    with torch.no_grad():
        for i, (x, y) in enumerate(tqdm(val_loader)):
            x = x.to(device)
            y_np = y.numpy()[0] # (H, W)

            # Inference
            logits = model(x)
            probs = F.softmax(logits, dim=1)
            pred_map = probs.argmax(dim=1).cpu().numpy()[0]

            # Extract Prob Maps for stats
            prob_book = probs[0, ID_BOOK].cpu().numpy()
            prob_cab  = probs[0, ID_CABINET].cpu().numpy()
            prob_obj  = probs[0, ID_OBJECT].cpu().numpy()

            # Target Classes: Cabinet & Object (Suspects)
            for target_cls in [ID_CABINET, ID_OBJECT]:
                mask = (pred_map == target_cls)
                if not mask.any(): continue

                lbl_img = label(mask)
                regions = regionprops(lbl_img)

                for props in regions:
                    if props.area < 50: continue # Skip dust

                    # --- 1. Geometric Features ---
                    y0, x0, y1, x1 = props.bbox
                    height = y1 - y0
                    width = x1 - x0
                    aspect_ratio = height / (width + 1e-6)
                    # Vertical position (normalized 0-1)
                    centroid_y_norm = props.centroid[0] / 768.0
                    centroid_x_norm = props.centroid[1] / 768.0

                    # Shape complexity (Solidity: Area / ConvexHullArea)
                    solidity = props.solidity
                    # Extent (Area / BoundingBoxArea)
                    extent = props.extent
                    # Eccentricity (0: Circle, 1: Line)
                    eccentricity = props.eccentricity

                    # --- 2. Statistical Features (Internal Probs) ---
                    # Mask for this blob
                    r_mask = (lbl_img == props.label)

                    p_book_vals = prob_book[r_mask]

                    mean_book = np.mean(p_book_vals)
                    max_book  = np.max(p_book_vals)
                    std_book  = np.std(p_book_vals)

                    # --- 3. Ground Truth (The Answer) ---
                    gt_vals = y_np[r_mask]
                    # How much of this blob is ACTUALLY Book?
                    book_pixel_count = (gt_vals == ID_BOOK).sum()
                    book_ratio = book_pixel_count / props.area

                    # Label: Is this blob "Mostly Book"? (Threshold 30% to be aggressive)
                    is_book_blob = (book_ratio > 0.3)

                    blob_data.append({
                        "original_pred": "Cabinet" if target_cls == ID_CABINET else "Object",
                        "area": props.area,
                        "height": height,
                        "width": width,
                        "aspect_ratio": aspect_ratio,
                        "centroid_y": centroid_y_norm,
                        "centroid_x": centroid_x_norm,
                        "solidity": solidity,
                        "extent": extent,
                        "eccentricity": eccentricity,
                        "prob_book_mean": mean_book,
                        "prob_book_max": max_book,
                        "prob_book_std": std_book,
                        "target_is_book": int(is_book_blob)
                    })

    # --- Analysis Phase ---
    df = pd.DataFrame(blob_data)
    if df.empty: return

    print(f"\n[INFO] Extracted {len(df)} blobs.")
    print(f"      Actual Books found in blobs: {df['target_is_book'].sum()} ({df['target_is_book'].mean():.2%})")

    # Save raw data
    df.to_csv(os.path.join(OUTPUT_DIR, "blob_features.csv"), index=False)

    # 4. Can we classify these? (Random Forest Test)
    # Features to use
    feature_cols = [
        "area", "height", "width", "aspect_ratio",
        "centroid_y", "centroid_x",
        "solidity", "extent", "eccentricity",
        "prob_book_mean", "prob_book_max", "prob_book_std"
    ]

    X = df[feature_cols].fillna(0)
    y = df["target_is_book"]

    # Split
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)

    # Train simple RF
    rf = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=42, class_weight="balanced")
    rf.fit(X_train, y_train)

    # Evaluate
    y_pred = rf.predict(X_test)
    y_prob = rf.predict_proba(X_test)[:, 1]

    print("\n--- Meta-Classifier Performance (Test Set) ---")
    print(classification_report(y_test, y_pred, target_names=["Not Book", "Is Book"]))
    print(f"ROC-AUC: {roc_auc_score(y_test, y_prob):.4f}")

    # 5. Feature Importance Plot
    importances = rf.feature_importances_
    indices = np.argsort(importances)[::-1]

    plt.figure(figsize=(10, 6))
    plt.title("What features define a 'Hidden Book' Blob?")
    plt.bar(range(X.shape[1]), importances[indices], align="center")
    plt.xticks(range(X.shape[1]), [feature_cols[i] for i in indices], rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, "feature_importance.png"))
    plt.close()

    print(f"[DONE] Check 'feature_importance.png' to see the rules found.")

if __name__ == "__main__":
    analyze_blobs()

In [None]:
# =========================
# Cell 8: Region-based Error Analysis (Finding the "Block" Threshold)
# =========================
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import os
import json
import pandas as pd
import cv2
import seaborn as sns
from tqdm import tqdm
from skimage.measure import label, regionprops

# --- Settings ---
TARGET_MODEL_PATH = "/content/base_best_model.pt"
OUTPUT_DIR = "/content/region_analysis"
ensure_dir(OUTPUT_DIR)

# IDs
ID_BOOK = 1
ID_CABINET = 5
ID_OBJECT = 6

def analyze_regions():
    print(f"==========================================")
    print(f" START REGION-BASED ANALYSIS")
    print(f" Hypothesis: Errors occur in 'blocks' (connected components).")
    print(f"==========================================")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1. Load Model
    try:
        model = ResNeXtDeepLabV3Plus_OS8(num_classes=13, in_channels=5).to(device)
        model.load_state_dict(torch.load(TARGET_MODEL_PATH, map_location=device))
        model.eval()
    except Exception as e:
        print(f"[ERROR] Model load failed: {e}")
        return

    # 2. Validation Data
    val_aug = A.Compose([A.Resize(768, 768)], additional_targets={'depth':'image','height':'image'})
    full_ds = NYUv2Dataset("/content/data", split="train", transform=None)
    indices = list(range(len(full_ds)))
    import random
    random.seed(42)
    random.shuffle(indices)
    val_idx = indices[-int(len(full_ds)*0.1):]
    val_ds = Subset(NYUv2Dataset("/content/data", "train", val_aug), val_idx)
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False) # Batch 1 for region prop

    # 3. Collect Region Stats
    region_data = []

    print("[INFO] extracting regions from Cabinet/Object predictions...")

    with torch.no_grad():
        for i, (x, y) in enumerate(tqdm(val_loader)):
            x = x.to(device)
            y_np = y.numpy()[0] # (H, W)

            logits = model(x)
            probs = F.softmax(logits, dim=1) # [1, 13, H, W]
            book_prob_map = probs[0, ID_BOOK].cpu().numpy() # [H, W] Book channel
            pred_map = probs.argmax(dim=1).cpu().numpy()[0] # [H, W]

            # Analyze 'Cabinet' and 'Object' regions (Candidate for correction)
            for target_cls in [ID_CABINET, ID_OBJECT]:
                # Extract connected components for this class
                mask = (pred_map == target_cls)
                if not mask.any(): continue

                # Label connected regions
                lbl_img = label(mask)
                regions = regionprops(lbl_img)

                for props in regions:
                    # Filter too small regions (Noise)
                    if props.area < 100: continue

                    # Create mask for this region
                    region_mask = (lbl_img == props.label)

                    # 1. How confident was the model that this is Book? (Latent belief)
                    mean_book_prob = np.mean(book_prob_map[region_mask])
                    max_book_prob = np.max(book_prob_map[region_mask])

                    # 2. Ground Truth Check: Is this ACTUALLY a Book?
                    gt_in_region = y_np[region_mask]
                    book_pixel_count = (gt_in_region == ID_BOOK).sum()
                    actual_book_ratio = book_pixel_count / props.area

                    # 3. Is this region a "Missed Book"? (Label > 50% Book)
                    is_actually_book = (actual_book_ratio > 0.5)

                    region_data.append({
                        "image_idx": i,
                        "pred_class": "Cabinet" if target_cls == ID_CABINET else "Object",
                        "area": props.area,
                        "mean_book_prob": mean_book_prob,
                        "max_book_prob": max_book_prob,
                        "actual_book_ratio": actual_book_ratio,
                        "is_actually_book": is_actually_book
                    })

    # 4. Analysis & Plotting
    df = pd.DataFrame(region_data)

    if df.empty:
        print("[WARN] No regions found.")
        return

    print(f"\n[INFO] Analyzed {len(df)} regions.")
    print(df.groupby(["pred_class", "is_actually_book"]).size())

    # Save CSV for check
    df.to_csv(os.path.join(OUTPUT_DIR, "region_stats.csv"), index=False)

    # Plot Scatter: Mean Book Prob vs Actual Book Ratio
    plt.figure(figsize=(10, 6))
    sns.scatterplot(
        data=df,
        x="mean_book_prob",
        y="actual_book_ratio",
        hue="pred_class",
        style="is_actually_book",
        alpha=0.6
    )
    plt.axvline(0.2, color="gray", linestyle="--", label="Threshold 0.2")
    plt.axhline(0.5, color="red", linestyle="--", label="Majority Book")
    plt.title("Region Analysis: Can we trust 'Book Probability' to flip labels?")
    plt.xlabel("Mean Book Probability in Region (Model's hidden belief)")
    plt.ylabel("Actual Book Ratio in Region (Ground Truth)")
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, "region_scatter.png"))
    plt.close()

    # Determine Thresholds
    # "If mean_book_prob > THRESHOLD, how many regions are actually books?"
    print("\n--- Threshold Analysis ---")
    for cls in ["Cabinet", "Object"]:
        print(f"\nTarget Pred Class: {cls}")
        sub_df = df[df["pred_class"] == cls]
        for thresh in [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5]:
            candidates = sub_df[sub_df["mean_book_prob"] > thresh]
            if len(candidates) == 0: continue
            precision = candidates["is_actually_book"].mean()
            n_convert = len(candidates)
            print(f"  Thresh > {thresh:.2f}: Convert {n_convert} regions -> Precision {precision:.2%}")

    print(f"\n[DONE] Analysis saved to {OUTPUT_DIR}")

if __name__ == "__main__":
    analyze_regions()

In [None]:
# =========================
# Cell 7: Confidence & Ambiguity Analysis
# =========================
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import os
import json
import pandas as pd
from tqdm import tqdm
import seaborn as sns

# --- Settings ---
# 分析対象のモデルパス（基本はBaseモデル推奨ですが、Boostモデルに変えてもOKです）
TARGET_MODEL_PATH = "/content/base_best_model.pt"
OUTPUT_DIR = "/content/confidence_analysis"
ensure_dir(OUTPUT_DIR)

# IDs
ID_BOOK = 1
ID_CABINET = 5
ID_OBJECT = 6

def analyze_confidence():
    print(f"==========================================")
    print(f" START CONFIDENCE ANALYSIS")
    print(f" Target Model: {os.path.basename(TARGET_MODEL_PATH)}")
    print(f"==========================================")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1. Load Model
    # ※モデル定義はCell 2のResNeXtDeepLabV3Plus_OS8を使用
    # Boostモデルの場合はクラス定義が異なるため、Baseモデル推奨
    try:
        model = ResNeXtDeepLabV3Plus_OS8(num_classes=13, in_channels=5).to(device)
        state = torch.load(TARGET_MODEL_PATH, map_location=device)
        model.load_state_dict(state)
        print("[INFO] Model loaded successfully.")
    except Exception as e:
        print(f"[ERROR] Failed to load model. Ensure Cell 2 is executed and path is correct.\n{e}")
        return

    model.eval()

    # 2. Setup Validation Data
    val_aug = A.Compose([A.Resize(768, 768)], additional_targets={'depth':'image','height':'image'})
    full_ds = NYUv2Dataset("/content/data", split="train", transform=None)
    indices = list(range(len(full_ds)))
    import random
    random.seed(42)
    random.shuffle(indices)
    val_idx = indices[-int(len(full_ds)*0.1):]
    val_ds = Subset(NYUv2Dataset("/content/data", "train", val_aug), val_idx)
    val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=2)

    # 3. Data Collection Containers
    # Bookピクセルにおける確信度（Probability of predicted class）
    conf_stats = {
        "TP_Book": [],       # 正解 (Pred=Book, GT=Book)
        "FN_Book_as_Cab": [], # 見逃し (Pred=Cabinet, GT=Book)
        "FN_Book_as_Obj": [], # 見逃し (Pred=Object, GT=Book)
        "FP_Cab_as_Book": []  # 過検知 (Pred=Book, GT=Cabinet)
    }

    # Book確率そのものの分布 (Logits of Book class)
    book_prob_stats = {
        "TP_Book": [],
        "FN_Book_as_Cab": [],
        "FN_Book_as_Obj": []
    }

    print("[INFO] Collecting probability stats...")

    with torch.no_grad():
        for x, y in tqdm(val_loader):
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            probs = F.softmax(logits, dim=1) # [B, 13, H, W]
            preds = probs.argmax(dim=1)      # [B, H, W]

            # --- Masking ---
            # 1. Ground Truth Masks
            mask_gt_book = (y == ID_BOOK)
            mask_gt_cab = (y == ID_CABINET)

            # 2. Prediction Masks
            mask_pred_book = (preds == ID_BOOK)
            mask_pred_cab = (preds == ID_CABINET)
            mask_pred_obj = (preds == ID_OBJECT)

            # --- Extract Probabilities ---
            # A. Correct Book (TP)
            mask_tp = mask_gt_book & mask_pred_book
            if mask_tp.any():
                # その予測に対する自信度
                conf_stats["TP_Book"].extend(probs.max(dim=1).values[mask_tp].cpu().numpy())
                # Bookクラスの確率値
                book_prob_stats["TP_Book"].extend(probs[:, ID_BOOK][mask_tp].cpu().numpy())

            # B. Missed Book -> Cabinet (FN)
            mask_fn_cab = mask_gt_book & mask_pred_cab
            if mask_fn_cab.any():
                conf_stats["FN_Book_as_Cab"].extend(probs.max(dim=1).values[mask_fn_cab].cpu().numpy())
                book_prob_stats["FN_Book_as_Cab"].extend(probs[:, ID_BOOK][mask_fn_cab].cpu().numpy())

            # C. Missed Book -> Object (FN)
            mask_fn_obj = mask_gt_book & mask_pred_obj
            if mask_fn_obj.any():
                conf_stats["FN_Book_as_Obj"].extend(probs.max(dim=1).values[mask_fn_obj].cpu().numpy())
                book_prob_stats["FN_Book_as_Obj"].extend(probs[:, ID_BOOK][mask_fn_obj].cpu().numpy())

            # D. False Alarm Cabinet -> Book (FP)
            mask_fp_cab = mask_gt_cab & mask_pred_book
            if mask_fp_cab.any():
                conf_stats["FP_Cab_as_Book"].extend(probs.max(dim=1).values[mask_fp_cab].cpu().numpy())

            # Sampling to avoid OOM for large lists (Keep max 100k samples per category)
            for k in conf_stats:
                if len(conf_stats[k]) > 100000:
                    conf_stats[k] = conf_stats[k][:100000]
            for k in book_prob_stats:
                if len(book_prob_stats[k]) > 100000:
                    book_prob_stats[k] = book_prob_stats[k][:100000]

    # 4. Visualization
    print("[INFO] Generating Plots...")
    sns.set_style("whitegrid")

    # Plot 1: Prediction Confidence Histogram (モデルがどれだけ自信を持って判定したか)
    plt.figure(figsize=(12, 6))
    sns.histplot(conf_stats["TP_Book"], color="green", label="Correct Book (TP)", kde=True, bins=50, alpha=0.3)
    sns.histplot(conf_stats["FN_Book_as_Cab"], color="red", label="Missed Book -> Cabinet (FN)", kde=True, bins=50, alpha=0.3)
    sns.histplot(conf_stats["FN_Book_as_Obj"], color="orange", label="Missed Book -> Object (FN)", kde=True, bins=50, alpha=0.3)
    plt.title("Model Confidence Distribution (How sure was the model?)")
    plt.xlabel("Confidence (Probability of Predicted Class)")
    plt.ylabel("Count (Pixels)")
    plt.legend()
    plt.savefig(os.path.join(OUTPUT_DIR, "confidence_histogram.png"))
    plt.close()

    # Plot 2: Book Probability Distribution (Bookクラスの確率はどれくらい出ていたか)
    plt.figure(figsize=(12, 6))
    sns.histplot(book_prob_stats["TP_Book"], color="green", label="Correct Book (TP)", kde=True, bins=50, alpha=0.3)
    sns.histplot(book_prob_stats["FN_Book_as_Cab"], color="red", label="Missed Book -> Cabinet (FN)", kde=True, bins=50, alpha=0.3)
    plt.axvline(0.5, color='black', linestyle='--', label="Threshold 0.5")
    plt.title("Book Class Probability Distribution (Did it even consider it as Book?)")
    plt.xlabel("Probability assigned to 'Book' class")
    plt.ylabel("Count (Pixels)")
    plt.legend()
    plt.savefig(os.path.join(OUTPUT_DIR, "book_prob_histogram.png"))
    plt.close()

    # 5. Statistical Summary
    summary = {}
    for k, v in conf_stats.items():
        if len(v) > 0:
            summary[k] = {
                "mean_conf": float(np.mean(v)),
                "std_conf": float(np.std(v)),
                "median_conf": float(np.median(v))
            }

    # Book Probability Summary
    summary["Book_Prob_Stats"] = {}
    for k, v in book_prob_stats.items():
        if len(v) > 0:
            summary["Book_Prob_Stats"][k] = {
                "mean_prob": float(np.mean(v)),
                "median_prob": float(np.median(v))
            }

    with open(os.path.join(OUTPUT_DIR, "confidence_stats.json"), "w") as f:
        json.dump(summary, f, indent=2)

    print(f"[DONE] Analysis complete. Saved to {OUTPUT_DIR}")
    print("Check 'confidence_histogram.png' and 'book_prob_histogram.png'")

if __name__ == "__main__":
    analyze_confidence()

In [None]:
# =========================
# Cell 6: Visual Error Analysis (Book/Cabinet/Object)
# =========================
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import os
import json
import cv2
from tqdm import tqdm
from datetime import datetime

# --- Constants & Settings ---
VIS_OUTPUT_DIR = "/content/error_analysis_vis"
JSON_OUTPUT_PATH = "/content/error_analysis_log.json"
BASE_MODEL_PATH = "/content/base_best_model.pt"

# ID Definition
ID_BOOK = 1
ID_CABINET = 5
ID_OBJECT = 6

# Color Map for Visualization
# [R, G, B]
COLORS = {
    "Correct_Book": [0, 255, 0],      # Green
    "Err_Book_as_Cab": [255, 0, 0],   # Red (Critical)
    "Err_Book_as_Obj": [255, 165, 0], # Orange
    "Err_Cab_as_Book": [255, 255, 0], # Yellow
    "Err_Obj_as_Book": [255, 0, 255], # Magenta
    "Background": [0, 0, 0]           # Black
}

def ensure_dir(d):
    os.makedirs(d, exist_ok=True)

def apply_color_map(gt, pred, img_shape):
    """
    Generate a Confusion Map based on specific errors.
    """
    H, W = img_shape
    vis_map = np.zeros((H, W, 3), dtype=np.uint8)

    # Masks
    mask_book_gt = (gt == ID_BOOK)
    mask_cab_gt = (gt == ID_CABINET)
    mask_obj_gt = (gt == ID_OBJECT)

    mask_book_pred = (pred == ID_BOOK)
    mask_cab_pred = (pred == ID_CABINET)
    mask_obj_pred = (pred == ID_OBJECT)

    # 1. Correct Book (Green)
    vis_map[mask_book_gt & mask_book_pred] = COLORS["Correct_Book"]

    # 2. Missed Book (Book -> Cabinet) (Red)
    vis_map[mask_book_gt & mask_cab_pred] = COLORS["Err_Book_as_Cab"]

    # 3. Missed Book (Book -> Object) (Orange)
    vis_map[mask_book_gt & mask_obj_pred] = COLORS["Err_Book_as_Obj"]

    # 4. False Book (Cabinet -> Book) (Yellow)
    vis_map[mask_cab_gt & mask_book_pred] = COLORS["Err_Cab_as_Book"]

    # 5. False Book (Object -> Book) (Magenta)
    vis_map[mask_obj_gt & mask_book_pred] = COLORS["Err_Obj_as_Book"]

    return vis_map

def run_visual_analysis():
    print(f"==========================================")
    print(f" START VISUAL ERROR ANALYSIS")
    print(f" Target: Book vs Cabinet vs Object")
    print(f" Output: {VIS_OUTPUT_DIR}")
    print(f"==========================================")

    ensure_dir(VIS_OUTPUT_DIR)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1. Load Model
    print(f"[INFO] Loading Base Model: {BASE_MODEL_PATH}")
    model = ResNeXtDeepLabV3Plus_OS8(num_classes=13, in_channels=5).to(device)
    try:
        model.load_state_dict(torch.load(BASE_MODEL_PATH, map_location=device))
    except Exception as e:
        print(f"[ERROR] Model load failed: {e}")
        return
    model.eval()

    # 2. Setup Validation Data
    val_aug = A.Compose([A.Resize(768, 768)], additional_targets={'depth':'image','height':'image'})
    # Assuming val_idx and dataset_root are available from previous context
    # Re-creating dataset for safety
    full_ds = NYUv2Dataset("/content/data", split="train", transform=None)
    indices = list(range(len(full_ds)))
    import random
    random.seed(42)
    random.shuffle(indices)
    val_idx = indices[-int(len(full_ds)*0.1):]

    val_ds = Subset(NYUv2Dataset("/content/data", "train", val_aug), val_idx)
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False) # Batch size 1 for visualization

    analysis_log = {
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "summary": {"total_book_pixels": 0, "book_as_cabinet": 0, "book_as_object": 0},
        "details": []
    }

    count_vis = 0
    MAX_VIS = 50 # Limit number of output images to avoid overflow

    print(f"[INFO] Analyzing {len(val_loader)} validation images...")

    with torch.no_grad():
        for i, (x, y) in enumerate(tqdm(val_loader)):
            x = x.to(device)
            y_np = y.numpy()[0] # (H, W)

            # Check if image contains relevant classes
            has_book = (y_np == ID_BOOK).any()
            has_cabinet = (y_np == ID_CABINET).any()
            has_object = (y_np == ID_OBJECT).any()

            # Skip if no Book is present (we want to see where books are missed)
            if not has_book:
                continue

            # Inference
            logits = model(x)
            pred = logits.argmax(dim=1).cpu().numpy()[0] # (H, W)

            # --- Statistics ---
            # Mask for Ground Truth Book
            mask_book = (y_np == ID_BOOK)
            total_book = mask_book.sum()

            # Errors within Book area
            pred_in_book = pred[mask_book]
            err_as_cab = (pred_in_book == ID_CABINET).sum()
            err_as_obj = (pred_in_book == ID_OBJECT).sum()

            # Accumulate
            analysis_log["summary"]["total_book_pixels"] += int(total_book)
            analysis_log["summary"]["book_as_cabinet"] += int(err_as_cab)
            analysis_log["summary"]["book_as_object"] += int(err_as_obj)

            # --- Visualization (Top 50 worst cases or interesting cases) ---
            # Calculate "Book IoU" for this image
            intersection = ((pred == ID_BOOK) & (y_np == ID_BOOK)).sum()
            union = ((pred == ID_BOOK) | (y_np == ID_BOOK)).sum()
            book_iou = intersection / (union + 1e-6)

            # Visualize if IoU is low OR significant confusion exists
            if count_vis < MAX_VIS and (book_iou < 0.5 or err_as_cab > 1000):
                # RGB Image for display
                rgb_disp = x[0, :3, :, :].permute(1, 2, 0).cpu().numpy()
                # Denormalize
                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                rgb_disp = (rgb_disp * std + mean).clip(0, 1)

                # Generate Confusion Map
                conf_map = apply_color_map(y_np, pred, (768, 768))

                # Plot
                fig, axs = plt.subplots(1, 3, figsize=(18, 6))

                axs[0].imshow(rgb_disp)
                axs[0].set_title(f"RGB Input (ID: {i})")
                axs[0].axis("off")

                axs[1].imshow(y_np, cmap='nipy_spectral', vmin=0, vmax=12)
                axs[1].set_title("Ground Truth (All Classes)")
                axs[1].axis("off")

                axs[2].imshow(conf_map)
                axs[2].set_title(f"Confusion Map\nBookIoU: {book_iou:.3f} | Red: Book->Cab")
                axs[2].axis("off")

                # Legend text
                legend = "Green: Correct Book\nRed: Book->Cabinet (Miss)\nOrange: Book->Object (Miss)\nYellow: Cabinet->Book (False)\nMagenta: Object->Book (False)"
                plt.figtext(0.85, 0.5, legend, fontsize=12, bbox={"facecolor":"white", "alpha":0.8, "pad":5})

                save_path = os.path.join(VIS_OUTPUT_DIR, f"err_{i:04d}_iou{int(book_iou*100)}.png")
                plt.savefig(save_path, bbox_inches='tight')
                plt.close()

                count_vis += 1

                # Log details
                analysis_log["details"].append({
                    "image_idx": i,
                    "book_iou": float(book_iou),
                    "total_book_pixels": int(total_book),
                    "misclassified_as_cabinet": int(err_as_cab),
                    "misclassified_as_object": int(err_as_obj)
                })

    # Save JSON
    with open(JSON_OUTPUT_PATH, "w") as f:
        json.dump(analysis_log, f, indent=2)

    print(f"[DONE] Analysis saved to {JSON_OUTPUT_PATH}")
    print(f"[DONE] Images saved to {VIS_OUTPUT_DIR} (Count: {count_vis})")

    # Zip for download
    import shutil
    shutil.make_archive("/content/error_analysis_vis", 'zip', VIS_OUTPUT_DIR)
    print(f"[READY] Download /content/error_analysis_vis.zip to see images.")

if __name__ == "__main__":
    if os.path.exists(BASE_MODEL_PATH):
        run_visual_analysis()
    else:
        print("[ERROR] Base model not found. Please run Base training first.")

In [None]:
# =========================
# Cell 6: Base Model Evaluation & Analysis
# =========================
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import numpy as np
import random
import json
import os
from tqdm import tqdm

# --- 以前のCellで定義された定数やクラスを使用します ---
# (ResNeXtDeepLabV3Plus_OS8, NYUv2Dataset, update_cm, compute_metrics_from_cm など)

def analyze_base_model(model_path, dataset_root, batch_size=16):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[EVAL] Device: {device}")
    print(f"[EVAL] Loading Model from: {model_path}")

    # 1. モデルのロード (学習時と同じ構成)
    model = ResNeXtDeepLabV3Plus_OS8(num_classes=NUM_CLASSES, in_channels=5).to(device)
    try:
        checkpoint = torch.load(model_path, map_location=device)
        # state_dictのキー調整（DataParallelなどで保存された場合用）
        if list(checkpoint.keys())[0].startswith("module."):
            checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()}
        model.load_state_dict(checkpoint)
    except Exception as e:
        print(f"[ERROR] Failed to load model: {e}")
        return

    model.eval()

    # 2. データセット分割の再現 (学習時と完全に同じSeedで行う)
    print("[EVAL] Reconstructing Validation Split (Seed=42)...")
    full_dataset = NYUv2Dataset(dataset_root, split="train", transform=None)
    indices = list(range(len(full_dataset)))

    # ★重要: 学習時と同じランダムシードでシャッフル
    random.seed(42)
    random.shuffle(indices)

    # 最後の10%を検証データとして切り出し
    n_val = int(len(full_dataset) * 0.1)
    val_idx = indices[-n_val:]

    print(f"  -> Total Valid Samples: {len(val_idx)}")

    # 評価用Augmentation (リサイズのみ)
    val_aug = A.Compose([
        A.Resize(768, 768) # 学習時と同じサイズ推奨
    ], additional_targets={'depth': 'image', 'height': 'image'})

    val_ds = Subset(NYUv2Dataset(dataset_root, "train", val_aug), val_idx)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    # 3. 推論と混合行列の作成
    cm = torch.zeros((NUM_CLASSES, NUM_CLASSES), dtype=torch.long)

    print("[EVAL] Running Inference...")
    with torch.no_grad():
        for x, y in tqdm(val_loader, desc="Evaluating"):
            x = x.to(device)
            # 推論
            logits = model(x)
            pred = logits.argmax(dim=1).cpu()
            y = y.cpu()

            # 混合行列の更新
            cm = update_cm(cm, pred, y, NUM_CLASSES, IGNORE_INDEX)

    # 4. メトリクス計算
    metrics = compute_metrics_from_cm(cm)
    cm_np = cm.numpy()

    # --- レポート作成 ---
    print("\n" + "="*60)
    print(f" BASE MODEL EVALUATION REPORT")
    print("="*60)
    print(f"Global mIoU: {metrics['miou']:.5f}\n")

    # (A) 人間用のテーブル表示
    header = f"{'ID':<3} {'Class':<12} | {'IoU':<8} {'Prec':<8} {'Recall':<8}"
    print(header)
    print("-" * len(header))
    for i, name in enumerate(CLASS_NAMES):
        iou = metrics['class_iou'][i]
        prec = metrics['class_precision'][i]
        rec = metrics['class_recall'][i]
        print(f"{i:<3} {name:<12} | {iou:.4f}   {prec:.4f}   {rec:.4f}")

    print("-" * len(header))

    # (B) AI分析用 JSONダンプ
    analysis_data = {
        "model": "Base_ResNeXt101_Unfrozen_Partially",
        "miou": metrics['miou'],
        "per_class": []
    }
    for i, name in enumerate(CLASS_NAMES):
        analysis_data["per_class"].append({
            "id": i,
            "name": name,
            "iou": metrics['class_iou'][i],
            "precision": metrics['class_precision'][i],
            "recall": metrics['class_recall'][i]
        })

    print("\n--- [JSON DATA START] (Copy this for Analysis) ---")
    print(json.dumps(analysis_data, indent=2))
    print("--- [JSON DATA END] ---\n")

    # (C) Bookクラス (ID:1) の誤検知詳細分析
    print("="*30)
    print(" 📖 BOOK CLASS (ID:1) ERROR ANALYSIS")
    print("="*30)

    book_row = cm_np[BOOK_ID] # 正解がBookであるピクセルが、何と予測されたか
    total_book_pixels = book_row.sum()

    if total_book_pixels > 0:
        tp = book_row[BOOK_ID]
        print(f"Total 'Book' Pixels in Val Set: {total_book_pixels}")
        print(f"Correctly Predicted (TP):       {tp} ({tp/total_book_pixels*100:.2f}%)")
        print(f"Missed (False Negatives):       {total_book_pixels - tp} ({(total_book_pixels - tp)/total_book_pixels*100:.2f}%)")
        print("\n[Top Confusion: What 'Book' was mistaken for]")

        # 誤検知の多い順にソート（自分自身を除く）
        confused_indices = np.argsort(-book_row)
        rank = 1
        for idx in confused_indices:
            if idx == BOOK_ID: continue
            if book_row[idx] == 0: break # 0件なら終了

            ratio = book_row[idx] / total_book_pixels * 100
            print(f"  #{rank} -> {CLASS_NAMES[idx]} (ID:{idx}): {book_row[idx]} pixels ({ratio:.2f}%)")
            rank += 1
            if rank > 5: break
    else:
        print("[WARN] No Book pixels found in validation set.")

    # (D) 混合行列の保存
    save_path = "base_eval_confusion_matrix.npy"
    np.save(save_path, cm_np)
    print(f"\n[SAVED] Confusion Matrix saved to: {save_path}")

# ★実行パス設定
# Cell 5で学習したパス、またはアップロードしたパスを指定してください
BASE_MODEL_PATH = "/content/base_best_model.pt"
DATASET_ROOT = "/content/data"

if os.path.exists(BASE_MODEL_PATH):
    analyze_base_model(BASE_MODEL_PATH, DATASET_ROOT)
else:
    print(f"Model file not found at {BASE_MODEL_PATH}. Please check the path.")

In [None]:
import json
import matplotlib.pyplot as plt
import glob
import os

def plot_learning_curves(log_file):
    epochs = []
    losses = []
    mious = []
    lrs = []

    # ログファイルの読み込み
    with open(log_file, 'r') as f:
        for line in f:
            data = json.loads(line)
            # ヘッダー行（info）はスキップ
            if "info" in data:
                continue

            epochs.append(data['epoch'])
            losses.append(data['loss'])
            mious.append(data['val_miou'])
            lrs.append(data['lr'])

    # プロットの作成
    fig, ax1 = plt.subplots(figsize=(12, 6))

    # Loss (左軸・赤)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Training Loss', color='tab:red')
    ax1.plot(epochs, losses, color='tab:red', label='Train Loss', marker='.')
    ax1.tick_params(axis='y', labelcolor='tab:red')
    ax1.grid(True, alpha=0.3)

    # mIoU (右軸・青)
    ax2 = ax1.twinx()
    ax2.set_ylabel('Validation mIoU', color='tab:blue')
    ax2.plot(epochs, mious, color='tab:blue', label='Val mIoU', marker='.')
    ax2.tick_params(axis='y', labelcolor='tab:blue')

    plt.title(f'Learning Curve: {os.path.basename(log_file)}')
    fig.tight_layout()
    plt.show()

    # 学習率の確認用プロット（Poly Schedulerが効いているか確認）
    plt.figure(figsize=(12, 3))
    plt.plot(epochs, lrs, color='orange')
    plt.title("Learning Rate Schedule")
    plt.xlabel("Epoch")
    plt.ylabel("LR")
    plt.grid(True, alpha=0.3)
    plt.show()

# 自動で最新のログファイルを見つけて描画
log_files = sorted(glob.glob("train_log_detailed_*.jsonl"))
if log_files:
    latest_log = log_files[-1]
    print(f"Plotting: {latest_log}")
    plot_learning_curves(latest_log)
else:
    print("No log file found.")

In [None]:
# -*- coding: utf-8 -*-
# NYUv2 val visualization (RGB | GT | Pred | RGB+Error) with FIXED palette -> ZIP
# - NO matplotlib colormap
# - GT and Pred share EXACT same palette mapping (ID -> RGB)
# - ignore_index(255) is shown as black by default

import os
import zipfile
import numpy as np
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, Subset
from torchvision.datasets import VisionDataset
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
from torchvision import models

# =========================
# EDIT HERE
# =========================
DATASET_ROOT = "/content/data"
IMAGE_SIZE = (512, 512)

MODEL_PATH = "/content/checkpoints/model_20260103062841.pt"  # ←ここだけ直す
OUTDIR = "val_viz_fixed"
ZIP_PATH = "val_viz_fixed.zip"

MAX_SAVE = 50  # Noneで全件
SEED = 42
TRAIN_VAL_SPLIT = 0.9

NUM_WORKERS = 0
BATCH_SIZE = 1

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
IGNORE_INDEX = 255

print("DEVICE:", DEVICE)
print("MODEL_PATH:", MODEL_PATH)

# =========================
# FIXED PALETTE (ID -> RGB)
# ここが「色の意味」そのもの。GTもPredも必ずこれを通す。
# =========================
PALETTE_13 = np.array([
    [  0,   0,   0],   # 0 wall
    [128,  64,  32],   # 1 floor
    [  0, 128, 255],   # 2 cabinet
    [255,   0, 128],   # 3 bed
    [  0, 255,   0],   # 4 chair
    [255, 128,   0],   # 5 sofa
    [255, 255,   0],   # 6 table
    [128, 128, 128],   # 7 door
    [  0, 255, 255],   # 8 window
    [  0,  64, 128],   # 9 bookshelf
    [255,   0,   0],   # 10 picture
    [128,   0, 255],   # 11 counter
    [ 64, 255, 128],   # 12 desk
], dtype=np.uint8)

# =========================
# Dataset helpers
# =========================
def pil_label_to_long_tensor(lbl_pil: Image.Image) -> torch.Tensor:
    arr = np.array(lbl_pil, dtype=np.int64)
    return torch.from_numpy(arr).long()

def depth_pil_to_tensor_01(depth_pil: Image.Image, size_hw):
    depth_pil = depth_pil.resize((size_hw[1], size_hw[0]), resample=Image.BILINEAR)
    arr = np.array(depth_pil)
    if arr.dtype == np.uint16 or (arr.max() > 255):
        arr = arr.astype(np.float32) / 65535.0
    else:
        arr = arr.astype(np.float32) / 255.0
    arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
    arr = np.clip(arr, 0.0, 1.0)
    return torch.from_numpy(arr).unsqueeze(0)  # [1,H,W]

class NYUv2(VisionDataset):
    def __init__(self, root, split="train", include_depth=True,
                 image_transform=None, depth_transform=None, target_transform=None):
        super().__init__(root)
        assert split in ("train", "test")
        self.root = root
        self.split = split
        self.include_depth = include_depth

        images_dir = os.path.join(self.root, self.split, "image")
        img_names = sorted(os.listdir(images_dir))

        self.images = [os.path.join(images_dir, n) for n in img_names]
        self.depths = [os.path.join(self.root, self.split, "depth", n) for n in img_names]

        if self.split == "train":
            self.targets = [os.path.join(self.root, self.split, "label", n) for n in img_names]
        else:
            self.targets = None

        self.image_transform = image_transform
        self.depth_transform = depth_transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert("RGB")
        depth = Image.open(self.depths[idx])

        if self.image_transform is not None:
            image = self.image_transform(image)
        if self.depth_transform is not None:
            depth = self.depth_transform(depth)

        if self.split == "test":
            return (image, depth) if self.include_depth else image

        target = Image.open(self.targets[idx])
        if self.target_transform is not None:
            target = self.target_transform(target)

        return (image, depth, target) if self.include_depth else (image, target)

# =========================
# Model blocks (ResNet50-UNet + optional SE)
# =========================
class ConvBNReLU(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, p=1):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=k, padding=p, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.block(x)

class UpBlock(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.conv1 = ConvBNReLU(in_ch + skip_ch, out_ch)
        self.conv2 = ConvBNReLU(out_ch, out_ch)

    def forward(self, x, skip):
        x = nn.functional.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
        x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class DepthStem(nn.Module):
    def __init__(self, z_ch=64):
        super().__init__()
        self.net = nn.Sequential(
            ConvBNReLU(1, 32, 3, 1),
            nn.MaxPool2d(2),
            ConvBNReLU(32, 48, 3, 1),
            nn.MaxPool2d(2),
            ConvBNReLU(48, z_ch, 3, 1),
            nn.MaxPool2d(2),
        )
    def forward(self, d):
        return self.net(d)

class SEFromDepth(nn.Module):
    def __init__(self, feat_ch, z_ch, reduction=16, alpha=0.05):
        super().__init__()
        self.alpha = float(alpha)
        hidden = max(32, feat_ch // reduction)
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(z_ch, hidden, 1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden, feat_ch, 1, bias=True),
            nn.Sigmoid(),
        )
    def forward(self, F, z):
        g = self.fc(z)  # [B,C,1,1]
        scale = 1.0 + self.alpha * (g - 0.5)
        return F * scale

class ResNet50UNet(nn.Module):
    def __init__(self, num_classes=13, coarse_classes=5, in_channels=4,
                 pretrained=False,
                 use_se=False, se_z_ch=64, se_reduction=16, se_alpha=0.05,
                 se_dec_stages=("up1",)):
        super().__init__()
        self.use_se = use_se
        self.se_dec_stages = tuple(se_dec_stages)

        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None)

        old_conv = resnet.conv1
        if in_channels != old_conv.in_channels:
            new_conv = nn.Conv2d(
                in_channels, old_conv.out_channels,
                kernel_size=old_conv.kernel_size,
                stride=old_conv.stride,
                padding=old_conv.padding,
                bias=False
            )
            with torch.no_grad():
                new_conv.weight[:, :3] = old_conv.weight
                mean_w = old_conv.weight.mean(dim=1, keepdim=True)
                for c in range(3, in_channels):
                    new_conv.weight[:, c:c+1] = mean_w
            resnet.conv1 = new_conv

        self.enc0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)  # 1/2
        self.pool = resnet.maxpool                                        # 1/4
        self.enc1 = resnet.layer1                                         # 1/4
        self.enc2 = resnet.layer2                                         # 1/8
        self.enc3 = resnet.layer3                                         # 1/16
        self.enc4 = resnet.layer4                                         # 1/32

        self.center = ConvBNReLU(2048, 1024)
        self.up4 = UpBlock(1024, 1024, 512)
        self.up3 = UpBlock(512,  512,  256)
        self.up2 = UpBlock(256,  256,  128)
        self.up1 = UpBlock(128,  64,   64)

        self.head_feat = ConvBNReLU(64, 64)
        self.head_out  = nn.Conv2d(64, num_classes, kernel_size=1)

        self.coarse_head = nn.Conv2d(256, coarse_classes, kernel_size=1)
        self.boundary_head = nn.Sequential(ConvBNReLU(64, 32), nn.Conv2d(32, 1, 1))

        if self.use_se:
            self.depth_stem = DepthStem(z_ch=se_z_ch)
            stage_to_ch = {"center":1024, "up4":512, "up3":256, "up2":128, "up1":64}
            self.se_dec = nn.ModuleDict()
            for st in self.se_dec_stages:
                self.se_dec[st] = SEFromDepth(stage_to_ch[st], se_z_ch, se_reduction, se_alpha)
        else:
            self.depth_stem = None
            self.se_dec = None

    def _apply_dec_se(self, feat, z, stage_name: str):
        if (not self.use_se) or (self.se_dec is None) or (stage_name not in self.se_dec):
            return feat
        z_ = nn.functional.interpolate(z, size=feat.shape[-2:], mode="bilinear", align_corners=False)
        return self.se_dec[stage_name](feat, z_)

    def forward(self, x, depth=None):
        c1 = self.enc0(x)
        t  = self.pool(c1)
        c2 = self.enc1(t)
        c3 = self.enc2(c2)
        c4 = self.enc3(c3)
        c5 = self.enc4(c4)

        z = self.depth_stem(depth) if self.use_se else None

        x = self.center(c5)
        if z is not None: x = self._apply_dec_se(x, z, "center")
        x = self.up4(x, c4)
        if z is not None: x = self._apply_dec_se(x, z, "up4")
        x = self.up3(x, c3)
        if z is not None: x = self._apply_dec_se(x, z, "up3")

        coarse_logits = self.coarse_head(x)

        x = self.up2(x, c2)
        if z is not None: x = self._apply_dec_se(x, z, "up2")
        x = self.up1(x, c1)
        if z is not None: x = self._apply_dec_se(x, z, "up1")

        x = nn.functional.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
        feat = self.head_feat(x)
        fine_logits = self.head_out(feat)
        boundary_logit = self.boundary_head(x)
        return fine_logits, coarse_logits, boundary_logit

# =========================
# FIXED colorization (ID -> RGB)
# =========================
def id_to_rgb(mask_hw: np.ndarray, ignore_index=255, ignore_rgb=(0,0,0)) -> np.ndarray:
    """
    mask_hw: [H,W] int (0..12 or 255)
    returns: [H,W,3] uint8
    """
    out = np.zeros((mask_hw.shape[0], mask_hw.shape[1], 3), dtype=np.uint8)
    ign = (mask_hw == ignore_index)
    valid = ~ign
    m = mask_hw.copy()
    m[ign] = 0
    m = np.clip(m, 0, 12)
    out[valid] = PALETTE_13[m[valid]]
    out[ign] = np.array(ignore_rgb, dtype=np.uint8)
    return out

def denorm_rgb(rgb_tensor: torch.Tensor) -> Image.Image:
    mean = torch.tensor([0.485, 0.456, 0.406], device=rgb_tensor.device).view(3,1,1)
    std  = torch.tensor([0.229, 0.224, 0.225], device=rgb_tensor.device).view(3,1,1)
    x = (rgb_tensor * std + mean).clamp(0, 1)
    x = (x * 255).byte().permute(1,2,0).cpu().numpy()
    return Image.fromarray(x, mode="RGB")

def overlay_error(rgb_pil: Image.Image, gt_hw: np.ndarray, pred_hw: np.ndarray, ignore_index=255) -> Image.Image:
    rgb = np.array(rgb_pil).astype(np.float32)
    err = (gt_hw != ignore_index) & (gt_hw != pred_hw)
    if err.any():
        overlay = rgb.copy()
        overlay[err] = 0.6 * overlay[err] + 0.4 * np.array([255, 0, 0], dtype=np.float32)
        overlay = np.clip(overlay, 0, 255).astype(np.uint8)
        return Image.fromarray(overlay, mode="RGB")
    return rgb_pil

def stack_h(images):
    widths = [im.width for im in images]
    heights = [im.height for im in images]
    out = Image.new("RGB", (sum(widths), max(heights)))
    x = 0
    for im in images:
        out.paste(im, (x, 0))
        x += im.width
    return out

# =========================
# Transforms (same as eval)
# =========================
eval_image_transform = Compose([
    Resize(IMAGE_SIZE, interpolation=InterpolationMode.BILINEAR),
    ToTensor(),
    Normalize((0.485,0.456,0.406), (0.229,0.224,0.225)),
])

def depth_transform(pil):
    return depth_pil_to_tensor_01(pil, IMAGE_SIZE)

def target_transform(pil):
    pil = pil.resize((IMAGE_SIZE[1], IMAGE_SIZE[0]), resample=Image.NEAREST)
    return pil_label_to_long_tensor(pil)

# =========================
# Build val split (same indices rule)
# =========================
split_base = NYUv2(
    root=DATASET_ROOT,
    split="train",
    include_depth=True,
    image_transform=eval_image_transform,
    depth_transform=depth_transform,
    target_transform=target_transform,
)

n_total = len(split_base)
n_train = int(n_total * TRAIN_VAL_SPLIT)
n_val = n_total - n_train
g = torch.Generator().manual_seed(SEED)
_, val_subset = random_split(split_base, [n_train, n_val], generator=g)
val_ds = Subset(split_base, val_subset.indices)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

print(f"Val samples: {len(val_ds)}")

# =========================
# Load checkpoint + auto-detect SE
# =========================
ckpt = torch.load(MODEL_PATH, map_location=DEVICE)
if isinstance(ckpt, dict) and "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
    sd = ckpt["state_dict"]
else:
    sd = ckpt

has_se = any(k.startswith("depth_stem.") or k.startswith("se_dec.") for k in sd.keys())
print("Detected SE in checkpoint:", has_se)

model = ResNet50UNet(
    num_classes=13,
    coarse_classes=5,
    in_channels=4,
    pretrained=False,
    use_se=has_se,
    se_z_ch=64,
    se_reduction=16,
    se_alpha=0.05,
    se_dec_stages=("up1",),
).to(DEVICE)

model.load_state_dict(sd, strict=True)
model.eval()

# =========================
# Generate panels with FIXED palette
# =========================
os.makedirs(OUTDIR, exist_ok=True)

saved = 0
with torch.no_grad():
    for i, (rgb, depth, label) in enumerate(tqdm(val_loader, desc="Saving fixed-color panels")):
        rgb = rgb.to(DEVICE, non_blocking=True)
        depth = depth.to(DEVICE, non_blocking=True)
        label = label.to(DEVICE, non_blocking=True)

        x = torch.cat([rgb.float(), depth.float()], dim=1)
        fine_logits, _, _ = model(x, depth=depth.float() if has_se else None)
        pred = fine_logits.argmax(dim=1)

        rgb_pil = denorm_rgb(rgb[0])

        gt_hw = label[0].cpu().numpy().astype(np.int32)
        pred_hw = pred[0].cpu().numpy().astype(np.int32)

        gt_rgb = id_to_rgb(gt_hw, ignore_index=IGNORE_INDEX, ignore_rgb=(0,0,0))
        pr_rgb = id_to_rgb(pred_hw, ignore_index=IGNORE_INDEX, ignore_rgb=(0,0,0))

        gt_pil = Image.fromarray(gt_rgb, mode="RGB")
        pr_pil = Image.fromarray(pr_rgb, mode="RGB")
        err_pil = overlay_error(rgb_pil, gt_hw, pred_hw, ignore_index=IGNORE_INDEX)

        panel = stack_h([rgb_pil, gt_pil, pr_pil, err_pil])
        panel.save(os.path.join(OUTDIR, f"{i:06d}.png"), optimize=True)

        saved += 1
        if MAX_SAVE is not None and saved >= MAX_SAVE:
            break

print(f"Saved: {saved} images -> {OUTDIR}")

# =========================
# Zip
# =========================
with zipfile.ZipFile(ZIP_PATH, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as zf:
    for fn in sorted(os.listdir(OUTDIR)):
        if fn.lower().endswith(".png"):
            zf.write(os.path.join(OUTDIR, fn), arcname=fn)

print(f"ZIP created: {ZIP_PATH}")


In [None]:
%ls -l checkpoints

In [None]:
# ------------------
#    Evaluation (DeepLabV3-ResNet50 / ResNet backbone)
# ------------------

ckpt = torch.load(final_path, map_location=device)

# もし「学習で checkpoint dict（model_state入り）」で保存している場合にも対応
if isinstance(ckpt, dict) and "model_state" in ckpt:
    model.load_state_dict(ckpt["model_state"])
else:
    model.load_state_dict(ckpt)

model.eval()

predictions = []

with torch.no_grad():
    print("Generating predictions...")
    for image, depth in tqdm(test_data):
        image = image.to(device, non_blocking=True)
        depth = depth.to(device, non_blocking=True)

        # 念のため：depthが3ch事故のときは1chに落とす
        if depth.dim() == 4 and depth.size(1) == 3:
            depth = depth[:, :1, :, :]

        x = torch.cat((image, depth), dim=1)   # [B,4,H,W]

        # DeepLabV3 は dict を返すので "out" を使う
        out = model(x)["out"]                  # [B,num_classes,H,W]
        pred = out.argmax(dim=1)               # [B,H,W]

        predictions.append(pred.cpu())

predictions = torch.cat(predictions, dim=0).numpy()
np.save("submission.npy", predictions)
print("Predictions saved to submission.npy")


"""## 提出方法

以下の3点をzip化し，Omnicampusの「最終課題 (セグメンテーション)」から提出してください．

- `submission.npy`
- `model.pt`や`model_best.pt`など，テストに使用した重み（拡張子は`.pt`のみ）
- 本Colab Notebook
"""

from zipfile import ZipFile, ZIP_DEFLATED

notebook_path = "/content/drive/MyDrive/Colab Notebooks/DL_Basic_2025_Competition_NYUv2_baseline.ipynb"

from zipfile import ZipFile, ZIP_DEFLATED

notebook_path = "/content/drive/MyDrive/code.ipynb"

with ZipFile("submission.zip",
             mode="w",
             compression=ZIP_DEFLATED,
             compresslevel=9) as zf:
    zf.write("submission.npy")
    zf.write(model_path)
#    zf.write(notebook_path,
#             arcname="code.ipynb")
