In [7]:
!pip install nibabel opencv-python tensorflow
import os
import cv2
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from tensorflow.keras.utils import Sequence


Defaulting to user installation because normal site-packages is not writeable


In [8]:
IMG_SIZE = 128

def load_volume(path):
    return nib.load(path).get_fdata()

def preprocess_slice(img, mask):
    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    mask = cv2.resize(mask, (IMG_SIZE, IMG_SIZE))
    
    img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-6)
    mask = (mask > 0).astype(np.float32)
    
    return img, mask


In [9]:
class BraTSDataset(Sequence):
    def __init__(self, image_paths, mask_paths, batch_size=8):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.batch_size = batch_size

    def __len__(self):
        return len(self.image_paths) // self.batch_size

    def __getitem__(self, idx):
        imgs, masks = [], []
        
        for i in range(self.batch_size):
            vol = load_volume(self.image_paths[idx*self.batch_size + i])
            seg = load_volume(self.mask_paths[idx*self.batch_size + i])

            slice_idx = vol.shape[2] // 2
            img, mask = preprocess_slice(vol[:,:,slice_idx], seg[:,:,slice_idx])
            
            imgs.append(img[...,None])
            masks.append(mask[...,None])

        return np.array(imgs), np.array(masks)


In [10]:
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model

def unet():
    inputs = Input((IMG_SIZE, IMG_SIZE, 1))

    c1 = Conv2D(32,3,activation='relu',padding='same')(inputs)
    c1 = Conv2D(32,3,activation='relu',padding='same')(c1)
    p1 = MaxPooling2D()(c1)

    c2 = Conv2D(64,3,activation='relu',padding='same')(p1)
    c2 = Conv2D(64,3,activation='relu',padding='same')(c2)
    p2 = MaxPooling2D()(c2)

    c3 = Conv2D(128,3,activation='relu',padding='same')(p2)
    c3 = Conv2D(128,3,activation='relu',padding='same')(c3)

    u1 = UpSampling2D()(c3)
    u1 = Concatenate()([u1, c2])
    c4 = Conv2D(64,3,activation='relu',padding='same')(u1)

    u2 = UpSampling2D()(c4)
    u2 = Concatenate()([u2, c1])
    c5 = Conv2D(32,3,activation='relu',padding='same')(u2)

    outputs = Conv2D(1,1,activation='sigmoid')(c5)

    return Model(inputs, outputs)


In [11]:
import tensorflow as tf

def dice_coef(y_true, y_pred):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    intersection = tf.reduce_sum(y_true * y_pred)
    return (2.*intersection + 1) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + 1)

def iou(y_true, y_pred):
    intersection = tf.reduce_sum(y_true * y_pred)
    union = tf.reduce_sum(y_true + y_pred) - intersection
    return (intersection + 1) / (union + 1)


In [13]:
import kagglehub

DATASET_PATH = kagglehub.dataset_download(
    "dschettler8845/brats-2021-task1"
)
image_paths = []
mask_paths = []

for patient in os.listdir(DATASET_PATH):
    patient_dir = os.path.join(DATASET_PATH, patient)

    if os.path.isdir(patient_dir):
        files = os.listdir(patient_dir)

        flair = [f for f in files if f.endswith("_flair.nii.gz")][0]
        seg   = [f for f in files if f.endswith("_seg.nii.gz")][0]

        image_paths.append(os.path.join(patient_dir, flair))
        mask_paths.append(os.path.join(patient_dir, seg))

from sklearn.model_selection import train_test_split

train_imgs, val_imgs, train_masks, val_masks = train_test_split(
    image_paths,
    mask_paths,
    test_size=0.2,
    random_state=42
)

train_dataset = BraTSDataset(train_imgs, train_masks, batch_size=4)
val_dataset   = BraTSDataset(val_imgs, val_masks, batch_size=4)



ValueError: With n_samples=0, test_size=0.2 and train_size=None, the resulting train set will be empty. Adjust any of the aforementioned parameters.

In [None]:
model = unet()
model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=[dice_coef, iou]
)

history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=10
)


In [None]:
imgs, masks = val_dataset[0]
preds = model.predict(imgs)

plt.figure(figsize=(12,4))
for i in range(3):
    plt.subplot(3,3,3*i+1)
    plt.imshow(imgs[i].squeeze(), cmap='gray')
    plt.title("MRI")

    plt.subplot(3,3,3*i+2)
    plt.imshow(masks[i].squeeze(), cmap='gray')
    plt.title("Ground Truth")

    plt.subplot(3,3,3*i+3)
    plt.imshow(preds[i].squeeze()>0.5, cmap='gray')
    plt.title("Prediction")
plt.tight_layout()
plt.show()
