# Data Loader
The purpose of this notebook is to build a pipeline to efficiently load the train and test images for later use.
We utilize the PyTorch capabilities where ever possible.

In [62]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from torchvision.io import read_image
import sys

sys.path.insert(0, '../helpers/')
from helpers import helpers

CONFIG = helpers.get_config()
torch.manual_seed(CONFIG['RANDOM_STATE'])

<torch._C.Generator at 0x1deb642a8f0>

In [63]:
class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [64]:
training_data = CustomImageDataset(annotations_file=CONFIG['DATA_DIR_TRAINSET'], img_dir=CONFIG['DATA_DIR_FLATTENED'])
test_data = CustomImageDataset(annotations_file=CONFIG['DATA_DIR_TESTSET'], img_dir=CONFIG['DATA_DIR_FLATTENED'])

In [65]:
train_dataloader = DataLoader(training_data, batch_size=CONFIG['BATCH_SIZE'], shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=CONFIG['BATCH_SIZE'], shuffle=True)

In [66]:
# Example image.
train_features, train_labels = next(iter(train_dataloader))
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
print(f"Label: {label}")

RuntimeError: stack expects each tensor to be equal size, but got [1, 1024, 1024] at entry 0 and [4, 1024, 1024] at entry 5

In [16]:
for i in range(len(training_data)):
    train_features, _ = next(iter(train_dataloader))
    print(train_features.shape)
    break

torch.Size([8, 1, 1024, 1024])
