In [1]:
# coding='utf-8'
import os
import sys
import numpy as np
import time
import datetime
import json
import importlib
import logging
import shutil
#import cv2
import random

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.ticker import NullLocator

import torch
import torch.nn as nn

from yolo_model import yoloModel
from PASCAL_Dataloader import create_split_loaders
from YOLO_Loss import YoloLoss
#from utils import NMS
from bbox import non_max_suppression

cmap = plt.get_cmap('tab20b')
colors = [cmap(i) for i in np.linspace(0, 1, 20)]

In [2]:
def test(config):
    
    # Load and initialize network
    net = yoloModel(config)

    # Set data parallel
    net = net.cuda()

    # Restore pretrain model
    if config["pretrain_snapshot"]:
        logging.info("load checkpoint from {}".format(config["pretrain_snapshot"]))
        state_dict = torch.load(config["pretrain_snapshot"])
        #print(state_dict.keys())
        #print(torch.load("official_yolov3_weights_pytorch.pth").keys())
        net.load_state_dict(state_dict)
    else:
        raise Exception("missing pretrain_snapshot!!!")

    # YOLO loss with 3 scales
    yolo_losses = []
    for i in range(3):
        yolo_losses.append(YoloLoss(config["classes"], (config["img_w"], config["img_h"]), config["anchors"][i]))

    
    # prepare images path
    images_name = os.listdir(config["images_path"])
    images_path = [os.path.join(config["images_path"], name) for name in images_name]
    if len(images_path) == 0:
        raise Exception("no image found in {}".format(config["images_path"]))
   
    root_dir = os.getcwd()
    train_loader, val_loader, test_loader = create_split_loaders(root_dir, config['batch_size'])
   

    # Start inference
    batch_size = config["batch_size"]
    for n, sample in enumerate(test_loader):#range(0, len(test_loader), batch_size):
        '''
        # preprocess
        images = []
        images_origin = []
        for path in images_path[n*batch_size: (n+1)*batch_size]:
            logging.info("processing: {}".format(path))
            image = cv2.imread(path, cv2.IMREAD_COLOR)
            if image is None:
                logging.error("read path error: {}. skip it.".format(path))
                continue
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            images_origin.append(image)  # keep for save result
            image = cv2.resize(image, (config["img_w"], config["img_h"]),
                               interpolation=cv2.INTER_LINEAR)
            image = image.astype(np.float32)
            image /= 255.0
            image = np.transpose(image, (2, 0, 1))
            image = image.astype(np.float32)
            images.append(image)
        '''
        images, labels = sample["image"], sample["label"]
        images = images.cuda()
        #images = np.asarray(images)
        #images = torch.from_numpy(images).cuda()
        # inference
        with torch.no_grad():
            outputs = net(images)
            output_list = []
            for i in range(3):
                output_list.append(yolo_losses[i](outputs[i]))
            #print(output_list)
            output = torch.cat(output_list, 1)
            batch_detections = non_max_suppression(output, config["classes"])

        # write result images. Draw bounding boxes and labels of detections
        classes = open(config["classes_names_path"], "r").read().split("\n")[:-1]
        if not os.path.isdir("./output/"):
            os.makedirs("./output/")
        for idx, detections in enumerate(batch_detections):
            plt.figure()
            fig, ax = plt.subplots(1)
            myimshow(images[idx])
            if detections is not None:
                unique_labels = detections[:, -1].cpu().unique()
                n_cls_preds = len(unique_labels)
                bbox_colors = random.sample(colors, n_cls_preds)
                for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:
                    color = bbox_colors[int(np.where(unique_labels == int(cls_pred))[0])]
                    
                    # Rescale coordinates to original dimensions
                    #ori_h, ori_w = images[idx].shape[:2]
                    pre_h, pre_w = config["img_h"], config["img_w"]
                    box_h = (y2 - y1)# / pre_h) #* ori_h
                    box_w = (x2 - x1)# / pre_w) #* ori_w
                    #y1 = (y1# / pre_h) #* ori_h
                    #x1 = (x1# / pre_w) #* ori_w
                    
                    
                    # Create a Rectangle patch
                    bbox = patches.Rectangle((x1, y1), box_w, box_h, linewidth=2,
                                             edgecolor=color,
                                             facecolor='none')
                    # Add the bbox to the plot
                    ax.add_patch(bbox)
                    # Add label
                    plt.text(x1, y1, s=classes[int(cls_pred)], color='white',
                             verticalalignment='top',
                             bbox={'color': color, 'pad': 0})
            # Save generated image with detections
            plt.axis('off')
            plt.gca().xaxis.set_major_locator(NullLocator())
            plt.gca().yaxis.set_major_locator(NullLocator())
            plt.savefig('output/{}_{}.jpg'.format(n, idx), bbox_inches='tight', pad_inches=0.0)
            plt.close()
    logging.info("Save all results to ./output/")  
    
