# W&B Production Monitoring Overview

This notebook demonstrates how to monitor production models with W&B through an illustrative example. We will train a model to correctly identify handwritten digits, then monitor a locally deployed version of the model. We create a gradio app which runs in the notebook and lets a user draw/"handwrite" characters with the mouse and give live feedback by labeling the character as the digit 0-9.

_Note: To keep the example focused on important code, much of the dataset manipulation, modelling, and other utilities are packaged in local files and imported here_

# Step 0: Setup & import dependencies

In [None]:
!pip install tensorflow
!pip install gradio
!pip install weave

Log in to W&B to sync these examples to your W&B account, where you can view, interact with, and customize the resulting Tables and Boards.

In [None]:
import wandb
wandb.login()

Set your W&B entity (username or team name) and optionally rename the destination project.

In [None]:
WB_ENTITY = "shawn"
WB_PROJECT = "prodmon_mnist"

# Step 1: Get data
In this example, we will use `keras.datasets.mnist.load_data()` to load in the MNIST dataset. 

In [None]:
import model_util

dataset = model_util.get_dataset()
model_util.image_from_array(dataset[0][0])

# Step 2: Train model
Next we will train a classic NN to predict the digits

In [None]:
model = model_util.train_model(*dataset, conv_layers=0, epochs=1) # 1 epoch so we can actually see some errors

# Step 3: Query model
Now, let's query the model! Normally there is a little pre- and post- processing needed to make a prediction - we will write a short function to handle this for us.

In [None]:
import numpy as np
import json

def simple_predict(image_arr):
    # Prepare image for model
    tensor = (image_arr.astype("float32")).reshape(1, 28, 28, 1)

    # Make the prediction
    prediction = model.predict(tensor, verbose=False)

    # In this application, we need to reshape the output:
    raw_predictions = prediction[0].tolist()
    logits = {
        str(k): v for k, v in zip(range(10), raw_predictions)
    }
    
    prediction = np.argmax(raw_predictions).tolist()
    
    return {"logits": logits, "prediction": prediction}

_, _, x_test, y_test = dataset
for i in range(10):
    image_arr = x_test[i]
    truth = y_test[i]
    preds = simple_predict(image_arr)
    
    print(f"Input: {truth}")
    display(model_util.image_from_array(image_arr))
    print(f"Prediction: {preds['prediction']}")
    print(f"Logits: {json.dumps(preds['logits'], indent=2)}")
    print("")
    

# Step 3A: Save predictions with W&B Weave using StreamTable
With W&B's Weave library, we can stream any data to W&B for storage and further analysis.

In [None]:
import weave
weave.use_frontend_devmode()
from weave.legacy.weave.monitoring import StreamTable

# Initialize a stream table
# (optionally change the name argument to any string
# that follows the wandbentity_name/project_name/table_name format)
st = StreamTable(f"{WB_ENTITY}/{WB_PROJECT}/logged_predictions")
_, _, x_test, y_test = dataset
for i in range(100):
    image_arr = x_test[i]
    truth = y_test[i].tolist()
    preds = simple_predict(image_arr)
    
    # Log the data
    st.log({
        **preds,
        "image": model_util.image_from_array(image_arr),
        "truth": truth
    })

# Optional: wait for the logs to finish uploading (nicer for live demos)
st.finish()

# Show the StreamTable
st    

# Step 3B: Save predictions with W&B Weave using `monitor` decorator
This pattern of logging inputs and outputs of a functions is so common, that we provide a decorator which automatically logs a function's I/O.

In [None]:
from weave.legacy.weave.monitoring import monitor
import numpy as np

mon = monitor.init_monitor(f"{WB_ENTITY}/{WB_PROJECT}/monitor_predict_function")

def preprocess(span):
    span.inputs['image'] = model_util.image_from_array(span.inputs['image_arr'])
    del span.inputs['image_arr']

@mon.trace(
    # An preprocessor allows the function arguments to be pre-processed before logging.
    preprocess = preprocess
)
def monitor_predict(image_arr):
    # Prepare image for model
    tensor = (image_arr.astype("float32")).reshape(1, 28, 28, 1)

    # Make the prediction
    prediction = model.predict(tensor, verbose=False)

    # In this application, we need to reshape the output:
    raw_predictions = prediction[0].tolist()
    logits = {
        str(k): v for k, v in zip(range(10), raw_predictions)
    }
    
    prediction = np.argmax(raw_predictions).tolist()
    
    return {"logits": logits, "prediction": prediction}

_, _, x_test, y_test = dataset
for i in range(100):
    image_arr = x_test[i]
    truth = y_test[i].tolist()
    # Use the added monitor_attributes argument to add additional data
    preds = monitor_predict(image_arr, monitor_attributes={'truth': truth})

# Step 4: End-to-end example
Typically a production application will contain a prediction service that provides predictions to a client. To demonstrate this in a notebook, we will create a `PredictionService` and an `AppUI`: a small interface which lets the user to draw an image, view the prediction, and give feedback on a result (in this case, correctly label a handdrawn digit 0-9). These communicate via `predict` and `record_feedback` methods. 

Note: this is purely for example purposes—your production systems may widely vary in structure_

In [None]:
# TODO: not yet working with new API

# import app_util
# from weave.legacy.weave.monitoring import monitor
# import PIL
# import numpy as np

# class PredictionService(app_util.PredictionServiceInterface):
#     def __init__(self, model):
#         self.model = model
#         self.last_prediction = {}
    
#     @monitor(auto_log = False,  entity_name=WB_ENTITY, project_name=WB_PROJECT)
#     def _raw_predict(self, pil_image: PIL.Image) -> dict:
#         # Prepare image for model
#         tensor = (np.array(pil_image.resize((28, 28))).astype("float32") / 255).reshape(1, 28, 28, 1)

#         # Make the prediction
#         prediction = self.model.predict(tensor, verbose=False)

#         # In this application, we need to reshape the output:
#         raw_predictions = prediction[0].tolist()
#         logits = {
#             str(k): v for k, v in zip(range(10), raw_predictions)
#         }

#         prediction = np.argmax(raw_predictions).tolist()

#         return {"logits": logits, "prediction": prediction}
    
#     def _update_last_prediction(self, prediction) -> None:
#         if len(self.last_prediction) > 0:
#             last_pred = self.last_prediction.pop(list(self.last_prediction.keys())[0])
#             last_pred.finalize()
#         self.last_prediction[prediction.id] = prediction

    
#     def predict(self, pil_image: PIL.Image) -> app_util.Prediction:
#         record = self._raw_predict(pil_image)
        
#         # Cache the last prediction for ground_truth recording
#         self._update_last_prediction(record)
        
#         # Return the prediction
#         return app_util.Prediction(record.get()['logits'], record.id)
    
#     def record_feedback(self, prediction_id: str, feedback: int) -> None:
#         if prediction_id not in self.last_prediction:
#             return

#         # Get the past prediction
#         prediction = self.last_prediction.pop(prediction_id)
        
#         # Save the user feedback
#         prediction.add_data({'user_feedback': feedback})
        
#         # Log the results
#         prediction.finalize()
        
# app_util.render_app(PredictionService(model))