In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import os.path as osp
import sys

if "../src" not in sys.path:
    sys.path.insert(0, "../src")

In [96]:
import numpy as onp

import jax
from jax import numpy as jnp, random
from jax.random import split, PRNGKey

import haiku as hk
from haiku.data_structures import merge
import optax as ox

from helpers import SimpleModel
from utils.losses import mean_xe_and_acc_dict
from datasets import get_dataset
from utils import shuffle_along_axis, use_self_as_default, split_rng_or_none, call_self_as_default, expand, tree_shape, first_leaf_shape
from utils.data.datasets import SimpleDataset, ImageDataset, MetaDataset, MetaDatasetArray

# from trainers.continual_learning import ContinualLearningTrainer
# from wrappers.continual_learner import ContinualLearnerB
from meta.base import MetaBase

from easydict import EasyDict as edict

In [4]:
did = ImageDataset(onp.arange(6).repeat(20).astype(jnp.float32), onp.arange(6).repeat(20), 1, 0.1)
md = MetaDatasetArray(
    did,
    2, 3, 4, 5, 25, False, False,
)

In [5]:
trainset = get_dataset("omniglot", "train", all=True, train=True, image_size=28)

In [6]:
meta_trainset = MetaDatasetArray(trainset, 2, 5, 5, 15, 64)
meta_trainset_sampler = meta_trainset.get_sampler(PRNGKey(0))

In [117]:
from meta.wrappers import MetaLearnerBaseB, ContinualLearnerB

In [118]:
model = hk.transform_with_state(lambda x, phase, training: SimpleModel()(x, phase, training))
params, state = model.init(PRNGKey(0), random.normal(PRNGKey(1), (2, 2)), "all", True)
continual_learner = ContinualLearnerB(
    model.apply,
    params,
    state,
    SimpleModel.train_slow_phase,
    SimpleModel.train_fast_phase,
    SimpleModel.get_train_slow_params,
    SimpleModel.get_train_fast_params,
    SimpleModel.get_train_slow_state,
    SimpleModel.get_train_fast_state,
    3e-2,
)
out = continual_learner.inner_loop(
    random.normal(PRNGKey(1), (4, 2, 2)), jnp.arange(2)[:, None].repeat(2, 0).repeat(2, 1),
)

In [119]:
tree_shape(out)

