<a href="https://colab.research.google.com/github/zifeitong/little_learner/blob/main/colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax
import jax.numpy as jnp
import builtins

# redefine map() to return a list
def map(*args, **kwargs):
  return list(builtins.map(*args, **kwargs))

In [None]:
# Chapter 1

def line(x):
  return lambda theta: theta[0] * x + theta[1]

In [None]:
# Chapter 2

def is_scalar(tensor):
  return isinstance(tensor, (int, float))

class Tensor:
  def __init__(self, elements):
    self._elements = elements

  def __getitem__(self, key):
    return self._elements[key]

  def __len__(self):
    return len(self._elements)

  def __repr__(self):
    return self._elements.__repr__()

def rank(tensor):
  if is_scalar(tensor):
    return 0
  else:
    return rank(tensor[0]) + 1

def shape(tensor):
  if is_scalar(tensor):
    return []
  else:
    return [len(tensor)] + shape(tensor[0])

def equal(lhs, rhs):
  if is_scalar(lhs) and is_scalar(rhs):
    return lhs == rhs

  if shape(lhs) != shape(rhs):
    return False

  for i in range(len(lhs._elements)):
    if not equal(lhs._elements[i], rhs._elements[i]):
      return False
  return True

assert(rank(1) == 0)
assert(rank(Tensor([0])) == 1)
assert(shape(Tensor([0])) == [1])
assert(rank(Tensor([Tensor([1, 2]), Tensor([3, 4])])) == 2)
assert(shape(Tensor([Tensor([1, 2, 3]), Tensor([4, 5, 6])])) == [2, 3])

In [None]:
# Interlude I

def sum_1(tensor):
  s = 0
  for i in range(len(tensor)):
    s += tensor[i]
  return s

def sum(tensor):
  if rank(tensor) == 1:
    return sum_1(tensor)
  else:
    return Tensor([sum(e) for e in tensor._elements])

assert(
    equal(sum(Tensor([Tensor([Tensor([1, 2, 3])]),
                      Tensor([Tensor([4, 5, 6])])])),
          Tensor([Tensor([6]), Tensor([15])])))
assert(equal(sum(Tensor([Tensor([1, 2, 3]), Tensor([4, 5, 6])])),
             Tensor([6, 15])))
assert(equal(sum(Tensor([1, 2, 3])), 6))

In [None]:
# Chapter 3

# From now on, using jnp.array as Tensor

def l2_loss(target):
  def expectant_func(xs, ys):
    def obj_func(theta, unused_rev = None):
      pred_ys = target(xs)(theta)
      return jnp.sum((ys - pred_ys) ** 2)
    return obj_func
  return expectant_func

xs = jnp.array([2.0, 1.0, 4.0, 3.0])
ys = jnp.array([1.8, 1.2, 4.2, 3.3])
assert(l2_loss(line)(xs, ys)([0, 0], 0) == 33.21)

In [None]:
# Chapter 4

gradient_of = jax.grad

assert(gradient_of(lambda theta: theta[0] * theta[0])([27.0]) == [54.0])
assert(gradient_of(l2_loss(line)(xs, ys))([0.0, 0.0]) == [-63.0, -21.0])

def revise(f, revs, theta):
  for i in range(revs):
    theta = f(theta, i)
  return theta

revs = 1000
alpha = 0.01
def gradient_descent(obj, theta):
  def f(big_theta, unused_rev = None):
    return map(lambda p, g: p - alpha * g,
               big_theta,
               gradient_of(obj)(big_theta))
  return revise(jax.jit(f), revs, theta)

big_theta = gradient_descent(l2_loss(line)(xs, ys), [0.0, 0.0])
assert jnp.isclose(big_theta[0], 1.05)
assert jnp.isclose(big_theta[1], 1.87e-6, atol=1e-5)

In [None]:
# Interlude II

import contextvars

class _HyperParameter():
  def __init__(self, parameter, value):
    self._parameter = parameter
    self._value = value

  def __enter__(self):
    self._token = self._parameter.set(self._value)

  def __exit__(self, type, value, traceback):
    self._parameter.reset(self._token)

def declare_hypers(name):
  return contextvars.ContextVar(name)

def hypers(parameter, value):
  return _HyperParameter(parameter, value)

