<p>Note: This page and embedded model are for demonstration and learning purposes only for MSBA7028 Deep Learning. <br>

Reference/Credit
<p>Source Code of this deployment:<br>
	https://www.bilibili.com/video/BV1Qv41117SR <br>
	https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/deploying_service/deploying_pytorch/pytorch_flask_service<br>
	
<p>Background Image: <br>
	https://www.flickr.com/photos/fools4tress/33532935721/in/photostream/lightbox/
	</p>


In [1]:
import os
import io
import json
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request, render_template
from flask_cors import CORS
from torch import nn

In [2]:
app = Flask(__name__)
CORS(app)  # 解决跨域问题, assign name to app

<flask_cors.extension.CORS at 0x7fa69cf78a00>

In [3]:
finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 8)
model = finetune_net

In [4]:
weights_path = "finetune_resnet18_param.pt"
class_json_path = "class_indices.json"
assert os.path.exists(weights_path), "weights path does not exist..."
assert os.path.exists(class_json_path), "class json path does not exist..."

In [5]:
# select device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [6]:
# create model
model = finetune_net.to(device)
# load model weights
model.load_state_dict(torch.load(weights_path, map_location=device))

model.eval() # initiate the evaluation mode

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [7]:
# load class info
json_file = open(class_json_path, 'rb')
class_indict = json.load(json_file)

In [8]:
def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    if image.mode != "RGB":
        raise ValueError("input file does not RGB image...")
    return my_transforms(image).unsqueeze(0).to(device)

In [9]:
def get_prediction(image_bytes):
    try:
        tensor = transform_image(image_bytes=image_bytes)
        outputs = torch.softmax(model.forward(tensor).squeeze(), dim=0)
        prediction = outputs.detach().cpu().numpy() # use detach to remove the gradient info, convert to cpu
        template = "Detected Damage Type: {:}" 
        index_pre = [(class_indict[str(index)], float(p)) for index, p in enumerate(prediction)] # restore class and possibility
        # sort probability
        index_pre.sort(key=lambda x: x[1], reverse=True)
        k = index_pre[0][0]
        v = index_pre[0][1]
        text = [template.format(k)]
        return_info = {"result": text}
    except Exception as e:
        return_info = {"result": [str(e)]}
    return return_info

In [10]:
@app.route("/predict", methods=["POST"])
@torch.no_grad()
def predict4():
    image = request.files["file"]
    img_bytes = image.read()
    info = get_prediction(image_bytes=img_bytes)
    return jsonify(info)


@app.route("/", methods=["GET", "POST"])
def root4():
    return render_template("up.html")


if __name__ == '__main__':
    app.run(host="0.0.0.0", port=5000)

 * Serving Flask app '__main__' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://10.64.199.205:5000 (Press CTRL+C to quit)
127.0.0.1 - - [14/Apr/2022 14:33:25] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [14/Apr/2022 14:33:26] "GET /static/js/jquery.min.js HTTP/1.1" 304 -
