In [2]:
import os
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

def load_and_normalize_dataset(base_dir):
    normalized_data = []
    
    classes = {
        'Mild_Demented': 0,
        'Moderate_Demented': 1,
        'Non_Demented': 2,
        'Very_Mild_Demented': 3
    }
    
    # Process train, test, and val folders
    for split in ['train', 'test', 'val']:
        split_dir = os.path.join(base_dir, split)
        
        if not os.path.exists(split_dir):
            print(f"Warning: {split_dir} does not exist, skipping...")
            continue
        
        # Process each class folder
        for class_name, label in classes.items():
            class_dir = os.path.join(split_dir, class_name)
            
            if not os.path.exists(class_dir):
                print(f"Warning: Class directory {class_dir} does not exist, skipping...")
                continue
                
            # Process each image in the class folder
            print(f"Processing {split}/{class_name}...")
            for img_file in os.listdir(class_dir):
                if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    img_path = os.path.join(class_dir, img_file)
                    
                    try:
                        # Load image
                        img = Image.open(img_path).convert('L')
                        img_array = np.array(img) / 255.0  # Normalize to [0,1]
                        
                        # Process and store
                        item = {
                            'original_path': img_path,
                            'label': label,
                            'class_name': class_name,
                            'split': split,
                            'image': img_array,
                            'dataset': 'folder_dataset' if 'folder' in img_file else 'parquet_dataset'
                        }
                        
                        # Apply normalizations
                        normalized_item = normalize_mri_for_ventricles(item)
                        normalized_data.append(normalized_item)
                        
                    except Exception as e:
                        print(f"Error processing {img_path}: {e}")
    
    # Split back into train, test, and val
    train_data = [item for item in normalized_data if item['split'] == 'train']
    test_data = [item for item in normalized_data if item['split'] == 'test']
    val_data = [item for item in normalized_data if item['split'] == 'val']
    
    return train_data, test_data, val_data

def normalize_mri_for_ventricles(item):
    image = item['image']
    
    # 1. Simple normalization - just use the original normalized image
    item['image_normalized'] = image
    
    # 2. Ventricle enhancement
    
    # Create version optimized for dark ventricle regions
    img_uint8 = (item['image_normalized'] * 255).astype(np.uint8)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    enhanced = clahe.apply(img_uint8)
    item['image_enhanced'] = enhanced / 255.0
    
    # Create inverted version to highlight ventricles
    inverted = 1 - item['image_normalized']
    # Apply adaptive thresholding to highlight ventricle regions
    item['image_ventricle_focus'] = inverted
    
    # 3. Ventricle segmentation using Otsu thresholding
    otsu_thresh, _ = cv2.threshold(
        img_uint8, 
        0, 
        255, 
        cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU
    )

    
    _, ventricle_mask = cv2.threshold(
        img_uint8, 
        int(otsu_thresh * 0.5),   
        255, 
        cv2.THRESH_BINARY_INV
    )
    
    
    # Clean up mask
    kernel = np.ones((3, 3), np.uint8)
    ventricle_mask = cv2.morphologyEx(ventricle_mask, cv2.MORPH_OPEN, kernel)
    ventricle_mask = cv2.morphologyEx(ventricle_mask, cv2.MORPH_CLOSE, kernel)
    
    item['ventricle_mask'] = ventricle_mask / 255.0
    
    return item

def visualize_normalizations(data, num_samples=4):
   
    samples = []
    classes = set(item['class_name'] for item in data)
    
    for class_name in classes:
        class_items = [item for item in data if item['class_name'] == class_name]
        if class_items:
            samples.append(class_items[0])
            if len(samples) >= num_samples:
                break
    
    # Use random samples if we don't have enough
    if len(samples) < num_samples:
        remaining = [i for i in data if not any(s['original_path'] == i['original_path'] for s in samples)]
        if remaining:
            additional = np.random.choice(
                remaining,
                size=min(num_samples - len(samples), len(remaining)),
                replace=False
            ).tolist()
            samples.extend(additional)
    
    # Create visualization
    fig, axes = plt.subplots(len(samples), 5, figsize=(20, 4 * len(samples)))
    
    # Handle case with just one sample
    if len(samples) == 1:
        axes = [axes]
    
    for i, item in enumerate(samples):
        # Original
        axes[i][0].imshow(item['image'], cmap='gray')
        axes[i][0].set_title(f"{item['class_name']}\nOriginal")
        axes[i][0].axis('off')
        
        # Normalized
        axes[i][1].imshow(item['image_normalized'], cmap='gray')
        axes[i][1].set_title('Normalized')
        axes[i][1].axis('off')
        
        # Enhanced
        axes[i][2].imshow(item['image_enhanced'], cmap='gray')
        axes[i][2].set_title('Enhanced (CLAHE)')
        axes[i][2].axis('off')
        
        # Ventricle Focus (Inverted)
        axes[i][3].imshow(item['image_ventricle_focus'], cmap='gray')
        axes[i][3].set_title('Ventricle Focus')
        axes[i][3].axis('off')
        
        # Ventricle Mask
        axes[i][4].imshow(item['ventricle_mask'], cmap='gray')
        axes[i][4].set_title('Ventricle Mask')
        axes[i][4].axis('off')
    
    plt.tight_layout()
    return fig

