In [None]:
from aiy.vision.inference import ModelDescriptor, CameraInference
from aiy.vision.models import utils
from aiy.vision.streaming.server import StreamingServer
from aiy.vision.streaming import svg
from aiy.leds import Leds, Color
from aiy.board import Board

from picamera import PiCamera
from IPython.display import Image, display, clear_output

import contextlib
import time

### MobileNet classifier

You can use this notebook to be used with your trained network. The training can be done using this 
<a href="https://colab.research.google.com/github/tproffen/ORCSPiVision/blob/master/Training/aiy_retrain_classification.ipynb">Colab Notebook</a>.

Remember if your joy detector is running, you need to turn it off using the commands `sudo systemctl stop joy_detection_demo.service`

In [None]:
def svg_overlay(classes, frame_size):

    width, height = frame_size
    doc = svg.Svg(width=width, height=height)
    
    for i, c in enumerate(classes):
        doc.add(svg.Text(c, x=50, y=(i+1)*50, fill='white', font_size=50))
    
    return str(doc)

These two routines we tool from <a href="AIYExamples/vision/mobilenet_based_classifier.py">mobilenet_based_classifier.py</a> from the example files.

In [None]:
def process(result, labels, tensor_name, threshold, top_k):
    # MobileNet based classification model returns one result vector.

    assert len(result.tensors) == 1
    tensor = result.tensors[tensor_name]
    probs, shape = tensor.data, tensor.shape
    assert shape.depth == len(labels)
    pairs = [pair for pair in enumerate(probs) if pair[1] > threshold]
    pairs = sorted(pairs, key=lambda pair: pair[1], reverse=True)
    pairs = pairs[0:top_k]
    
    return [' %s (%.2f)' % (labels[index], prob) for index, prob in pairs]

In [None]:
def read_labels(label_path):
    with open(label_path) as label_file:
        return [label.strip() for label in label_file.readlines()]

### Setting up model

This is our own model, so we need to define some things :)

In [None]:
model_path = "/home/pi/MyProjects/retrained_graph.binaryproto"
label_path = "/home/pi/MyProjects/retrained_labels.txt"

model = ModelDescriptor(name='mobilenet_based_classifier',
                        input_shape=(1, 160, 160, 3),
                        input_normalizer=(128.0, 128.0),
                        compute_graph=utils.load_compute_graph(model_path))

labels = read_labels(label_path)

#### Main loop

Here is our main loop. To watch the camera feed, you can connect to http://orcspi-vis.local:4664.

In [None]:
detection_threshhold = 0.5      # Confidence thresshold to list an object

with contextlib.ExitStack() as stack:
    leds   = stack.enter_context(Leds())
    camera = stack.enter_context(PiCamera(sensor_mode=4, resolution=(820, 616)))
    board  = stack.enter_context(Board())

    # This starts and runs the streaming of the camera
    server = stack.enter_context(StreamingServer(camera))  

    print ("Loading model - hold on ..")
    
    # Do inference on VisionBonnet
    with CameraInference(model) as inference:
        try:   
            for result in inference.run():
                leds.update(Leds.rgb_on(Color.RED))
                processed_results = process(result, labels, "final_result",
                                           detection_threshhold, 3)
                
                if(len(processed_results)>0):
                    clear_output(wait=True)
                    leds.update(Leds.rgb_on(Color.BLUE))
                    for result in processed_results:
                        print(result)
                        
                    server.send_overlay(svg_overlay(processed_results, (820, 616)))
                                                                                 
        except KeyboardInterrupt:
            print("Interrupted ..")
            
    leds.update(Leds.rgb_off())    
    print("Done")