In [0]:
# Import the libraries
import io
import json

import torch
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


In [0]:
# Intialize the flask app
app = Flask(__name__)

# Define the class names
imagenet_class_index = json.load(open('image_class_index.json'))

# Load the model
model = torch.load("best_model.pth")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

In [0]:
# Image transform function
def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

In [0]:
# Get Prediction
def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor.to(device))
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

In [0]:
@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        file = request.files['file']
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})

In [0]:
# Run the app
if __name__ == '__main__':
    app.run()