In [7]:
import flask
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms as T
import io
from PIL import Image

In [8]:
app = flask.Flask(__name__)

useGPU = torch.cuda.is_available()
print(f'Train on gpu: {useGPU}')
device = torch.device('cuda' if useGPU else 'cpu')

def load_model():
    """Load the pre-trained model, you can use your model just as easily.
    """
    global model
    model = torchvision.models.resnet18(pretrained = True) # Use pre-trained ResNet 18 model
    
    model.fc = nn.Sequential(
                      nn.Linear(model.fc.in_features, 256),
                      nn.ReLU(),
                      nn.Dropout(0.1),
                      nn.Linear(256, 70),                   
                      nn.LogSoftmax(dim=1))
    model.load_state_dict(torch.load('model_weights.pth', map_location=device))
    if useGPU:
        model.cuda()
    model.eval()
    
def prepare_image(img):
    """Do image preprocessing before prediction on any data"""
    global transforms 
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(256),
        torchvision.transforms.CenterCrop(224),   
        torchvision.transforms.ToTensor(),      
        torchvision.transforms.RandomHorizontalFlip(p = 0.5) # 随机水平翻转
    ])
    
    if useGPU:
        return transforms(img).unsqueeze(0).cuda()
    else:
        return transforms(img).unsqueeze(0).cpu()

@app.route("/predict", methods=["POST"])
def predict():
    # Initialize the data dictionary that will be returned from the view.
    data = {"success": False}

    # Ensure an image was properly uploaded to our endpoint.
    if flask.request.method == 'POST':
        if flask.request.files.get("image"):
            # Read the image in PIL format
            image = flask.request.files["image"].read()
            image = Image.open(io.BytesIO(image)) # 一般传入二进制文件
            
            # Preprocess the image and prepare it for classification.
            image = prepare_image(image)
            
            # Classify the input image and then initialize the list of predictions to return to the client.
            preds = F.softmax(model(image), dim=1)
            results = torch.topk(preds.cpu().data, k=3, dim=1)
            
            data['predictions'] = list()
            
            trainset = torchvision.datasets.ImageFolder("./datasets/DogBreeds/train", transform = transforms)

            for index, value in zip(results.indices.squeeze(), results.values.squeeze()):
                data["predictions"].append({"label": trainset.classes[index], "prop": value.item() }) # prop -> prob
            
            # Indicate that the request was a success.
            data["success"] = True

    # Return the data dictionary as a JSON response.
    return flask.jsonify(data)

if __name__ == '__main__':
    print("Loading PyTorch model and Flask starting server ...")
    print("Please wait until server has fully started")
    load_model()
    app.run()

Train on gpu: True
Loading PyTorch model and Flask starting server ...
Please wait until server has fully started
 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
   Use a production WSGI server instead.
 * Debug mode: off


 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
127.0.0.1 - - [13/Feb/2022 13:45:26] "[37mPOST /predict HTTP/1.1[0m" 200 -


True


127.0.0.1 - - [13/Feb/2022 13:45:30] "[37mPOST /predict HTTP/1.1[0m" 200 -


True
