In [12]:
import os, copy, itertools, random, numpy as np, pandas as pd
import torch, torch.nn as nn, torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import models, transforms
import nibabel as nib
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import precision_recall_curve, auc
import wandb

In [13]:
# global config  (★要環境合わせ)
# ————————————————————————————————————————
WANDB_PROJECT_NAME = "vertebrae_axial_bceloss_3"
TRAIN_CSV_PATH     = "/mnt/nfs1/home/yamamoto-hiroto/research/vertebrae/Sakaguchi_file/slice_train/axial/train_labels_axial.csv"
VAL_CSV_PATH       = "/mnt/nfs1/home/yamamoto-hiroto/research/vertebrae/Sakaguchi_file/slice_val/axial/val_labels_axial.csv"
MODEL_SAVE_DIR     = "/mnt/nfs1/home/yamamoto-hiroto/research/vertebrae/Sakaguchi_file/S_model_learning/model_pth"
SEED               = 42

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [14]:

# — data-augmentation hyper-parameters —
AUG_RESIZE_SIZE        = (224, 224)
AUG_ROTATION_DEGREES   = 30
AUG_TRANSLATE_PERCENT  = (0.1, 0.1)
AUG_SCALE_RANGE        = (0.9, 1.1)
AUG_GAUSSIAN_NOISE_STD = 0.05
AUG_GAUSSIAN_NOISE_PROB= 0.5
AUG_NORM_MEAN          = (0.5,)
AUG_NORM_STD           = (0.5,)

# — training control —
EARLY_STOPPING_PATIENCE = 5
SCHEDULER_PATIENCE      = 5
SCHEDULER_FACTOR        = 0.1
EARLY_STOP_MIN_DELTA    = 1e-4   # PRAUC がこれ以上伸びなければ停⽌

# ----------------------------------------------------------------------
#  transforms / dataset
# ----------------------------------------------------------------------
class AddGaussianNoise(object):
    def __init__(self, std=0.05, p=0.5):
        self.std, self.p = std, p
    def __call__(self, t):
        if torch.rand(1).item() < self.p:
            noise = torch.randn_like(t) * self.std
            t = t + noise
            t.clamp_(0.0, 1.0)
        return t
    def __repr__(self):
        return f"{self.__class__.__name__}(std={self.std}, p={self.p})"

