In [None]:
import mlflow
import mlflow.pyfunc
import mlflow.pytorch
from  mlflow.tracking import MlflowClient
import os
import numpy as np
import torch
from torchvision import datasets, transforms
import pandas as pd
from PIL import Image
import base64
import io

In [None]:
experiment_name = "pytorch_exp_1"
tracking_uri = os.environ.get("TRACKING_URL")
client = MlflowClient(tracking_uri=tracking_uri)
mlflow.set_tracking_uri(tracking_uri)
experiments = client.list_experiments()

experiment_names = []
for exp in experiments:
    experiment_names.append(exp.name)
if experiment_name not in experiment_names:
    try:
        mlflow.create_experiment(experiment_name)
    except:
        pass
mlflow.set_experiment(experiment_name)

### Download mnist png files from https://github.com/myleott/mnist_png

```
cd ~/
git clone https://github.com/myleott/mnist_png
cd mnist_png
tar xvzf mnist_png.tar.gz
```

In [None]:
import os
from glob import glob
_path = "/home/ubuntu/mnist_png/testing/"
file_list = [y for x in os.walk(_path) for y in glob(os.path.join(x[0], '*.png'))]

In [None]:
test_size = min(len(file_list),100)
file_list = file_list[:test_size]
encoded_string_list = []
for _file in file_list:
    with open(_file, "rb") as image_file:
        encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
    encoded_string_list.append(encoded_string)
input_df = pd.DataFrame(encoded_string_list,columns=["images"])

In [None]:
run_id = "0a8b2e7d97e14eaea54da39a39edd7fa"
with mlflow.start_run(run_id=run_id):
    loaded_model = mlflow.pyfunc.load_model(mlflow.get_artifact_uri("pytorch-model"))

    # Make a few predictions
    predictions = loaded_model.predict(input_df)
predictions

### Sample code to test REST API Endpoint Deployment

In [None]:
import requests
# url = "https://mscvzavj.babyrocket.net/invocations"
url = "https://mtynlegi.babyrocket.net/invocations"
data_json = input_df.to_json(orient="split",index=False)
headers = {"Content-Type":"application/json; format=pandas-split"}
response = requests.post(url,data=data_json,headers=headers)
if response.status_code == 200:
    print(response.json())
else:
    print("REST API deployment is in progress -- please try again in a few minutes!")