# Problem definiton
**Segmentation of gliomas in pre-operative MRI scans.**

*Each pixel on image must be labeled:*
* Pixel is part of a tumor area (1 or 2 or 3) -> can be one of multiple classes / sub-regions
* Anything else -> pixel is not on a tumor region (0)

The sub-regions of tumor considered for evaluation are: 1) the "enhancing tumor" (ET), 2) the "tumor core" (TC), and 3) the "whole tumor" (WT)
The provided segmentation labels have values of 1 for NCR & NET, 2 for ED, 4 for ET, and 0 for everything else.



# ![Brats official annotations](https://www.med.upenn.edu/cbica/assets/user-content/images/BraTS/brats-tumor-subregions.jpg)

# Setup env

# MultiModal Fusion with Convolutional Variational AutoEncoder from BraTS20 dataset
### modals used=4

# Calculation of Parameters for Each Layer in the Encoder

Let me break down the parameter calculations for each layer in the encoder model:

## conv3d_20 (Conv3D): 6,976 parameters
- Input shape: (240, 240, 48, 4) - A 3D volume with 4 channels
- Filter size: Likely 3×3×3 with 64 output filters
- Parameters = (3×3×3×4 + 1) × 64 = (108 + 1) × 64 = 109 × 64 = 6,976
  - 3×3×3: Filter dimensions
  - 4: Input channels
  - +1: Bias for each filter
  - ×64: Number of filters

## max_pooling3d_15 (MaxPooling3D): 0 parameters
- Pooling operations don't have trainable parameters

## batch_normalization_30 (BatchNorm): 256 parameters
- Normalizes 64 feature maps
- Parameters = 64 × 4 = 256
  - 2 learned parameters per channel (scale and shift)
  - 2 running statistics per channel (mean and variance)

## conv3d_21 (Conv3D): 221,312 parameters
- Input: 64 channels from previous layer
- Filter size: 3×3×3 with 128 output filters
- Parameters = (3×3×3×64 + 1) × 128 = (1,728 + 1) × 128 = 1,729 × 128 = 221,312

## max_pooling3d_16 (MaxPooling3D): 0 parameters
- No trainable parameters

## batch_normalization_31 (BatchNorm): 512 parameters
- Normalizes 128 feature maps
- Parameters = 128 × 4 = 512

## conv3d_22 (Conv3D): 884,992 parameters
- Input: 128 channels
- Filter size: 3×3×3 with 256 output filters
- Parameters = (3×3×3×128 + 1) × 256 = (3,456 + 1) × 256 = 3,457 × 256 = 884,992

## max_pooling3d_17 (MaxPooling3D): 0 parameters
- No trainable parameters

## batch_normalization_32 (BatchNorm): 1,024 parameters
- Normalizes 256 feature maps
- Parameters = 256 × 4 = 1,024

## global_average_pooling3d_5 (GlobalAveragePooling3D): 0 parameters
- Pooling operations don't have trainable parameters

## z_mean (Dense): 65,792 parameters
- Input: 256 features from global pooling
- Output: 256 latent dimensions
- Parameters = 256 × 256 + 256 = 65,792
  - 256 × 256: Weight matrix
  - 256: Bias terms

## z_log_var (Dense): 65,792 parameters
- Input: 256 features
- Output: 256 latent dimensions
- Parameters = 256 × 256 + 256 = 65,792

## sampling_5 (Sampling): 0 parameters
- Custom layer that performs the reparameterization trick
- No trainable parameters

## Total Parameters
- Sum of all parameters = 6,976 + 0 + 256 + 221,312 + 0 + 512 + 884,992 + 0 + 1,024 + 0 + 65,792 + 65,792 + 0 = 1,246,656
- Trainable: 1,245,760 (all except some batch norm statistics)
- Non-trainable: 896 (running statistics in batch normalization layers)

This parameter distribution shows a typical CNN pattern where the convolutional layers contain the bulk of the parameters, especially as channel depth increases through the network.

# Decoder Parameter Calculations

Let's break down the parameter calculations for each layer in the decoder:

1. **Dense Layer**: `layers.Dense(15*15*6*256, activation='relu')(latent_inputs)`
   - Input: 256 (latent_dim)
   - Output: 15×15×6×256 = 345,600 neurons
   - Parameters = (256 × 345,600) + 345,600 = 88,819,200
     - 256 × 345,600: Weight matrix
     - 345,600: Bias terms
   - This is by far the largest parameter consumer in the entire model

2. **Reshape Layer**: No parameters, just reorganizes the data

3. **Conv3DTranspose (256 filters)**:
   - Input shape: (15, 15, 6, 256)
   - Filter size: 3×3×3 with 256 output filters
   - Parameters = (3×3×3×256 + 1) × 256 = (6,912 + 1) × 256 = 1,769,728

4. **BatchNormalization after first Conv3DTranspose**:
   - 256 feature maps × 4 parameters per map = 1,024 parameters
   - Trainable: 512 (gamma, beta)
   - Non-trainable: 512 (running mean, variance)

5. **Conv3DTranspose (128 filters)**:
   - Input: 256 channels
   - Output: 128 filters
   - Parameters = (3×3×3×256 + 1) × 128 = (6,912 + 1) × 128 = 884,864

6. **BatchNormalization after second Conv3DTranspose**:
   - 128 feature maps × 4 parameters = 512 parameters
   - Trainable: 256
   - Non-trainable: 256

7. **Conv3DTranspose (64 filters)**:
   - Input: 128 channels
   - Output: 64 filters
   - Parameters = (3×3×3×128 + 1) × 64 = (3,456 + 1) × 64 = 221,248

8. **BatchNormalization after third Conv3DTranspose**:
   - 64 feature maps × 4 parameters = 256 parameters
   - Trainable: 128
   - Non-trainable: 128

9. **Conv3DTranspose (32 filters)**:
   - Input: 64 channels
   - Output: 32 filters
   - Parameters = (3×3×3×64 + 1) × 32 = (1,728 + 1) × 32 = 55,328

10. **Final Conv3D (4 filters)**:
    - Input: 32 channels
    - Output: 4 channels
    - Parameters = (3×3×3×32 + 1) × 4 = (864 + 1) × 4 = 3,460

Total decoder parameters = 91,755,620 (as shown in the model summary)

## Why Sigmoid Activation in the Output Layer?

The sigmoid activation function in the output layer serves several important purposes:

1. **Range Constraint**: Sigmoid outputs values between 0 and 1, which matches the range of the input data after min-max normalization. The preprocessing step normalized all MRI values to [0,1], so the output needs to produce values in the same range.

2. **Intensity Reconstruction**: In medical imaging, pixel/voxel intensities represent physical properties. The sigmoid function ensures that reconstructed intensities remain within a valid range without clipping.

3. **Probabilistic Interpretation**: In some contexts, output values can be interpreted as probabilities or confidence levels of voxel intensity, which the sigmoid naturally provides.

4. **Smooth Gradients**: Sigmoid provides smooth gradients that are beneficial during training, especially for reconstruction tasks.

An alternative could have been tanh (outputs -1 to 1) with appropriate rescaling, but since the input data is normalized to [0,1], sigmoid is the natural choice that directly matches this range.

The dtype specification of 'float32' ensures that even when using mixed precision training, the output maintains full 32-bit precision, which is important for accurate reconstruction of medical images where small intensity differences can be clinically significant.

