In [None]:
import base64
import os
from io import BytesIO

import cv2
import IPython.display
import matplotlib.pyplot as plt
import PIL.Image
import seaborn as sns
import torch
import torch.nn.functional as TF
import torchvision.transforms.functional as VF
from torchvision.models import ResNet18_Weights, resnet18

# setting
builtin_camera = True

# mpl setting
sns.set_theme()
sns.set_context('notebook', 2)
plt.rcParams['figure.figsize'] = [10, 6]

# capture setup
capture = cv2.VideoCapture(0)
assert capture.isOpened(), 'Could not open video device'

# ImageNet labels
label_path = os.path.join('..', 'data', 'imagenet_labels.txt')
with open(label_path, 'r') as f:
    labels = f.readlines()

# frame sampling setting
frame_count = 0
update_rate = 10

# crop setting
height = 720 if builtin_camera else 1080
width = 1280 if builtin_camera else 1920
width_start = (width - height) // 2
width_end = width_start + height

# buffer
buf_movie = BytesIO()
buf_graph = BytesIO()

# model setup
weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)
model = model.eval()
for p in model.parameters():
    p.requires_grad = False

# IPython setup
handle = IPython.display.DisplayHandle()
handle.display(IPython.display.HTML(''))

try:

    while(True):
        success, frame = capture.read() # (bool, numpy.ndarray)

        #print(frame.shape) # external: (1080, 1920, 3), builtin: (720, 1280, 3)
        #break

        if success:

            if frame_count % update_rate == 0: # frame sampling
                frame_count = 0

                # preprocess frame
                frame = frame[:, width_start:width_end] # crop to square
                frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5) # compress to 50%
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

                # prediction
                model_input = frame.transpose(2, 0, 1)
                model_input = torch.from_numpy(model_input)
                model_input = VF.resize(model_input, [224, 224], antialias=None)
                model_input = model_input / 255
                model_input = model_input.view(1, 3, 224, 224)
                model_input = VF.normalize(model_input, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                probs = TF.softmax(model(model_input), dim=1)
                sorted_probs, sorted_indices = probs[0].sort(descending=True)
                sorted_probs *= 100

                # movie display
                PIL.Image.fromarray(frame).save(buf_movie, 'png')
                encoded_movie = base64.b64encode(buf_movie.getvalue()).decode('utf-8')

                # graph
                display_items = 3
                x = [i for i in range(display_items)]
                top_probs = sorted_probs[:display_items].flip(0)
                top_labels = [labels[i] for i in sorted_indices[:display_items]]
                top_labels.reverse() # for plot
                bar = plt.barh(x, top_probs, tick_label=top_labels)

                # mpl setting
                plt.xlabel('Probability (%)')
                plt.xlim(0, 100)

                # graph display
                plt.subplots_adjust(left=0.3, right=0.97, bottom=0.15, top=0.995)
                plt.savefig(buf_graph, format='png')
                plt.close()
                encoded_graph = base64.b64encode(buf_graph.getvalue()).decode('utf-8')

                # IPython
                html = \
                f"""
                <div style="display: flex; justify-content: center; align-items: stretch;">
                  <div style="flex: 1; display: flex; justify-content: center;">
                    <img src="data:image/png;base64,{encoded_movie}" style="max-width: 120%; height: 100%; object-fit: contain;">
                  </div>
                  <div style="flex: 1; display: flex; justify-content: center;">
                    <img src="data:image/png;base64,{encoded_graph}" style="max-width: 120%; height: 100%; object-fit: contain;">
                  </div>
                </div>
                """
                handle.update(IPython.display.HTML(html))

                # clear BytesIO object
                buf_movie.seek(0)
                buf_graph.seek(0)
                buf_movie.truncate()
                buf_graph.truncate()

            frame_count += 1

except KeyboardInterrupt:
    capture.release()
    plt.close()