In [1]:
import pdb
import traceback

import jax
from jax import numpy as jnp
import optax
import wandb
import equinox as eqx

from common_dl_utils.config_creation import Config
import common_jax_utils as cju

wandb.login()

key = jax.random.PRNGKey(12398)
key_gen = cju.key_generator(key)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmaxwell_litsios[0m ([33mbep-circle[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
config = Config()

# first we specify what the model should look like
config.architecture = './model_components'  # module containing all relevant classes for architectures
# NB if the classes relevant for creating the model are spread over multiple modules, this is no problem
# let config.architecture be the module that contains the "main" model class, and for all other components just specify the module
# or specify the other modules as default modules to the tools in common_jax_utils.run_utils
config.model_type = 'inr_modules.CombinedINR'

config.model_config = Config()
config.model_config.in_size = 2
config.model_config.out_size = 1
config.model_config.terms = [  # CombinedINR uses multiple MLPs and returns the sum of their outputs. These 'terms' are the MLPs
    ('inr_modules.MLPINR.from_config',{
        'hidden_size': 1028,
        'num_layers': 5,
        # 'layer_type': 'inr_layers.GaussianINRLayer',
        'layer_type': 'inr_layers.ComplexWIRE',
        'num_splits': 3,
        'activation_kwargs': {'w0': 25., "s0":15},
        'initialization_scheme':'initialization_schemes.siren_scheme',
        'initialization_scheme_kwargs': {'w0': 12.},
        'post_processor':'auxiliary.real_scalar'
        #'positional_encoding_layer': ('inr_layers.ClassicalPositionalEncoding.from_config', {'num_frequencies': 10}),
    }),
    # ('inr_modules.MLPINR.from_config',{
    #     'hidden_size': 1024,
    #     'num_layers': 2,
    #     'num_splits': 1,
    #     'layer_type': 'inr_layers.GaussianINRLayer',
    #     'use_complex': False,
    #     'activation_kwargs': {'inverse_scale': 1},
    # })
]
#config.model_config.post_processor = lambda x: x[0]

In [3]:
# let's first see if we get the correct model
try:
    inr = cju.run_utils.get_model_from_config_and_key(
        prng_key=next(key_gen),
        config=config,
        model_sub_config_name_base='model',
        add_model_module_to_architecture_default_module=False, # since the model is already in the default module specified by 'architecture',
    )
except Exception as e:
    traceback.print_exc()
    print(e)
    print('\n')
    pdb.post_mortem()

In [4]:
def apply_inr(inr, location):
    return inr(location)

inr_grad = eqx.filter_grad(apply_inr)

In [5]:
inr_grad(inr, jnp.array([0.1, 0.75]))

  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)


CombinedINR(
  terms=(
    MLPINR(
      layers=(
        ComplexWIRE(
          weights=f32[1028,2],
          biases=f32[1028],
          activation_kwargs={'s0': 15, 'w0': 25.0}
        ),
        ComplexWIRE(
          weights=f32[1028,1028],
          biases=f32[1028],
          activation_kwargs={'s0': 15, 'w0': 25.0}
        ),
        ComplexWIRE(
          weights=f32[1028,1028],
          biases=f32[1028],
          activation_kwargs={'s0': 15, 'w0': 25.0}
        ),
        ComplexWIRE(
          weights=f32[1028,1028],
          biases=f32[1028],
          activation_kwargs={'s0': 15, 'w0': 25.0}
        ),
        Linear(weights=f32[1,1028], biases=f32[1], activation_kwargs={}),
        Lambda(fn=None)
      )
    ),
  ),
  post_processor=None
)

In [6]:
def tree_inner_product(tree_1, tree_2):
    component_wise = jax.tree.map(lambda x, y: jnp.sum(x*y), tree_1, tree_2)
    return sum(jax.tree.leaves(component_wise))

In [7]:
tree_inner_product(
    inr_grad(inr, jnp.array([0.1, 0.75])),
    inr_grad(inr, jnp.array([0.2, 0.65]))
)

Array(1.9876883, dtype=float32)

In [8]:

def ntk_single(inr, loc_1, loc_2):
    return tree_inner_product(
    inr_grad(inr, loc_1),
    inr_grad(inr, loc_2)
)

def _ntk_single(inr, loc1loc2):
    channels = loc1loc2.shape[-1]//2
    loc_1 = loc1loc2[:channels]
    loc_2 = loc1loc2[channels:]
    return ntk_single(inr, loc_1, loc_2)

def ntk_array(inr, locations, batch_size):
    channels = locations.shape[-1]
    locations = locations.reshape(-1, channels)

    #first on the lower triangle
    loc_1_indices, loc_2_indices = jnp.tril_indices(locations.shape[0])
    loc_1 = locations[loc_1_indices]
    loc_2 = locations[loc_2_indices]
    loc_1_loc_2 = jnp.concatenate([loc_1, loc_2], -1)
    
    apply_ntk_single_batch = lambda batch: jax.vmap(_ntk_single, (None, 0))(inr, batch)
    batches = loc_1_loc_2.reshape((-1, batch_size, 2*channels))
    print(f"{batches.shape=}")
    num_batches = batches.shape[0]
    resulting_batches = jax.lax.map(apply_ntk_single_batch, batches)
    results_flat = resulting_batches.reshape(num_batches*batch_size)

    return results_flat




In [17]:
from inr_utils.images import make_lin_grid
locations = make_lin_grid(0., 1., (28, 28))


In [19]:
kernel = ntk_array(inr, locations, 4*28)

TypeError: cannot reshape array of shape (307720, 4) (size 1230880) into shape (-1, 112, 4) because the product of specified axis sizes (448) does not evenly divide 1230880

In [None]:
kernel.shape

In [20]:
1230880/(4*28)

10990.0