## Import & Environment Setting

In [None]:
import torch
from torch import autograd
from torch.utils.data import DataLoader

import json
import gc
import numpy as np

from utils.dataset import DetectionFolder
from utils.model import YoloV3, YoloLoss
from utils.postprocess import PostProcessor

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches

import os
import fnmatch
from collections import Counter

import urllib
from io import BytesIO

from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw 

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

## Load Config

In [None]:
with open("./config/config.json", "r") as config_file:
    main_config = json.load(config_file)

try:
    model_config = main_config['model']
    test_config = main_config['test']
except NameError:
    assert False, ('Failed to load config file')
except KeyError:
    assert False, ('Failed to find key on config file')

In [None]:
model_config['device'] = device
model_config['dtype'] = dtype
model_config['attrib_count'] = 5 + model_config['class_count']

test_config['device'] = device
test_config['dtype'] = dtype

## Build

In [None]:
model = torch.load('./model/model_r_255.dat')
model.to(device)
model = model.eval()

In [None]:
test_context = { }

test_context['dataset'] = DetectionFolder(test_config['test_list'], test_config['test_image'], test_config['test_label'])
test_context['dataloader'] = DataLoader(test_context['dataset'], batch_size = 1, num_workers = 1)

test_context['post_conf_threshold'] = test_config['post_conf_threshold']
test_context['post_iou_threshold'] = test_config['post_iou_threshold']
test_context['acc_iou_threshold'] = test_config['acc_iou_threshold']

test_context['video_dir'] = test_config['video_dir']
test_context['font'] = ImageFont.truetype("arial.ttf", 16)

test_context['debug_level'] = test_config['debug_level']

In [None]:
def iou(bbox1, bbox2):
    w1 = bbox1[2]
    h1 = bbox1[3]
    
    w2 = bbox1[2]
    h2 = bbox1[3]
    
    left1 = bbox1[0] - w1 / 2
    left2 = bbox2[0] - w2 / 2
    right1 = bbox1[0] + w1 / 2
    right2 = bbox2[0] + w2 / 2
    top1 = bbox1[1] + h1 / 2
    top2 = bbox2[1] + h2 / 2
    bottom1 = bbox1[1] - h1 / 2
    bottom2 = bbox2[1] - h2 / 2
    
    area1 = w1 * h1
    area2 = w2 * h2
    
    w_intersect = min(right1, right2) - max(left1, left2)
    h_intersect = min(top1, top2) - max(bottom1, bottom2)
    area_intersect = h_intersect * w_intersect
    
    if w_intersect < 0 or h_intersect < 0:
        return 0

    iou_ = area_intersect / (area1 + area2 - area_intersect)
    return iou_
        
def plot_pred(test_context, prediction, plot_options = ['ABOVE'], print_accuracy = False):
        
    if test_context['debug_level'] >= 1:
        print('title : ', prediction['title'])
        print('plot_options : ', plot_options)
        
    pred = prediction['pred']
    image = prediction['image']
    label = prediction['label']
    label_len = prediction['label_len']
    
    postProcessor = PostProcessor()
    
    ax_indx = 0
    _, ax = plt.subplots(1, len(plot_options), figsize=(8 * len(plot_options), 8))
    
    for plot_opt in plot_options:
        # plot all boundingbox above threshold
        if plot_opt is 'ABOVE':
            pred_result = postProcessor.ABOVE(pred, test_context)
            accuracy = None
        elif plot_opt is 'NMS':
            pred_result = postProcessor.NMS(pred, test_context)
            if print_accuracy is True:
                accuracy = postProcessor.calcAccuracyMap(label, label_len, pred_result, test_context)
                print('NMS accuracy : ', accuracy)
        elif plot_opt is 'CUSTOM1':
            pred_result = postProcessor.CUSTOM1(pred, test_context)
            if print_accuracy is True:
                accuracy = postProcessor.calcAccuracyMap(label, label_len, pred_result, test_context)
                print('CUSTOM1 accuracy : ', accuracy)
        elif plot_opt is 'CUSTOM2':
            pred_result = postProcessor.CUSTOM2(pred, test_context)
            if print_accuracy is True:
                accuracy = postProcessor.calcAccuracyMap(label, label_len, pred_result, test_context)
                print('CUSTOM2 accuracy : ', accuracy)
        else:
            ax_indx = ax_indx + 1
            continue
            
            
        if len(plot_options) > 1:
            ax[ax_indx].imshow(image)
        else:
            ax.imshow(image)
                
        for bbox in pred_result:
            color = '%02x'% int(255 * (bbox[4] - test_context['post_conf_threshold']) / (1 - test_context['post_conf_threshold']))
            color = '#' + str(color) + str('00') + str('00')
            
            bounding = patches.Rectangle((bbox[0] - bbox[2] / 2, bbox[1] - bbox[3] / 2), bbox[2], bbox[3], 
                        linewidth=1, edgecolor=color, facecolor='none')
                
            if len(plot_options) > 1:
                ax[ax_indx].add_patch(bounding)
            else:
                ax.add_patch(bounding)
        ax_indx = ax_indx + 1
    
        
    plt.show()
    plt.close('all')
        
    gc.collect()
    torch.cuda.empty_cache()

    
