In [1]:
import torch
import torch.nn as nn
import pretrainedmodels
import torchvision.transforms as transforms

from PIL import Image
import numpy as np
import ast
from tqdm import tqdm

In [2]:
model_name = 'resnet34'
num_classes = 340
width = 256

model = pretrainedmodels.__dict__[model_name](num_classes=1000, pretrained=None)
model.avgpool = nn.AdaptiveAvgPool2d(1)
model.avgpool_1a = nn.AdaptiveAvgPool2d(1)
model.last_linear = nn.Linear(in_features=model.last_linear.in_features, out_features=num_classes, bias=True)
model = model.cuda()

state = torch.load('resnet34_sgd_adam_0.9353_100.pt')
model.load_state_dict(state['model'])
model = model.eval()

In [3]:
from dataset import SplitDataset, DataLoader
transform = transforms.Compose([
    transforms.Resize(size=(width, width)),
    transforms.ToTensor()
])

doodles = SplitDataset('test_simplified.csv', mode='test', transform=transform)
testloader = DataLoader(doodles, batch_size=64, shuffle=False, num_workers=6)

In [4]:
model.eval()
labels = np.empty((0,3))
with tqdm(testloader) as pbar:
    for x in pbar:
        x = x.cuda()
        output = model(x)
        _, pred = output.topk(3, 1, True, True)
        labels = np.concatenate([labels, pred], axis = 0)

100%|██████████| 1754/1754 [02:04<00:00, 14.14it/s]


In [5]:
labels[:10]

array([[234., 281., 285.],
       [144.,  36., 226.],
       [305.,  62., 110.],
       [187., 303., 111.],
       [ 56., 113., 165.],
       [110., 274., 157.],
       [335., 151., 257.],
       [289.,  91., 253.],
       [ 38., 101., 136.],
       [151., 326., 335.]])

In [6]:
import pandas as pd
df2 = pd.read_csv('split_recognized/train_k99.csv')
label_word_map = dict(zip(df2.y.unique(), [x.replace(' ', '_') for x in df2.word.unique()]))

In [7]:
%%time
labels_string = [' '.join([label_word_map[y] for y in x]) for x in labels]
submission = pd.read_csv('test_simplified.csv', index_col='key_id')
submission.drop(['countrycode', 'drawing'], axis=1, inplace=True)
submission['word'] = labels_string
submission.to_csv(f'preds_{model_name}.csv')

CPU times: user 2.05 s, sys: 58.6 ms, total: 2.11 s
Wall time: 1.51 s