def get_transforms():
    pos_train = transforms.Compose([
        transforms.Resize(AUG_RESIZE_SIZE),
        transforms.RandomAffine(degrees=AUG_ROTATION_DEGREES,
                                translate=AUG_TRANSLATE_PERCENT,
                                scale=AUG_SCALE_RANGE),
        transforms.ToTensor(),
        AddGaussianNoise(std=AUG_GAUSSIAN_NOISE_STD, p=AUG_GAUSSIAN_NOISE_PROB),
        transforms.Normalize(AUG_NORM_MEAN, AUG_NORM_STD)
    ])
    neg_train = transforms.Compose([
        transforms.Resize(AUG_RESIZE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(AUG_NORM_MEAN, AUG_NORM_STD)
    ])
    val_tf = transforms.Compose([
        transforms.Resize(AUG_RESIZE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(AUG_NORM_MEAN, AUG_NORM_STD)
    ])
    return pos_train, neg_train, val_tf

class CTDataset(Dataset):
    def __init__(self, csv_path, pos_tf=None, neg_tf=None):
        self.df      = pd.read_csv(csv_path)
        self.pos_tf  = pos_tf
        self.neg_tf  = neg_tf
    def __len__(self): return len(self.df)
    def _load_slice(self, path):
        try:
            img = nib.load(path).get_fdata()
            if img.ndim == 3:
                img = img[:, :, img.shape[2]//2]   # middle slice
                
            img_min = 100
            img_max = 1800
            img = np.clip(img, img_min, img_max)
            img = ((img-img_min)/(img_max-img_min)*255).astype(np.uint8)
            return Image.fromarray(img).convert("L")
        except Exception as e:
            print(f"[warn] {e} → black Img")
            return Image.fromarray(np.zeros(AUG_RESIZE_SIZE, np.uint8)).convert("L")
    def __getitem__(self, idx):
        row   = self.df.iloc[idx]
        label = int(row.Fracture_Label)
        img   = self._load_slice(row.FullPath)
        tf    = self.pos_tf if label==1 else self.neg_tf
        if tf is None: tf = self.pos_tf   # fallback
        img   = tf(img)
        return img, float(label)

# ----------------------------------------------------------------------
#  sampler
# ----------------------------------------------------------------------
def create_sampler(labels_np):
    counts  = np.bincount(labels_np.astype(int))
    if len(counts) < 2:
        print("[warn] single-class dataset, sampler disabled")
        return None
    class_wt = 1.0 / counts
    sample_wt = class_wt[labels_np.astype(int)]
    return WeightedRandomSampler(sample_wt, len(sample_wt), replacement=True)

# ----------------------------------------------------------------------
#  model
# ----------------------------------------------------------------------
class ModifiedResNet(nn.Module):
    def __init__(self, dropout_rate=0.5):
        super().__init__()
        self.base = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        self.base.conv1 = nn.Conv2d(1, 64, 7, 2, 3, bias=False)
        nfeat = self.base.fc.in_features
        self.base.fc = nn.Sequential(nn.Dropout(dropout_rate),
                                     nn.Linear(nfeat, 1))
    def forward(self, x): return self.base(x)

# ----------------------------------------------------------------------
#  evaluation
# ----------------------------------------------------------------------
@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    tot_loss, ys, ps = 0.0, [], []
    for x,y in loader:
        x,y = x.to(DEVICE), y.to(DEVICE).float().unsqueeze(1)
        o   = model(x)
        tot_loss += criterion(o, y).item()
        ps.append(torch.sigmoid(o).cpu().view(-1))
        ys.append(y.cpu().view(-1))
    loss = tot_loss / len(loader)
    ys   = torch.cat(ys).numpy()
    ps   = torch.cat(ps).numpy()
    if len(np.unique(ys)) > 1:
        prec, rec, thr = precision_recall_curve(ys, ps)
        prauc  = auc(rec, prec)
        f1     = 2*prec*rec/(prec+rec+1e-8)
        best_i = np.nanargmax(f1)
        best_t = thr[min(best_i, len(thr)-1)] if len(thr) else 0.5
        best_p, best_r = prec[best_i], rec[best_i]
    else:
        prauc, best_t, best_p, best_r = 0.0, 0.5, 0.0, 0.0
    return loss, prauc, best_t, best_p, best_r

# ----------------------------------------------------------------------
#  training loop (EarlyStop & LR on PRAUC)
# ----------------------------------------------------------------------
def train_model(model, train_loader, val_loader,
                criterion, optimizer, scheduler,
                epochs, run_name):

    best_prauc, best_epoch, es_counter = -1.0, -1, 0

    for ep in range(epochs):
        model.train()
        prog = tqdm(train_loader, desc=f"Epoch {ep+1}/{epochs}", leave=False)
        t_loss = 0.0
        for x,y in prog:
            x,y = x.to(DEVICE), y.to(DEVICE).float().unsqueeze(1)
            optimizer.zero_grad()
            o = model(x)
            loss = criterion(o, y)
            loss.backward()
            optimizer.step()
            t_loss += loss.item()
            prog.set_postfix(loss=loss.item())

        v_loss, v_prauc, v_thr, _, _ = evaluate(model, val_loader, criterion)

        wandb.log({"epoch": ep+1,
                   "train_loss": t_loss/len(train_loader),
                   "val_loss": v_loss,
                   "val_prauc": v_prauc,
                   "lr": optimizer.param_groups[0]["lr"]})

        improved = v_prauc > best_prauc + EARLY_STOP_MIN_DELTA
        if improved:
            best_prauc, best_epoch, es_counter = v_prauc, ep+1, 0
            torch.save(model.state_dict(),
                       os.path.join(MODEL_SAVE_DIR,
                                    f"{run_name}_best.pth"))
            print(f"  ↳ new best PRAUC {best_prauc:.4f}  (model saved)")
        else:
            es_counter += 1

        if scheduler: scheduler.step(v_prauc)

        if es_counter >= EARLY_STOPPING_PATIENCE:
            print(f"Early-stop: no PRAUC gain for {EARLY_STOPPING_PATIENCE} epochs")
            break

    return best_prauc, best_epoch

# ----------------------------------------------------------------------
#  grid search
# ----------------------------------------------------------------------
def grid_search_main(train_csv, val_csv):
    num_epochs_list      = [25]
    lr_list              = [1e-4, 1e-5, 1e-6]
    weight_decay_list    = [5e-4, 5e-3]
    dropout_rate_list    = [0.2]
    batch_size           = 32

    pos_tf, neg_tf, val_tf = get_transforms()

    # validation loader (pos/neg とも val_tf を適用 ← 重要)
    val_ds  = CTDataset(val_csv, pos_tf=val_tf, neg_tf=val_tf)
    val_ld  = DataLoader(val_ds, batch_size=batch_size,
                         shuffle=False, num_workers=4, pin_memory=True)

    best_overall, best_cfg = -1.0, None
    combos = list(itertools.product(num_epochs_list, lr_list,
                                    weight_decay_list, dropout_rate_list))

    print(f"⏩ Grid Search: {len(combos)} combinations")
    for i,(epochs, lr, wd, dr) in enumerate(combos, 1):
        run_name = f"e{epochs}_lr{lr}_wd{wd}_do{dr}_#{i}"
        print(f"\n—— Trial {i}/{len(combos)} : {run_name} ——")

        wandb.init(project=WANDB_PROJECT_NAME, name=run_name,
                   config=dict(epochs=epochs, lr=lr, wd=wd,
                               dropout=dr, batch=batch_size), reinit=True)

        random.seed(SEED+i); np.random.seed(SEED+i)
        torch.manual_seed(SEED+i); torch.cuda.manual_seed_all(SEED+i)

        model = ModifiedResNet(dr).to(DEVICE)
        crit  = nn.BCEWithLogitsLoss().to(DEVICE)
        opt   = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
        sched = optim.lr_scheduler.ReduceLROnPlateau(
                    opt, mode='max', factor=SCHEDULER_FACTOR,
                    patience=SCHEDULER_PATIENCE, verbose=True)

        train_ds  = CTDataset(train_csv, pos_tf=pos_tf, neg_tf=neg_tf)
        sampler   = create_sampler(train_ds.df.Fracture_Label.values)
        train_ld  = DataLoader(train_ds, batch_size=batch_size,
                               sampler=sampler if sampler else None,
                               shuffle=(sampler is None),
                               num_workers=4, pin_memory=True)

        best_prauc, best_epoch = train_model(model, train_ld,
                                             val_ld, crit, opt,
                                             sched, epochs, run_name)

        wandb.summary["best_prauc"]  = best_prauc
        wandb.summary["best_epoch"]  = best_epoch

        if best_prauc > best_overall:
            best_overall, best_cfg = best_prauc, wandb.config.as_dict()
            print(f"🎉 new overall best PRAUC {best_overall:.4f}")
            
        wandb.finish()

    print("\n=========== Grid Search done ===========")
    print(f" best PRAUC : {best_overall:.4f}")
    print(f" best config: {best_cfg}")

# ----------------------------------------------------------------------
if __name__ == "__main__":
    try:
        wandb.login()
    except Exception as e:
        print(f"W&B login failed: {e}")

    grid_search_main(TRAIN_CSV_PATH, VAL_CSV_PATH)


⏩ Grid Search: 6 combinations

—— Trial 1/6 : e25_lr0.0001_wd0.0005_do0.2_#1 ——


                                                                             

  ↳ new best PRAUC 0.3094  (model saved)


                                                                              

  ↳ new best PRAUC 0.3270  (model saved)


                                                                              

  ↳ new best PRAUC 0.3757  (model saved)


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


Early-stop: no PRAUC gain for 5 epochs
🎉 new overall best PRAUC 0.3757


0,1
epoch,▁▂▃▃▄▅▆▆▇█
lr,▁▁▁▁▁▁▁▁▁▁
train_loss,█▃▂▂▁▁▂▁▁▁
val_loss,▁▁▃▂▂▃█▃█▂
val_prauc,▅▄▆▄█▆▃▁▃▅

0,1
best_epoch,5.0
best_prauc,0.37571
epoch,10.0
lr,0.0001
train_loss,0.01107
val_loss,0.34002
val_prauc,0.30173



—— Trial 2/6 : e25_lr0.0001_wd0.005_do0.2_#2 ——


                                                                              

  ↳ new best PRAUC 0.3527  (model saved)


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


Early-stop: no PRAUC gain for 5 epochs


0,1
epoch,▁▂▄▅▇█
lr,▁▁▁▁▁▁
train_loss,█▃▂▂▂▁
val_loss,▁▂▃▁▂█
val_prauc,█▇▁▁▁▂

0,1
best_epoch,1.0
best_prauc,0.35269
epoch,6.0
lr,0.0001
train_loss,0.02496
val_loss,0.60789
val_prauc,0.18106



—— Trial 3/6 : e25_lr1e-05_wd0.0005_do0.2_#3 ——


                                                                             

  ↳ new best PRAUC 0.2555  (model saved)


                                                                              

  ↳ new best PRAUC 0.2676  (model saved)


                                                                              

  ↳ new best PRAUC 0.2818  (model saved)


                                                                              

  ↳ new best PRAUC 0.2895  (model saved)


                                                                              

  ↳ new best PRAUC 0.3144  (model saved)


                                                                               

  ↳ new best PRAUC 0.3554  (model saved)


                                                                               

  ↳ new best PRAUC 0.3581  (model saved)


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


Early-stop: no PRAUC gain for 5 epochs


0,1
epoch,▁▁▂▂▃▃▄▄▅▅▆▆▇▇██
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁
val_loss,▁▄▃▄▃▄▄▄▅▃▅▅█▇▇▆
val_prauc,▃▁▄▅▄▅▄▆▄██▆▄▅▃▆

0,1
best_epoch,11.0
best_prauc,0.35808
epoch,16.0
lr,1e-05
train_loss,0.00325
val_loss,0.41534
val_prauc,0.3075



—— Trial 4/6 : e25_lr1e-05_wd0.005_do0.2_#4 ——


                                                                             

  ↳ new best PRAUC 0.3102  (model saved)


                                                                              

  ↳ new best PRAUC 0.3220  (model saved)


                                                                              

  ↳ new best PRAUC 0.3391  (model saved)


                                                                               

  ↳ new best PRAUC 0.3598  (model saved)


                                                                               

  ↳ new best PRAUC 0.3899  (model saved)


                                                                               

  ↳ new best PRAUC 0.3999  (model saved)


                                                                               

  ↳ new best PRAUC 0.4040  (model saved)


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


Early-stop: no PRAUC gain for 5 epochs
🎉 new overall best PRAUC 0.4040


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,▁▂▃▃▄▃▅▆▅▅▄▇▄▄▃▃▅▄▅█
val_prauc,▄▂▃▄▄▄▅▂▃▂▆▁▇██▆▁▆▃▂

0,1
best_epoch,15.0
best_prauc,0.40402
epoch,20.0
lr,1e-05
train_loss,0.00577
val_loss,0.48259
val_prauc,0.2835



—— Trial 5/6 : e25_lr1e-06_wd0.0005_do0.2_#5 ——


                                                                           

  ↳ new best PRAUC 0.2816  (model saved)


                                                                            

  ↳ new best PRAUC 0.2988  (model saved)


                                                                            

  ↳ new best PRAUC 0.3040  (model saved)


                                                                             

  ↳ new best PRAUC 0.3042  (model saved)


                                                                             

  ↳ new best PRAUC 0.3068  (model saved)


                                                                              

  ↳ new best PRAUC 0.3095  (model saved)


                                                                              

  ↳ new best PRAUC 0.3193  (model saved)


                                                                              

  ↳ new best PRAUC 0.3238  (model saved)


                                                                              

  ↳ new best PRAUC 0.3427  (model saved)


                                                                              

  ↳ new best PRAUC 0.3469  (model saved)


                                                                               

  ↳ new best PRAUC 0.3498  (model saved)


                                                                               

  ↳ new best PRAUC 0.3517  (model saved)


                                                                               

  ↳ new best PRAUC 0.3593  (model saved)


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▅▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▄▂▁▂▂▄▄▄▄▄▄▄▅▄▅▅▆▆▆▅▇▆▇▇
val_prauc,▁▃▂▃▃▃▃▃▂▄▄▅▅▄▇▆▅▆▇▇▇▇▇██

0,1
best_epoch,24.0
best_prauc,0.35931
epoch,25.0
lr,0.0
train_loss,0.01483
val_loss,0.27525
val_prauc,0.35548



—— Trial 6/6 : e25_lr1e-06_wd0.005_do0.2_#6 ——


                                                                           

  ↳ new best PRAUC 0.2615  (model saved)


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


Early-stop: no PRAUC gain for 5 epochs


0,1
epoch,▁▂▄▅▇█
lr,▁▁▁▁▁▁
train_loss,█▄▃▂▁▁
val_loss,█▁▁▂▂▄
val_prauc,█▄▅▁▆▇

0,1
best_epoch,1.0
best_prauc,0.26146
epoch,6.0
lr,0.0
train_loss,0.0736
val_loss,0.26649
val_prauc,0.25772



 best PRAUC : 0.4040
 best config: {'epochs': 25, 'lr': 1e-05, 'wd': 0.005, 'dropout': 0.2, 'batch': 32}
