In [22]:
from __future__ import print_function
import os
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from torchvision import transforms
import torchvision.datasets as datasets

In [37]:
class TripletLandmarkDataset(Dataset):

    def __init__(self, root_dir, csv_name, num_of_triplets, transform = None):
        
        self.root_dir          = root_dir
        self.df                = pd.read_csv(csv_name)
        self.num_of_triplets   = num_of_triplets
        self.transform         = transform
        self.training_triplets = self.generate_triplets(self.df, self.num_of_triplets)
        
    @staticmethod
    def generate_triplets(df, num_of_triplets):
        
        def create_landmarks(df):
            landmarks = dict()
            for idx, label in enumerate(df['landmark_id']):
                if label not in landmarks:
                    landmarks[label] = []
                landmarks[label].append(df.iloc[idx, 0])
            return landmarks

        triplets = []
        classes = df['landmark_id'].unique()
        num_of_classes = df['landmark_id'].nunique()
        landmarks = create_landmarks(df) # = {'label0': [id, ...], 'label1': [id, ...], ...}
        
        for idummy in range(num_of_triplets):
            
            '''
              - randomly choose anchor, positive and negative images for triplet loss
              - icls0 for anchor and positive images
              - icls1 for negative image
              - at least, two images needed for anchor and positive images in icls0
              - negative image should have different label as anchor and positive images by definition
            '''
        
            #icls0 = np.random.randint(0, num_of_classes)
            #icls1 = np.random.randint(0, num_of_classes)
            icls0 = np.random.choice(classes)
            icls1 = np.random.choice(classes)
            while len(landmarks[icls0]) < 2:
                #icls0 = np.random.randint(0, num_of_classes)
                icls0 = np.random.choice(classes)
            while icls0 == icls1:
                #icls1 = np.random.randint(0, num_of_classes)
                icls1 = np.random.choice(classes)
            
            if len(landmarks[icls0]) == 2:
                ianc, ipos = np.random.choice(2, size = 2, replace = False)
            else:
                ianc = np.random.randint(0, len(landmarks[icls0]))
                ipos = np.random.randint(0, len(landmarks[icls0]))
                while ianc == ipos:
                    ipos = np.random.randint(0, len(landmarks[icls0]))
            ineg = np.random.randint(0, len(landmarks[icls1]))
            
            triplets.append(landmarks[icls0][ianc], landmarks[icls0][ipos], landmarks[icls1][ineg], icls0, icls1)
        
        return triplets
        
        
    def __getitem__(self, idx):
        
        id_a, id_p, id_n, cls0, cls1 = self.training_triplets[idx]
        
        img_a = os.path.join(self.root_dir, 'train', str(cls0), str(id_a + '.jpg'))
        img_p = os.path.join(self.root_dir, 'train', str(cls0), str(id_p + '.jpg'))
        img_n = os.path.join(self.root_dir, 'train', str(cls1), str(id_n + '.jpg'))
        
        img_a = io.imread(img_a)
        img_p = io.imread(img_p)
        img_n = io.imread(img_n)

        cls0    = torch.from_numpy(np.array([cls0]).reshape(-1, 1))
        cls1    = torch.from_numpy(np.array([cls1]).reshape(-1, 1))
        
        sample = {'img_a': img_a, 'img_p': img_p, 'img_n': img_n, 'cls0': cls0, 'cls1': cls1}

        if self.transform:
            sample['img_a'] = self.transform(sample['img_a'])
            sample['img_p'] = self.transform(sample['img_p'])
            sample['img_n'] = self.transform(sample['img_n'])
        
        return sample
    
    
    def __len__(self):
        
        return len(self.training_triplets)

In [38]:
data_transforms = {
    'train': transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
    ])
}

landmarks_dataset = {
    'train' : TripletLandmarkDataset(root_dir        = '/mnt/sw/workspace/Google',
                                     csv_name        = './top5_landmarks_shrunk_train.csv',
                                     num_of_triplets = 1000,
                                     transform       = data_transforms['train']),
    'val'   : TripletLandmarkDataset(root_dir        = '/mnt/sw/workspace/Google',
                                     csv_name        = './top5_landmarks_shrunk_val.csv',
                                     num_of_triplets = 200,
                                     transform       = data_transforms['val'])
}

TypeError: append() takes exactly one argument (5 given)