# Neural Nets using [JAX](https://github.com/google/jax#readme)
**JAX is NumPy on the CPU, GPU, and TPU, with great [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) for high-performance machine learning research.**

Version 0.1, in `nn-jax`

Should do [Working efficiently with jupyter lab](https://florianwilhelm.info/2018/11/working_efficiently_with_jupyter_lab/)

When this was a notebook with integrated tests, we did: \
`
%load_ext autoreload
%autoreload 2
%matplotlib widget
#%matplotlib inline`

In [1]:
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

We'll be generating random data in the following examples. One big difference between NumPy and JAX is how you generate random numbers. For more details, see [Common Gotchas in JAX].

[Common Gotchas in JAX]: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers

In [2]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))

A network built of components which:
1. accept an ordered set of reals (we'll use `numpy.array`, and  call them vectors) at the input port and produce another at the output port - this is forward propagation. ${\displaystyle f\colon \mathbf {R} ^{n}\to \mathbf {R} ^{m}}$
1. accept an ordered set of reals at the output port, representing the gradient of the loss function at the output, and produce the gradient of the loss function at the input port - this is back propagation, aka backprop. ${\displaystyle b\colon \mathbf {R} ^{m}\to \mathbf {R} ^{n}}$
1. from the gradient of the loss function at the output, calculate the partial of the loss function w.r.t the internal parameters ${\displaystyle \frac{\partial E}{\partial w} }$
1. accept a scalar $\eta$ to control the adjustment of internal parameters. _Or is this effected by scaling the loss gradient before passing? YES_
1. update internal parameters ${\displaystyle w \leftarrow w - \eta \frac{\partial E}{\partial w} }$


In [3]:
class Layer:
    def __init__(self):
        pass
    
    def __call__(self, x):
        """Compute response to input"""
        raise NotImplementedError
        
    def backprop(self, output_delE):
        """Use output error gradient to adjust internal parameters, return gradient of error at input"""
        raise NotImplementedError
        
    def state_vector(self):
        """Provide the layer's learnable state as a vector"""
        raise NotImplementedError

    def set_state_from_vector(self, sv):
        """Set the layer's learnable state from a vector"""
        raise NotImplementedError

A network built of a cascade of layers:

In [4]:
class Network:
    def __init__(self):
        self.layers = []
        self.eta = 0.1 #FIXME
        
    def extend(self, net):
        self.layers.append(net)
        return self
        
    def __call__(self, input):
        v = input
        for net in self.layers:
            v = net(v)
        return v
    
    def learn(self, facts, eta=None):
        self.eta = eta or self.eta
        for x, ideal in facts:
            y = self(x)
            e = y - ideal
            egrad = e * self.eta / e.shape[0]
            for net in reversed(self.layers):
                egrad = net.backprop(egrad)
        #loss = float(e.dot(e.T))/2.0
        loss = jnp.einsum('...ij,...ij', e, e) / (2.0 * e.shape[0])
        return loss

    def losses(self, facts):
        return [jnp.einsum('...ij,...ij', e, e) / (2.0 * e.shape[0]) \
                for e in (self(x) - ideal for x, ideal in facts)]
        
    def state_vector(self):
        """Provide the network's learnable state as a vector"""
        return jnp.concatenate([layer.state_vector() for layer in self.layers])
    
    def set_state_from_vector(self, sv):
        """Set the layer's learnable state from a vector"""
        i = 0
        for layer in self.layers:
            lsvlen = len(layer.state_vector())
            layer.set_state_from_vector(sv[i:i+lsvlen])
            i += lsvlen

___

## Useful Layers

### Identify

In [5]:
class IdentityLayer(Layer):
    def __call__(self, x):
        return x
    
    def backprop(self, output_delE):
        return output_delE

    def state_vector(self):
        return jnp.array([])
    
    def set_state_from_vector(self, sv):
        pass

### Affine
A layer that does an [affine transformation](https://mathworld.wolfram.com/AffineTransformation.html) aka affinity, which is the classic fully-connected layer with output offsets.

$$ \mathbf{M} \mathbf{x} + \mathbf{b} = \mathbf{y} $$
where
$$
\mathbf{x} = \sum_{j=1}^{n} x_j \mathbf{\hat{x}}_j \\
\mathbf{b} = \sum_{i=1}^{m} b_i \mathbf{\hat{y}}_i \\
\mathbf{y} = \sum_{i=1}^{m} y_i \mathbf{\hat{y}}_i
$$
and $\mathbf{M}$ can be written
$$
\begin{bmatrix}
    m_{1,1} & \dots & m_{1,n} \\
    \vdots & \ddots & \vdots \\
    m_{m,1} & \dots & m_{m,n}
\end{bmatrix} \\
$$

#### Error gradient back-propagation
$$ 
\begin{align}
 \frac{\partial loss}{\partial\mathbf{x}}
  &= \frac{\partial loss}{\partial\mathbf{y}} \frac{\partial\mathbf{y}}{\partial\mathbf{x}} \\
  &= \mathbf{M}^\mathsf{T}\frac{\partial loss}{\partial\mathbf{y}}
\end{align}
$$

#### Parameter adjustment
$$
 \frac{\partial loss}{\partial\mathbf{M}}
 = \frac{\partial loss}{\partial\mathbf{y}} \frac{\partial\mathbf{y}}{\partial\mathbf{M}}
 = \frac{\partial loss}{\partial\mathbf{y}} \mathbf{x} \\
 \frac{\partial loss}{\partial\mathbf{b}}
 = \frac{\partial loss}{\partial\mathbf{y}} \frac{\partial\mathbf{y}}{\partial\mathbf{b}}
 = \frac{\partial loss}{\partial\mathbf{y}}
$$

#### Adapting to `numpy`

In `numpy` it is more convenient to use row vectors, particularly for calculating the transform on multiple inputs in one operation. We use the identity $ \mathbf{M} \mathbf{x} = (\mathbf{x} \mathbf{M}^\mathsf{T})^\mathsf{T}.$ To avoid cluttering names, we will use `M` in the code below to hold $\mathbf{M}^\mathsf{T}$.

In [6]:
class AffineLayer(Layer):
    """An affine transformation, which is the classic fully-connected layer with offsets.
    
    The layer has n inputs and m outputs, which numbers must be supplied
    upon creation. The inputs and outputs are marshalled in numpy arrays, 1-D
    in the case of a single calculation, and 2-D when calculating the outputs
    of multiple inputs in one call.
    If called with 1-D array having shape == (n,), e.g numpy.arange(n), it will
    return a 1-D numpy array of shape (m,).
    If called with a 2-D numpy array, input shall have shape (k,n) and will return
    a 2-D numpy array of shape (k,m), suitable as input to a subsequent layer
    that has input width m.
    """
    def __init__(self, n, m):
        self.M = jnp.empty((n, m))
        self.b = jnp.empty(m)
        self.randomize()
        
    def randomize(self):
        self.M = random.normal(key, self.M.shape, dtype=jnp.float32)
        self.b = random.normal(key, self.b.shape, dtype=jnp.float32)
        
    def __call__(self, x):
        self.input = x
        self.output = x @ self.M + self.b
        return self.output
    
    def backprop(self, output_delE):
        input_delE = output_delE @ self.M.T
        o_delE = jnp.atleast_2d(output_delE)
        self.M -= jnp.einsum('...ki,...kj->...ji', o_delE, jnp.atleast_2d(self.input))
        self.b -= jnp.sum(o_delE, 0)       
        return input_delE

    def state_vector(self):
        return jnp.concatenate((self.M.ravel(), self.b.ravel()))
    
    def set_state_from_vector(self, sv):
        """Set the layer's learnable state from a vector"""
        l_M = len(self.M.ravel())
        l_b = len(self.b.ravel())
        self.M = sv[:l_M].reshape(self.M.shape)
        self.b = sv[l_M : l_M + l_b].reshape(self.b.shape)

### Map
Maps a scalar function on the inputs, for e.g. activation layers.

In [7]:
class MapLayer(Layer):
    """Map a scalar function on the input taken element-wise"""
    def __init__(self, fun, dfundx):
        self.vfun = jnp.vectorize(fun)
        self.vdfundx = jnp.vectorize(dfundx)

    def __call__(self, x):
        self.input = x
        return self.vfun(x)
    
    def backprop(self, output_delE):
        input_delE = self.vdfundx(self.input) * output_delE
        return input_delE

    def state_vector(self):
        return jnp.array([])
    
    def set_state_from_vector(self, sv):
        pass

---

# Tests
*Incomplete* \
Also `unittest` the `.py` version with a separate test script, see `test-nn_v3.py`.

Make a few test arrays:

In [8]:
if __name__ == '__main__':
    one_wide = jnp.atleast_2d(jnp.arange(1*4)).reshape(-1,1)
    print(f"one_wide is:\n{one_wide}")
    two_wide = jnp.arange(2*4).reshape(-1,2)
    print(f"two_wide is:\n{two_wide}")
    three_wide = jnp.arange(3*4).reshape(-1,3)
    print(f"three_wide is:\n{three_wide}\n")

one_wide is:
[[0]
 [1]
 [2]
 [3]]
two_wide is:
[[0 1]
 [2 3]
 [4 5]
 [6 7]]
three_wide is:
[[ 0  1  2]
 [ 3  4  5]
 [ 6  7  8]
 [ 9 10 11]]



## Tooling for Testing

In [9]:
if __name__ == '__main__':
    import sympy
    
    class VC():
        def grad(f, x, eps=1e-3):
            #epsihat = jnp.eye(x.size) * eps
            epsihat = jnp.eye(x.shape[-1]) * eps
            yp = jnp.apply_along_axis(f, 1, x + epsihat)
            ym = jnp.apply_along_axis(f, 1, x - epsihat)
            return (yp - ym)/(2 * eps)
        
        def tensor_grad(f, x, eps=1e-3):
            return jnp.apply_along_axis(lambda v: VC.grad(f, v, eps), 1, x)
            
    def closenuf(a, b, tol=0.001):
        return jnp.allclose(a, b, rtol=tol)
    
    def arangep(n, starting_index=0):
        sympy.sieve.extend_to_no(starting_index + n)
        return jnp.array(sympy.sieve._list[starting_index:starting_index + n])

In [10]:
#VC.grad(lambda x:x**2, three_wide[1])

In [11]:
#VC.tensor_grad(lambda x:x**2, three_wide)

---

Input to a layer can be a single (row) vector, or a vertical stack of row vectors,
a 2-d array that resembles a matrix. We need to test each layer class with both single and stacked input.

## Identity layer

In [12]:
if __name__ == '__main__':
    iL = IdentityLayer()
    
    # It's transparent from input to output
    assert jnp.equal(iL(jnp.arange(5)), jnp.arange(5)).all()
    assert (iL(three_wide) == three_wide).all()
    
    # It back-propagates the loss gradient without alteration
    assert jnp.equal(iL.backprop(jnp.arange(7)), jnp.arange(7)).all()
    assert (iL.backprop(three_wide) == three_wide).all()

    # It works for stacked input
    # (see above)

## Map layer

#### Test single vector input behavior

In [13]:
if __name__ == '__main__':
    mL = MapLayer(lambda x:x**2, lambda d:2*d)
    
    # It applies the forward transformation
    assert jnp.equal(mL(jnp.array([-2,1,3])), jnp.array([4,1,9])).all()
    
    # It back-propagages the loss gradient
    x = jnp.array([1,2,2])
    y = mL(x)
    
    # for loss function, use L2-distance from some ideal
    # (divided by 2, for convenient gradient = error)
    ideal = jnp.array([2,3,5])
    loss = lambda v: (v - ideal).dot(v - ideal) / 2.0
    loss_at_y = loss(y)
    print(f"x = {x}, y = {y}, loss at y = {loss_at_y}")
    
    # find numerical gradient of loss function at y, the layer output
    grad_y = VC.grad(loss, y)
    print(f"∇𝑙𝑜𝑠𝑠(𝑦) = {grad_y}")
    
    # find the numerical gradient of the loss w.r.t. the input of the layer
    grad_x = VC.grad(lambda x:loss(mL(x)), x)
    print(f"∇𝑙𝑜𝑠𝑠(𝑥) = {grad_x}")
    
    # The backprop method does the same
    _ = mL(x) # Make sure the last x is in the right place
    in_delE = mL.backprop(grad_y)
    print(f"backprop({grad_y}) = {in_delE}")
    assert closenuf(in_delE, grad_x)
    
    # The backprop operation did not change the behavior
    assert jnp.equal(mL(x), y).all()

x = [1 2 2], y = [1 4 4], loss at y = 1.5
∇𝑙𝑜𝑠𝑠(𝑦) = [-0.99998707  0.99992746 -0.99992746]
∇𝑙𝑜𝑠𝑠(𝑥) = [-1.9999741  3.9999483 -3.9999483]
backprop([-0.99998707  0.99992746 -0.99992746]) = [-1.9999741  3.9997098 -3.9997098]


#### Test stacked-vectors input:

In [14]:
if __name__ == '__main__':
    mL = MapLayer(lambda x:x**2, lambda d:2*d)
    
    two_wide_sq = jnp.array([[ 0,  1],
                            [ 4,  9],
                            [16, 25],
                            [36, 49]])
    # It applies the forward transformation
    assert jnp.equal(mL(two_wide), two_wide_sq).all()
    
    # It back-propagages the loss gradient
    x = two_wide
    y = mL(x)

    # for loss function, use L2-distance from some ideal
    # (divided by 2, for convenient gradient = error)
    ideal = two_wide * 2 + 11
    #print(y - ideal)
    #loss = lambda v: (v - ideal).dot(v - ideal) / 2.0
    loss = lambda v: jnp.einsum('ij,ij', v-ideal, v-ideal) / (2 * v.shape[0])
    loss_at_y = loss(y)
    print(f"x =\n{x}\ny =\n{y}, loss = {loss_at_y}\n")
    
    # find numerical gradient of loss function at y, the layer output
    grad_y = VC.tensor_grad(loss, y)
    print(f"∇𝑙𝑜𝑠𝑠(𝑦) =\n{grad_y}\n")
    
    # find the numerical gradient of the loss w.r.t. the input of the layer
    grad_x = VC.tensor_grad(lambda x:loss(mL(x)), x)
    print(f"∇𝑙𝑜𝑠𝑠(𝑥) =\n{grad_x}\n")
    
    # The backprop method does the same
    _ = mL(x) # Make sure the last x is in the right place
    in_delE = mL.backprop(grad_y)
    print(f"backprop({grad_y}) =\n{in_delE}")
    assert closenuf(in_delE, grad_x)
    
    # The backprop operation did not change the behavior
    assert jnp.equal(mL(x), y).all()

x =
[[0 1]
 [2 3]
 [4 5]
 [6 7]]
y =
[[ 0  1]
 [ 4  9]
 [16 25]
 [36 49]], loss = 152.5

∇𝑙𝑜𝑠𝑠(𝑦) =
[[-33.99658   -36.01074  ]
 [-26.000975  -19.989012 ]
 [ -1.9989012  11.993407 ]
 [ 37.963863   59.93652  ]]

∇𝑙𝑜𝑠𝑠(𝑥) =
[[   0.        -71.99096 ]
 [-104.0039   -119.99511 ]
 [ -15.995025  119.991295]
 [ 455.99362   839.9047  ]]

backprop([[-33.99658   -36.01074  ]
 [-26.000975  -19.989012 ]
 [ -1.9989012  11.993407 ]
 [ 37.963863   59.93652  ]]) =
[[  -0.        -72.02148 ]
 [-104.0039   -119.93407 ]
 [ -15.99121   119.934074]
 [ 455.56635   839.11127 ]]


## Affine layer

#### Test single vector input behavior

Test, for single input-vector operations:
* input and output widths
* state vector setting and getting
* forward calculation

In [15]:
if __name__ == '__main__':
    # Affine
    a = AffineLayer(2,3)
    key = random.PRNGKey(0)
    #x = random.normal(key, (10,))

    
    # The input and output widths are correct
    assert a(jnp.arange(2)).shape == (3,) 

    # Its internal state can be set
    a.set_state_from_vector(jnp.arange(9))
    # and read back
    assert (a.state_vector() == jnp.arange(9)).all()
    # NOTE: The two assertions below are commented out because they depend
    # on white-box knowledge, and are duplicative of other tests
    #assert jnp.equal(a.M, jnp.array([[0, 1, 2],
    #                               [3, 4, 5]])).all()
    #assert jnp.equal(a.b, jnp.array([6, 7, 8])).all()

    # Its internal state observed using numerical gradient is correct
    x = random.uniform(key, (2,))
    y = a(x)
    dydx = VC.grad(a, x)
    b = y - x.dot(dydx)
    #print(dydx, b)
    #print(dydx, jnp.arange(6).reshape(2,-1))
    assert closenuf(dydx, jnp.arange(6).reshape(2, -1))
    #print(b, jnp.arange(6, 9))
    assert closenuf(b, jnp.arange(6, 9))
    
    # It performs a single-input forward calculation correctly
    x = jnp.array([2, 1])
    y = a(x)
    #print(f"a.M is:\n{a.M}\na.b is {a.b}\nx is: {x}\ny is: {y}\n")
    assert (y == jnp.array([9, 13, 17])).all()
    
    # It performs a different single-input forward calculation correctly
    a.set_state_from_vector(jnp.array([ 2,  3,  5,  7, 11, 13, 17, 19, 23]))
    x = jnp.array([[29, 31]])
    y = a(x)
    assert (y == jnp.array([[292, 447, 571]])).all()

Test, for single input-vector operations:
* back-propagation of the loss gradient
* learning (change in forward function) from the back-prop operation

In [16]:
if __name__ == '__main__':
    # Affine
    a = AffineLayer(2,3)
    a.set_state_from_vector(jnp.arange(9))

    # Doing a single-input-vector calculation
    x = jnp.array([2, 1])
    y = a(x)
    assert jnp.equal(y, jnp.array([9, 13, 17])).all()

    # It back-propagages the loss gradient
    ideal = jnp.array([11,12,10])
    loss = lambda v: (v - ideal).dot(v - ideal) / 2.0
    loss_at_y = loss(y)
    print(f"x = {x}, y = {y}, loss = {loss_at_y}")
    grad_y = VC.grad(loss, y)
    print(f"∇𝑙𝑜𝑠𝑠(𝑦) = {grad_y}")
    grad_x = VC.grad(lambda x:loss(a(x)), x)
    print(f"∇𝑙𝑜𝑠𝑠(𝑥) = {grad_x}")
    
    # Back-propagate the loss gradient from layer output to input
    _ = a(x) # Make sure the last x is in the right place
    out_delE = grad_y * 0.1 # Backprop one-tenth of the loss gradient
    in_delE = a.backprop(out_delE)
    print(f"backprop({out_delE}) = {in_delE}")
    
    # The loss gradient back-propagated to the layer input is correct
    assert closenuf(in_delE / 0.1, grad_x)
    
    # And how did the learning affect the layer?
    print(f"Now a({x}) = {a(x)}, loss = {loss(a(x))}")
    print(f"state_vector is {a.state_vector()}")
    # FIXME: Check the change is correct

x = [2 1], y = [ 9 13 17], loss = 27.0
∇𝑙𝑜𝑠𝑠(𝑦) = [-2.0008085  1.0004042  6.9961543]
∇𝑙𝑜𝑠𝑠(𝑥) = [14.991759 33.006664]
backprop([-0.20008086  0.10004043  0.6996154 ]) = [1.4992713 3.2979963]
Now a([2 1]) = [10.200485 12.399757 12.802307], loss = 4.325977325439453
state_vector is [0.4001617  0.7999191  0.60076916 3.2000809  3.8999596  4.3003845
 6.200081   6.8999596  7.3003845 ]


#### Test batch operations

Test, for batch operations:
* input and output widths
* forward calculation

In [17]:
if __name__ == '__main__':
    # Affine
    a = AffineLayer(2,3)
    a.set_state_from_vector(jnp.arange(9))
    
    # The input and output widths for the forward calculation are correct
    x = two_wide
    y = a(two_wide)
    assert y.shape[0] == x.shape[0]
    assert y.shape[1] == 3
    
    # The input and output widths for the backprop calculation are correct
    bp = a.backprop(three_wide * 0.001)
    assert bp.shape[0] == three_wide.shape[0]
    assert bp.shape[1] == x.shape[1]

    # The forward calculation is correct (in at least two instances)
    a.set_state_from_vector(jnp.arange(9))
    x = jnp.array([[0, 1],
                  [2, 3],
                  [4, 5],
                  [6, 7]])
    assert (a(x) == jnp.array([[ 9, 11, 13],
                              [15, 21, 27],
                              [21, 31, 41],
                              [27, 41, 55]])).all()
    #print(f"a.M is:\n{a.M}\na.b is {a.b}\nx is: {x}\ny is: {y}")
    a.set_state_from_vector(jnp.array([ 2,  3,  5,  7, 11, 13, 17, 19, 23]))
    y = a(x)
    #print(f"x is: {x}\ny is: {y}")
    assert (y == jnp.array([[ 24,  30,  36],
                           [ 42,  58,  72],
                           [ 60,  86, 108],
                           [ 78, 114, 144]])).all()

Test, for batch operations:
* back-propagation of the loss gradient
* learning (change in forward function) from the back-prop operation

In [18]:
if __name__ == '__main__':
    # Affine
    a = AffineLayer(2,3)
    a.set_state_from_vector(jnp.arange(9))
    x = jnp.array([[0, 1],
                  [2, 3],
                  [4, 5],
                  [6, 7]])
    y = a(x)

    # It back-propagages the loss gradient

    # for loss function, use L2-distance from some ideal
    # (divided by 2, for convenient gradient = error)
    ideal = x @ arangep(2*3).reshape(2,3) + arangep(3,6) # A known, different parameter setting
    print(f"y - ideal =\n{y - ideal}")
    #loss = lambda v: (v - ideal).dot(v - ideal) / 2.0
    loss = lambda v: jnp.einsum('ij,ij', v-ideal, v-ideal) / (2 * v.shape[0])
    loss_at_y = loss(y)
    print(f"x =\n{x}\nideal =\n{ideal}\ny =\n{y}, loss = {loss_at_y}\n")

    
    # find numerical gradient of loss function at y, the layer output
    grad_y = VC.tensor_grad(loss, y)
    print(f"∇𝑙𝑜𝑠𝑠(𝑦) =\n{grad_y}")
    
    # find the numerical gradient of the loss w.r.t. the input of the layer
    grad_x = VC.tensor_grad(lambda x:loss(a(x)), x)
    print(f"∇𝑙𝑜𝑠𝑠(𝑥) =\n{grad_x}")
            
    # Back-propagate the loss gradient from layer output to input
    _ = a(x) # Make sure the last x is in the right place
    out_delE = grad_y * 0.01 # Backprop one percent of the loss gradient
    in_delE = a.backprop(out_delE)
    print(f"backprop({out_delE}) = {in_delE}")
    
    # The loss gradient back-propagated to the layer input is correct
    #assert closenuf(in_delE / 0.1, grad_x)
    
    # And how did the learning affect the layer?
    print(f"Now a({x}) = {a(x)}, loss = {loss(a(x))}")
    print(f"state_vector is {a.state_vector()}")
    # FIXME: Check the change is correct

y - ideal =
[[-15 -19 -23]
 [-27 -37 -45]
 [-39 -55 -67]
 [-51 -73 -89]]
x =
[[0 1]
 [2 3]
 [4 5]
 [6 7]]
ideal =
[[ 24  30  36]
 [ 42  58  72]
 [ 60  86 108]
 [ 78 114 144]]
y =
[[ 9 11 13]
 [15 21 27]
 [21 31 41]
 [27 41 55]], loss = 3765.5

∇𝑙𝑜𝑠𝑠(𝑦) =
[[ -55.66406   -81.05468  -102.539055]
 [ -48.33984   -67.871086  -83.98437 ]
 [ -40.03906   -54.687496  -65.18554 ]
 [ -31.98242   -41.25976   -46.630856]]
∇𝑙𝑜𝑠𝑠(𝑥) =
[[ -286.13278 -1006.8359 ]
 [ -235.83983  -835.44916]
 [ -185.30272  -665.039  ]
 [ -134.52147  -494.38474]]
backprop([[-0.55664057 -0.81054676 -1.0253905 ]
 [-0.48339838 -0.6787108  -0.83984363]
 [-0.40039057 -0.54687494 -0.65185535]
 [-0.3198242  -0.4125976  -0.46630853]]) = [[ -2.8613276 -10.039061 ]
 [ -2.358398   -8.364256 ]
 [ -1.8505857  -6.6479483]
 [ -1.3452146  -4.9414053]]
Now a([[0 1]
 [2 3]
 [4 5]
 [6 7]]) = [[ 17.007812  21.917969  26.051758]
 [ 44.47754   60.897453  74.35839 ]
 [ 71.947266  99.876945 122.66502 ]
 [ 99.41699  138.85643  170.97166 ]], loss =

#### Test batch operations when the affine layer has only one input

Test, for batch operations:
* input and output widths
* forward calculation

In [19]:
if __name__ == '__main__':
    # Affine
    a = AffineLayer(1,3)
    a.set_state_from_vector(jnp.arange(6))
    
    # The input and output widths for the forward calculation are correct
    x = one_wide
    y = a(one_wide)
    assert y.shape[0] == x.shape[0]
    assert y.shape[1] == 3
    
    # The input and output widths for the backprop calculation are correct
    bp = a.backprop(three_wide * 0.001)
    assert bp.shape[0] == three_wide.shape[0]
    assert bp.shape[1] == x.shape[1]

    # The forward calculation is correct (in at least two instances)
    a.set_state_from_vector(jnp.arange(6))
    x = jnp.array([[0],
                  [1],
                  [2],
                  [3]])
    assert (y == jnp.array([[ 3.,  4.,  5.],
                           [ 3.,  5.,  7.],
                           [ 3.,  6.,  9.],
                           [ 3.,  7., 11.]])).all()
    #print(f"a.M is:\n{a.M}\na.b is {a.b}\nx is: {x}\ny is: {y}")
    a.set_state_from_vector(jnp.array([ 2,  3,  5,  7, 11, 13]))
    y = a(x)
    #print(f"x is: {x}\ny is: {y}")
    assert (a(x) == jnp.array([[ 7, 11, 13],
                              [ 9, 14, 18],
                              [11, 17, 23],
                              [13, 20, 28]])).all()

Test, for batch operations:
* back-propagation of the loss gradient
* learning (change in forward function) from the back-prop operation

In [20]:
if __name__ == '__main__':
    # Affine
    a = AffineLayer(1,3)
    a.set_state_from_vector(jnp.arange(6))
    x = jnp.array([[0],
                  [1],
                  [2],
                  [3]])
    y = a(x)
    #print(f"x =\n{x}\ny =\n{y}")
    
    # It back-propagages the loss gradient

    # for loss function, use L2-distance from some ideal
    # (divided by 2, for convenient gradient = error)
    ideal = x @ arangep(1*3).reshape(1,3) + arangep(3,6) # A known, different parameter setting
    print(f"y - ideal =\n{y - ideal}")
    #loss = lambda v: (v - ideal).dot(v - ideal) / 2.0
    loss = lambda v: jnp.einsum('ij,ij', v-ideal, v-ideal) / (2 * v.shape[0])
    loss_at_y = loss(y)
    print(f"x =\n{x}\nideal =\n{ideal}\ny =\n{y}, loss = {loss_at_y}\n")

    
    # find numerical gradient of loss function at y, the layer output
    grad_y = VC.tensor_grad(loss, y)
    print(f"∇𝑙𝑜𝑠𝑠(𝑦) =\n{grad_y}")
    
    # find the numerical gradient of the loss w.r.t. the input of the layer
    grad_x = VC.tensor_grad(lambda x:loss(a(x)), x)
    print(f"∇𝑙𝑜𝑠𝑠(𝑥) =\n{grad_x}")
            
    # Back-propagate the loss gradient from layer output to input
    _ = a(x) # Make sure the last x is in the right place
    out_delE = grad_y * 0.01 # Backprop one percent of the loss gradient
    in_delE = a.backprop(out_delE)
    print(f"backprop({out_delE}) = {in_delE}")
    
    # The loss gradient back-propagated to the layer input is correct
    #assert closenuf(in_delE / 0.1, grad_x)
    
    # And how did the learning affect the layer?
    print(f"Now a({x}) = {a(x)}, loss = {loss(a(x))}")
    print(f"state_vector is {a.state_vector()}")
    # FIXME: Check the change is correct

y - ideal =
[[-14 -15 -18]
 [-16 -17 -21]
 [-18 -19 -24]
 [-20 -21 -27]]
x =
[[0]
 [1]
 [2]
 [3]]
ideal =
[[17 19 23]
 [19 22 28]
 [21 25 33]
 [23 28 38]]
y =
[[ 3  4  5]
 [ 3  5  7]
 [ 3  6  9]
 [ 3  7 11]], loss = 570.25

∇𝑙𝑜𝑠𝑠(𝑦) =
[[-22.613523 -25.970457 -33.935543]
 [-22.644041 -24.627684 -31.311033]
 [-22.644041 -23.254393 -28.625486]
 [-22.674559 -21.972654 -26.000975]]
∇𝑙𝑜𝑠𝑠(𝑥) =
[[-93.90258]
 [-87.3413 ]
 [-80.62743]
 [-73.9746 ]]
backprop([[-0.22613522 -0.25970456 -0.3393554 ]
 [-0.2264404  -0.24627683 -0.31311032]
 [-0.2264404  -0.23254392 -0.28625485]
 [-0.22674558 -0.21972653 -0.26000974]]) = [[-0.9384154 ]
 [-0.87249744]
 [-0.8050536 ]
 [-0.739746  ]]
Now a([[0]
 [1]
 [2]
 [3]]) = [[ 3.9057617  4.958252   6.1987305]
 [ 5.26532    7.3287964  9.86438  ]
 [ 6.6248775  9.699341  13.530029 ]
 [ 7.984435  12.069884  17.195679 ]], loss = 389.4485778808594
state_vector is [1.3595579 2.3705442 3.6656492 3.9057617 4.958252  6.1987305]


## Network

### Network assembly

The simplest, the empty network, does identity:

In [21]:
if __name__ == '__main__':
    net = Network()
    assert all(x == net(x) for x in [0, 42, 'cows in trouble'])
    assert all((x == net(x)).all() for x in [np.arange(7), jnp.arange(3*4*5).reshape(3,4,5)])

A stack of maps composes the operations:

In [22]:
if __name__ == '__main__':
    net = Network()
    net.extend(MapLayer(lambda x: x**3, lambda d: 3*d**2))
    x = jnp.array([0, 2, 3, 42, -3.14])
    assert (net(x) == x**3).all()
    net.extend(MapLayer(lambda x: 7-x, lambda d: -1))
    assert (net(x) == 7-x**3).all()
    
    # It operates on each element of an input vector separately
    assert (net(jnp.arange(4)) == 7 - jnp.arange(4) ** 3).all()

A composition of affine transformations

_[to do someday]_

### Network Learning

Test simple batch learning of a single affine layer

In [23]:
if __name__ == '__main__':
    net = Network()
    a = AffineLayer(2,3)
    a.set_state_from_vector(jnp.arange(9)) # A well-known initial state
    net.extend(a)
    print(f"\nNet has state {net.state_vector()}")

    x = jnp.array([[0, 1],
                  [2, 3],
                  [4, 5],
                  [6, 7]])

    # The net wraps the layer
    y = a(x)
    assert (net(x) == y).all()
    
    # Make the training batch.
    # We use a separate affine layer, initialized differently, to determine the ideal
    t_a = AffineLayer(2,3)
    t_a.set_state_from_vector(arangep(9)) # A known different initial state (of primes)
    ideal = t_a(x)
    
    facts = [(x, ideal)]
    print(f"facts are:\n{facts}\n")
    print(f"net(x) =\n{net(x)}")
    
    net.eta = 0.01
    for i in range(10):
        print(f"net.learn(facts) = {net.learn(facts)}")
    print(f"net(x) =\n{net(x)}")
    
    # A simple fact yielder:
    def fact_ory(facts, n):
        for i in range(n):
            yield facts
    
    #print(f"list(fact_ory(facts[0], 3)) =\n{list(fact_ory(facts[0], 3))}\n")
    print(f"net.learn(fact_ory(facts[0],10)) = {net.learn(fact_ory(facts[0],10))}")
    print(f"net(x) =\n{net(x)}")
    for i in range(1000):
        loss = net.learn(fact_ory(facts[0],10))
        if loss < 1e-25:
            break
    print(f"did {(i+1)*10} more learnings of fact. Now loss is {loss}")
    print(f"net(x) =\n{net(x)}")
    
    print(f"net.state_vector() = {net.state_vector()}")
    
    # The network has learned the target transform
    assert closenuf(net(x), facts[0][1])
    
    # Save prior results and learn again, with different batch clustering
    prev_run_loss = loss
    prev_y = net(x)
    net.set_state_from_vector(jnp.arange(9)) # A well-known initial state
    print(f"\nReset net to state {net.state_vector()}")

    # Try multiple batches in each call to Network.learn
    def multibatch_fact_ory(facts, n):
        for i in range(n//2):
            yield facts * 2

    for i in range(1000):
        loss = net.learn(fact_ory(facts[0],10))
        if loss < 1e-25:
            break
    print(f"did {(i+1)*10} learnings of fact. Now loss is {loss}")
    print(f"net(x) =\n{net(x)}")
    
    # The results should match exactly
    assert loss == prev_run_loss
    assert (net(x) == prev_y).all()


Net has state [0 1 2 3 4 5 6 7 8]
facts are:
[(DeviceArray([[0, 1],
             [2, 3],
             [4, 5],
             [6, 7]], dtype=int32), DeviceArray([[ 24,  30,  36],
             [ 42,  58,  72],
             [ 60,  86, 108],
             [ 78, 114, 144]], dtype=int32))]

net(x) =
[[ 9 11 13]
 [15 21 27]
 [21 31 41]
 [27 41 55]]
net.learn(facts) = 3765.5
net.learn(facts) = 1605.49853515625
net.learn(facts) = 708.707275390625
net.learn(facts) = 336.169189453125
net.learn(facts) = 181.206787109375
net.learn(facts) = 116.54439544677734
net.learn(facts) = 89.3607177734375
net.learn(facts) = 77.7341537475586
net.learn(facts) = 72.56697082519531
net.learn(facts) = 70.0833511352539
net(x) =
[[ 14.699151  18.972603  22.707714]
 [ 36.72516   51.692673  64.392944]
 [ 58.75116   84.41274  106.078186]
 [ 80.77717  117.13282  147.76343 ]]
net.learn(fact_ory(facts[0],10)) = 63.1659049987793
net(x) =
[[ 15.146448  19.51728   23.365414]
 [ 37.197094  52.312637  65.145096]
 [ 59.24774   85.1

### Test Network.losses

In [24]:
if __name__ == '__main__':
    # Make a network. Leave it with the default identity behavior.
    net = Network()
    
    x = jnp.array([[0, 1],
                  [2, 3],
                  [4, 5],
                  [6, 7]])
    ideal = net(x)
    facts = [(x, ideal), (x, ideal-np.array([1,-1])), (x, 2*x)]
    assert (net.losses(facts) == [0, 1, 17.5])

    # Add some layers
    net.extend(AffineLayer(2,3)).extend(MapLayer(jnp.sin, jnp.cos)).extend(AffineLayer(3,2))
    # Place it in a known state for test repeatability
    net.set_state_from_vector(jnp.arange(len(net.state_vector())))
    ideal = net(x)
    facts = [(x, ideal), (x, ideal-np.array([1,-1]))]
    #print(net.losses(facts))
    assert (net.losses(facts) == [0, 1])

---

To produce an importable `nn.py`:
1. Save this notebook
1. Uncomment the `jupyter nbconvert` line below
1. Execute it.
1. Comment out the convert again
1. Save the notebook again in that form

In [25]:
###!jupyter nbconvert --to script nn-jax.ipynb