In [16]:
import os, random, numpy as np, pandas as pd
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import nibabel as nib
from PIL import Image
from tqdm.notebook import tqdm
from sklearn.metrics import precision_recall_curve, auc
import matplotlib.pyplot as plt
import wandb
from itertools import product
from datetime import datetime

# wandb 共通設定
WANDB_PROJECT = "vertebrae-sampling_axial_learning_1"
wandb.login()                     # API キーは環境変数でも可

True

# 学習モデル

In [17]:
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [18]:
# %%
# 1. データセットクラス (CTDataset)
##############################################
class CTDataset(Dataset):
    def __init__(self, csv_path, transform=None):
        self.data = pd.read_csv(csv_path)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_path = row["FullPath"]
        # CSVの列名が "Fracture_label" の場合
        # (もし "Fracture_Label" や別名であれば修正してください)
        label = row["Fracture_Label"]

        # ----- 画像読み込み (NIfTI) -----
        nifti_obj = nib.load(img_path)
        img_arr = nifti_obj.get_fdata()

        # 3次元の場合は最初のスライスを使う（軸が違うなら適宜変更）
        if len(img_arr.shape) == 3:
            img_arr = img_arr[:, :, 0]
        elif len(img_arr.shape) != 2:
            raise ValueError(f"Unsupported image shape: {img_arr.shape}")

        # ----- HUウィンドウ (100～2000) 例 -----
        img_arr = np.clip(img_arr, 100, 2000)
        img_arr = (img_arr - 100) / (2000 - 100)  # 0～1 スケーリング
        img_arr = np.uint8(img_arr * 255)

        pil_img = Image.fromarray(img_arr).convert("L")

        if self.transform:
            pil_img = self.transform(pil_img)

        return pil_img, float(label)

# 2. モデル定義 (ModifiedResNet)
##############################################
class ModifiedResNet(nn.Module):
    def __init__(self, dropout_rate=0.5):
        super(ModifiedResNet, self).__init__()
        self.resnet = models.resnet18(pretrained=True)
        # 1ch入力に変更
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # 出力層を差し替え → 1ユニット (Sigmoid)
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(num_features, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)

        x = self.resnet.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.resnet.fc(x)
        return x


# 4. 検証関数 (evaluate_model)
##############################################
def evaluate_model(model, val_loader, criterion, device="cuda"):
    model.eval()
    running_loss = 0.0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.float32)

            outputs = model(images).squeeze()  # shape: (batch,)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(outputs.cpu().numpy())

    avg_loss = running_loss / len(val_loader)

    # Precision-Recall
    precision, recall, thresholds = precision_recall_curve(all_labels, all_preds)
    prauc = auc(recall, precision)

    # F1最大となる閾値を検索
    f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
    best_idx = np.argmax(f1_scores)
    best_threshold = thresholds[best_idx] if best_idx < len(thresholds) else 0.5
    best_precision = precision[best_idx]
    best_recall = recall[best_idx]

    return avg_loss, prauc, best_threshold, best_precision, best_recall




In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer,
                scheduler=None, num_epochs=20, device="cuda"):
    run = wandb.init(project=WANDB_PROJECT, reinit=True,
                     config=dict(
                         epochs=num_epochs,
                         lr=optimizer.param_groups[0]["lr"],
                         weight_decay=optimizer.param_groups[0]["weight_decay"],
                         dropout=getattr(model.resnet.fc[0], "p", None)
                     ))
    wandb.watch(model, log="all", log_freq=100)

    best_prauc = 0.0
    for epoch in range(num_epochs):
        # ----- Training -----
        model.train()
        running_loss = 0.0
        for images, labels in tqdm(train_loader, desc=f"[Train] {epoch+1}/{num_epochs}"):
            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.float32)

            optimizer.zero_grad()
            outputs = model(images).squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)

        # ----- Validation -----
        val_loss, val_prauc, th, prec, rec = evaluate_model(model, val_loader, criterion, device)

        # wandb ログ
        wandb.log({
            "epoch": epoch+1,
            "train_loss": avg_train_loss,
            "val_loss":   val_loss,
            "val_prauc":  val_prauc,
            "best_th":    th,
            "precision":  prec,
            "recall":     rec,
            "lr": optimizer.param_groups[0]["lr"]
        })

        # scheduler
        if scheduler is not None:
            scheduler.step(val_loss)

        # モデル保存
        if val_prauc > best_prauc:
            best_prauc = val_prauc
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")  # 例: 20250414_113045
            model_path = f"best_model_{timestamp}.pth"
            torch.save(model.state_dict(), model_path)
            wandb.run.summary["best_prauc"] = best_prauc
            wandb.run.summary["best_epoch"] = epoch+1
            wandb.run.summary["saved_model_path"] = model_path

            #使わない
            #torch.save(model.state_dict(), "best_model.pth")
            #wandb.run.summary["best_prauc"] = best_prauc
            #wandb.run.summary["best_epoch"] = epoch+1

    wandb.finish()
    return best_prauc


def grid_search(train_loader, val_loader, device="cuda"):
    num_epochs_list   = [20, 30]
    lr_list           = [1e-5, 3e-5]
    weight_decay_list = [1e-3, 1e-2]
    dropout_rate_list = [0.3, 0.0]

    for (num_epochs, lr, wd, do) in product(num_epochs_list, lr_list, weight_decay_list, dropout_rate_list):
        model = ModifiedResNet(dropout_rate=do).to(device)
        criterion = nn.BCELoss()
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min')
        best_prauc = train_model(model, train_loader, val_loader, criterion, optimizer,
                                 scheduler=scheduler, num_epochs=num_epochs, device=device)
        print(f"Finished: Ep{num_epochs} LR{lr} WD{wd} DO{do} → PRAUC {best_prauc:.4f}")

: 

# sweep_config = {
#     "method": "random",
#     "metric": {"name": "val_prauc", "goal": "maximize"},
#     "parameters": {
#         "epochs":       {"values": [20, 30]},
#         "lr":           {"values": [1e-5, 3e-5]},
#         "weight_decay": {"values": [1e-3, 1e-2]},
#         "dropout":      {"values": [0.0, 0.3]}
#     }
# }
# sweep_id = wandb.sweep(sweep_config, project=WANDB_PROJECT)
# 
# def sweep_train():
#     # wandb.config からパラメータ取得
#     c = wandb.config
#     model = ModifiedResNet(dropout_rate=c.dropout).to(device)
#     criterion = nn.BCELoss()
#     optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=c.weight_decay)
#     scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min')
#     train_model(model, train_loader, val_loader, criterion, optimizer,
#                 scheduler=scheduler, num_epochs=c.epochs, device=device)
# 
# wandb.agent(sweep_id, function=sweep_train, count=20)   # 例: 20 試行
# 

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

train_csv = "/mnt/nfs1/home/yamamoto-hiroto/research/vertebrae/Sakaguchi_file/slice_train_sampling/axial/sampling_labels_axial_2.csv"
val_csv   = "/mnt/nfs1/home/yamamoto-hiroto/research/vertebrae/Sakaguchi_file/slice_val/axial/val_labels_axial.csv"

transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
transform_val = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

batch_size = 64
train_dataset = CTDataset(train_csv, transform=transform_train)
val_dataset   = CTDataset(val_csv,   transform=transform_val)
train_loader  = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,  num_workers=4)
val_loader    = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=4)

# ---- 実行 ----
grid_search(train_loader, val_loader, device=device)
# あるいは wandb Sweep を回す場合はセル 6‑B を実行