smaller = declare_hypers('smaller')
larger = declare_hypers('larger')

with hypers(smaller, 1), hypers(larger, 2):
  assert(smaller.get() == 1)
  assert(larger.get() == 2)

In [None]:
# Chapter 5

revs = declare_hypers('revs')
alpha = declare_hypers('alpha')

def gradient_descent(obj, theta):
  def f(big_theta, rev):
    return map(lambda p, g: p - alpha.get() * g,
               big_theta,
               gradient_of(obj)(big_theta, rev))
  return revise(jax.jit(f), revs.get(), theta)

quad_xs = jnp.array([-1.0, 0.0, 1.0, 2.0, 3.0])
quad_ys = jnp.array([2.55, 2.1, 4.35, 10.2, 18.25])

def quad(x):
  return lambda theta: theta[0] * (x ** 2) + theta[1] * x + theta[2]

with hypers(revs, 1000), hypers(alpha, 0.001):
  big_theta = gradient_descent(l2_loss(quad)(quad_xs, quad_ys), [0.0, 0.0, 0.0])
  assert jnp.isclose(big_theta[0], 1.48, rtol=1e-2)
  assert jnp.isclose(big_theta[1], 0.99, rtol=1e-2)
  assert jnp.isclose(big_theta[2], 2.05, rtol=1e-2)

plane_xs = jnp.array([[1.0, 2.05], [1.0, 3.0], [2.0, 2.0], [2.0, 3.91], [3.0, 6.13], [4.0, 8.09]])
plane_ys = jnp.array([13.99, 15.99, 18.0, 22.4, 30.2, 37.94])

def plane(t):
  return lambda theta: jnp.dot(t, theta[0]) + theta[1]

with hypers(revs, 1000), hypers(alpha, 0.001):
  big_theta = gradient_descent(l2_loss(plane)(plane_xs, plane_ys), [jnp.array([0.0, 0.0]), 0.0])
  assert jnp.isclose(big_theta[0][0], 3.98, rtol=1e-1)
  assert jnp.isclose(big_theta[0][1], 2.04, rtol=1e-1)
  assert jnp.isclose(big_theta[1], 5.78, rtol=1e-1)

In [None]:
# Chapter 6

batch_size = declare_hypers('batch_size')

def samples(n, s, key):
  return jax.random.choice(key, jnp.arange(0, n), [s])

def sampling_obj(expectant, xs, ys, key):
  def obj(theta, rev):
    n = len(xs)
    b = samples(n, batch_size.get(), jax.random.fold_in(key, rev))
    return expectant(xs[b], ys[b])(theta)
  return obj

with hypers(revs, 15000), hypers(alpha, 0.001), hypers(batch_size, 4):
  big_theta = gradient_descent(
      sampling_obj(l2_loss(plane), plane_xs, plane_ys, jax.random.key(42)),
       [jnp.array([0.0, 0.0]), 0.0])
  assert jnp.isclose(big_theta[0][0], 3.98, rtol=1e-1)
  assert jnp.isclose(big_theta[0][1], 1.97, rtol=1e-1)
  assert jnp.isclose(big_theta[1], 6.16, rtol=1e-1)

In [None]:
# Chapter 7

def try_plane(a_gradient_decent):
  with hypers(revs, 15000), hypers(alpha, 0.001), hypers(batch_size, 4):
    big_theta = a_gradient_decent(
        sampling_obj(l2_loss(plane), plane_xs, plane_ys, jax.random.key(42)),
        [jnp.array([0.0, 0.0]), 0.0])
    assert jnp.isclose(big_theta[0][0], 3.98, rtol=0.1)
    assert jnp.isclose(big_theta[0][1], 1.97, rtol=0.1)
    assert jnp.isclose(big_theta[1], 6.16, rtol=0.1)

def gradient_descent(inflate, deflate, update):
  def _gradient_descent(obj, theta):
    def f(big_theta, rev):
      return map(update,
                 big_theta,
                 gradient_of(obj)(map(deflate, big_theta), rev))
    return map(deflate, revise(jax.jit(f), revs.get(), map(inflate, theta)))
  return _gradient_descent

def naked_i(p):
  big_p = p
  return big_p

def naked_d(big_p):
  p = big_p
  return p

