In [13]:
import os
import json
import numpy as np
import nibabel as nib
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
from tensorflow.keras.utils import Sequence
from scipy.ndimage import rotate
from tqdm import tqdm
import gc
import seaborn as sns
from sklearn.metrics import confusion_matrix
import ipywidgets as widgets
from IPython.display import display
import signal

In [14]:
class BrainTumor3DDataset:
    def __init__(self, base_path):
        self.base_path = base_path
        self.dataset_json = self._load_dataset_json()
        self.train_files = self._get_training_files()
        
    def _load_dataset_json(self):
        json_path = os.path.join(self.base_path, 'ML_Decathlon_Dataset/Task01_BrainTumour/dataset.json')
        with open(json_path, 'r') as f:
            return json.load(f)
    
    def _get_training_files(self):
        return self.dataset_json['training']
    
    def validate_data_integrity(self, images, labels):
        """Validate data integrity and check for potential issues."""
        print("\nData Validation Report:")
        print("-----------------------")
        
        # Check shapes
        print(f"Number of samples: {len(images)}")
        print(f"Image shape: {images[0].shape}")
        print(f"Label shape: {labels[0].shape}")
        
        # Check value ranges
        for i, (img, lbl) in enumerate(zip(images, labels)):
            print(f"\nSample {i}:")
            print(f"Image value range: [{np.min(img):.3f}, {np.max(img):.3f}]")
            print(f"Unique labels: {np.unique(lbl)}")
            
            # Check for NaN/Inf
            if np.any(np.isnan(img)) or np.any(np.isinf(img)):
                print("WARNING: Found NaN or Inf values in image!")
            
            # Check label validity
            if not np.array_equal(np.unique(lbl), np.arange(len(np.unique(lbl)))):
                print("WARNING: Labels might not be consecutive integers!")
        
        return True

    def load_volume(self, file_path):
        full_path = os.path.join(self.base_path, 'ML_Decathlon_Dataset/Task01_BrainTumour', 
                                file_path.replace('./', ''))
        return nib.load(full_path).get_fdata()

    def analyze_class_distribution(self, labels):
        """Analyze class distribution in the dataset"""
        class_counts = {}
        total_voxels = 0
        
        for label_volume in labels:
            unique, counts = np.unique(label_volume, return_counts=True)
            total_voxels += label_volume.size
            
            for class_idx, count in zip(unique, counts):
                if class_idx not in class_counts:
                    class_counts[class_idx] = 0
                class_counts[class_idx] += count
        
        # Convert to percentages
        class_percentages = {k: (v/total_voxels)*100 for k, v in class_counts.items()}
        
        return class_counts, class_percentages

    def preprocess_volume(self, volume):
        """Optimized preprocessing using vectorized operations"""
        preprocessed = np.zeros_like(volume, dtype=np.float32)
        for i in range(volume.shape[-1]):
            modality = volume[..., i]
            nonzero_mask = modality != 0
            if np.any(nonzero_mask):
                mean = np.mean(modality[nonzero_mask])
                std = np.std(modality[nonzero_mask])
                if std != 0:
                    preprocessed[..., i] = (modality - mean) / std
        return preprocessed

    def prepare_data(self, num_samples=None):
        images = []
        labels = []
        
        train_files = self.train_files[:num_samples] if num_samples else self.train_files
        
        for idx, file_info in tqdm(enumerate(train_files), desc="Loading data", total=len(train_files)):
            try:
                image = self.load_volume(file_info['image'])
                label = self.load_volume(file_info['label'])
                
                # Check for potential data issues
                if np.any(np.isnan(image)) or np.any(np.isinf(image)):
                    print(f"Warning: Found NaN or Inf values in image {idx}")
                    continue
                
                image = self.preprocess_volume(image)
                
                images.append(image)
                labels.append(label)
                
                gc.collect()
                
            except Exception as e:
                print(f"Error processing file {idx}: {e}")
                continue
            
        return images, labels

