In [None]:
import torch
import torch.nn as nn
import torchvision.datasets
import torchvision.transforms as transforms
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torch.utils.data.dataset import random_split
from torchvision.transforms.transforms import ToPILImage
from skimage import io

import numpy as np
import matplotlib.pyplot as plt

import pandas as pd
import os

In [None]:
source_dir = 'D:\\heart_data\\undersampled_heart_images'
list_dir = os.listdir(source_dir)
list_pd = []
for file_name in list_dir:
    list_pd.append(file_name)
df = pd.DataFrame(list_pd)
df

In [None]:
class CustomDataTransform(Dataset):
    def __init__(self, df, features_transform=None, label_transform=None):
        self.df = df
        self.features_transform = features_transform
        self.label_transform = label_transform
        self.root_dir_x = 'D:\\heart_data\\undersampled_heart_images'
        self.root_dir_y = 'D:\\heart_data\\heart_images'

    def __len__(self):
        return len(self.df)
        
    def __getitem__(self,index):
        img_path_x = os.path.join(self.root_dir_x, self.df.iloc[index, 0])
        img_path_y = os.path.join(self.root_dir_y, self.df.iloc[index, 0])
        image_x = io.imread(img_path_x)
        image_y = io.imread(img_path_y)
        
        if self.features_transform is not None:
            image_x = self.features_transform(image_x)

        if self.label_transform is not None:
            image_y = self.label_transform(image_y)

        return (image_x, image_y)

x_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((128,128)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5],
                         std=[0.1])
])

y_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((128,128)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5],
                         std=[0.1])
])

dataset = CustomDataTransform(df, features_transform=x_transform,
                                  label_transform=y_transform)

batch_size = 32
part = 0.8
train_lenght = int(len(dataset)*part)
test_lenght = int(len(dataset) - train_lenght)

train_set, test_set = random_split(dataset, [train_lenght, test_lenght])
train_loader = DataLoader(train_set, batch_size=batch_size, drop_last=False, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, drop_last=False, shuffle=True)

In [None]:
print(f'Length of dataset is {len(dataset)}')
plt.imshow(train_set[0][1][0], cmap='gray')