In [None]:
!pip install tensorflow-metal

In [1]:
import h5py
import numpy as np
import glob
import tensorflow as tf
from collections import defaultdict

def group_slices_by_volume(file_list):
    """
    Group slices by volume (patient) ID
    """
    volume_files = defaultdict(list)
    
    for f in file_list:
        # Extract volume ID: "volume_1_slice_0.h5" -> "volume_1"
        filename = f.split('/')[-1]
        volume_id = '_'.join(filename.split('_')[:2])  # "volume_1"
        volume_files[volume_id].append(f)
    
    # Sort slices for each volume
    for vol_id in volume_files:
        volume_files[vol_id] = sorted(volume_files[vol_id], 
                                      key=lambda x: int(x.split('_')[-1].replace('.h5', '')))
    
    return volume_files

def load_volume_from_slices(slice_files, target_depth=64):
    """
    Load 3D volume from sorted slice files
    """
    images = []
    masks = []
    
    # Take first target_depth slices (or middle slices where tumor is)
    # Middle slices often have more tumor
    start_idx = (len(slice_files) - target_depth) // 2
    selected_files = slice_files[start_idx:start_idx + target_depth]
    
    for file_path in selected_files:
        with h5py.File(file_path, 'r') as f:
            image = f['image'][:]  # (240, 240, 4)
            mask = f['mask'][:]    # (240, 240, 3)
            
            # Crop to save memory
            image = image[:192, :192, :]
            mask = mask[:192, :192, :]
            
            # Normalize
            image = (image - image.min()) / (image.max() - image.min() + 1e-8)
            mask = (mask > 0).astype(np.float32)
            
            images.append(image)
            masks.append(mask)
    
    return np.array(images), np.array(masks)

def volume_generator_3d(volume_dict, batch_size=1, target_depth=64):
    """
    Generator that yields complete patient volumes
    """
    volume_ids = list(volume_dict.keys())
    
    while True:
        np.random.shuffle(volume_ids)
        
        for i in range(0, len(volume_ids), batch_size):
            volumes = []
            masks = []
            
            for j in range(batch_size):
                if i + j >= len(volume_ids):
                    break
                
                vol_id = volume_ids[i + j]
                slice_files = volume_dict[vol_id]
                
                if len(slice_files) >= target_depth:
                    try:
                        vol, mask = load_volume_from_slices(slice_files, target_depth)
                        
                        if np.sum(mask > 0) > 500:
                            volumes.append(vol)
                            masks.append(mask)
                    except Exception as e:
                        print(f"Error loading {vol_id}: {e}")
                        continue
            
            if len(volumes) > 0:
                yield np.array(volumes), np.array(masks)

# Get files and group by volume
data_dir = './data/BraTS/BraTS2020_training_data/content/data'
all_files = glob.glob(f'{data_dir}/*.h5')
print(f"Total files: {len(all_files)}")

volume_dict = group_slices_by_volume(all_files)
print(f"Total volumes (patients): {len(volume_dict)}")

# Split volumes (not slices!) into train/val
volume_ids = list(volume_dict.keys())
np.random.seed(42)
np.random.shuffle(volume_ids)

split_idx = int(0.8 * len(volume_ids))
train_volume_ids = volume_ids[:split_idx]
val_volume_ids = volume_ids[split_idx:]

train_volume_dict = {vid: volume_dict[vid] for vid in train_volume_ids}
val_volume_dict = {vid: volume_dict[vid] for vid in val_volume_ids}

print(f"Train volumes: {len(train_volume_dict)}")
print(f"Val volumes: {len(val_volume_dict)}")

# Create generators
train_gen = volume_generator_3d(train_volume_dict, batch_size=1, target_depth=64)
val_gen = volume_generator_3d(val_volume_dict, batch_size=1, target_depth=64)

Total files: 57195
Total volumes (patients): 369
Train volumes: 295
Val volumes: 74


In [None]:
import random
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)

# Smaller 3D U-Net
def build_unet_3d(input_shape=(64, 192, 192, 4)):
    inputs = layers.Input(input_shape)
    
    # Encoder
    c1 = layers.Conv3D(8, (3, 3, 3), activation='relu', padding='same')(inputs)
    c1 = layers.Conv3D(8, (3, 3, 3), activation='relu', padding='same')(c1)
    c1 = layers.Dropout(0.1)(c1)
    p1 = layers.MaxPooling3D((2, 2, 2))(c1)
    
    c2 = layers.Conv3D(16, (3, 3, 3), activation='relu', padding='same')(p1)
    c2 = layers.Conv3D(16, (3, 3, 3), activation='relu', padding='same')(c2)
    c2 = layers.Dropout(0.1)(c2)
    p2 = layers.MaxPooling3D((2, 2, 2))(c2)
    
    # Bottleneck
    c3 = layers.Conv3D(32, (3, 3, 3), activation='relu', padding='same')(p2)
    c3 = layers.Conv3D(32, (3, 3, 3), activation='relu', padding='same')(c3)
    c3 = layers.Dropout(0.2)(c3)
    
    # Decoder
    u4 = layers.UpSampling3D((2, 2, 2))(c3)
    u4 = layers.concatenate([u4, c2])
    c4 = layers.Conv3D(16, (3, 3, 3), activation='relu', padding='same')(u4)
    c4 = layers.Conv3D(16, (3, 3, 3), activation='relu', padding='same')(c4)
    c4 = layers.Dropout(0.1)(c4)
    
    u5 = layers.UpSampling3D((2, 2, 2))(c4)
    u5 = layers.concatenate([u5, c1])
    c5 = layers.Conv3D(8, (3, 3, 3), activation='relu', padding='same')(u5)
    c5 = layers.Conv3D(8, (3, 3, 3), activation='relu', padding='same')(c5)
    c5 = layers.Dropout(0.1)(c5)
    
    outputs = layers.Conv3D(3, (1, 1, 1), activation='sigmoid')(c5)
    
    return models.Model(inputs, outputs)

