## 0. Libarary Import

In [2]:
import io
import os

import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn as nn

import albumentations
import albumentations.pytorch


import ipywidgets as widgets
from IPython.display import Image as display_image

## 1. 모델 정의 & 설정
### TODO : 현재 모델이 너무 좋지 않아서.. 모델 파일만 받아서 load_model하는 부분도 준비하면 좋을듯

In [3]:
class MaskClassificationModel(nn.Module):
    def __init__(self, num_classes: int = 1000):
        super(MaskClassificationModel, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(64, 32),
            nn.ReLU(inplace=True),
            nn.Linear(32, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [4]:
def transform_image(image_bytes):
    transform = albumentations.Compose([
            albumentations.Resize(height=512, width=384),
            albumentations.Normalize(mean=(0.5, 0.5, 0.5), 
                                     std=(0.2, 0.2, 0.2)),
            albumentations.pytorch.transforms.ToTensorV2()
        ])
    image = Image.open(io.BytesIO(image_bytes))
    image = image.convert('RGB')
    image_array = np.array(image)
    return transform(image=image_array)['image'].unsqueeze(0)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MaskClassificationModel(num_classes=18).to(device)
model.eval()

MaskClassificationModel(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=64, out_features=32, bias=True)
    (2): ReLU(inplace=True)
    (3): Linear(in_features=32, out_features=18, bias=True)
  )
)

In [6]:
def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return tensor, y_hat

## 2. Inference 

In [7]:
# ipywidget으로 파일 업로더 생성
uploader = widgets.FileUpload(
    accept='.png, .jpg, .jpeg', 
    multiple=False  
)
display(uploader)


FileUpload(value={}, accept='.png, .jpg, .jpeg', description='Upload')

In [8]:
display_button = widgets.Button(description='Display Image')
display_image_space = widgets.Image()


def on_click_callback(clicked_button: widgets.Button) -> None:
    global content
    uploaded_filename = next(iter(uploader.value))
    content = uploader.value[uploaded_filename]['content']
    display_image_space.value = content

display_button.on_click(on_click_callback)

display(display_button, display_image_space)
# 처음엔 이미지 넣지 않아서 깨진 표시가 나오지만, 위에 Upload하고 누르면 이미지가 보임

Button(description='Display Image', style=ButtonStyle())

Image(value=b'')

In [9]:
inference_button = widgets.Button(description='Inference!')

inference_output = widgets.Output(layour={'border': '1px solid black'})

def on_click_callback(clicked_button: widgets.Button) -> None:
    with inference_output:
        inference_output.clear_output()
        tensor, output = get_prediction(content)
        print(output)

inference_button.on_click(on_click_callback)

display(inference_button, inference_output)


Button(description='Inference!', style=ButtonStyle())

Output()

In [10]:
content

b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c $.\' ",#\x1c\x1c(7),01444\x1f\'9=82<.342\xff\xdb\x00C\x01\t\t\t\x0c\x0b\x0c\x18\r\r\x182!\x1c!22222222222222222222222222222222222222222222222222\xff\xc0\x00\x11\x08\x02\x00\x01\x80\x03\x01"\x00\x02\x11\x01\x03\x11\x01\xff\xc4\x00\x1f\x00\x00\x01\x05\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\xff\xc4\x00\xb5\x10\x00\x02\x01\x03\x03\x02\x04\x03\x05\x05\x04\x04\x00\x00\x01}\x01\x02\x03\x00\x04\x11\x05\x12!1A\x06\x13Qa\x07"q\x142\x81\x91\xa1\x08#B\xb1\xc1\x15R\xd1\xf0$3br\x82\t\n\x16\x17\x18\x19\x1a%&\'()*456789:CDEFGHIJSTUVWXYZcdefghijstuvwxyz\x83\x84\x85\x86\x87\x88\x89\x8a\x92\x93\x94\x95\x96\x97\x98\x99\x9a\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xd2\xd3\