# Vesuvius Challenge - 元のコードと同じ（環境対応のみ変更）

In [None]:
import os, warnings
os.environ["KERAS_BACKEND"] = "torch"  # jax → torch のみ変更
warnings.filterwarnings('ignore')

In [None]:
import glob
import numpy as np
import pandas as pd

from PIL import Image
import matplotlib.pyplot as plt

import keras
from keras import ops
from keras.optimizers import SGD, AdamW, Muon
from keras.optimizers.schedules import CosineDecay, PolynomialDecay

import tensorflow as tf

import medicai
from medicai.transforms import (
    Compose,
    NormalizeIntensity,
    ScaleIntensityRange,
    Resize,
    RandShiftIntensity,
    RandRotate90,
    RandRotate,
    RandFlip,
    RandCutOut,
    RandSpatialCrop
)
from medicai.layers import ResizingND
from medicai.models import (
    UNet, SegFormer, TransUNet, SwinUNETR, UPerNet, ConvNeXtV2Tiny, UNETRPlusPlus
)
from medicai.losses import (
    SparseDiceCELoss, SparseTverskyLoss, SparseCenterlineDiceLoss
)
from medicai.metrics import SparseDiceMetric
from medicai.callbacks import SlidingWindowInferenceCallback
from medicai.utils import SlidingWindowInference
from medicai.utils import soft_skeletonize

In [None]:
keras.utils.set_random_seed(101)

# TPU設定を削除、単一デバイス用に
total_device = 1
print(f'total device: {total_device}')

In [None]:
keras.version(), keras.config.backend(), medicai.version()

## Data Loader

In [None]:
input_shape=(128, 128, 128)
batch_size=1 * total_device
num_classes=3
num_samples = 780
epochs = 200

In [None]:
def parse_tfrecord_fn(example):
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.FixedLenFeature([], tf.string),
        "image_shape": tf.io.FixedLenFeature([3], tf.int64),
        "label_shape": tf.io.FixedLenFeature([3], tf.int64),
    }
    parsed_example = tf.io.parse_single_example(example, feature_description)
    image = tf.io.decode_raw(parsed_example["image"], tf.uint8)
    label = tf.io.decode_raw(parsed_example["label"], tf.uint8)
    image_shape = tf.cast(parsed_example["image_shape"], tf.int64)
    label_shape = tf.cast(parsed_example["label_shape"], tf.int64)
    image = tf.reshape(image, image_shape)
    label = tf.reshape(label, label_shape)
    return image, label

In [None]:
def prepare_inputs(image, label):
    image = image[..., None]
    label = label[..., None]
    image = tf.cast(image, tf.float32)
    label = tf.cast(label, tf.float32)
    return image, label

In [None]:
def train_transformation(image, label):
    data = {"image": image, "label": label}
    pipeline = Compose([
        ## Geometric transformation
        RandSpatialCrop(
            keys=["image", "label"],
            roi_size=input_shape,
            random_center=True,
            random_size=False,
            invalid_label=2,         
            min_valid_ratio=0.5,     
            max_attempts=10
        ),
        RandFlip(keys=["image", "label"], spatial_axis=[0], prob=0.5),
        RandFlip(keys=["image", "label"], spatial_axis=[1], prob=0.5),
        RandFlip(keys=["image", "label"], spatial_axis=[2], prob=0.5),
        RandRotate90(
            keys=["image", "label"], 
            prob=0.4, 
            max_k=3, 
            spatial_axes=(0, 1)
        ),
        RandRotate(
            keys=["image", "label"], 
            factor=0.2, 
            prob=0.7, 
            fill_mode="crop",
        ),

        ## Intensiry transformation
        NormalizeIntensity(
            keys=["image"], 
            nonzero=True,
            channel_wise=False
        ),
        RandShiftIntensity(
            keys=["image"], offsets=0.10, prob=0.5
        ),
        ## Spatial transformation 
        RandCutOut(
            keys=["image", "label"],
            invalid_label=2, 
            mask_size=[
                input_shape[1]//4,
                input_shape[2]//4
            ],
            fill_mode="constant",
            cutout_mode='volume',
            prob=0.8,
            num_cuts=5,
        ),
    ])
    result = pipeline(data)
    return result["image"], result["label"]


def val_transformation(image, label):
    data = {"image": image, "label": label}
    pipeline = Compose([
        NormalizeIntensity(
            keys=["image"], 
            nonzero=True,
            channel_wise=False
        ),
    ])
    result = pipeline(data)
    return result["image"], result["label"]