def naked_u(big_p, g):
  p = big_p
  return p - alpha.get() * g

naked_gradient_descent = gradient_descent(naked_i, naked_d, naked_u)

try_plane(naked_gradient_descent)

In [None]:
# Chapter 8

mu = declare_hypers('mu')

def velocity_i(p):
  if is_scalar(p):
    return [p, 0.0]
  return [p, jnp.zeros(p.shape)]

def velocity_d(big_p):
  return big_p[0]

def velocity_u(big_p, g):
  v = mu.get() * big_p[1] - alpha.get() * g
  return [big_p[0] + v, v]

velocity_gradient_descent = gradient_descent(velocity_i, velocity_d, velocity_u)

def try_plane(a_gradient_decent, a_revs):
  with hypers(revs, a_revs), hypers(alpha, 0.001), hypers(batch_size, 4):
    key = jax.random.key(42)
    big_theta = a_gradient_decent(
        sampling_obj(l2_loss(plane), plane_xs, plane_ys, key),
        [jnp.array([0.0, 0.0]), 0.0])
    assert jnp.isclose(big_theta[0][0], 3.98, rtol=1e-1)
    assert jnp.isclose(big_theta[0][1], 1.97, rtol=1e-1)
    assert jnp.isclose(big_theta[1], 6.16, rtol=1e-1)

with hypers(mu, 0.9):
  try_plane(velocity_gradient_descent, 5000)

In [None]:
# Interlude IV

def smooth(decay_rate, average, g):
  return decay_rate * average + (1 - decay_rate) * g

assert jnp.isclose(smooth(0.9,
                          jnp.array([0.82, 2.9, 2.28]),
                          jnp.array([13.4, 18.2, 41.4])),
                   jnp.array([2.08, 4.43, 6.19]), rtol=1e-1).all()

In [None]:
# Chapter 9

beta = declare_hypers('beta')

EPS = 1e-8

def rms_u(big_p, g):
  r = smooth(beta.get(), big_p[1], g ** 2)
  alpha_hat = alpha.get() / (jnp.sqrt(r) + EPS)
  return [big_p[0] - alpha_hat * g, r]

def rms_i(p):
  if is_scalar(p):
    return [p, 0.0]
  return [p, jnp.zeros(p.shape)]

def rms_d(big_p):
  return big_p[0]

rms_gradient_descent = gradient_descent(rms_i, rms_d, rms_u)

def try_plane(a_gradient_decent, a_revs, a_alpha):
  with hypers(revs, a_revs), hypers(alpha, a_alpha), hypers(batch_size, 4):
    big_theta = a_gradient_decent(
        sampling_obj(l2_loss(plane), plane_xs, plane_ys, jax.random.key(42)),
        [jnp.array([0.0, 0.0]), 0.0])
    assert jnp.isclose(big_theta[0][0], 3.98, rtol=1e-1)
    assert jnp.isclose(big_theta[0][1], 1.97, rtol=1e-1)
    assert jnp.isclose(big_theta[1], 6.16, rtol=1e-1)

with hypers(beta, 0.9):
  try_plane(rms_gradient_descent, 3000, 0.01)

def adam_u(big_p, g):
  r = smooth(beta.get(), big_p[2], g ** 2)
  alpha_hat = alpha.get() / (jnp.sqrt(r) + EPS)
  v = smooth(mu.get(), big_p[1], g)
  return [big_p[0] - alpha_hat * v, v, r]

def adam_i(p):
  if is_scalar(p):
    return [p, 0.0, 0.0]
  return [p, jnp.zeros(p.shape), jnp.zeros(p.shape)]

def adam_d(big_p):
  return big_p[0]

adam_gradient_descent = gradient_descent(adam_i, adam_d, adam_u)

with hypers(mu, 0.85), hypers(beta, 0.9):
  try_plane(rms_gradient_descent, 1500, 0.01)

In [None]:
# Interlude V

import math

def tmap(f, *args):
  return jax.vmap(f)(*args)

def is_scalar(t):
  return isinstance(t, (int, float)) or t.shape == ()

def is_of_rank(n, t):
  if n == 0:
    return is_scalar(t)
  elif is_scalar(t):
    return False
  else:
    return is_of_rank(n-1, t[0])

