# %% [markdown]
# # 椎体骨折検出モデル（陽性オンライン拡張＋不均衡サンプラ対応）
# * TorchVision / MONAI 不要の軽量実装
# * 陽性クラスにだけ強めのオンライン Data Augmentation
# * `WeightedRandomSampler` で 1:1 に近いバランスで学習
# 
# **事前準備**
# * `pip install nibabel pandas scikit-learn matplotlib wandb`
# * W&B の API キーは `wandb login` で保存 or 環境変数に設定
# * CSV は以下カラムを想定  
#   * `FullPath` : NIfTI ファイルパス  
#   * `Fracture_Label` : 1 (骨折) / 0 (健常)

In [1]:
import os, copy, random, numpy as np, pandas as pd
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import models, transforms as T
import torch.nn.functional as F
import nibabel as nib
from PIL import Image
from tqdm.notebook import tqdm
from sklearn.metrics import precision_recall_curve, auc
import wandb
from datetime import datetime
from itertools import product

WANDB_PROJECT = "vertebrae-online_aug_axial_learning_3-3(α=0.8,gamma=1.0)"
wandb.login()  

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33myuya00[0m ([33myuya00-university-of-hyogo[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

# 学習モデル

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cuda


In [3]:
#2. FocalLoss：alpha自動設定
############################
class FocalLoss(nn.Module):
    def __init__(self, pos_weight, gamma=1.0, reduction="mean"):
        super().__init__()
        self.alpha = pos_weight         # pos:neg 比から計算
        self.gamma, self.reduction = gamma, reduction

    def forward(self, logits, targets):
        ce = F.binary_cross_entropy_with_logits(logits, targets,
                                                reduction="none")
        p_t = torch.exp(-ce)            # = σ(logits) if y=1 else 1-σ
        alpha_t = self.alpha*targets + (1-self.alpha)*(1-targets)
        loss = alpha_t * (1 - p_t) ** self.gamma * ce
        return loss.mean() if self.reduction == "mean" else loss.sum()

In [4]:
def build_transforms():
    resize = T.Resize((224, 224))
    to_tensor = T.ToTensor()
    norm = T.Normalize([0.5], [0.5])

    # --- 陽性 ---
    geom = T.RandomApply([
            T.RandomRotation(20),
            T.RandomAffine(
                degrees=0,
                translate=(0.07, 0.07),
                scale=(0.9, 1.1),
                shear=5)], p=0.7)

    intensity = T.RandomApply([
            T.ColorJitter(brightness=0.1, contrast=0.15)], p=0.5)

    # ★ ノイズは Tensor 化のあと ★
    noise = T.RandomApply([
            lambda x: x + 0.05 * torch.randn_like(x)], p=0.3)

    pos_tf = T.Compose([resize, geom, intensity,
                        to_tensor,       # Tensor 変換
                        noise,           # ← 順序をここに
                        norm])

    # --- 陰性 ---
    neg_tf = T.Compose([resize,
                        T.RandomHorizontalFlip(0.2),
                        to_tensor,
                        norm])

    # --- 検証 ---
    val_tf = T.Compose([resize, to_tensor, norm])
    return pos_tf, neg_tf, val_tf

# 呼び出し例
pos_tf, neg_tf, val_tf = build_transforms()

In [5]:
# ## 2. Dataset
# %%
class CTDataset(Dataset):
    def __init__(self, csv_path, pos_tf=None, neg_tf=None):
        self.data   = pd.read_csv(csv_path)
        self.pos_tf = pos_tf
        self.neg_tf = neg_tf

    def __len__(self):
        return len(self.data)

    def _load_nifti(self, path):
        img = nib.load(path).get_fdata()
        if img.ndim == 3:
            img = img[:, :, 0]      # 1 枚目スライス
        img_min = 100
        img_max = 2000
        img = np.clip(img, img_min, img_max)  # 100-2000
        img = (img - img_min) / (img_max - img_min)                # 0-1
        img = (img * 255).astype(np.uint8)
        return Image.fromarray(img).convert("L")

    def __getitem__(self, idx):
        r      = self.data.iloc[idx]
        label  = int(r["Fracture_Label"])
        img    = self._load_nifti(r["FullPath"])
        img_tf = self.pos_tf if label == 1 else self.neg_tf
        if img_tf:
            img = img_tf(img)
        return img, float(label)

# %% [markdown]
# ## 3. モデル
# %%
class ModifiedResNet(nn.Module):
    def __init__(self, drop=0.5):
        super().__init__()
        self.backbone = models.resnet18(weights="IMAGENET1K_V1")
        self.backbone.conv1 = nn.Conv2d(1, 64, 7, 2, 3, bias=False)
        n = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Dropout(drop), nn.Linear(n, 1))   #nn.Sigmoid()はfocal_lossのとき外す
        

    def forward(self, x):
        return self.backbone(x).squeeze(1)

# %% [markdown]
# ## 4. 評価関数
# %%
def evaluate(model, loader, criterion):
    model.eval()
    loss_sum, y_true, y_pred = 0.0, [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device, torch.float32), y.to(device, torch.float32)
            out  = model(x)
            loss_sum += criterion(out, y).item()
            y_true.extend(y.cpu().numpy())
            #y_pred.extend(out.cpu().numpy()) 
            y_pred.extend(torch.sigmoid(out).cpu().numpy())   # ← logits → σ に変換 #focal_lossのときsigmoidは外す
    val_loss = loss_sum / len(loader)
    prec, rec, thr = precision_recall_curve(y_true, y_pred)
    prauc = auc(rec, prec)
    f1 = 2 * prec * rec / (prec + rec + 1e-8)
    idx = f1.argmax()
    best = dict(th=thr[idx] if idx < len(thr) else 0.5,
                prec=prec[idx], rec=rec[idx])
    return val_loss, prauc, best


In [6]:
# -------------------- 5. 早期停止付き学習ループ --------------------
def train_model(model, tr_loader, val_loader,
                criterion, optimizer, #scheduler,
                epochs=30, early_stop_patience=5,
                log_wandb=True):
    
     # ❷ One-CycleLR をここで生成
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=optimizer.param_groups[0]['lr'],  # (=1e-3)
        steps_per_epoch=len(tr_loader),
        epochs=epochs,
        pct_start=0.1           # 最初の 10 % がウォームアップ
    )
    
    if log_wandb:
        run = wandb.init(project=WANDB_PROJECT, reinit=True,
                         config=dict(epochs=epochs,
                                     lr=optimizer.param_groups[0]['lr'],
                                     wd=optimizer.param_groups[0]['weight_decay'],
                                     dropout=getattr(model.backbone.fc[0], 'p', None)))
        wandb.watch(model, log="all", log_freq=50)

    best_prauc, best_state = 0.0, None
    no_improve = 0

    for ep in range(epochs):
        # ---------- train ----------
        model.train(); tr_loss = 0.0
        for x, y in tr_loader:
            x, y = x.to(device, torch.float32), y.to(device, torch.float32)
            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward(); optimizer.step()
            scheduler.step()           # ← 毎バッチ学習率を更新
            tr_loss += loss.detach().item()

        # ---------- val ----------
        v_loss, v_prauc, _ = evaluate(model, val_loader, criterion)
        #scheduler.step(v_loss)

        if log_wandb:
            wandb.log({"epoch": ep+1,
                       "train_loss": tr_loss/len(tr_loader),
                       "val_loss": v_loss,
                       "val_prauc": v_prauc,
                       "lr": optimizer.param_groups[0]['lr']})

        # ---- early-stop 判定 ----
        if v_prauc > best_prauc:
            best_prauc  = v_prauc
            best_state  = copy.deepcopy(model.state_dict())
            no_improve  = 0
        else:
            no_improve += 1
            if no_improve >= early_stop_patience:
                break     # ← 打ち切り

    if log_wandb:
        wandb.finish()
    return best_prauc, best_state


## BCE_lossのときつかう

In [None]:
# -------------------- 6. 高速 2-ステージ grid_search --------------------
def grid_search(tr_loader, val_loader,
                quick_epochs=4, full_epochs=20,
                top_k=3, max_trials=None, seed=42):
    """
    quick_epochs : Stage-1 で回す短いエポック数
    full_epochs  : Stage-2 で回す本番エポック数
    top_k        : Stage-2 に進める上位構成数
    max_trials   : ランダム探索の最大試行数（None なら全組み合わせ）
    """
    # ----- ハイパラ候補 -----
    EPOCHS_dummy = [0]                  # Stage-1/2 で上書きするのでダミー
    LRS     = [1e-5, 3e-5, 1e-4]
    WDECAY  = [1e-3, 1e-2]
    DROPOUT = [0.0, 0.3]

    # すべての組み合わせ
    from itertools import product
    all_params = list(product(LRS, WDECAY, DROPOUT))
    random.seed(seed); random.shuffle(all_params)   # ランダム順にすると偏りが出にくい
    if max_trials:
        all_params = all_params[:max_trials]

    results_stage1 = []

    # ===== Stage-1 : ざっくり評価 (quick_epochs) =====
    print(f"\n▶ Stage-1  ( {len(all_params)} trials × {quick_epochs} epochs )")
    for (lr, wd, do) in all_params:
        model = ModifiedResNet(do).to(device)
        crit  = nn.BCELoss()
        opt   = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
        sch   = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min')

        prauc, _ = train_model(model, tr_loader, val_loader,
                               crit, opt, sch,
                               epochs=quick_epochs,
                               early_stop_patience=2,
                               log_wandb=True)
        results_stage1.append(((lr, wd, do), prauc))
        print(f" quick  lr={lr:.0e} wd={wd:.0e} do={do} → PRAUC {prauc:.4f}")

    # 上位 top_k を選抜
    results_stage1.sort(key=lambda x: x[1], reverse=True)
    top_params = results_stage1[:top_k]
    print("\n★ Stage-1 TOP:")
    for p, s in top_params:
        print(f"  {p}  PRAUC={s:.4f}")

    # ===== Stage-2 : 本命学習 (full_epochs) =====
    best_score, best_state, best_hp = 0.0, None, None
    print(f"\n▶ Stage-2  ( {len(top_params)} trials × {full_epochs} epochs )")
    for (lr, wd, do), _ in top_params:
        model = ModifiedResNet(do).to(device)
        crit  = nn.BCELoss()
        opt   = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
        sch   = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min')

        prauc, state = train_model(model, tr_loader, val_loader,
                                   crit, opt, sch,
                                   epochs=full_epochs,
                                   early_stop_patience=4,
                                   log_wandb=True)         # ← ここは詳細ログ

        print(f" full  lr={lr:.0e} wd={wd:.0e} do={do} → PRAUC {prauc:.4f}")
        if prauc > best_score:
            best_score, best_state = prauc, state
            best_hp = dict(lr=lr, wd=wd, dropout=do)

    # ----- 最高モデルだけ保存 -----
    torch.save(best_state, "best_model_online_aug_1.pth")
    print(f"\n✅ Saved best_model.pth  (PRAUC {best_score:.4f})")
    print("   Best hyper-parameters:", best_hp)


## focal_loss

In [7]:
def grid_search(tr_loader, val_loader, pos_weight,
                quick_epochs=4, full_epochs=30,
                top_k=4, seed=42):

    LRS      = [1e-5, 1e-4]
    WDECAY   = [1e-3, 1e-2]
    DROPOUT  = [0.0, 0.3]
    GAMMAS   = [1.0, 2.0]

    from itertools import product
    all_params = list(product(LRS, WDECAY, DROPOUT, GAMMAS))

    # ---------- Stage-1 ----------
    results = []
    for i, (lr, wd, do, g) in enumerate(all_params, 1):
        model = ModifiedResNet(drop=do).to(device)
        crit  = FocalLoss(pos_weight=pos_weight, gamma=g)
        opt   = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

        prauc, _ = train_model(model, tr_loader, val_loader,
                               crit, opt,
                               epochs=quick_epochs,
                               early_stop_patience=2)
        results.append(((lr, wd, do, g), prauc))

    results.sort(key=lambda x: x[1], reverse=True)
    top_params = results[:top_k]

    # ---------- Stage-2 ----------
    best_score, best_state, best_hp = 0.0, None, None
    for rank, ((lr, wd, do, g), _) in enumerate(top_params, 1):
        model = ModifiedResNet(drop=do).to(device)
        crit  = FocalLoss(pos_weight=pos_weight, gamma=g)
        opt   = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)

        prauc, state = train_model(model, tr_loader, val_loader,
                                   crit, opt,
                                   epochs=full_epochs,
                                   early_stop_patience=6)

        if prauc > best_score:
            best_score, best_state = prauc, state
            best_hp = dict(lr=lr, wd=wd, dropout=do, gamma=g)

    torch.save(best_state, "best_model_online_aug_2.pth")
    print(f"✅ saved best_model.pth   PRAUC={best_score:.4f}")
    print("   hyper-params:", best_hp)


In [8]:
# ## 7. DataLoader 準備
# %%
train_csv = "/mnt/nfs1/home/yamamoto-hiroto/research/vertebrae/Sakaguchi_file/slice_train/axial/train_labels_axial.csv"
val_csv   = "/mnt/nfs1/home/yamamoto-hiroto/research/vertebrae/Sakaguchi_file/slice_val/axial/val_labels_axial.csv"

train_ds = CTDataset(train_csv, pos_tf=pos_tf, neg_tf=neg_tf)
val_ds   = CTDataset(val_csv,   pos_tf=val_tf, neg_tf=val_tf)


## ★ Sampler で 1:1 → α=0.5 固定
pos_weight = 0.8
print(f"pos_weight (alpha) = {pos_weight:.3f}")

y = train_ds.data["Fracture_Label"].astype(int).values
cls_cnt = np.bincount(y, minlength=2)
weights = (1. / cls_cnt)[y]
sampler = WeightedRandomSampler(weights=torch.DoubleTensor(weights),
                                num_samples=len(weights),
                                replacement=True)

batch = 64
train_loader = DataLoader(train_ds, batch_size=batch,
                          sampler=sampler, num_workers=4)
val_loader   = DataLoader(val_ds,  batch_size=batch,
                          shuffle=False, num_workers=4)

print(f"Train: {len(train_ds)} (neg {cls_cnt[0]}, pos {cls_cnt[1]})")

# ----------------- 8. 実行！ -----------------
grid_search(train_loader, val_loader, pos_weight=pos_weight,
            quick_epochs=4, full_epochs=30, top_k=4)

pos_weight (alpha) = 0.800
Train: 37737 (neg 33142, pos 4595)




[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▃▆█
lr,█▅▂▁
train_loss,█▂▁▁
val_loss,█▁▆▅
val_prauc,▁█▇█

0,1
epoch,4.0
lr,0.0
train_loss,0.00832
val_loss,0.07578
val_prauc,0.31681


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▃▆█
lr,█▅▂▁
train_loss,█▂▁▁
val_loss,▁▆▇█
val_prauc,▁▆▇█

0,1
epoch,4.0
lr,0.0
train_loss,0.00661
val_loss,0.04309
val_prauc,0.30208


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▃▆█
lr,█▅▂▁
train_loss,█▂▁▁
val_loss,▁▅██
val_prauc,▁▄▃█

0,1
epoch,4.0
lr,0.0
train_loss,0.01264
val_loss,0.06704
val_prauc,0.31271


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▃▆█
lr,█▅▂▁
train_loss,█▂▁▁
val_loss,▁▂█▅
val_prauc,▄█▁█

0,1
epoch,4.0
lr,0.0
train_loss,0.00605
val_loss,0.04363
val_prauc,0.2779


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▃▆█
lr,█▅▂▁
train_loss,█▂▁▁
val_loss,▁▄██
val_prauc,▁▆██

0,1
epoch,4.0
lr,0.0
train_loss,0.01168
val_loss,0.06911
val_prauc,0.31907


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▃▆█
lr,█▅▂▁
train_loss,█▂▁▁
val_loss,█▁▂▂
val_prauc,▁▅██

0,1
epoch,4.0
lr,0.0
train_loss,0.0062
val_loss,0.03858
val_prauc,0.28156


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▃▆█
lr,█▅▂▁
train_loss,█▂▁▁
val_loss,█▁▇▅
val_prauc,▁█▆▇

0,1
epoch,4.0
lr,0.0
train_loss,0.01314
val_loss,0.06462
val_prauc,0.32355


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▃▆█
lr,█▅▂▁
train_loss,█▃▁▁
val_loss,█▁▂▃
val_prauc,▁█▆▇

0,1
epoch,4.0
lr,0.0
train_loss,0.0088
val_loss,0.0336
val_prauc,0.32556


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▅█
lr,█▅▁
train_loss,█▂▁
val_loss,▁▆█
val_prauc,█▁█

0,1
epoch,3.0
lr,2e-05
train_loss,0.00308
val_loss,0.11986
val_prauc,0.40747


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▃▆█
lr,█▅▂▁
train_loss,█▃▂▁
val_loss,▁▅▃█
val_prauc,▁▅██

0,1
epoch,4.0
lr,0.0
train_loss,0.00063
val_loss,0.09208
val_prauc,0.43227


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▃▆█
lr,█▅▂▁
train_loss,█▂▁▁
val_loss,▂▁█▇
val_prauc,▁▆▇█

0,1
epoch,4.0
lr,0.0
train_loss,0.00132
val_loss,0.10485
val_prauc,0.46446


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▃▆█
lr,█▅▂▁
train_loss,█▂▁▁
val_loss,▁▄██
val_prauc,▁▃▆█

0,1
epoch,4.0
lr,0.0
train_loss,0.00075
val_loss,0.08384
val_prauc,0.38949


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▃▆█
lr,█▅▂▁
train_loss,█▄▂▁
val_loss,█▁▇▅
val_prauc,▁▅██

0,1
epoch,4.0
lr,0.0
train_loss,0.00251
val_loss,0.11352
val_prauc,0.4469


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▃▆█
lr,█▅▂▁
train_loss,█▄▂▁
val_loss,▄▁▆█
val_prauc,▁█▅█

0,1
epoch,4.0
lr,0.0
train_loss,0.00178
val_loss,0.063
val_prauc,0.43513


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▃▆█
lr,█▅▂▁
train_loss,█▃▂▁
val_loss,▅▁▇█
val_prauc,▁▃▇█

0,1
epoch,4.0
lr,0.0
train_loss,0.00221
val_loss,0.11315
val_prauc,0.47624


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▃▆█
lr,█▅▂▁
train_loss,█▄▂▁
val_loss,▇▁▇█
val_prauc,▁▇▅█

0,1
epoch,4.0
lr,0.0
train_loss,0.00178
val_loss,0.07535
val_prauc,0.42284


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▂▂▃▃▄▄▅▅▆▆▇▇█
lr,▁▆████▇▇▇▆▆▆▅▅
train_loss,█▂▂▂▂▂▂▂▁▁▁▁▁▁
val_loss,▁▂▁█▂▁▂▂▁▂▃▂▁▁
val_prauc,▅▆▄▂▇▅▅█▅▆▄▅▅▁

0,1
epoch,14.0
lr,6e-05
train_loss,0.00658
val_loss,0.09433
val_prauc,0.24956


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▂▃▄▅▅▆▇█
lr,▁▆████▇▇▇
train_loss,█▂▂▁▁▁▁▁▁
val_loss,▅█▁▃▂▂▂▂▂
val_prauc,▁▃█▃▅▇▅▄▆

0,1
epoch,9.0
lr,9e-05
train_loss,0.00618
val_loss,0.07364
val_prauc,0.41423


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▂▂▃▄▅▅▆▇▇█
lr,▁▆████▇▇▇▆▆
train_loss,█▃▃▃▂▂▂▂▁▁▁
val_loss,▁▂▂▄▁▄▇▂█▄▅
val_prauc,▆▃▁▅█▄▃▄▁▃▃

0,1
epoch,11.0
lr,8e-05
train_loss,0.00892
val_loss,0.23123
val_prauc,0.25597


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▂▂▃▃▄▅▅▆▆▇▇█
lr,▁▆████▇▇▇▆▆▆▅
train_loss,█▃▃▃▃▃▂▂▂▂▁▁▁
val_loss,▁▂▃▂▃▆▁▆▄▆█▅▆
val_prauc,▇▄▅▄▄▄█▅▆▃▅▁▂

0,1
epoch,13.0
lr,7e-05
train_loss,0.00468
val_loss,0.0726
val_prauc,0.18345


✅ saved best_model.pth   PRAUC=0.4755
   hyper-params: {'lr': 0.0001, 'wd': 0.001, 'dropout': 0.3, 'gamma': 1.0}
