In [11]:
import os
import re
import shutil
import json
import matplotlib.pyplot as plt
from PIL import Image
import torchvision
from torchvision.io.image import read_image
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights
from torchvision.utils import draw_bounding_boxes
from torchvision.transforms.functional import to_pil_image

In [1]:
classes = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
colors = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] * 100

In [3]:
weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT
preprocess = weights.transforms()
model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.7)
_ = model.eval()

In [17]:
resolutions = [15, 20, 30, 50, 100, 160, 240]
trans_totensor = torchvision.transforms.ToTensor()
for resolution in resolutions:
    nearest_folder = './nearest/video_test_nearest_' + str(resolution)
    if os.path.exists(os.path.join('rcnn_video_plot', str(resolution))):
        shutil.rmtree(os.path.join('rcnn_video_plot', str(resolution)))
    os.makedirs(os.path.join('rcnn_video_plot', str(resolution)))
    if os.path.exists(os.path.join('rcnn_video_result', str(resolution))):
        shutil.rmtree(os.path.join('rcnn_video_result', str(resolution)))
    os.makedirs(os.path.join('rcnn_video_result', str(resolution)))
    for video_folder in os.listdir(nearest_folder):
        if video_folder == '.DS_Store':
            continue
        for picture_name in os.listdir(os.path.join(nearest_folder, video_folder)):
            plot_path = os.path.join('rcnn_video_plot', str(resolution), picture_name + '.jpg')
            result_path = os.path.join('rcnn_video_result', str(resolution), picture_name + '.json')
            img = Image.open(os.path.join(nearest_folder, video_folder, picture_name))
            img = trans_totensor(img)
            batch = [preprocess(img)]
            prediction = model(batch)[0]
            labels = [weights.meta["categories"][i] for i in prediction["labels"]]
            plt.figure(figsize=(16,10))
            plt.imshow(to_pil_image(img))
            ax = plt.gca()
            for p, [xmin, ymin, xmax, ymax], label, c in zip(prediction['scores'].tolist(), prediction['boxes'].tolist(), labels, colors):
                ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
                cl = p
                text = f'{label}: {p:0.2f}'
                ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
            plt.axis('off')
            plt.savefig(plot_path)
            plt.close()
            size = 240
            info = {}
            objects = []
            for p, [xmin, ymin, xmax, ymax], label in zip(prediction['scores'].tolist(), prediction['boxes'].tolist(), labels):
                transformed_label = re.split(r' ', label)[-1]
                xmin_transformed = xmin / size * 240
                xmax_transformed = xmax / size * 240
                ymin_transformed = ymin / size * 240
                ymax_transformed = ymax / size * 240
            single_object = {}
            single_object['points'] = [xmin_transformed, ymin_transformed, xmax_transformed, ymax_transformed]
            single_object['score'] = p
            single_object['label'] = transformed_label
            objects.append(single_object)
            info['objects'] = objects
            with open(result_path, 'w', encoding='utf8') as objects_file:
                json.dump(info, objects_file, ensure_ascii=False)
           