In [1]:
# From scratch python version of sequential NN training
# includes a toy demo fitting to some random data

In [2]:
import numpy as np
npr = lambda x: np.random.randn(*x)
from copy import copy
import itertools

In [3]:
# each function returns a forward function for the layer
# forward functions begin with _
# only linear layer has parameters so far (well, cross_entropy has targets)

def linear(W, B):
    assert(W.shape[-1]==B.shape[0])
    d = {"W":W, "B":B}
    def _linear(X):
        return np.matmul(X,d["W"])+d["B"]
    
    return _linear

# done for naming conventions, I think there's a better way
def actual_relu(x):
    if x>0.: return x

    return 0.

def _relu(X):
    return np.vectorize(actual_relu)(X)

relu = _relu

def _sigmoid(x):
    return 0.5 * (np.tanh(x / 2) + 1)

sigmoid = _sigmoid 

def cross_ent(targets):
    d = {"T":targets}
    def _cross_ent(X):
        label_probs = X * d["T"] + (1 - X) * (1 - d["T"])
        return -np.log(label_probs)
    return _cross_ent

# helper to turn functional programming into OO
def param(layer,p):
    return layer.__closure__[0].cell_contents[p]

# takes the layer, parameter to differentiate wrt, input and output gradient
def vjp(l,p,X,dldy):
    if l.__name__ == "_linear":
        if p=='X':
            return np.matmul(dldy,param(l,'W').T)
        if p=='W':
            return np.matmul(X.T,dldy)
        if p=='B':
            return np.matmul(np.ones((1,dldy.shape[0])),dldy)
            
        
    if l.__name__ == '_relu':
        return (dldy*relu(np.sign(X))).reshape(dldy.shape)
    
    if l.__name__ == '_sigmoid':
        return (dldy*0.5/(np.cosh(X)**2)).reshape(dldy.shape)
        
    
    if l.__name__ == '_cross_ent':
        T = param(l,'T')
        return (((-T/X + (1-T)/(1-X+.0001)))*dldy).reshape(dldy.shape)
    
    raise NotImplementedError(f"Layer: {l.__name__}")
    
def nn(layers):
    def forward(X):
        _X = copy(X)
        for l in layers:
            _X = l(_X)
            
        return _X
    
    return forward

def gradient(layers):
    def backward(X):
        forwards = [X, layers[0](X)] 
        for ix,l in enumerate(layers[1:]):
            forwards.append(l(forwards[ix+1]))
                
        dldy = np.ones((X.shape[0],1)) #number of points in minibatch
        
        grads = [dldy]
        param_grads = []
        
        #print("forward layers")
        #for f in forwards:
            #print(f.shape)
        
        for l,f in zip(layers[::-1],forwards[-2::-1]):
            
            grads.append(vjp(l,'X',f,grads[-1])) # f is the layer input
            
            if hasattr(l,"__name__") and l.__name__ == "_linear":
                # ultimately needs to identify the parameters a layer has and take gradient
                param_grads.append(vjp(l,'W',f,grads[-2])) # grads just got one appended earlier
                param_grads.append(vjp(l,'B',f,grads[-2]))
            
        return grads, param_grads
            
    return backward

In [4]:
def gnuplot(series,markers=itertools.repeat("x"),extra=""):
    pts = ["\n".join([f"{x} {y}" for (x,y) in s]) for s in series]
    plots = ",".join([f" '<(echo \\\"{p}\\\")' pt \\\"{m}\\\" notitle " for p,m in zip(pts,markers)])
    q = f""" "set terminal dumb 80 25; {extra}
             plot {plots}"
        """
    !gnuplot -e {q}   

In [5]:
# some random input data points and ground truth labels (one of two classes)
X0=npr((10,2))
T = relu(np.sign(npr((10,1))))

In [6]:
# plot the datapoints by class
s1 = [(x,y) for (x,y),t in zip(X0,T) if t == 0]
s2 = [(x,y) for (x,y),t in zip(X0,T) if t == 1]
gnuplot([s1, s2], "01", "set xtics out; set ytics out;")

        +         +        +         +         +         +        +         +   
    2 +-+-------------------------------------------------------------------+-+ 
        |                    0                                              |   
  1.5 +-|                                                                   |-+ 
        |                                                                   |   
    1 +-|                                                              0    |-+ 
        |                                                                   |   
        |                                                                   |   
  0.5 +-|   1                                  0                            |-+ 
        |                         1                                         |   
    0 +-|                      1                       0                    |-+ 
        |                      0                                            |   
 -0.5 +-|      

In [7]:
# random initialization
W0 = npr((2,8))
B0 = npr((8,))

W1 = npr((8,1))
B1 = npr((1,))

