In [18]:
import torch
import torch.nn as nn
import torchvision
import pretrainedmodels


In [175]:
model_names = [ 'resnet50', 'se_resnet50', 'se_resnext50_32x4d']


class WrapperModel(nn.Module):
    
    def __init__(self, base_model_name):
        super(WrapperModel, self).__init__()
        
        assert(base_model_name in model_names)
        
        base_model = pretrainedmodels.__dict__[base_model_name](pretrained='imagenet')
        
        self.features = nn.Sequential(*list(base_model.children())[:-2])        
        
        feature_num = base_model.last_linear.in_features
        
        self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
        self.last_linear = nn.Linear(feature_num, 200)
        
    def forward(self, x):
        x = self.features(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0),-1)
        x = self.last_linear(x)
        return x

def test_model():
    
    image = torch.randn(1,3,64,64)
    for model_name in model_names:
        model = WrapperModel(model_name)
        out = model(image)
        print(model_name,image.shape, out.shape)
        
test_model()

resnet50 torch.Size([1, 3, 64, 64]) torch.Size([1, 200])
se_resnet50 torch.Size([1, 3, 64, 64]) torch.Size([1, 200])
se_resnext50_32x4d torch.Size([1, 3, 64, 64]) torch.Size([1, 200])


In [2]:
from torch.utils.data.dataset import Dataset
import os
import glob
from PIL import Image 
import numpy as np    


def load_wnids(path='../tiny-imagenet-200/wnids.txt'):
    with open(path) as f:
        wnids = f.readlines()
        assert len(wnids) == 200
        wnids = [x.strip() for x in wnids]
    class2idx = {wnids[i]:i  for i in range(len(wnids))}
    return wnids, class2idx

class TinyTrainset(Dataset):
    
    def __init__(self, root, wnids_path, transform=None):
        super(TinyTrainset, self).__init__()
        classes = [d.name for d in os.scandir(root) if d.is_dir()]
        image_lists = [glob.glob(os.path.join(root, cls, 'images','*.JPEG'))  for cls in classes]
        
        self.images = []
        for image_list in image_lists:
            self.images += image_list
        labels = [ image.split('/')[-1].split('_')[0] for image in self.images]   
        
        self.wnids, self.class_to_idx = load_wnids(wnids_path)
        self.idxs = [ self.class_to_idx[label] for label in labels]
        self.labels = labels
        self.transform = transform
        
        assert(len(self.wnids)==200 and len(self.images)==200*500)
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):        
        image = Image.open(self.images[index])
        label_idx = self.idxs[index]
        
        if self.transform is not None:
            image = self.transform(image)
        return image, label_idx
        

class TinyValset(Dataset):
    
    def __init__(self, root, wnids_path, ann_path, transform=None):
        super(TinyValset, self).__init__()
        
        self.root = root
        self.transform = transform
        self.wnids, self.class_to_idx = load_wnids(wnids_path)
        
        with open(ann_path) as f:
            labels = f.readlines()
            assert len(labels) == 10000        
            data = [ label.split('\t')[:2] for label in labels]
            self.images, self.image_labels = zip(*data)        
        self.idxs = [ self.class_to_idx[label] for label in self.image_labels]
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):        
        image = Image.open(os.path.join(self.root, self.images[index]))
        label_idx = self.idxs[index]
        
        if self.transform is not None:
            image = self.transform(image)
        return image, label_idx        
        
def test_trainset():
    root = '../tiny-imagenet-200/train/'
    wnids_path = '../tiny-imagenet-200/wnids.txt'
    
    trainset = TinyTrainset(root, wnids_path)
    sample = trainset[0]
    image, label = sample
    image = np.array(image)
    assert image.shape == (64, 64, 3)
    assert 0 <= label and label < 200
    
    print(trainset.images[0], trainset.labels[0], trainset.idxs[0])
    
def test_valset():
    root = '../tiny-imagenet-200/val/images/'
    wnids_path = '../tiny-imagenet-200/wnids.txt'
    ann_path = '../tiny-imagenet-200/val/val_annotations.txt'
    
    valset = TinyValset(root, wnids_path, ann_path)
    image, label = valset[0]
    image = np.array(image)
    
    assert image.shape == (64, 64, 3)
    assert 0 <= label and label < 200    
    
    print(valset.images[0], valset.image_labels[0], valset.idxs[0])
    
test_trainset()
test_valset()

../tiny-imagenet-200/train/n04074963/images/n04074963_130.JPEG n04074963 54
val_0.JPEG n03444034 163
