In [31]:
from PIL import Image

In [72]:
import os

In [119]:
import numpy as np

In [120]:
import pandas as pd

In [50]:
import tqdm

In [51]:
from torchvision import transforms

In [52]:
import torch

In [53]:
from torch import nn

In [54]:
from torchvision.models import resnet18, resnet50

In [55]:
from torch.utils.data import DataLoader

In [96]:
from torch.autograd import Variable

In [106]:
import torch.nn.functional as F

In [107]:
num_classes = 17

In [108]:
model = resnet50(pretrained=True).cuda()
model.fc = nn.Linear(model.fc.in_features, num_classes).cuda()

In [122]:
state = torch.load('../data/best-model.pt')
epoch = state['epoch']
step = state['step']
best_valid_loss = state['best_valid_loss']
model.load_state_dict(state['model'])

In [123]:
def load_image(path):
    return Image.open(str(path)).convert('RGB')

In [124]:
transformations_test = transforms.Compose([
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

In [125]:
class PredictionDataset:
    def __init__(self, paths, n_test_aug):
        self.paths = paths
        self.n_test_aug = n_test_aug

    def __len__(self):
        return len(self.paths) * self.n_test_aug

    def __getitem__(self, idx):
        path = self.paths[idx % len(self.paths)]
        image = load_image(path)
        return transformations_test(image)

In [126]:
cuda_is_available = torch.cuda.is_available()


def variable(x, volatile=False):
    if isinstance(x, (list, tuple)):
        return [variable(y, volatile=volatile) for y in x]
    return cuda(Variable(x, volatile=volatile))


def cuda(x):
    return x.cuda() if cuda_is_available else x


In [127]:
def predict(model, paths, batch_size, n_test_aug):
    loader = DataLoader(
        dataset=PredictionDataset(paths, n_test_aug),
        shuffle=False,
        batch_size=batch_size,
        num_workers=1,
    )
    model.eval()
    all_outputs = []    
    for inputs in tqdm.tqdm(loader, desc='Predict'):
        inputs = variable(inputs, volatile=True)
        outputs = F.sigmoid(model(inputs))
        all_outputs.append(outputs.data.cpu().numpy())        
    all_outputs = np.concatenate(all_outputs)
    df = pd.DataFrame(data=all_outputs)
#     df = pd.DataFrame(data=all_outputs, index=all_stems,
#                       columns=dataset.CLASSES)
    return df
    

In [128]:
test_path = '../data/test-jpg'
list_files = os.listdir(test_path)

In [129]:
list_files = [os.path.join(test_path, x) for x in list_files]

In [131]:
p = predict(model, list_files, 32, 1)

Predict: 100%|██████████| 1271/1271 [05:34<00:00,  4.17it/s]


In [132]:
p

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16
0,0.007159,8.362954e-06,0.000146,0.000336,0.000106,0.004989,9.977539e-03,2.624495e-05,0.002147,0.000538,0.603607,0.315762,0.987529,0.003039,0.000144,0.000050,0.015700
1,0.006324,1.007194e-05,0.000394,0.015575,0.005182,0.997842,1.808381e-07,8.742442e-06,0.005437,0.000548,0.000061,0.004617,0.999998,0.003033,0.005417,0.000108,0.007718
2,0.001970,1.446923e-06,0.000105,0.003862,0.001455,0.998372,5.734845e-08,2.515545e-06,0.002421,0.000173,0.000023,0.005549,0.999999,0.001221,0.001501,0.000027,0.002375
3,0.001709,2.209177e-06,0.000089,0.009231,0.001613,0.991930,7.195432e-07,3.139799e-06,0.002273,0.000218,0.000173,0.013444,0.999995,0.000743,0.001524,0.000036,0.001962
4,0.563458,1.977381e-03,0.042921,0.001501,0.001289,0.974863,2.279682e-05,4.962340e-03,0.309149,0.197345,0.012232,0.021599,0.997956,0.962614,0.020137,0.006255,0.121110
5,0.871200,4.869810e-05,0.001691,0.000052,0.000044,0.271479,2.560648e-03,5.837707e-04,0.061310,0.028586,0.645836,0.056057,0.878822,0.996298,0.001220,0.000248,0.155153
6,0.026214,1.882720e-04,0.000867,0.000233,0.000448,0.001275,8.935936e-01,2.681515e-04,0.007571,0.005293,0.013932,0.114817,0.193299,0.018805,0.000424,0.000376,0.028779
7,0.004703,3.403498e-05,0.000760,0.048281,0.003477,0.997258,4.395231e-07,1.923617e-05,0.004463,0.001313,0.000272,0.003673,0.999994,0.017782,0.034464,0.000110,0.012335
8,0.003507,2.807036e-06,0.000124,0.005966,0.001061,0.999055,2.383705e-08,4.037989e-06,0.002846,0.000556,0.000044,0.003065,0.999999,0.005815,0.002740,0.000023,0.005699
9,0.057725,1.099989e-04,0.002986,0.008610,0.006757,0.964422,2.503290e-04,1.050765e-04,0.025301,0.003615,0.014375,0.017101,0.998691,0.015170,0.003148,0.001355,0.027188