In [None]:
import os
import cv2
import glob
import PIL
import shutil
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from skimage import data
from skimage.util import montage 
import skimage.transform as skTrans
from skimage.transform import rotate
from skimage.transform import resize
from PIL import Image, ImageOps  


# neural imaging
import nilearn as nl
import nibabel as nib
import nilearn.plotting as nlplt
!pip install git+https://github.com/miykael/gif_your_nifti # nifti to gif 
import gif_your_nifti.core as gif2nif


# ml libs
import keras
import keras.backend as K
from keras.callbacks import CSVLogger
import tensorflow as tf
from tensorflow.keras.utils import plot_model
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, TensorBoard
from tensorflow.keras.layers.experimental import preprocessing


# Make numpy printouts easier to read.
np.set_printoptions(precision=3, suppress=True)

In [None]:
TRAIN_DATASET_PATH = '../input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/'
VALIDATION_DATASET_PATH = '../inp|ut/brats20-dataset-training-validation/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData'

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Parameters
VOLUME_SLICES = 96
VOLUME_START_AT = 22
INPUT_SHAPE = (240, 240, VOLUME_SLICES, 4)
latent_dim = 256
batch_size = 1

# Sampling layer with explicit dtype handling
class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        
        # Ensure epsilon matches z_mean dtype
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim), dtype=z_mean.dtype)
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

 # Encoder with dtype policy awareness
 # padding=same :ensure that the output feature maps have the same spatial dimensions (height and width, or height, width, and depth in 3D) as the input volume
def create_encoder(input_shape):
     encoder_inputs = keras.Input(shape=input_shape)
    x = layers.Conv3D(64, 3, activation='relu', padding='same')(encoder_inputs)
    x = layers.MaxPooling3D((2, 2, 2))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.2)(x)
    
    x = layers.Conv3D(128, 3, activation='relu', padding='same')(x)
    x = layers.MaxPooling3D((2, 2, 2))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)

    x = layers.Conv3D(128, 3, activation='relu', padding='same')(x)
    x = layers.MaxPooling3D((2, 2, 2))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    
    x = layers.Conv3D(256, 3, activation='relu', padding='same')(x)
    x = layers.MaxPooling3D((2, 2, 2))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.4)(x)

    x = layers.GlobalAveragePooling3D()(x)
    # resnet_model = Sequential()

    # pretrained_model= tf.keras.applications.ResNet50(include_top=False,
    #                    input_shape=encoder_inputs,
    #                    pooling='avg',
    #                    weights='imagenet')
    # for layer in pretrained_model.layers:
    #         layer.trainable=False
    
    #     # Apply the pretrained model to the inputs
    # x = pretrained_model(inputs)
    
    # # Flatten the output
    # x = layers.Flatten()(x)
    # x = layers.Dense()
    
    
    
    # Ensure output layers match policy
    z_mean = layers.Dense(latent_dim, name='z_mean', dtype='float32')(x)
    z_log_var = layers.Dense(latent_dim, name='z_log_var', dtype='float32')(x)
    z = Sampling()([z_mean, z_log_var])
    
    return keras.Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder')

# Decoder with dtype policy awareness
def create_decoder(output_shape):
    latent_inputs = keras.Input(shape=(latent_dim,), dtype='float32')
    
    x = layers.Dense(15*15*6*256, activation='relu')(latent_inputs)
    x = layers.Reshape((15,15,6,256))(x)
    x = layers.Dropout(0.4)(x)
    
    x = layers.Conv3DTranspose(256, 3, strides=2, activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    
    x = layers.Conv3DTranspose(128, 3, strides=2, activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)

    x = layers.Conv3DTranspose(128, 3, strides=2, activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.2)(x)
    
    x = layers.Conv3DTranspose(64, 3, strides=2, activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)

    decoder_outputs = layers.Conv3D(output_shape[-1], 3, activation='sigmoid', padding='same', dtype='float32')(x)
    
    # Dynamic cropping with dtype awareness
    def crop_to_match(x):
        input_shape = tf.shape(x)
        crop_h = (input_shape[1] - output_shape[0]) // 2
        crop_w = (input_shape[2] - output_shape[1]) // 2
        crop_d = (input_shape[3] - output_shape[2]) // 2
        
        return x[:, 
               tf.maximum(crop_h, 0):tf.maximum(crop_h, 0)+output_shape[0],
               tf.maximum(crop_w, 0):tf.maximum(crop_w, 0)+output_shape[1],
               tf.maximum(crop_d, 0):tf.maximum(crop_d, 0)+output_shape[2],
               :]
    
    decoder_outputs = layers.Lambda(crop_to_match, dtype='float32')(decoder_outputs)
    
    return keras.Model(latent_inputs, decoder_outputs, name='decoder')

# CVAE model with mixed precision handling
class CVAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
        
    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstruction = self.decoder(z)
        return reconstruction

    
    def train_step(self, data):
        if isinstance(data, tuple):
            data = data[0]
            
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            
            # Cast to float32 for loss calculation if using mixed precision
            data_f32 = tf.cast(data, tf.float32)
            reconstruction_f32 = tf.cast(reconstruction, tf.float32)
            
            reconstruction_loss = tf.reduce_mean(
                keras.losses.mean_squared_error(data_f32, reconstruction_f32)
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
            
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        
        return {m.name: m.result() for m in self.metrics}



In [None]:
trained_cvae.build((None, *INPUT_SHAPE))
# Now summary will work
print("="*80)
print("ENCODER SUMMARY:")
encoder.summary()

print("\n" + "="*80)
print("DECODER SUMMARY:")
decoder.summary()

print("\n" + "="*80)
print("FULL VAE SUMMARY:")
cvae.summary()

In [None]:

# Training data loading
def load_brats_data(path, patient_id):
    modalities = []
    for mod in ['flair', 't1', 't1ce', 't2']:
        vol = nib.load(f"{path}/BraTS20_Training_{patient_id:03d}/BraTS20_Training_{patient_id:03d}_{mod}.nii").get_fdata()
        vol = (vol - vol.min()) / (vol.max() - vol.min())
        modalities.append(vol[..., VOLUME_START_AT:VOLUME_START_AT+VOLUME_SLICES])
    return np.stack(modalities, axis=-1)


def train_cvae(train_data):
    # adding the early stoping and dropout for better performance (date 19-04-25)

    if len(train_data) == 0:
        raise ValueError("No training data could be loaded. Please check the dataset path.")
    
    # Create data generators with batching
    train_dataset = tf.data.Dataset.from_tensor_slices(train_data).batch(batch_size).shuffle(buffer_size=100).prefetch(tf.data.AUTOTUNE)
    test_dataset = tf.data.Dataset.from_tensor_slices(test_data).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    
    # Create and compile model
    encoder = create_encoder(INPUT_SHAPE)
    decoder = create_decoder(INPUT_SHAPE)
    cvae = CVAE(encoder, decoder)
    
    # Configure mixed precision
    policy = tf.keras.mixed_precision.Policy('mixed_float16')
    tf.keras.mixed_precision.set_global_policy(policy)
    
    # Optimizer with learning rate schedule
    lr_schedule = keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=1e-4,
        decay_steps=10000,
        decay_rate=0.9)
    
    optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)
    cvae.compile(optimizer=optimizer)
    
    # Callbacks
    early_stopping = keras.callbacks.EarlyStopping(
        monitor='val_total_loss',
        patience=15,
        restore_best_weights=True,
        min_delta=0.001)
    
    checkpoint = keras.callbacks.ModelCheckpoint(
        'best_cvae_model.h5',
        monitor='val_total_loss',
        save_best_only=True)

    # Training loop
    history = cvae.fit(
        train_data,
        epochs=60,
        batch_size=batch_size,
        shuffle=True
    )

    return cvae, history