In [15]:
class DataGenerator3D(Sequence):
    def __init__(self, image_list, label_list, batch_size=1, patch_size=(64, 64, 64),
                 n_channels=4, n_classes=4, shuffle=True, augment=False):
        self.image_list = image_list
        self.label_list = label_list
        self.batch_size = batch_size
        self.patch_size = patch_size
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.augment = augment
        
        self.valid_patches = []
        self._calculate_valid_patches()
        self.on_epoch_end()

    def _calculate_valid_patches(self):
        valid_patches = []
        stride = [p // 2 for p in self.patch_size]
        
        for idx, image in enumerate(self.image_list):
            x_coords = range(0, image.shape[0] - self.patch_size[0], stride[0])
            y_coords = range(0, image.shape[1] - self.patch_size[1], stride[1])
            z_coords = range(0, image.shape[2] - self.patch_size[2], stride[2])
            
            for x in x_coords:
                for y in y_coords:
                    for z in z_coords:
                        label_patch = self.label_list[idx][
                            x:x + self.patch_size[0],
                            y:y + self.patch_size[1],
                            z:z + self.patch_size[2]
                        ]
                        if np.any(label_patch):  # Only include patches with labels
                            valid_patches.append((idx, x, y, z))
        
        self.valid_patches = valid_patches
        print(f"Total valid patches: {len(self.valid_patches)}")

    def __len__(self):
        return int(np.ceil(len(self.valid_patches) / self.batch_size))

    def __getitem__(self, index):
        start_idx = index * self.batch_size
        end_idx = min((index + 1) * self.batch_size, len(self.valid_patches))
        batch_patches = self.valid_patches[start_idx:end_idx]
        
        batch_size = len(batch_patches)
        X = np.zeros((batch_size, *self.patch_size, self.n_channels), dtype=np.float32)
        y = np.zeros((batch_size, *self.patch_size, self.n_classes), dtype=np.float32)

        for i, (img_idx, x, y_coord, z) in enumerate(batch_patches):
            X[i], y[i] = self._extract_patch(img_idx, x, y_coord, z)

        return X, y

    def _extract_patch(self, img_idx, x, y, z):
        image = self.image_list[img_idx]
        label = self.label_list[img_idx]

        patch_x = image[x:x + self.patch_size[0],
                       y:y + self.patch_size[1],
                       z:z + self.patch_size[2]].astype(np.float32)

        patch_y = np.zeros((*self.patch_size, self.n_classes), dtype=np.float32)
        for c in range(self.n_classes):
            patch_y[..., c] = (label[x:x + self.patch_size[0],
                                    y:y + self.patch_size[1],
                                    z:z + self.patch_size[2]] == c)

        if self.augment:
            patch_x, patch_y = self._augment_data(patch_x, patch_y)

        return patch_x, patch_y

    @staticmethod
    def _augment_data(image, label):
        if np.random.random() > 0.5:
            angle = np.random.uniform(-20, 20)
            image = np.stack([rotate(image[..., c], angle, axes=(0, 1), reshape=False)
                            for c in range(image.shape[-1])], axis=-1)
            label = np.stack([rotate(label[..., c], angle, axes=(0, 1), reshape=False)
                            for c in range(label.shape[-1])], axis=-1)
        return image, label

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.valid_patches)


In [16]:
def create_3d_unet(input_shape, n_classes=4, n_filters=16):
    inputs = tf.keras.Input(input_shape)
    
    # Encoder
    conv1 = conv_block_3d(inputs, n_filters)
    pool1 = tf.keras.layers.MaxPooling3D(pool_size=(2, 2, 2))(conv1)
    
    conv2 = conv_block_3d(pool1, n_filters*2)
    pool2 = tf.keras.layers.MaxPooling3D(pool_size=(2, 2, 2))(conv2)
    
    conv3 = conv_block_3d(pool2, n_filters*4)
    pool3 = tf.keras.layers.MaxPooling3D(pool_size=(2, 2, 2))(conv3)
    
    # Bridge
    conv4 = conv_block_3d(pool3, n_filters*8)
    
    # Decoder
    up5 = tf.keras.layers.Conv3DTranspose(n_filters*4, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv4)
    concat5 = tf.keras.layers.concatenate([up5, conv3])
    conv5 = conv_block_3d(concat5, n_filters*4)
    
    up6 = tf.keras.layers.Conv3DTranspose(n_filters*2, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv5)
    concat6 = tf.keras.layers.concatenate([up6, conv2])
    conv6 = conv_block_3d(concat6, n_filters*2)
    
    up7 = tf.keras.layers.Conv3DTranspose(n_filters, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv6)
    concat7 = tf.keras.layers.concatenate([up7, conv1])
    conv7 = conv_block_3d(concat7, n_filters)
    
    outputs = tf.keras.layers.Conv3D(n_classes, (1, 1, 1), activation='softmax')(conv7)
    
    model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
    return model


