In [1]:
from tqdm import tqdm
import os.path as osp
import numpy as np
import json
import random
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import models, transforms

In [2]:
rnd = 1234
torch.manual_seed(rnd)
np.random.seed(rnd)
random.seed(rnd)

In [3]:
class BaseTransform():
    def __init__(self, resize, mean, std):
        self.base_transform = transforms.Compose([
            transforms.Resize(resize),
            transforms.CenterCrop(resize),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

    def __call__(self, img):
        return self.base_transform(img)

In [4]:
ILSVRC_class_index = json.load(open('./data/imagenet_class_index.json', 'r'))

In [5]:
class ILSVRCPredictor():
    def __init__(self, class_index):
        self.class_index = class_index

    def predict_max(self, out):
        maxid = np.argmax(out.detach().numpy())
        predicted_label_name = self.class_index[str(maxid)][1]

        return predicted_label_name

In [6]:
predictor = ILSVRCPredictor(ILSVRC_class_index)

image_file_path = './data/goldenretriever-3724972_640.jpg'
img = Image.open(image_file_path)

resize = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

transform = BaseTransform(resize, mean, std)
img_transformed = transform(img)
inputs = img_transformed.unsqueeze_(0)

use_pretrained = True
net = models.vgg16(pretrained=use_pretrained)
net.eval()

out = net(inputs)
result = predictor.predict_max(out)

print("Result: {}".format(result))

Result: golden_retriever


In [8]:
class ImageTransform():
    def __init__(self, resize, mean, std):
       self.data_transform = {
           'train': transform.Compose([
               transforms.RandomResizedCrop(
                   resize, scale=(0.5, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
           ]),
           'val': transforms.Compose([
               transforms.Resize(resize),
               transforms.CenterCrop(resize),
               transforms.ToTensor(),
               transforms.Normalize(mean, std)
               ])
            }

    def __call__(self, img, phase='train'):
        return self.data_transform[phase](img)