def plot_save(test_context, target_name, prediction, size = (1920, 1080), plot_opt = 'CUSTOM2'):
        
    if test_context['debug_level'] >= 1:
        print('title : ', prediction['title'])
        
    postProcessor = PostProcessor()
    prediction = postProcessor.resize(from_size = (608, 608), to_size = size, prediction = prediction)
    
    pred = prediction['pred']
    image = prediction['image']
    label = prediction['label']
    
    if plot_opt is 'ABOVE':
        pred_result = postProcessor.ABOVE(pred, test_context)
    elif plot_opt is 'NMS':
        pred_result = postProcessor.NMS(pred, test_context)
    elif plot_opt is 'CUSTOM1':
        pred_result = postProcessor.CUSTOM1(pred, test_context)
    elif plot_opt is 'CUSTOM2':
        pred_result = postProcessor.CUSTOM2(pred, test_context)
    else:
        return
    
    imagepath = test_context['video_dir'] + target_name + '/' + prediction['title']
    image = Image.open(imagepath)
    image_pixels = image.load()
    
    draw = ImageDraw.Draw(image)

    print(len(pred_result), 'boxes found')
    for bbox in pred_result:
        left = int(bbox[0] - bbox[2] / 2)
        right = int(bbox[0] + bbox[2] / 2)
        up = int(bbox[1] - bbox[3] / 2)
        down = int(bbox[1] + bbox[3] / 2)
        
        if left >= 0 and up >= 0:
            draw.text((left, up), str(bbox[4]),(255,0,0),font=test_context['font'])
        elif left >= 0 and up < 0:
            draw.text((left, down), str(bbox[4]),(255,0,0),font=test_context['font'])
        elif left < 0 and up >= 0:
            draw.text((right, up), str(bbox[4]),(255,0,0),font=test_context['font'])
        else:
            draw.text((right, up), str(bbox[4]),(255,0,0),font=test_context['font'])
            
            
        draw.rectangle(xy=[left, up, right, down], outline=(255, 0, 0) )
    output_dir = test_context['video_dir'] + 'o_' + target_name + '/'
    image.save(output_dir + prediction['title'][:-3] + 'png')
    
    gc.collect()
    torch.cuda.empty_cache()
    
        
def test_all(test_context, test_config, plot_options = ['ABOVE', 'NMS', 'CUSTOM2']):
    accs = Counter({})
    
    with torch.no_grad():
        for idx, batches in enumerate(test_context['dataloader']):
            image = batches['image'].to(test_config['device'], dtype = test_config['dtype'])
            labels = batches['label'].to(test_config['device'], dtype = test_config['dtype'])
            label_len = batches['label_len'].to(test_config['device'], dtype = torch.long)
            
            out1, out2, out3 = model(image)
            pred = torch.cat((out1, out2, out3), 1)
        
            output = {}
            output['title'] = batches['title']
            output['image'] = batches['image'].cpu().permute(0, 2, 3, 1).squeeze(0).numpy()
            output['pred'] = pred.cpu().detach().squeeze(0).numpy()
            output['label'] = batches['label'].cpu().squeeze(0).numpy()
            output['label_len'] = batches['label_len'].cpu().squeeze(0).numpy()
            
            plot_pred(test_context, output, plot_options, print_accuracy = True)
        
            postProcessor = PostProcessor()
            pred_result = postProcessor.CUSTOM2(output['pred'], test_context)
            accuracy = postProcessor.calcAccuracyMap(output['label'], output['label_len'], pred_result, test_context)
            accs = accs + Counter(accuracy)
            
            del image, labels, label_len
            del out1, out2, out3, output
            gc.collect()
            torch.cuda.empty_cache()
        print('total : ', accs)
        tp = accs['true positive']
        fn = accs['false negative']
        fp = accs['false positive'] + accs['duplicate']
        print('accuracy : ', tp / (tp + fn + fp))