In [17]:
def conv_block_3d(inputs, n_filters, kernel_size=(3, 3, 3)):
    x = tf.keras.layers.Conv3D(n_filters, kernel_size, padding='same')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    
    x = tf.keras.layers.Conv3D(n_filters, kernel_size, padding='same')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    return x

def dice_loss(y_true, y_pred):
    smooth = 1e-5
    
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    numerator = 2 * tf.reduce_sum(y_true * y_pred) + smooth
    denominator = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + smooth
    
    return 1 - numerator / denominator

In [18]:
class TrainingController:
    def __init__(self):
        self.stop_training = False
        # Create button widget
        self.button = widgets.Button(
            description='Stop Training',
            button_style='danger',  # 'success', 'info', 'warning', 'danger' or ''
            tooltip='Click to stop training after current batch',
            icon='stop'  # (optional) icon name from Font Awesome
        )
        self.button.on_click(self.on_button_clicked)
        self.status_label = widgets.Label(value='Training in progress...')
        self.container = widgets.VBox([self.button, self.status_label])
        display(self.container)
    
    def on_button_clicked(self, b):
        self.stop_training = True
        self.status_label.value = "Stopping... Please wait for current batch to complete."
        print("\nStop signal received. Training will stop after current batch...")

class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self, controller):
        super().__init__()
        self.controller = controller
        self.batch_count = 0
        
    def on_train_begin(self, logs=None):
        print("Training started. Press the Stop button or Ctrl+C to stop training safely.")
    
    def on_batch_end(self, batch, logs=None):
        self.batch_count += 1
        # Check stop condition every batch
        if self.controller.stop_training:
            self.model.stop_training = True
            print(f"\nTraining stopped at batch {self.batch_count}")
            # Save model immediately
            self.model.save(f'model_stopped_batch_{self.batch_count}.keras')
            print(f"Model saved as model_stopped_batch_{self.batch_count}.keras")
    
    def on_epoch_end(self, epoch, logs=None):
        # Save model at the end of each epoch
        self.model.save(f'model_epoch_{epoch}.keras')
        if self.controller.stop_training:
            print(f"\nTraining stopped at epoch {epoch}")
            self.controller.status_label.value = f"Training stopped at epoch {epoch}"


In [23]:
def create_3d_unet_weighted(input_shape, n_classes=4, n_filters=16):
    inputs = tf.keras.Input(input_shape)
    
    # Encoder
    conv1 = conv_block_3d(inputs, n_filters)
    pool1 = tf.keras.layers.MaxPooling3D(pool_size=(2, 2, 2))(conv1)
    
    conv2 = conv_block_3d(pool1, n_filters*2)
    pool2 = tf.keras.layers.MaxPooling3D(pool_size=(2, 2, 2))(conv2)
    
    conv3 = conv_block_3d(pool2, n_filters*4)
    pool3 = tf.keras.layers.MaxPooling3D(pool_size=(2, 2, 2))(conv3)
    
    # Bridge
    conv4 = conv_block_3d(pool3, n_filters*8)
    
    # Decoder
    up5 = tf.keras.layers.Conv3DTranspose(n_filters*4, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv4)
    concat5 = tf.keras.layers.concatenate([up5, conv3])
    conv5 = conv_block_3d(concat5, n_filters*4)
    
    up6 = tf.keras.layers.Conv3DTranspose(n_filters*2, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv5)
    concat6 = tf.keras.layers.concatenate([up6, conv2])
    conv6 = conv_block_3d(concat6, n_filters*2)
    
    up7 = tf.keras.layers.Conv3DTranspose(n_filters, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv6)
    concat7 = tf.keras.layers.concatenate([up7, conv1])
    conv7 = conv_block_3d(concat7, n_filters)
    
    outputs = tf.keras.layers.Conv3D(n_classes, (1, 1, 1), activation='softmax')(conv7)
    
    model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
    return model

def weighted_dice_loss(y_true, y_pred):
    smooth = 1e-5
    
    # Convert to float32
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    # Calculate class weights based on inverse frequency
    weights = tf.reduce_sum(y_true, axis=[0,1,2,3])
    weights = 1.0 / (weights + smooth)
    weights = weights / tf.reduce_sum(weights)  # Normalize weights
    
    # Calculate weighted Dice loss for each class
    numerator = 2.0 * tf.reduce_sum(y_true * y_pred * weights, axis=[0,1,2,3])
    denominator = tf.reduce_sum((y_true + y_pred) * weights, axis=[0,1,2,3])
    
    dice_scores = (numerator + smooth) / (denominator + smooth)
    return 1.0 - tf.reduce_mean(dice_scores)