def ext_1(f, n):
  return lambda t: f(t) if is_of_rank(n, t) else tmap(ext_1(f, n), t)

sqrt_0 = jnp.sqrt

sqrt = ext_1(sqrt_0, 0)
zeros = ext_1(lambda x: 0.0, 0)
sum = ext_1(sum_1, 1)
flatten = ext_1(jnp.ravel, 2)

def rank_gt(t, u):
  if is_scalar(t):
    return False
  elif is_scalar(u):
    return True
  else:
    return rank_gt(t[0], u[0])

def is_of_ranks(n, t, m, u):
  if is_of_rank(n, t):
    return is_of_rank(m, u)
  return False

def desc_t(g, t, u):
  return tmap(lambda et: g(et, u), t)

def desc_u(g, t, u):
  return tmap(lambda eu: g(t, eu), u)

def desc(g, n, t, m, u):
  if is_of_rank(n, t):
    return desc_u(g, t, u)
  elif is_of_rank(m, u):
    return desc_t(g, t, u)
  elif len(t) == len(u):
    return tmap(g, t, u)
  elif rank_gt(t, u):
    return desc_t(g, t, u)
  else:
    return desc_u(g, t, u)


def ext_2(f, n, m):
  return lambda t, u: f(t, u) if is_of_ranks(n, t, m, u) else desc(ext_2(f, n, m), n, t, m, u)

import operator

add = ext_2(operator.add, 0, 0)
mul = ext_2(operator.mul, 0, 0)

def sqr(t):
  return mul(t, t)

dot = ext_2(jnp.dot, 1, 1)

assert jnp.array_equal(add(jnp.array([1, 2]), 3),
                       jnp.array([4, 5]))
assert jnp.array_equal(add(jnp.array([1, 2]), jnp.array([1, 2])),
                       jnp.array([2, 4]))
assert jnp.array_equal(sqr(jnp.array([1, 2])),
                       jnp.array([1, 4]))
assert jnp.array_equal(sqr(jnp.array([[1, 2], [3, 4]])),
                       jnp.array([[1, 4], [9, 16]]))
assert dot(jnp.array([1, 2]), jnp.array([1, 2])) == 5
assert jnp.array_equal(dot(jnp.array([[1, 2], [3, 4]]), jnp.array([1, 2])),
                       jnp.array([5, 11]))

mul_2_1 = ext_2(mul, 2, 1)

assert jnp.array_equal(mul(jnp.array([[3, 4, 5], [7, 8, 9]]),
                           jnp.array([2, 4, 3])),
                       jnp.array([[6, 16, 15], [14, 32, 27]]))
assert jnp.array_equal(mul_2_1(jnp.array([[3, 4, 5], [7, 8, 9]]),
                               jnp.array([2, 4, 3])),
                       jnp.array([[6, 16, 15], [14, 32, 27]]))
assert jnp.array_equal(mul(jnp.array([[8, 1], [7, 3], [5, 4]]),
                           jnp.array([[6, 2], [4, 9], [3, 8]])),
                       jnp.array([[48, 2], [28, 27], [15, 32]]))
assert jnp.array_equal(mul_2_1(jnp.array([[8, 1], [7, 3], [5, 4]]),
                               jnp.array([[6, 2], [4, 9], [3, 8]])),
                       jnp.array([[[48, 2], [42, 6], [30, 8]],
                                  [[32, 9], [28, 27], [20, 36]],
                                  [[24, 8], [21, 24], [15, 32]]]))

In [None]:
# Chapter 10

import jax.lax as lax

def rectify_0(s):
  return lax.cond(s < 0.0, lambda _: 0.0, lambda x: x, s)

rectify = ext_1(rectify_0, 0)

def linear_1_1(t):
  return lambda theta: theta[0].dot(t) + theta[1]

def relu_1_1(t):
  return lambda theta: rectify(linear_1_1(t)(theta))

assert relu_1_1(jnp.array([2.0, 1.0, 3.0]))([jnp.array([7.1, 4.3, -6.4]), 0.6]) == 0.0

In [None]:
# Chapter 11

def dot_2_1(w, t):
  return sum(mul_2_1(w, t))

def linear(t):
  return lambda theta: dot_2_1(theta[0], t) + theta[1]

