In [1]:
import jax.numpy as np
import neural_tangents as nt
from neural_tangents import stax
from jax.experimental import stax as ostax
from jax import random
import tensorflow_datasets as tfds
import tensorflow as tf
import functools
from jax.api import jit, grad, vmap

In [2]:
key = random.PRNGKey(10)



In [3]:
train_dataset = tfds.load(name="mnist", split=tfds.Split.TRAIN)
test_dataset = tfds.load(name="mnist", split=tfds.Split.TEST)
# Build your input pipeline
train_dataset = train_dataset.shuffle(1024)
test_dataset = test_dataset.shuffle(1024)



In [4]:
x_train = []
y_train = []
for i in train_dataset.take(1000):
    x_train.append(np.reshape(i['image'].numpy()/255, (-1)))
    y_train.append(np.eye(10)[i['label'].numpy()])

    
x_test = []
y_test = []
for i in test_dataset.take(1000):
    x_test.append(np.reshape(i['image'].numpy()/255, (-1)))
    y_test.append(i['label'].numpy())
    
x_train = np.array(x_train)
x_test = np.array(x_test)
y_train = np.array(y_train)
y_test = np.array(y_test)

In [89]:
init_relu, apply_relu, kernel_relu = stax.serial(
    stax.Dense(512), stax.Relu(),
    stax.Dense(10)
)

init_erf, apply_erf, kernel_erf = stax.serial(
    stax.Dense(512), stax.Erf(),
    stax.Dense(10)
)

In [6]:
_, params = init_fn(key, (-1, 784))

In [90]:
ntk_mean, ntk_covariance = nt.predict.gp_inference(
    kernel_erf, x_train, y_train, x_test, 
    diag_reg=1e-4, get='ntk', compute_cov=True)




In [91]:
def acc_metric(means, y):
    return np.sum(np.argmax(means, axis=1)==y)/len(y)

def test_func(func):
    means, covs = func()

In [92]:
ntk_mean.shape

(1000, 10)

In [93]:
acc_metric(ntk_mean, y_test)

DeviceArray(0.883, dtype=float32)

# Losses

In [62]:
def loss_l2(predict_fn, ys, t):
    mean, var = predict_fn(t)
    mean = np.argmax(mean, axis=1)
    var = np.diag(var)
    ys = np.reshape(ys, (-1,))
    mean_predictions = 0.5 * np.mean(ys ** 2 - 2 * mean * ys + var + mean ** 2)

    return mean_predictions

def loss_logcosh(predict_fn, ys, t):
    mean, var = predict_fn(t)
    mean = np.reshape(mean, (-1,))
    var = np.diag(var)
    ys = np.reshape(ys, (-1,))

    mean_predictions = np.mean(ys ** 2 - 2 * mean * ys + var + mean ** 2)

    return mean_predictions

cross_entropy = lambda fx, y_hat: -np.mean(ostax.logsoftmax(fx) * y_hat)
l2 = lambda fx, y_hat: 0.5*np.mean((ostax.softmax(fx)-y_hat)**2)
log_cosh = lambda fx, y_hat: np.mean(np.log(np.cosh(ostax.softmax(fx)-y_hat)))

# MSE

In [83]:
test_predict_fn = nt.predict.gradient_descent_mse_gp(
    kernel_erf, x_train, y_train, x_test, 'ntk', 1e-4, compute_cov=True)

test_loss_fn = functools.partial(loss_l2, test_predict_fn, y_test)

In [78]:
acc_metric(ostax.softmax(train_fin), np.argmax(y_train, axis=1))

DeviceArray(0.195, dtype=float32)