def dice_coefficient(y_true, y_pred):
    smooth = 1e-5
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    intersection = tf.reduce_sum(y_true * y_pred, axis=[1,2,3])
    union = tf.reduce_sum(y_true, axis=[1,2,3]) + tf.reduce_sum(y_pred, axis=[1,2,3])
    
    dice = (2.0 * intersection + smooth) / (union + smooth)
    return tf.reduce_mean(dice)

def train_weighted_model(validation_split=0.2, max_epochs=5, num_samples=10):
    try:
        # Initialize controller
        controller = TrainingController()
        
        # Setup keyboard interrupt handler
        original_sigint = signal.getsignal(signal.SIGINT)
        def keyboard_interrupt_handler(sig, frame):
            print("\nKeyboard interrupt (Ctrl+C) detected. Stopping training safely...")
            controller.stop_training = True
            # Restore original SIGINT handler
            signal.signal(signal.SIGINT, original_sigint)
        
        # Set up the keyboard interrupt handler
        signal.signal(signal.SIGINT, keyboard_interrupt_handler)
              
        # Load and prepare data
        dataset = BrainTumor3DDataset(base_path='.')
        print("Starting data preparation...")
        images, labels = dataset.prepare_data(num_samples=num_samples)
        
        # Validate data
        dataset.validate_data_integrity(images, labels)
        
        # Analyze class distribution
        class_counts, class_percentages = dataset.analyze_class_distribution(labels)
        print("\nClass distribution:")
        for class_idx, count in class_counts.items():
            print(f"Class {class_idx}: {count} voxels ({class_percentages[class_idx]:.2f}%)")
        
        # Split data
        print("\nCreating train/val split...")
        X_train, X_val, y_train, y_val = train_test_split(
            images, labels, test_size=validation_split, random_state=42
        )
        
        # Clear memory
        del images, labels
        gc.collect()
        
        # Create data generators
        patch_size = (64, 64, 64)
        train_generator = DataGenerator3D(
            X_train, y_train, batch_size=1, patch_size=patch_size, augment=True
        )
        val_generator = DataGenerator3D(
            X_val, y_val, batch_size=1, patch_size=patch_size, augment=False
        )
        
        # Create and compile model with weighted loss
        input_shape = (*patch_size, 4)
        model = create_3d_unet_weighted(input_shape, n_classes=4)
        
        optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
        
        # Compile model with weighted loss and additional metrics
        model.compile(
            optimizer=optimizer,
            loss=weighted_dice_loss,
            metrics=[
                dice_coefficient,
                tf.keras.metrics.MeanIoU(num_classes=4),
                tf.keras.metrics.Precision(),
                tf.keras.metrics.Recall()
            ]
        )
        
        # Callbacks
        callbacks = [
            CustomCallback(controller),
            tf.keras.callbacks.EarlyStopping(
                monitor='val_dice_coefficient',
                mode='max',
                patience=5,
                restore_best_weights=True,
                verbose=1
            ),
            tf.keras.callbacks.ModelCheckpoint(
                'best_3d_model_weighted.keras',
                save_best_only=True,
                monitor='val_dice_coefficient',
                mode='max',
                verbose=1
            ),
            tf.keras.callbacks.ReduceLROnPlateau(
                monitor='val_dice_coefficient',
                mode='max',
                factor=0.5,
                patience=3,
                min_lr=1e-6,
                verbose=1
            ),
            tf.keras.callbacks.CSVLogger('training_log_weighted.csv'),
            tf.keras.callbacks.ModelCheckpoint(
                'checkpoint_epoch_{epoch:02d}.keras',
                save_freq='epoch',
                verbose=1
            ),
        ]
        
        # Train model
        try:
            history = model.fit(
                train_generator,
                validation_data=val_generator,
                epochs=max_epochs,
                callbacks=callbacks
            )
        except KeyboardInterrupt:
            print("\nTraining interrupted by user. Saving model...")
            model.save('interrupted_model.keras')
            print("Model saved as interrupted_model.keras")
            return model, history
        except Exception as e:
            print(f"\nAn error occurred during training: {str(e)}")
            model.save('error_model.keras')
            print("Model saved as error_model.keras")
            raise e
        finally:
            # Restore original SIGINT handler
            signal.signal(signal.SIGINT, original_sigint)
            controller.status_label.value = "Training completed or stopped."
        
        return model, history
    
    except Exception as e:
        print(f'\nAn error occurred during setup: {str(e)}')
        raise e

