In [None]:
%load_ext autoreload
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

In [None]:
%load_ext line_profiler

In [None]:
from stepper.driver import move
import time

In [None]:
%autoreload 2
from pi_py_darknet.darknet import initialize, detect
import cv2
from IPython.display import Image, display

In [None]:
from picamera import PiCamera
import time

In [None]:
## Initialize darknet
net, meta = initialize()

In [None]:
## Initialize camera
camera = PiCamera()
camera.rotation = 180
camera.resolution = (640, 480)
time.sleep(2)

In [None]:
move(1, 0)

In [None]:
img_filename = 'capture.jpg'

camera.capture(img_filename)

In [None]:
def detect_people(img_filename):
    """Basic person detector. Runs yolo model on a file, filters result to only return 'person' """
    
    results = detect(net, meta, bytes(img_filename, 'utf-8'))
    
    people = []
    
    for cat, score, bounds in results:
        if cat == b'person':
            
            x, y, w, h = bounds
            
            center = (int(x), int(y))
            size = (w, h)
            top_left = (int(x - w / 2), int(y - h / 2))
            bottom_right = (int(x + w / 2), int(y + h / 2))
            target = (int(x), int(y-h/6))
            
            people.append(dict(
                category=cat.decode("utf-8"), 
                score=score, 
                center=center,
                top_left=top_left,
                bottom_right=bottom_right,
                target=target,
            ))
            
    return people



In [None]:
def label(img_filename, people):
    """ Uses opencv to annotate image with bounding boxes and labels of detected objects """
    
    img = cv2.imread(img_filename)

    for n, person in enumerate(people):

        cv2.rectangle(img, person['top_left'], person['bottom_right'], (255, 0, 0), thickness=2)
        
        cv2.circle(img, person['target'], 3, (255,255,0)) ## target torso
        
        label = '{}[{}]:{:06.3f}'.format(person['category'], n, person['score']*100)
        cv2.putText(img, label, person['center'], cv2.FONT_HERSHEY_COMPLEX, 1, (255,255,0))
    
    return img

In [None]:
def test_run(result_widget=None, text_widget=None):
    
    n = 0
    predictions = []
    file_found = True
    while file_found and n < 10:

        img_filename = 'test_captures/raw_{}.jpg'.format(n+20)

        try:
            people = process_image(img_filename, n=n, result_widget=result_widget, text_widget=text_widget)
            
        except FileNotFoundError:
            file_found = False
        n += 1
        
def run(result_widget=None, text_widget=None):
    n = 0
    predictions = []
    
    while True:

        img_filename = 'capture.jpg'
        
        camera.capture(img_filename)

        people = process_image(img_filename, n=n, result_widget=result_widget, text_widget=text_widget)
        
        n += 1
        
def process_image(img_filename, result_widget=None, text_widget=None, n=0):
    
    start_time = time.time()

    people = detect_people(img_filename)

    if result_widget is not None:
        ## Label image with OpenCV and save
        img = label(img_filename, people)
        out_file = 'results/test_{}.jpg'.format(n)
        cv2.imwrite(out_file, img)
#         display(Image(filename=out_file, width=640, height=480))
        
        file = open(out_file, "rb")
        image = file.read()
        result_widget.value = image

    ## Drive stepper
    if len(people) > 0:
        img_width = 640
        target_person = people[0]
        target_x, target_y = target_person['target']

        gain = 1/240 ## rough estimate of 'revolutions' per pixel
        error = abs(target_x - img_width / 2)
        pterm = error * gain
        
        if target_x < img_width * 0.45:
            text_widget.value = 'Person detected at {}: moving right'.format(target_x)
            move(pterm, 1)
            
        elif target_x > img_width * 0.55:
            text_widget.value = 'Person detected at {}: moving left'.format(target_x)
            move(pterm, 0)


    elapsed_time = time.time() - start_time
    print(n, len(people), '{:5.2f} seconds'.format(elapsed_time))
    
    return people


In [None]:
import ipywidgets as widgets

file = open('results/test_0.jpg', "rb")
image = file.read()

result_widget = widgets.Image(
    value=image,
    format='png',
    width=1024,
    height=768,
)


text_widget = widgets.HTML(
    value='Init',
    placeholder='',
    description='',
)

widgets.VBox([text_widget, result_widget])


In [None]:
run(result_widget=result_widget, text_widget=text_widget)

In [None]:
test_run(result_widget=result_widget, text_widget=text_widget)

In [None]:
%lprun -f test_run test_run()