In [1]:
import torch
from torch.utils import data
from torchvision import transforms
import os.path as osp
import glob
from PIL import Image

In [2]:
torch.manual_seed(0)

<torch._C.Generator at 0x1141315d0>

In [3]:
class ImageTransform():
    def __init__(self, resize, mean, std):
        self.data_transform = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(resize, scale=(0.5, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]),
            'val': transforms.Compose([
                transforms.Resize(resize),
                transforms.CenterCrop(resize),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
        }

    def __call__(self, img, phase='train'):
        return self.data_transform[phase](img)

In [4]:
root_path = "./data/hymenoptera_data"
target_path = osp.join(root_path, "train")

In [5]:
def make_data_list(phase="train"):
    root_path = "./data/hymenoptera_data"
    target_path = osp.join(root_path, phase)
    
    return glob.glob(osp.join(target_path, "*/*.jpg"))

In [6]:
train_list = make_data_list("train")
val_list = make_data_list("val")

In [7]:
class MyDataset(data.Dataset):
    def __init__(self, file_list, transform=None, phase='train'):
        self.file_list = file_list
        self.transform = transform
        self.phase = phase

    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, index):
        img_path = self.file_list[index]
        img = Image.open(img_path)
        img_transformed = self.transform(img, self.phase)
        
        if img_path.split("/")[-2] == "ants":
            label = 0
        elif img_path.split("/")[-2] == "bees":
            label = 1
        
        return img_transformed, label

In [8]:
train_dataset = MyDataset(file_list=train_list, transform=ImageTransform(224, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), phase='train')
val_dataset = MyDataset(file_list=val_list, transform=ImageTransform(224, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), phase='val')

In [9]:
index = 0

print(train_dataset.__len__())

img, label = train_dataset.__getitem__(index)
print(img.size())
print(label)

243
torch.Size([3, 224, 224])
1


In [10]:
batch_size = 32

train_dataloader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}

In [12]:
batch_iterator = iter(dataloaders_dict["train"])
inputs, labels = next(batch_iterator)
print(inputs.shape)
print(labels)

torch.Size([32, 3, 224, 224])
tensor([1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1,
        1, 1, 0, 0, 0, 0, 0, 1])
