In [None]:
import sys

sys.path.append('..')

import base64
import os
from io import BytesIO

import cv2
import IPython.display
import PIL.Image
import torch
import torchvision.transforms.functional as VF

from utils.mlp import MLP
from utils.utils import ModelWithNormalization, freeze

# setting
builtin_camera = True

# capture setup
capture = cv2.VideoCapture(0)
assert capture.isOpened(), 'Could not open video device'

# frame sampling setting
frame_count = 0
update_rate = 10

# crop setting
height = 720 if builtin_camera else 1080
width = 1280 if builtin_camera else 1920
width_start = (width - height) // 2
width_end = width_start + height

# buffer
buf_movie = BytesIO()

# define model
models = {}
widths = [125, 1000]
for width in widths:
    model = MLP(784, width, 10, 5, True)
    model = ModelWithNormalization(model, [0.1307], [0.3081])
    model = model.eval()
    freeze(model)
    # load weight
    weight_path = os.path.join('..', 'weights', f'width={width}.ckpt')
    weight = torch.load(weight_path)
    model.load_state_dict(weight)
    models[width] = model

# IPython setup
handle = IPython.display.DisplayHandle()
handle.display(IPython.display.HTML(''))

try:

    while(True):
        success, frame = capture.read() # (bool, numpy.ndarray)

        #print(frame.shape) # external: (1080, 1920, 3), builtin: (720, 1280, 3)
        #break

        if success:

            if frame_count % update_rate == 0: # frame sampling
                frame_count = 0

                # preprocess frame
                frame = frame[:, width_start:width_end] # crop to square
                frame = cv2.resize(frame, (0, 0), fx=0.5, fy=0.5) # compress to 50%
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

                # prediction
                model_input = frame.transpose(2, 0, 1)
                model_input = torch.from_numpy(model_input)
                img_size = 28
                model_input = VF.resize(model_input, [img_size, img_size], antialias=None)
                model_input = model_input / 255
                model_input = model_input.view(1, 3, img_size, img_size)
                model_input = model_input.mean(1)

                predicted_numbers = {}
                for width in widths:
                    predicted_numbers[width] = model(model_input).max(1).indices.item()

                # movie display
                PIL.Image.fromarray(frame).save(buf_movie, 'png')
                encoded_movie = base64.b64encode(buf_movie.getvalue()).decode('utf-8')

                # IPython
                html = \
                f"""
                <div style="display: flex; justify-content: center; align-items: start; width: 90%;">
                  <div style="flex: 1; display: flex; justify-content: center; align-items: center;">
                    <img src="data:image/png;base64,{encoded_movie}" style="width: 90%; height: 90%; object-fit: contain;">
                  </div>
                  <div style="flex: 1; display: flex; flex-direction: column; justify-content: space-between;">
                    <div style="display: table;">
                      <div style="display: table-row; margin-bottom: 20px;">
                        <div style="display: table-cell; text-align: left; width: 160pt;">
                          <font size="7">幅 125:</font>
                        </div>
                        <div style="display: table-cell; text-align: left;">
                          <font size="7">予測 {predicted_numbers[125]}</font>
                        </div>
                      </div>
                      <div style="display: table-row;">
                        <div style="display: table-cell; text-align: left; width: 160pt;">
                          <font size="7">幅 1000:</font>
                        </div>
                        <div style="display: table-cell; text-align: left;">
                          <font size="7">予測 {predicted_numbers[1000]}</font>
                        </div>
                      </div>
                    </div>
                  </div>
                </div>
                """
                handle.update(IPython.display.HTML(html))

                # clear BytesIO object
                buf_movie.seek(0)
                buf_movie.truncate()

            frame_count += 1

except KeyboardInterrupt:
    capture.release()