# Neural Net observatory

In [1]:
%load_ext autoreload
%autoreload 2
#%matplotlib inline
%matplotlib widget

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
rcParams['figure.max_open_warning'] = 0
import ipywidgets as widgets

Fetch our tools:

In [3]:
from nn import Network, Layer, IdentityLayer, AffineLayer, MapLayer
from nnbench import NNBench
from nnvis import NNVis

Use [`ipywidgets`](https://ipywidgets.readthedocs.io/en/latest/index.html)

# Multiprocessing
We run the net training in a child process, so that it can proceed while we observe and analyze partial results.

### Tooling
* `JSONConn`

In [4]:
from multiprocessing import Process, Pipe
import json
from time import sleep

class JSONConn():
    def __init__(self, conn):
        self.conn = conn
        
    def send(self, v):
        self.conn.send_bytes(json.dumps(v).encode('utf8'))
        
    def poll(self):
        return self.conn.poll()
    
    def recv(self):
        r = self.conn.recv_bytes()
        if r == bytes(4):
            self.close()
            raise EOFError
        return json.loads(r)
        
    def close(self):
        self.conn.send_bytes(bytes(4))
        self.conn.close()

### The child

In [71]:
def f(conn):
    jc = JSONConn(conn)
 
    net = Network()
    net.extend(AffineLayer(2,2))
    net.extend(MapLayer(np.tanh, lambda d: 1.0 - np.tanh(d)**2))
    net.extend(AffineLayer(2,1))
    net.extend(MapLayer(np.tanh, lambda d: 1.0 - np.tanh(d)**2))

    training_batch = (np.array([[-0.5, -0.5],
                                [-0.5,  0.5],
                                [ 0.5,  0.5],
                                [ 0.5, -0.5]]),
                      np.array([[-0.5],
                                [ 0.5],
                                [-0.5],
                                [ 0.5]]))

    batch_ctr = 0
    batch_to = 0
    report_state = True
    done = False

    for i in range(100):
        txm = dict()
        
        # Check for new instructions
        while jc.poll():
            rxm = jc.recv()
            print(rxm)
            for k,v in rxm.items():
                if k == 'eta':
                    net.eta = v
                elif k == 'batch to':
                    batch_to = v
                elif k == 'tell state':
                    report_state = True
                elif k == 'shutdown':
                    done = True
        
        # eport states if it's the right batch phase, or if asked to
        report_state = report_state or batch_ctr % 4 == 0 and last_state_report_at_batch < batch_ctr

        if report_state:
            txm['eta'] = [batch_ctr, net.eta]
            txm['sv'] = [batch_ctr, list(float(v) for v in net.state_vector())]
            last_state_report_at_batch = batch_ctr
            report_state = False
            
        # Run a learning step if we aren't at the target number of steps
        if batch_to > batch_ctr:
            loss = net.learn([training_batch])
            batch_ctr += 1
            txm['loss'] = [batch_ctr, loss]
            
        jc.send(txm)
        if done:
            break
        time.sleep(0.05)

    jc.close()

### The parent initiates

In [72]:
from collections import defaultdict
import time

if __name__ == '__main__':
    pipe = Pipe()
    parent_conn, child_conn = pipe
    jc = JSONConn(parent_conn)
    p = Process(target=f, args=(child_conn,))
    p.start()
    
    jc.send({'batch to': 100})
    rxen = defaultdict(list)
    for i in range(50):        
        if jc.poll():
            try:
                m = jc.recv()
                print(m)
                for k, v in m.items():
                    #print(f"key is {k}, val is {v}")
                    rxen[k].append(v)
            except EOFError:
                print("sender closed")
                break
        else:
            print('.', end='')
            sleep(0.1)
    p.join()


{'batch to': 100}
.{'eta': [0, 0.1], 'sv': [0, [-0.8774924133771402, 0.4787428099731336, -0.6726777769500891, 0.01966985240848951, -1.512468940853395, 0.5045606413931971, 0.9538708069653601, 0.3654533623137645, -1.0654674255643688]], 'loss': [1, 0.5668681270027477]}
{'loss': [2, 0.5643455732976648]}
.{'loss': [3, 0.5616103615914585]}
{'loss': [4, 0.5586361418396362]}
.{'eta': [4, 0.1], 'sv': [4, [-0.8796500121043779, 0.4775160466569647, -0.6742816156933711, 0.019625789451851547, -1.5026206490900402, 0.5176864478815685, 0.9135277033968342, 0.38645033419594127, -1.0192364819909985]], 'loss': [5, 0.5553923045060556]}
{'loss': [6, 0.5518431468944113]}
.{'loss': [7, 0.5479468543571303]}
{'loss': [8, 0.543654253640248]}
.{'eta': [8, 0.1], 'sv': [8, [-0.8820272005794964, 0.4760288115574221, -0.6760338080614506, 0.019640346535056225, -1.4913785585219608, 0.5340534284326937, 0.8654654785233153, 0.41218283133246697, -0.9641271713283246]], 'loss': [9, 0.5389072874074162]}
{'loss': [10, 0.53363715

In [73]:
rxen

defaultdict(list,
            {'eta': [[0, 0.1],
              [4, 0.1],
              [8, 0.1],
              [12, 0.1],
              [16, 0.1],
              [20, 0.1],
              [24, 0.1],
              [28, 0.1],
              [32, 0.1]],
             'sv': [[0,
               [-0.8774924133771402,
                0.4787428099731336,
                -0.6726777769500891,
                0.01966985240848951,
                -1.512468940853395,
                0.5045606413931971,
                0.9538708069653601,
                0.3654533623137645,
                -1.0654674255643688]],
              [4,
               [-0.8796500121043779,
                0.4775160466569647,
                -0.6742816156933711,
                0.019625789451851547,
                -1.5026206490900402,
                0.5176864478815685,
                0.9135277033968342,
                0.38645033419594127,
                -1.0192364819909985]],
              [8,
               [-0.8820272005