In [1]:
import h5py
import matplotlib.pyplot as plt
import numpy as np
import glob

def data_generator(file_list, batch_size=4):
    """
    Generator that loads and yields batches on-the-fly
    Avoids loading all data into RAM
    """
    while True:
        # Shuffle files each epoch
        np.random.shuffle(file_list)
        
        for i in range(0, len(file_list), batch_size):
            batch_files = file_list[i:i+batch_size]
            images = []
            masks = []
            
            for file_path in batch_files:
                try:
                    with h5py.File(file_path, 'r') as f:
                        image = f['image'][:]
                        mask = f['mask'][:]
                        
                        # Only keep slices with substantial tumor
                        if np.sum(mask > 0) > 100:
                            # Normalize per-slice
                            image = (image - image.min()) / (image.max() - image.min() + 1e-8)
                            mask = (mask > 0).astype(np.float32)
                            
                            images.append(image)
                            masks.append(mask)
                except Exception as e:
                    print(f"Error loading {file_path}: {e}")
                    continue
            
            # Only yield if we have data
            if len(images) > 0:
                yield np.array(images), np.array(masks)

# Get all files
data_dir = './data/BraTS/BraTS2020_training_data/content/data'
all_files = glob.glob(f'{data_dir}/*.h5')
print(f"Total files found: {len(all_files)}")

# Shuffle and split 80/20
np.random.seed(42)
np.random.shuffle(all_files)
split_idx = int(0.8 * len(all_files))

train_files = all_files[:split_idx]
val_files = all_files[split_idx:]

print(f"Train files: {len(train_files)}")
print(f"Val files: {len(val_files)}")

# Create generators
train_gen = data_generator(train_files, batch_size=4)
val_gen = data_generator(val_files, batch_size=4)

Total files found: 57195
Train files: 45756
Val files: 11439


In [None]:
# 1. Set seeds
import numpy as np
import tensorflow as tf
import random
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras import layers, models

np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)

# 2. Build model
def build_unet(input_shape=(240, 240, 4)):
    inputs = layers.Input(input_shape)
    
    # Encoder
    c1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(c1)
    c1 = layers.Dropout(0.1)(c1)
    p1 = layers.MaxPooling2D((2, 2))(c1)
    
    c2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c2)
    c2 = layers.Dropout(0.1)(c2)
    p2 = layers.MaxPooling2D((2, 2))(c2)
    
    # Bottleneck
    c3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(p2)
    c3 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c3)
    c3 = layers.Dropout(0.2)(c3)
    
    # Decoder
    u4 = layers.UpSampling2D((2, 2))(c3)
    u4 = layers.concatenate([u4, c2])
    c4 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(u4)
    c4 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c4)
    c4 = layers.Dropout(0.1)(c4)
    
    u5 = layers.UpSampling2D((2, 2))(c4)
    u5 = layers.concatenate([u5, c1])
    c5 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(u5)
    c5 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(c5)
    c5 = layers.Dropout(0.1)(c5)
    
    outputs = layers.Conv2D(3, (1, 1), activation='sigmoid')(c5)
    
    return models.Model(inputs, outputs)

model = build_unet()

# 3. Define losses
def dice_coef(y_true, y_pred, smooth=1e-5):
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

def focal_loss(y_true, y_pred, alpha=0.25, gamma=2.0):
    y_pred = tf.clip_by_value(y_pred, 1e-7, 1 - 1e-7)
    bce = -(y_true * tf.math.log(y_pred) + (1 - y_true) * tf.math.log(1 - y_pred))
    p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
    focal_weight = alpha * tf.pow(1 - p_t, gamma)
    return tf.reduce_mean(focal_weight * bce)

def combined_focal_dice(y_true, y_pred):
    return focal_loss(y_true, y_pred, alpha=0.25, gamma=0.2) + dice_loss(y_true, y_pred)

# 4. Compile
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss=combined_focal_dice,
    metrics=[dice_coef]
)

