The VU cluster has 32 CPU cores, and therefore I hypothesize that 32 workers could be used for loading data using the pytorch DataLoader class.

This file serves as a test for that, using data applicable to the research question and a version of DataSet class along with some data augmentations.

In [5]:
from pathlib import Path
from typing import Any
import os

import pandas as pd
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import (
    Compose,
    Resize,
    ToTensor,
    Grayscale,
    AugMix,
    RandAugment,
    FiveCrop,
    RandomCrop,
    AutoAugmentPolicy,
    AutoAugment,
)


# from load_data import load_image_data, load_label_data


class THGStrainStressDataset(Dataset[Any]):
    def __init__(
        self,
        root_data_dir: str,
        folder: int,
        targets: np.ndarray,
        extension: str = "bmp",
        data_transform=None,
        target_transform=None,
    ):
        # header = 0, assume there is a header in the labels.csv file.
        self._data = Path(root_data_dir) / str(folder)
        self.group = folder
        self.targets = targets
        self.extension = extension
        self.transform = data_transform
        self.target_transform = target_transform
        self._length = sum(1 for _ in os.listdir(self._data))

    def __len__(self):
        return self._length

    def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
        # TODO: Make it work with z-stacks!!!
        # https://stackoverflow.com/a/60176057
        # Assuming images follow [0, n-1], so they can be accesed directly.
        # data_path = self.data_dir / (str(int(self.labels["index"].iloc[idx])) + ".tif")
        image = Image.open(self._data / f"{str(idx)}.{self.extension}")

        if self.transform:
            image = self.transform(image)

        if self.target_transform:
            self.targets = self.target_transform(self.targets)

        return image, self.targets


# def create_dataloader(
#     batch_size: int,
#     data_path: Path,
#     label_path: Path,
#     shuffle: bool = True,
# ) -> DataLoader[Any]:
#     return DataLoader(
#         dataset=THGStrainStressDataset(
#             data_dir=data_path,
#             label_path=label_path,
#             data_transform=Compose(
#                 [
#                     RandomCrop(size=(258, 258)),
#                     # Resize((258, 258)),
#                     Grayscale(),
#                     # AugMix(),
#                     # RandAugment(num_ops=2),
#                     ToTensor(),
#                     # Lambda(lambda y: (y - y.mean()) / y.std()), # To normalize the image.
#                 ],
#             ),
#         ),
#         batch_size=batch_size,
#         shuffle=shuffle,
#     )


In [6]:
import importlib
import pandas as pd
import numpy as np
from pathlib import Path
from torchvision.transforms import Compose, Resize, Grayscale, ToTensor
from torch.utils.data import ConcatDataset

data_transform = Compose(
    [
        # RandomCrop(size=(258, 258)),
        Resize((258, 258)),
        Grayscale(),
        # AugMix(),
        # RandAugment(num_ops=2),
        ToTensor(),
        # Lambda(lambda y: (y - y.mean()) / y.std()), # To normalize the image.
    ]
)

datasets = []
groups = []
for _, labels in pd.read_csv('../data/z-stacks/sigmoid_labels.csv').iterrows():

    folder = int(labels["index"])
    targets = labels[["A", "h", "slope", "C"]].to_numpy(dtype=float)

    if not (Path('../data/z-stacks/') / str(folder)).is_dir():
        continue

    dataset = THGStrainStressDataset(
        root_data_dir='../data/z-stacks/',
        folder=folder,
        targets=targets,
        data_transform=data_transform,
    )
    datasets.append(dataset)
    groups.extend([folder] * len(dataset))

groups = np.array(groups)
dataset = ConcatDataset(datasets)

In [10]:
from random import shuffle
import time
from tqdm import tqdm

pin_memory = True
print('pin_memory is', pin_memory)

for num_workers in range(20, 40, 1): 
    train_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=30, num_workers=num_workers, pin_memory=pin_memory)
    start = time.time()
    for epoch in tqdm(range(1, 5)):
        for i, data in enumerate(train_loader):
            pass
    end = time.time()
    print("Finish with:{} second, num_workers={}".format(end - start, num_workers))

pin_memory is True


100%|██████████| 4/4 [00:11<00:00,  2.92s/it]


Finish with:11.696817874908447 second, num_workers=20


