## Import & Environment Setting

In [None]:
import torch
from torch import autograd
from torch.utils.data import DataLoader

import json
import gc
import numpy as np

from utils.dataset import LabeledDataset, VideoDataset
from utils.model import YoloV3, YoloLoss
from utils.postprocess import PostProcessor

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches

import os
import fnmatch
from collections import Counter

import urllib
from io import BytesIO

from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw 

from random import randint

import datetime

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

## Load Config

In [None]:
with open("./config/config.json", "r") as config_file:
    main_config = json.load(config_file)

try:
    model_config = main_config['model']
    test_config = main_config['test']
except NameError:
    assert False, ('Failed to load config file')
except KeyError:
    assert False, ('Failed to find key on config file')

In [None]:
model_config['device'] = device
model_config['dtype'] = dtype
model_config['attrib_count'] = 5 + model_config['class_count']

test_config['device'] = device
test_config['dtype'] = dtype

## Build

In [None]:
model = torch.load('./model/run18/output_117.dat')
model.to(device)
model = model.eval()

In [None]:
test_context = { }

test_context['dataset'] = LabeledDataset(test_config['set']['index'], 
                                         test_config['set']['image_dir'], 
                                         test_config['set']['label_dir'])
test_context['dataloader'] = DataLoader(test_context['dataset'], batch_size = 1, num_workers = 1)

test_context['video_dir'] = test_config['video_dir']
test_context['font'] = ImageFont.truetype("arial.ttf", 16)

In [None]:
def iou(bbox1, bbox2):
    w1 = bbox1[2]
    h1 = bbox1[3]
    
    w2 = bbox1[2]
    h2 = bbox1[3]
    
    left1 = bbox1[0] - w1 / 2
    left2 = bbox2[0] - w2 / 2
    right1 = bbox1[0] + w1 / 2
    right2 = bbox2[0] + w2 / 2
    top1 = bbox1[1] + h1 / 2
    top2 = bbox2[1] + h2 / 2
    bottom1 = bbox1[1] - h1 / 2
    bottom2 = bbox2[1] - h2 / 2
    
    area1 = w1 * h1
    area2 = w2 * h2
    
    w_intersect = min(right1, right2) - max(left1, left2)
    h_intersect = min(top1, top2) - max(bottom1, bottom2)
    area_intersect = h_intersect * w_intersect
    
    if w_intersect < 0 or h_intersect < 0:
        return 0

    iou_ = area_intersect / (area1 + area2 - area_intersect)
    return iou_
        
def plot_pred(test_context, test_config, prediction, plot_options = ['ABOVE'], print_accuracy = False):
        
    print('title : ', prediction['title'])
    print('plot_options : ', plot_options)
        
    pred = prediction['pred']
    image = prediction['image']
    label = prediction['label']
    label_len = prediction['label_len']
    
    postProcessor = PostProcessor()
    
    ax_indx = 0
    _, ax = plt.subplots(1, len(plot_options), figsize=(8 * len(plot_options), 8))
    
    for plot_opt in plot_options:
        # plot all boundingbox above threshold
        if plot_opt is 'ABOVE':
            pred_result = postProcessor.ABOVE(pred, test_config['post'])
            accuracy = None
        elif plot_opt is 'NMS':
            pred_result = postProcessor.NMS(pred, test_config['post'])
            if print_accuracy is True:
                accuracy = postProcessor.calcAccuracyMap(label, label_len, pred_result, test_config['post'])
                print('NMS accuracy : ', accuracy)
        elif plot_opt is 'CUSTOM1':
            pred_result = postProcessor.CUSTOM1(pred, test_config['post'])
            if print_accuracy is True:
                accuracy = postProcessor.calcAccuracyMap(label, label_len, pred_result, test_config['post'])
                print('CUSTOM1 accuracy : ', accuracy)
        elif plot_opt is 'CUSTOM2':
            pred_result = postProcessor.CUSTOM2(pred, test_config['post'])
            if print_accuracy is True:
                accuracy = postProcessor.calcAccuracyMap(label, label_len, pred_result, test_config['post'])
                print('CUSTOM2 accuracy : ', accuracy)
        else:
            ax_indx = ax_indx + 1
            continue
            
            
        if len(plot_options) > 1:
            ax[ax_indx].imshow(image)
        else:
            ax.imshow(image)
                
        for bbox in pred_result:
            #color = '%02x'% int(255 * (bbox[4] - test_context['post_conf_threshold']) / (1 - test_context['post_conf_threshold']))
            #color = '#' + str(color) + str('00') + str('00')
            color = '#FF0000'
            
            bounding = patches.Rectangle((bbox[0] - bbox[2] / 2, bbox[1] - bbox[3] / 2), bbox[2], bbox[3], 
                        linewidth=1, edgecolor=color, facecolor='none')
                
            if len(plot_options) > 1:
                ax[ax_indx].add_patch(bounding)
            else:
                ax.add_patch(bounding)
        ax_indx = ax_indx + 1
    
        
    plt.show()
    plt.close('all')
        
    gc.collect()
    torch.cuda.empty_cache()
    
        
