<a href="https://colab.research.google.com/github/salma-abed/Deep-learning-based-automated-detection-and-classification-of-Alzheimer-s-disease-Using-Neuroimaging/blob/main/Data_splitter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import shutil
import random
import os
from google.colab import drive
from sklearn.model_selection import train_test_split

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
class DataSplitter:
    def __init__(self, data_path, label_names, test_size=0.2):
        self.data_path = data_path
        self.label_names = label_names
        self.test_size = test_size
        self.train_data = {}
        self.test_data = {}
        self.train_filenames = []
        self.test_filenames = []
        self.train_images = []
        self.test_images = []

    def load_data(self):
        for label in self.label_names:
            label_path = os.path.join(self.data_path, label)
            folder_names = os.listdir(label_path)
            self.train_data[label], self.test_data[label] = train_test_split(folder_names, test_size=self.test_size)

            # Create train and test directories for each label
            train_dir = os.path.join(self.data_path, "train", label)
            test_dir = os.path.join(self.data_path, "test", label)
            os.makedirs(train_dir, exist_ok=True)
            os.makedirs(test_dir, exist_ok=True)

            # Copy training folders to train directory
            for folder_name in self.train_data[label]:
                src = os.path.join(label_path, folder_name)
                dst = os.path.join(train_dir, folder_name)
                shutil.copytree(src, dst)

            # Copy testing folders to test directory
            for folder_name in self.test_data[label]:
                src = os.path.join(label_path, folder_name)
                dst = os.path.join(test_dir, folder_name)
                shutil.copytree(src, dst)

        print("Number of training labels:")
        for label in self.label_names:
            print(label, len(self.train_data[label]))
        print("Number of testing labels:")
        for label in self.label_names:
            print(label, len(self.test_data[label]))

    def save_data(self):
        with open("train_filenames.txt", "w") as f:
            for label in self.label_names:
                for filename in self.train_data[label]:
                    f.write(os.path.join("train", label, filename) + "\n")
        with open("test_filenames.txt", "w") as f:
            for label in self.label_names:
                for filename in self.test_data[label]:
                    f.write(os.path.join("test", label, filename) + "\n")

    def get_train_data(self):
        if self.train_images:
            return self.train_images

        for label in self.label_names:
            label_path = os.path.join(self.data_path, "train", label)
            for folder_name in os.listdir(label_path):
                folder_path = os.path.join(label_path, folder_name)
                for filename in os.listdir(folder_path):
                    if filename.endswith(".jpg"):
                        filepath = os.path.join(folder_path, filename)
                        self.train_images.append((filepath, label))

        random.shuffle(self.train_images)
        self.train_filenames = [x[0] for x in self.train_images]
        return self.train_images

    def get_test_data(self):
        if self.test_images:
            return self.test_images

        for label in self.label_names:
            label_path = os.path.join(self.data_path, "test", label)
            for folder_name in os.listdir(label_path):
                folder_path = os.path.join(label_path, folder_name)
                for filename in os.listdir(folder_path):
                    if filename.endswith(".jpg"):
                        filepath = os.path.join(folder_path, filename)
                        self.test_images.append((filepath, label))

        random.shuffle(self.test_images)
        self.test_filenames = [x[0] for x in self.test_images]
        return self.test_images

In [None]:
# Initialize DataSplitter
dataset_path = '/content/drive/MyDrive/output_images_preprocessed'
label_names = ['AD', 'CN', 'LMCI', 'MCI']

In [None]:
data_splitter = DataSplitter(dataset_path, label_names)

# Load and save data
data_splitter.load_data()
data_splitter.save_data()

# Get train and test data
train_data = data_splitter.get_train_data()
test_data = data_splitter.get_test_data()

# Print number of samples
print("Number of training samples:", len(train_data))
print("Number of testing samples:", len(test_data))

Number of training labels:
AD 79
CN 80
LMCI 78
MCI 80
Number of testing labels:
AD 20
CN 20
LMCI 20
MCI 20
Number of training samples: 3170
Number of testing samples: 800
