In [1]:
import gradio as gr

In [2]:
import torch
import torch.nn as nn
from torchvision import models
import torchvision.transforms as transforms

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
# 모델 부르기
model = models.densenet121(pretrained=True) # 가중치 학습 된 걸로 가져와
model



DenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu

In [7]:
fc = nn.Sequential(
    nn.Linear(1024 , 512),
    nn.ReLU(),
    # dropout 집어넣기(랜덤하게 가중치 끊어서 fully 아님) -> 과적합 안생김
    nn.Dropout(0.5),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(256, 2) # 바꿀 카테고리 개수 출력층
)

# classifier = nn.Linear(in_features=1024,out_features=2)
model.classifier = fc

for param in model.parameters():
    param.requires_grad = True

In [8]:
model.load_state_dict(torch.load("./densenet121_7_293_clothes_data.pth"))

<All keys matched successfully>

In [9]:
# 추론
from PIL import Image
from io import BytesIO
import requests # 이미지 주소로 다운받는 방법
import koreanize_matplotlib
import numpy as np
import matplotlib.pyplot as plt

In [10]:
category = {0: '적합', 1:'부적합'}

In [11]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x[:3]),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [30]:
def for_client(*images):
  # CNN 추론 코드

  # 한 장도 첨부하지 않으면
  if all(image is None for image in images):
    return "사진을 첨부해주세요"
  for img in images:
    # 사진이 없으면 없는 것으로 판단
    if img is None:
      continue
    img = Image.fromarray(img) # 넘파이 이미지 데이터를 Pillow로 받는 법
    img = transform(img) # 전처리
    img = img.unsqueeze(0) # 차원 추가

    model.eval()
    with torch.no_grad():
      pred = model(img)
      # print(f'pred: {pred}')
      result = pred.max(dim=1)[1]

      if result.item() == 0:
        return "후기가 등록되었습니다."

  return "사진이 적합하지 않습니다. 리뷰 포인트가 지급되지 않을 가능성이 있어도 후기를 등록하시겠습니까?"

with gr.Blocks() as app3:
  with gr.Row():
    images =  [gr.Image(label=f"Image {i+1}") for i in range(5)] # type='pil'하면 fromarray할 필요 없음
  with gr.Row():
    clear = gr.Button('초기화')
    send_btn = gr.Button('제출')

  out_text = gr.Textbox(container=False)

  send_btn.click(fn = for_client, inputs=images, outputs=out_text)
  clear.click(fn=lambda: [None] * 5 + [""], inputs=[], outputs=images + [out_text])

app3.launch()

Running on local URL:  http://127.0.0.1:7878

To create a public link, set `share=True` in `launch()`.


