In [13]:
import torch
import os
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
from PIL import Image
from torchvision import transforms
from sklearn.model_selection import train_test_split

In [14]:
def get_shuffled_images(data_dir):
    classes = os.listdir(data_dir)[1:]
    data = []
    for class_name in classes:
        class_dir = os.path.join(data_dir, class_name)
        class_label = 1 if class_name == 'Malignant' else 0
        for img_name in os.listdir(class_dir):
            image_path = os.path.join(class_dir, img_name)
            image = np.array(Image.open(image_path))
            label = class_label
            data.append((image, label))
    np.random.shuffle(data)
    return data

In [15]:
data_train_dir = "../data/train"
data_test_dir = "../data/test"

In [16]:
train_data = get_shuffled_images(data_train_dir)
test_data = get_shuffled_images(data_test_dir)

In [17]:
for image, label in train_data[:15]:
    print(f"label: {label}")

label: 1
label: 0
label: 1
label: 0
label: 0
label: 1
label: 0
label: 1
label: 1
label: 0
label: 0
label: 1
label: 1
label: 1
label: 1


In [18]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [19]:
transformed_train_data = [(transform(image), torch.tensor(label)) for image, label in train_data]
transformed_test_data = [(transform(image), torch.tensor(label)) for image, label in test_data]

In [20]:
train_images = torch.stack([item[0] for item in transformed_train_data])
train_labels = torch.stack([item[1] for item in transformed_train_data])
test_images_tensor = torch.stack([item[0] for item in transformed_test_data])
test_labels_tensor = torch.stack([item[1] for item in transformed_test_data])

In [21]:
train_val_images_numpy = train_images.numpy()
train_val_labels_numpy = train_labels.numpy()

train_images_numpy, val_images_numpy, train_labels_numpy, val_labels_numpy = train_test_split(
    train_val_images_numpy, 
    train_val_labels_numpy, 
    test_size=0.20, 
    random_state=42, 
    stratify=train_val_labels_numpy
)

In [22]:
train_images_tensor = torch.tensor(train_images_numpy)
train_labels_tensor = torch.tensor(train_labels_numpy)
val_images_tensor = torch.tensor(val_images_numpy)
val_labels_tensor = torch.tensor(val_labels_numpy)

In [23]:
train_dataset = TensorDataset(train_images_tensor, train_labels_tensor)
val_dataset = TensorDataset(val_images_tensor, val_labels_tensor)
test_dataset = TensorDataset(test_images_tensor, test_labels_tensor)

batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [24]:
torch.save(train_loader, "../data/train_loader.pt")
torch.save(val_loader, "../data/val_loader.pt")
torch.save(test_loader, "../data/test_loader.pt")