In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import os.path as osp
import sys

if "../src" not in sys.path:
    sys.path.insert(0, "../src")

In [3]:
import numpy as onp

import jax
from jax import numpy as jnp, random
from jax.random import split, PRNGKey

import haiku as hk
from haiku.data_structures import merge
import optax as ox

from helpers import SimpleModel
from utils.losses import mean_xe_and_acc_dict
from datasets import get_dataset
from utils import (
    shuffle_along_axis,
    use_self_as_default,
    split_rng_or_none,
    call_self_as_default,
    expand,
    tree_shape,
    first_leaf_shape,
    tree_flatten_array,
    mean_of_f,
)
from utils.data.datasets import (
    SimpleDataset,
    ImageDataset,
    MetaDataset,
    MetaDatasetArray,
)

from meta.trainers import MetaTrainerB
from meta.base import MetaBase
from meta.wrappers import MetaLearnerBaseB, ContinualLearnerB

from easydict import EasyDict as edict

In [4]:
did = ImageDataset(onp.arange(6).repeat(20).astype(jnp.float32), onp.arange(6).repeat(20), 1, 0.1)
md = MetaDatasetArray(
    did,
    2, 3, 4, 5, 25, False, False,
)

In [5]:
trainset = get_dataset("omniglot", "train", all=True, train=True, image_size=28)

In [6]:
meta_trainset = MetaDatasetArray(trainset, 2, 5, 5, 15, 64)
meta_trainset_sampler = meta_trainset.get_sampler(PRNGKey(0))

In [7]:
def pmap_init(model, static_args, static_kwargs, *args, **kwargs):
    return jax.pmap(jax.partial(model.init, *static_args, **static_kwargs), axis_name="i")(*args, **kwargs)

In [8]:
cross_repliace_axis = "i"

model = hk.transform_with_state(
    lambda x, phase, training: SimpleModel(
        cross_replica_axis=cross_repliace_axis,
    )(x, phase, training)
)

if cross_repliace_axis is not None:
    params, state = pmap_init(model, (PRNGKey(0),), dict(phase="all", training=True), random.normal(PRNGKey(1), (1, 1, 2)))
else:
    params, state = model.init(PRNGKey(0), random.normal(PRNGKey(1), (2, 2)), "all", True)

continual_learner = ContinualLearnerB(
    model.apply,
    params,
    state,
    SimpleModel.train_slow_phase,
    SimpleModel.train_fast_phase,
    SimpleModel.get_train_slow_params,
    SimpleModel.get_train_fast_params,
    SimpleModel.get_train_slow_state,
    SimpleModel.get_train_fast_state,
    3e-2,
)

In [9]:
meta_trainer = MetaTrainerB(
    continual_learner,
    model.apply,
    params,
    state,
    SimpleModel.train_slow_phase,
    SimpleModel.train_fast_phase,
    SimpleModel.get_train_slow_params,
    SimpleModel.get_train_fast_params,
    SimpleModel.get_train_slow_state,
    SimpleModel.get_train_fast_state,
    inner_lr=1e-2,
    train_lr=True,
    optimizer=ox.adam(1e-4),
    cross_replica_axis=cross_repliace_axis,
)

meta_trainer.replicate_state().initialize_opt_state()

<meta.trainers.MetaTrainerB at 0x7fe77a801a10>

In [10]:
print(tree_shape(meta_trainer.params))
print(tree_shape(meta_trainer.state))
print(tree_shape(meta_trainer.inner_lr))

