In [1]:
import os, random, numpy as np, pandas as pd
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import nibabel as nib
from PIL import Image
from tqdm.notebook import tqdm
from sklearn.metrics import precision_recall_curve, auc
import matplotlib.pyplot as plt
import wandb
from itertools import product
from datetime import datetime

# wandb 共通設定
WANDB_PROJECT = "vertebrae-sampling_axial_learning_2"
wandb.login()                     # API キーは環境変数でも可


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33myuya00[0m ([33myuya00-university-of-hyogo[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

# 学習モデル

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [3]:
# ### 1. Dataset class
# - 画像を読み込んだら **`convert("RGB")`** で 3ch に複製するだけ

# %%
class CTDataset(Dataset):
    def __init__(self, csv_path, transform=None):
        self.data = pd.read_csv(csv_path)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_path = row["FullPath"]
        label = row["Fracture_Label"]

        # --- Load NIfTI ---
        img_arr = nib.load(img_path).get_fdata()

        # Use first slice if 3D
        if len(img_arr.shape) == 3:
            img_arr = img_arr[:, :, 0]
        elif len(img_arr.shape) != 2:
            raise ValueError(f"Unsupported image shape: {img_arr.shape}")

        # HU window (100–2000) → 0‑255
        img_arr = np.clip(img_arr, 100, 2000)
        img_arr = (img_arr - 100) / (2000 - 100)
        img_arr = np.uint8(img_arr * 255)

        # ★ 3‑channel duplication ★
        pil_img = Image.fromarray(img_arr).convert("RGB")

        if self.transform:
            pil_img = self.transform(pil_img)

        return pil_img, float(label)

# %% [markdown]
# ### 2. Model – ResNet‑18 without touching `conv1`

# %%
class ModifiedResNet(nn.Module):
    def __init__(self, dropout_rate=0.5):
        super().__init__()
        # Keep original conv1 (3ch)
        self.resnet = models.resnet18(pretrained=True)
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(num_features, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.resnet(x)
        
# ### 3. Validation helper

# %%
def evaluate_model(model, val_loader, criterion, device="cuda"):
    model.eval()
    running_loss = 0.0
    all_labels, all_preds = [], []

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.float32)

            outputs = model(images).squeeze()
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(outputs.cpu().numpy())

    avg_loss = running_loss / len(val_loader)
    precision, recall, thresholds = precision_recall_curve(all_labels, all_preds)
    prauc = auc(recall, precision)

    f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
    best_idx = np.argmax(f1_scores)
    best_threshold = thresholds[best_idx] if best_idx < len(thresholds) else 0.5

    return avg_loss, prauc, best_threshold, precision[best_idx], recall[best_idx]


In [4]:
# ### 4. Training loop (unchanged)

# %%
def train_model(model, train_loader, val_loader, criterion, optimizer,
                scheduler=None, num_epochs=20, device="cuda"):
    run = wandb.init(project=WANDB_PROJECT, reinit=True,
                     config=dict(epochs=num_epochs,
                                 lr=optimizer.param_groups[0]["lr"],
                                 weight_decay=optimizer.param_groups[0]["weight_decay"],
                                 dropout=getattr(model.resnet.fc[0], "p", None)))
    wandb.watch(model, log="all", log_freq=100)

    best_prauc = 0.0
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in tqdm(train_loader, desc=f"[Train] {epoch+1}/{num_epochs}"):
            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.float32)

            optimizer.zero_grad()
            outputs = model(images).squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)
        val_loss, val_prauc, th, prec, rec = evaluate_model(model, val_loader, criterion, device)

        wandb.log({"epoch": epoch+1, "train_loss": avg_train_loss, "val_loss": val_loss,
                   "val_prauc": val_prauc, "best_th": th, "precision": prec, "recall": rec,
                   "lr": optimizer.param_groups[0]["lr"]})

        if scheduler is not None:
            scheduler.step(val_loss)

        if val_prauc > best_prauc:
            best_prauc = val_prauc
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            model_path = f"best_model_{timestamp}.pth"
            torch.save(model.state_dict(), model_path)
            wandb.run.summary.update({"best_prauc": best_prauc,
                                       "best_epoch": epoch+1,
                                       "saved_model_path": model_path})

    wandb.finish()
    return best_prauc

# %% [markdown]
# ### 5. Grid search (unchanged)

# %%
def grid_search(train_loader, val_loader, device="cuda"):
    num_epochs_list   = [20, 30]
    lr_list           = [1e-5, 3e-4]
    weight_decay_list = [1e-4, 5e-4]
    dropout_rate_list = [0.3, 0.0]

    for (num_epochs, lr, wd, do) in product(num_epochs_list, lr_list, weight_decay_list, dropout_rate_list):
        model = ModifiedResNet(dropout_rate=do).to(device)
        criterion = nn.BCELoss()
        
        #optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
        #scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min')
        
        optimizer = optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
        steps = num_epochs * len(train_loader)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=3e-4, total_steps=steps)
        
        best_prauc = train_model(model, train_loader, val_loader, criterion, optimizer,
                                 scheduler=scheduler, num_epochs=num_epochs, device=device)
        print(f"Finished: Ep{num_epochs} LR{lr} WD{wd} DO{do} → PRAUC {best_prauc:.4f}")

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

train_csv = "/mnt/nfs1/home/yamamoto-hiroto/research/vertebrae/Sakaguchi_file/slice_train_sampling/axial/sampling_labels_axial_2.csv"
val_csv   = "/mnt/nfs1/home/yamamoto-hiroto/research/vertebrae/Sakaguchi_file/slice_val/axial/val_labels_axial.csv"

transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # 3ch
])