def test_all(test_context, test_config, plot_options = ['ABOVE', 'NMS', 'CUSTOM1']):
    
    postProcessor = PostProcessor()
    
    accs = Counter({})
    custom2_count = 0
    nms_count = 0
    
    with torch.no_grad():
        for idx, batches in enumerate(test_context['dataloader']):
            image = batches['image'].to(test_config['device'], dtype = test_config['dtype'])
            labels = batches['label'].to(test_config['device'], dtype = test_config['dtype'])
            label_len = batches['label_len'].to(test_config['device'], dtype = torch.long)
            
            out1, out2, out3 = model(image)
            pred = torch.cat((out1, out2, out3), 1)
        
            output = {}
            output['title'] = batches['title']
            output['image'] = batches['image'].cpu().permute(0, 2, 3, 1).squeeze(0).numpy()
            output['pred'] = pred.cpu().detach().squeeze(0).numpy()
            output['label'] = batches['label'].cpu().squeeze(0).numpy()
            output['label_len'] = batches['label_len'].cpu().squeeze(0).numpy()
            
            pred_result = postProcessor.CUSTOM1(output['pred'], test_config['post'])
            accuracy = postProcessor.calcAccuracyMap(output['label'], output['label_len'], pred_result, test_config['post'])
            accs = accs + Counter(accuracy)
            
            #plot_pred(test_context, output, plot_options, print_accuracy = True)
            print(output['title'], ' : ', accuracy)
            pred_custom2 = postProcessor.CUSTOM1(output['pred'], test_config['post'])
            pred_NMS = postProcessor.NMS(output['pred'], test_config['post'])
            
            
            # iou loss sum of custom2
            iou_loss_sum_custom2 = 0
            for p in pred_custom2:
                max_indx = -1
                max_iou = 0.01

                for i in range(0, output['label_len']):
                    iou_val = postProcessor.iou(p, output['label'][i])
                    if max_iou < iou_val:
                        max_iou = iou_val
                        max_indx = i
            
                if max_indx is -1:
                    iou_loss_sum_custom2 += 1
                else:
                    if max_iou > test_config['post']['acc_iou_threshold']:
                        iou_loss_sum_custom2 = 1 - max_iou
                    else:
                        iou_loss_sum_custom2 += 1
        
            
            # iou loss sum of nms
            iou_loss_sum_nms = 0
            for p in pred_NMS:
                max_indx = -1
                max_iou = 0.01

                for i in range(0, output['label_len']):
                    iou_val = postProcessor.iou(p, output['label'][i])
                    if max_iou < iou_val:
                        max_iou = iou_val
                        max_indx = i
            
                if max_indx is -1:
                    iou_loss_sum_nms += 1
                else:
                    if max_iou > test_config['post']['acc_iou_threshold']:
                        iou_loss_sum_nms = 1 - max_iou
                    else:
                        iou_loss_sum_nms += 1
            
            if iou_loss_sum_nms > iou_loss_sum_custom2:
                print('Custom2 : ', iou_loss_sum_custom2, ' vs ', iou_loss_sum_nms)
                custom2_count += 1
            elif iou_loss_sum_nms < iou_loss_sum_custom2:
                print('NMS : ', iou_loss_sum_custom2, ' vs ', iou_loss_sum_nms)
                nms_count += 1
                
            #if accuracy['false negative'] + accuracy['false positive'] + accuracy['duplicate'] != 0:
            plot_pred(test_context, test_config, output, plot_options, print_accuracy = False)
            
            del image, labels, label_len
            del out1, out2, out3, output
            gc.collect()
            torch.cuda.empty_cache()
        print('total : ', accs)
        tp = accs['true positive']
        fn = accs['false negative']
        fp = accs['false positive'] + accs['duplicate']
        print('accuracy : ', tp / (tp + fn + fp))
        print('custom2 vs nms : ', custom2_count, ' ', nms_count)

