In [1]:
%matplotlib inline
import os, sys, re
import glob

import pandas as pd
import numpy as np
import torch
import torch.utils.data
import torch.nn

from random import randrange
from PIL import Image
import matplotlib.pyplot as plt

!pip install opencv-python -qqq
!pip install wandb -qqq
import wandb
wandb.login()



Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.




[34m[1mwandb[0m: Currently logged in as: [33mvalenetjong[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
import argparse
""" Training and hyperparameter search configurations """
curr_dir = os.getcwd()

parser = argparse.ArgumentParser(description='Final')
parser.add_argument('--img_dir', type=str, default='/Users/valenetjong/Downloads/Data-3',
                    help='directory for image storage')
parser.add_argument('--seed', type=int, default=1,
                    help='random seed (default: 1)')
parser.add_argument('--num_classes', type=int, default=3,
                    help='number of classes')
parser.add_argument('--process_flag', type=bool, default=False,
                    help="extract files from disk if True, use already extracted files, if False")
parser.add_argument('--create_dataset', type=bool, default=False,
                    help="create dataset from scratch if True, load in processed dataset if False")
parser.add_argument('--transforms', type=str, default='all',
                    help='transforms for data augmentation')
parser.add_argument('--threshold', type=float, default=3e-4,
                    help='early stopping criterion')
args = parser.parse_args('')
# Set random seed to reproduce results
torch.manual_seed(args.seed)

<torch._C.Generator at 0x1294052b0>

In [3]:
""" Set-up wandb """
sweep_config = {
    'method': 'bayes'
    }

metric = {
    'name': 'max val acc',
    'goal': 'maximize'   
    }

sweep_config['metric'] = metric

params = {
    'max_epochs': {
        'value': 250
        },
    'hidden_size': {
        'values': [8, 16],
        },
    'fc_size': {
        'values': [32, 64, 128, 256, 512]
        },
    'conv_in_size': {
        'values': [32, 64, 128, 256]
        },
    'conv_hid_size': {
        'values': [8, 16, 32]
        },
    'conv_out_size': {
        'values': [8, 16, 32]
        },
    'dropout': {
          'values': [0.15, 0.2, 0.25, 0.3]
        },
    'batch_size': {
        'distribution': 'q_log_uniform_values',
        'q': 8,
        'min': 8,
        'max': 64,
        },
    'lr': {
        'values': [1e-3, 1e-4, 1e-5]
        },
    }

sweep_config['parameters'] = params
# sweep_id = wandb.sweep(sweep_config, project="3D-masked-imgs")

### Download Files
Available at: https://www.kaggle.com/datasets/ninadaithal/imagesoasis/download?datasetVersionNumber=1

In [4]:
import cv2 as cv
import os
import numpy as np
from PIL import Image
# 248,358
# Preprocessing configurations
CONV_WIDTH = 137
CONV_HEIGHT = 167
DEMENTIA_TYPES = ['Mild Dementia', 'Moderate Dementia', 'Non Demented']
VIEWS = ['mpr-1']  # Corresponding to tra


def normalize_intensity(img):
    """
    Normalizes the intensity of an image to the range [0, 255].
    """
    img_min = img.min()
    img_max = img.max()
    normalized_img = (img - img_min) / (img_max - img_min) * 255
    return normalized_img.astype(np.uint8)

def pad_image_to_size(img, width, height):
    """
    Resizes and pads an image with zeros to the specified width and height.
    """
    # Resize the image to fit within the specified dimensions while maintaining aspect ratio
    scale = min(width / img.shape[1], height / img.shape[0])
    resized_img = cv.resize(img, None, fx=scale, fy=scale, interpolation=cv.INTER_AREA)

    # Calculate padding sizes
    y_pad = max(height - resized_img.shape[0], 0)
    x_pad = max(width - resized_img.shape[1], 0)
    y_offset = y_pad // 2
    x_offset = x_pad // 2

    # Create a padded image with the specified dimensions
    padded_img = np.zeros((height, width), dtype=resized_img.dtype)
    padded_img[y_offset:y_offset+resized_img.shape[0], x_offset:x_offset+resized_img.shape[1]] = resized_img
    return padded_img

def crop_black_boundary(mri_image):
    """
    Crops the black boundary from an MRI image.
    """
    _, thresh = cv.threshold(mri_image, 1, 255, cv.THRESH_BINARY)
    contours, _ = cv.findContours(thresh, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
    largest_contour = max(contours, key=cv.contourArea)
    x, y, w, h = cv.boundingRect(largest_contour)
    cropped_image = mri_image[y:y+h, x:x+w]
    return cropped_image

def process_image_stack(filenames):
    """
    Processes a stack of MRI images and returns them as a single 3D array.
    """
    image_stack = []
    for fn in filenames:
        with Image.open(fn) as img:
            img = np.array(img.convert('L'))  # Convert to grayscale
        img = crop_black_boundary(img)
        img = normalize_intensity(img)
        img = pad_image_to_size(img, CONV_WIDTH, CONV_HEIGHT)
        image_stack.append(img)

        print(f"Processed {fn}, shape after padding: {img.shape}")

    stacked_img = np.stack(image_stack, axis=-1)
    print(f"Stacked image shape: {stacked_img.shape}")
    return stacked_img

def preprocess_data(data_folder, output_folder):
    print("Starting preprocessing...")

    for dementia_type in DEMENTIA_TYPES:
        dementia_folder = os.path.join(data_folder, dementia_type)
        if not os.path.exists(dementia_folder):
            print(f"Warning: Folder not found - {dementia_folder}")
            continue

        # Collect unique patient IDs based on the file naming convention
        patient_ids = set()
        for filename in os.listdir(dementia_folder):
            if 'mpr-1' in filename and filename.endswith('.jpg'):
                patient_id = filename.split('_mpr')[0]  # Extract patient ID
                patient_ids.add(patient_id)

        # Process each patient's images
        for patient_id in patient_ids:
            view_images = sorted([f for f in os.listdir(dementia_folder) if patient_id in f and 'mpr-1' in f])

            # Ensure we have images for mpr-1 view before proceeding
            if view_images:
                file_paths = [os.path.join(dementia_folder, img) for img in view_images]
                stacked_img = process_image_stack(file_paths)

                # Create directory for saving processed images if it doesn't exist
                output_subdir = os.path.join(output_folder, dementia_type)
                os.makedirs(output_subdir, exist_ok=True)

                # Save the stacked image as a NumPy array
                output_path = os.path.join(output_subdir, f'{patient_id}_3D.npy')
                np.save(output_path, stacked_img)
                print(f'Saved 3D image for patient {patient_id} in {output_subdir}')
            else:
                print(f"Warning: No images found for mpr-1 for patient {patient_id} in {dementia_type}")

data_folder = './Data_now'
output_folder = './3D_data'
preprocess_data(data_folder, output_folder)

In [5]:
base_directory = args.img_dir  # Replace with your path
class_counts = load_images_to_tensor(base_directory)

Processing subdirectory: Mild Dementia
Mild Dementia data saved to Mild Dementia_data.h5
Done processing subdirectory: Mild Dementia
Processing subdirectory: Very mild Dementia
Very mild Dementia data saved to Very mild Dementia_data.h5
Done processing subdirectory: Very mild Dementia
Processing subdirectory: Moderate Dementia
Moderate Dementia data saved to Moderate Dementia_data.h5
Done processing subdirectory: Moderate Dementia
Processing subdirectory: Non Demented
Non Demented data saved to Non Demented_data.h5
Done processing subdirectory: Non Demented
Class Counts: None


In [6]:
# base_directory = args.img_dir  # Replace with your path
# X_dataset, y_dataset, class_counts = load_images_to_tensor(base_directory)

# print(f"Combined Tensor Size: {X_dataset.size()}")
# print(f"Labels Tensor Size: {y_dataset.size()}")
# print(f"Class Counts: {class_counts}")

# print(X_dataset.shape)  # This will print (dataset_len, 60, img_height, img_width)
# print(y_dataset.shape)  # This will print (dataset_len,)