# 5. Callbacks
callbacks = [
    EarlyStopping(monitor='val_dice_coef', patience=20, mode='max', restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_dice_coef', factor=0.5, patience=5, mode='max', verbose=1),
    ModelCheckpoint('best_model.h5', monitor='val_dice_coef', mode='max', save_best_only=True)
]

# 6. Train with generators
history = model.fit(
    train_gen,
    steps_per_epoch=len(train_files) // 4,
    validation_data=val_gen,
    validation_steps=len(val_files) // 4,
    epochs=150,
    callbacks=callbacks,
    verbose=1
)

Epoch 1/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 302ms/step - dice_coef: 0.5520 - loss: 0.4573



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3647s[0m 319ms/step - dice_coef: 0.6364 - loss: 0.3687 - val_dice_coef: 0.7051 - val_loss: 0.2978 - learning_rate: 1.0000e-04
Epoch 2/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 235ms/step - dice_coef: 0.7063 - loss: 0.2965



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2864s[0m 250ms/step - dice_coef: 0.7161 - loss: 0.2861 - val_dice_coef: 0.7197 - val_loss: 0.2810 - learning_rate: 1.0000e-04
Epoch 3/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 243ms/step - dice_coef: 0.7418 - loss: 0.2598



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2957s[0m 259ms/step - dice_coef: 0.7435 - loss: 0.2579 - val_dice_coef: 0.7611 - val_loss: 0.2382 - learning_rate: 1.0000e-04
Epoch 4/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 297ms/step - dice_coef: 0.7559 - loss: 0.2458



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3571s[0m 312ms/step - dice_coef: 0.7598 - loss: 0.2417 - val_dice_coef: 0.7629 - val_loss: 0.2350 - learning_rate: 1.0000e-04
Epoch 5/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 240ms/step - dice_coef: 0.7710 - loss: 0.2319



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2923s[0m 255ms/step - dice_coef: 0.7723 - loss: 0.2295 - val_dice_coef: 0.7839 - val_loss: 0.2145 - learning_rate: 1.0000e-04
Epoch 6/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3026s[0m 265ms/step - dice_coef: 0.7797 - loss: 0.2191 - val_dice_coef: 0.7816 - val_loss: 0.2219 - learning_rate: 1.0000e-04
Epoch 7/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 242ms/step - dice_coef: 0.7864 - loss: 0.2130



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2953s[0m 258ms/step - dice_coef: 0.7913 - loss: 0.2084 - val_dice_coef: 0.8060 - val_loss: 0.1947 - learning_rate: 1.0000e-04
Epoch 8/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3042s[0m 266ms/step - dice_coef: 0.7965 - loss: 0.2029 - val_dice_coef: 0.8031 - val_loss: 0.1969 - learning_rate: 1.0000e-04
Epoch 9/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 247ms/step - dice_coef: 0.8014 - loss: 0.1979



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3003s[0m 263ms/step - dice_coef: 0.8040 - loss: 0.1953 - val_dice_coef: 0.8068 - val_loss: 0.1913 - learning_rate: 1.0000e-04
Epoch 10/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 243ms/step - dice_coef: 0.8056 - loss: 0.1944



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2953s[0m 258ms/step - dice_coef: 0.8089 - loss: 0.1908 - val_dice_coef: 0.8152 - val_loss: 0.1852 - learning_rate: 1.0000e-04
Epoch 11/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2801s[0m 245ms/step - dice_coef: 0.8174 - loss: 0.1828 - val_dice_coef: 0.8124 - val_loss: 0.1862 - learning_rate: 1.0000e-04
Epoch 12/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 230ms/step - dice_coef: 0.8216 - loss: 0.1784



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2803s[0m 245ms/step - dice_coef: 0.8212 - loss: 0.1782 - val_dice_coef: 0.8211 - val_loss: 0.1781 - learning_rate: 1.0000e-04
Epoch 13/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 564ms/step - dice_coef: 0.8262 - loss: 0.1742



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6624s[0m 579ms/step - dice_coef: 0.8249 - loss: 0.1754 - val_dice_coef: 0.8274 - val_loss: 0.1723 - learning_rate: 1.0000e-04
Epoch 14/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2937s[0m 257ms/step - dice_coef: 0.8290 - loss: 0.1710 - val_dice_coef: 0.8248 - val_loss: 0.1739 - learning_rate: 1.0000e-04
Epoch 15/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 232ms/step - dice_coef: 0.8320 - loss: 0.1671



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2825s[0m 247ms/step - dice_coef: 0.8342 - loss: 0.1659 - val_dice_coef: 0.8346 - val_loss: 0.1650 - learning_rate: 1.0000e-04
Epoch 16/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2805s[0m 245ms/step - dice_coef: 0.8363 - loss: 0.1637 - val_dice_coef: 0.8282 - val_loss: 0.1702 - learning_rate: 1.0000e-04
Epoch 17/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2808s[0m 245ms/step - dice_coef: 0.8384 - loss: 0.1613 - val_dice_coef: 0.8305 - val_loss: 0.1681 - learning_rate: 1.0000e-04
Epoch 18/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2795s[0m 244ms/step - dice_coef: 0.8425 - loss: 0.1576 - val_dice_coef: 0.8341 - val_loss: 0.1643 - learning_rate: 1.0000e-04
Epoch 19/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 230ms/step - dice_coef: 0.8440 - loss: 0.1557



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2808s[0m 245ms/step - dice_coef: 0.8447 - loss: 0.1556 - val_dice_coef: 0.8421 - val_loss: 0.1578 - learning_rate: 1.0000e-04
Epoch 20/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2820s[0m 247ms/step - dice_coef: 0.8480 - loss: 0.1521 - val_dice_coef: 0.8406 - val_loss: 0.1592 - learning_rate: 1.0000e-04
Epoch 21/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2796s[0m 244ms/step - dice_coef: 0.8492 - loss: 0.1511 - val_dice_coef: 0.8406 - val_loss: 0.1563 - learning_rate: 1.0000e-04
Epoch 22/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2821s[0m 247ms/step - dice_coef: 0.8511 - loss: 0.1487 - val_dice_coef: 0.8396 - val_loss: 0.1604 - learning_rate: 1.0000e-04
Epoch 23/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 232ms/step - dice_coef: 0.8532 - loss: 0.1468



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2835s[0m 248ms/step - dice_coef: 0.8540 - loss: 0.1463 - val_dice_coef: 0.8483 - val_loss: 0.1513 - learning_rate: 1.0000e-04
Epoch 24/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2802s[0m 245ms/step - dice_coef: 0.8566 - loss: 0.1437 - val_dice_coef: 0.8479 - val_loss: 0.1516 - learning_rate: 1.0000e-04
Epoch 25/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2805s[0m 245ms/step - dice_coef: 0.8572 - loss: 0.1429 - val_dice_coef: 0.8424 - val_loss: 0.1552 - learning_rate: 1.0000e-04
Epoch 26/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 231ms/step - dice_coef: 0.8580 - loss: 0.1424



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2814s[0m 246ms/step - dice_coef: 0.8588 - loss: 0.1413 - val_dice_coef: 0.8554 - val_loss: 0.1460 - learning_rate: 1.0000e-04
Epoch 27/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2788s[0m 244ms/step - dice_coef: 0.8611 - loss: 0.1390 - val_dice_coef: 0.8473 - val_loss: 0.1521 - learning_rate: 1.0000e-04
Epoch 28/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2739s[0m 239ms/step - dice_coef: 0.8629 - loss: 0.1378 - val_dice_coef: 0.8512 - val_loss: 0.1478 - learning_rate: 1.0000e-04
Epoch 29/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2738s[0m 239ms/step - dice_coef: 0.8637 - loss: 0.1370 - val_dice_coef: 0.8541 - val_loss: 0.1444 - learning_rate: 1.0000e-04
Epoch 30/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 231ms/step - dice_coef: 0.8661 - loss: 0.1351



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2814s[0m 246ms/step - dice_coef: 0.8651 - loss: 0.1355 - val_dice_coef: 0.8582 - val_loss: 0.1404 - learning_rate: 1.0000e-04
Epoch 31/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2808s[0m 245ms/step - dice_coef: 0.8658 - loss: 0.1345 - val_dice_coef: 0.8573 - val_loss: 0.1432 - learning_rate: 1.0000e-04
Epoch 32/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2844s[0m 249ms/step - dice_coef: 0.8678 - loss: 0.1330 - val_dice_coef: 0.8563 - val_loss: 0.1433 - learning_rate: 1.0000e-04
Epoch 33/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 232ms/step - dice_coef: 0.8702 - loss: 0.1308



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2841s[0m 248ms/step - dice_coef: 0.8690 - loss: 0.1319 - val_dice_coef: 0.8598 - val_loss: 0.1402 - learning_rate: 1.0000e-04
Epoch 34/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9088s[0m 794ms/step - dice_coef: 0.8707 - loss: 0.1297 - val_dice_coef: 0.8556 - val_loss: 0.1425 - learning_rate: 1.0000e-04
Epoch 35/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 240ms/step - dice_coef: 0.8681 - loss: 0.1319



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2921s[0m 255ms/step - dice_coef: 0.8691 - loss: 0.1310 - val_dice_coef: 0.8602 - val_loss: 0.1388 - learning_rate: 1.0000e-04
Epoch 36/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3001s[0m 262ms/step - dice_coef: 0.8711 - loss: 0.1292 - val_dice_coef: 0.8596 - val_loss: 0.1425 - learning_rate: 1.0000e-04
Epoch 37/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 237ms/step - dice_coef: 0.8730 - loss: 0.1276



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2891s[0m 253ms/step - dice_coef: 0.8730 - loss: 0.1276 - val_dice_coef: 0.8605 - val_loss: 0.1381 - learning_rate: 1.0000e-04
Epoch 38/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 235ms/step - dice_coef: 0.8736 - loss: 0.1274



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2860s[0m 250ms/step - dice_coef: 0.8742 - loss: 0.1268 - val_dice_coef: 0.8635 - val_loss: 0.1365 - learning_rate: 1.0000e-04
Epoch 39/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 234ms/step - dice_coef: 0.8756 - loss: 0.1257



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2848s[0m 249ms/step - dice_coef: 0.8756 - loss: 0.1254 - val_dice_coef: 0.8653 - val_loss: 0.1344 - learning_rate: 1.0000e-04
Epoch 40/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 234ms/step - dice_coef: 0.8750 - loss: 0.1260



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2847s[0m 249ms/step - dice_coef: 0.8757 - loss: 0.1252 - val_dice_coef: 0.8658 - val_loss: 0.1345 - learning_rate: 1.0000e-04
Epoch 41/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 233ms/step - dice_coef: 0.8769 - loss: 0.1241



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2843s[0m 249ms/step - dice_coef: 0.8762 - loss: 0.1246 - val_dice_coef: 0.8667 - val_loss: 0.1336 - learning_rate: 1.0000e-04
Epoch 42/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2807s[0m 245ms/step - dice_coef: 0.8758 - loss: 0.1246 - val_dice_coef: 0.8652 - val_loss: 0.1343 - learning_rate: 1.0000e-04
Epoch 43/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2940s[0m 257ms/step - dice_coef: 0.8785 - loss: 0.1223 - val_dice_coef: 0.8551 - val_loss: 0.1447 - learning_rate: 1.0000e-04
Epoch 44/150
[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 239ms/step - dice_coef: 0.8800 - loss: 0.1215



[1m11439/11439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2909s[0m 254ms/step - dice_coef: 0.8792 - loss: 0.1219 - val_dice_coef: 0.8694 - val_loss: 0.1306 - learning_rate: 1.0000e-04
Epoch 45/150
[1m  101/11439[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m43:17[0m 229ms/step - dice_coef: 0.8859 - loss: 0.1171

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.legend()
plt.title('Loss')
plt.subplot(1, 2, 2)
plt.plot(history.history['dice_coef'], label='Train Dice')
plt.plot(history.history['val_dice_coef'], label='Val Dice')
plt.legend()
plt.title('Dice Coefficient')
plt.show()

In [None]:
# Predict on validation set
predictions = model.predict(X_val)

# Visualize
fig, axes = plt.subplots(3, 3, figsize=(12, 12))

for i in range(3):
    axes[i, 0].imshow(X_val[i, :, :, 0], cmap='gray')
    axes[i, 0].set_title('MRI Input')
    
    axes[i, 1].imshow(y_val[i, :, :, 0], cmap='Reds')
    axes[i, 1].set_title('Ground Truth')
    
    axes[i, 2].imshow(predictions[i, :, :, 0] > 0.5, cmap='Reds')
    axes[i, 2].set_title('Prediction')

plt.tight_layout()
plt.show()