In [22]:
!pip install opendatasets



In [24]:
import opendatasets as od
od.download("https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation/data")

Skipping, found downloaded files in "./lgg-mri-segmentation" (use force=True to force download)


In [47]:
# ============================================================
# LIBRARIES
# ============================================================

import os
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.style.use("ggplot")
%matplotlib inline

import cv2
from glob import glob
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint

# ============================================================
# PARAMETERS
# ============================================================

im_width = 256
im_height = 256

mask_files = glob('/content/lgg-mri-segmentation/kaggle_3m/*/*_mask*')
train_files = [i.replace('_mask','') for i in mask_files]

df = pd.DataFrame({"filename": train_files, "mask": mask_files})

def positiv_negativ_diagnosis(mask_path):
    value = np.max(cv2.imread(mask_path))
    return 1 if value > 0 else 0

df["diagnosis"] = df["mask"].apply(positiv_negativ_diagnosis)

df_train, df_test = train_test_split(df, test_size=0.1, random_state=42)
df_train, df_val = train_test_split(df_train, test_size=0.2, random_state=42)

# ============================================================
# DATA GENERATOR
# ============================================================

def adjust_data(img, mask):
    img = img / 255.
    mask = mask / 255.
    mask[mask > 0.5] = 1
    mask[mask <= 0.5] = 0
    return img, mask

def train_generator(df, batch_size=16, aug_dict={}):
    # Oversample positive masks
    pos_df = df[df.diagnosis==1]
    neg_df = df[df.diagnosis==0]
    df_balanced = pd.concat([neg_df, pos_df.sample(len(neg_df), replace=True)])

    image_datagen = ImageDataGenerator(**aug_dict)
    mask_datagen = ImageDataGenerator(**aug_dict)

    image_gen = image_datagen.flow_from_dataframe(
        df_balanced, x_col="filename", class_mode=None,
        color_mode="rgb", target_size=(im_height,im_width),
        batch_size=batch_size, seed=42
    )

    mask_gen = mask_datagen.flow_from_dataframe(
        df_balanced, x_col="mask", class_mode=None,
        color_mode="grayscale", target_size=(im_height,im_width),
        batch_size=batch_size, seed=42
    )

    for img, mask in zip(image_gen, mask_gen):
        img, mask = adjust_data(img, mask)
        yield img, mask

val_gen = train_generator(df_val, batch_size=16, aug_dict={})
test_gen = train_generator(df_test, batch_size=16, aug_dict={})

# ============================================================
# METRICS AND LOSS
# ============================================================