# two layer NN
net = [linear(W0,B0),
       relu,
       linear(W1,B1),
       sigmoid,
       cross_ent(T),
      ]


In [8]:
# plot the initial softmax output vs class labels
initial_out = nn(net[:-1])(X0)
series = [enumerate(T.flatten()), enumerate(initial_out.flatten())]
markers = "xo"
extra = "set xtics out; set ytics out; set yrange [-0.2:1.2]; set xrange [-0.5:9.5];"
print ("ground truth: x\nprediction: o")
gnuplot(series, markers,extra)

ground truth: x
prediction: o
           +             +             +            +             +             
  1.2 +-+-------------------------------------------------------------------+-+ 
        |                                                                   |   
        |                                                                   |   
    1 +-|                x             x            x      x             x  |-+ 
        |                                                                   |   
        |                                                                   |   
  0.8 +-|                                                                   |-+ 
        |                                                                   |   
        |                                                                   |   
  0.6 +-|                                    o                              |-+ 
        |                                           o                       | 

In [9]:
# train the model with sgd

α = -0.01
losses = []

for step in range(200):

    L = sum(nn(net)(X0))[0]
    losses.append(L)
    
    _, dLdθ = gradient(net)(X0)
    
    W0 += α*dLdθ[2] #npr((2,8))
    B0 += α*dLdθ[3].reshape((B0.shape))#npr((8,))

    W1 += α*dLdθ[0]#npr((8,1))
    B1 += α*dLdθ[1].reshape((B1.shape))#npr((1,))

    net = [linear(W0,B0),
           relu,
           linear(W1,B1),
           sigmoid,
           cross_ent(T),
          ]

    print(f"Step: {step}, Loss {L}")

Step: 0, Loss 6.525573170967385
Step: 1, Loss 6.197217751425823
Step: 2, Loss 5.895874423487789
Step: 3, Loss 5.650123918902228
Step: 4, Loss 5.4519834954889035
Step: 5, Loss 5.283075291050255
Step: 6, Loss 5.128004147059727
Step: 7, Loss 4.974638743636677
Step: 8, Loss 4.842789449354516
Step: 9, Loss 4.7284585342144085
Step: 10, Loss 4.628280485447879
Step: 11, Loss 4.5395105511121585
Step: 12, Loss 4.45997599293664
Step: 13, Loss 4.400451338474991
Step: 14, Loss 4.355767593398162
Step: 15, Loss 4.313962246182162
Step: 16, Loss 4.274640892293005
Step: 17, Loss 4.237485227484032
Step: 18, Loss 4.202235548285703
Step: 19, Loss 4.168677597540788
Step: 20, Loss 4.136632622935657
Step: 21, Loss 4.10594982565236
Step: 22, Loss 4.076500597006351
Step: 23, Loss 4.04817409920334
Step: 24, Loss 4.020873860762831
Step: 25, Loss 3.9945151406359667
Step: 26, Loss 3.969022876489216
Step: 27, Loss 3.9443300781905815
Step: 28, Loss 3.9203765615311355
Step: 29, Loss 3.897107942699872
Step: 30, Loss 3.

In [10]:
# loss
gnuplot([[(x,y) for x,y in enumerate(losses)]])

                                                                                
    7 +---------------------------------------------------------------------+   
      |      +      +      +      +      +      +      +      +      +      |   
  6.5 |-+                                                                 +-|   
      |                                                                     |   
      |                                                                     |   
    6 |x+                                                                 +-|   
      |x                                                                    |   
  5.5 |x+                                                                 +-|   
      | x                                                                   |   
    5 |-x                                                                 +-|   
      |  x                                                                  |   
      |  xx    

In [11]:
# compate how the outputs have moved towards the ground truth 
# for the trained network (same plot as before)
trained_out = nn(net[:-1])(X0)
series = [enumerate(T.flatten()), enumerate(trained_out.flatten())]
markers = "xo"
extra = "set xtics out; set ytics out; set yrange [-0.2:1.2]; set xrange [-0.5:9.5];"
print ("ground truth: x\nprediction: o")
gnuplot(series, markers,extra)

ground truth: x
prediction: o
           +             +             +            +             +             
  1.2 +-+-------------------------------------------------------------------+-+ 
        |                                                                   |   
        |                                                                   |   
    1 +-|                x             x            x      x             x  |-+ 
        |                                           o                       |   
        |                o                                                  |   
  0.8 +-|                                                                o  |-+ 
        |                                                                   |   
        |                              o     o             o                |   
  0.6 +-|                                                                   |-+ 
        |                                                                   | 

In [12]:
T

array([[0.],
       [0.],
       [1.],
       [0.],
       [1.],
       [0.],
       [1.],
       [1.],
       [0.],
       [1.]])