In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
import sys
from functools import partial

import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import PIL
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage

from dataset import *
from resnet import *
from slice_net import *
from classifier import *
from sampler import BalancedStatisticSampler
from visualize import *

In [2]:
TRAIN_NAME = 'slicenet'
TRAIN_ID = '09'
EPOCH = 186

SHOW_COUNT = 50

SCORE_THRESHOLD = 0.5
MASK_RESIZE_RATIO = 32

# data consts
ROOT_PATH = '/home/xd/data/fire'
NUM_CLASSES = 2 # fg + 1(bg)
INPUT_SIZE = (640, 480)
BATCH_SIZE = 64
NUM_WORKERS = 16

# trainer consts
DEVICE = 'cuda'

# model consts
ATTENTION = 'non_local'
R = 16
K = 7

In [3]:
val_seq = iaa.Sequential([
    iaa.Resize({'height':INPUT_SIZE[1], 'width':INPUT_SIZE[0]})
])

tensor_trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

mask_trans = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((INPUT_SIZE[1]//MASK_RESIZE_RATIO, INPUT_SIZE[0]//MASK_RESIZE_RATIO), PIL.Image.NEAREST),
    transforms.ToTensor()
])

def trans(
    img, mask,
    iaa_seq=None,
    vision_trans=None,
    mask_trans=None
):
    mask_map = SegmentationMapsOnImage(mask, shape=img.shape)
    img_aug, mask_map_aug = iaa_seq(image=img, segmentation_maps=mask_map)
    
    img = vision_trans(img_aug)
    mask = mask_trans(mask_map_aug.get_arr())
    
    return img, mask

val_trans = partial(
    trans,
    iaa_seq=val_seq,
    vision_trans=tensor_trans,
    mask_trans=mask_trans
)

val_dataset = SegmentationDataset(
    ROOT_PATH,
    training=False,
    transform=val_trans
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

In [4]:
# model
device = torch.device(DEVICE)

resnet = resnet101(pretrained=True, num_classes=NUM_CLASSES)

if ATTENTION == 'se':
    attention = SE_module(resnet.feature_size, R)
elif ATTENTION == 'channel':
    self.attention = Channel_Attention(resnet.feature_size, R)
elif ATTENTION == 'spartial':
    attention = Spartial_Attention(K)
elif ATTENTION == 'cbam':
    attention = nn.Sequential(
        Channel_Attention(resnet.feature_size, R),
        Spartial_Attention(K)
    )
elif ATTENTION == 'non_local':
    attention = NonLocalBlockND(resnet.feature_size)
else:
    attention = None
    
model = SliceNet(resnet, num_classes=NUM_CLASSES, attention=attention)

checkpoint_path = os.path.join('./models', '{}_{}'.format(TRAIN_NAME, TRAIN_ID), '{:0>3d}.pth'.format(EPOCH))
cp_state_dict = torch.load(checkpoint_path, map_location='cpu')

if 'module' in list(cp_state_dict.keys())[0]:
    new_state_dict = {}
    
    for key, value in cp_state_dict.items():
        new_state_dict[key.split('.', 1)[1]] = value
    
    model.load_state_dict(new_state_dict)
else:
    model.load_state_dict(cp_state_dict)

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model = model.to(device)

In [None]:
with tqdm(total=SHOW_COUNT, file=sys.stdout) as pbar:
    for frame_no, sample in enumerate(val_dataset):
        img_file = os.path.join(
            val_dataset.root_path,
            'images',
            val_dataset.img_filenames[val_dataset.indices[frame_no]]
        )
        ori_img = cv2.imread(img_file)
        ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
        ori_img = cv2.resize(ori_img, INPUT_SIZE)
        
        img, mask = sample
        
        _, h, w = img.shape
        
        img = img.to(device)
        mask = mask.to(device)
        
        # infer
        with torch.no_grad():
            logits = model(img.unsqueeze(0))
            scores = nn.functional.softmax(logits, 1).squeeze()[-1]
            preds = torch.gt(scores, SCORE_THRESHOLD)
    
            results = torch.ne(preds, torch.gt(mask, 0.5).squeeze())
            
        pred_img = ori_img.copy()
        gt_img = ori_img.copy()
        result_img = ori_img.copy()

        pred_mask = slice_mask((h, w), MASK_RESIZE_RATIO, preds.squeeze().view(-1).cpu().numpy())
        gt_mask = slice_mask((h, w), MASK_RESIZE_RATIO, torch.gt(mask, 0.5).squeeze().view(-1).cpu().numpy())
        result_mask = slice_mask((h, w), MASK_RESIZE_RATIO, results.squeeze().view(-1).cpu().numpy())
        
        pred_img = apply_mask(pred_img, pred_mask, (0, 0, 128))
        gt_img = apply_mask(gt_img, gt_mask, (0, 128, 0))
        result_img = apply_mask(result_img, result_mask, (128, 0, 0))
        
        fig=plt.figure(figsize=(16,12))

        plt.subplot(1, 3, 1)
        plt.imshow(gt_img)

        plt.subplot(1, 3, 2)
        plt.imshow(pred_img)

        plt.subplot(1, 3, 3)
        plt.imshow(result_img)
        
        plt.show()

        pbar.update(1)
        
        if frame_no == SHOW_COUNT-1:
            break