## Import & Environment Setting

In [None]:
import torch
from torch import autograd
from torch.utils.data import DataLoader

import json
import gc
import numpy as np

from utils.dataset import DetectionFolder
from utils.model import YoloV3, YoloLoss, PostProcessor

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches

import urllib
from io import BytesIO
from PIL import Image

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

## Load Config

In [None]:
with open("./config/config.json", "r") as config_file:
    main_config = json.load(config_file)

try:
    model_config = main_config['model']
    test_config = main_config['test']
except NameError:
    assert False, ('Failed to load config file')
except KeyError:
    assert False, ('Failed to find key on config file')

In [None]:
model_config['device'] = device
model_config['dtype'] = dtype
model_config['attrib_count'] = 5 + model_config['class_count']

test_config['device'] = device
test_config['dtype'] = dtype

## Build

In [None]:
model = torch.load('./model/model_250.dat')
model.to(device)
model = model.eval()

In [None]:
test_context = { }

test_context['dataset'] = DetectionFolder(test_config['test_list'], test_config['test_image'], test_config['test_label'])
test_context['dataloader'] = DataLoader(test_context['dataset'], batch_size = 1, num_workers = 1)

test_context['conf_threshold'] = test_config['conf_threshold']
test_context['iou_threshold'] = test_config['iou_threshold']
test_context['debug_level'] = test_config['debug_level']

In [None]:
def iou(bbox1, bbox2):
    w1 = bbox1[2]
    h1 = bbox1[3]
    
    w2 = bbox1[2]
    h2 = bbox1[3]
    
    left1 = bbox1[0] - w1 / 2
    left2 = bbox2[0] - w2 / 2
    right1 = bbox1[0] + w1 / 2
    right2 = bbox2[0] + w2 / 2
    top1 = bbox1[1] + h1 / 2
    top2 = bbox2[1] + h2 / 2
    bottom1 = bbox1[1] - h1 / 2
    bottom2 = bbox2[1] - h2 / 2
    
    area1 = w1 * h1
    area2 = w2 * h2
    
    w_intersect = min(right1, right2) - max(left1, left2)
    h_intersect = min(top1, top2) - max(bottom1, bottom2)
    area_intersect = h_intersect * w_intersect
    
    if w_intersect < 0 or h_intersect < 0:
        return 0

    iou_ = area_intersect / (area1 + area2 - area_intersect)
    return iou_
        
