# Graph thrash `neo4j`
Train a net, recording its path in `neo4j` graph database

# Preliminaries

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from pprint import pprint
import neo4j

In [3]:
import hashlib
from hashlib import sha256
from pprint import pprint
import time

## Connecting

Need to get the `gpu-jupyter` and the `neo4j` docker containers connected. If run bare, something like:

    docker network connect gpu-jupyter_default gpu-jupyter 
    docker network connect gpu-jupyter_default neo4j
    docker network inspect gpu-jupyter_default 
    
Docker has better ways than this.

In [4]:
driver = neo4j.GraphDatabase.driver("neo4j://172.19.0.2:7687", auth=("neo4j", "test"))

In [5]:
driver.verify_connectivity()

  driver.verify_connectivity()


{IPv4Address(('172.19.0.2', 7687)): [{'ttl': 300,
   'servers': [{'addresses': ['172.19.0.2:7687'], 'role': 'WRITE'},
    {'addresses': ['172.19.0.2:7687'], 'role': 'READ'},
    {'addresses': ['172.19.0.2:7687'], 'role': 'ROUTE'}]}]}

# Utility functions

## `numpy` array store

In [23]:
from graph_utils_neo4j import NumpyStore

# The Model
    Investigation -> Experiment -> multiple ResultDAGs
`ResultDAG` is

    (netState, params)-[mutation]->(netState, params)-[mutation ...
                     +-[mutation]->(netstate, params) ...
etc. `mutation` can be a learning trajectory, or an edit.

Perhaps `mutation` can be expressed in python.

Generally the results of experiments are preferred to be reproducible, but they won't always be, when they import entropy.

## Some neural nets

In [7]:
from nn import Network, Layer, IdentityLayer, AffineLayer, MapLayer
from nnbench import NetMaker, NNMEG

In [8]:
mnm = NetMaker(NNMEG)
xor_net = mnm('2x2tx1t')
adc_net = mnm('1x8tx8tx3t')
#adc_net = mnm('1x8tx8tx3tx3t')

## ... and training data

In [9]:
xor_training_batch = (np.array([[-0.5, -0.5],
                            [-0.5,  0.5],
                            [ 0.5,  0.5],
                            [ 0.5, -0.5]]),
                  np.array([[-0.5],
                            [ 0.5],
                            [-0.5],
                            [ 0.5]]))

In [10]:
def adc(input):
    m = max(0, min(7, int(8*input)))
    return np.array([(m>>2)&1, (m>>1)&1, m&1]) * 2 - 1

vadc = lambda v: np.array([adc(p) for p in v])
#plot_ADC(vadc)

In [11]:
x = np.arange(0, 1, 1.0/(8*8)).reshape(-1,1) # 1 point in each output region
adc_training_batch = (x, vadc(x))

### We use `adc_net`

In [12]:
net = adc_net
training_batch = adc_training_batch

# The graph database

## Utilities

In [13]:
nps = NumpyStore(driver)

In [14]:
def add_start(tx, facts, net):
    tx.run("MERGE (:net "
           "{shorthand: $shorthand, "
           "ksv: $ksv, "
           "loss: $loss, "
           "ts: $ts, "
           "experiment: $experiment, "
           "head: $head}) ",
           **facts)

In [15]:
def add_subsequent(tx, facts, net):
    tx.run("MATCH (a:net {ksv: $prior_ksv}) "
           "MERGE (a)-"
           "[:LEARNED "
               "{batch_points: $batch_points, "
               "etas: $etas, "
               "eta_change_batches: $eta_change_batches, "
               "batches_this_segment: $batches_this_segment, "
               "losses: $loss, "
               "loss_steps: $loss_step, "
               "traj_L2_sqs: $traj_L2_sq, "
               "traj_cos_sq_signeds: $traj_cos_sq_signed, "
               "ts: $ts "
               "}]->"
           "(b:net "
               "{shorthand: $shorthand, "
               "ksv: $ksv, "
               "loss: $end_loss, "
               "ts: $ts, "
               "experiment: $experiment}) ",
           **facts)

## An example experiment's DAG

### We use `adc_net`

In [16]:
net = adc_net
training_batch = adc_training_batch

In [17]:
net.eta = 0.1

## Train, recording trajectory

In [18]:
def trainer(net):
    loss = net.losses([training_batch])[0]
    prior_ksv 
    batch_ctr = 0
    while loss > 1e-3:
        batch_ctr_at_seg_start = batch_ctr
        losses = []
        etas = []
        deltas = []
        prior_loss = loss
        while loss / prior_loss > 0.7071 and len(deltas) < 100:
            if not etas or net.eta != etas[-1][1]:
                etas.append([batch_ctr, net.eta])
            loss = net.learn([training_batch])
            if batch_ctr < 100 or batch_ctr % 100 == 0:
                losses.append([batch_ctr, loss])
                deltas.append([batch_ctr, net.deltas()])
            batch_ctr += 1
        #if losses[-1][0] < (batch_ctr-1):
        #    losses.append([batch_ctr, loss])
        if not deltas or deltas[-1][0] < (batch_ctr-1):
            deltas.append((batch_ctr, net.deltas()))
        properties = dict(zip(deltas[0][1]._fields, map(list, (zip(*(v[1] for v in deltas)))))) # RedisGraph has a tuple allergy
        #properties = {}
        properties['batch_points'] = [v[0] for v in deltas]
        #properties['etas'] = etas
        properties['etas'], properties['eta_change_batches'] = (list(v) for v in zip(*etas))
        properties['batches_this_segment'] = batch_ctr - batch_ctr_at_seg_start
        properties['ts'] = time.time()
        properties['shorthand'] = net.shorthand
        properties['ksv'] = nps.store(net.state_vector())
        properties['end_loss'] = net.losses([training_batch])[0]
        properties['experiment'] = 'ADC'
        yield properties

In [19]:
starting_facts = {'shorthand': net.shorthand,
              'ksv': nps.store(net.state_vector()),
              'loss': net.losses([training_batch])[0],
              'ts': time.time(),
              'experiment': 'ADC',
              'head': True,
             }

### Record results as they arrive

In [20]:
with driver.session() as session:
    session.write_transaction(add_start, starting_facts, net)
    prior_ksv = starting_facts['ksv']
    for observations in trainer(net):
        observations['prior_ksv'] = prior_ksv
        prior_ksv = observations['ksv']
        #pprint(observations)
        #observations['etas'] = observations['etas'][0] #DEBUG HACK, FIXME
        session.write_transaction(add_subsequent, observations, net)
        print(f"loss {observations['end_loss']}")