# UnionML Demo

Here we use the UnionML library to simplify training a pytorch neural network model for quickdraw.

A unionml app is composed of two objects: Dataset and Model. Together, they expose method decorator entrypoints that serve as the core building blocks of an end-to-end machine learning application. Here we implement the train entrypoint.

In [2]:
from pictionary_app import model

num_classes = 200

execution = model.remote_train(
    hyperparameters={"num_classes": num_classes},
    trainer_kwargs={"num_epochs": 1},
    data_dir="./data",
    max_examples_per_class=10000,
    class_limit=num_classes,
    app_version='71ea176253225b67d5f9e3e2a59354913ffd85f0',
)

Executing quickdraw_classifier.train, execution name: fed4d7985216a4f0ead7.
Go to https://playground.hosted.unionai.cloud/console/projects/unionml/domains/development/executions/fed4d7985216a4f0ead7 to see the execution in the console.


Now, wait for the execution to complete and then load model from the remote training job. We can easily interact with the fetched model locally to generate predictions.

In [None]:
model.remote_load(execution)

In [None]:
import gradio as gr

gr.Interface(
    fn=model.predict,
    inputs="sketchpad",
    outputs="label",
    live=True,
    allow_flagging="never",
).launch()

# Flytekit Demo

## Workflow for Batch Prediction
UnionML is backed by flytekit, a general purpose SDK for expressing Flyte workflows. In this section, we define a workflow that uses the model we trained above, and run batch prediction on a larger dataset on the same Flyte backend.

## Workflow overview
This workflow 
1. Downloads the available quickdraw categories: `download_quickdraw_dataset`
2. Downloads the quickdraw labeled dataset: `download_quickdraw_dataset`
3. Attempts to label and generate features for randomly selected drawings in the dataset: `generate_input`
4. Tags the labeled drawings with the model with their computed features:`prepare_map_inputs`
5. Generates a prediction for each drawing: `batch_predictions_task`
6. Finally return the prediction along with their corresponding computed features and labels


In [None]:
import torch.nn
from typing import List
from flytekit import workflow

from flytekit_demo.batch_predictions import batch_predictions_task, prepare_map_inputs, download_quickdraw_dataset, generate_input, download_quickdraw_class_names

@workflow
def batch_predict(
    model_object: torch.nn.Module,
    n_entries: int,
    max_items_per_class: int = 1000,
    num_classes: int = 200,
) -> (List[dict], List[torch.Tensor], List[str]):
    class_names = download_quickdraw_class_names()
    dataset = download_quickdraw_dataset(max_items_per_class=max_items_per_class, num_classes=num_classes)
    feature_list, label_list = generate_input(n_entries=n_entries, dataset=dataset, class_names=class_names)
    map_input = prepare_map_inputs(model_object=model_object, feature_list=feature_list)
    predictions = batch_predictions_task(input=map_input)
    return predictions, feature_list, label_list

## Workflow Registration
The workflow above still only exists locally in Python memory.  To tell Flyte about it, we create a `FlyteRemote` client object and register the workflow with the associated backend. The workflow is run in a prebuilt container with requisite dependencies. Once registered on Flyte, the workflow is immutable and versioned and can be used to generate reproducable output data.

In [None]:
from flytekit.remote import FlyteRemote
from flytekit.configuration import Config, SerializationSettings, ImageConfig

remote = FlyteRemote(config=Config.auto(config_file="./config/config-remote.yaml"), 
                     data_upload_location="s3://open-compute-playground/data",
                     default_project="unionml",
                     default_domain="development")
image_config = ImageConfig.auto(img_name="ghcr.io/unionai-oss/unionml:quickdraw-classifier-71ea176253225b67d5f9e3e2a59354913ffd85f0")
wf = remote.register_workflow(batch_predict, 
                         version="71ea176253225b67d5f9e3e2a59354913ffd85f0", 
                         serialization_settings=SerializationSettings(image_config=image_config))

The remote client allows us to programmatically run the newly registered workflow on the hosted Flyte platform

In [None]:
execution = remote.execute(
    wf,
    inputs={"model_object": model.artifact.model_object, "n_entries": 20, "max_items_per_class": 1000, "num_classes": 200},
    project="unionml",
    domain="development",
    wait=True,
)

Once the execution completes, we can fetch the typed outputs from the workflow run. In this case, the workflow outputs a series of predictions on the batch input. Let's go through and see how well the model performed here.

In [None]:
import math
from matplotlib import pyplot as plt

predictions = execution.outputs.get('o0')
feature_list = execution.outputs.get('o1')
label_list = execution.outputs.get('o2')

# Arrange results in a square where each image contains the prediction alongside with the original label
fig = plt.figure(figsize=(40, 40))

sqr = math.sqrt(len(predictions))
rows = math.ceil(sqr)
columns = math.floor(sqr) + 1

for i in range(len(predictions)):
    best_prediction = max(predictions[i], key=predictions[i].get)
    expected_label = label_list[i]

    fig.add_subplot(rows, columns, i + 1)
    plt.imshow(feature_list[i], interpolation='bicubic')
    plt.axis('off')
    plt.title(f"predicted: {best_prediction}\nexpected: {expected_label}")