## Import & Environment Setting

In [None]:
import torch
from torch import autograd
from torch.utils.data import DataLoader

import json
import gc
import numpy as np

from utils.dataset import DetectionFolder
from utils.model import YoloV3, YoloLoss, PostProcessor

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches

import urllib
from io import BytesIO
from PIL import Image

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dtype = torch.float

## Load Config

In [None]:
with open("./config/config.json", "r") as config_file:
    main_config = json.load(config_file)

try:
    model_config = main_config['model']
    test_config = main_config['test']
except NameError:
    assert False, ('Failed to load config file')
except KeyError:
    assert False, ('Failed to find key on config file')

In [None]:
model_config['device'] = device
model_config['dtype'] = dtype
model_config['attrib_count'] = 5 + model_config['class_count']

test_config['device'] = device
test_config['dtype'] = dtype

## Build

In [None]:
model = torch.load('./model/1104/model_250.dat')
model.to(device)
model = model.eval()

In [None]:
test_context = { }

test_context['dataset'] = DetectionFolder(test_config['test_list'], test_config['test_image'], test_config['test_label'])
test_context['dataloader'] = DataLoader(test_context['dataset'], batch_size = 1, num_workers = 1)

test_context['k'] = 30
test_context['threshold'] = 0.5
test_context['nms_threshold'] = 0.2
test_context['debug_level'] = test_config['debug_level']

In [None]:
def test_labels(test_context, test_config, n):
    output = {}
    output['title'] = []
    output['image'] = []
    output['pred'] = []
    output['label'] = []
    output['n'] = 0
    
    with torch.no_grad():
        for idx, batches in enumerate(test_context['dataloader']):
            image = batches['image'].to(test_config['device'], dtype = test_config['dtype'])
            labels = batches['label'].to(test_config['device'], dtype = test_config['dtype'])
            label_len = batches['label_len'].to(test_config['device'], dtype = torch.long)
            
            out1, out2, out3 = model(image)
            pred = torch.cat((out1, out2, out3), 1)
        
            output['title'].append(batches['title'])
            output['image'].append(batches['image'].cpu().permute(0, 2, 3, 1).squeeze(0).numpy())
            output['pred'].append(pred.cpu().detach().squeeze(0).numpy())
            output['label'].append(batches['label'].cpu().squeeze(0).numpy())
            output['n'] = output['n'] + 1
        
            gc.collect()
            torch.cuda.empty_cache()
            del image, labels, label_len
            del out1, out2, out3
            
            if (idx + 1) == n :
                break
                
    return output

def test_no_labels(test_context, test_config, n):
    output = {}
    output['title'] = []
    output['image'] = []
    output['pred'] = []
    output['label'] = []
    output['n'] = 0
    
    with torch.no_grad():
        for idx, batches in enumerate(test_context['dataloader']):
            image = batches['image'].to(test_config['device'], dtype = test_config['dtype'])
            
            out1, out2, out3 = model(image)
            pred = torch.cat((out1, out2, out3), 1)
        
            output['title'].append(batches['title'])
            output['image'].append(batches['image'].permute(0, 2, 3, 1).squeeze(0).numpy())
            output['pred'].append(pred.cpu().detach().squeeze(0).numpy())
            output['n'] = output['n'] + 1
        
            gc.collect()
            torch.cuda.empty_cache()
            del image, labels, label_len
            del out1, out2, out3
            
            if (idx + 1) == n :
                break
                
    return output


def test_url_image(test_config, url, crop = False, crop_size = (0, 0, 1920, 1080)):
    image = Image.open(urllib.request.urlopen(url))
    if crop is True:
        image = image.crop(crop_size)
    image = np.array(image.resize((608, 608), Image.BILINEAR))[:, :, :3]
    image = torch.from_numpy(image).unsqueeze(0).float() / 255
    image = image.permute(0, 3, 1, 2)
    
    image = image.to(test_config['device'], dtype = test_config['dtype'])
        
    out1, out2, out3 = model(image)
        
    pred = torch.cat((out1, out2, out3), 1)
        
    output = {}
    output['title'] = [url]
    output['image'] = [image.permute(0, 2, 3, 1).cpu().squeeze(0).numpy()]
    output['pred'] = [pred.cpu().detach().squeeze(0).numpy()]
    output['label'] = []
    output['n'] = 1
        
    return output


def test_url_temp(test_config, url, crop = False, crop_size = (0, 0, 1920, 1080)):
    image = Image.open(urllib.request.urlopen(url))
    if crop is True:
        image = image.crop(crop_size)
    image = image.convert("L")
    image = np.array(image.resize((608, 608), Image.BILINEAR))[:, :]
    image = np.array([image, image, image])
    image = torch.from_numpy(image).unsqueeze(0).float() / 255
    #image = image.permute(0, 3, 1, 2)
    print(image.shape)
    
    image = image.to(test_config['device'], dtype = test_config['dtype'])
        
    out1, out2, out3 = model(image)
        
    pred = torch.cat((out1, out2, out3), 1)
        
    output = {}
    output['title'] = [url]
    output['image'] = [image.permute(0, 2, 3, 1).cpu().squeeze(0).numpy()]
    output['pred'] = [pred.cpu().detach().squeeze(0).numpy()]
    output['label'] = []
    output['n'] = 1
        
    return output


