1. Dataset 로드 확인

In [None]:
"""
Quick sanity-check for DataModule
---------------------------------
Run:  python motor_det/tests/test_loader.py
"""

from motor_det.data.module import MotorDataModule

DATA_ROOT = r"D:\project\Kaggle\BYU\byu-motor\data"   # ← 프로젝트 경로에 맞게 수정
dm = MotorDataModule(data_root=DATA_ROOT, fold=0, batch_size=2)
dm.setup()

batch = next(iter(dm.train_dataloader()))
print("image :", batch["image"].shape, batch["image"].dtype)
print("cls   :", batch["cls"].shape, batch["cls"].sum())
print("offset:", batch["offset"].shape)


2. model 로직 확인

In [None]:
import torch
from motor_det.model.net import MotorDetNet
net = MotorDetNet()
y = net(torch.randn(1,1,96,128,128))
print(y["cls"].shape, y["offset"].shape)
# (1,1,48,64,64) (1,3,48,64,64)

3. loss 구현 확인(dummy test)

In [None]:
import torch
from motor_det.loss.losses import motor_detection_loss

# --------------------------------------------------
B, D, H, W = 1, 48, 64, 64        # stride-2 output 크기
gt_cls = torch.zeros((B, 1, D, H, W))
gt_cls[0, 0, 10, 20, 30] = 1.0    # 모터 1개

gt_off = torch.zeros((B, 3, D, H, W))
gt_off[:, :, 10, 20, 30] = torch.tensor([0.3, -0.2, 0.1])

batch = {"cls": gt_cls, "offset": gt_off}

# 예측값: GT 부근에 약간의 노이즈
pred_cls = gt_cls * 0.9 + 0.05            # 0.95 / 0.05
pred_off = gt_off + torch.randn_like(gt_off) * 0.05

pred = {"cls": pred_cls, "offset": pred_off}

loss, logs = motor_detection_loss(pred, batch)
print("loss =", loss.item())
print("logs =", logs)


4. Train 확인(1step)

In [None]:
# --- quick_sanity_check.ipynb ------------------------------------
import torch, lightning as L
from lightning.pytorch.callbacks import RichProgressBar
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.utilities.rank_zero import rank_zero_info

from motor_det.data.module import MotorDataModule
from motor_det.engine.lit_module import LitMotorDet
from motor_det.callbacks.val_console_logger import ValConsoleLoggerEveryN   # ★ NEW

# 1) Seed & matmul
L.seed_everything(42)
torch.set_float32_matmul_precision("high")

# 2) DataModule ───────────────────────────────────────────────
DATA_ROOT = r"D:\project\Kaggle\BYU\byu-motor\data"
dm = MotorDataModule(
    data_root=DATA_ROOT,
    fold=0,
    batch_size=1,
    num_workers=12,
    persistent_workers=True,
)
dm.setup()

# 3) Model ────────────────────────────────────────────────────
model = LitMotorDet(lr=1e-3, total_steps=1)

# 4) Train-loop console logger (train) ────────────────────────
class TrainConsoleLoggerEveryN(L.Callback):
    def __init__(self, every_n_steps: int = 10):
        self.every_n_steps = every_n_steps

    def on_train_batch_end(self, trainer, *_):
        if (trainer.global_step + 1) % self.every_n_steps == 0:
            mets = {k: v for k, v in trainer.logged_metrics.items() if k.startswith("train/")}
            msg = f"[TRAIN] step {trainer.global_step:>4}: " + ", ".join(f"{k}={v:.4f}" for k, v in mets.items())
            rank_zero_info(msg)