def myimshow(image, ax=plt):
    image = image.to('cpu').numpy()
    image = np.moveaxis(image, [0, 1, 2], [2, 0, 1])
    image = (image + 1) / 2
    image[image < 0] = 0
    image[image > 1] = 1
    h = ax.imshow(image)
    ax.axis('off')
    return h

In [3]:
def main():
    logging.basicConfig(level=logging.DEBUG,
                        format="[%(asctime)s %(filename)s] %(message)s")

    config = {}
    config["batch_size"] = 1
    config['backbone_name'] = "darknet_53"
    config['backbone_pretrained'] = ""
    config['anchors'] = [[[116, 90], [156, 198], [373, 326]],
                                [[30, 61], [62, 45], [59, 119]],
                                [[10, 13], [16, 30], [33, 23]]]
    config['classes'] = 20
    config['img_h'] = config['img_w'] = 416
    config['confidence_threshold'] = 0.5
    config['pretrain_snapshot'] = "./states/20190530222509/model.pth"
    config['classes_names_path'] = "./data/voc.names"
    config['images_path'] = "./images"

    # Start training
    test(config)

In [4]:
if __name__ == "__main__":
    main()

[2019-05-31 10:56:35,411 <ipython-input-2-b4dc440c316e>] load checkpoint from ./states/20190530222509/model.pth
[2019-05-31 10:56:38,398 _base.py] update_title_pos
[2019-05-31 10:56:38,605 font_manager.py] findfont: Matching :family=sans-serif:style=normal:variant=normal:weight=normal:stretch=normal:size=10.0 to DejaVu Sans ('/opt/conda/lib/python3.6/site-packages/matplotlib/mpl-data/fonts/ttf/DejaVuSans.ttf') with score of 0.050000.
[2019-05-31 10:56:38,706 _base.py] update_title_pos
[2019-05-31 10:56:38,720 _base.py] update_title_pos
[2019-05-31 10:56:38,731 _base.py] update_title_pos
[2019-05-31 10:56:39,015 _base.py] update_title_pos
[2019-05-31 10:56:39,355 _base.py] update_title_pos
[2019-05-31 10:56:39,604 _base.py] update_title_pos
[2019-05-31 10:56:39,697 _base.py] update_title_pos
[2019-05-31 10:56:39,707 _base.py] update_title_pos
[2019-05-31 10:56:39,910 _base.py] update_title_pos
[2019-05-31 10:56:40,264 _base.py] update_title_pos
[2019-05-31 10:56:40,503 _base.py] update_

[2019-05-31 10:57:03,108 _base.py] update_title_pos
[2019-05-31 10:57:03,501 _base.py] update_title_pos
[2019-05-31 10:57:03,612 _base.py] update_title_pos
[2019-05-31 10:57:03,710 _base.py] update_title_pos
[2019-05-31 10:57:03,807 _base.py] update_title_pos
[2019-05-31 10:57:03,910 _base.py] update_title_pos
[2019-05-31 10:57:04,302 _base.py] update_title_pos
[2019-05-31 10:57:04,408 _base.py] update_title_pos
[2019-05-31 10:57:04,504 _base.py] update_title_pos
[2019-05-31 10:57:04,602 _base.py] update_title_pos
[2019-05-31 10:57:04,710 _base.py] update_title_pos
[2019-05-31 10:57:05,104 _base.py] update_title_pos
[2019-05-31 10:57:05,212 _base.py] update_title_pos
[2019-05-31 10:57:05,309 _base.py] update_title_pos
[2019-05-31 10:57:05,407 _base.py] update_title_pos
[2019-05-31 10:57:05,602 _base.py] update_title_pos
[2019-05-31 10:57:05,939 _base.py] update_title_pos
[2019-05-31 10:57:06,102 _base.py] update_title_pos
[2019-05-31 10:57:06,114 _base.py] update_title_pos
[2019-05-31 

IndexError: list index out of range