def test_n(test_context, test_config, n = 1, plot_options = ['ABOVE', 'NMS', 'CUSTOM2']):
    with torch.no_grad():
        for idx, batches in enumerate(test_context['dataloader']):
            image = batches['image'].to(test_config['device'], dtype = test_config['dtype'])
            labels = batches['label'].to(test_config['device'], dtype = test_config['dtype'])
            label_len = batches['label_len'].to(test_config['device'], dtype = torch.long)
            
            out1, out2, out3 = model(image)
            pred = torch.cat((out1, out2, out3), 1)
        
            output = {}
            output['title'] = batches['title']
            output['image'] = batches['image'].cpu().permute(0, 2, 3, 1).squeeze(0).numpy()
            output['pred'] = pred.cpu().detach().squeeze(0).numpy()
            output['label'] = batches['label'].cpu().squeeze(0).numpy()
            output['label_len'] = batches['label_len'].cpu().squeeze(0).numpy()
            
            plot_pred(test_context, output, plot_options, print_accuracy = True)
        
            del image, labels, label_len
            del out1, out2, out3, output
            gc.collect()
            torch.cuda.empty_cache()
            
            if idx + 1 >= n:
                return
            
def test_url_image(test_context, test_config, url, crop = False, crop_size = (0, 0, 1920, 1080), 
                   plot_options = ['ABOVE', 'NMS', 'CUSTOM2']):
    image = Image.open(urllib.request.urlopen(url))
    if crop is True:
        image = image.crop(crop_size)
    image = np.array(image.resize((608, 608), Image.BILINEAR))[:, :, :3]
    image = torch.from_numpy(image).unsqueeze(0).float() / 255
    image = image.permute(0, 3, 1, 2)
    
    image = image.to(test_config['device'], dtype = test_config['dtype'])
        
    out1, out2, out3 = model(image)
        
    pred = torch.cat((out1, out2, out3), 1)
        
    output = {}
    output['title'] = url if len(url) < 80 else url[0:77] + '...' 
    output['image'] = image.permute(0, 2, 3, 1).cpu().squeeze(0).numpy()
    output['pred'] = pred.cpu().detach().squeeze(0).numpy()
    output['label'] = []
    output['label_len'] = 0
        
    plot_pred(test_context, output, plot_options)
    
    gc.collect()
    torch.cuda.empty_cache()
    
def test_video(test_context, target_name, start = 0, size=(1920, 1080), plot_opt = 'CUSTOM2'):
    file_list = os.listdir(test_context['video_dir'] + target_name + '/')
    file_list = np.sort(file_list)
    file_list = file_list[start:]
    for image_name in file_list:
        if fnmatch.fnmatch(image_name, '*.png') is False:
            continue
        print(image_name)
        image = Image.open(test_context['video_dir'] + target_name + '/' + image_name)
        image = np.array(image.resize((608, 608), Image.BILINEAR))[:, :, :3]
        image = torch.from_numpy(image).unsqueeze(0).float() / 255
        image = image.permute(0, 3, 1, 2)
    
        image = image.to(test_config['device'], dtype = test_config['dtype'])
        
        out1, out2, out3 = model(image)
        
        pred = torch.cat((out1, out2, out3), 1)
        
        output = {}
        output['title'] = image_name
        output['image'] = image.permute(0, 2, 3, 1).cpu().squeeze(0).numpy()
        output['pred'] = pred.cpu().detach().squeeze(0).numpy()
        output['label'] = []
        output['label_len'] = 0
        
        plot_save(test_context, target_name, output, size = size, plot_opt = plot_opt)
            
        del image
        del out1, out2, out3
        del output
        gc.collect()
        torch.cuda.empty_cache()

In [None]:
test_all(test_context, test_config)

In [None]:
#with open("../data/detect/links/url_links.txt", "r") as url_files:
#    lines = url_files.readlines()
#    for url in lines:
#        try:
#            test_context['debug_level'] = 1
#            test_url_image(test_context, test_config, url)
#        except urllib.error.HTTPError:
#            print('http error')

In [None]:
test_video(test_context, 'toaru_majutsu_op', start = 0, size=(1280, 720), plot_opt = 'CUSTOM2')