100%|██████████| 4/4 [00:10<00:00,  2.70s/it]


Finish with:10.787748098373413 second, num_workers=21


100%|██████████| 4/4 [00:10<00:00,  2.69s/it]


Finish with:10.748571634292603 second, num_workers=22


100%|██████████| 4/4 [00:10<00:00,  2.74s/it]


Finish with:10.972535371780396 second, num_workers=23


100%|██████████| 4/4 [00:10<00:00,  2.75s/it]


Finish with:10.998663425445557 second, num_workers=24


100%|██████████| 4/4 [00:11<00:00,  2.75s/it]


Finish with:11.005809783935547 second, num_workers=25


100%|██████████| 4/4 [00:11<00:00,  2.79s/it]


Finish with:11.182546138763428 second, num_workers=26


100%|██████████| 4/4 [00:11<00:00,  2.80s/it]


Finish with:11.199867010116577 second, num_workers=27


100%|██████████| 4/4 [00:11<00:00,  2.89s/it]


Finish with:11.575569152832031 second, num_workers=28


100%|██████████| 4/4 [00:11<00:00,  2.90s/it]


Finish with:11.587659358978271 second, num_workers=29


100%|██████████| 4/4 [00:11<00:00,  2.91s/it]


Finish with:11.645626068115234 second, num_workers=30


100%|██████████| 4/4 [00:10<00:00,  2.71s/it]


Finish with:10.828525304794312 second, num_workers=31


100%|██████████| 4/4 [00:10<00:00,  2.68s/it]


Finish with:10.72502064704895 second, num_workers=32


100%|██████████| 4/4 [00:10<00:00,  2.72s/it]


Finish with:10.866923093795776 second, num_workers=33


100%|██████████| 4/4 [00:16<00:00,  4.01s/it]


Finish with:16.059099912643433 second, num_workers=34


100%|██████████| 4/4 [00:20<00:00,  5.20s/it]


Finish with:20.79282021522522 second, num_workers=35


100%|██████████| 4/4 [00:11<00:00,  2.80s/it]


Finish with:11.189171552658081 second, num_workers=36


100%|██████████| 4/4 [00:11<00:00,  2.80s/it]


Finish with:11.191312074661255 second, num_workers=37


 25%|██▌       | 1/4 [00:04<00:12,  4.16s/it]


KeyboardInterrupt: 

pin_memory is True


100%|██████████| 4/4 [00:23<00:00,  5.78s/it]

Finish with:23.138214349746704 second, num_workers=5

100%|██████████| 4/4 [00:20<00:00,  5.17s/it]

Finish with:20.68344020843506 second, num_workers=6

100%|██████████| 4/4 [00:18<00:00,  4.51s/it]

Finish with:18.05572271347046 second, num_workers=7

100%|██████████| 4/4 [00:16<00:00,  4.09s/it]

Finish with:16.365644693374634 second, num_workers=8

100%|██████████| 4/4 [00:15<00:00,  3.76s/it]

Finish with:15.046205043792725 second, num_workers=9

100%|██████████| 4/4 [00:14<00:00,  3.69s/it]

Finish with:14.744060516357422 second, num_workers=10

100%|██████████| 4/4 [00:13<00:00,  3.44s/it]

Finish with:13.76044750213623 second, num_workers=11

100%|██████████| 4/4 [00:14<00:00,  3.51s/it]

Finish with:14.033259630203247 second, num_workers=12

100%|██████████| 4/4 [00:12<00:00,  3.15s/it]

Finish with:12.621521472930908 second, num_workers=13

100%|██████████| 4/4 [00:12<00:00,  3.22s/it]

Finish with:12.886996269226074 second, num_workers=14

100%|██████████| 4/4 [00:12<00:00,  3.17s/it]

Finish with:12.667302370071411 second, num_workers=15

100%|██████████| 4/4 [00:11<00:00,  2.96s/it]

Finish with:11.828012943267822 second, num_workers=16

100%|██████████| 4/4 [00:11<00:00,  2.95s/it]

Finish with:11.805269718170166 second, num_workers=17

100%|██████████| 4/4 [00:11<00:00,  2.91s/it]

Finish with:11.644353151321411 second, num_workers=18

100%|██████████| 4/4 [00:11<00:00,  2.85s/it]Finish with:11.383978605270386 second, num_workers=19