def save_normalized_dataset(base_dir, train_data, test_data, val_data):
     
    # Create output directories
    output_dir = os.path.join(base_dir, 'normalized')
    os.makedirs(output_dir, exist_ok=True)
    
    saved_count = 0
    
    for split, data in [('train', train_data), ('test', test_data), ('val', val_data)]:
        split_dir = os.path.join(output_dir, split)
        os.makedirs(split_dir, exist_ok=True)
        
        # Create class directories
        class_names = set(item['class_name'] for item in data)
        for class_name in class_names:
            os.makedirs(os.path.join(split_dir, class_name), exist_ok=True)
        
        # Save normalized images
        for item in tqdm(data, desc=f"Saving {split} images"):
            # Generate filename
            original_filename = os.path.basename(item['original_path'])
            base_name = os.path.splitext(original_filename)[0]
            
            # Define paths for different normalizations
            class_dir = os.path.join(split_dir, item['class_name'])
            
            # Save normalized image
            norm_img = (item['image_normalized'] * 255).astype(np.uint8)
            norm_path = os.path.join(class_dir, f"{base_name}_norm.png")
            Image.fromarray(norm_img).save(norm_path)
            
            # Save enhanced image
            enhanced_img = (item['image_enhanced'] * 255).astype(np.uint8)
            enhanced_path = os.path.join(class_dir, f"{base_name}_enhanced.png")
            Image.fromarray(enhanced_img).save(enhanced_path)
            
            # Save ventricle focused image
            ventricle_img = (item['image_ventricle_focus'] * 255).astype(np.uint8)
            ventricle_path = os.path.join(class_dir, f"{base_name}_ventricle.png")
            Image.fromarray(ventricle_img).save(ventricle_path)
            
            # Save ventricle mask
            mask_img = (item['ventricle_mask'] * 255).astype(np.uint8)
            mask_path = os.path.join(class_dir, f"{base_name}_mask.png")
            Image.fromarray(mask_img).save(mask_path)
            
            saved_count += 1
    
    print(f"Total of {saved_count} images processed and saved with normalizations")

# Main execution
if __name__ == "__main__":

    base_dir = "Combined_MRI_Dataset"
    
    print("Loading and normalizing dataset...")
    train_data, test_data, val_data = load_and_normalize_dataset(base_dir)
    
    print(f"Processed {len(train_data)} training images, {len(test_data)} test images, and {len(val_data)} validation images")
    
    # Class distribution summary
    for split_name, split_data in [("Training", train_data), ("Test", test_data), ("Validation", val_data)]:
        print(f"\n{split_name} set class distribution:")
        class_counts = {}
        for item in split_data:
            class_name = item['class_name']
            if class_name in class_counts:
                class_counts[class_name] += 1
            else:
                class_counts[class_name] = 1
        
        for class_name, count in class_counts.items():
            print(f"  {class_name}: {count} images ({count/len(split_data)*100:.1f}%)")
    
    # Visualize normalizations
    print("\nGenerating visualization...")
    fig = visualize_normalizations(train_data)
    plt.savefig(os.path.join(base_dir, "normalization_visualization.png"))
    plt.close(fig)
    
    # Save normalized dataset
    print("\nSaving normalized dataset...")
    save_normalized_dataset(base_dir, train_data, test_data, val_data)
    
    print("\nFeature extraction complete!")

Loading and normalizing dataset...
Processing train/Mild_Demented...
Processing train/Moderate_Demented...
Processing train/Non_Demented...
Processing train/Very_Mild_Demented...
Processing test/Mild_Demented...
Processing test/Moderate_Demented...
Processing test/Non_Demented...
Processing test/Very_Mild_Demented...
Processing val/Mild_Demented...
Processing val/Moderate_Demented...
Processing val/Non_Demented...
Processing val/Very_Mild_Demented...
Processed 12543 training images, 3584 test images, and 1792 validation images

Training set class distribution:
  Mild_Demented: 2545 images (20.3%)
  Moderate_Demented: 1845 images (14.7%)
  Non_Demented: 4480 images (35.7%)
  Very_Mild_Demented: 3673 images (29.3%)

Test set class distribution:
  Mild_Demented: 727 images (20.3%)
  Moderate_Demented: 527 images (14.7%)
  Non_Demented: 1280 images (35.7%)
  Very_Mild_Demented: 1050 images (29.3%)

Validation set class distribution:
  Mild_Demented: 363 images (20.3%)
  Moderate_Demented: 

Saving train images: 100%|███████████████| 12543/12543 [00:18<00:00, 682.42it/s]
Saving test images: 100%|██████████████████| 3584/3584 [00:05<00:00, 682.52it/s]
Saving val images: 100%|███████████████████| 1792/1792 [00:02<00:00, 688.41it/s]

Total of 17919 images processed and saved with normalizations

Feature extraction complete!



