In [1]:
import os
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class DeepFashion(Dataset):
    '''
    DeepFashion dataset class
    img_dir : directory of image files
    ann_dir : directory of annotation
    transform : how to transform images, should be from torchvision.transforms
    '''
    def __init__(self, data_dir, ann_dir, transform=None):
        self.data_dir = data_dir
        self.imdir = []
        self.ann = []
        self.anndir = os.path.join(self.data_dir, ann_dir)
        self.transform = transform
        with open(self.anndir, 'r') as ann:
            for i, line in enumerate(ann):
                if i == 0:
                    self.len = int(line.rstrip('\n'))
                elif i > 1:
                    imdir, ann = line.rstrip('\n').split(maxsplit=1)
                    self.imdir.append(os.path.join(self.data_dir, imdir))
                    ann_np = np.array([int(i) for i in ann.split()])
                    ann_np[ann_np == -1] = 0
                    self.ann.append(ann_np)
        print('completed dataset loading')

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        impath = self.imdir[idx]
        img = Image.open(impath)
        ann = self.ann[idx]
        
        if self.transform:
            img = self.transform(img)

        sample = {'image': img, 'attributes': ann}

        return sample

In [2]:
transform = transforms.Compose([
    transforms.Resize((250, 250)),
    transforms.ToTensor()
])
dataset = DeepFashion('../../../local/DeepFashion/', 'ann/list_attr_img.txt',
                      transform=transform)

completed dataset loading


In [3]:
import matplotlib.pyplot as plt
for i in range(10):
    sample = dataset[i]
    img = sample['image'].numpy().transpose((1, 2, 0))
    plt.figure()
    plt.imshow(img)

In [5]:
def collate_fn(data):
    """
    converts list of samples to a batch.
    sample : {'image': (250, 250, 3), 'attributes': (1000,)}
    """
    image = np.stack([sample['image'] for sample in data], 0)
    attributes = np.array([sample['attributes'] for sample in data])
    return {'image': image, 'attributes': attributes}

In [6]:
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, collate_fn=collate_fn)
for i, data in enumerate(dataloader):
    print(data['image'].shape, data['attributes'].shape)
    if (i == 0):
          break

(10, 3, 250, 250) (10, 1000)
