# Load libraries

In [2]:
from sympy import *
from IPython.display import display
a, b, c, d, x, y_true, h, e = symbols('a b c d x y_true h \epsilon')
w = Function('w')
s = Function('s')

# Setting up problem

We begin by setting up the parameter matrix.   This is meant to be a simple example, so we will set up the parameters as scalars.

Note, there is no $w$ here, since that will be added later in the iterate function.   We want to think of $w$ as a general nonlinear function of the state variables, so we will not include it in the parameter matrix.


In [3]:
A = Matrix([[1, 0, 0], 
            [0, 0, a],
            [b, c, d],
            ])
A

Matrix([
[1, 0, 0],
[0, 0, a],
[b, c, d]])

In [14]:
Matrix([[1, 0, 0], 
        [w(x), 0, a],
        [b, c, d],
        ])

Matrix([
[   1, 0, 0],
[w(x), 0, a],
[   b, c, d]])

This is the input data, initialized with $x$, but we don't know $y$ so we set it to zero.   We keep $h$ since perhaps we want to initialize that in some interesting way.

In [4]:
z = Matrix([[x], 
            [0], 
            [h]])

Ok, this is function includes all the bits, including the matrix of parameters, the nonlinear $w$ function, and the non-linear activation function on just the $h$ part.

In [5]:
def iterate(z):
    # The linear part of the model
    z_hat = A @ z
    # We allow w(x) to be an arbitrary function of x
    z_hat[1] = z_hat[1] + w(x)
    # A nonlinear activation function
    z_hat[2] = s(z_hat[2])
    return z_hat

We can iterate the map as much as we like, but $3$ will suffice for our current purposes.

In [6]:
iterate(z)

Matrix([
[           x],
[  a*h + w(x)],
[s(b*x + d*h)]])

In [7]:
z_hat = iterate(iterate(iterate(z)))
z_hat

Matrix([
[                                                                              x],
[                              a*s(b*x + c*(a*h + w(x)) + d*s(b*x + d*h)) + w(x)],
[s(b*x + c*(a*s(b*x + d*h) + w(x)) + d*s(b*x + c*(a*h + w(x)) + d*s(b*x + d*h)))]])

# Compute the loss function

The loss only depends on $y$.

In [8]:
y_hat = z_hat[1]
y_hat

a*s(b*x + c*(a*h + w(x)) + d*s(b*x + d*h)) + w(x)

We do squared loss

In [9]:
loss = (y_true - y_hat)**2
loss

(-a*s(b*x + c*(a*h + w(x)) + d*s(b*x + d*h)) + y_true - w(x))**2

Ok,now things get interesting, as long as we start at $a=0$ we do not increase the loss!  This is what we want, and nice to confirm here.

In [10]:
loss.subs({a: 0, b: e, c: e, d: e})

(y_true - w(x))**2

We now look at the gradient of the loss function with respect to the parameters.  If we set $b,c,d$ to zero, and have that $s(0)=$ we have a $$0$ gradient, which we don't want!

In [11]:
init = 0
print('loss')
display(loss.subs({a: 0, b: init, c: init, d: init}))
print('gradients')
display(diff(loss, a).subs({a: 0, b: init, c: init, d: init}))
display(diff(loss, b).subs({a: 0, b: init, c: init, d: init}))
display(diff(loss, c).subs({a: 0, b: init, c: init, d: init}))
display(diff(loss, d).subs({a: 0, b: init, c: init, d: init}))


loss


(y_true - w(x))**2

gradients


-2*(y_true - w(x))*s(0)

0

0

0

But, If we set $b,c,d$ to $\epsilon$, then we get exactly what we want!  The original loss, but a useful gradient for $a$

In [12]:
init = e
print('loss')
display(loss.subs({a: 0, b: init, c: init, d: init}))
print('gradients')
display(diff(loss, a).subs({a: 0, b: init, c: init, d: init}))
display(diff(loss, b).subs({a: 0, b: init, c: init, d: init}))
display(diff(loss, c).subs({a: 0, b: init, c: init, d: init}))
display(diff(loss, d).subs({a: 0, b: init, c: init, d: init}))

loss


(y_true - w(x))**2

gradients


-2*(y_true - w(x))*s(\epsilon*x + \epsilon*s(\epsilon*h + \epsilon*x) + \epsilon*w(x))

0

0

0

As soon as $a$ is non-zero, we get a non-zero gradient for the rest, which is what we want.

In [13]:
init = e
display(diff(loss, a).subs({a: init, b: init, c: init, d: init}))
display(diff(loss, b).subs({a: init, b: init, c: init, d: init}))
display(diff(loss, c).subs({a: init, b: init, c: init, d: init}))
display(diff(loss, d).subs({a: init, b: init, c: init, d: init}))

(-2*\epsilon**2*h*Subs(Derivative(s(_xi_1), _xi_1), _xi_1, \epsilon*x + \epsilon*(\epsilon*h + w(x)) + \epsilon*s(\epsilon*h + \epsilon*x)) - 2*s(\epsilon*x + \epsilon*(\epsilon*h + w(x)) + \epsilon*s(\epsilon*h + \epsilon*x)))*(-\epsilon*s(\epsilon*x + \epsilon*(\epsilon*h + w(x)) + \epsilon*s(\epsilon*h + \epsilon*x)) + y_true - w(x))

-2*\epsilon*(\epsilon*x*Subs(Derivative(s(_xi_1), _xi_1), _xi_1, \epsilon*h + \epsilon*x) + x)*(-\epsilon*s(\epsilon*x + \epsilon*(\epsilon*h + w(x)) + \epsilon*s(\epsilon*h + \epsilon*x)) + y_true - w(x))*Subs(Derivative(s(_xi_1), _xi_1), _xi_1, \epsilon*x + \epsilon*(\epsilon*h + w(x)) + \epsilon*s(\epsilon*h + \epsilon*x))

-2*\epsilon*(\epsilon*h + w(x))*(-\epsilon*s(\epsilon*x + \epsilon*(\epsilon*h + w(x)) + \epsilon*s(\epsilon*h + \epsilon*x)) + y_true - w(x))*Subs(Derivative(s(_xi_1), _xi_1), _xi_1, \epsilon*x + \epsilon*(\epsilon*h + w(x)) + \epsilon*s(\epsilon*h + \epsilon*x))

-2*\epsilon*(\epsilon*h*Subs(Derivative(s(_xi_1), _xi_1), _xi_1, \epsilon*h + \epsilon*x) + s(\epsilon*h + \epsilon*x))*(-\epsilon*s(\epsilon*x + \epsilon*(\epsilon*h + w(x)) + \epsilon*s(\epsilon*h + \epsilon*x)) + y_true - w(x))*Subs(Derivative(s(_xi_1), _xi_1), _xi_1, \epsilon*x + \epsilon*(\epsilon*h + w(x)) + \epsilon*s(\epsilon*h + \epsilon*x))