# Making Use of Your Model

We are going to use the saved model architecture in `final_model.py`, as well as a saved checkpoint in `final_model_checkpoint.pth` to augment a Flask API.

In [6]:
import flask
from flask import request
import final_model

import torch

In [10]:
# here final_model is a local python file and .Classifier is a class stored in that file
def load_model_checkpoint(path):
    checkpoint = torch.load(path)
    model = final_model.Classifier(checkpoint["input"])
    model.load_state_dict(checkpoint["state_dict"])
    return model
model = load_model_checkpoint("final_model_checkpoint.pth")

Now we can create some input data to predict with. I'll be hardcoding the tensor but in reality this could be any data that has been converted to a tensor of floats.

In [12]:
t = torch.tensor([[0.0606, 0.5000, 0.3333, 0.4828, 0.4000, 0.4000, 0.4000, 0.4000, 0.4000,
               0.4000, 0.1651, 0.0869, 0.0980, 0.1825, 0.1054, 0.2807, 0.0016, 0.0000,
               0.0033, 0.0027, 0.0031, 0.0021]]).float()
print(t.shape)

torch.Size([1, 22])


In [16]:
# create a prediction using the example t
pred = model(t)
pred = torch.exp(pred)
top_p, top_class_test = pred.topk(1, dim=1)
print(top_class_test)

tensor([[1]])


In [19]:
# convert our model to torchscript
traced_script = torch.jit.trace(model, t)
traced_pred = traced_script(t)
top_p, top_class_test = pred.topk(1, dim=1)
print(top_class_test)

tensor([[1]])


Now we can create a basic flask app as before, as extend it to return a prediction based on passed in tensor data.

In [26]:
app = flask.Flask(__name__)
app.config["DEBUG"] = True
@app.route('/prediction', methods=['POST'])

def prediction():
    body = request.get_json()
    example = torch.tensor(body['data']).float()
    
    pred = model(example)
    pred = torch.exp(pred)
    _, top_class_test = pred.topk(1, dim=1)
    top_class_test = top_class_test.numpy()
    
    return {"status":"ok", "result":int(top_class_test[0][0])}

app.run(debug=True, use_reloader=False)

 * Serving Flask app '__main__'
 * Debug mode: on


 * Running on http://127.0.0.1:5000
[33mPress CTRL+C to quit[0m
127.0.0.1 - - [29/Jul/2024 21:49:55] "POST /prediction HTTP/1.1" 200 -