In [None]:
def tfrecord_loader(tfrecord_pattern, batch_size=1, shuffle=True):
    dataset = tf.data.TFRecordDataset(
        tf.io.gfile.glob(tfrecord_pattern)
    )
    dataset = dataset.shuffle(buffer_size=100) if shuffle else dataset 
    dataset = dataset.map(
        parse_tfrecord_fn, num_parallel_calls=tf.data.AUTOTUNE
    )
    dataset = dataset.map(
        prepare_inputs,
        num_parallel_calls=tf.data.AUTOTUNE
    )
    if shuffle:
        dataset = dataset.map(
            train_transformation,
            num_parallel_calls=tf.data.AUTOTUNE
        )
    else:
        dataset = dataset.map(
            val_transformation,
            num_parallel_calls=tf.data.AUTOTUNE
        )
    dataset = dataset.batch(batch_size, drop_remainder=shuffle)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

In [None]:
all_tfrec = sorted(
    glob.glob("/kaggle/input/vesuvius-tfrecords/*.tfrec"),
    key=lambda x: int(x.split("_")[-1].replace(".tfrec", ""))
)

val_idx = -1
val_patterns = [all_tfrec[val_idx]]
train_patterns = [
    f for i, f in enumerate(all_tfrec) if i != len(all_tfrec) + val_idx
]

train_ds = tfrecord_loader(
    train_patterns, batch_size=batch_size, shuffle=True
)
val_ds = tfrecord_loader(
    val_patterns, batch_size=1, shuffle=False
)

In [None]:
x, y = next(iter(val_ds))
x.shape, y.shape

## Model

In [None]:
model = SegFormer(
    input_shape=input_shape + (1,),
    encoder_name='mit_b0',
    classifier_activation='softmax',
    num_classes=num_classes,
)

model.count_params() / 1e6

## LR Schedules and Optimizer

In [None]:
steps_per_epoch = num_samples // batch_size
total_steps = steps_per_epoch * epochs
warmup_steps = int(total_steps * 0.05)
decay_steps = max(1, total_steps - warmup_steps)
lr_schedule = CosineDecay(
    initial_learning_rate=1e-6,
    decay_steps=decay_steps,
    warmup_target=min(3e-4, 1e-4 * (batch_size / 2)),
    warmup_steps=warmup_steps,
    alpha=0.1,
)

In [None]:
optim = keras.optimizers.AdamW(
    learning_rate=lr_schedule,
    weight_decay=1e-5,
)

dice_ce_loss_fn = SparseDiceCELoss(
    from_logits=False, 
    num_classes=num_classes,
    ignore_class_ids=2,
)
cldice_loss_fn = SparseCenterlineDiceLoss(
    from_logits=False, 
    num_classes=num_classes,
    target_class_ids=1,
    ignore_class_ids=2,
    iters=50
)
combined_loss_fn = lambda y_true, y_pred: (
    dice_ce_loss_fn(y_true, y_pred) + cldice_loss_fn(y_true, y_pred)
)


metrics = [
    SparseDiceMetric(
        from_logits=False, 
        num_classes=num_classes, 
        ignore_class_ids=2,
        name='dice'
    ),
]

model.compile(
    optimizer=optim,
    loss=combined_loss_fn,
    metrics=metrics,
)

In [None]:
swi_callback_metric = SparseDiceMetric(
    from_logits=False,
    ignore_class_ids=2,
    num_classes=num_classes,
    name='val_dice',
)

swi_callback = SlidingWindowInferenceCallback(
    model,
    dataset=val_ds,
    metrics=swi_callback_metric,
    num_classes=num_classes,
    interval=5,
    overlap=0.5,
    mode='gaussian',
    roi_size=input_shape,
    sw_batch_size=1 * total_device,
    save_path="model.weights.h5"
)

In [None]:
model.fit(
    train_ds,
    epochs=epochs,
    callbacks=[
        swi_callback
    ]
)

## Eval

In [None]:
model.load_weights(
    "model.weights.h5"
)
swi = SlidingWindowInference(
    model,
    num_classes=num_classes,
    roi_size=input_shape,
    mode='gaussian',
    sw_batch_size=1 * total_device,
    overlap=0.5,
)

In [None]:
dice = SparseDiceMetric(
    from_logits=False,
    num_classes=num_classes,
    ignore_class_ids=2,
    name='dice',
)

In [None]:
for sample in val_ds:
    x, y = sample
    output = swi(x)
    y = ops.convert_to_tensor(y)
    output = ops.convert_to_tensor(output)
    dice.update_state(y, output)

dice_score = float(ops.convert_to_numpy(dice.result()))
print(f"Dice Score: {dice_score:.4f}")
dice.reset_state()

In [None]:
x, y = next(iter(val_ds))
x.shape, y.shape

In [None]:
y_pred = swi(x)
y_pred.shape

In [None]:
segment = y_pred.argmax(-1).astype(np.uint8)
segment.shape, np.unique(segment)