model = build_unet_3d(input_shape=(64, 192, 192, 4))

# Loss functions
def dice_coef(y_true, y_pred, smooth=1e-5):
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

def focal_loss(y_true, y_pred, alpha=0.25, gamma=2.0):
    y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
    bce = -(y_true * tf.math.log(y_pred) + (1 - y_true) * tf.math.log(1 - y_pred))
    p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
    focal_weight = alpha * tf.pow(1 - p_t, gamma)
    return tf.reduce_mean(focal_weight * bce)

def combined_focal_dice(y_true, y_pred):
    return focal_loss(y_true, y_pred, alpha=0.25, gamma=0.2) + dice_loss(y_true, y_pred)

# Compile
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss=combined_focal_dice,
    metrics=[dice_coef]
)

print(model.summary())

# Callbacks
callbacks = [
    EarlyStopping(monitor='val_dice_coef', patience=15, mode='max', restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_dice_coef', factor=0.5, patience=5, mode='max', verbose=1),
    ModelCheckpoint('best_model_3d.h5', monitor='val_dice_coef', mode='max', save_best_only=True)
]

# Train
history = model.fit(
    train_gen,
    steps_per_epoch=len(train_volume_dict),  # One volume per step
    validation_data=val_gen,
    validation_steps=len(val_volume_dict),
    epochs=100,
    callbacks=callbacks,
    verbose=1
)

None
Epoch 1/100
[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7s/step - dice_coef: 0.0209 - loss: 1.0652



[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2109s[0m 7s/step - dice_coef: 0.0541 - loss: 1.0013 - val_dice_coef: 0.2397 - val_loss: 0.7750 - learning_rate: 1.0000e-04
Epoch 2/100
[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7s/step - dice_coef: 0.2695 - loss: 0.7509



[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2112s[0m 7s/step - dice_coef: 0.3065 - loss: 0.7183 - val_dice_coef: 0.4025 - val_loss: 0.6189 - learning_rate: 1.0000e-04
Epoch 3/100
[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6s/step - dice_coef: 0.3477 - loss: 0.6789



[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1905s[0m 6s/step - dice_coef: 0.3482 - loss: 0.6783 - val_dice_coef: 0.4214 - val_loss: 0.5984 - learning_rate: 1.0000e-04
Epoch 4/100
[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7s/step - dice_coef: 0.3666 - loss: 0.6583



[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1969s[0m 7s/step - dice_coef: 0.3686 - loss: 0.6552 - val_dice_coef: 0.4280 - val_loss: 0.5916 - learning_rate: 1.0000e-04
Epoch 5/100
[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8s/step - dice_coef: 0.3833 - loss: 0.6402



[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2465s[0m 8s/step - dice_coef: 0.4031 - loss: 0.6207 - val_dice_coef: 0.4399 - val_loss: 0.5825 - learning_rate: 1.0000e-04
Epoch 6/100
[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2503s[0m 8s/step - dice_coef: 0.4180 - loss: 0.6056 - val_dice_coef: 0.3279 - val_loss: 0.7016 - learning_rate: 1.0000e-04
Epoch 7/100
[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7s/step - dice_coef: 0.4235 - loss: 0.6005



[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2105s[0m 7s/step - dice_coef: 0.4254 - loss: 0.5975 - val_dice_coef: 0.4489 - val_loss: 0.5706 - learning_rate: 1.0000e-04
Epoch 8/100
[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8s/step - dice_coef: 0.4354 - loss: 0.5862



[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2324s[0m 8s/step - dice_coef: 0.4371 - loss: 0.5845 - val_dice_coef: 0.4633 - val_loss: 0.5573 - learning_rate: 1.0000e-04
Epoch 9/100
[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8s/step - dice_coef: 0.4520 - loss: 0.5677



[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2362s[0m 8s/step - dice_coef: 0.4624 - loss: 0.5566 - val_dice_coef: 0.4844 - val_loss: 0.5330 - learning_rate: 1.0000e-04
Epoch 10/100
[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7s/step - dice_coef: 0.4733 - loss: 0.5454



[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2046s[0m 7s/step - dice_coef: 0.4814 - loss: 0.5361 - val_dice_coef: 0.5275 - val_loss: 0.4891 - learning_rate: 1.0000e-04
Epoch 11/100
[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2205s[0m 7s/step - dice_coef: 0.5014 - loss: 0.5152 - val_dice_coef: 0.5059 - val_loss: 0.5103 - learning_rate: 1.0000e-04
Epoch 12/100
[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8s/step - dice_coef: 0.5159 - loss: 0.5015



[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2456s[0m 8s/step - dice_coef: 0.5185 - loss: 0.4999 - val_dice_coef: 0.5552 - val_loss: 0.4609 - learning_rate: 1.0000e-04
Epoch 13/100
[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8s/step - dice_coef: 0.5408 - loss: 0.4777



[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2264s[0m 8s/step - dice_coef: 0.5330 - loss: 0.4850 - val_dice_coef: 0.5680 - val_loss: 0.4489 - learning_rate: 1.0000e-04
Epoch 14/100
[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 167s/step - dice_coef: 0.5581 - loss: 0.4598  



[1m295/295[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49111s[0m 167s/step - dice_coef: 0.5445 - loss: 0.4738 - val_dice_coef: 0.5701 - val_loss: 0.4476 - learning_rate: 1.0000e-04
Epoch 15/100
[1m168/295[0m [32m━━━━━━━━━━━[0m[37m━━━━━━━━━[0m [1m56:03[0m 26s/step - dice_coef: 0.5577 - loss: 0.4606