In [None]:
train_data = np.array([load_brats_data(TRAIN_DATASET_PATH ,i)for i in range(1,31)])
len(train_data)
# Train-test split
train_data, test_data = train_test_split(train_data, test_size=0.2, random_state=42)
len(train_data),len(test_data)

In [None]:
trained_cvae, training_history = train_cvae(train_data)
# Save the final model
trained_cvae.save_weights('cvae_final_weights.h5')

In [None]:
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np

def generate_and_save_images(model, input_data, output_path, num_samples=3):
    """
    Generate reconstructed images from input data and save them as NIfTI files
    and 2D slices as PNG images.
    
    Args:
        model: Trained CVAE model
        input_data: Input MRI volumes (shape: [num_samples, 240, 240, 48, 4])
        output_path: Directory to save outputs
        num_samples: Number of samples to generate and visualize
    """
    # Ensure output directory exists
    os.makedirs(output_path, exist_ok=True)
    
    # Get reconstructions
    _, _, z = model.encoder.predict(input_data[:num_samples])
    reconstructions = model.decoder.predict(z)
    
    # Convert to numpy array and ensure float32
    reconstructions = np.array(reconstructions, dtype=np.float32)
    
    # Save each modality separately
    modalities = ['flair', 't1', 't1ce', 't2']
    
    for i in range(num_samples):
        # Save original and reconstructed as NIfTI
        for mod_idx, mod in enumerate(modalities):
            # Original
            orig_vol = nib.Nifti1Image(input_data[i,...,mod_idx], np.eye(4))
            nib.save(orig_vol, f"{output_path}/sample_{i}_original_{mod}.nii.gz")
            
            # Reconstructed
            recon_vol = nib.Nifti1Image(reconstructions[i,...,mod_idx], np.eye(4))
            nib.save(recon_vol, f"{output_path}/sample_{i}_reconstructed_{mod}.nii.gz")
        
        # Visualize slices
        visualize_slices(input_data[i], reconstructions[i], 
                         save_path=f"{output_path}/sample_{i}_comparison.png")

def visualize_slices(original, reconstructed, save_path=None, num_slices=3):
    """
    Visualize comparison between original and reconstructed slices.
    
    Args:
        original: Original volume (240, 240, 48, 4)
        reconstructed: Reconstructed volume (240, 240, 48, 4)
        save_path: Path to save the figure (if None, shows interactively)
        num_slices: Number of slices to display
    """
    modalities = ['FLAIR', 'T1', 'T1ce', 'T2']
    slice_indices = np.linspace(0, original.shape[2]-1, num_slices, dtype=int)
    
    plt.figure(figsize=(20, 6*num_slices))
    
    for i, slice_idx in enumerate(slice_indices):
        for mod_idx, mod in enumerate(modalities):
            # Original
            plt.subplot(num_slices, 8, i*8 + mod_idx*2 + 1)
            plt.imshow(original[..., slice_idx, mod_idx], cmap='gray')
            plt.title(f"Original {mod}\nSlice {slice_idx}")
            plt.axis('off')
            
            # Reconstructed
            plt.subplot(num_slices, 8, i*8 + mod_idx*2 + 2)
            plt.imshow(reconstructed[..., slice_idx, mod_idx], cmap='gray')
            plt.title(f"Reconstructed {mod}")
            plt.axis('off')
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

def visualize_latent_space(model, data, save_path=None):
    """
    Visualize the latent space distribution using PCA or t-SNE.
    
    Args:
        model: Trained CVAE model
        data: Input data to encode
        save_path: Path to save the figure
    """
    from sklearn.decomposition import PCA
    from sklearn.manifold import TSNE
    
    # Get latent representations
    z_mean, _, _ = model.encoder.predict(data)
    
    # Reduce dimensionality
    pca = PCA(n_components=2)
    z_pca = pca.fit_transform(z_mean)
    
    tsne = TSNE(n_components=2, perplexity=min(30, z_mean.shape[0]-1))
    z_tsne = tsne.fit_transform(z_mean)
    
    # Plot
    plt.figure(figsize=(12, 6))
    
    plt.subplot(1, 2, 1)
    plt.scatter(z_pca[:, 0], z_pca[:, 1])
    plt.title("PCA of Latent Space")
    plt.xlabel("PC1")
    plt.ylabel("PC2")
    
    plt.subplot(1, 2, 2)
    plt.scatter(z_tsne[:, 0], z_tsne[:, 1])
    plt.title("t-SNE of Latent Space")
    plt.xlabel("Dimension 1")
    plt.ylabel("Dimension 2")
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

In [None]:
import os
import nibabel as nib
import numpy as np
from tqdm import tqdm

# Verify paths (Kaggle specific)
VALIDATION_DATASET_PATH = "/kaggle/input/brats20-dataset-training-validation/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData"
output_dir = "/kaggle/working/reconstruction_results"

# Create output directory
os.makedirs(output_dir, exist_ok=True)

# Find available validation patients
def find_validation_patients(base_path, start=1, end=10):
    valid_patients = []
    for i in range(start, end+1):
        path_3d = f"{base_path}/BraTS20_Validation_{i:03d}"
        path_flat = f"{base_path}/BraTS20_Validation_{i:03d}_flair.nii.gz"
        if os.path.exists(path_3d) or os.path.exists(path_flat):
            valid_patients.append(i)
    return valid_patients

available_patients = find_validation_patients(VALIDATION_DATASET_PATH)
print(f"Found {len(available_patients)} validation patients: {available_patients}")

# Load first 3 available patients

# test_data = []
# for i in available_patients[:3]:
#     try:
#         data = load_brats_Validationdata(VALIDATION_DATASET_PATH, i)
#         test_data.append(data)
#         print(f"Successfully loaded patient {i}")
#     except Exception as e:
#         print(f"Error loading patient {i}: {str(e)}")

# if len(test_data) > 0:
#     test_data = np.array(train_data)
#     print(f"Loaded test data shape: {train_data.shape}")
    
# Generate and save reconstructions
generate_and_save_images(trained_cvae, test_data, output_dir, num_samples=min(3, len(train_data)))

# Visualize latent space
visualize_latent_space(trained_cvae, test_data, save_path=f"{output_dir}/latent_space.png")


#     # Verify output files were created
#     print("\nGenerated files:")
#     !ls -lh {output_dir}
# else:
#     print("No validation data could be loaded. Please check the dataset path.")

Today is 21 april 2025 ,i am working on this project but the memory issues are there i need to spend time in order to resolve the issue,i will resume working on this on 1 may 2025 <br>
## The first task would be to use the resnet50 network for the encoder and decoder,since the model is traine on 2D data and i am working on 3D data i ahve to maek some changes in the architechture.
## the second task would be to get rid of the memory issue while validating the model.