transform_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

batch_size = 64
train_dataset = CTDataset(train_csv, transform=transform_train)
val_dataset   = CTDataset(val_csv,   transform=transform_val)
train_loader  = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=4)
val_loader    = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=4)

# %% [markdown]
# ### 7. Run grid search (or any custom training routine)

# %%
grid_search(train_loader, val_loader, device=device)



[Train] 1/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 2/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 3/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 4/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 5/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 6/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 7/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 8/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 9/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 10/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 11/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 12/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 13/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 14/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 15/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 16/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 17/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 18/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 19/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 20/20:   0%|          | 0/79 [00:00<?, ?it/s]

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
best_th,█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▁█▃▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃▃▃
precision,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
recall,▁███████████████████
train_loss,█▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▂▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃
val_prauc,█▁▃▃▃▃▄▃▃▃▂▂▂▂▂▂▁▁▁▁

0,1
best_epoch,1
best_prauc,0.07795
best_th,0.00451
epoch,20
lr,1e-05
precision,0.0672
recall,0.98574
saved_model_path,best_model_20250414_...
train_loss,0.0061
val_loss,0.35257


Finished: Ep20 LR1e-05 WD0.0001 DO0.3 → PRAUC 0.0779




[Train] 1/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 2/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 3/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 4/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 5/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 6/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 7/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 8/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 9/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 10/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 11/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 12/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 13/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 14/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 15/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 16/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 17/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 18/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 19/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 20/20:   0%|          | 0/79 [00:00<?, ?it/s]

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
best_th,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▁▇▄▄▄▅▅▅▆▆▆▆▇▇▇▇▇███
precision,▁▁▂▁▁▁▁▂▂▂▂▄▆▅▆▆█▇▇█
recall,███████▇▇█▇▅▄▅▃▃▃▃▂▁
train_loss,█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,▇▁▂▂▃▃▄▄▅▅▆▆▆▇▇▇▇███
val_prauc,▁▅▅▄▆▅▄▅▅▅▅▆▆▇▇▇█▇██

0,1
best_epoch,20
best_prauc,0.06053
best_th,0.00254
epoch,20
lr,1e-05
precision,0.07021
recall,0.72193
saved_model_path,best_model_20250414_...
train_loss,0.00277
val_loss,0.39718


Finished: Ep20 LR1e-05 WD0.0001 DO0.0 → PRAUC 0.0605




[Train] 1/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 2/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 3/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 4/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 5/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 6/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 7/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 8/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 9/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 10/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 11/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 12/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 13/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 14/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 15/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 16/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 17/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 18/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 19/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 20/20:   0%|          | 0/79 [00:00<?, ?it/s]

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
best_th,█▅▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▁█▄▄▄▄▅▅▅▆▆▆▆▇▇▇▇▇██
precision,▁▂▃▄▆▄▅▆▅▆▆▇█▇▇▆▆▇▇▇
recall,█▇▄▄▂▃▂▁▃▂▁▁▁▂▂▂▃▂▂▂
train_loss,█▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▁▁▁▂▃▃▄▄▅▅▅▆▆▆▇▇▇██
val_prauc,▁▂▄▅█▆▇▇▇▇▇▇█▇▇▇▇█▇▇

0,1
best_epoch,13
best_prauc,0.0754
best_th,0.00328
epoch,20
lr,1e-05
precision,0.08708
recall,0.54545
saved_model_path,best_model_20250414_...
train_loss,0.00368
val_loss,0.38241


Finished: Ep20 LR1e-05 WD0.0005 DO0.3 → PRAUC 0.0754




[Train] 1/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 2/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 3/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 4/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 5/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 6/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 7/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 8/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 9/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 10/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 11/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 12/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 13/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 14/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 15/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 16/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 17/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 18/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 19/20:   0%|          | 0/79 [00:00<?, ?it/s]



[Train] 20/20:   0%|          | 0/79 [00:00<?, ?it/s]

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
best_th,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▁█▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃▃▃
precision,█▁▂▄▄▅▆▅▅▅▅▅▄▄▄▄▄▄▄▃
recall,▁█▇▅▇▇▇▇▇▇██▇▇▇▇▇▇▇▆
train_loss,█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▁▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▃▃
val_prauc,█▂▂▂▁▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁

0,1
best_epoch,1
best_prauc,0.07458
best_th,0.00544
epoch,20
lr,1e-05
precision,0.07149
recall,0.85918
saved_model_path,best_model_20250414_...
train_loss,0.00568
val_loss,0.34838


Finished: Ep20 LR1e-05 WD0.0005 DO0.0 → PRAUC 0.0746




[Train] 1/20:   0%|          | 0/79 [00:00<?, ?it/s]

KeyboardInterrupt: 