# 5) Trainer ──────────────────────────────────────────────────
trainer = L.Trainer(
    accelerator="gpu", devices=1, precision="16-mixed",
    max_epochs=1,
    limit_train_batches=0.005,
    limit_val_batches=0.01,
    log_every_n_steps=10,
    callbacks=[
        RichProgressBar(refresh_rate=1),
        TrainConsoleLoggerEveryN(10),     # ← train 로그
        ValConsoleLoggerEveryN(5),        # ← NEW: val 로그 (batch 5,10,…)
    ],
    logger=CSVLogger("logs", "debug"),
    enable_checkpointing=False,
    enable_model_summary=False,
)

trainer.fit(model, datamodule=dm)


5. metric 검증

In [None]:
from motor_det.metrics.det_metric import fbeta_score
import torch
pred = torch.tensor([[0.,0.,0.], [10.,0.,0.]])
gt   = torch.tensor([[0.,0.,0.], [5.,0.,0.]])
print(fbeta_score(pred, gt, beta=2)[0])   # → 1.0

6. decoder 검증

In [None]:
import torch
from motor_det.postprocess.decoder import decode_with_nms
log = torch.zeros(1,1,4,4,4); log[0,0,2,2,2] = 10
off = torch.zeros(1,3,4,4,4)
print(decode_with_nms(log, off, stride=2)[0])
# tensor([[5., 5., 5.]])  ≈ ( (2+0.5)*2 , ... )


7. 퀵 학습 후 단일 tomogram 학습 검증

In [None]:
# --- quick_train.ipynb : CELL 1 ------------------------------------
import lightning as L, torch, time, os
from lightning.pytorch.callbacks import ModelCheckpoint, RichProgressBar
from motor_det.data.module   import MotorDataModule
from motor_det.engine.lit_module import LitMotorDet

# 경로 및 세팅 ----------------------------------------------------------------
DATA_ROOT   = r"D:\project\Kaggle\BYU\byu-motor\data"
RUNS_DIR    = "runs/quick"                 # 체크포인트 저장 폴더
FOLD        = 0
os.makedirs(RUNS_DIR, exist_ok=True)

L.seed_everything(42)
torch.set_float32_matmul_precision("high")

# DataModule (5 % train, 10 % val 로 빠르게) ---------------------------------
dm = MotorDataModule(
    data_root=DATA_ROOT,
    fold=FOLD,
    batch_size=1,
    num_workers=12,
    persistent_workers=True,
)
dm.setup()

# 모델 -----------------------------------------------------------------------
model = LitMotorDet(lr=2e-4, total_steps=1_000)

# 콜백: 최고 f2 한 개만 저장
ckpt_cb = ModelCheckpoint(
    dirpath=RUNS_DIR,
    filename="best",
    monitor="val/f2",
    mode="max",
    save_top_k=1,
)
trainer = L.Trainer(
    accelerator="gpu", devices=1, precision="16-mixed",
    max_epochs=1,
    limit_train_batches=0.05,      # 데이터 5 %만
    limit_val_batches=0.01,
    check_val_every_n_epoch=1,
    callbacks=[RichProgressBar(), ckpt_cb],
    log_every_n_steps=20,
)
start = time.time()
trainer.fit(model, dm)
print(f"✔︎ training done in {time.time()-start:0.1f}s")
print("saved:", ckpt_cb.best_model_path)


In [None]:
# --- quick_train.ipynb : CELL 2 (수정본) ----------------------------
import time
import torch
import pandas as pd
from pathlib import Path
from lightning.pytorch.utilities.rank_zero import rank_zero_info

from motor_det.model.net    import MotorDetNet
from motor_det.engine.infer import infer_single_tomo
# voxel_spacing_map는 여기선 사용하지 않음
# from motor_det.utils.voxel import voxel_spacing_map

# ─── 경로 & 하이퍼 --------------------------------------------------
DATA_ROOT = r"D:\project\Kaggle\BYU\byu-motor\data"
CKPT_PATH = Path("runs/quick/best.ckpt")      # 셀 1에서 저장된 체크포인트
TOMO_ID   = "tomo_00e047"                     # 테스트할 tomogram ID
DEVICE    = torch.device("cuda:0")

