In [8]:
import pandas as pd
import os
from os.path import abspath, join, dirname, normpath
import sys
from skimage import io
import skimage
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
local_path = normpath(abspath(dirname("__file__")))

covert the breeds into binary

In [16]:
def makelabelset(csv_path):
    landmarks_frame = pd.read_csv(csv_path, sep=',')
    breeds = []
    for name in landmarks_frame.breed:
        if name not in breeds:
            breeds.append(name)
    classes = list(range(len(breeds)))
    labelset = dict(zip(breeds, classes))
    return labelset, breeds

In [17]:
class DogbreedDataset(Dataset):
    """Load in Dog Breed Dataset"""
    
    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied 
                on a sample.
        """
        csv_file = join(local_path, csv_file)
        root_dir = join(local_path, root_dir)
        self.landmarks_frame = pd.read_csv(csv_file, sep=',')
        self.root_dir = root_dir
        self.transform = transform
        self.labelset, _ = makelabelset(csv_file)
        
    def __len__(self):
        return len(self.landmarks_frame)
    
    def __getitem__(self, idx):
        # define the name of pic
        _img_name = self.landmarks_frame.iloc[idx, 0] + '.jpg'
        img_name = os.path.join(self.root_dir, _img_name)
        # read the jpg
        image = io.imread(img_name)
        # define the label
        landmarks = self.labelset[self.landmarks_frame.iloc[idx, 1]]
        dataset = {'image': image, 'landmarks': landmarks}
        
        if self.transform:
            dataset = self.transform(dataset)
            
        return dataset

In [18]:
dogdataset = DogbreedDataset('all/labels.csv', 'all/train')
dogdataloader = DataLoader(dogdataset, batch_size=4, shuffle=True, num_workers=2)

In [19]:
(dogdataset[100]['landmarks'])

44

In [24]:
from skimage import transform
class Rescale(object):
    """Rescale the image in a sample to a given size.
    
    Args:
        output_size (tuple or int): Desired output size. If tuple, output is 
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """
    
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size
        
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        h,w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size
            
        new_h, new_w = int(new_h), int(new_w)
        
        img = transform.resize(image, (new_h, new_w))
        
        return {'image': img, 'landarks': landmarks}

one-hot the dog breeds

In [25]:
class ToTensor(object):
    """Covert ndarrays in sample to Tensors."""
    
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        image = image.transpose((2, 0, 1))
        # one-hot
        num_classes = 120
        batch_size = 1
        label = torch.LongTensor([[landmarks]])
        landmarks = torch.zeros(batch_size, num_classes).scatter_(1, label, 1)
        return {'image': torch.from_numpy(image), 
                'landmarks': landmarks}

In [26]:
from torchvision import transforms
dogdataset = DogbreedDataset(csv_file='all/labels.csv', 
                             root_dir='all/train', 
                             transform= transforms.Compose([
                                 Rescale(224),
                                 ToTensor()
                             ]))