def iou(bbox1, bbox2):
    w1 = bbox1[2]
    h1 = bbox1[3]
    
    w2 = bbox1[2]
    h2 = bbox1[3]
    
    left1 = bbox1[0] - w1 / 2
    left2 = bbox2[0] - w2 / 2
    right1 = bbox1[0] + w1 / 2
    right2 = bbox2[0] + w2 / 2
    top1 = bbox1[1] + h1 / 2
    top2 = bbox2[1] + h2 / 2
    bottom1 = bbox1[1] - h1 / 2
    bottom2 = bbox2[1] - h2 / 2
    
    area1 = w1 * h1
    area2 = w2 * h2
    
    w_intersect = min(right1, right2) - max(left1, left2)
    h_intersect = min(top1, top2) - max(bottom1, bottom2)
    area_intersect = h_intersect * w_intersect

    iou_ = area_intersect / (area1 + area2 - area_intersect + 1e-9)
    return iou_
        
def plot_pred(test_context, pred):
    
    tmp_context = {}
    tmp_context['post_mAp'] = 0.5
    tmp_context['post_iou_threshold'] = 0.3
        
    postprocessor = PostProcessor()
    pred_boxes = postprocessor.GAS(torch.from_numpy(np.array(pred['pred'])), tmp_context)
        
    for indx in range(0, pred['n']):
        if len(pred['title']) > indx and test_context['debug_level'] >= 1:
            print('title : ', pred['title'][indx])
        
        p = pred['pred'][indx]
        i = pred['image'][indx]
        l = pred['label'][indx] if len(pred['label']) > indx else None
        
        
        avg_boxes = pred_boxes[indx]
        
        _, ax = plt.subplots(1, 1, figsize=(6, 6))
        ax.imshow(i)
        for box_it in range(0, len(avg_boxes)):
            bbox = avg_boxes[box_it]
            
            color = '%02x'% (255)
            color = '#' + str(color) + str('00') + str('00')
            
            bounding = patches.Rectangle((bbox[0] - bbox[2] / 2, bbox[1] - bbox[3] / 2), bbox[2], bbox[3], 
                                 linewidth=1, edgecolor=color, facecolor='none')
            ax.add_patch(bounding)
        plt.show()
        
        # previous
        #bbox_cnt = test_context['k']
        #
        #p_sorted = np.argsort(p[:, 4])[len(p) - bbox_cnt:]
        #tmp = []
        #for pred_it in p_sorted:
        #    if p[pred_it][4] > test_context['threshold']:
        #        tmp.append(pred_it)
        #p_sorted = np.array(tmp)
        #
        #_, ax = plt.subplots(1, 1, figsize=(6, 6))
        #ax.imshow(i)
        #for box_it in range(len(p_sorted) - 1, -1, -1):
        #    bbox = p[p_sorted[box_it]]
        #    
        #    color = '%02x'% (255 - (255 // (len(p_sorted)) * (len(p_sorted) - 1 - box_it)))
        #    color = '#' + str(color) + str('00') + str('00')
        #    
        #    bounding = patches.Rectangle((bbox[0] - bbox[2] / 2, bbox[1] - bbox[3] / 2), bbox[2], bbox[3], 
        #                         linewidth=1, edgecolor=color, facecolor='none')
        #    ax.add_patch(bounding)
        #plt.show()
        
        gc.collect()
        torch.cuda.empty_cache()

In [None]:
pred = test_labels(test_context, test_config, 1)

In [None]:
plot_pred(test_context, pred)

In [None]:
test_context['threshold'] = 0.5
for _ in range(0, 5) :
    test_context['dataset'].shuffle()
    pred = test_labels(test_context, test_config, 1)
    plot_pred(test_context, pred)
    
    gc.collect()
    torch.cuda.empty_cache()

In [None]:


with open("../data/detect/links/url_links.txt", "r") as url_files:
    gc.collect()
    torch.cuda.empty_cache()
    
    lines = url_files.readlines()

    test_context['k'] = 100
    for url in lines:
        try:
            #url = 'http://dowase.net/blog/wp-content/uploads/2012/10/OdaCast.jpg'
            print(url)
            pred = test_url_image(test_config, url)
        
            test_context['threshold'] = 0.5
            plot_pred(test_context, pred)
            #test_context['threshold'] = 0.3
            #plot_pred(test_context, pred)
            #test_context['threshold'] = 0.1
            #plot_pred(test_context, pred)
        except urllib.error.HTTPError:
            print('http error')