WINDOW    = (192, 128, 128)
STRIDE    = (96,  64,  64)
STRIDE_H  = 2               # 모델 헤드의 출력 stride
SIGMA_Å   = 60.0
SCORE_THR = 0.5
BATCH_SZ  = 1
NUM_W     = 4
# ────────────────────────────────────────────────────────────────────

# 1) 모델 로드 & 가중치 불러오기 --------------------------------------
def strip_prefix_state_dict(sd, prefix="net."):
    """state_dict의 key에 붙은 prefix를 제거"""
    from collections import OrderedDict
    return OrderedDict(
        (k[len(prefix):] if k.startswith(prefix) else k, v)
        for k, v in sd.items()
    )

model = MotorDetNet().to(DEVICE).eval()
ckpt   = torch.load(CKPT_PATH, map_location="cpu")
state  = ckpt.get("state_dict", ckpt)
state  = strip_prefix_state_dict(state, prefix="net.")
model.load_state_dict(state, strict=True)
rank_zero_info("✓ weights loaded")

# 2) 단일 tomo 추론 -------------------------------------------------
zarr_path = Path(DATA_ROOT) / "processed" / "zarr" / "test" / f"{TOMO_ID}.zarr"

t0 = time.time()
arr = infer_single_tomo(
    zarr_path   = zarr_path,
    net         = model,
    window      = WINDOW,
    stride      = STRIDE,
    stride_head = STRIDE_H,
    spacing_Å   = 15.0,      # 테스트 데이터의 기본 voxel spacing
    batch_size  = BATCH_SZ,
    num_workers = NUM_W,
    prob_thr    = SCORE_THR,
    sigma_Å     = SIGMA_Å,
    iou_thr     = 0.25,
    device      = DEVICE,
)
rank_zero_info(f"Inference done in {time.time() - t0:0.1f}s")

# 3) numpy 결과 → DataFrame으로 변환 -------------------------------
# arr.shape == (0,3) 혹은 (1,3)
if arr.size == 0:
    df_pred = pd.DataFrame(
        [[TOMO_ID, -1, -1, -1]],
        columns=["tomo_id", "Motor axis 0", "Motor axis 1", "Motor axis 2"]
    )
else:
    x, y, z = arr[0]
    df_pred = pd.DataFrame(
        [[TOMO_ID, x, y, z]],
        columns=["tomo_id", "Motor axis 0", "Motor axis 1", "Motor axis 2"]
    )

display(df_pred)

# 4) CSV 저장 -------------------------------------------------------
df_pred.to_csv("submission_single.csv", index=False)
print("CSV saved -> submission_single.csv")


8. 본격 성능 검증

In [None]:
# ─── quick_train_val.ipynb: CELL 1 (독립 실행용) ─────────────────────────────────

import os
import time
from pathlib import Path
import lightning as L
import torch
import torch.nn.functional as F

from lightning.pytorch.callbacks import ModelCheckpoint, RichProgressBar
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch import Trainer
from torch import nn, Tensor

from motor_det.data.module       import MotorDataModule
from motor_det.model.net         import MotorDetNet
from motor_det.loss.losses       import motor_detection_loss
from motor_det.postprocess.decoder import decode_with_nms
from motor_det.metrics.det_metric  import fbeta_score
from motor_det.optim.cosine_with_warmup import WarmupCosineScheduler

# ─── 설정 ─────────────────────────────────────────────────────────────────────────
DATA_ROOT  = r"D:\project\Kaggle\BYU\byu-motor\data"
RUNS_DIR   = "runs/quick_val"
os.makedirs(RUNS_DIR, exist_ok=True)

L.seed_everything(42)
torch.set_float32_matmul_precision("high")