# Modified evaluation function to include Dice scores
def evaluate_weighted_model(model, val_generator):
    print("\nModel Evaluation:")
    print("----------------")
    
    val_predictions = []
    val_true = []
    dice_scores = []
    
    for i in tqdm(range(len(val_generator)), desc="Evaluating"):
        x, y = val_generator[i]
        pred = model.predict(x, verbose=0)
        val_predictions.append(pred)
        val_true.append(y)
        
        # Calculate Dice score for this batch
        dice = dice_coefficient(y, pred).numpy()
        dice_scores.append(dice)
    
    val_pred = np.concatenate(val_predictions)
    val_true = np.concatenate(val_true)
    
    # Calculate metrics
    mean_dice = np.mean(dice_scores)
    accuracy = np.mean(np.argmax(val_pred, axis=-1) == np.argmax(val_true, axis=-1))
    
    # Create confusion matrix
    cm = confusion_matrix(
        np.argmax(val_true.reshape(-1, val_true.shape[-1]), axis=-1),
        np.argmax(val_pred.reshape(-1, val_pred.shape[-1]), axis=-1)
    )
    
    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()
    
    print(f"\nOverall Accuracy: {accuracy:.4f}")
    print(f"Mean Dice Score: {mean_dice:.4f}")
    
    # Per-class metrics
    for i in range(val_true.shape[-1]):
        class_acc = np.mean(
            np.argmax(val_pred, axis=-1)[np.argmax(val_true, axis=-1) == i] == i
        )
        print(f"Class {i} Accuracy: {class_acc:.4f}")
    
    return accuracy, mean_dice, cm

In [24]:
def train_model(validation_split=0.2, max_epochs=5, num_samples=10):
    try:
        # Initialize controller
        controller = TrainingController()
        
        # Load and prepare data
        dataset = BrainTumor3DDataset(base_path='.')
        print("Starting data preparation...")
        images, labels = dataset.prepare_data(num_samples=num_samples)
        
        # Validate data
        dataset.validate_data_integrity(images, labels)
        
        # Split data
        print("\nCreating train/val split...")
        X_train, X_val, y_train, y_val = train_test_split(
            images, labels, test_size=validation_split, random_state=42
        )
        
        # Clear memory
        del images, labels
        gc.collect()
        
        # Create data generators
        patch_size = (64, 64, 64)
        train_generator = DataGenerator3D(
            X_train, y_train, batch_size=1, patch_size=patch_size, augment=True
        )
        val_generator = DataGenerator3D(
            X_val, y_val, batch_size=1, patch_size=patch_size, augment=False
        )
        
        # Create and compile model
        input_shape = (*patch_size, 4)
        model = create_3d_unet(input_shape, n_classes=4)
        
        optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
        
        model.compile(
            optimizer=optimizer,
            loss=dice_loss,
            metrics=['accuracy', tf.keras.metrics.MeanIoU(num_classes=4)]
        )
        
        # Callbacks
        callbacks = [
            CustomCallback(controller),
            tf.keras.callbacks.EarlyStopping(
                monitor='val_loss',
                patience=5,
                restore_best_weights=True,
                verbose=1
            ),
            tf.keras.callbacks.ModelCheckpoint(
                'best_3d_model.keras',
                save_best_only=True,
                monitor='val_loss',
                verbose=1
            ),
            tf.keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss',
                factor=0.5,
                patience=3,
                min_lr=1e-6,
                verbose=1
            ),
            tf.keras.callbacks.CSVLogger('training_log.csv')
        ]
        
        # Train model
        history = model.fit(
            train_generator,
            validation_data=val_generator,
            epochs=max_epochs,
            callbacks=callbacks
        )
        
        return model, history
    
    except Exception as e:
        print(f'\nAn error occurred during training: {str(e)}')
        if 'model' in locals():
            model.save('error_model.keras')
            print('Model saved as error_model.keras')
        raise e

