# Vesuvius Challenge - GPU/CPU Training with SWI
元のコードから最小限の変更 + SlidingWindowInference保持

In [None]:
# バックエンドをPyTorchに変更（GPU/CPU対応）
import os
import warnings
os.environ["KERAS_BACKEND"] = "torch"  # jax → torch
warnings.filterwarnings('ignore')

In [None]:
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import keras
from keras import ops
from keras.optimizers import AdamW
from keras.optimizers.schedules import CosineDecay

# tf.dataのみTensorFlow使用
import tensorflow as tf

import medicai
from medicai.transforms import (
    Compose,
    NormalizeIntensity,
    ScaleIntensityRange,
    Resize,
    RandShiftIntensity,
    RandRotate90,
    RandRotate,
    RandFlip,
    RandCutOut,
    RandSpatialCrop
)
from medicai.models import SegFormer, TransUNet, UNETRPlusPlus
from medicai.losses import SparseDiceCELoss, SparseCenterlineDiceLoss
from medicai.metrics import SparseDiceMetric
from medicai.callbacks import SlidingWindowInferenceCallback
from medicai.utils import SlidingWindowInference

In [None]:
# 設定
input_shape = (64, 64, 64)  # GPU/CPUメモリに合わせて調整可能
batch_size = 2  # GPU/CPUに合わせて調整
num_classes = 3
num_samples = 780
epochs = 100  # 元は200

## Data Loader

In [None]:
def parse_tfrecord_fn(example):
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.FixedLenFeature([], tf.string),
        "image_shape": tf.io.FixedLenFeature([3], tf.int64),
        "label_shape": tf.io.FixedLenFeature([3], tf.int64),
    }
    parsed_example = tf.io.parse_single_example(example, feature_description)
    image = tf.io.decode_raw(parsed_example["image"], tf.uint8)
    label = tf.io.decode_raw(parsed_example["label"], tf.uint8)
    image_shape = tf.cast(parsed_example["image_shape"], tf.int64)
    label_shape = tf.cast(parsed_example["label_shape"], tf.int64)
    image = tf.reshape(image, image_shape)
    label = tf.reshape(label, label_shape)
    return image, label

In [None]:
def prepare_inputs(image, label):
    # Add channel dimension
    image = image[..., None] # (D, H, W, 1)
    label = label[..., None] # (D, H, W, 1)
    
    # Convert to float32
    image = tf.cast(image, tf.float32)
    label = tf.cast(label, tf.float32)
    return image, label

In [None]:
def train_transformation(image, label):
    data = {"image": image, "label": label}
    pipeline = Compose([
        ## Geometric transformation
        RandSpatialCrop(
            keys=["image", "label"],
            roi_size=input_shape,
            random_center=True,
            random_size=False,
            invalid_label=2,         
            min_valid_ratio=0.5,     
            max_attempts=10
        ),
        RandFlip(keys=["image", "label"], spatial_axis=[0], prob=0.5),
        RandFlip(keys=["image", "label"], spatial_axis=[1], prob=0.5),
        RandFlip(keys=["image", "label"], spatial_axis=[2], prob=0.5),
        RandRotate90(
            keys=["image", "label"], 
            prob=0.4, 
            max_k=3, 
            spatial_axes=(0, 1)
        ),
        RandRotate(
            keys=["image", "label"], 
            factor=0.2, 
            prob=0.7, 
            fill_mode="crop",
        ),
        
        ## Intensity transformation
        NormalizeIntensity(
            keys=["image"], 
            nonzero=True,
            channel_wise=False
        ),
        RandShiftIntensity(
            keys=["image"], offsets=0.10, prob=0.5
        ),
        
        ## Spatial transformation 
        RandCutOut(
            keys=["image", "label"],
            invalid_label=2, 
            mask_size=[
                input_shape[1]//4,
                input_shape[2]//4
            ],
            fill_mode="constant",
            cutout_mode='volume',
            prob=0.8,
            num_cuts=5,
        ),
    ])
    result = pipeline(data)
    return result["image"], result["label"]

def val_transformation(image, label):
    data = {"image": image, "label": label}
    pipeline = Compose([
        NormalizeIntensity(
            keys=["image"], 
            nonzero=True,
            channel_wise=False
        ),
    ])
    result = pipeline(data)
    return result["image"], result["label"]

