In [78]:
import os
import sys
import random
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import Dataset
from torchvision.transforms import functional as TF
from dataclasses import dataclass
from __future__ import annotations
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

In [76]:
@dataclass
class ImageDatasetConfig:
    foreground_dir: str = "shoe_dataset/"
    background_dir: str = "shoe_dataset/bg/"
    mode: str = "train"
    image_size: int = 256
    augment: bool = True
    augment_prob: float = 0.5

img_ds_config = ImageDatasetConfig()

In [84]:
class ImageDataset(Dataset):
    def __init__(self, image_ds_config) -> None:
        super().__init__()

        self.foreground_directory = image_ds_config.foreground_dir
        self.background_directory = image_ds_config.background_dir
        self.mode = image_ds_config.mode
        self.image_size = image_ds_config.image_size
        self.augment = image_ds_config.augment
        self.augment_prob = image_ds_config.augment_prob
        self.rotation_degree = [0, 90, 180, 270]

        self.train_images = list(map(lambda x: f"{self.foreground_directory}train/{x}", os.listdir(f"{self.foreground_directory}/train")))
        self.val_images = list(map(lambda x: f"{self.foreground_directory}val/{x}", os.listdir(f"{self.foreground_directory}/val")))
        self.test_images = list(map(lambda x: f"{self.foreground_directory}test/{x}", os.listdir(f"{self.foreground_directory}/test")))
        self.train_background_images = list(map(lambda x: f"{self.background_directory}train/{x}", os.listdir(f"{self.background_directory}/train")))
        self.val_background_images = list(map(lambda x: f"{self.background_directory}val/{x}", os.listdir(f"{self.background_directory}/val")))

    def __getitem__(self, index):
        if self.mode == "train":
            img_path = self.train_images[index]
        elif self.mode == "val":
            img_path = self.val_images[index]
        else:
            img_path = self.test_images[index]

        print(img_path)
        return self.transform_image(img_path, self.augment)
    
    def __len__(self):
        if self.mode == "train":
            return len(self.train_images)
        elif self.mode == "val":
            return len(self.val_images)
        else:
            return len(self.test_images)
    
    def transform_image(self, img_path: str, augment: bool):
        image_alpha = Image.open(img_path)
        assert str(image_alpha.mode) == 'RGBA'
        x, y = image_alpha.size
        aspect_ratio = y / x
        ch_r, ch_g, ch_b, ch_a = image_alpha.split()
        img = Image.merge('RGB', (ch_r, ch_g, ch_b))
        mask = ch_a
        
        if self.mode == "train":
            bg = Image.open(self.train_background_images[random.randint(0, len(self.train_background_images)-1)])
            bg = bg.resize(img.size)
            bg.paste(img, mask=mask)
        else:
            bg = Image.open(self.val_background_images[random.randint(0, len(self.val_background_images)-1)])
            bg = bg.resize(img.size)
            bg.paste(img, mask=mask)

        img = bg
            
        if augment and random.random() < self.augment_prob:
            transform = list()
            resize_range = random.randint(300, 320)
            transform.append(T.Resize((int(resize_range * aspect_ratio), resize_range)))
            rot_deg = self.rotation_degree[random.randint(0, 3)]
            if rot_deg == 90 or rot_deg == 270:
                aspect_ratio = 1 / aspect_ratio
            transform.append(T.RandomRotation((rot_deg, rot_deg)))
            rot_range = random.randint(-10, 10)
            transform.append(T.RandomRotation((rot_range, rot_range)))
            crop_range = random.randint(270, 300)
            transform.append(T.CenterCrop((int(crop_range * aspect_ratio), crop_range)))
            transform = T.Compose(transform)

            img = transform(img)
            mask = transform(mask)

            transform = T.ColorJitter(brightness=0.2, contrast=0.2, hue=0.2)

            img = transform(img)

            if random.random() < 0.5:
                img = TF.hflip(img)
                mask = TF.hflip(mask)
            
            if random.random() < 0.5:
                img = TF.vflip(img)
                mask = TF.vflip(mask)
            
        transform = list()
        transform.append(T.Resize((self.image_size, self.image_size)))
        transform.append(T.ToTensor())
        transform = T.Compose(transform)

        img = transform(img)
        mask = transform(mask)

        return img, mask

In [85]:
train_dataset = ImageDataset(img_ds_config)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)

In [86]:
for img, mask in train_dataloader:
    print(img.shape)
    print(mask.shape)

shoe_dataset/train/480.png
shoe_dataset/train/412.png
torch.Size([2, 3, 256, 256])
torch.Size([2, 1, 256, 256])
shoe_dataset/train/88.png
shoe_dataset/train/221.png
torch.Size([2, 3, 256, 256])
torch.Size([2, 1, 256, 256])
shoe_dataset/train/603.png
shoe_dataset/train/134.png
torch.Size([2, 3, 256, 256])
torch.Size([2, 1, 256, 256])
shoe_dataset/train/213.png
shoe_dataset/train/288.png
torch.Size([2, 3, 256, 256])
torch.Size([2, 1, 256, 256])
shoe_dataset/train/598.png
shoe_dataset/train/132.png
torch.Size([2, 3, 256, 256])
torch.Size([2, 1, 256, 256])
shoe_dataset/train/612.png
shoe_dataset/train/619.png
torch.Size([2, 3, 256, 256])
torch.Size([2, 1, 256, 256])
shoe_dataset/train/311.png
shoe_dataset/train/54.png
torch.Size([2, 3, 256, 256])
torch.Size([2, 1, 256, 256])
shoe_dataset/train/259.png
shoe_dataset/train/491.png
torch.Size([2, 3, 256, 256])
torch.Size([2, 1, 256, 256])
shoe_dataset/train/236.png
shoe_dataset/train/116.png
torch.Size([2, 3, 256, 256])
torch.Size([2, 1, 256, 

OSError: image file is truncated (0 bytes not processed)