In [None]:
import kagglehub
import os
import random
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class BrainTumorDataset(Dataset):
    def __init__(self, mode='train', transform=None, train_ratio=0.8):
        # Initialize paths and class labels
        self.path = kagglehub.dataset_download("rm1000/brain-tumor-mri-scans")
        self.classes = ['glioma', 'healthy', 'meningioma', 'pituitary']
        self.classes_dirs = [os.path.join(self.path, c) for c in self.classes]
        self.mode = mode
        self.transform = transform

        # Load all images and labels, convert to RGB for ViT compatibility
        self.images, self.labels = self.load_images()
        
        # Split data into training and validation
        self.train_images, self.val_images, self.train_labels, self.val_labels = train_test_split(
            self.images, self.labels, test_size=1-train_ratio, random_state=42)

        # Select images and labels based on mode
        if mode == 'train':
            self.images, self.labels = self.train_images, self.train_labels
        elif mode == 'val':
            self.images, self.labels = self.val_images, self.val_labels

    def load_images(self):
        images, labels = [], []
        for idx, d in enumerate(self.classes_dirs):
            image_files = [os.path.join(d, f) for f in os.listdir(d)]
            for img_path in image_files:
                image = cv2.imread(img_path, 0)  # Read as grayscale
                if image is not None:  # Ensure image was read successfully
                    image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)  # Convert to 3-channel grayscale (RGB)
                    image_rgb = cv2.resize(image_rgb, (224, 224))  # Resize to (224, 224)
                    images.append(image_rgb)
                    labels.append(idx)
        return np.array(images), np.array(labels)


    def plot_random_images(self, num_images=10):
        for idx, d in enumerate(self.classes_dirs):
            image_files = [os.path.join(d, f) for f in os.listdir(d)]
            random_images = random.sample(image_files, num_images)

            plt.figure(figsize=(10, 10))
            for i, img_path in enumerate(random_images):
                image = cv2.imread(img_path, 0)
                plt.imshow(image, cmap='gray')
                plt.axis('off')
            print(f"Random image sample for {self.classes[idx]}")
            plt.show()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = Image.fromarray(self.images[idx])  # Convert to PIL for PyTorch compatibility
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label
    
    print("Runing")