In [1]:
!pip install -U -q "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [2]:
!pip install -U -q flax neural-tangents

In [3]:
import flax.linen as nn
from flax.core.frozen_dict import freeze
from flax.serialization import to_state_dict
import tensorflow_datasets as tfds
import jax.numpy as jnp
import jax
import neural_tangents as nt
from scipy.sparse.linalg import eigsh
import numpy as np

In [4]:
class Lenet300_100(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))
        x = nn.relu(nn.Dense(300)(x))
        x = nn.relu(nn.Dense(100)(x))
        x = nn.Dense(1)(x)
        return x

In [5]:
def normalize_data(data, mean=None, std=None):
    if mean is None or std is None:
        mean = jnp.mean(data, axis=[0, 1, 2])[jnp.newaxis, jnp.newaxis, jnp.newaxis, :]
        std = jnp.std(data, axis=[0, 1, 2])[jnp.newaxis, jnp.newaxis, jnp.newaxis, :]

    data = data - mean
    data = data / std
    return data, mean, std
    
def get_dataset(dataset, normalize=False, data_dir=None):
    """Load dataset train and test datasets into memory."""
    ds_builder = tfds.builder(dataset, data_dir=data_dir)
    ds_builder.download_and_prepare()

    train_data, train_labels = tfds.as_numpy(
        ds_builder.as_dataset(split="train", batch_size=-1, as_supervised=True, shuffle_files=False)
    )
    train_data = jnp.float32(train_data) / 255.0
    if normalize:
        train_data, mean, std = normalize_data(train_data)

    train_ds = {"data": train_data, "labels": train_labels}

    test_data, test_labels = tfds.as_numpy(
        ds_builder.as_dataset(split="test", batch_size=-1, as_supervised=True)
    )
    test_data = jnp.float32(test_data) / 255.0
    if normalize:
        test_data, _, _ = normalize_data(test_data, mean, std)

    test_ds = {"data": test_data, "labels": test_labels}

    return train_ds, test_ds

In [6]:
key = jax.random.PRNGKey(100)

In [7]:
train_ds, test_ds = get_dataset("mnist", normalize=True)

In [8]:
data = jnp.concatenate([train_ds["data"], test_ds["data"]], axis=0)

In [9]:
model = Lenet300_100()

In [10]:
batch = jnp.ones([1, 28, 28, 1])
variables = model.init(key, batch)

In [11]:
BATCH_SIZE = 5000

## Eigenpairs of NTK at initialization

In [12]:
def get_apply_fn(model, variables, bn=False, train=False):
    if not bn:
        model_state, _ = variables.pop("params")

        def apply_fn(params, x):
            new_vars = freeze({'params': params, **model_state})
            logits = model.apply(new_vars, x, mutable=False)
            return logits
    else:
        def apply_fn(params, x):
            logits = model.apply(params, x, mutable=train)
            return logits

    return apply_fn

In [13]:
data_small = data[:20000]

In [14]:
def eigendecomposition_ntk(model, variables, batch_size, data, num_eigvecs):
    apply_fn = get_apply_fn(model, variables, bn=False, train=False)

    kernel_fn = nt.batch(
        nt.empirical_ntk_fn(apply_fn, vmap_axes=0),
        batch_size = batch_size,
        device_count = -1,
        store_on_device = False
    )

    ntk_nxn_mat = kernel_fn(data, None, variables["params"])
    eig_vals, eig_vecs = eigsh(jax.device_get(ntk_nxn_mat), k=num_eigvecs)

    eig_vals = np.flipud(eig_vals)
    eig_vecs = np.flipud(eig_vecs.T) # Eigen vectors are now row-wise

    return eig_vals, eig_vecs

In [15]:
eig_vals_init, eig_vecs_init = eigendecomposition_ntk(model, variables, BATCH_SIZE, data_small, num_eigvecs=100)

In [16]:
print(eig_vals_init)

[1196159.9     217486.62    157252.83    122719.484   113999.516
   88319.336    82553.91     65963.48     63081.773    60508.656
   54187.074    49453.027    47070.04     43218.406    40318.11
   38757.145    35068.207    32243.582    31308.727    29916.234
   28223.951    27554.908    26243.307    25889.553    23761.55
   23130.74     21845.945    21046.936    20155.262    19308.348
   18848.316    18242.432    17299.379    16723.928    16455.562
   15975.346    15394.872    15054.196    14851.225    14216.606
   13805.492    13672.652    13260.9375   13068.598    12581.18
   12139.04     11750.129    11400.504    11170.843    11123.215
   10887.76     10503.474    10133.709     9976.055     9777.321
    9687.691     9421.513     9169.092     8960.049     8851.499
    8698.974     8453.13      8413.308     8258.492     8164.882
    7977.204     7804.7354    7614.984     7546.01      7386.6094
    7343.029     7241.948     7147.9995    7008.2373    6996.7734
    6849.672     6814.4556

In [17]:
eig_vecs_init.shape

(100, 20000)

In [18]:
eig_vecs_init[0]

array([-0.00629224, -0.00533323, -0.01033875, ..., -0.0053913 ,
       -0.00603498, -0.00546225], dtype=float32)

In [19]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [20]:
train_ds['labels'][:200]

array([4, 1, 0, 7, 8, 1, 2, 7, 1, 6, 6, 4, 7, 7, 3, 3, 7, 9, 9, 1, 0, 6,
       6, 9, 9, 4, 8, 9, 4, 7, 3, 3, 0, 9, 4, 9, 0, 6, 8, 4, 7, 2, 6, 0,
       3, 1, 1, 7, 2, 4, 4, 6, 5, 1, 9, 3, 2, 4, 3, 4, 4, 7, 5, 8, 1, 1,
       4, 1, 5, 3, 5, 8, 4, 1, 1, 4, 5, 3, 2, 4, 1, 4, 8, 1, 2, 1, 9, 0,
       7, 6, 7, 4, 4, 9, 7, 5, 6, 8, 4, 6, 9, 2, 9, 4, 4, 9, 5, 4, 5, 7,
       7, 1, 8, 3, 7, 9, 8, 4, 9, 2, 8, 0, 3, 9, 4, 7, 6, 6, 1, 4, 0, 2,
       9, 1, 7, 7, 4, 1, 8, 5, 0, 5, 0, 9, 6, 3, 8, 9, 9, 7, 3, 1, 2, 2,
       7, 8, 6, 4, 0, 6, 2, 4, 2, 4, 8, 9, 8, 5, 9, 8, 4, 7, 6, 9, 3, 9,
       9, 8, 9, 4, 1, 1, 8, 5, 9, 9, 2, 3, 8, 7, 1, 4, 2, 2, 4, 8, 6, 1,
       4, 2])

In [21]:
variables_state_dict = to_state_dict(variables)

In [22]:
out_dir = "/content/drive/MyDrive/Paper Implementations/What can linearized neural networks actually say about generalization?/artifacts_mnist_mlp"

In [23]:
jnp.save(f'{out_dir}/eig_vecs_init.npy', eig_vecs_init)
jnp.save(f'{out_dir}/eig_vals_init.npy', eig_vals_init)
jnp.save(f'{out_dir}/variables_init.npy', variables_state_dict)
jnp.save(f'{out_dir}/data.npy', data_small)