# ─── LightningModule 정의 (TP/FP/FN까지 로깅) ────────────────────────────────────
class LitMotorDet(L.LightningModule):
    def __init__(self, lr=3e-4, weight_decay=1e-4, warmup_steps=500, total_steps=30_000):
        super().__init__()
        self.save_hyperparameters()
        self.net = MotorDetNet()

    def forward(self, x: Tensor) -> dict[str, Tensor]:
        return self.net(x)

    def _shared_step(self, batch, stage: str):
        preds = self(batch["image"])
        loss, logs = motor_detection_loss(preds, batch)
        logs = {f"{stage}/{k}": v for k, v in logs.items()}
        self.log_dict(logs, on_step=True, on_epoch=True, prog_bar=(stage=="train"))
        return loss

    def training_step(self, batch, batch_idx):
        return self._shared_step(batch, "train")

    def validation_step(self, batch, batch_idx):
        # 1) 공통 손실 + log
        loss = self._shared_step(batch, "val")

        # 2) 디코드 + F₂, TP/FP/FN 계산
        preds   = self(batch["image"])
        logits  = preds["cls"]
        offsets = preds["offset"]
        centers_pred = decode_with_nms(
            logits, offsets,
            stride=2, prob_thr=0.5, sigma=60.0, iou_thr=0.25
        )[0]

        gt_centers = batch["centers_Å"][0]
        f2, prec, rec, tp, fp, fn = fbeta_score(
            centers_pred, gt_centers,
            beta=2, dist_thr=1000.0
        )

        # 3) epoch 단위로 aggregate되는 로그
        self.log_dict({
            "val/f2":   f2,
            "val/prec": prec,
            "val/rec":  rec,
            "val/tp":   tp,
            "val/fp":   fp,
            "val/fn":   fn,
        }, on_step=False, on_epoch=True, prog_bar=True)

        return {"f2": f2, "prec": prec, "rec": rec, "tp": tp, "fp": fp, "fn": fn}

    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(),
                                lr=self.hparams.lr,
                                weight_decay=self.hparams.weight_decay)
        sched = WarmupCosineScheduler(
            optimizer=opt,
            warmup_steps=self.hparams.warmup_steps,
            total_steps=self.hparams.total_steps,
            warmup_learning_rate=self.hparams.lr * 0.1,
        )
        return {"optimizer": opt,
                "lr_scheduler": {"scheduler": sched, "interval": "step"}}

# ─── DataModule 준비 ─────────────────────────────────────────────────────────
dm = MotorDataModule(
    data_root=DATA_ROOT,
    fold=0,
    batch_size=1,
    num_workers=12,
    # persistent_workers 지원 시
    persistent_workers=True,
)
dm.setup()

# ─── 콜백 및 Trainer 구성 ────────────────────────────────────────────────────
ckpt_cb = ModelCheckpoint(
    dirpath=RUNS_DIR,
    filename="best",
    monitor="val/f2",
    mode="max",
    save_top_k=1,
)
csv_logger = CSVLogger(RUNS_DIR, name="tensorboard")

trainer = Trainer(
    accelerator="gpu",
    devices=1,
    precision="16-mixed",
    max_epochs=10,               # 전체 에폭 수
    val_check_interval=33152,       # 매 33152 train-steps마다 validation 실행
    limit_val_batches=0.1,
    log_every_n_steps=20,
    callbacks=[RichProgressBar(refresh_rate=20), ckpt_cb],
    logger=csv_logger,
)

# ─── 학습 + 검증 실행 ─────────────────────────────────────────────────────────
start = time.time()
trainer.fit(LitMotorDet(), dm)
print(f"✔︎ 전체 학습+검증 완료  in {time.time() - start:.1f}s")
print("Best checkpoint:", ckpt_cb.best_model_path)


Seed set to 42
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\Users\shaun\anaconda3\envs\byu-motor\lib\site-packages\lightning\pytorch\callbacks\model_checkpoint.py:654: Checkpoint directory D:\project\Kaggle\BYU\byu-motor\motor_det\tests\runs\quick_val exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()