def test_n(test_context, test_config, n = 1, plot_options = ['ABOVE', 'NMS', 'CUSTOM2']):
    with torch.no_grad():
        for idx, batches in enumerate(test_context['dataloader']):
            image = batches['image'].to(test_config['device'], dtype = test_config['dtype'])
            labels = batches['label'].to(test_config['device'], dtype = test_config['dtype'])
            label_len = batches['label_len'].to(test_config['device'], dtype = torch.long)
            
            out1, out2, out3 = model(image)
            pred = torch.cat((out1, out2, out3), 1)
        
            output = {}
            output['title'] = batches['title']
            output['image'] = batches['image'].cpu().permute(0, 2, 3, 1).squeeze(0).numpy()
            output['pred'] = pred.cpu().detach().squeeze(0).numpy()
            output['label'] = batches['label'].cpu().squeeze(0).numpy()
            output['label_len'] = batches['label_len'].cpu().squeeze(0).numpy()
            
            plot_pred(test_context, output, plot_options, print_accuracy = True)
        
            del image, labels, label_len
            del out1, out2, out3, output
            gc.collect()
            torch.cuda.empty_cache()
            
            if idx + 1 >= n:
                return
            
def test_url_image(test_context, test_config, url, crop = False, crop_size = (0, 0, 1920, 1080), 
                   plot_options = ['ABOVE', 'NMS', 'CUSTOM2']):
    image = Image.open(urllib.request.urlopen(url))
    if crop is True:
        image = image.crop(crop_size)
    image = np.array(image.resize((608, 608), Image.BILINEAR))[:, :, :3]
    image = torch.from_numpy(image).unsqueeze(0).float() / 255
    image = image.permute(0, 3, 1, 2)
    
    image = image.to(test_config['device'], dtype = test_config['dtype'])
        
    out1, out2, out3 = model(image)
        
    pred = torch.cat((out1, out2, out3), 1)
        
    output = {}
    output['title'] = url if len(url) < 80 else url[0:77] + '...' 
    output['image'] = image.permute(0, 2, 3, 1).cpu().squeeze(0).numpy()
    output['pred'] = pred.cpu().detach().squeeze(0).numpy()
    output['label'] = []
    output['label_len'] = 0
        
    plot_pred(test_context, output, plot_options)
    
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
test_all(test_context, test_config)

