# Flask를 활용한 매포

In [1]:
# !pip install Flask==2.0.1 torchvision==0.10.0

Collecting Flask==2.0.1
  Downloading Flask-2.0.1-py3-none-any.whl (94 kB)
[K     |████████████████████████████████| 94 kB 2.3 MB/s 
[?25hCollecting torchvision==0.10.0
  Downloading torchvision-0.10.0-cp37-cp37m-manylinux1_x86_64.whl (22.1 MB)
[K     |████████████████████████████████| 22.1 MB 1.8 MB/s 
Collecting itsdangerous>=2.0
  Downloading itsdangerous-2.0.1-py3-none-any.whl (18 kB)
Collecting Jinja2>=3.0
  Downloading Jinja2-3.0.3-py3-none-any.whl (133 kB)
[K     |████████████████████████████████| 133 kB 45.3 MB/s 
[?25hCollecting Werkzeug>=2.0
  Downloading Werkzeug-2.0.2-py3-none-any.whl (288 kB)
[K     |████████████████████████████████| 288 kB 37.6 MB/s 
Collecting torch==1.9.0
  Downloading torch-1.9.0-cp37-cp37m-manylinux1_x86_64.whl (831.4 MB)
[K     |████████████████████████████████| 831.4 MB 2.6 kB/s 
Installing collected packages: Werkzeug, torch, Jinja2, itsdangerous, torchvision, Flask
  Attempting uninstall: Werkzeug
    Found existing installation: Werkzeug 1

In [1]:
from flask import Flask
app = Flask(__name__)


@app.route('/')
def hello():
    return 'Hello World!'

In [4]:
# app.py
!FLASK_ENV=development FLASK_APP=app.py flask run

 * Serving Flask app 'app.py' (lazy loading)
 * Environment: development
 * Debug mode: on
Usage: flask run [OPTIONS]

Error: Could not import 'app'.


In [None]:
from flask import Flask, jsonify
app = Flask(__name__)


# app.py에 predict 넣음
@app.route('/predict', methods=['POST'])
def predict():
    return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

# image 준비

In [5]:
import io

import torchvision.transforms as transforms
from PIL import Image

def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

In [6]:
# 이미지 바꿈
with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    tensor = transform_image(image_bytes=image_bytes)
    print(tensor)

FileNotFoundError: ignored

# 예측

In [None]:
from torchvision import models

model = models.densenet121(pretrained=True)
model.eval()


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return y_hat

In [None]:
# Yhat은 class label을 보여주기에 이름으로 바꾸어야함
import json

imagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

In [None]:
with open("../_static/img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))

# API 서버와 model통합

In [7]:
from flask import request

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file'] # request해서 파일 불러옴
        img_bytes = file.read() # 파일을 바이트로 바꿈
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})

In [8]:
import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


app = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json')) # josn load
model = models.densenet121(pretrained=True)
model.eval()


def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    app.run()

FileNotFoundError: ignored

### upgrad 방향

- /predict : request에 항상 이미지 파일이 있다고 가정

- 사용자들이 이미지 이외의 것을 보낼 수 있기에 예외 처리
- 모델이 많은 수의 이미지를 인식할 지라도 모든 이미지를 인식하는 것은 아님. 그러므로, 이를 방지할 필요있다
- UI 추가 가능