## 0. Libarary Import

In [1]:
import io
import os
import yaml

import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F

from efficientnet_pytorch import EfficientNet


import albumentations
import albumentations.pytorch


import ipywidgets as widgets
from IPython.display import Image as display_image

from typing import Tuple


## 1. 모델 정의 & 설정
### 사전에 학습된 모델을 로딩

In [2]:
with open("config.yaml") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [3]:
class MyEfficientNet(nn.Module) :
    '''
    EfiicientNet-b4의 출력층만 변경합니다.
    한번에 18개의 Class를 예측하는 형태의 Model입니다.
    '''
    def __init__(self, num_classes: int = 18) :
        super(MyEfficientNet, self).__init__()
        self.EFF = EfficientNet.from_pretrained('efficientnet-b4', in_channels=3, num_classes=num_classes)
    
    def forward(self, x) -> torch.Tensor:
        x = self.EFF(x)
        x = F.softmax(x, dim=1)
        return x

In [4]:
def transform_image(image_bytes) -> torch.Tensor:
    transform = albumentations.Compose([
            albumentations.Resize(height=512, width=384),
            albumentations.Normalize(mean=(0.5, 0.5, 0.5), 
                                     std=(0.2, 0.2, 0.2)),
            albumentations.pytorch.transforms.ToTensorV2()
        ])
    image = Image.open(io.BytesIO(image_bytes))
    image = image.convert('RGB')
    image_array = np.array(image)
    return transform(image=image_array)['image'].unsqueeze(0)

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MyEfficientNet(num_classes=18).to(device)
model.load_state_dict(torch.load(config['model_path'], map_location=device))
model.eval()

Loaded pretrained weights for efficientnet-b4


MyEfficientNet(
  (EFF): EfficientNet(
    (_conv_stem): Conv2dStaticSamePadding(
      3, 48, kernel_size=(3, 3), stride=(2, 2), bias=False
      (static_padding): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)
    )
    (_bn0): BatchNorm2d(48, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_depthwise_conv): Conv2dStaticSamePadding(
          48, 48, kernel_size=(3, 3), stride=[1, 1], groups=48, bias=False
          (static_padding): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
        )
        (_bn1): BatchNorm2d(48, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          48, 12, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          12, 48, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (

In [6]:
def get_prediction(image_bytes: bytes) -> Tuple[torch.Tensor, torch.Tensor]:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tensor = transform_image(image_bytes=image_bytes).to(device)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return tensor, y_hat

## 2. Inference 
- TODO : 파일 업로더 생성
- TODO : 버튼 클릭시 이미지 보이기
- TODO : 인퍼런스 버튼 클릭시 인퍼런스 실행

In [38]:
uploader = widgets.FileUpload(
    accept='.png, .jpg, .jpeg', 
    multiple=False  
)

display(uploader)

FileUpload(value={}, accept='.png, .jpg, .jpeg', description='Upload')

In [14]:
for i in uploader.value:
    print(i)

성적증명서_201703618_최지민.jpg


In [15]:
next(iter(uploader.value))

'성적증명서_201703618_최지민.jpg'

In [16]:
type(uploader.value)

dict

In [None]:
uploader.value['성적증명서_201703618_최지민.jpg']['content']
# b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x01,\x01,\x00\x00\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c $.\' ",#\x1c\x1c(7),01444\x1f\'9=82<.342\xff\xdb\x00C\x01\t\t\t\x0c\x0b\x0c\x18\r\r\x182!\x1c!22222222222222222222222222222222222222222222222222\xff\xc0\x00\x11\x08\r\xb4\t\xf6\x03\x01"\x00\x02\x11\x01\x03\x11\x01\xff\xc4\x00\x1f\x00\x00\x01\x05\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\xff\xc4\x00\xb5\x10\x00\x02\x01\x03\x03\x02\x04\x03\x05\x05\x04\x04\x00\x00\x01}\x01\x02\x03\x00\x04\x11\x05\x12!1A\x06

In [27]:
def on_display_click_callback(clicked_button: widgets.Button) -> None:
    global content
    uploaded_filename = next(iter(uploader.value))
    content = uploader.value[uploaded_filename]['content']
    display_image_space.value = content

In [28]:
display_button = widgets.Button(description='Display Image')
display_image_space = widgets.Image()

In [29]:
display_button.on_click(on_display_click_callback)
display(display_button, display_image_space)
# 처음엔 이미지 넣지 않아서 깨진 표시가 나오지만, 위에 Upload하고 누르면 이미지가 보임

Button(description='Display Image', style=ButtonStyle())

Image(value=b'')

In [35]:
def on_inference_click_callback(clicked_button: widgets.Button) -> None:
    with inference_output:
        inference_output.clear_output()
        _, output = get_prediction(content)
        # print(output)
        print(config['classes'][output.item()])

In [31]:
inference_button = widgets.Button(description='Inference!')

inference_output = widgets.Output(layout={'border': '1px solid black'})

In [32]:
inference_button.on_click(on_inference_click_callback)

display(inference_button, inference_output)

Button(description='Inference!', style=ButtonStyle())

Output(layout=Layout(border='1px solid black'))

In [33]:
config

{'classes': {0: ['Wear', 'Male', 'under 30'],
  1: ['Wear', 'Male', 'between 30 and 60'],
  2: ['Wear', 'Male', 'over 60'],
  3: ['Wear', 'Female', 'under 30'],
  4: ['Wear', 'Female', 'between 30 and 60'],
  5: ['Wear', 'Female', 'over 60'],
  6: ['Incorrect', 'Male', 'under 30'],
  7: ['Incorrect', 'Male', 'between 30 and 60'],
  8: ['Incorrect', 'Male', 'over 60'],
  9: ['Incorrect', 'Female', 'under 30'],
  10: ['Incorrect', 'Female', 'between 30 and 60'],
  11: ['Incorrect', 'Female', 'over 60'],
  12: ['Not Wear', 'Male', 'under 30'],
  13: ['Not Wear', 'Male', 'between 30 and 60'],
  14: ['Not Wear', 'Male', 'over 60'],
  15: ['Not Wear', 'Female', 'under 30'],
  16: ['Not Wear', 'Female', 'between 30 and 60'],
  17: ['Not Wear', 'Female', 'over 60']},
 'model_path': '../../assets/mask_task/model.pth'}

In [37]:
inference_button.on_click(on_inference_click_callback)

display(inference_button, inference_output)

Button(description='Inference!', style=ButtonStyle())

Output(layout=Layout(border='1px solid black'), outputs=({'name': 'stdout', 'text': "tensor([17], device='cuda…