## 0. Libarary Import

In [None]:
import io
import os
import yaml

import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F

from efficientnet_pytorch import EfficientNet

import albumentations
import albumentations.pytorch

import ipywidgets as widgets
from IPython.display import Image as display_image

from typing import Tuple

## 1. 모델 정의 & 설정
### 사전에 학습된 모델을 로딩

In [None]:
with open("config.yaml") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [None]:
class MyEfficientNet(nn.Module) :
    '''
    EfiicientNet-b4의 출력층만 변경합니다.
    한번에 18개의 Class를 예측하는 형태의 Model입니다.
    '''
    def __init__(self, num_classes: int = 18) :
        super(MyEfficientNet, self).__init__()
        self.EFF = EfficientNet.from_pretrained('efficientnet-b4', in_channels=3, num_classes=num_classes)
    
    def forward(self, x) -> torch.Tensor:
        x = self.EFF(x)
        x = F.softmax(x, dim=1)
        return x

In [None]:
def transform_image(image_bytes) -> torch.Tensor:
    transform = albumentations.Compose([
            albumentations.Resize(height=512, width=384),
            albumentations.Normalize(mean=(0.5, 0.5, 0.5), 
                                     std=(0.2, 0.2, 0.2)),
            albumentations.pytorch.transforms.ToTensorV2()
        ])
    image = Image.open(io.BytesIO(image_bytes))
    image = image.convert('RGB')
    image_array = np.array(image)
    return transform(image=image_array)['image'].unsqueeze(0)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MyEfficientNet(num_classes=18).to(device)
model.load_state_dict(torch.load(config['model_path'], map_location=device))
model.eval()

In [None]:
def get_prediction(image_bytes: bytes) -> Tuple[torch.Tensor, torch.Tensor]:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tensor = transform_image(image_bytes=image_bytes).to(device)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return tensor, y_hat

## 2. Inference 


In [None]:
uploader = widgets.FileUpload(
    accept='.png, .jpg, .jpeg',
    multiple= False
)
display(uploader)

In [None]:
def on_display_click_cb(clicked_button: widgets.Button) -> None:
    global content 
    uploader_filename = next(iter(uploader.value))
    content = uploader.value[uploader_filename]['content']
    display_image_space.value = content

In [None]:
display_button = widgets.Button(description='Display Image')
display_image_space = widgets.Image()

display_button.on_click(on_display_click_cb)
display(display_button, display_image_space)

In [None]:
def on_inference_click_cb(clicked_button: widgets.Button) -> None:
    with inference_output:
        inference_output.clear_output()
        _, output = get_prediction(content)
        print(output, config['classes'][output.item()])

In [None]:
inference_button = widgets.Button(description='Inference')
inference_output = widgets.Output(layout={'border': '1ox solid black'})

inference_button.on_click(on_inference_click_cb)
display(inference_button, inference_output)