# Neural Net observatory

In [None]:
%load_ext autoreload
%autoreload 2
#%matplotlib inline
%matplotlib widget

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
rcParams['figure.max_open_warning'] = 0
import ipywidgets as widgets

Fetch our tools:

In [None]:
from nn import Network, Layer, IdentityLayer, AffineLayer, MapLayer
from nnbench import NNBench
from nnvis import NNVis

Use [`ipywidgets`](https://ipywidgets.readthedocs.io/en/latest/index.html)

# Multiprocessing
We run the net training in a child process, so that it can proceed while we observe and analyze partial results.

## Tooling
 * `JSONConn` over the Process Pipe
 -- Not seeing the exception on `recv()` of a closed connection, so we accomplish a close by a non-JSON message of four bytes of zero

In [None]:
#from multiprocessing import Process, Pipe
import multiprocessing as mp
import json
from time import sleep

class JSONConn():
    def __init__(self, conn):
        self.conn = conn
        
    def send(self, v):
        self.conn.send_bytes(json.dumps(v).encode('utf8'))
        
    def poll(self):
        return self.conn.poll()
    
    def recv(self):
        r = self.conn.recv_bytes()
        if r == bytes(4):
            self.close()
            raise EOFError
        return json.loads(r)
        
    def close(self):
        self.conn.send_bytes(bytes(4))
        self.conn.close()

## The child

In [None]:
def f(conn):
    jc = JSONConn(conn)
 
    net = Network()
    net.extend(AffineLayer(2,2))
    net.extend(MapLayer(np.tanh, lambda d: 1.0 - np.tanh(d)**2))
    net.extend(AffineLayer(2,1))
    net.extend(MapLayer(np.tanh, lambda d: 1.0 - np.tanh(d)**2))

    training_batch = (np.array([[-0.5, -0.5],
                                [-0.5,  0.5],
                                [ 0.5,  0.5],
                                [ 0.5, -0.5]]),
                      np.array([[-0.5],
                                [ 0.5],
                                [-0.5],
                                [ 0.5]]))

    batch_ctr = 0
    batch_to = 0
    report_state = True
    done = False

    for i in range(100):
        if done:
            break
    #while not done:
        txm = dict()
        
        # Check for new instructions
        while jc.poll():
            rxm = jc.recv()
            print(rxm)
            for k,v in rxm.items():
                if k == 'eta':
                    net.eta = v
                elif k == 'batch to':
                    batch_to = v
                elif k == 'tell state':
                    report_state = True
                elif k == 'shutdown':
                    print(f"got shutdown at batch {batch_ctr}")
                    done = True
        
        # Report states if it's the right batch phase, or if asked to
        report_state = report_state or batch_ctr % 10 == 0 and last_state_report_at_batch < batch_ctr

        if report_state:
            txm['eta'] = [batch_ctr, net.eta]
            txm['sv'] = [batch_ctr, list(float(v) for v in net.state_vector())]
            last_state_report_at_batch = batch_ctr
            report_state = False
            
        # Run a learning step if we aren't at the target number of steps
        if batch_to > batch_ctr:
            loss = net.learn([training_batch])
            batch_ctr += 1
            txm['loss'] = [batch_ctr, loss]
            time.sleep(0.1) # Pretend this is a time-consuming calculation
        else:
            time.sleep(0.05)
            
        jc.send(txm)

    jc.close()

## The parent

In [None]:
import ipywidgets as widgets
from IPython.display import display
from collections import defaultdict
import rx
from rx import Observable
from rx.subject import Subject
from rx import operators as op
import time

### Setup

In [None]:
if __name__ == '__main__':
    
    def chew(time_limit):
        print('.', end='')
        time.sleep(time_limit)

    # Fork the worker process
    ipc_pipe = mp.Pipe()
    parent_conn, child_conn = ipc_pipe
    jc = JSONConn(parent_conn)
    p = mp.Process(target=f, args=(child_conn,))
    
    
    # Set up some visibility widgets
    batch_w = widgets.FloatText(value=-1.0, description='Batch:', max_width=6, disabled=False)
    loss_w = widgets.FloatText(value=-1.0, description='Loss:', max_width=6, disabled=False)
    display(batch_w, loss_w)
    
    shut_down_child = False
    def shutdown_child(w):
        print('sending shutdown from shutdown_child')
        jc.send({'shutdown': 'now'})
        shut_down_child = True
    
    # Set up a button to stop the worker
    shutdown_b_w = widgets.Button(description="Shutdown worker")
    #shutdown_b_w.on_click(lambda w: jc.send({'shutdown': 'now'}))
    shutdown_b_w.on_click(shutdown_child)
    display(shutdown_b_w)
    
    def set_w_value(w, val):
        ov = w.value
        w.value = val
        return ov

    # Process worker messages into topic observables
    worker_messages_s = rx.subject.Subject()
    burst_messages_s = worker_messages_s.pipe(
        op.flat_map(lambda m: m.items()))
    loss_s = burst_messages_s.pipe(
        op.filter(lambda t: t[0] == 'loss'),
        op.map(lambda t: t[1]))
    sv_s = burst_messages_s.pipe(
        op.filter(lambda t: t[0] == 'sv'),
        op.map(lambda t: t[1]))
    eta_s = burst_messages_s.pipe(
        op.filter(lambda t: t[0] == 'eta'),
        op.map(lambda t: t[1]))

    loss_s.subscribe(lambda t: set_w_value(batch_w, t[0]) + set_w_value(loss_w, t[1]))
    loss_s.pipe(op.take_last(1)).subscribe(print) # show the last loss

### mainloop

In [None]:
if __name__ == '__main__':
    
    p.start()
    jc.send({'batch to': 50})
    
    #for i in range(110):
    done = False
    while not done:
        if jc.poll():
            try:
                m = jc.recv()
                worker_messages_s.on_next(m)
            except EOFError:
                worker_messages_s.on_completed()
                print("sender closed")
                done = True
        else:
            if shut_down_child:
                print('sending shutdown')
                jc.send({'shutdown': 'now'})
            chew(0.1)

    p.join()


In [None]:
done = True

---

In [None]:
assert False, "stop here if entering from above"

## UI using `asyncio`

In [None]:
%gui asyncio

In [None]:
import asyncio
def wait_for_change(widget, value):
    future = asyncio.Future()
    def getvalue(change):
        # make the new value available
        future.set_result(change.new)
        widget.unobserve(getvalue, value)
    widget.observe(getvalue, value)
    return future

In [None]:
from ipywidgets import IntSlider, Output
slider = IntSlider()
out = Output()

In [None]:
async def f():
    for i in range(10):
        out.append_stdout('did work ' + str(i) + '\n')
        x = await wait_for_change(slider, 'value')
        out.append_stdout('async function continued with value ' + str(x) + '\n')

async def g():
    out.clear_output()

asyncio.create_task(g())
asyncio.create_task(f())

slider

In [None]:
out

In [None]:
out.clear_output()

In [None]:
slider.value

In [None]:
import threading
from IPython.display import display
import ipywidgets as widgets
import time
progress = widgets.FloatProgress(value=0.0, min=0.0, max=1.0)

def work(progress):
    total = 100
    for i in range(total):
        time.sleep(0.2)
        progress.value = float(i+1)/total

thread = threading.Thread(target=work, args=(progress,))
display(progress)
thread.start()

In [None]:
i=2

In [None]:
i+1

In [None]:
import multiprocessing
multiprocessing.cpu_count()

In [None]:
out = widgets.Output(layout={'border': '1px solid black'})
out

In [None]:
with out:
    print("yo")