## 0. Libarary Import

In [None]:
import io
import os
import yaml

import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F

from efficientnet_pytorch import EfficientNet


import albumentations
import albumentations.pytorch


import ipywidgets as widgets
from IPython.display import Image as display_image

## 1. 모델 정의 & 설정
### 사전에 학습된 모델을 로딩

In [None]:
#TODO뭔가 지저분한..
asset_dir = "../../assets/mask_task/"

with open(os.path.join(asset_dir,"config.yaml")) as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
model_path = os.path.join(asset_dir,config['model_name'])


In [None]:
class MyEfficientNet(nn.Module) :
    '''
    EfiicientNet-b4의 출력층만 변경합니다.
    한번에 18개의 Class를 예측하는 형태의 Model입니다.
    '''
    def __init__(self, num_classes: int = 1000) :
        super(MyEfficientNet, self).__init__()
        self.EFF = EfficientNet.from_pretrained('efficientnet-b4', in_channels=3, num_classes=num_classes)
    
    def forward(self, x) -> torch.Tensor:
        x = self.EFF(x)
        x = F.softmax(x, dim=1)
        return x

In [None]:
def transform_image(image_bytes):
    transform = albumentations.Compose([
            albumentations.Resize(height=512, width=384),
            albumentations.Normalize(mean=(0.5, 0.5, 0.5), 
                                     std=(0.2, 0.2, 0.2)),
            albumentations.pytorch.transforms.ToTensorV2()
        ])
    image = Image.open(io.BytesIO(image_bytes))
    image = image.convert('RGB')
    image_array = np.array(image)
    return transform(image=image_array)['image'].unsqueeze(0)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MyEfficientNet(num_classes=18).to(device)


if str(device) =="cpu":
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
else:
    model.load_state_dict(torch.load(model_path))
    model.to(device)
model.eval()

In [None]:
def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return tensor, y_hat

## 2. Inference 

In [None]:
# ipywidget으로 파일 업로더 생성
uploader = widgets.FileUpload(
    accept='.png, .jpg, .jpeg', 
    multiple=False  
)
display(uploader)


In [None]:
display_button = widgets.Button(description='Display Image')
display_image_space = widgets.Image()


def on_click_callback(clicked_button: widgets.Button) -> None:
    global content
    uploaded_filename = next(iter(uploader.value))
    content = uploader.value[uploaded_filename]['content']
    display_image_space.value = content

display_button.on_click(on_click_callback)

display(display_button, display_image_space)
# 처음엔 이미지 넣지 않아서 깨진 표시가 나오지만, 위에 Upload하고 누르면 이미지가 보임

In [None]:
inference_button = widgets.Button(description='Inference!')

inference_output = widgets.Output(layour={'border': '1px solid black'})

def on_click_callback(clicked_button: widgets.Button) -> None:
    with inference_output:
        inference_output.clear_output()
        tensor, output = get_prediction(content)

        print(config['classes'][output.item()])

inference_button.on_click(on_click_callback)

display(inference_button, inference_output)