def relu(t):
  return lambda theta: rectify(linear(t)(theta))

def k_relu(k):
  return lambda t: lambda theta: t if k == 0 else k_relu(k-1)(relu(t)(theta))(theta[2:])

In [None]:
# Chapter 12

def block(fn, shape_list):
  return [fn, shape_list]

def block_fn(ba):
  return ba[0]

def block_ls(ba):
  return ba[1]

def dense_block(n, m):
  return block(relu, [[m, n], [m]])

layer1 = dense_block(32, 64)
layer2 = dense_block(64, 45)
layer3 = dense_block(45, 26)

def block_compose(f, g, j):
  return lambda t: lambda theta: g(f(t)(theta))(theta[j:])

def stack_2(ba, bb):
  return block(block_compose(block_fn(ba), block_fn(bb), len(block_ls(ba))),
               block_ls(ba) + block_ls(bb))

def stacked_blocks(rbls, ba):
  if not rbls:
    return ba
  else:
    return stacked_blocks(rbls[1:], stack_2(ba, rbls[0]))

def stack_blocks(bls):
  return stacked_blocks(bls[1:], bls[0])

three_layer_network = stack_blocks([layer1, layer2, layer3])

In [None]:
# Chapter 13

from sklearn import datasets
iris = datasets.load_iris()

iris_train_xs = jnp.concat(
    [iris['data'][0:45],  iris['data'][50:95],  iris['data'][100:145]])
iris_test_xs = jnp.concat(
    [iris['data'][45:50], iris['data'][95:100], iris['data'][145:150]])
iris_train_ys = jax.nn.one_hot(jnp.concat(
    [iris['target'][0:45],  iris['target'][50:95],  iris['target'][100:145]]), 3)
iris_test_ys = jax.nn.one_hot(jnp.concat(
    [iris['target'][45:50], iris['target'][95:100], iris['target'][145:150]]), 3)

iris_network = stack_blocks([dense_block(4, 8), dense_block(8, 3)])

def random_tensor(c, v, s, key):
   return c + jax.random.normal(key, shape=s) * sqrt(v)

def init_shape(s, key):
  if len(s) == 1:
    return jnp.zeros(s)
  elif len(s) == 2:
    return random_tensor(0.0, 2.0 / s[1], s, key)

def init_theta(shapes, key):
  key, *keys = jax.random.split(key, num=len(shapes) + 1)
  return map(init_shape, shapes, keys)

iris_classifer = block_fn(iris_network)
iris_theta_shapes = block_ls(iris_network)

# From https://github.com/themetaschemer/malt/blob/4fee9a6b70146058bf253dbadaae1eff3681ccbe/examples/iris.rkt#L94
iris_initial_theta = [
    jnp.array(
        [[0.4567374693020529, 0.19828623224159106, -0.1791656741530271, -0.3010909419105787],
        [-0.6085978529055036, -0.37813256632159414, 0.6525919461799214, -0.02736258427588277],
        [-0.15910077091878255, 0.30935100240945007, -0.43223348220649294, 0.44424201464211593],
        [ 0.29780171646282, 0.27115067507001933, 0.3512802108530173, -0.941133353767241],
        [-0.6435366194048697, -0.7870457121505098, 0.4672028162559846, -0.4060316748060222],
        [ 0.3542366127804169, -0.6294805381631496, 1.2119983516222874, -0.48964923866459675],
        [ 0.29072501026246134, -0.11992778583131615, 0.2716865689059567, 0.5051197463327993],
        [-0.05677192201680251, -0.8933344786252218, 0.10639004770659627, -0.7276129460870265]]),
    jnp.array([0., 0., 0., 0., 0., 0., 0., 0.]),
    jnp.array(
        [[0.8360463658942785, 0.21163937440648464, -0.36559830767572854, 0.34006155051045595,
           0.3095265146359776, -0.1585941540367561, 0.33268716624682165, -0.5114119488395097],
        [0.15466255181586858, -0.26658077790718954, 0.04571706722376748, 0.10422918798466209,
           -0.17593682447129064, 0.6075530713389936, 0.007216798991190192, -0.4698148147112468],
        [0.06636510408180833, -0.11501406598247131, 0.7855953481117244, 0.00849992094421447,
           0.10415852852056427, 0.4557511137599346, -0.029003952783791656, 1.1873084795704665]]),
    jnp.array([0., 0., 0.])]

