In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## README

This notebook is responsible for downloading, preprocessing, and splitting the Brain MRI dataset from Kaggle in preparation for federated learning (FL) simulations.

It is a critical component of the reproducibility pipeline for our study Privacy-Preserving Biomedical AI in Local and Federated Learning using Gerchberg‚ÄìSaxton Data Transformations" [link_to_publication]. Please ensure that you follow the instructions in this notebook step-by-step to maintain consistency with the experimental setup described in the manuscript.

### ‚ö†Ô∏è Final Step Required
At the end of this notebook, you will be prompted to define a local saving path for the preprocessed data.

üö® Do not skip this step.

The saved data is required for all downstream notebooks, including Gerchberg‚ÄìSaxton transformations and federated model training.

## Imports and Functions

Cells below are the definitions of some helper functions we will be using in this notebook to read/write, preprocess and save the data in our workspace.

You can simply run all cells with no modification until the next block of cells (Load Data).

In [2]:
import os
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split

# Read all images and corresponding labels
def load_images_and_labels(base_path):
    images = []
    labels = []
    label_map = {"glioma": 0, "meningioma": 1, "notumor": 2, "pituitary": 3}

    for folder_type in ["Testing", "Training"]:
        folder_path = os.path.join(base_path, folder_type)
        for label_name, label_id in label_map.items():
            label_folder_path = os.path.join(folder_path, label_name)
            if os.path.exists(label_folder_path):
                for img_file in os.listdir(label_folder_path):
                    if img_file.endswith(".jpg"):
                        img_path = os.path.join(label_folder_path, img_file)
                        try:
                            img = Image.open(img_path)
                            images.append(img)
                            labels.append(label_id)
                        except Exception as e:
                            print(f"Error loading image {img_path}: {e}")
    return images, labels

# Resize and preprocess images
def preprocess_images(image_list, target_size=(299, 299)):
    processed_images = []
    for img in image_list:
        # Resize
        img_resized = img.resize(target_size)
        # Convert to grayscale (remove channel dimension)
        img_gray = img_resized.convert('L')
        # Convert to numpy array
        img_array = np.array(img_gray)
        # Normalize (simple normalization to [0, 1])
        img_normalized = img_array.astype(np.float32) / 255.0
        processed_images.append(img_normalized)

    return np.array(processed_images)

In [3]:
# Split data for clients and testing
def split_data_for_fl(images, labels, client_ratios, min_test_size=1000):

    # Ensure enough data for testing
    if len(images) < min_test_size:
         raise ValueError(f"Not enough data for minimum test size ({min_test_size})")

    # Split into training (clients) and testing sets
    train_images, test_images, train_labels, test_labels = train_test_split(
        images, labels, test_size=min_test_size, random_state=42, stratify=labels
    )

    # Split training data among clients
    total_train_size = len(train_images)
    client_datasets = []
    current_start_index = 0

    for ratio in client_ratios:
        client_size = int(total_train_size * ratio)
        client_end_index = current_start_index + client_size

        client_images = train_images[current_start_index:client_end_index]
        client_labels = train_labels[current_start_index:client_end_index]

        client_datasets.append((client_images, client_labels))
        current_start_index = client_end_index

    # Add any remaining data to the last client
    if current_start_index < total_train_size:
        client_datasets[-1] = (
            np.concatenate((client_datasets[-1][0], train_images[current_start_index:])),
            np.concatenate((client_datasets[-1][1], train_labels[current_start_index:]))
        )

    return client_datasets, (test_images, test_labels)

## Load and Preprocess Data
Below, we will be importing the data from kaggle using kagglehub.
Running the code below will automatically download the data to the folder path given as "path" variable.


Alternatively you can download the zipped data from the link below. Please make sure you save the data in your workspace for future steps.

https://www.kaggle.com/datasets/masoudnickparvar/brain-tumor-mri-dataset?resource=download

In [4]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("masoudnickparvar/brain-tumor-mri-dataset")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/brain-tumor-mri-dataset


In [5]:
import gc
## Load data
# Assuming 'path' from the previous cell is the base directory for the dataset
# If you downloaded the zipped version from the link, please make sure the base_data_path is leading to the folder location holds Training and Testing folders.
base_data_path = path
all_images, all_labels = load_images_and_labels(base_data_path) ## Cumulate all images and labels into individual arrays

# Preprocess data and labels
processed_images = preprocess_images(all_images)
labels_array = np.array(all_labels)

del all_images, all_labels ## Delete unnecessary objects for memory efficiency
gc.collect()