FlatMapping({
  'simple_model/fast/batch_norm': FlatMapping({'offset': (1, 1, 8), 'scale': (1, 1, 8)}),
  'simple_model/fast/linear': FlatMapping({'b': (1, 8), 'w': (1, 8, 8)}),
  'simple_model/fast_1/batch_norm': FlatMapping({'offset': (1, 1, 4), 'scale': (1, 1, 4)}),
  'simple_model/fast_1/linear': FlatMapping({'b': (1, 4), 'w': (1, 8, 4)}),
  'simple_model/fast_1/linear_1': FlatMapping({'b': (1, 2), 'w': (1, 4, 2)}),
  'simple_model/slow1/batch_norm': FlatMapping({'offset': (1, 1, 8), 'scale': (1, 1, 8)}),
  'simple_model/slow1/linear': FlatMapping({'b': (1, 8), 'w': (1, 2, 8)}),
  'simple_model/slow2/batch_norm': FlatMapping({'offset': (1, 1, 8), 'scale': (1, 1, 8)}),
  'simple_model/slow2/linear': FlatMapping({'b': (1, 8), 'w': (1, 8, 8)}),
})
FlatMapping({
  'simple_model/fast/batch_norm/~/mean_ema': FlatMapping({'average': (1, 1, 8), 'counter': (1,), 'hidden': (1, 1, 8)}),
  'simple_model/fast/batch_norm/~/var_ema': FlatMapping({'average': (1, 1, 8), 'counter': (1,), 'hidden': (

In [15]:
inputs = (random.normal(PRNGKey(1), (4, 2, 2)),
jnp.arange(2)[:, None].repeat(2, 0).repeat(2, 1),
random.normal(PRNGKey(1), (4, 2, 2)),
jnp.arange(2)[:, None].repeat(2, 0).repeat(2, 1),
random.normal(PRNGKey(1), (4, 2, 2)),
jnp.arange(2)[:, None].repeat(2, 0).repeat(2, 1),)

if cross_repliace_axis is not None:
    inputs = jax.tree_map(expand, inputs)
    

out = meta_trainer.step(
    PRNGKey(0),
    0,
    *inputs,
    # random.normal(PRNGKey(1), (1, 4, 2, 2)),
    # jnp.arange(2)[None, :, None].repeat(2, 1).repeat(2, 2),
    # random.normal(PRNGKey(1), (1, 4, 2, 2)),
    # jnp.arange(2)[None, :, None].repeat(2, 1).repeat(2, 2),
    # random.normal(PRNGKey(1), (1, 4, 2, 2)),
    # jnp.arange(2)[None, :, None].repeat(2, 1).repeat(2, 2),
    # None,
    # meta_trainer.params,
    # meta_trainer.state,
)

In [10]:
class _MetaTrainerB(MetaLearnerBaseB):
    def __init__(
        self,
        learner,
        *args,
        opt_state=None,
        inner_lr=None,
        augmentation="none",
        augmentation_fn=None,
        alt=True,
        train_lr=False,
        optimizer=None,
        scheduler=None,
        reset_fast_params=None,
        reset_before_outer_loop=True,
        cross_replica_axis=None,
        include_spt=False,
        **kwargs,
        # inner_lr, training, loss_fn, init_opt_state, opt_update
    ):
        super().__init__(*args, **kwargs)
        self.learner = learner
        self.alt = alt
        self.augmentation = augmentation
        self.augmentation_fn = augmentation_fn
        self.train_lr = train_lr
        self.inner_lr = jnp.ones([]) * inner_lr
        self.reset_fast_params = reset_fast_params
        self.reset_before_outer_loop = reset_before_outer_loop
        self.cross_replica_axis = cross_replica_axis
        self.include_spt = include_spt
        self.opt_state = opt_state
        self.scheduler = scheduler
        self.optimizer = optimizer

        update = jax.partial(
            lambda rng, step_num, x_spt, y_spt, x_qry, y_qry, x_qry_cl, y_qry_cl, spt_classes, params, state, opt_state: self._update(
                rng,
                step_num,
                x_spt,
                y_spt,
                x_qry,
                y_qry,
                x_qry_cl,
                y_qry_cl,
                spt_classes,
                params,
                state,
                opt_state=opt_state,
            )
        )
        
        if self.cross_replica_axis is not None:
            self.update = jax.pmap(update, axis_name="i")
        else:
            self.update = jax.jit(update)

        # if self.train_lr:
        #     self.params = (self.params, self.inner_lr)
        # else:
        #     self.params = self.params,

    def initialize_opt_state(self):
        if self.train_lr:
            params = (self.params, self.inner_lr)
        else:
            params = (self.params,)
        opt_state = self.optimizer.init(params)
        self.opt_state = opt_state
        return self

    def replicate_state(self):
        # self.params = jax.device_put_replicated(self.params, jax.local_devices())
        # self.state = jax.device_put_replicated(self.state, jax.local_devices())
        if self.train_lr:
            self.inner_lr = jax.device_put_replicated(
                self.inner_lr, jax.local_devices()
            )
        # self.opt_state = jax.device_put_replicated(self.opt_state, jax.local_devices())
        return self

    def apply_augmentation(self, rng, _input):
        if self.augmentation_fn is not None:
            return self.augmentation_fn(rng, _input)
        return _input

    def augments(self, rng, x_spt, x_qry, x_qry_cl):
        rng = split(rng, 3)
        if self.augmentation in ["spt", "all"]:
            x_spt = self.apply_augmentation(rng[0], x_spt)
        if self.augmentation in ["qry", "all"]:
            x_qry = self.apply_augmentation(rng[0], x_qry)
            x_qry_cl = self.apply_augmentation(rng[0], x_qry_cl)

        return x_spt, x_qry, x_qry_cl

    def step(
        self,
        rng,
        step_num,
        x_spt,
        y_spt,
        x_qry,
        y_qry,
        x_qry_cl,
        y_qry_cl,
        spt_classes=None,
    ):
        # spt_classes = onp.unique(y_spt, axis=-1)

        if self.cross_replica_axis is not None:
            rng = split(rng, jax.local_device_count())
            step_num = jax.device_put_replicated(step_num, jax.local_devices())
            if spt_classes is None:
                spt_classes = [None] * jax.local_device_count()

        if self.train_lr:
            params = (self.params, self.inner_lr)
        else:
            params = self.params

        loss, params, opt_state, out = self.update(
            rng,
            step_num,
            x_spt,
            y_spt,
            x_qry,
            y_qry,
            x_qry_cl,
            y_qry_cl,
            spt_classes,
            params,
            state,
            self.opt_state,
        )
        if self.train_lr:
            self.params = params[0]
            self.inner_lr = params[1]
        else:
            self.params = params
        self.state = merge(
            out["initial_outer_out"]["slow_state"],
            out["inner_out"]["fast_state"],
        )
        self.opt_state = opt_state

        return loss, out, self.inner_lr

    @use_self_as_default(
        "training",
        # "init_opt_state",
        "reset_fast_params",
        "reset_before_outer_loop",
        "include_spt",
        "scheduler",
        "inner_lr",
        "loss_fn",
        # "opt_update",
        "optimizer",
    )
    def _update(
        self,
        rng,
        step_num,
        x_spt,
        y_spt,
        x_qry,
        y_qry,
        x_qry_cl,
        y_qry_cl,
        spt_classes,
        params,
        state,
        training=None,
        loss_fn=None,
        opt_state=None,
        inner_lr=None,
        # init_opt_state=None,
        # opt_update=None,
        optimizer=None,
        alt=None,
        reset_fast_params=None,
        reset_before_outer_loop=None,
        include_spt=None,
        scheduler=None,
        inner_opt_state=None,
        inner_init_opt_state=None,
        inner_opt_update=None,
    ):
        
        print(tree_shape(params))
        print(tree_shape(state))
        print(tree_shape(opt_state))
        
        rng_data, rng_reset, rng_pre, rng = split(rng, 4)

        # inputs come as uint8 for speedier transfer
        x_spt, x_qry, x_qry_cl = self.augments(
            rng_data, x_spt / 255, x_qry / 255, x_qry_cl / 255
        )

        if self.train_lr:
            _params = params
        else:
            _params = (params,)

        if opt_state is None:
            # opt_state = self.init_opt_state(params)
            opt_state = optimizer.init(params)

        if reset_before_outer_loop and reset_fast_params:
            params = (reset_fast_params(rng_reset, params[0], spt_classes), *params[1:])

        # _, pre_slow_state = self.slow_apply(
        #     x_qry_cl,
        #     rng_pre,
        #     params,
        #     state,
        #     training,
        # )

        if include_spt:
            x_qry = jnp.concatenate((x_spt, x_qry, x_qry_cl))
            y_qry = jnp.concatenate((y_spt, y_qry, y_qry_cl))
        else:
            x_qry = jnp.concatenate((x_qry, x_qry_cl))
            y_qry = jnp.concatenate((y_qry, y_qry_cl))

        def helper(_params, _lr=inner_lr):
            loss, out = self.outer_loss(
                x_spt,
                y_spt,
                x_qry,
                y_qry,
                spt_classes=spt_classes,
                rng=rng,
                params=_params,
                # state=pre_slow_state,
                state=state,
                training=training,
                loss_fn=loss_fn,
                opt_state=inner_opt_state,
                lr=_lr,
                init_opt_state=inner_init_opt_state,
                opt_update=inner_opt_update,
                alt=alt,
                reset_fast_params=reset_fast_params,
                reset_before_outer_loop=reset_before_outer_loop,
            )
            if self.cross_replica_axis is not None:
                loss = jax.lax.pmean(loss, self.cross_replica_axis)
            else:
                loss = jnp.mean(loss)
            return jnp.mean(loss), out

        (loss, out), grads = jax.value_and_grad(
            helper, has_aux=True, argnums=tuple(range(len(_params)))
        )(*_params)
        if self.cross_replica_axis is not None:
            grads = jax.lax.pmean(grads, self.cross_replica_axis)

        updates, opt_state = optimizer.update(grads, opt_state, params)
        if self.scheduler is not None:
            updates = self.scheduler(updates)
        params = ox.apply_updates(params, updates)
        
        print(tree_shape(params))
        print(tree_shape(state))
        print(tree_shape(opt_state))

        return loss, params, opt_state, out

    @use_self_as_default(
        "alt",
        "params",
        "state",
        "training",
        # "lr",
        "loss_fn",
        # "init_opt_state",
        # "opt_update",
        "reset_fast_params",
        "reset_before_outer_loop",
    )
    def outer_loss(
        self,
        x_spt,
        y_spt,
        x_qry,
        y_qry,
        spt_classes=None,
        rng=None,
        params=None,
        state=None,
        training=None,
        loss_fn=None,
        opt_state=None,
        lr=None,
        init_opt_state=None,
        opt_update=None,
        alt=None,
        reset_fast_params=None,
        reset_before_outer_loop=None,
    ):
        rng_inner, rng_outer_slow, rng_outer_fast, rng_reset = split_rng_or_none(rng, 4)

        if (lr is not None) and (not self.train_lr):
            lr = jax.lax.stop_gradient(lr)

        if (spt_classes is not None) and reset_before_outer_loop and reset_fast_params:
            params = reset_fast_params(rng_reset, params, spt_classes)

        slow_outputs, slow_state = self.learner.slow_apply(
            x_qry,
            rng_outer_slow,
            params,
            state,
            training,
        )

        loss, (new_state, loss_aux, outputs) = self.learner.fast_apply_and_loss(
            slow_outputs,
            y_qry,
            rng=rng_outer_fast,
            params=params,
            state=slow_state,
            training=training,
            loss_fn=loss_fn,
            # alt=alt,
        )
        initial_outer_out = dict(
            loss=loss,
            slow_state=slow_state,
            fast_state=new_state,
            loss_aux=loss_aux,
            outputs=outputs,
        )

        inner_out = self.learner.inner_loop(
            x_spt,
            y_spt,
            params,
            slow_state,
            rng_inner,
            training=training,
            opt_state=opt_state,
            lr=lr,
            loss_fn=loss_fn,
            init_opt_state=init_opt_state,
            opt_update=opt_update,
        )

        if alt:
            slow_outputs, y_qry = tree_flatten_array((slow_outputs, y_qry))

        loss, (new_state, loss_aux, outputs) = self.learner.fast_apply_and_loss(
            slow_outputs,
            y_qry,
            rng=split(rng_outer_fast)[0],
            params=params,
            state=slow_state,
            training=training,
            fast_params=inner_out["fast_params"],
            fast_state=inner_out["fast_state"],
            loss_fn=loss_fn,
            alt=alt,
        )
        outer_out = dict(
            loss=loss, fast_state=new_state, loss_aux=loss_aux, outputs=outputs
        )

        return (outer_out["loss"]), dict(
            inner_out=inner_out,
            outer_out=outer_out,
            initial_outer_out=initial_outer_out,
        )

In [11]:
# init_opt_state = ox.adam(0).init
# opt_update = make_simple_opt_update(ox.adam)

model = hk.transform_with_state(
    lambda x, phase, training: SimpleModel(
        cross_replica_axis="i",
    )(x, phase, training)
)
# params, state = model.init(PRNGKey(0), random.normal(PRNGKey(1), (2, 2)), "all", True)

params, state = pmap_init(model, (PRNGKey(0),), dict(phase="all", training=True), random.normal(PRNGKey(1), (1, 1, 2)))

continual_learner = ContinualLearnerB(
    model.apply,
    params,
    state,
    SimpleModel.train_slow_phase,
    SimpleModel.train_fast_phase,
    SimpleModel.get_train_slow_params,
    SimpleModel.get_train_fast_params,
    SimpleModel.get_train_slow_state,
    SimpleModel.get_train_fast_state,
    3e-2,
)

meta_trainer = MetaTrainerB(
    continual_learner,
    model.apply,
    params,
    state,
    SimpleModel.train_slow_phase,
    SimpleModel.train_fast_phase,
    SimpleModel.get_train_slow_params,
    SimpleModel.get_train_fast_params,
    SimpleModel.get_train_slow_state,
    SimpleModel.get_train_fast_state,
    inner_lr=1e-2,
    train_lr=True,
    # init_opt_state=init_opt_state,
    # opt_update=opt_update,
    optimizer=ox.adam(1e-4),
    cross_replica_axis="i",
)
# out = meta_trainer.outer_loss(
#     random.normal(PRNGKey(1), (4, 2, 2)),
#     jnp.arange(2)[:, None].repeat(2, 0).repeat(2, 1),
#     random.normal(PRNGKey(1), (4, 2, 2)),
#     jnp.arange(2)[:, None].repeat(2, 0).repeat(2, 1),
# )

meta_trainer.initialize_opt_state().replicate_state()

out = meta_trainer.step(
    PRNGKey(0),
    0,
    random.normal(PRNGKey(1), (1, 4, 2, 2)),
    jnp.arange(2)[None, :, None].repeat(2, 1).repeat(2, 2),
    random.normal(PRNGKey(1), (1, 4, 2, 2)),
    jnp.arange(2)[None, :, None].repeat(2, 1).repeat(2, 2),
    random.normal(PRNGKey(1), (1, 4, 2, 2)),
    jnp.arange(2)[None, :, None].repeat(2, 1).repeat(2, 2),
    # None,
    # meta_trainer.params,
    # meta_trainer.state,
)

ValueError: pmap got arg 51 of rank 0 but axis to be mapped 0. The tree of ranks is:
((2, 1, 4, 3, 4, 3, 4, 3, [None], (FlatMapping({
  'simple_model/fast/batch_norm': FlatMapping({'offset': 3, 'scale': 3}),
  'simple_model/fast/linear': FlatMapping({'b': 2, 'w': 3}),
  'simple_model/fast_1/batch_norm': FlatMapping({'offset': 3, 'scale': 3}),
  'simple_model/fast_1/linear': FlatMapping({'b': 2, 'w': 3}),
  'simple_model/fast_1/linear_1': FlatMapping({'b': 2, 'w': 3}),
  'simple_model/slow1/batch_norm': FlatMapping({'offset': 3, 'scale': 3}),
  'simple_model/slow1/linear': FlatMapping({'b': 2, 'w': 3}),
  'simple_model/slow2/batch_norm': FlatMapping({'offset': 3, 'scale': 3}),
  'simple_model/slow2/linear': FlatMapping({'b': 2, 'w': 3}),
}), 1), FlatMapping({
  'simple_model/fast/batch_norm/~/mean_ema': FlatMapping({'average': 3, 'counter': 1, 'hidden': 3}),
  'simple_model/fast/batch_norm/~/var_ema': FlatMapping({'average': 3, 'counter': 1, 'hidden': 3}),
  'simple_model/fast_1/batch_norm/~/mean_ema': FlatMapping({'average': 3, 'counter': 1, 'hidden': 3}),
  'simple_model/fast_1/batch_norm/~/var_ema': FlatMapping({'average': 3, 'counter': 1, 'hidden': 3}),
  'simple_model/slow1/batch_norm/~/mean_ema': FlatMapping({'average': 3, 'counter': 1, 'hidden': 3}),
  'simple_model/slow1/batch_norm/~/var_ema': FlatMapping({'average': 3, 'counter': 1, 'hidden': 3}),
  'simple_model/slow2/batch_norm/~/mean_ema': FlatMapping({'average': 3, 'counter': 1, 'hidden': 3}),
  'simple_model/slow2/batch_norm/~/var_ema': FlatMapping({'average': 3, 'counter': 1, 'hidden': 3}),
}), [ScaleByAdamState(count=0, mu=(FlatMapping({
  'simple_model/fast/batch_norm': FlatMapping({'offset': 3, 'scale': 3}),
  'simple_model/fast/linear': FlatMapping({'b': 2, 'w': 3}),
  'simple_model/fast_1/batch_norm': FlatMapping({'offset': 3, 'scale': 3}),
  'simple_model/fast_1/linear': FlatMapping({'b': 2, 'w': 3}),
  'simple_model/fast_1/linear_1': FlatMapping({'b': 2, 'w': 3}),
  'simple_model/slow1/batch_norm': FlatMapping({'offset': 3, 'scale': 3}),
  'simple_model/slow1/linear': FlatMapping({'b': 2, 'w': 3}),
  'simple_model/slow2/batch_norm': FlatMapping({'offset': 3, 'scale': 3}),
  'simple_model/slow2/linear': FlatMapping({'b': 2, 'w': 3}),
}), 0), nu=(FlatMapping({
  'simple_model/fast/batch_norm': FlatMapping({'offset': 3, 'scale': 3}),
  'simple_model/fast/linear': FlatMapping({'b': 2, 'w': 3}),
  'simple_model/fast_1/batch_norm': FlatMapping({'offset': 3, 'scale': 3}),
  'simple_model/fast_1/linear': FlatMapping({'b': 2, 'w': 3}),
  'simple_model/fast_1/linear_1': FlatMapping({'b': 2, 'w': 3}),
  'simple_model/slow1/batch_norm': FlatMapping({'offset': 3, 'scale': 3}),
  'simple_model/slow1/linear': FlatMapping({'b': 2, 'w': 3}),
  'simple_model/slow2/batch_norm': FlatMapping({'offset': 3, 'scale': 3}),
  'simple_model/slow2/linear': FlatMapping({'b': 2, 'w': 3}),
}), 0)), ScaleState()]), {})

In [44]:
out = meta_trainer.step(
    PRNGKey(0),
    0,
    random.normal(PRNGKey(1), (1, 4, 2, 2)),
    jnp.arange(2)[None, :, None].repeat(2, 1).repeat(2, 2),
    random.normal(PRNGKey(1), (1, 4, 2, 2)),
    jnp.arange(2)[None, :, None].repeat(2, 1).repeat(2, 2),
    random.normal(PRNGKey(1), (1, 4, 2, 2)),
    jnp.arange(2)[None, :, None].repeat(2, 1).repeat(2, 2),
    # None,
    # meta_trainer.params,
    # meta_trainer.state,
)

AssertionError: 'simple_model/slow1/linear/w' with shape (1, 2, 8) does not match shape=[2, 8] dtype=dtype('float32')

In [86]:
meta_trainer.step(
    PRNGKey(0),
    0,
    random.normal(PRNGKey(1), (4, 2, 2)),
    jnp.arange(2)[:, None].repeat(2, 0).repeat(2, 1),
    random.normal(PRNGKey(1), (4, 2, 2)),
    jnp.arange(2)[:, None].repeat(2, 0).repeat(2, 1),
    random.normal(PRNGKey(1), (4, 2, 2)),
    jnp.arange(2)[:, None].repeat(2, 0).repeat(2, 1),
    # None,
    # meta_trainer.params,
    # meta_trainer.state,
)

(DeviceArray(0.6925241, dtype=float32),
 {'initial_outer_out': {'fast_state': FlatMapping({
     'simple_model/fast/batch_norm/~/mean_ema': FlatMapping({
                                                  'average': DeviceArray([[[ 0.6382035 ,  0.3566168 ,  0.08486773, -0.17143543,
                                                                            -0.38496044,  0.6316271 , -0.45933118, -0.05595254]],
                                                             
                                                                          [[-0.26198316, -0.1494211 , -0.03911211,  0.06703265,
                                                                             0.16195728, -0.26247463,  0.19609499,  0.01839363]],
                                                             
                                                                          [[ 0.23847139, -0.09976219, -0.2969207 , -0.3211375 ,
                                                                             0

(4, 1)


NameError: name 'y_qry_cl' is not defined

In [20]:
out[0]

DeviceArray(0.6935904, dtype=float32)

In [14]:
out[1]["initial_outer_out"]["loss"], out[1]["outer_out"]["loss"]

(DeviceArray([0.6931472, 0.6931472, 0.6931472, 0.6931472], dtype=float32),
 DeviceArray([0.6638154, 0.6638154, 0.6638154, 0.6638154], dtype=float32))

In [38]:
tree_shape(out[1]["inner_out"]["fast_params"])

FlatMapping({
  'simple_model/fast/batch_norm': FlatMapping({'offset': (4, 1, 8), 'scale': (4, 1, 8)}),
  'simple_model/fast/linear': FlatMapping({'b': (4, 8), 'w': (4, 8, 8)}),
  'simple_model/fast_1/batch_norm': FlatMapping({'offset': (4, 1, 4), 'scale': (4, 1, 4)}),
  'simple_model/fast_1/linear': FlatMapping({'b': (4, 4), 'w': (4, 8, 4)}),
  'simple_model/fast_1/linear_1': FlatMapping({'b': (4, 2), 'w': (4, 4, 2)}),
})

In [19]:
tree_shape(out)

((),
 {'initial_outer_out': {'fast_state': FlatMapping({
     'simple_model/fast/batch_norm/~/mean_ema': FlatMapping({'average': (4, 1, 8), 'counter': (4,), 'hidden': (4, 1, 8)}),
     'simple_model/fast/batch_norm/~/var_ema': FlatMapping({'average': (4, 1, 8), 'counter': (4,), 'hidden': (4, 1, 8)}),
     'simple_model/fast_1/batch_norm/~/mean_ema': FlatMapping({'average': (4, 1, 4), 'counter': (4,), 'hidden': (4, 1, 4)}),
     'simple_model/fast_1/batch_norm/~/var_ema': FlatMapping({'average': (4, 1, 4), 'counter': (4,), 'hidden': (4, 1, 4)}),
     'simple_model/slow1/batch_norm/~/mean_ema': FlatMapping({'average': (4, 1, 8), 'counter': (4,), 'hidden': (4, 1, 8)}),
     'simple_model/slow1/batch_norm/~/var_ema': FlatMapping({'average': (4, 1, 8), 'counter': (4,), 'hidden': (4, 1, 8)}),
     'simple_model/slow2/batch_norm/~/mean_ema': FlatMapping({'average': (4, 1, 8), 'counter': (4,), 'hidden': (4, 1, 8)}),
     'simple_model/slow2/batch_norm/~/var_ema': FlatMapping({'average': (4, 1,

In [23]:
out

(DeviceArray(0.6931472, dtype=float32),
 {'inner_out': {'fast_params': FlatMapping({
     'simple_model/fast/batch_norm': FlatMapping({
                                       'offset': DeviceArray([[[0., 0., 0., 0., 0., 0., 0., 0.]],
                                                 
                                                              [[0., 0., 0., 0., 0., 0., 0., 0.]],
                                                 
                                                              [[0., 0., 0., 0., 0., 0., 0., 0.]],
                                                 
                                                              [[0., 0., 0., 0., 0., 0., 0., 0.]]], dtype=float32),
                                       'scale': DeviceArray([[[1., 1., 1., 1., 1., 1., 1., 1.]],
                                                
                                                             [[1., 1., 1., 1., 1., 1., 1., 1.]],
                                                
             

In [15]:
out[1]["inner_out"]["loss_aux"]

{'acc': DeviceArray([[[1.],
               [1.]],
 
              [[1.],
               [1.]],
 
              [[0.],
               [1.]],
 
              [[0.],
               [1.]]], dtype=float32),
 'loss': DeviceArray([[[[0.6931472]],
 
               [[0.6782596]]],
 
 
              [[[0.6931472]],
 
               [[0.6782596]]],
 
 
              [[[0.6931472]],
 
               [[0.6782596]]],
 
 
              [[[0.6931472]],
 
               [[0.6782596]]]], dtype=float32)}

In [16]:
out[1]["initial_outer_out"]["loss_aux"], out[1]["initial_outer_out"]["loss"]

({'acc': DeviceArray([[1., 1.],
               [1., 1.],
               [0., 0.],
               [0., 0.]], dtype=float32),
  'loss': DeviceArray([[[0.6931472],
                [0.6931472]],
  
               [[0.6931472],
                [0.6931472]],
  
               [[0.6931472],
                [0.6931472]],
  
               [[0.6931472],
                [0.6931472]]], dtype=float32)},
 DeviceArray([0.6931472, 0.6931472, 0.6931472, 0.6931472], dtype=float32))

In [17]:
out[1]["outer_out"]["loss_aux"], out[1]["outer_out"]["loss"]

({'acc': DeviceArray([[1., 1.],
               [1., 1.],
               [1., 1.],
               [1., 1.]], dtype=float32),
  'loss': DeviceArray([[[0.6638154],
                [0.6638154]],
  
               [[0.6638154],
                [0.6638154]],
  
               [[0.6638154],
                [0.6638154]],
  
               [[0.6638154],
                [0.6638154]]], dtype=float32)},
 DeviceArray([0.6638154, 0.6638154, 0.6638154, 0.6638154], dtype=float32))