## Single model test
Define contants and configs before we start:

In [1]:
import os

import torch
import torch.nn as nn

from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets, models, transforms

from tqdm import tqdm

from classifier import create_network
from atlas import CsvDataset, collater

MULTI_CLASS_NUM = 28
IMAGE_SIZE = 512
SCORE_THRESHOLDS = [0.4, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9]
NUM_WORKERS = 4

batch_size = 96
tag = 'ATLAS-1'
dataset = './ATLAS'
data_root = '/home/voyager/data/atlas/'
# model_dir = './ATLAS-1_resnet-101_multi_1_0.5_20190125_175349/'
# model_dir = './ATLAS-2_resnet-101_multi_1_0.5_20190130_115614/'
model_dir = './test-1080_resnet-152_single_1_0.5_20190131_145152/'
model_file = 'fold0_epoch1.pth'
device_name = 'cuda:0'
# network = 'resnet-101'
network = 'resnet-152'


Now import data:

In [2]:
# data
test_augmentations = transforms.Compose([
    transforms.ToTensor()
])

test_set = CsvDataset(
    csv_path=None,
    data_root=data_root,
    num_classes=MULTI_CLASS_NUM,
    phase='test',
    augment=test_augmentations
)

test_loader = DataLoader(
    dataset=test_set,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collater,
    num_workers=NUM_WORKERS
)

Load model snapshot:

In [4]:
# model
device = torch.device(device_name)

model = create_network(
    network,
    pretrained=True,
    num_classes=MULTI_CLASS_NUM
)

state_dict = torch.load('{}{}'.format(model_dir, model_file))

model = model.to(device=device)
model.load_state_dict(state_dict)

model.training = False

RuntimeError: Error(s) in loading state_dict for ResNet:
	size mismatch for fc.weight: copying a param of torch.Size([28, 2048]) from checkpoint, where the shape is torch.Size([1, 2048]) in current model.
	size mismatch for fc.bias: copying a param of torch.Size([28]) from checkpoint, where the shape is torch.Size([1]) in current model.

OK, predict:

In [None]:
def test(model, test_loader, device):
    model.eval()
    
    all_probs = []
    
    with torch.no_grad():
        with tqdm(total=len(test_loader)) as pbar:
            for i, data in enumerate(test_loader):
                inputs = data['images']
                inputs = inputs.to(device=device)
                
                outputs = model(inputs)
                probs = torch.sigmoid(outputs)
                
                all_probs.append(probs.detach().cpu())
                    
                pbar.update(1)
            
            return torch.cat(all_probs, dim=0)
        
results = test(model, test_loader, device)

# release gmem cache
torch.cuda.empty_cache()

Refine the results:

In [18]:
refined_results = [[] for i in range(len(SCORE_THRESHOLDS))]

for i, score in enumerate(SCORE_THRESHOLDS):
    score_results = torch.gt(results, score)
    refined_results[i] = score_results

Write out the results:

In [19]:
from datetime import datetime

now = datetime.now()

def write_results(results, ids, score):
    submission_dir = '{}submissions/'.format(model_dir)
    model_id = model_file.split('.')[0]
    
    if not os.path.exists(submission_dir):
        os.mkdir(submission_dir)
        
    submission_file = '{}{}_{}_{:.2f}.csv'.format(submission_dir, model_id, now.strftime('%Y%m%d_%H%M%S'), score)
    
    with open(submission_file, "w") as csv_file:
        # write header
        csv_file.write('Id,Predicted\n')
        
        for i, result in enumerate(results):
            image_id = ids[i]
            
            predicted = ''
            for j, single_class in enumerate(result):
                if single_class == 1:
                    if len(predicted) == 0:
                        predicted = '{}'.format(j)
                    else:
                        predicted = '{} {}'.format(predicted, j)
                        
            csv_file.write('{},{}\n'.format(image_id, predicted))
            
        print('Written {}'.format(submission_file))

for i, result in enumerate(refined_results):
    write_results(result, test_set.image_ids, SCORE_THRESHOLDS[i])
    

Written ./ATLAS-3_resnet-34_multi_1_0.5_20190130_172843/submissions/fold0_epoch59_20190131_115714_0.40.csv
Written ./ATLAS-3_resnet-34_multi_1_0.5_20190130_172843/submissions/fold0_epoch59_20190131_115714_0.50.csv
Written ./ATLAS-3_resnet-34_multi_1_0.5_20190130_172843/submissions/fold0_epoch59_20190131_115714_0.55.csv
Written ./ATLAS-3_resnet-34_multi_1_0.5_20190130_172843/submissions/fold0_epoch59_20190131_115714_0.60.csv
Written ./ATLAS-3_resnet-34_multi_1_0.5_20190130_172843/submissions/fold0_epoch59_20190131_115714_0.65.csv
Written ./ATLAS-3_resnet-34_multi_1_0.5_20190130_172843/submissions/fold0_epoch59_20190131_115714_0.70.csv
Written ./ATLAS-3_resnet-34_multi_1_0.5_20190130_172843/submissions/fold0_epoch59_20190131_115714_0.75.csv
Written ./ATLAS-3_resnet-34_multi_1_0.5_20190130_172843/submissions/fold0_epoch59_20190131_115714_0.80.csv
Written ./ATLAS-3_resnet-34_multi_1_0.5_20190130_172843/submissions/fold0_epoch59_20190131_115714_0.85.csv
Written ./ATLAS-3_resnet-34_multi_1_0