In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
import sys

import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

from dataset import *
from resnet import *
from classifier import *
from sampler import BalancedStatisticSampler
from visualize import *

In [None]:
TRAIN_NAME = 'slice'
TRAIN_ID = '01'
EPOCH = 8

SHOW_COUNT = 50

POS_THRESHOLD = 0.1
NEG_THRESHOLD = 0.01
SCORE_THRESHOLD = 0.5

# data consts
ROOT_PATH = '/home/xd/data/fire/test_1'
NUM_CLASSES = 2 # fg + 1(bg)
READ_SIZE = (640, 480)
SLICE_SIZE = 40
SLICE_COUNT = READ_SIZE[0] * READ_SIZE[1] // (SLICE_SIZE * SLICE_SIZE)
INPUT_SIZE = 224
BATCH_SIZE = 64
NUM_WORKERS = 16

# trainer consts
DEVICE = 'cuda'

In [None]:
val_trans = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((INPUT_SIZE, INPUT_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

In [None]:
# model
device = torch.device(DEVICE)

model = resnet101(pretrained=False, num_classes=NUM_CLASSES)

checkpoint_path = os.path.join('./models', '{}_{}'.format(TRAIN_NAME, TRAIN_ID), '{:0>3d}.pth'.format(EPOCH))
cp_state_dict = torch.load(checkpoint_path, map_location='cpu')

if 'module' in list(cp_state_dict.keys())[0]:
    new_state_dict = {}
    
    for key, value in cp_state_dict.items():
        new_state_dict[key.split('.', 1)[1]] = value
    
    model.load_state_dict(new_state_dict)
else:
    model.load_state_dict(cp_state_dict)

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model = model.to(device)

In [None]:
data_path = os.path.join(ROOT_PATH, 'images')
img_names = os.listdir(data_path)

with tqdm(total=SHOW_COUNT, file=sys.stdout) as pbar:
    for frame_no, img_name in enumerate(img_names):
        img_path = os.path.join(ROOT_PATH, 'images', img_name)

        pred_img = slice_plots(
            model,
            device,
            img_path,
            SLICE_SIZE,
            READ_SIZE,
            val_trans,
            SCORE_THRESHOLD,
            mask_path=None,
            gt_thres=0.05
        )
        
        fig=plt.figure(figsize=(16,12))
        plt.imshow(pred_img)
        plt.show()

        pbar.update(1)
        
        if frame_no == SHOW_COUNT-1:
            break