# Neural Net observatory
We run the net training in a child process, so that it can proceed while we observe and analyze partial results.

In [1]:
%load_ext autoreload
%autoreload 2
#%matplotlib inline
%matplotlib widget

FIXME: clean up imports

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
rcParams['figure.max_open_warning'] = 0

In [3]:
import ipywidgets as widgets
import json
import multiprocessing as mp
import threading
import time
import trio

In [4]:
import ipywidgets as widgets
from IPython.display import display
from collections import defaultdict
import rx
from rx import Observable
from rx.subject import Subject
from rx import operators as op

Fetch our tools:

In [5]:
from nn import Network, Layer, IdentityLayer, AffineLayer, MapLayer
#from nnbench import NNBench
#from nnvis import NNVis

Use [`ipywidgets`](https://ipywidgets.readthedocs.io/en/latest/index.html)

# A global (within the parent) state object

In [6]:
class Thing:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

g = Thing()

# Tooling
 * `JSONConn` over the Process Pipe
 -- Not seeing the exception on `recv()` of a closed connection, so we accomplish a close by a non-JSON message of four bytes of zero

In [7]:
class JSONConn():
    def __init__(self, conn):
        self.conn = conn
        
    def send(self, v):
        self.conn.send_bytes(json.dumps(v).encode('utf8'))
        
    def poll(self):
        return self.conn.poll()
    
    def recv(self):
        r = self.conn.recv_bytes()
        if r == bytes(4):
            self.close()
            raise EOFError
        return json.loads(r)
        
    def close(self):
        self.conn.send_bytes(bytes(4))
        self.conn.close()

In [8]:
class FileSystemConn():
    def __init__(self, fname):
        self.fname = fname
        self.outf = open(fname, 'wb')
        
    def send(self, v):
        self.outf.write(json.dumps(v).encode('utf8'))
        self.outf.write('\n'.encode('utf8'))
        
    def poll(self):
        raise UnimplementedError
    
    def recv(self):
        raise UnimplementedError
        
    def close(self):
        self.outf.close

# The Machine Learning compute process
* Contain the ML model
* Run in a separate O/S process for isolation
* Communicate over `multiprocessing.Pipe` via messages each of which is a JSON-encoded dictionary

In [9]:
def f(conn, fname):
    jc = JSONConn(conn)
    fc = FileSystemConn(fname)
 
    net = Network()
    net.extend(AffineLayer(2,2))
    net.extend(MapLayer(np.tanh, lambda d: 1.0 - np.tanh(d)**2))
    net.extend(AffineLayer(2,1))
    net.extend(MapLayer(np.tanh, lambda d: 1.0 - np.tanh(d)**2))

    training_batch = (np.array([[-0.5, -0.5],
                                [-0.5,  0.5],
                                [ 0.5,  0.5],
                                [ 0.5, -0.5]]),
                      np.array([[-0.5],
                                [ 0.5],
                                [-0.5],
                                [ 0.5]]))

    # Initialize a lot of variables that control the state machine
    batch_ctr = 0
    batch_to = 0
    report_state = True
    report_state_interval = 0
    last_state_report_was_at_batch = 0
    last_reported_loss = None
    loss = -1
    last_loss_report_time = 0
    loss_report_min_interval = 0.05
    loss_report_max_interval = 1
    done = False

    #for i in range(100):
    #    if done:
    #        break
    while not done:
        txm = dict()
        
        # Check for new instructions
        while jc.poll():
            rxm = jc.recv()
            print(f"compute process got {rxm}")
            for k,v in rxm.items():
                if k == 'eta':
                    net.eta = v
                elif k == 'batch to':
                    batch_to = v
                elif k == 'tell state':
                    report_state = True
                elif k == 'report state every':
                    report_state_interval = v
                elif k == 'shutdown':
                    print(f"compute process got shutdown at batch {batch_ctr}")
                    done = True
        
        # Report states if it's the right batch phase, or if asked to
        report_state = report_state or \
                        report_state_interval > 0 \
                        and batch_ctr % report_state_interval == 0 \
                        and last_state_report_was_at_batch < batch_ctr

        if report_state:
            txm['eta'] = [batch_ctr, net.eta]
            txm['sv'] = [batch_ctr, list(float(v) for v in net.state_vector())]
            last_state_report_was_at_batch = batch_ctr
            report_state = False
            
        # Run a learning step if we aren't at the target number of steps
        if batch_to > batch_ctr:
            loss = net.learn([training_batch])
            batch_ctr += 1
            # Report the loss when we reach the target number of batches
            if batch_ctr == batch_to:
                txm['loss'] = [batch_ctr, loss]
            #time.sleep(0.2) # Pretend this was a time-consuming calculation
            #time.sleep(0.01) #DEBUG: rate limit
            
        # Report the loss, with rate limiting, if it's changed since last report
        if loss != last_reported_loss \
                and time.time() - last_loss_report_time > loss_report_min_interval:
            txm['loss'] = [batch_ctr, loss]
            last_reported_loss = loss
            last_loss_report_time = time.time()
            
        # Report the loss, when it's been too long, and we're still working on it
        # even if the loss is unchanged, as a sort of heartbeat
        if time.time() - last_loss_report_time > loss_report_max_interval:
            txm['loss'] = [batch_ctr, loss]
            last_reported_loss = loss
            last_loss_report_time = time.time()
            
        if txm:
            jc.send(txm)
            fc.send(txm)
        elif batch_ctr >= batch_to:
            time.sleep(0.1)

    jc.close()
    fc.close()

# The parent
* Do work in a background thread, as coroutines in `trio`
    * Burst and route received messages from ML compute process using ReactiveX
    * Interact with the UI widgets
* Leave the foreground free for the notebook and its widgets

## Set up the compute process

In [10]:
# Below did not work with either of 'spawn' or 'forkserver'
# It resulted in 'AttributeError: Can't get attribute 'f' on <module '__main__' (built-in)>'
"""# Spawn the worker process (fork is risky)
ctx = mp.get_context('forkserver') 
ipc_pipe = ctx.Pipe()
parent_conn, child_conn = ipc_pipe
jc = JSONConn(parent_conn)
p = ctx.Process(target=f, args=(child_conn, 't1.jsons'))
"""

# Fork the compute process, but don't start it yet
# Could be trouble for trio, viz. https://github.com/python-trio/trio/issues/1614
# so we do this first, outside of any `trio.run`
ipc_pipe = mp.Pipe()
parent_conn, child_conn = ipc_pipe
jc = JSONConn(parent_conn)
p = mp.Process(target=f, args=(child_conn, 't1.jsons'))

## Build UI widgets

In [11]:
# Set up some visibility widgets
batch_w = widgets.FloatText(value=-1.0, description='Batch:', max_width=6, disabled=True)
loss_w = widgets.FloatText(value=-1.0, description='Loss:', max_width=6, disabled=True)

# Set up control widgets
# Button to stop the worker
shutdown_b_w = widgets.Button(description="Shutdown worker")

# Input field to modify batch_to, and button to submit it
batch_to_w = widgets.IntText(value=50, description='target batches:')
batch_to_b_w = widgets.Button(description="submit target batches")

ui_widgets = (batch_w, loss_w, batch_to_w, batch_to_b_w, shutdown_b_w)

## ... and widget behaviors

In [12]:
def set_w_value(w, val):
    ov = w.value
    w.value = val
    return ov

g.shut_down_child = False
def shutdown_child(w):
    #print('sending shutdown from shutdown_child')
    jc.send({'shutdown': 'now'})
    g.shut_down_child = True
shutdown_b_w.on_click(shutdown_child)

def submit_target_batches(w):
    target_batches = batch_to_w.value
    jc.send({'batch to': target_batches})
batch_to_b_w.on_click(submit_target_batches)

## Build the ReactiveX pipeline

In [13]:
# Set up pipeline to process worker messages into topic observables
compute_worker_messages_s = rx.subject.Subject()
burst_messages_s = compute_worker_messages_s.pipe(
    op.flat_map(lambda m: m.items()))
loss_s = burst_messages_s.pipe(
    op.filter(lambda t: t[0] == 'loss'),
    op.map(lambda t: t[1]))
sv_s = burst_messages_s.pipe(
    op.filter(lambda t: t[0] == 'sv'),
    op.map(lambda t: t[1]))
eta_s = burst_messages_s.pipe(
    op.filter(lambda t: t[0] == 'eta'),
    op.map(lambda t: t[1]))

loss_s.subscribe(lambda t: set_w_value(batch_w, t[0]) + set_w_value(loss_w, t[1]))
loss_s.pipe(op.take_last(1)).subscribe(print) # show the last loss

<rx.disposable.disposable.Disposable at 0x7f5d5979d520>

# Background thread

## The background thread function is a `trio` task

In [14]:
def thread_work(g, jc, s):

    async def receive_from_compute_process(g, jc, s):
        while not g.stop_requested:
            try:
                if jc.poll():
                    try:
                        m = jc.recv()
                        s.on_next(m)
                    except EOFError:
                        s.on_completed()
                        print("sender closed")
                        g.done = True
                    except BrokenPipeError:
                        s.on_completed()
                        print("broken pipe")
                        g.done = True
                else:
                    await trio.sleep(0.01)
            except OSError as e:
                print(e)
                g.done = True
                break

    async def worker(g):
        while not g.stop_requested:
            await trio.sleep(await g.threadwork())

    
    async def threadloop(g, jc, s):
        async with trio.open_nursery() as nursery:
            nursery.start_soon(receive_from_compute_process, g, jc, s)
            nursery.start_soon(worker, g)
        
    trio.run(threadloop, g, jc, s)
    g.done = True

## Set up an idle job and the thread
But don't start the thread just yet

In [15]:
async def just_one():
    g.work_ctr += 1
    return 1

g.work_ctr = 0
g.threadwork = just_one
g.stop_requested = False
g.thread = threading.Thread(target=thread_work, args=(g, jc, compute_worker_messages_s))

# Start backgrounds

In [16]:
display(*ui_widgets)
p.start()
g.thread.start()
#jc.send({'batch to': 50})


FloatText(value=-1.0, description='Batch:', disabled=True)

FloatText(value=-1.0, description='Loss:', disabled=True)

IntText(value=50, description='target batches:')

Button(description='submit target batches', style=ButtonStyle())

Button(description='Shutdown worker', style=ButtonStyle())

compute process got {'batch to': 50}
compute process got {'batch to': 5000}
compute process got {'batch to': 50000}


In [17]:
assert False, "You can use the UI now"

AssertionError: You can use the UI now

In [None]:
g.stop_requested

In [None]:
p.join()

In [None]:
g.thread.is_alive() or g.thread.join()

---

# Scratch

In [None]:
assert False, "stop here if entering from above"

In [None]:
%autoawait trio