In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
import sys
import argparse

import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from sklearn import metrics
from matplotlib import pyplot as plt

from dataset import *
from resnet import *
from classifier import *

In [None]:
# consts
TRAIN_NAME = 'uda-test'
TRAIN_ID = '10'
EPOCH = 99

# data consts
ROOT_PATH = '/home/xd/data/chromo/class-2/uda'
NUM_CLASSES = 2 # fg + 1(bg)
INPUT_SIZE = 512
BATCH_SIZE = 1
NUM_WORKERS = 4

# trainer consts
DEVICE = 'cuda:1'

In [None]:
val_trans = transforms.Compose([
    transforms.ToPILImage(),
    PadSquare(),
    transforms.Resize((INPUT_SIZE, INPUT_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

val_dataset = UdaDataset(
    ROOT_PATH,
    training=False,
    image_ext='.png',
    sup_transform=val_trans
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    # batch_sampler=val_sampler,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

In [None]:
device = torch.device(DEVICE)

model = resnet50(pretrained=True, num_classes=NUM_CLASSES)

checkpoint_path = os.path.join('./models', '{}_{}'.format(TRAIN_NAME, TRAIN_ID), '{:0>3d}.pth'.format(EPOCH))
cp_state_dict = torch.load(checkpoint_path, map_location='cpu')

if 'module' in list(cp_state_dict.keys())[0]:
    new_state_dict = {}
    
    for key, value in cp_state_dict.items():
        new_state_dict[key.split('.', 1)[1]] = value
    
    model.load_state_dict(new_state_dict)
else:
    model.load_state_dict(cp_state_dict)

'''
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
'''
model = model.to(device)

In [None]:
all_preds = []
all_scores = []
all_gts = []

model.eval()

with torch.no_grad():
    with tqdm(total=len(val_loader), file=sys.stdout) as pbar:
        for iter_no, (imgs, gts) in enumerate(val_loader):
            imgs = imgs.to(device)
            gts = gts.to(device)

            results = model(imgs)
            
            scores = nn.functional.softmax(results, dim=1)[:, 1]
            
            # be ware torch.max is overloaded
            preds = torch.max(nn.functional.softmax(results, dim=1), 1)[1]
            
            all_scores.append(scores.cpu().view(-1))
            all_preds.append(preds.cpu().view(-1))
            all_gts.append(gts.cpu().squeeze().view(-1))

            pbar.update(1)
        
    all_scores = torch.cat(all_scores).numpy()
    all_preds = torch.cat(all_preds).numpy()
    all_gts = torch.cat(all_gts).numpy()
    
    f1 = metrics.f1_score(all_gts, all_preds)
    precision = metrics.precision_score(all_gts, all_preds)
    recall = metrics.recall_score(all_gts, all_preds)

In [None]:
print('f1: {:0.3f}, ap: {:0.3f}, ar: {:0.3f}'.format(f1, precision, recall))

## visualize

In [None]:
VIS_COUNT = 20

In [None]:
with tqdm(total=len(val_loader), file=sys.stdout) as pbar:
    for iter_no, sample in enumerate(zip(val_dataset, all_preds, all_scores)):
        data, pred, score = sample
        img, gt = data
        
        img = img.numpy().transpose((1, 2, 0))
        
        print(gt, pred, score)
        
        fig = plt.figure(figsize=(16,12))
        plt.imshow(img)
        plt.show()
        
        if iter_no == VIS_COUNT:
            break