# If a channel dimension (e.g., grayscale) is required later, add:
# processed_images = np.expand_dims(processed_images, axis=-1)

# Define client ratios
client_ratios = [0.5, 0.3, 0.2]
min_test_samples = 1000 # Ensure at least 1000 images for testing

# Perform the split
client_data, test_data = split_data_for_fl(processed_images, labels_array, client_ratios, min_test_size=min_test_samples)

# Print info about the resulting datasets
print(f"Processed images shape: {processed_images.shape}")
print(f"Labels shape: {labels_array.shape}")

del processed_images, labels_array ## Delete unnecessary objects for memory efficiency
gc.collect()

print("\n--- Split Dataset Info ---")
for i, (client_images, client_labels) in enumerate(client_data):
    print(f"Client {i+1} dataset size: {len(client_images)} images")
    print(f"Client {i+1} images shape: {client_images.shape}")
    print(f"Client {i+1} labels shape: {client_labels.shape}")
    print(f"Client {i+1} label distribution: {np.unique(client_labels, return_counts = True)}")

print(f"\nTest dataset size: {len(test_data[0])} images")
print(f"Test images shape: {test_data[0].shape}")
print(f"Test labels shape: {test_data[1].shape}")

Processed images shape: (7023, 299, 299)
Labels shape: (7023,)

--- Split Dataset Info ---
Client 1 dataset size: 3011 images
Client 1 images shape: (3011, 299, 299)
Client 1 labels shape: (3011,)
Client 1 label distribution: (array([0, 1, 2, 3]), array([696, 729, 855, 731]))
Client 2 dataset size: 1806 images
Client 2 images shape: (1806, 299, 299)
Client 2 labels shape: (1806,)
Client 2 label distribution: (array([0, 1, 2, 3]), array([415, 401, 524, 466]))
Client 3 dataset size: 1206 images
Client 3 images shape: (1206, 299, 299)
Client 3 labels shape: (1206,)
Client 3 label distribution: (array([0, 1, 2, 3]), array([279, 281, 336, 310]))

Test dataset size: 1000 images
Test images shape: (1000, 299, 299)
Test labels shape: (1000,)


## Pickle save data to workspace

Please adjust the ```base_save_dir``` below according to your folder hierarchy.

In [6]:
import os
import pickle

def save_data_to_pickle(data, filename, directory):
  """Saves data to a pickle file in the specified directory."""
  os.makedirs(directory, exist_ok=True) # Create directory if it doesn't exist
  filepath = os.path.join(directory, filename)
  try:
    with open(filepath, 'wb') as f:
      pickle.dump(data, f, protocol=4)
    print(f"Successfully saved data to: {filepath}")
  except Exception as e:
    print(f"Error saving data to {filepath}: {e}")


# Define the base directory for saving
# base_save_dir = "/clients_data/bench" # Modify as needed
base_save_dir ='/content/drive/MyDrive/Spring 25/github_brainfl/data/bench'

# Save client data
for i, (client_images, client_labels) in enumerate(client_data):
  client_idx = i + 1
  save_data_to_pickle(client_images, f"data_client{client_idx}.pickle", base_save_dir)
  save_data_to_pickle(client_labels, f"labels_client{client_idx}.pickle", base_save_dir)

# Save test data
# If you prefer a different path for test data, uncomment and modify the line below
# test_save_dir = "/modify/if/different/test/path/preferred"
test_save_dir = base_save_dir # Use the same path as client data by default

save_data_to_pickle(test_data[0], "test_images.pickle", test_save_dir)
save_data_to_pickle(test_data[1], "test_labels.pickle", test_save_dir)


Successfully saved data to: /content/drive/MyDrive/Spring 25/github_brainfl/data/bench/data_client1.pickle
Successfully saved data to: /content/drive/MyDrive/Spring 25/github_brainfl/data/bench/labels_client1.pickle
Successfully saved data to: /content/drive/MyDrive/Spring 25/github_brainfl/data/bench/data_client2.pickle
Successfully saved data to: /content/drive/MyDrive/Spring 25/github_brainfl/data/bench/labels_client2.pickle
Successfully saved data to: /content/drive/MyDrive/Spring 25/github_brainfl/data/bench/data_client3.pickle
Successfully saved data to: /content/drive/MyDrive/Spring 25/github_brainfl/data/bench/labels_client3.pickle
Successfully saved data to: /content/drive/MyDrive/Spring 25/github_brainfl/data/bench/test_images.pickle
Successfully saved data to: /content/drive/MyDrive/Spring 25/github_brainfl/data/bench/test_labels.pickle