with hypers(revs, 2000), hypers(alpha, 0.0002), hypers(batch_size, 8):
  iris_theta = naked_gradient_descent(
      sampling_obj(l2_loss(iris_classifer), iris_train_xs, iris_train_ys, jax.random.key(42)),
      iris_initial_theta)

def model(target, theta):
  return lambda t: target(t)(theta)

iris_model = model(iris_classifer, iris_theta)

In [None]:
# Interlude VI
import itertools
import contextlib

def next_a(t, i, a):
  return jnp.where(t[i] > t[a], i, a)

def argmaxed(t, i, a):
  a_hat = next_a(t, i, a)
  if i == 0:
    return a_hat
  else:
    return argmaxed(t, i-1, a_hat)

def argmax_1(t):
  i = len(t) - 1
  return argmaxed(t, i, i)

def class_eq_1(t, u):
  return jnp.where(argmax_1(t) == argmax_1(u), 1.0, 0.0)

class_eq = ext_2(class_eq_1, 1, 1)

def accuracy(a_model, xs, ys):
  return sum(class_eq(a_model(xs), ys)) / len(xs)

assert accuracy(iris_model, iris_test_xs, iris_test_ys) == 1.0

def grid_search(hyper_vals, f):
  for prod in itertools.product(*map(lambda e: e[1], hyper_vals)):
    with contextlib.ExitStack() as stack:
      for var, val in zip(hyper_vals, prod):
        stack.enter_context(hypers(var[0], val))
      f()

def find_iris_theta():
  iris_theta = naked_gradient_descent(
      sampling_obj(l2_loss(iris_classifer), iris_train_xs, iris_train_ys, jax.random.key(42)),
      iris_initial_theta)
  iris_model = model(iris_classifer, iris_theta)
  print(f"rev: {revs.get()}, alpha: {alpha.get()}, batch_size: {batch_size.get()}, accuracy:",
        accuracy(iris_model, iris_test_xs, iris_test_ys))

grid_search([(revs,       [500, 1000, 2000, 4000]),
             (alpha,      [0.0001, 0.0002, 0.0005]),
             (batch_size, [4, 8, 16])],
            find_iris_theta)

rev: 500, alpha: 0.0001, batch_size: 4, accuracy: 0.6666667
rev: 500, alpha: 0.0001, batch_size: 8, accuracy: 0.6666667
rev: 500, alpha: 0.0001, batch_size: 16, accuracy: 0.8
rev: 500, alpha: 0.0002, batch_size: 4, accuracy: 0.8
rev: 500, alpha: 0.0002, batch_size: 8, accuracy: 0.8
rev: 500, alpha: 0.0002, batch_size: 16, accuracy: 0.93333334
rev: 500, alpha: 0.0005, batch_size: 4, accuracy: 1.0
rev: 500, alpha: 0.0005, batch_size: 8, accuracy: 1.0
rev: 500, alpha: 0.0005, batch_size: 16, accuracy: 1.0
rev: 1000, alpha: 0.0001, batch_size: 4, accuracy: 0.93333334
rev: 1000, alpha: 0.0001, batch_size: 8, accuracy: 0.93333334
rev: 1000, alpha: 0.0001, batch_size: 16, accuracy: 1.0
rev: 1000, alpha: 0.0002, batch_size: 4, accuracy: 1.0
rev: 1000, alpha: 0.0002, batch_size: 8, accuracy: 1.0
rev: 1000, alpha: 0.0002, batch_size: 16, accuracy: 1.0
rev: 1000, alpha: 0.0005, batch_size: 4, accuracy: 1.0
rev: 1000, alpha: 0.0005, batch_size: 8, accuracy: 1.0
rev: 1000, alpha: 0.0005, batch_size

In [None]:
# Chapter 14

import functools
import numpy as np

def load_data(path, shape):
  return jnp.array(np.loadtxt(path)).reshape(*shape)

!git clone https://github.com/zifeitong/little_learner &> /dev/null

