In [13]:
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

plt.ion()   # interactive mode
get_label = {
    'articulated_truck': 0,
    'background': 1,
    'bicycle': 2,
    'bus': 3,
    'car': 4,
    'motorcycle': 5,
    'non-motorized_vehicle': 6,
    'pedestrian': 7,
    'pickup_truck': 8,
    'single_unit_truck': 9,
    'work_van': 10
}

In [14]:
# for i, name in enumerate(os.listdir('./data/train')):
#     for j in os.listdir('./data/train/'+ name):
#         self.image_arr.append('./data/train/'+ name+'/'+j)
#         self.label_arr.append(i)
#         self.operation_arr.append(False)
# #                 print('./data/train/'+ i+'/'+j, str(i), False)
#         break
#     break

In [15]:
class Rescale(object):
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size
    
    def __call__(self, sample):
        image = sample['image']
        h, w = image.shape[:2]
        
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size
        
        new_h, new_w = int(new_h), int(new_w)
        
        img = transform.resize(image, (new_h, new_w))
        return {'image': img, 'label': sample['label']}
class ToTensor(object):

    def __call__(self, sample):
        image = sample['image']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image), 'label': sample['label']}

In [18]:
class DS(Dataset):
    def __init__(self, csv_file, root_dir, transform = None):
        self.transform = transform
        self.csv_ds = pd.read_csv(csv_file, dtype='str')
        self.root_dir = root_dir
    
    def __len__(self):
        return len(self.csv_ds)
    
    def __getitem__(self, idx):
        print(self.csv_ds.iloc[idx][0], self.csv_ds.iloc[idx][1])
        img_name = str(self.csv_ds.iloc[idx][0])
        img_path = os.path.join(self.root_dir, 'train',str(self.csv_ds.iloc[idx][1]), str(self.csv_ds.iloc[idx][0]) + '.jpg')
        image = io.imread(img_path)
        label = get_label[self.csv_ds.iloc[idx][1]]
        sample = {'image': image, 'label': label}
        
        if self.transform:
            sample = self.transform(sample)
        return sample
    
tds = DS(csv_file='./data/gt_train.csv',
                                    root_dir='./data/',
                                    transform = transforms.Compose([Rescale((100, 100)), ToTensor()]) )

fig = plt.figure()
scale = Rescale((100, 100))
toTensor = ToTensor()

for i in range(len(tds)):
    
    sample = tds[i]
    print(i, sample['image'].shape, sample['label'])
    np_i = sample['image'].numpy()
#     plt.imshow(np.transpose(np_i, (1, 2, 0)), interpolation='nearest')
#     tensor_image = toTensor(sample)['image']
#     scaled_image = scale(sample)['image']
#     io.imshow(scaled_image)
#     ax = plt.subplot(1, 4, i + 1)
#     plt.tight_layout()
#     ax.set_title('Sample #{}'.format(i))
#     ax.axis('off')
# #     show_landmarks(**sample)

    if i == 200:
        plt.show()
        break

00260164 non-motorized_vehicle
0 torch.Size([3, 100, 100]) 6
00400458 non-motorized_vehicle
1 torch.Size([3, 100, 100]) 6
00457677 non-motorized_vehicle
2 torch.Size([3, 100, 100]) 6
00209377 non-motorized_vehicle
3 torch.Size([3, 100, 100]) 6
00277048 non-motorized_vehicle
4 torch.Size([3, 100, 100]) 6
00356378 non-motorized_vehicle
5 torch.Size([3, 100, 100]) 6
00529840 non-motorized_vehicle
6 torch.Size([3, 100, 100]) 6
00188674 non-motorized_vehicle
7 torch.Size([3, 100, 100]) 6
00461004 non-motorized_vehicle
8 torch.Size([3, 100, 100]) 6
00498624 non-motorized_vehicle
9 torch.Size([3, 100, 100]) 6
00282139 non-motorized_vehicle
10 torch.Size([3, 100, 100]) 6
00255010 non-motorized_vehicle
11 torch.Size([3, 100, 100]) 6
00253096 non-motorized_vehicle
12 torch.Size([3, 100, 100]) 6
00413833 non-motorized_vehicle
13 torch.Size([3, 100, 100]) 6
00517913 non-motorized_vehicle
14 torch.Size([3, 100, 100]) 6
00240485 non-motorized_vehicle
15 torch.Size([3, 100, 100]) 6
00201348 non-motor

143 torch.Size([3, 100, 100]) 6
00344073 non-motorized_vehicle
144 torch.Size([3, 100, 100]) 6
00535028 non-motorized_vehicle
145 torch.Size([3, 100, 100]) 6
00365828 non-motorized_vehicle
146 torch.Size([3, 100, 100]) 6
00233004 non-motorized_vehicle
147 torch.Size([3, 100, 100]) 6
00498171 non-motorized_vehicle
148 torch.Size([3, 100, 100]) 6
00473880 non-motorized_vehicle
149 torch.Size([3, 100, 100]) 6
00501850 non-motorized_vehicle
150 torch.Size([3, 100, 100]) 6
00250388 non-motorized_vehicle
151 torch.Size([3, 100, 100]) 6
00355987 non-motorized_vehicle
152 torch.Size([3, 100, 100]) 6
00467622 non-motorized_vehicle
153 torch.Size([3, 100, 100]) 6
00422693 non-motorized_vehicle
154 torch.Size([3, 100, 100]) 6
00501431 non-motorized_vehicle
155 torch.Size([3, 100, 100]) 6
00183164 non-motorized_vehicle
156 torch.Size([3, 100, 100]) 6
00281018 non-motorized_vehicle
157 torch.Size([3, 100, 100]) 6
00349694 non-motorized_vehicle
158 torch.Size([3, 100, 100]) 6
00173346 non-motorized_

<Figure size 432x288 with 0 Axes>