In [None]:
from flask import Flask, jsonify, request
import torch
from torchvision import transforms
from PIL import Image
import io

In [ ]:
app = Flask( __name__ )

In [ ]:
model = torch.load( '../data/model.pth' )
model.eval()

In [ ]:
def transform_image(image_bytes):
    transform = transforms.Compose( [
        transforms.Resize(255),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize( [0.485, 0.456, 0.406], 
                              [0.229, 0.224, 0.225])
        ] )
    image = Image.open( io.BytesIO( image_bytes ) )
    return transform( image ).unsqueeze( 0 )



def get_prediction( image_bytes ):
    tensor = transform_image( image_bytes) 
    outputs = model.forward( tensor )
    _, y = outputs.max( 1 )
    return y.item()


In [ ]:
@app.route('/predict', methods=['POST'])
def predict():
    if 'image' not in request.files:
        return jsonify({'error': 'No image provided'}), 400

    image = request.files['image']
    image_bytes = image.read()
    prediction = get_prediction(image_bytes)
    return jsonify({'prediction': prediction}), 200