morse_train_xs = load_data('little_learner/data/morse-train-xs', [5200, 16, 1]) - 0.5
morse_train_ys = load_data('little_learner/data/morse-train-ys', [5200, 26])
morse_test_xs = load_data('little_learner/data/morse-test-xs', [1040, 16, 1]) - 0.5
morse_test_ys = load_data('little_learner/data/morse-test-ys', [1040, 26])

def correlate_2_2(filter, signal):
  output = jax.scipy.signal.correlate2d(signal, filter, mode="same")
  return output[:, (shape(output)[1] - 1) // 2]

def correlate_3_2(filter_banks, signal):
  return ext_2(correlate_2_2, 2, 2)(filter_banks, signal).transpose()

correlate = ext_2(correlate_3_2, 3, 2)


signal = jnp.array([
    [1., 2.], [3., 4.], [5., 6.], [7., 8.], [9., 10.], [11., 12.]
])

filter_banks = jnp.array([
    [[1., 2.], [3., 4.], [5., 6.]],
    [[7., 8.], [9., 10.], [11., 12.]],
    [[13., 14.], [15., 16.], [17., 18.]],
    [[19., 20.], [21., 22.], [23., 24.]],
])

assert jnp.array_equal(
    correlate_3_2(filter_banks, signal),
    jnp.array([
       [  50.,  110.,  170.,  230.],
       [  91.,  217.,  343.,  469.],
       [ 133.,  331.,  529.,  727.],
       [ 175.,  445.,  715.,  985.],
       [ 217.,  559.,  901., 1243.],
       [ 110.,  362.,  614.,  866.]]))

In [None]:
# Chapter 15

def corr(t):
  return lambda theta: correlate(theta[0], t) + theta[1]

def recu(t):
  return lambda theta: rectify(corr(t)(theta))

def recu_block(b, m, d):
    return block(recu, [[b, m, d], [b]])

sum_2 = sum_1

sum_cols = ext_1(sum_2, 2)

assert jnp.array_equal(sum_cols(jnp.array([[1, 2, 3], [4, 5, 6]])),
                       jnp.array([5, 7, 9]))

def signal_avg(t):
  return lambda unused_theta: sum_cols(t) / shape(t)[rank(t) - 2]

signal_avg_block = block(signal_avg, [])

def fcn_block(b, m, d):
  return stack_blocks([
      recu_block(b, m, d),
      recu_block(b, m, b)
  ])

morse_fcn = stack_blocks([
    fcn_block(4, 3, 1),
    fcn_block(8, 3, 4),
    fcn_block(16, 3, 8),
    fcn_block(26, 3, 16),
    signal_avg_block
])

def init_shape(s, key):
  if len(s) == 1:
    return jnp.zeros(s)
  elif len(s) == 2:
    return random_tensor(0.0, 2.0 / s[1], s, key)
  elif len(s) == 3:
    return random_tensor(0.0, 2.0 / (s[1] * s[2]), s, key)

def trained_morse(classifier, theta_shapes):
  return model(classifier,
               adam_gradient_descent(
                   sampling_obj(l2_loss(classifier), morse_train_xs, morse_train_ys, jax.random.key(42)),
                   init_theta(theta_shapes, jax.random.key(42))))

def train_morse(network):
  with hypers(alpha, 0.0005), hypers(revs, 20000), hypers(batch_size, 8), hypers(mu, 0.9), hypers(beta, 0.999):
    return trained_morse(block_fn(network), block_ls(network))

print("morse_fcn: accuracy", accuracy(train_morse(morse_fcn), morse_test_xs, morse_test_ys))

def skip(f, j):
  return lambda t: lambda theta: f(t)(theta) + correlate(theta[j], t)

def skip_block(ba, d, b):
  shape_list = block_ls(ba)
  return block(skip(block_fn(ba), len(shape_list)), shape_list + [[b, 1, d]])

def residual_block(b, m, d):
  return skip_block(fcn_block(b, m, d), d, b)

morse_residual = stack_blocks([
    residual_block(4, 3, 1),
    residual_block(8, 3, 4),
    residual_block(16, 3, 8),
    residual_block(26, 3, 16),
    signal_avg_block
])

print("morse_residual: accuracy", accuracy(train_morse(morse_residual), morse_test_xs, morse_test_ys))

morse_fcn: accuracy 0.94134617
morse_residual: accuracy 0.9634615
