In [None]:
import re
import os
import math
import json
import shutil
from PIL import Image
import requests
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
torch.set_grad_enabled(False);

In [None]:
classes_raw = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]

classes_modified = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']

# colors for visualization
colors_raw = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

In [None]:
# standard PyTorch mean-std input image normalization
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b

In [None]:
def plot_results(pil_img, prob, boxes, name):
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = colors_raw * 100
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   alpha=0.3, color=c, linewidth=3))
        cl = p.argmax()
        text = f'{classes_raw[cl]}: {p[cl]:0.2f}'
        ax.text(xmin, ymin, text, fontsize=30,
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    plt.savefig(name)
    plt.close()

In [None]:
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
model.eval()

In [None]:
resolutions = [15, 20, 30, 50, 100, 160, 240]
plot_folder = './video_plot/'
info_folder = './video_result/'
for resolution in resolutions:
    test_folder = './nearest/video_test_nearest_' + str(resolution)
    if os.path.exists(os.path.join(plot_folder, str(resolution))):
        shutil.rmtree(os.path.join(plot_folder, str(resolution)))
    os.makedirs(os.path.join(plot_folder, str(resolution)))
    if os.path.exists(os.path.join(info_folder, str(resolution))):
        shutil.rmtree(os.path.join(info_folder, str(resolution)))
    os.makedirs(os.path.join(info_folder, str(resolution)))
    for video_folder in os.listdir(test_folder):
        if video_folder == ".DS_Store":
            continue
        for picture_name in os.listdir(os.path.join(test_folder, video_folder)):
            if picture_name == ".DS_Store":
                continue
            im = Image.open(os.path.join(test_folder, video_folder, picture_name))
            img = transform(im).unsqueeze(0)
            outputs = model(img)
            probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
            keep = probas.max(-1).values > 0.7
            bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
            info = {}
            objects = []
            plot_path = os.path.join(plot_folder, str(resolution), picture_name + '.jpg')
            info_path = os.path.join(info_folder, str(resolution), picture_name + '.json')
            plot_results(im, probas[keep], bboxes_scaled, plot_path)
            for p, (xmin, ymin, xmax, ymax) in zip(probas[keep], bboxes_scaled.tolist()):
                single_object = {}
                single_object['points'] = [xmin, ymin, xmax, ymax]
                single_object['score'] = p[p.argmax()].item()
                single_object['label'] = classes_raw[p.argmax()]
                objects.append(single_object)
            info['objects'] = objects
            with open(info_path, 'w', encoding='utf8') as objects_file:
                json.dump(info, objects_file, ensure_ascii=False)

In [None]:
classes2idx = {}
for idx, classes in enumerate(classes_modified):
    classes2idx[classes] = idx
print(classes2idx)

In [None]:
size = 15
source_path = "./data/"
dest_path = './objects/' + str(size)
list_path = './list/' + str(size) + '/test.txt'
if os.path.isdir(dest_path):
    shutil.rmtree(dest_path)
os.mkdir(dest_path)
# if os.path.exists(list_path):
#     os.remove(list_path)
for folder in os.listdir(source_path):
    if folder == ".DS_Store":
        continue
    for picture in os.listdir(os.path.join(source_path, folder)):
        if picture == ".DS_Store":
            continue
        image = Image.open(os.path.join(source_path, folder, picture))
        image = image.resize((size, size), Image.BICUBIC)
        img = transform(image).unsqueeze(0)
        outputs = model(img)
        probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
        keep = probas.max(-1).values > 0.7
        bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], image.size)
        objects = {}
        objects['bboxes'] = bboxes_scaled.tolist()
        objects['categories'] = []
        people = {}
        probability = []
        for prob in probas[keep]:
            cl = prob.argmax()
            probability.append(prob[cl])
            objects['categories'].append(classes2idx[classes_raw[cl.item()]])
        match_result = re.match(r'(\d*)_(.*).jpg', picture)
        index = match_result.group(1)
        with open(os.path.join(dest_path, index + '.json'),'w',encoding='utf8') as objects_file:
            json.dump(objects, objects_file, ensure_ascii=False)
#         for prob, bbox, classes in zip(probability, objects['bboxes'], objects['categories']):
#             if classes == 0:
#                 people[prob.item()] = bbox
#         people_sorted = sorted(people.items(), key=lambda x:x[0], reverse=True)
#         if len(people_sorted) <= 1:
#             continue
#         person_pair = []
#         for idx in range(0, len(people_sorted)):
#             for jdx in range(idx + 1, len(people_sorted)):
#                 person_pair.append([people_sorted[idx][1], people_sorted[jdx][1]])
#         with open(list_path, 'a', encoding='utf8') as list_file:
#             for ([[xmin1, ymin1, xmax1, ymax1], [xmin2, ymin2, xmax2, ymax2]]) in person_pair:
#                 list_file.write(index + '.jpg ' + str((int)(xmin1)) + ' ' + str((int)(ymin1)) + ' ' + str((int)(xmax1)) + ' ' + str((int)(ymax1)) + ' '
#                 + str((int)(xmin2)) + ' ' + str((int)(ymin2)) + ' ' + str((int)(xmax2)) + ' ' + str((int)(ymax2)) + ' ' + '0\n') 