def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2.*intersection + smooth)/(K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1 - dice_coef(y_true, y_pred)

def bce_dice_loss(y_true, y_pred):
    bce = tf.keras.losses.BinaryCrossentropy()(y_true, y_pred)
    return bce + dice_loss(y_true, y_pred)

def iou(y_true, y_pred, smooth=1):
    intersection = K.sum(y_true * y_pred)
    union = K.sum(y_true + y_pred) - intersection
    return (intersection + smooth)/(union + smooth)

# ============================================================
# MODEL ARCHITECTURE (~4M PARAMS)
# ============================================================

def dynamic_selector(x, reduction=8):
    channels = x.shape[-1]
    se = GlobalAveragePooling2D()(x)
    se = Dense(channels//reduction, activation='relu')(se)
    se = Dense(channels, activation='sigmoid')(se)
    se = Reshape((1,1,channels))(se)
    return Multiply()([x, se])

def nas_decoder_block(x, skip, filters):
    c1 = Conv2D(filters,3,padding="same",activation="relu")(x)
    c2 = SeparableConv2D(filters,3,padding="same",activation="relu")(x)
    concat = Concatenate()([c1,c2])

    w = GlobalAveragePooling2D()(concat)
    w = Dense(2, activation="softmax")(w)
    w = Reshape((1,1,2))(w)

    f1, f2 = Lambda(lambda t: tf.split(t,2,axis=-1))(concat)
    selected = w[...,0:1]*f1 + w[...,1:2]*f2

    skip = dynamic_selector(skip)
    return Concatenate()([selected, skip])

def enc_block(x, filters):
    x = Conv2D(filters,3,padding="same",activation="relu")(x)
    x = Conv2D(filters,3,padding="same",activation="relu")(x)
    skip = x
    x = MaxPooling2D()(x)
    return x, skip

def EffiDec3D_Medium(input_shape=(256,256,3), filters=[32,64,128,256]):
    inputs = Input(input_shape)
    x, s1 = enc_block(inputs, filters[0])
    x, s2 = enc_block(x, filters[1])
    x, s3 = enc_block(x, filters[2])

    x = Conv2D(filters[3],3,padding="same",activation="relu")(x)

    x = UpSampling2D()(x)
    x = nas_decoder_block(x, s3, filters[2])
    x = UpSampling2D()(x)
    x = nas_decoder_block(x, s2, filters[1])
    x = UpSampling2D()(x)
    x = nas_decoder_block(x, s1, filters[0])

    outputs = Conv2D(1,1,activation="sigmoid")(x)
    return Model(inputs, outputs, name="EffiDec3D_Medium")

# ============================================================
# TRAINING
# ============================================================

model = EffiDec3D_Medium((256,256,3))
model.compile(
    optimizer=Adam(3e-4),
    loss=bce_dice_loss,
    metrics=["binary_accuracy", dice_coef, iou]
)

callbacks = [ModelCheckpoint("EffiDec3D_Medium_best.keras", save_best_only=True)]

train_gen = train_generator(df_train, batch_size=16, aug_dict={
    "rotation_range":0.2,
    "width_shift_range":0.05,
    "height_shift_range":0.05,
    "zoom_range":0.05,
    "horizontal_flip":True
})

history = model.fit(
    train_gen,
    steps_per_epoch=len(df_train)//16,
    validation_data=val_gen,
    validation_steps=len(df_val)//16,
    epochs=50,
    callbacks=callbacks
)


Found 3710 validated image filenames.
Found 3710 validated image filenames.
Epoch 1/50
[1m176/176[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 485ms/step - binary_accuracy: 0.9814 - dice_coef: 0.1250 - iou: 0.0763 - loss: 1.1386Found 896 validated image filenames.
Found 896 validated image filenames.
[1m176/176[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m154s[0m 572ms/step - binary_accuracy: 0.9814 - dice_coef: 0.1259 - iou: 0.0769 - loss: 1.1369 - val_binary_accuracy: 0.9890 - val_dice_coef: 0.4761 - val_iou: 0.3296 - val_loss: 0.5806
Epoch 2/50
[1m176/176[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m129s[0m 733ms/step - binary_accuracy: 0.9870 - dice_coef: 0.5092 - iou: 0.3556 - loss: 0.5529 - val_binary_accuracy: 0.9892 - val_dice_coef: 0.6340 - val_iou: 0.4749 - val_loss: 0.4200
Epoch 3/50
[1m176/176[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m83s[0m 474ms/step - binary_accuracy: 0.9874 - dice_coef: 0.5530 - iou: 0.3979 - loss: 0.5124 - val_binary_accuracy: 

In [48]:

# ============================================================
# MODEL PARAMETERS, GFLOPs & MEMORY
# ============================================================

# Parameters
total_params = model.count_params()
print(f"Total Parameters: {total_params/1e6:.2f}M")

# Memory usage
def memory_usage(model, batch=1):
    weights_bytes = model.count_params() * 4
    weights_mb = weights_bytes / (1024**2)

    activation_bytes = 0
    for layer in model.layers:
        try:
            out_shape = layer.output_shape
        except:
            continue
        if isinstance(out_shape, list):
            continue
        if None in out_shape:
            continue
        activation_bytes += np.prod(out_shape)*4
    activation_mb = activation_bytes / (1024**2)
    return weights_mb, activation_mb, weights_mb + activation_mb

weights_mb, activations_mb, total_mb = memory_usage(model)
print(f"Weight Memory: {weights_mb:.2f} MB")
print(f"Total Estimated Memory: {total_mb:.2f} MB")

# GFLOPs calculation
def get_gflops(model, input_shape=(1,256,256,3)):
    inputs = tf.random.uniform(input_shape)
    @tf.function
    def forward(x): return model(x)
    concrete_func = forward.get_concrete_function(inputs)

    from tensorflow.python.profiler.model_analyzer import profile
    from tensorflow.python.profiler.option_builder import ProfileOptionBuilder

    profiler_options = ProfileOptionBuilder.float_operation()
    graph_info = profile(concrete_func.graph, options=profiler_options)

    flops = graph_info.total_float_ops
    return flops/1e9

gflops = get_gflops(model)
print(f"GFLOPs: {gflops:.3f}")


Total Parameters: 1.13M
Weight Memory: 4.30 MB
Total Estimated Memory: 4.30 MB
GFLOPs: 19.269