{'fast_params': FlatMapping({
   'simple_model/fast/batch_norm': FlatMapping({'offset': (4, 1, 8), 'scale': (4, 1, 8)}),
   'simple_model/fast/linear': FlatMapping({'b': (4, 8), 'w': (4, 8, 8)}),
   'simple_model/fast_1/batch_norm': FlatMapping({'offset': (4, 1, 4), 'scale': (4, 1, 4)}),
   'simple_model/fast_1/linear': FlatMapping({'b': (4, 4), 'w': (4, 8, 4)}),
   'simple_model/fast_1/linear_1': FlatMapping({'b': (4, 2), 'w': (4, 4, 2)}),
 }),
 'fast_state': FlatMapping({
   'simple_model/fast/batch_norm/~/mean_ema': FlatMapping({'average': (4, 1, 8), 'counter': (4,), 'hidden': (4, 1, 8)}),
   'simple_model/fast/batch_norm/~/var_ema': FlatMapping({'average': (4, 1, 8), 'counter': (4,), 'hidden': (4, 1, 8)}),
   'simple_model/fast_1/batch_norm/~/mean_ema': FlatMapping({'average': (4, 1, 4), 'counter': (4,), 'hidden': (4, 1, 4)}),
   'simple_model/fast_1/batch_norm/~/var_ema': FlatMapping({'average': (4, 1, 4), 'counter': (4,), 'hidden': (4, 1, 4)}),
 }),
 'loss': (4, 2),
 'loss_aux': 

In [142]:
class MetaTrainerB(MetaLearnerBaseB):
    def __init__(
        self,
        learner,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.learner = learner

    def batch_fast_apply_and_loss(
        self,
        x,
        y,
        rng=None,
        params=None,
        state=None,
        fast_params=None,
        fast_state=None,
        training=None,
        loss_fn=None, 
    ):
        def f(_x, _y, _fast_params=None, _fast_state=None, _rng=None):
            return self.fast_apply_and_loss(
                _x,
                _y,
                _rng,
                params,
                state, 
                training,
                _fast_params,
                _fast_state,
                loss_fn,
            )
        
        inputs = {}
        if rng is not None:
            inputs["_rng"] = split(rng, first_leaf_shape(y)[0])
        if fast_params is not None:
            inputs["_fast_params"] = fast_params
        if fast_state is not None:
            inputs["_fast_state"] = fast_state
        
        loss, (outer_fast_state, loss_aux, outputs) = jax.vmap(f)(x, y, **inputs)
        return dict(loss=loss, fast_state=outer_fast_state, loss_aux=loss_aux, outputs=outputs)
        
    def outer_loss(
        self,
        x_spt,
        y_spt,
        x_qry,
        y_qry,
        rng=None,
        params=None,
        state=None,
        training=None,
        loss_fn=None,
        opt_state=None,
        lr=None,
        init_opt_state=None,
        opt_update=None,
    ):
        rng_inner, rng_outer_slow, rng_outer_fast = split_rng_or_none(rng, 3)
        
        slow_outputs, slow_state = self.slow_apply(
            x_qry, rng_outer_slow, params, state, training,
        )
        
        initial_outer_out = self.batch_fast_apply_and_loss(
            slow_outputs,
            y_qry,
            rng_outer_fast,
            params,
            slow_state,
            training=training,
            loss_fn=loss_fn, 
        )
        
        inner_out = self.learner.inner_loop(
            x_spt,
            y_spt,
            params,
            slow_state,
            rng_inner,
            training=training,
            opt_state=opt_state,
            lr=lr,
            loss_fn=loss_fn,
            init_opt_state=init_opt_state,
            opt_update=opt_update,
        )
        
        outer_out = self.batch_fast_apply_and_loss(
            slow_outputs,
            y_qry,
            rng_outer_fast,
            params,
            slow_state,
            inner_out["fast_params"],
            inner_out["fast_state"],
            training,
            loss_fn, 
        )
        
        return jnp.mean(outer_out["loss"]), dict(inner_out=inner_out, outer_out=outer_out, initial_outer_out=initial_outer_out)
    
    

In [144]:
meta_trainer = MetaTrainerB(
    continual_learner,
    model.apply,
    params,
    state,
    SimpleModel.train_slow_phase,
    SimpleModel.train_fast_phase,
    SimpleModel.get_train_slow_params,
    SimpleModel.get_train_fast_params,
    SimpleModel.get_train_slow_state,
    SimpleModel.get_train_fast_state,
)
out = meta_trainer.outer_loss(
    random.normal(PRNGKey(1), (4, 2, 2)),
    jnp.arange(2)[:, None].repeat(2, 0).repeat(2, 1),
    random.normal(PRNGKey(1), (4, 2, 2)),
    jnp.arange(2)[:, None].repeat(2, 0).repeat(2, 1),
)

In [147]:
out[1].keys()

dict_keys(['inner_out', 'outer_out', 'initial_outer_out'])

In [153]:
out[1]["initial_outer_out"]["loss_aux"], out[1]["initial_outer_out"]["loss"]

({'acc': DeviceArray([1., 1., 0., 0.], dtype=float32)},
 DeviceArray([0.6931472, 0.6931472, 0.6931472, 0.6931472], dtype=float32))

In [151]:
out[1]["outer_out"]["loss_aux"], out[1]["outer_out"]["loss"]

({'acc': DeviceArray([1., 1., 1., 1.], dtype=float32)},
 DeviceArray([0.6638154, 0.6638154, 0.6638154, 0.6638154], dtype=float32))