In [25]:
def evaluate_model(model, val_generator):
    print("\nModel Evaluation:")
    print("----------------")
    
    val_predictions = []
    val_true = []
    
    for i in tqdm(range(len(val_generator)), desc="Evaluating"):
        x, y = val_generator[i]
        pred = model.predict(x, verbose=0)
        val_predictions.append(pred)
        val_true.append(y)
    
    val_pred = np.concatenate(val_predictions)
    val_true = np.concatenate(val_true)
    
    # Calculate metrics
    accuracy = np.mean(np.argmax(val_pred, axis=-1) == np.argmax(val_true, axis=-1))
    
    # Create confusion matrix
    cm = confusion_matrix(
        np.argmax(val_true.reshape(-1, val_true.shape[-1]), axis=-1),
        np.argmax(val_pred.reshape(-1, val_pred.shape[-1]), axis=-1)
    )
    
    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()
    
    print(f"\nOverall Accuracy: {accuracy:.4f}")
    
    # Per-class metrics
    for i in range(val_true.shape[-1]):
        class_acc = np.mean(
            np.argmax(val_pred, axis=-1)[np.argmax(val_true, axis=-1) == i] == i
        )
        print(f"Class {i} Accuracy: {class_acc:.4f}")
    
    return accuracy, cm

In [None]:
if __name__ == "__main__":
    try:
        # Set memory growth for GPU
        physical_devices = tf.config.list_physical_devices('GPU')
        if physical_devices:
            print(f"Found {len(physical_devices)} GPU(s)")
            for device in physical_devices:
                tf.config.experimental.set_memory_growth(device, True)
                print(f"Enabled memory growth for {device}")
        else:
            print("No GPU devices found. Running on CPU.")

        # Initialize dataset
        print("\nInitializing dataset...")
        dataset = BrainTumor3DDataset(base_path='.')
        
        # Load and prepare data
        num_samples = 10  # Adjust based on your available memory
        print(f"\nLoading {num_samples} samples...")
        images, labels = dataset.prepare_data(num_samples=num_samples)
        
        # Analyze class distribution before training
        class_counts, class_percentages = dataset.analyze_class_distribution(labels)
        print("\nClass distribution before training:")
        for class_idx, count in class_counts.items():
            print(f"Class {class_idx}: {count:,} voxels ({class_percentages[class_idx]:.2f}%)")
        
        # Split data
        validation_split = 0.2
        print(f"\nSplitting data with {validation_split:.0%} validation split...")
        X_train, X_val, y_train, y_val = train_test_split(
            images, labels, 
            test_size=validation_split, 
            random_state=42,
            shuffle=True
        )
        
        # Clear original arrays to free memory
        del images, labels
        gc.collect()
        
        # Create data generators
        patch_size = (64, 64, 64)
        print("\nInitializing data generators...")
        train_generator = DataGenerator3D(
            X_train, y_train,
            batch_size=1,
            patch_size=patch_size,
            augment=True
        )
        
        val_generator = DataGenerator3D(
            X_val, y_val,
            batch_size=1,
            patch_size=patch_size,
            augment=False
        )
        
        print(f"Training samples: {len(train_generator)}")
        print(f"Validation samples: {len(val_generator)}")
        
        # Train model with weighted loss
        print("\nStarting model training...")
        max_epochs = 30  # Adjust based on your needs
        model, history = train_weighted_model(
            validation_split=validation_split,
            max_epochs=max_epochs,
            num_samples=num_samples
        )
        
        # Evaluate model
        print("\nEvaluating model...")
        accuracy, mean_dice, conf_matrix = evaluate_weighted_model(model, val_generator)
        
        # Plot and save training history
        print("\nPlotting training history...")
        plt.figure(figsize=(15, 5))
        
        # Loss plot
        plt.subplot(1, 3, 1)
        plt.plot(history.history['loss'], label='Training Loss')
        plt.plot(history.history['val_loss'], label='Validation Loss')
        plt.title('Model Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Weighted Dice Loss')
        plt.legend()
        
        # Dice coefficient plot
        plt.subplot(1, 3, 2)
        plt.plot(history.history['dice_coefficient'], label='Training Dice')
        plt.plot(history.history['val_dice_coefficient'], label='Validation Dice')
        plt.title('Dice Coefficient')
        plt.xlabel('Epoch')
        plt.ylabel('Dice Score')
        plt.legend()
        
        # IoU plot
        plt.subplot(1, 3, 3)
        plt.plot(history.history['mean_io_u'], label='Training IoU')
        plt.plot(history.history['val_mean_io_u'], label='Validation IoU')
        plt.title('Mean IoU')
        plt.xlabel('Epoch')
        plt.ylabel('IoU Score')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig('training_history_3d.png')
        plt.show()
        
        # Save final model
        print("\nSaving final model...")
        model.save('final_brain_tumor_model.keras')
        
        # Save training history
        print("Saving training history...")
        with open('training_history.json', 'w') as f:
            json.dump(history.history, f)
            
        print("\nTraining completed successfully!")
        
    except Exception as e:
        print(f"\nAn error occurred during execution: {str(e)}")
        import traceback
        traceback.print_exc()
        
        # Try to save model if it exists
        if 'model' in locals():
            try:
                model.save('emergency_saved_model.keras')
                print('Emergency model save successful')
            except:
                print('Failed to save model during error recovery')