def plot_pred(test_context, prediction, plot_options = ['ABOVE']):
        
    if test_context['debug_level'] >= 1:
        print('title : ', prediction['title'])
        print('plot_options : ', plot_options)
        
    pred = prediction['pred']
    image = prediction['image']
    label = prediction['label']
    
    ax_indx = 0
    _, ax = plt.subplots(1, len(plot_options), figsize=(8 * len(plot_options), 8))
    
    
    for plot_opt in plot_options:
        # plot all boundingbox above threshold
        if plot_opt is 'ABOVE':
            above_thres = pred[np.where(pred[:, 4] > test_context['conf_threshold'])]
            pred_sorted = np.flip(np.argsort(above_thres[:, 4]))
            
            if len(plot_options) > 1:
                ax[ax_indx].imshow(image)
            else:
                ax.imshow(image)
                
            for box_it in range(0, len(pred_sorted)):
                bbox = above_thres[box_it]
            
                color = '%02x'% (255 - (255 // (len(pred_sorted)) * (len(pred_sorted) - 1 - box_it)))
                color = '#' + str(color) + str('00') + str('00')
            
                bounding = patches.Rectangle((bbox[0] - bbox[2] / 2, bbox[1] - bbox[3] / 2), bbox[2], bbox[3], 
                            linewidth=1, edgecolor=color, facecolor='none')
                
                if len(plot_options) > 1:
                    ax[ax_indx].add_patch(bounding)
                else:
                    ax.add_patch(bounding)
            ax_indx = ax_indx + 1
        
        # plot boundingbox above threshold using Non-Maximum Suppression
        if plot_opt is 'NMS':
            above_thres = pred[np.where(pred[:, 4] > test_context['conf_threshold'])]
            pred_sorted = np.flip(np.argsort(above_thres[:, 4]))
            
            pred_result = []
            for p0 in pred_sorted:
                discard = False
                for p1 in pred_result:
                    if iou(above_thres[p0], above_thres[p1]) > test_context['iou_threshold']:
                        discard = True
                        break
                if discard is False:
                    pred_result.append(p0)
            pred_result = np.array(pred_result)
                   
            if len(plot_options) > 1:
                ax[ax_indx].imshow(image)
            else:
                ax.imshow(image)
                
            for box_it in pred_result:
                bbox = above_thres[box_it]
            
                bounding = patches.Rectangle((bbox[0] - bbox[2] / 2, bbox[1] - bbox[3] / 2), bbox[2], bbox[3], 
                            linewidth=1, edgecolor='#FF0000', facecolor='none')
                
                if len(plot_options) > 1:
                    ax[ax_indx].add_patch(bounding)
                else:
                    ax.add_patch(bounding)
            ax_indx = ax_indx + 1
        
        # plot with custom algorithm 1
        if plot_opt is 'CUSTOM1':
            above_thres = pred[np.where(pred[:, 4] > test_context['conf_threshold'])]
            pred_sorted = np.flip(np.argsort(above_thres[:, 4]))
            
            pred_result = []
            for p0 in pred_sorted:
                new_group = True
                max_matching_group = 0
                max_iou = 0
                    
                for g1 in range(0, len(pred_result)):
                    iou_match = iou(above_thres[p0], np.mean(pred_result[g1], axis = 0))
                    if iou_match > test_context['iou_threshold']:
                        new_group = False
                        if max_iou < iou_match:
                            max_iou = iou_match
                            max_matching_group = g1
                                
                if new_group is True:
                    pred_result.append([above_thres[p0]])
                else:
                    pred_result[max_matching_group].append(above_thres[p0])
                        
            pred_result = np.array(pred_result)
                   
            if len(plot_options) > 1:
                ax[ax_indx].imshow(image)
            else:
                ax.imshow(image)
                
            for box_it in range(0, len(pred_result)):
                bbox = np.mean(pred_result[box_it], axis = 0)
            
                bounding = patches.Rectangle((bbox[0] - bbox[2] / 2, bbox[1] - bbox[3] / 2), bbox[2], bbox[3], 
                            linewidth=1, edgecolor='#FF0000', facecolor='none')
                
                if len(plot_options) > 1:
                    ax[ax_indx].add_patch(bounding)
                else:
                    ax.add_patch(bounding)
            ax_indx = ax_indx + 1
                
        # plot with custom algorithm 2
        if plot_opt is 'CUSTOM2':
            above_thres = np.copy(pred[np.where(pred[:, 4] > test_context['conf_threshold'])])
            pred_sorted = above_thres[np.flip(np.argsort(above_thres[:, 4]))]
            
            # merge with max iou until converge
            pred_result = []
            converge = False
            while converge is False:
                if len(pred_sorted) is 0:
                    converge = True
                    break
                        
                max_iou = 0
                max_indx = 0
                    
                p0 = pred_sorted[0]
                for p_indx in range(1, len(pred_sorted)):
                    iou_match = iou(p0, pred_sorted[p_indx])
                    if iou_match > test_context['iou_threshold'] and iou_match > max_iou:
                        max_iou = iou_match
                        max_indx = p_indx
                    
                if max_indx is not 0:
                    weight_0 = pred_sorted[0][4]
                    weight_1 = pred_sorted[max_indx][4]
                    weight_sum = weight_0 + weight_1
                        
                    avg  = (pred_sorted[0] * weight_0 / weight_sum) + (pred_sorted[max_indx] * weight_1 / weight_sum)
                        
                    pred_sorted = np.delete(pred_sorted, max_indx, 0)
                    pred_sorted = np.delete(pred_sorted, 0, 0)
                    pred_sorted = np.append(pred_sorted, [avg], 0)
                else:
                    pred_result.append(p0)
                    pred_sorted = np.delete(pred_sorted, 0, 0)
                        
                if len(pred_sorted) is 0:
                    converge = True
                else:
                    pred_sorted = pred_sorted[np.flip(np.argsort(pred_sorted[:, 4]))]
                    
            if len(plot_options) > 1:
                ax[ax_indx].imshow(image)
            else:
                ax.imshow(image)
                
            for box_it in range(0, len(pred_result)):
                bbox = pred_result[box_it]
            
                bounding = patches.Rectangle((bbox[0] - bbox[2] / 2, bbox[1] - bbox[3] / 2), bbox[2], bbox[3], 
                            linewidth=1, edgecolor='#FF0000', facecolor='none')
                
                if len(plot_options) > 1:
                    ax[ax_indx].add_patch(bounding)
                else:
                    ax.add_patch(bounding)
            ax_indx = ax_indx + 1
                
    plt.show()
        
    gc.collect()
    torch.cuda.empty_cache()
        
def test_all(test_context, test_config, plot_options = ['ABOVE', 'NMS', 'CUSTOM2']):
    with torch.no_grad():
        for idx, batches in enumerate(test_context['dataloader']):
            image = batches['image'].to(test_config['device'], dtype = test_config['dtype'])
            labels = batches['label'].to(test_config['device'], dtype = test_config['dtype'])
            label_len = batches['label_len'].to(test_config['device'], dtype = torch.long)
            
            out1, out2, out3 = model(image)
            pred = torch.cat((out1, out2, out3), 1)
        
            output = {}
            output['title'] = batches['title']
            output['image'] = batches['image'].cpu().permute(0, 2, 3, 1).squeeze(0).numpy()
            output['pred'] = pred.cpu().detach().squeeze(0).numpy()
            output['label'] = batches['label'].cpu().squeeze(0).numpy()
            
            plot_pred(test_context, output, plot_options)
        
            del image, labels, label_len
            del out1, out2, out3, output
            gc.collect()
            torch.cuda.empty_cache()

def test_first(test_context, test_config, plot_options = ['ABOVE', 'NMS', 'CUSTOM2']):
    with torch.no_grad():
        for idx, batches in enumerate(test_context['dataloader']):
            image = batches['image'].to(test_config['device'], dtype = test_config['dtype'])
            labels = batches['label'].to(test_config['device'], dtype = test_config['dtype'])
            label_len = batches['label_len'].to(test_config['device'], dtype = torch.long)
            
            out1, out2, out3 = model(image)
            pred = torch.cat((out1, out2, out3), 1)
        
            output = {}
            output['title'] = batches['title']
            output['image'] = batches['image'].cpu().permute(0, 2, 3, 1).squeeze(0).numpy()
            output['pred'] = pred.cpu().detach().squeeze(0).numpy()
            output['label'] = batches['label'].cpu().squeeze(0).numpy()
            
            plot_pred(test_context, output, plot_options)
        
            del image, labels, label_len
            del out1, out2, out3, output
            gc.collect()
            torch.cuda.empty_cache()
            
            return
        
def test_url_image(test_config, url, crop = False, crop_size = (0, 0, 1920, 1080), plot_options = ['ABOVE', 'NMS', 'CUSTOM2']):
    image = Image.open(urllib.request.urlopen(url))
    if crop is True:
        image = image.crop(crop_size)
    image = np.array(image.resize((608, 608), Image.BILINEAR))[:, :, :3]
    image = torch.from_numpy(image).unsqueeze(0).float() / 255
    image = image.permute(0, 3, 1, 2)
    
    image = image.to(test_config['device'], dtype = test_config['dtype'])
        
    out1, out2, out3 = model(image)
        
    pred = torch.cat((out1, out2, out3), 1)
        
    output = {}
    output['title'] = url if len(url) < 80 else url[0:77] + '...' 
    output['image'] = image.permute(0, 2, 3, 1).cpu().squeeze(0).numpy()
    output['pred'] = pred.cpu().detach().squeeze(0).numpy()
    output['label'] = []
        
    plot_pred(test_context, output, plot_options)
    
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
test_all(test_context, test_config)

In [None]:
for _ in range(0, 5) :
    test_context['dataset'].shuffle()
    test_first(test_context, test_config)

In [None]:


with open("../data/detect/links/url_links.txt", "r") as url_files:
    lines = url_files.readlines()
    for url in lines:
        try:
            test_context['debug_level'] = 1
            test_url_image(test_config, url)
        except urllib.error.HTTPError:
            print('http error')