In [79]:
def loss_calc(loss_func, x_train, y_train):
    ts = np.arange(0, 1000000000 , 10000000)
    test_predict_fn = nt.predict.gradient_descent(
        kernel_fn(x_train, x_train).ntk, y_train, loss_func)
    test_loss_fn = functools.partial(loss_func, test_predict_fn, y_train)
    loss = []
    for t in ts:
        fx = test_predict_fn(t, apply_fn(params, x_train))
        print(acc_metric(fx, np.argmax(y_train, axis=1)))
        loss.append(loss_func(fx, y_train))
        
    return loss

In [56]:
cross_loss = loss_calc(cross_entropy, x_train, y_train)

0.118
0.92


KeyboardInterrupt: 

In [80]:
l2_loss = loss_calc(l2, x_train, y_train)

  self.messages.get(istate, unexpected_istate_msg)))


0.118
0.272
0.187
0.107
0.107
0.107
0.107


KeyboardInterrupt: 

In [65]:
log_cosh_loss = loss_calc(log_cosh, x_train, y_train)

0.118
0.119
0.119
0.119
0.119
0.119
0.121
0.121
0.122
0.124


In [81]:
cross_loss

[DeviceArray(0.22911297, dtype=float32),
 DeviceArray(0.22833882, dtype=float32),
 DeviceArray(0.2275747, dtype=float32),
 DeviceArray(0.22681998, dtype=float32),
 DeviceArray(0.22607405, dtype=float32),
 DeviceArray(0.22533622, dtype=float32),
 DeviceArray(0.22460659, dtype=float32),
 DeviceArray(0.2238845, dtype=float32),
 DeviceArray(0.22316962, dtype=float32),
 DeviceArray(0.2224616, dtype=float32),
 DeviceArray(0.22176027, dtype=float32),
 DeviceArray(0.22106552, dtype=float32),
 DeviceArray(0.22037691, dtype=float32),
 DeviceArray(0.2196936, dtype=float32),
 DeviceArray(0.21901648, dtype=float32),
 DeviceArray(0.2183444, dtype=float32),
 DeviceArray(0.21767794, dtype=float32),
 DeviceArray(0.21701613, dtype=float32),
 DeviceArray(0.2163593, dtype=float32),
 DeviceArray(0.21570711, dtype=float32),
 DeviceArray(0.21505977, dtype=float32),
 DeviceArray(0.21441643, dtype=float32),
 DeviceArray(0.21377754, dtype=float32),
 DeviceArray(0.21314317, dtype=float32),
 DeviceArray(0.2125126

In [137]:
test_predict_fn = nt.predict.gradient_descent_mse(
    kernel_fn(x_train, x_train).ntk, y_train, kernel_fn(x_train, x_test).ntk , 1e-4)




In [138]:
test_loss_fn = functools.partial(loss_l2, test_predict_fn, y_test)

In [139]:
f_train_init, f_test_init = apply_fn(params, x_train), apply_fn(params, x_test)

In [140]:
acc_metric(test_predict_fn(1000000000, f_train_init, f_test_init)[0], np.argmax(y_train, axis=1))

DeviceArray(1., dtype=float32)

In [141]:
acc_metric(test_predict_fn(1, f_train_init, f_test_init)[1], y_test)

DeviceArray(0.131, dtype=float32)

In [73]:
f_train_init

DeviceArray([[ 0.10801702,  0.0298161 ,  0.04890977, ..., -0.06429182,
              -0.13902469,  0.13385423],
             [ 0.11882064, -0.04457122, -0.00194839, ...,  0.02325084,
              -0.13691665,  0.23880939],
             [ 0.06833835,  0.15590656,  0.06496935, ...,  0.07997879,
              -0.12056512,  0.26858053],
             ...,
             [ 0.01049394, -0.13339698,  0.05737164, ..., -0.10291094,
              -0.02151006,  0.22904924],
             [-0.19749318,  0.07150874,  0.04004507, ..., -0.01229034,
              -0.04406513,  0.26738104],
             [ 0.13985181, -0.10131788,  0.01434625, ..., -0.04564786,
              -0.17408863,  0.11591386]], dtype=float32)