see the deepseek REplacing Encoder and Decoder with resnet50 for 
references[https://chat.deepseek.com/a/chat/s/6f76c9c7-a084-4a48-a4c7-2aba94783381](http://)

reference2:[https://github.com/nachi-hebbar/Transfer-Learning-ResNet-Keras/blob/main/ResNet_50.ipynb](http://)

In [None]:
def display_from_memory(test_data, reconstructions, sample_idx=0, slice_idx=24):
    """Display comparison using in-memory arrays"""
    modalities = ['FLAIR', 'T1', 'T1ce', 'T2']
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    
    for mod_idx, mod in enumerate(modalities):
        # Original
        axes[0, mod_idx].imshow(test_data[sample_idx, :, :, slice_idx, mod_idx], cmap='gray')
        axes[0, mod_idx].set_title(f'Original {mod}')
        axes[0, mod_idx].axis('off')
        
        # Reconstructed
        axes[1, mod_idx].imshow(reconstructions[sample_idx, :, :, slice_idx, mod_idx], cmap='gray')
        axes[1, mod_idx].set_title(f'Reconstructed {mod}')
        axes[1, mod_idx].axis('off')
    
    plt.suptitle(f'Patient {sample_idx+1} - Slice {slice_idx} Comparison', y=1.02, fontsize=16)
    plt.tight_layout()
    plt.show()

# To use this version:
# 1. First get reconstructions for your test_data
_, _, z = cvae_trained.encoder.predict(train_data)
reconstructions = cvae_trained.decoder.predict(z)

# 2. Then display
display_from_memory(train_data, reconstructions, sample_idx=0, slice_idx=24)

# Interactive version
interact(lambda sample, slice: display_from_memory(train_data, reconstructions, sample, slice),
         sample=IntSlider(min=0, max=len(train_data)-1, value=0),
         slice=IntSlider(min=0, max=47, value=24));

In [None]:
from IPython.display import Image

def show_saved_comparison(sample_idx):
    img_path = f"/kaggle/working/reconstruction_results/sample_{sample_idx}_comparison.png"
    return Image(filename=img_path)

# Display all three comparisons
for i in range(3):
    display(show_saved_comparison(i))

In [None]:
# Assuming you have your trained cvae model and test data
#for validation dataset
def load_brats_Validationdata(path, patient_id):
    modalities = []
    for mod in ['flair', 't1', 't1ce', 't2']:
        vol = nib.load(f"{path}/BraTS20_Validation_{patient_id:03d}/BraTS20_Validation_{patient_id:03d}_{mod}.nii").get_fdata()
        vol = (vol - vol.min()) / (vol.max() - vol.min())
        modalities.append(vol[..., VOLUME_START_AT:VOLUME_START_AT+VOLUME_SLICES])
    return np.stack(modalities, axis=-1)
    
output_dir = "/kaggle/working/reconstruction_results"

# Load test data (similar to how you loaded training data)
test_data = np.array([load_brats_Validationdata(VALIDATION_DATASET_PATH, i) for i in range(4, 7)])

# Generate and save reconstructions
generate_and_save_images(cvae, test_data, output_dir, num_samples=3)

# Visualize latent space
visualize_latent_space(cvae, test_data, save_path=f"{output_dir}/latent_space.png")

In [None]:
# import numpy as np
# import tensorflow as tf
# from tensorflow import keras
# from tensorflow.keras import layers
# from sklearn.model_selection import train_test_split
# import nibabel as nib
# import os
# from tqdm import tqdm

# # Parameters
# VOLUME_SLICES = 96
# VOLUME_START_AT = 22
# INPUT_SHAPE = (240, 240, VOLUME_SLICES, 4)
# latent_dim = 256
# batch_size = 4  # Increased batch size for better GPU utilization
# epochs = 100

# # Sampling layer with explicit dtype handling
# class Sampling(layers.Layer):
#     def call(self, inputs):
#         z_mean, z_log_var = inputs
#         batch = tf.shape(z_mean)[0]
#         dim = tf.shape(z_mean)[1]
        
#         epsilon = tf.keras.backend.random_normal(shape=(batch, dim), dtype=z_mean.dtype)
#         return z_mean + tf.exp(0.5 * z_log_var) * epsilon

# # Enhanced Encoder with dropout
# def create_encoder(input_shape):
#     encoder_inputs = keras.Input(shape=input_shape)
    
#     x = layers.Conv3D(64, 3, activation='relu', padding='same')(encoder_inputs)
#     x = layers.MaxPooling3D((2, 2, 2))(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Dropout(0.2)(x)
    
#     x = layers.Conv3D(128, 3, activation='relu', padding='same')(x)
#     x = layers.MaxPooling3D((2, 2, 2))(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Dropout(0.3)(x)

#     x = layers.Conv3D(128, 3, activation='relu', padding='same')(x)
#     x = layers.MaxPooling3D((2, 2, 2))(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Dropout(0.3)(x)
    
#     x = layers.Conv3D(256, 3, activation='relu', padding='same')(x)
#     x = layers.MaxPooling3D((2, 2, 2))(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Dropout(0.4)(x)
    
#     x = layers.GlobalAveragePooling3D()(x)
    
#     z_mean = layers.Dense(latent_dim, name='z_mean')(x)
#     z_log_var = layers.Dense(latent_dim, name='z_log_var')(x)
#     z = Sampling()([z_mean, z_log_var])
    
#     return keras.Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder')

# # Enhanced Decoder with dropout
# def create_decoder(output_shape):
#     latent_inputs = keras.Input(shape=(latent_dim,))
    
#     x = layers.Dense(15*15*6*256, activation='relu')(latent_inputs)
#     x = layers.Reshape((15,15,6,256))(x)
#     x = layers.Dropout(0.4)(x)
    
#     x = layers.Conv3DTranspose(256, 3, strides=2, activation='relu', padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Dropout(0.3)(x)
    
#     x = layers.Conv3DTranspose(128, 3, strides=2, activation='relu', padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Dropout(0.3)(x)

#     x = layers.Conv3DTranspose(128, 3, strides=2, activation='relu', padding='same')(x)
#     x = layers.BatchNormalization()(x)
#     x = layers.Dropout(0.2)(x)
    
#     x = layers.Conv3DTranspose(64, 3, strides=2, activation='relu', padding='same')(x)
#     x = layers.BatchNormalization()(x)

#     decoder_outputs = layers.Conv3D(output_shape[-1], 3, activation='sigmoid', padding='same')(x)
    
#     def crop_to_match(x):
#         input_shape = tf.shape(x)
#         crop_h = (input_shape[1] - output_shape[0]) // 2
#         crop_w = (input_shape[2] - output_shape[1]) // 2
#         crop_d = (input_shape[3] - output_shape[2]) // 2
        
#         return x[:, 
#                tf.maximum(crop_h, 0):tf.maximum(crop_h, 0)+output_shape[0],
#                tf.maximum(crop_w, 0):tf.maximum(crop_w, 0)+output_shape[1],
#                tf.maximum(crop_d, 0):tf.maximum(crop_d, 0)+output_shape[2],
#                :]
    
#     decoder_outputs = layers.Lambda(crop_to_match)(decoder_outputs)
    
#     return keras.Model(latent_inputs, decoder_outputs, name='decoder')

# # CVAE model
# class CVAE(keras.Model):
#     def __init__(self, encoder, decoder, **kwargs):
#         super().__init__(**kwargs)
#         self.encoder = encoder
#         self.decoder = decoder
#         self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
#         self.reconstruction_loss_tracker = keras.metrics.Mean(name="reconstruction_loss")
#         self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
        
#     def call(self, inputs):
#         z_mean, z_log_var, z = self.encoder(inputs)
#         reconstruction = self.decoder(z)
#         return reconstruction
    
#     def train_step(self, data):
#         if isinstance(data, tuple):
#             data = data[0]
            
#         with tf.GradientTape() as tape:
#             z_mean, z_log_var, z = self.encoder(data)
#             reconstruction = self.decoder(z)
            
#             reconstruction_loss = tf.reduce_mean(
#                 keras.losses.mean_squared_error(data, reconstruction)
#             )
#             kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
#             kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
#             total_loss = reconstruction_loss + kl_loss
            
#         grads = tape.gradient(total_loss, self.trainable_weights)
#         self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        
#         self.total_loss_tracker.update_state(total_loss)
#         self.reconstruction_loss_tracker.update_state(reconstruction_loss)
#         self.kl_loss_tracker.update_state(kl_loss)
        
#         return {m.name: m.result() for m in self.metrics}

# # Data loading functions
# def load_brats_data(path, patient_id):
#     modalities = []
#     for mod in ['flair', 't1', 't1ce', 't2']:
#         try:
#             # Try both possible file naming conventions
#             try:
#                 vol = nib.load(f"{path}/BraTS20_Training_{patient_id:03d}/BraTS20_Training_{patient_id:03d}_{mod}.nii").get_fdata()
#             except:
#                 vol = nib.load(f"{path}/BraTS20_Training_{patient_id:03d}_{mod}.nii.gz").get_fdata()
            
#             # Normalize and select slices
#             vol = (vol - vol.min()) / (vol.max() - vol.min())
#             modalities.append(vol[..., VOLUME_START_AT:VOLUME_START_AT+VOLUME_SLICES])
#         except Exception as e:
#             print(f"Error loading patient {patient_id} modality {mod}: {str(e)}")
#             return None
    
#     return np.stack(modalities, axis=-1)

# def load_dataset(path, num_patients=None):
#     patient_ids = []
#     for i in range(1, 1000):  # Check up to 1000 patients
#         if os.path.exists(f"{path}/BraTS20_Training_{i:03d}") or \
#            os.path.exists(f"{path}/BraTS20_Training_{i:03d}_flair.nii.gz"):
#             patient_ids.append(i)
#             if num_patients and len(patient_ids) >= num_patients:
#                 break
    
#     data = []
#     for pid in tqdm(patient_ids, desc="Loading data"):
#         patient_data = load_brats_data(path, pid)
#         if patient_data is not None:
#             data.append(patient_data)
    
#     return np.array(data)

# # Main training function
# def train_cvae():
#     # Load and preprocess data
#     train_data = load_dataset("/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData", num_patients=100)
    
#     if len(train_data) == 0:
#         raise ValueError("No training data could be loaded. Please check the dataset path.")
    
#     # Train-test split
#     train_data, test_data = train_test_split(train_data, test_size=0.2, random_state=42)
    
#     # Create data generators with batching
#     train_dataset = tf.data.Dataset.from_tensor_slices(train_data).batch(batch_size).shuffle(buffer_size=100).prefetch(tf.data.AUTOTUNE)
#     test_dataset = tf.data.Dataset.from_tensor_slices(test_data).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    
#     # Create and compile model
#     encoder = create_encoder(INPUT_SHAPE)
#     decoder = create_decoder(INPUT_SHAPE)
#     cvae = CVAE(encoder, decoder)
    
#     # Configure mixed precision
#     policy = tf.keras.mixed_precision.Policy('mixed_float16')
#     tf.keras.mixed_precision.set_global_policy(policy)
    
#     # Optimizer with learning rate schedule
#     lr_schedule = keras.optimizers.schedules.ExponentialDecay(
#         initial_learning_rate=1e-4,
#         decay_steps=10000,
#         decay_rate=0.9)
    
#     optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)
#     cvae.compile(optimizer=optimizer)
    
#     # Callbacks
#     early_stopping = keras.callbacks.EarlyStopping(
#         monitor='val_total_loss',
#         patience=15,
#         restore_best_weights=True,
#         min_delta=0.001)
    
#     checkpoint = keras.callbacks.ModelCheckpoint(
#         'best_cvae_model.h5',
#         monitor='val_total_loss',
#         save_best_only=True)
    
#     # Train the model
#     history = cvae.fit(
#         train_dataset,
#         validation_data=test_dataset,
#         epochs=epochs,
#         callbacks=[early_stopping, checkpoint]
#     )
    
#     return cvae, history

# # Run training
# if __name__ == "__main__":
#     trained_cvae, training_history = train_cvae(
    
#     # Save the final model
#     trained_cvae.save_weights('cvae_final_weights.h5')
    
#     # Visualization code would go here
#     # generate_and_save_images(trained_cvae, ...)
#     # visualize_latent_space(trained_cvae, ...)

# MRI segmentation

In [None]:
# DEFINE seg-areas  
SEGMENT_CLASSES = {
    0 : 'NOT tumor',
    1 : 'NECROTIC/CORE', # or NON-ENHANCING tumor CORE
    2 : 'EDEMA',
    3 : 'ENHANCING' # original 4 -> converted into 3 later
}


# there are 155 slices per volume
# to start at 5 and use 145 slices means we will skip the first 5 and last 5 
VOLUME_SLICES = 100 
VOLUME_START_AT = 22 # first slice of volume that we will include

# Image data descriptions

All BraTS multimodal scans are available as  NIfTI files (.nii.gz) -> commonly used medical imaging format to store brain imagin data obtained using MRI and describe different MRI settings 
1. **T1**: T1-weighted, native image, sagittal or axial 2D acquisitions, with 1–6 mm slice thickness.
2. **T1c**: T1-weighted, contrast-enhanced (Gadolinium) image, with 3D acquisition and 1 mm isotropic voxel size for most patients.
3. **T2**: T2-weighted image, axial 2D acquisition, with 2–6 mm slice thickness.
4. **FLAIR**: T2-weighted FLAIR image, axial, coronal, or sagittal 2D acquisitions, 2–6 mm slice thickness.

Data were acquired with different clinical protocols and various scanners from multiple (n=19) institutions.

All the imaging datasets have been segmented manually, by one to four raters, following the same annotation protocol, and their annotations were approved by experienced neuro-radiologists. Annotations comprise the GD-enhancing tumor (ET — label 4), the peritumoral edema (ED — label 2), and the necrotic and non-enhancing tumor core (NCR/NET — label 1), as described both in the BraTS 2012-2013 TMI paper and in the latest BraTS summarizing paper. The provided data are distributed after their pre-processing, i.e., co-registered to the same anatomical template, interpolated to the same resolution (1 mm^3) and skull-stripped.



In [None]:
TRAIN_DATASET_PATH = '../input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/'
VALIDATION_DATASET_PATH = '../input/brats20-dataset-training-validation/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData'

test_image_flair=nib.load(TRAIN_DATASET_PATH + 'BraTS20_Training_001/BraTS20_Training_001_flair.nii').get_fdata()
test_image_t1=nib.load(TRAIN_DATASET_PATH + 'BraTS20_Training_001/BraTS20_Training_001_t1.nii').get_fdata()
test_image_t1ce=nib.load(TRAIN_DATASET_PATH + 'BraTS20_Training_001/BraTS20_Training_001_t1ce.nii').get_fdata()
test_image_t2=nib.load(TRAIN_DATASET_PATH + 'BraTS20_Training_001/BraTS20_Training_001_t2.nii').get_fdata()
test_mask=nib.load(TRAIN_DATASET_PATH + 'BraTS20_Training_001/BraTS20_Training_001_seg.nii').get_fdata()


fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1,5, figsize = (20, 10))
slice_w = 25
ax1.imshow(test_image_flair[:,:,test_image_flair.shape[0]//2-slice_w], cmap = 'gray')
ax1.set_title('Image flair')
ax2.imshow(test_image_t1[:,:,test_image_t1.shape[0]//2-slice_w], cmap = 'gray')
ax2.set_title('Image t1')
ax3.imshow(test_image_t1ce[:,:,test_image_t1ce.shape[0]//2-slice_w], cmap = 'gray')
ax3.set_title('Image t1ce')
ax4.imshow(test_image_t2[:,:,test_image_t2.shape[0]//2-slice_w], cmap = 'gray')
ax4.set_title('Image t2')
ax5.imshow(test_mask[:,:,test_mask.shape[0]//2-slice_w])
ax5.set_title('Mask')


**Show whole nifti data -> print each slice from 3d data**

In [None]:
# Skip 50:-50 slices since there is not much to see
fig, ax1 = plt.subplots(1, 1, figsize = (15,15))
ax1.imshow(rotate(montage(test_image_t1[50:-50,:,:]), 90, resize=True), cmap ='gray')

**Show segment of tumor for each above slice**

In [None]:
# Skip 50:-50 slices since there is not much to see
fig, ax1 = plt.subplots(1, 1, figsize = (15,15))
ax1.imshow(rotate(montage(test_mask[60:-60,:,:]), 90, resize=True), cmap ='gray')

In [None]:
shutil.copy2(TRAIN_DATASET_PATH + 'BraTS20_Training_001/BraTS20_Training_001_flair.nii', './test_gif_BraTS20_Training_001_flair.nii')
gif2nif.write_gif_normal('./test_gif_BraTS20_Training_001_flair.nii')

**Gif representation of slices in 3D volume**
<img src="https://media1.tenor.com/images/15427ffc1399afc3334f12fd27549a95/tenor.gif?itemid=20554734">

**Show segments of tumor using different effects**

In [None]:
niimg = nl.image.load_img(TRAIN_DATASET_PATH + 'BraTS20_Training_001/BraTS20_Training_001_flair.nii')
nimask = nl.image.load_img(TRAIN_DATASET_PATH + 'BraTS20_Training_001/BraTS20_Training_001_seg.nii')

fig, axes = plt.subplots(nrows=4, figsize=(30, 40))


nlplt.plot_anat(niimg,
                title='BraTS20_Training_001_flair.nii plot_anat',
                axes=axes[0])

nlplt.plot_epi(niimg,
               title='BraTS20_Training_001_flair.nii plot_epi',
               axes=axes[1])

nlplt.plot_img(niimg,
               title='BraTS20_Training_001_flair.nii plot_img',
               axes=axes[2])

nlplt.plot_roi(nimask, 
               title='BraTS20_Training_001_flair.nii with mask plot_roi',
               bg_img=niimg, 
               axes=axes[3], cmap='Paired')

plt.show()

# Create model || U-Net: Convolutional Networks for Biomedical Image Segmentation
he u-net is convolutional network architecture for fast and precise segmentation of images. Up to now it has outperformed the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. It has won the Grand Challenge for Computer-Automated Detection of Caries in Bitewing Radiography at ISBI 2015, and it has won the Cell Tracking Challenge at ISBI 2015 on the two most challenging transmitted light microscopy categories (Phase contrast and DIC microscopy) by a large margin
[more on](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/)
![official definiton](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)


# Loss function
**Dice coefficient**
, which is essentially a measure of overlap between two samples. This measure ranges from 0 to 1 where a Dice coefficient of 1 denotes perfect and complete overlap. The Dice coefficient was originally developed for binary data, and can be calculated as:

![dice loss](https://wikimedia.org/api/rest_v1/media/math/render/svg/a80a97215e1afc0b222e604af1b2099dc9363d3b)

**As matrices**
![dice loss](https://www.jeremyjordan.me/content/images/2018/05/intersection-1.png)

[Implementation, (images above) and explanation can be found here](https://www.jeremyjordan.me/semantic-segmentation/)

In [None]:
# dice loss as defined above for 4 classes
def dice_coef(y_true, y_pred, smooth=1.0):
    class_num = 4
    for i in range(class_num):
        y_true_f = K.flatten(y_true[:,:,:,i])
        y_pred_f = K.flatten(y_pred[:,:,:,i])
        intersection = K.sum(y_true_f * y_pred_f)
        loss = ((2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth))
   #     K.print_tensor(loss, message='loss value for class {} : '.format(SEGMENT_CLASSES[i]))
        if i == 0:
            total_loss = loss
        else:
            total_loss = total_loss + loss
    total_loss = total_loss / class_num
#    K.print_tensor(total_loss, message=' total dice coef: ')
    return total_loss


 
# define per class evaluation of dice coef
# inspired by https://github.com/keras-team/keras/issues/9395
def dice_coef_necrotic(y_true, y_pred, epsilon=1e-6):
    intersection = K.sum(K.abs(y_true[:,:,:,1] * y_pred[:,:,:,1]))
    return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,1])) + K.sum(K.square(y_pred[:,:,:,1])) + epsilon)

def dice_coef_edema(y_true, y_pred, epsilon=1e-6):
    intersection = K.sum(K.abs(y_true[:,:,:,2] * y_pred[:,:,:,2]))
    return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,2])) + K.sum(K.square(y_pred[:,:,:,2])) + epsilon)

def dice_coef_enhancing(y_true, y_pred, epsilon=1e-6):
    intersection = K.sum(K.abs(y_true[:,:,:,3] * y_pred[:,:,:,3]))
    return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,3])) + K.sum(K.square(y_pred[:,:,:,3])) + epsilon)



# Computing Precision 
def precision(y_true, y_pred):
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision

    
# Computing Sensitivity      
def sensitivity(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    return true_positives / (possible_positives + K.epsilon())


# Computing Specificity
def specificity(y_true, y_pred):
    true_negatives = K.sum(K.round(K.clip((1-y_true) * (1-y_pred), 0, 1)))
    possible_negatives = K.sum(K.round(K.clip(1-y_true, 0, 1)))
    return true_negatives / (possible_negatives + K.epsilon())

In [None]:
IMG_SIZE=128

In [None]:
# source https://naomi-fridman.medium.com/multi-class-image-segmentation-a5cc671e647a

def build_unet(inputs, ker_init, dropout):
    conv1 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(inputs)
    conv1 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(conv1)
    
    pool = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(pool)
    conv = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(conv)
    
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(pool1)
    conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(conv2)
    
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(pool2)
    conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(conv3)
    
    
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv5 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(pool4)
    conv5 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(conv5)
    drop5 = Dropout(dropout)(conv5)

    up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(UpSampling2D(size = (2,2))(drop5))
    merge7 = concatenate([conv3,up7], axis = 3)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(merge7)
    conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(conv7)

    up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(UpSampling2D(size = (2,2))(conv7))
    merge8 = concatenate([conv2,up8], axis = 3)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(merge8)
    conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(conv8)

    up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(UpSampling2D(size = (2,2))(conv8))
    merge9 = concatenate([conv,up9], axis = 3)
    conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(merge9)
    conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(conv9)
    
    up = Conv2D(32, 2, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(UpSampling2D(size = (2,2))(conv9))
    merge = concatenate([conv1,up], axis = 3)
    conv = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(merge)
    conv = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = ker_init)(conv)
    
    conv10 = Conv2D(4, (1,1), activation = 'softmax')(conv)
    
    return Model(inputs = inputs, outputs = conv10)

input_layer = Input((IMG_SIZE, IMG_SIZE, 2))

model = build_unet(input_layer, 'he_normal', 0.2)
model.compile(loss="categorical_crossentropy", optimizer=keras.optimizers.Adam(learning_rate=0.001), metrics = ['accuracy',tf.keras.metrics.MeanIoU(num_classes=4), dice_coef, precision, sensitivity, specificity, dice_coef_necrotic, dice_coef_edema ,dice_coef_enhancing] )

**model architecture** <br>
If you are about to use U-NET, I suggest to check out this awesome library that I found later, after manual implementation of U-NET [keras-unet-collection](https://pypi.org/project/keras-unet-collection/), which also contains implementation of dice loss, tversky loss and many more!

In [None]:
plot_model(model, 
           show_shapes = True,
           show_dtype=False,
           show_layer_names = True, 
           rankdir = 'TB', 
           expand_nested = False, 
           dpi = 70)

# Load data
Loading all data into memory is not a good idea since the data are too big to fit in.
So we will create dataGenerators - load data on the fly as explained [here](https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly)

In [None]:
# lists of directories with studies
train_and_val_directories = [f.path for f in os.scandir(TRAIN_DATASET_PATH) if f.is_dir()]

# file BraTS20_Training_355 has ill formatted name for for seg.nii file
train_and_val_directories.remove(TRAIN_DATASET_PATH+'BraTS20_Training_355')


def pathListIntoIds(dirList):
    x = []
    for i in range(0,len(dirList)):
        x.append(dirList[i][dirList[i].rfind('/')+1:])
    return x

train_and_test_ids = pathListIntoIds(train_and_val_directories); 

    
train_test_ids, val_ids = train_test_split(train_and_test_ids,test_size=0.2) 
train_ids, test_ids = train_test_split(train_test_ids,test_size=0.15) 

**Override Keras sequence DataGenerator class**

In [None]:
class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, dim=(IMG_SIZE,IMG_SIZE), batch_size = 1, n_channels = 2, shuffle=True):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        Batch_ids = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(Batch_ids)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, Batch_ids):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.zeros((self.batch_size*VOLUME_SLICES, *self.dim, self.n_channels))
        y = np.zeros((self.batch_size*VOLUME_SLICES, 240, 240))
        Y = np.zeros((self.batch_size*VOLUME_SLICES, *self.dim, 4))

        
        # Generate data
        for c, i in enumerate(Batch_ids):
            case_path = os.path.join(TRAIN_DATASET_PATH, i)

            data_path = os.path.join(case_path, f'{i}_flair.nii');
            flair = nib.load(data_path).get_fdata()    

            data_path = os.path.join(case_path, f'{i}_t1ce.nii');
            ce = nib.load(data_path).get_fdata()
            
            data_path = os.path.join(case_path, f'{i}_seg.nii');
            seg = nib.load(data_path).get_fdata()
        
            for j in range(VOLUME_SLICES):
                 X[j +VOLUME_SLICES*c,:,:,0] = cv2.resize(flair[:,:,j+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE));
                 X[j +VOLUME_SLICES*c,:,:,1] = cv2.resize(ce[:,:,j+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE));

                 y[j +VOLUME_SLICES*c] = seg[:,:,j+VOLUME_START_AT];
                    
        # Generate masks
        y[y==4] = 3;
        mask = tf.one_hot(y, 4);
        Y = tf.image.resize(mask, (IMG_SIZE, IMG_SIZE));
        return X/np.max(X), Y
        
training_generator = DataGenerator(train_ids)
valid_generator = DataGenerator(val_ids)
test_generator = DataGenerator(test_ids)

**Number of data used**
for training / testing / validation

In [None]:
# show number of data for each dir 
def showDataLayout():
    plt.bar(["Train","Valid","Test"],
    [len(train_ids), len(val_ids), len(test_ids)], align='center',color=[ 'green','red', 'blue'])
    plt.legend()

    plt.ylabel('Number of images')
    plt.title('Data distribution')

    plt.show()
    
showDataLayout()

**Add callback for training process**

In [None]:
csv_logger = CSVLogger('training.log', separator=',', append=False)


callbacks = [
#     keras.callbacks.EarlyStopping(monitor='loss', min_delta=0,
#                               patience=2, verbose=1, mode='auto'),
      keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2,
                              patience=2, min_lr=0.000001, verbose=1),
#  keras.callbacks.ModelCheckpoint(filepath = 'model_.{epoch:02d}-{val_loss:.6f}.m5',
#                             verbose=1, save_best_only=True, save_weights_only = True)
        csv_logger
    ]

# Train model
My best model was trained with 81% accuracy on mean IOU and 65.5% on Dice loss <br>
I will load this pretrained model instead of training again

In [None]:
K.clear_session()

# history =  model.fit(training_generator,
#                     epochs=35,
#                     steps_per_epoch=len(train_ids),
#                     callbacks= callbacks,
#                     validation_data = valid_generator
#                     )  
# model.save("model_x1_1.h5")

**Visualize the training process**

In [None]:
############ load trained model ################
model = keras.models.load_model('../input/modelperclasseval/model_per_class.h5', 
                                   custom_objects={ 'accuracy' : tf.keras.metrics.MeanIoU(num_classes=4),
                                                   "dice_coef": dice_coef,
                                                   "precision": precision,
                                                   "sensitivity":sensitivity,
                                                   "specificity":specificity,
                                                   "dice_coef_necrotic": dice_coef_necrotic,
                                                   "dice_coef_edema": dice_coef_edema,
                                                   "dice_coef_enhancing": dice_coef_enhancing
                                                  }, compile=False)

history = pd.read_csv('../input/modelperclasseval/training_per_class.log', sep=',', engine='python')

hist=history

############### ########## ####### #######

# hist=history.history

acc=hist['accuracy']
val_acc=hist['val_accuracy']

epoch=range(len(acc))

loss=hist['loss']
val_loss=hist['val_loss']

train_dice=hist['dice_coef']
val_dice=hist['val_dice_coef']

f,ax=plt.subplots(1,4,figsize=(16,8))

ax[0].plot(epoch,acc,'b',label='Training Accuracy')
ax[0].plot(epoch,val_acc,'r',label='Validation Accuracy')
ax[0].legend()

ax[1].plot(epoch,loss,'b',label='Training Loss')
ax[1].plot(epoch,val_loss,'r',label='Validation Loss')
ax[1].legend()

ax[2].plot(epoch,train_dice,'b',label='Training dice coef')
ax[2].plot(epoch,val_dice,'r',label='Validation dice coef')
ax[2].legend()

ax[3].plot(epoch,hist['mean_io_u'],'b',label='Training mean IOU')
ax[3].plot(epoch,hist['val_mean_io_u'],'r',label='Validation mean IOU')
ax[3].legend()

plt.show()

# Prediction examples 

In [None]:
# mri type must one of 1) flair 2) t1 3) t1ce 4) t2 ------- or even 5) seg
# returns volume of specified study at `path`
def imageLoader(path):
    image = nib.load(path).get_fdata()
    X = np.zeros((self.batch_size*VOLUME_SLICES, *self.dim, self.n_channels))
    for j in range(VOLUME_SLICES):
        X[j +VOLUME_SLICES*c,:,:,0] = cv2.resize(image[:,:,j+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE));
        X[j +VOLUME_SLICES*c,:,:,1] = cv2.resize(ce[:,:,j+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE));

        y[j +VOLUME_SLICES*c] = seg[:,:,j+VOLUME_START_AT];
    return np.array(image)


# load nifti file at `path`
# and load each slice with mask from volume
# choose the mri type & resize to `IMG_SIZE`
def loadDataFromDir(path, list_of_files, mriType, n_images):
    scans = []
    masks = []
    for i in list_of_files[:n_images]:
        fullPath = glob.glob( i + '/*'+ mriType +'*')[0]
        currentScanVolume = imageLoader(fullPath)
        currentMaskVolume = imageLoader( glob.glob( i + '/*seg*')[0] ) 
        # for each slice in 3D volume, find also it's mask
        for j in range(0, currentScanVolume.shape[2]):
            scan_img = cv2.resize(currentScanVolume[:,:,j], dsize=(IMG_SIZE,IMG_SIZE), interpolation=cv2.INTER_AREA).astype('uint8')
            mask_img = cv2.resize(currentMaskVolume[:,:,j], dsize=(IMG_SIZE,IMG_SIZE), interpolation=cv2.INTER_AREA).astype('uint8')
            scans.append(scan_img[..., np.newaxis])
            masks.append(mask_img[..., np.newaxis])
    return np.array(scans, dtype='float32'), np.array(masks, dtype='float32')
        
#brains_list_test, masks_list_test = loadDataFromDir(VALIDATION_DATASET_PATH, test_directories, "flair", 5)


In [None]:
def predictByPath(case_path,case):
    files = next(os.walk(case_path))[2]
    X = np.empty((VOLUME_SLICES, IMG_SIZE, IMG_SIZE, 2))
  #  y = np.empty((VOLUME_SLICES, IMG_SIZE, IMG_SIZE))
    
    vol_path = os.path.join(case_path, f'BraTS20_Training_{case}_flair.nii');
    flair=nib.load(vol_path).get_fdata()
    
    vol_path = os.path.join(case_path, f'BraTS20_Training_{case}_t1ce.nii');
    ce=nib.load(vol_path).get_fdata() 
    
 #   vol_path = os.path.join(case_path, f'BraTS20_Training_{case}_seg.nii');
 #   seg=nib.load(vol_path).get_fdata()  

    
    for j in range(VOLUME_SLICES):
        X[j,:,:,0] = cv2.resize(flair[:,:,j+VOLUME_START_AT], (IMG_SIZE,IMG_SIZE))
        X[j,:,:,1] = cv2.resize(ce[:,:,j+VOLUME_START_AT], (IMG_SIZE,IMG_SIZE))
 #       y[j,:,:] = cv2.resize(seg[:,:,j+VOLUME_START_AT], (IMG_SIZE,IMG_SIZE))
        
  #  model.evaluate(x=X,y=y[:,:,:,0], callbacks= callbacks)
    return model.predict(X/np.max(X), verbose=1)


def showPredictsById(case, start_slice = 60):
    path = f"../input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_{case}"
    gt = nib.load(os.path.join(path, f'BraTS20_Training_{case}_seg.nii')).get_fdata()
    origImage = nib.load(os.path.join(path, f'BraTS20_Training_{case}_flair.nii')).get_fdata()
    p = predictByPath(path,case)

    core = p[:,:,:,1]
    edema= p[:,:,:,2]
    enhancing = p[:,:,:,3]

    plt.figure(figsize=(18, 50))
    f, axarr = plt.subplots(1,6, figsize = (18, 50)) 

    for i in range(6): # for each image, add brain background
        axarr[i].imshow(cv2.resize(origImage[:,:,start_slice+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE)), cmap="gray", interpolation='none')
    
    axarr[0].imshow(cv2.resize(origImage[:,:,start_slice+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE)), cmap="gray")
    axarr[0].title.set_text('Original image flair')
    curr_gt=cv2.resize(gt[:,:,start_slice+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE), interpolation = cv2.INTER_NEAREST)
    axarr[1].imshow(curr_gt, cmap="Reds", interpolation='none', alpha=0.3) # ,alpha=0.3,cmap='Reds'
    axarr[1].title.set_text('Ground truth')
    axarr[2].imshow(p[start_slice,:,:,1:4], cmap="Reds", interpolation='none', alpha=0.3)
    axarr[2].title.set_text('all classes')
    axarr[3].imshow(edema[start_slice,:,:], cmap="OrRd", interpolation='none', alpha=0.3)
    axarr[3].title.set_text(f'{SEGMENT_CLASSES[1]} predicted')
    axarr[4].imshow(core[start_slice,:,], cmap="OrRd", interpolation='none', alpha=0.3)
    axarr[4].title.set_text(f'{SEGMENT_CLASSES[2]} predicted')
    axarr[5].imshow(enhancing[start_slice,:,], cmap="OrRd", interpolation='none', alpha=0.3)
    axarr[5].title.set_text(f'{SEGMENT_CLASSES[3]} predicted')
    plt.show()
    
    
showPredictsById(case=test_ids[0][-3:])
showPredictsById(case=test_ids[1][-3:])
showPredictsById(case=test_ids[2][-3:])
showPredictsById(case=test_ids[3][-3:])
showPredictsById(case=test_ids[4][-3:])
showPredictsById(case=test_ids[5][-3:])
showPredictsById(case=test_ids[6][-3:])


# mask = np.zeros((10,10))
# mask[3:-3, 3:-3] = 1 # white square in black background
# im = mask + np.random.randn(10,10) * 0.01 # random image
# masked = np.ma.masked_where(mask == 0, mask)

# plt.figure()
# plt.subplot(1,2,1)
# plt.imshow(im, 'gray', interpolation='none')
# plt.subplot(1,2,2)
# plt.imshow(im, 'gray', interpolation='none')
# plt.imshow(masked, 'jet', interpolation='none', alpha=0.7)
# plt.show()

# Evaluation

In [None]:
case = case=test_ids[3][-3:]
path = f"../input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_{case}"
gt = nib.load(os.path.join(path, f'BraTS20_Training_{case}_seg.nii')).get_fdata()
p = predictByPath(path,case)


core = p[:,:,:,1]
edema= p[:,:,:,2]
enhancing = p[:,:,:,3]


i=40 # slice at
eval_class = 2 #     0 : 'NOT tumor',  1 : 'ENHANCING',    2 : 'CORE',    3 : 'WHOLE'



gt[gt != eval_class] = 1 # use only one class for per class evaluation 

resized_gt = cv2.resize(gt[:,:,i+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE))

plt.figure()
f, axarr = plt.subplots(1,2) 
axarr[0].imshow(resized_gt, cmap="gray")
axarr[0].title.set_text('ground truth')
axarr[1].imshow(p[i,:,:,eval_class], cmap="gray")
axarr[1].title.set_text(f'predicted class: {SEGMENT_CLASSES[eval_class]}')
plt.show()

In [None]:
model.compile(loss="categorical_crossentropy", optimizer=keras.optimizers.Adam(learning_rate=0.001), metrics = ['accuracy',tf.keras.metrics.MeanIoU(num_classes=4), dice_coef, precision, sensitivity, specificity, dice_coef_necrotic, dice_coef_edema, dice_coef_enhancing] )
# Evaluate the model on the test data using `evaluate`
print("Evaluate on test data")
results = model.evaluate(test_generator, batch_size=100, callbacks= callbacks)
print("test loss, test acc:", results)