In [None]:
def plot_kernel(conv_layer, group_thres = 2):
    if type(conv_layer) is not torch.nn.modules.conv.Conv2d:
        print('layer not Conv2d')
        return
    kernels = conv_layer.weight.data.detach().cpu()
    cnt = kernels.shape[0]
    
    similar_group = []
    for i in range(0, cnt):
        for j in range(i, cnt):
            if i == j:
                continue
                
            kernel_i = kernels[i].numpy().astype(np.float)
            max_num = np.max(np.abs(kernel_i))
            kernel_i = kernel_i / max_num
            kernel_j = kernels[j].numpy().astype(np.float)
            max_num = np.max(np.abs(kernel_j))
            kernel_j = kernel_j / max_num
                
            diff = kernel_i - kernel_j
            diff = np.abs(diff)
            diff = np.sum(diff)
            
            if diff < group_thres:
                new_group = True
                for g in similar_group:
                    if i in g:
                        new_group = False
                        g.add(j)
                if new_group:
                    similar_group.append(set([i, j]))
    print(similar_group)
    
    group_color = []
    for i in range(0, len(similar_group)):
        r = '%02X' % randint(0, 255)
        g = '%02X' % randint(0, 255)
        b = '%02X' % randint(0, 255)
        group_color.append('#' + str(r) + str(g) + str(b))
    print(group_color)
                
    fig, ax = plt.subplots(cnt // 8, 8, figsize=(16, 16))
    for i in range(0, cnt // 8):
        for j in range(0, 8):
            kernel = kernels[i * 8 + j].numpy().astype(np.float)
            max_num = np.max(np.abs(kernel))
            kernel = kernel / max_num
            kernel = (kernel + 1)/ 2
            
            ax[i, j].imshow(kernel)
            
    for g in range(0, len(similar_group)):
        for i in similar_group[g]:
            ax[i // 8, i % 8].spines['left'].set_color(group_color[g])
            ax[i // 8, i % 8].spines['right'].set_color(group_color[g])
            ax[i // 8, i % 8].spines['bottom'].set_color(group_color[g])
            
            ax[i // 8, i % 8].tick_params(axis='x', colors = group_color[g])
            ax[i // 8, i % 8].tick_params(axis='y', colors = group_color[g])
            
    for i in range(0, cnt):
        kernel = kernels[i].numpy().astype(np.float)
        max_num = np.max(np.abs(kernel))
        
        if max_num < 0.3 :
            ax[i // 8, i % 8].spines['top'].set_color('red')
            
    plt.show()
    
            

In [None]:
plot_kernel(model.darknet.baseline[0].body[0])

In [None]:
#with open("../data/detect/links/url_links.txt", "r") as url_files:
#    lines = url_files.readlines()
#    for url in lines:
#        try:
#            test_context['debug_level'] = 1
#            test_url_image(test_context, test_config, url)
#        except urllib.error.HTTPError:
#            print('http error')

In [None]:
#test_video(test_context, 'toaru_majutsu_op', start = 0, size=(1280, 720), plot_opt = 'CUSTOM2')

In [None]:
def test_video2(context, config, target_name, size=(1920, 1080), splits = (1, 1), batch_multiplier = 1, post_opt = 'CUSTOM1'):
    split_count = splits[0] * splits[1]
    batch_size = batch_multiplier * split_count
    
    image_dir = test_context['video_dir'] + target_name + '/'
    dataset = VideoDataset(image_dir, splits = splits, from_size = size, to_size = (608, 608))
    loader = DataLoader(dataset, batch_size = batch_size, num_workers = 1)
    postprocess = PostProcessor()
    
    
    with torch.no_grad():
        for idx, batches in enumerate(loader):
            image = batches['image'].to(config['device'], dtype = config['dtype'])
            
            out1, out2, out3 = model(image)
            pred = torch.cat((out1, out2, out3), 1)
            
            for i in range(0, batches['image'].shape[0] // split_count):
                # processing image
                print(batches['title'][i * split_count])
                
                # 1. recover prediction coordinate
                for j in range(0, split_count):
                    crop_size = batches['crop_size'][i * split_count + j]
                    crop_bbox = batches['crop_bbox'][i * split_count + j]
    
                    split_ratio = float(max(crop_size[0], crop_size[1])) / 608
                    
                    crop_mult = torch.as_tensor([split_ratio, 
                                                 split_ratio, 
                                                 split_ratio, 
                                                 split_ratio]).to(config['device'], dtype = config['dtype'])
                    crop_add = torch.as_tensor([crop_bbox[0], 
                                                crop_bbox[1], 
                                                0, 
                                                0]).to(config['device'], dtype = config['dtype'])
                    
                    pred[i * split_count + j][:, :4].mul_(crop_mult).add_(crop_add)
                
                # 2. concat splits
                pred_result = torch.cat(tuple(pred[i * split_count : (i + 1) * split_count]), dim = 0).cpu().detach()
                pred_result = np.copy(pred_result[np.where(pred_result[:, 4] > context['post_conf_threshold'])])
                
                # 3. run CUSTOM1
                pred_result = postprocess.CUSTOM1(pred_result, context)
                print(len(pred_result), 'boxes found')
            
                
                # 4. to pillow
                image_to_save = batches['image_og'][i * split_count].cpu().numpy()
                image_to_save = Image.fromarray(image_to_save.astype('uint8'), 'RGB')
            
                # 5. draw image
                draw = ImageDraw.Draw(image_to_save)
                for bbox in pred_result:
                    left = int(bbox[0] - bbox[2] / 2)
                    right = int(bbox[0] + bbox[2] / 2)
                    up = int(bbox[1] - bbox[3] / 2)
                    down = int(bbox[1] + bbox[3] / 2)
        
                    draw.text((left, down), str(bbox[4]),(255,0,0),font=test_context['font'])
                    draw.rectangle(xy=[left, up, right, down], outline=(255, 0, 0) )
            
                # 6_1. save image
                output_dir = test_context['video_dir'] + 'o_' + target_name + '/'
                image_to_save.save(output_dir + batches['title'][i * split_count][:-3] + 'png')
                
                # 6_2. show image
                #_, ax = plt.subplots(figsize=(16, 9))
                #ax.imshow(np.array(image_to_save))
                #plt.show()
            
            del image, batches
            del out1, out2, out3, pred
            del pred_result
            gc.collect()
            torch.cuda.empty_cache()
            

In [None]:
start_time = datetime.datetime.now().time()
print('start time : ', start_time)
#test_video2(test_context, test_config, 'imas_ready', size=(1920, 1080), splits = (2, 1), batch_multiplier = 1, post_opt = 'CUSTOM1')
end_time = datetime.datetime.now().time()

print('done')
print('start time : ', start_time)
print('end time : ', end_time)