In [1]:
from IPython.core.display import display, HTML
display(HTML(
    '<style>'
        '#notebook { padding-top:0px !important; } ' 
        '.container { width:100% !important; } '
        '.end_space { min-height:0px !important; } '
    '</style>'
))

In [2]:
import os
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms


In [3]:
X_train_path = os.path.join('FashionDataset', 'split', 'train.txt')
X_train_file = open(X_train_path).read().split('\n')

X_val_path = os.path.join('FashionDataset', 'split', 'val.txt')
X_val_file = open(X_val_path).read().split('\n')

In [4]:
class FashionDataset(Dataset):
    def __init__(self, img_path='FashionDataset/', 
                 split_path='FashionDataset/split/', 
                 transform=None, flag=None):

        super().__init__()
        
        self.data = []
        self.labels = []
        self.transform = transform
        
        if flag == 'train':
            X_path = os.path.join(split_path, 'train.txt')
            X_files = open(X_path).read().split('\n')[:32]
            y_path = os.path.join(split_path, 'train_attr.txt')
            y_files = open(y_path).read().split('\n')
            
        if flag == 'val':
            X_path = os.path.join(split_path, 'val.txt')
            X_files = open(X_path).read().split('\n')
            y_path = os.path.join(split_path, 'val_attr.txt')
            y_files = open(y_path).read().split('\n')
            
        for i in range(len(X_files)):
            # images path
            self.data.append(os.path.join(img_path, X_files[i]))

            # labels
            tmp_labels = y_files[i].split(' ')
            self.labels.append({
                'cat1': int(tmp_labels[0]),
                'cat2': int(tmp_labels[1]),
                'cat3': int(tmp_labels[2]),
                'cat4': int(tmp_labels[3]),
                'cat5': int(tmp_labels[4]),
                'cat6': int(tmp_labels[5])
            })
            
    def __getitem__(self, idx):
        # read image
        img_path = self.data[idx]
        img = Image.open(img_path)
        
        # check if transform
        if self.transform:
            img = self.transform(img)
            
        opt_data = {
            'img': img,
            'labels': self.labels[idx]
        }
        return opt_data
    
    def __len__(self):
        return len(self.data)

In [5]:
# train data
flag = 'train'
train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_data = FashionDataset(transform = train_transform, flag=flag)

In [6]:
train_data.__getitem__(0)

{'img': tensor([[[1.7694, 1.7694, 1.7694,  ..., 1.6495, 1.6495, 1.6495],
          [1.7694, 1.7694, 1.7694,  ..., 1.6495, 1.6495, 1.6495],
          [1.7694, 1.7694, 1.7694,  ..., 1.6495, 1.6495, 1.6495],
          ...,
          [1.4783, 1.4783, 1.4783,  ..., 1.4269, 1.4269, 1.4269],
          [1.4783, 1.4783, 1.4783,  ..., 1.4269, 1.4269, 1.4098],
          [1.4783, 1.4783, 1.4783,  ..., 1.4269, 1.4269, 1.4098]],
 
         [[1.9734, 1.9734, 1.9734,  ..., 1.8333, 1.8333, 1.8333],
          [1.9734, 1.9734, 1.9734,  ..., 1.8333, 1.8333, 1.8333],
          [1.9734, 1.9734, 1.9734,  ..., 1.8333, 1.8333, 1.8333],
          ...,
          [1.6758, 1.6758, 1.6758,  ..., 1.5882, 1.5882, 1.5882],
          [1.6758, 1.6758, 1.6758,  ..., 1.5882, 1.5882, 1.5707],
          [1.6758, 1.6758, 1.6758,  ..., 1.5882, 1.5882, 1.5707]],
 
         [[2.1694, 2.1694, 2.1694,  ..., 2.0823, 2.0823, 2.0823],
          [2.1694, 2.1694, 2.1694,  ..., 2.0823, 2.0823, 2.0823],
          [2.1694, 2.1694, 2.1694

In [7]:
class MultiLabelModel(nn.Module):
    def __init__(self, n_cats):
        super().__init__()
        # pretrained resnet50 as base model
        self.resnet50 = models.resnet50(pretrained=True)
        
        # size of last channel before classifier
        last_channel = models.resnet50().fc.out_features
        
        

In [8]:
train_data.__getitem__(0)['img'].size()

torch.Size([3, 224, 224])

In [9]:
test_model = models.resnet50(pretrained=True)
train_dataloader = DataLoader(train_data, batch_size=2, shuffle=True, num_workers=8)

In [10]:
train_dataloader_iterator = iter(train_dataloader)

In [11]:
data, target = next(train_dataloader_iterator)

KeyboardInterrupt: 

In [12]:
a = {1: 1, 2: 2, 3:3}

In [14]:
sum(a.values())

6