No GPU devices found. Running on CPU.

Initializing dataset...

Loading 10 samples...


Loading data: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]



Class distribution before training:
Class 0.0: 88,051,681 voxels (98.62%)
Class 1.0: 786,221 voxels (0.88%)
Class 2.0: 197,447 voxels (0.22%)
Class 3.0: 244,651 voxels (0.27%)

Splitting data with 20% validation split...

Initializing data generators...
Total valid patches: 388
Total valid patches: 86
Training samples: 388
Validation samples: 86

Starting model training...


VBox(children=(Button(button_style='danger', description='Stop Training', icon='stop', style=ButtonStyle(), to…

Starting data preparation...


Loading data: 100%|██████████| 10/10 [00:04<00:00,  2.07it/s]



Data Validation Report:
-----------------------
Number of samples: 10
Image shape: (240, 240, 155, 4)
Label shape: (240, 240, 155)

Sample 0:
Image value range: [-5.103, 10.282]
Unique labels: [0. 1. 2. 3.]

Sample 1:
Image value range: [-3.452, 12.930]
Unique labels: [0. 1. 2. 3.]

Sample 2:
Image value range: [-4.458, 11.208]
Unique labels: [0. 1. 2. 3.]

Sample 3:
Image value range: [-4.506, 12.677]
Unique labels: [0. 1. 2. 3.]

Sample 4:
Image value range: [-3.708, 12.848]
Unique labels: [0. 1. 2. 3.]

Sample 5:
Image value range: [-5.864, 10.379]
Unique labels: [0. 1. 2. 3.]

Sample 6:
Image value range: [-4.986, 13.716]
Unique labels: [0. 1. 2. 3.]

Sample 7:
Image value range: [-3.634, 12.997]
Unique labels: [0. 1. 2. 3.]

Sample 8:
Image value range: [-5.502, 11.336]
Unique labels: [0. 1. 2. 3.]

Sample 9:
Image value range: [-4.710, 11.241]
Unique labels: [0. 1. 2. 3.]

Class distribution:
Class 0.0: 88051681 voxels (98.62%)
Class 1.0: 786221 voxels (0.88%)
Class 2.0: 197447 

  self._warn_if_super_not_called()


Training started. Press the Stop button or Ctrl+C to stop training safely.
Epoch 1/30




[1m388/388[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - dice_coefficient: 0.1764 - loss: 0.7880 - mean_io_u_2: 0.3849 - precision_2: 0.6098 - recall_2: 0.2243
Epoch 1: val_dice_coefficient improved from -inf to 0.17057, saving model to best_3d_model_weighted.keras

Epoch 1: saving model to checkpoint_epoch_01.keras
[1m388/388[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m412s[0m 1s/step - dice_coefficient: 0.1765 - loss: 0.7879 - mean_io_u_2: 0.3849 - precision_2: 0.6101 - recall_2: 0.2246 - val_dice_coefficient: 0.1706 - val_loss: 0.7896 - val_mean_io_u_2: 0.3750 - val_precision_2: 0.4740 - val_recall_2: 0.3277 - learning_rate: 1.0000e-04
Epoch 2/30
[1m388/388[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - dice_coefficient: 0.2590 - loss: 0.7151 - mean_io_u_2: 0.3855 - precision_2: 0.7516 - recall_2: 0.4443
Epoch 2: val_dice_coefficient improved from 0.17057 to 0.21956, saving model to best_3d_model_weighted.keras

Epoch 2: saving model to chec