In [None]:
def tfrecord_loader(tfrecord_pattern, batch_size=1, shuffle=True):
    dataset = tf.data.TFRecordDataset(
        tf.io.gfile.glob(tfrecord_pattern)
    )
    dataset = dataset.shuffle(buffer_size=100) if shuffle else dataset 
    dataset = dataset.map(
        parse_tfrecord_fn, num_parallel_calls=tf.data.AUTOTUNE
    )
    dataset = dataset.map(
        prepare_inputs,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    if shuffle:
        dataset = dataset.map(
            train_transformation,
            num_parallel_calls=tf.data.AUTOTUNE
        )
    else:
        dataset = dataset.map(
            val_transformation,
            num_parallel_calls=tf.data.AUTOTUNE
        )
    dataset = dataset.batch(batch_size, drop_remainder=shuffle)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

In [None]:
# データ読み込み
all_tfrec = sorted(
    glob.glob("/kaggle/input/vesuvius-tfrecords/*.tfrec"),
    key=lambda x: int(x.split("_")[-1].replace(".tfrec", ""))
)

val_idx = -1
val_patterns = [all_tfrec[val_idx]]
train_patterns = [
    f for i, f in enumerate(all_tfrec) if i != len(all_tfrec) + val_idx
]

train_ds = tfrecord_loader(
    train_patterns, batch_size=batch_size, shuffle=True
)
val_ds = tfrecord_loader(
    val_patterns, batch_size=1, shuffle=False
)

print(f"Train files: {len(train_patterns)}")
print(f"Val files: {len(val_patterns)}")

## Visualization

In [None]:
# データ確認
x, y = next(iter(val_ds))
print(f"Input shape: {x.shape}")
print(f"Label shape: {y.shape}")

In [None]:
def plot_sample(x, y, sample_idx=0, max_slices=4):
    img = np.squeeze(x[sample_idx])  # (D, H, W)
    mask = np.squeeze(y[sample_idx])  # (D, H, W)
    D = img.shape[0]
    
    step = max(1, D // max_slices)
    slices = range(0, D, step)[:max_slices]
    
    n_slices = len(slices)
    fig, axes = plt.subplots(2, n_slices, figsize=(3*n_slices, 6))
    
    for i, s in enumerate(slices):
        axes[0, i].imshow(img[s], cmap='gray')
        axes[0, i].set_title(f"Slice {s}")
        axes[0, i].axis('off')
        
        axes[1, i].imshow(mask[s], cmap='gray')
        axes[1, i].set_title(f"Mask {s}")
        axes[1, i].axis('off')
    
    plt.suptitle(f"Sample {sample_idx}")
    plt.tight_layout()
    plt.show()

plot_sample(x, y, sample_idx=0, max_slices=4)

## Model

In [None]:
# モデル定義（3つから選択可能）
MODEL_TYPE = "segformer"  # "segformer", "transunet", "unetr++"

if MODEL_TYPE == "segformer":
    model = SegFormer(
        input_shape=input_shape + (1,),
        encoder_name='mit_b0',
        classifier_activation='softmax',
        num_classes=num_classes,
    )
elif MODEL_TYPE == "transunet":
    model = TransUNet(
        encoder_name='seresnext50',
        input_shape=input_shape + (1,),
        num_classes=num_classes,
        classifier_activation='softmax'
    )
elif MODEL_TYPE == "unetr++":
    model = UNETRPlusPlus(
        input_shape=input_shape + (1,),
        encoder_name='unetr_plusplus_encoder',
        classifier_activation='softmax',
        num_classes=num_classes,
    )

print(f"Model: {MODEL_TYPE}")
print(f"Parameters: {model.count_params() / 1e6:.2f}M")

## Training Setup

In [None]:
# LRスケジュール
steps_per_epoch = num_samples // batch_size
total_steps = steps_per_epoch * epochs
warmup_steps = int(total_steps * 0.05)
decay_steps = max(1, total_steps - warmup_steps)

lr_schedule = CosineDecay(
    initial_learning_rate=1e-6,
    decay_steps=decay_steps,
    warmup_target=min(3e-4, 1e-4 * (batch_size / 2)),
    warmup_steps=warmup_steps,
    alpha=0.1,
)

In [None]:
# オプティマイザ、損失関数、メトリクス
optim = keras.optimizers.AdamW(
    learning_rate=lr_schedule,
    weight_decay=1e-5,
)

dice_ce_loss_fn = SparseDiceCELoss(
    from_logits=False, 
    num_classes=num_classes,
    ignore_class_ids=2,
)

cldice_loss_fn = SparseCenterlineDiceLoss(
    from_logits=False, 
    num_classes=num_classes,
    target_class_ids=1,
    ignore_class_ids=2,
    iters=20  # GPU/CPU用に削減（元は50）
)

combined_loss_fn = lambda y_true, y_pred: (
    dice_ce_loss_fn(y_true, y_pred) + cldice_loss_fn(y_true, y_pred)
)

metrics = [
    SparseDiceMetric(
        from_logits=False, 
        num_classes=num_classes, 
        ignore_class_ids=2,
        name='dice'
    ),
]

model.compile(
    optimizer=optim,
    loss=combined_loss_fn,
    metrics=metrics,
)

In [None]:
# SlidingWindowInferenceCallback（重要！）
swi_callback_metric = SparseDiceMetric(
    from_logits=False,
    ignore_class_ids=2,
    num_classes=num_classes,
    name='val_dice',
)

swi_callback = SlidingWindowInferenceCallback(
    model,
    dataset=val_ds,
    metrics=swi_callback_metric,
    num_classes=num_classes,
    interval=5,  # 5エポックごとに評価
    overlap=0.5,
    mode='gaussian',
    roi_size=input_shape,
    sw_batch_size=batch_size,  # GPU/CPU用
    save_path="model.weights.h5"
)

callbacks = [
    swi_callback,
    keras.callbacks.EarlyStopping(
        monitor='val_dice',
        patience=15,
        mode='max',
        restore_best_weights=True,
        verbose=1
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    ),
]

## Training

In [None]:
# 学習実行
print("Starting training...")
print(f"Epochs: {epochs}")
print(f"Batch size: {batch_size}")
print(f"Input shape: {input_shape}")

history = model.fit(
    train_ds,
    epochs=epochs,
    callbacks=callbacks,
    verbose=1
)

## Evaluation

In [None]:
# ベストモデルをロード
model.load_weights("model.weights.h5")

# SlidingWindowInferenceで評価
swi = SlidingWindowInference(
    model,
    num_classes=num_classes,
    roi_size=input_shape,
    mode='gaussian',
    sw_batch_size=batch_size,
    overlap=0.5,
)

In [None]:
# Diceスコア計算
dice = SparseDiceMetric(
    from_logits=False,
    num_classes=num_classes,
    ignore_class_ids=2,
    name='dice',
)

for sample in val_ds:
    x, y = sample
    output = swi(x)  # SlidingWindowで推論
    y = ops.convert_to_tensor(y)
    output = ops.convert_to_tensor(output)
    dice.update_state(y, output)

dice_score = float(ops.convert_to_numpy(dice.result()))
print(f"Final Dice Score: {dice_score:.4f}")
dice.reset_state()

In [None]:
# 予測例の可視化
x, y = next(iter(val_ds))
y_pred = swi(x)
segment = y_pred.argmax(-1).astype(np.uint8)

print(f"Input shape: {x.shape}")
print(f"Prediction shape: {segment.shape}")
print(f"Unique classes: {np.unique(segment)}")

plot_sample(x, segment, sample_idx=0, max_slices=4)

In [None]:
# 学習曲線
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
if 'val_loss' in history.history:
    plt.plot(history.history['val_loss'], label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(history.history['dice'], label='Train Dice')
if 'val_dice' in history.history:
    plt.plot(history.history['val_dice'], label='Val Dice')
plt.xlabel('Epoch')
plt.ylabel('Dice Score')
plt.title('Dice Score')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

## Save Final Model

In [None]:
# 最終モデルの保存
model.save_weights("final_model.weights.h5")
print("Model saved to final_model.weights.h5")
print(f"Best Dice Score: {dice_score:.4f}")