In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import os.path as osp
import sys

if "../src" not in sys.path:
    sys.path.insert(0, "../src")

In [3]:
import numpy as onp

import jax
from jax import numpy as jnp, random
from jax.random import split, PRNGKey

import haiku as hk
from haiku.data_structures import merge
import optax as ox

from helpers import SimpleModel
from utils.losses import mean_xe_and_acc_dict
from datasets import get_dataset
from utils import shuffle_along_axis, use_self_as_default, split_rng_or_none, call_self_as_default, expand, tree_shape, first_leaf_shape
from utils.data.datasets import SimpleDataset, ImageDataset, MetaDataset, MetaDatasetArray

from trainers.continual_learning import ContinualLearningTrainer
from wrappers.continual_learner import ContinualLearningWrapperBase

In [4]:
did = ImageDataset(onp.arange(6).repeat(20).astype(jnp.float32), onp.arange(6).repeat(20), 1, 0.1)
md = MetaDatasetArray(
    did,
    2, 3, 4, 5, 25, False, False,
)

In [5]:
trainset = get_dataset("omniglot", "train", all=True, train=True, image_size=28)

In [6]:
meta_trainset = MetaDatasetArray(trainset, 2, 5, 5, 15, 64)
meta_trainset_sampler = meta_trainset.get_sampler(PRNGKey(0))

In [7]:
# model = hk.transform_with_state(lambda x, phase, training: SimpleModel()(x, phase, training))
# params, state = model.init(PRNGKey(0), random.normal(PRNGKey(1), (2, 2)), "all", True)

# meta_learner = BaseMetaLearner(model.apply, SimpleModel, state=state, bmap=True)
# algo = meta_learner.inner_loop(
#     random.normal(PRNGKey(0), (2, 2, 2)),
#     random.randint(PRNGKey(0), (2, 2), 0, 2),
#     params,
#     # state,
#     # expand(state, 2),
#     rng=None,
#     lr=-3e-2,
#     # training=False,
# )

In [69]:
class MyContinualLearnerWrapper(ContinualLearningWrapperBase):
    @use_self_as_default("init_opt_state")
    def inner_loop(
        self,
        x_spt,
        y_spt,
        params=None,
        state=None,
        rng=None,
        fast_params=None,
        fast_state=None,
        training=None,
        opt_state=None,
        lr=None,
        loss_fn=None,
        init_opt_state=None,
        opt_update=None,
        counter=None,
    ):
        bsz, traj_length = first_leaf_shape(x_spt)[:2]
        rng_slow, rng_fast = split_rng_or_none(rng)
        slow_outputs, slow_state = self.bmap_slow_apply(
            x_spt,
            rng=rng_slow,
            params=params,
            state=state,
            training=training,
        )

        if fast_params is None:
            fast_params = expand(self.get_fp(params or self.params), bsz)
        if fast_state is None:
            fast_state = expand(self.get_fs(merge(state or self.state, slow_state)), bsz)
        if opt_state is None:
            opt_state = (init_opt_state(fast_params))
        
        def scan_fun(carry, xs):
            nonlocal rng_fast
            (fast_params, fast_state, opt_state) = carry
                      
            if rng_fast is None:
                _rng = None
                x, y = xs
            else:
                x, y, _rng = xs            
            (
                fast_params,
                new_state,
                opt_state,
                (loss, loss_aux, outputs),
            ) = self.single_fast_step(
                x,
                y,
                rng=_rng,
                params=params,
                state=state,
                fast_params=fast_params,
                fast_state=fast_state,
                lr=lr,
                training=training,
                loss_fn=loss_fn,
                opt_update=opt_update,
                init_opt_state=init_opt_state,
                opt_state=opt_state,
            )
            out_carry = (fast_params, self.get_fs(new_state), opt_state)
            
            return out_carry, (loss, loss_aux, outputs)
        
        # Tranpose inputs to perform scan over trajectory
        scan_inputs = [slow_outputs, y_spt]
        scan_inputs = jax.tree_map(lambda x: jnp.transpose(x, (1, 0, *jnp.arange(2, len(x.shape)))), scan_inputs)
        scan_inputs = expand(scan_inputs, 1, 2)
        if rng_fast is not None:
            scan_inputs.append(split(rng, traj_length))

            
        final_carry, ys = jax.lax.scan(scan_fun, (fast_params, fast_state, opt_state), scan_inputs)
        return final_carry, jax.tree_map(lambda x: jnp.transpose(x, (1, 0, *jnp.arange(2, len(x.shape)))), ys)

In [78]:
model = hk.transform_with_state(lambda x, phase, training: SimpleModel()(x, phase, training))
params, state = model.init(PRNGKey(0), random.normal(PRNGKey(1), (2, 2)), "all", True)
continual_learner = MyContinualLearnerWrapper(
    model.apply,
    params,
    state,
    SimpleModel.train_slow_phase,
    SimpleModel.train_fast_phase,
    SimpleModel.get_train_slow_params,
    SimpleModel.get_train_fast_params,
    SimpleModel.get_train_slow_state,
    SimpleModel.get_train_fast_state,
    3e-2,
)
final_carry, ys = continual_learner.inner_loop(
    random.normal(PRNGKey(1), (4, 2, 2)), jnp.arange(2)[:, None].repeat(2, 0).repeat(2, 1),
)

[((2, 4, 1, 8),), (2, 4, 1)]
((2, 4), {'acc': (2, 4)}, (2, 4, 1, 2))


In [74]:
final_carry[0]

FlatMapping({
  'simple_model/fast/batch_norm': FlatMapping({
                                    'offset': DeviceArray([[[0., 0., 0., 0., 0., 0., 0., 0.]],
                                              
                                                           [[0., 0., 0., 0., 0., 0., 0., 0.]],
                                              
                                                           [[0., 0., 0., 0., 0., 0., 0., 0.]],
                                              
                                                           [[0., 0., 0., 0., 0., 0., 0., 0.]]], dtype=float32),
                                    'scale': DeviceArray([[[1., 1., 1., 1., 1., 1., 1., 1.]],
                                             
                                                          [[1., 1., 1., 1., 1., 1., 1., 1.]],
                                             
                                                          [[1., 1., 1., 1., 1., 1., 1., 1.]],
                          

In [44]:
jnp.arange(2)[:, None].repeat(2, 0).repeat(2, 1)

DeviceArray([[0, 0],
             [0, 0],
             [1, 1],
             [1, 1]], dtype=int32)

In [30]:
tree_shape(final_carry)

((FlatMapping({
    'simple_model/fast/batch_norm': FlatMapping({'offset': (4, 1, 8), 'scale': (4, 1, 8)}),
    'simple_model/fast/linear': FlatMapping({'b': (4, 8), 'w': (4, 8, 8)}),
    'simple_model/fast_1/batch_norm': FlatMapping({'offset': (4, 1, 4), 'scale': (4, 1, 4)}),
    'simple_model/fast_1/linear': FlatMapping({'b': (4, 4), 'w': (4, 8, 4)}),
    'simple_model/fast_1/linear_1': FlatMapping({'b': (4, 2), 'w': (4, 4, 2)}),
  }),
  FlatMapping({
    'simple_model/fast/batch_norm/~/mean_ema': FlatMapping({'average': (4, 1, 8), 'counter': (4,), 'hidden': (4, 1, 8)}),
    'simple_model/fast/batch_norm/~/var_ema': FlatMapping({'average': (4, 1, 8), 'counter': (4,), 'hidden': (4, 1, 8)}),
    'simple_model/fast_1/batch_norm/~/mean_ema': FlatMapping({'average': (4, 1, 4), 'counter': (4,), 'hidden': (4, 1, 4)}),
    'simple_model/fast_1/batch_norm/~/var_ema': FlatMapping({'average': (4, 1, 4), 'counter': (4,), 'hidden': (4, 1, 4)}),
  }),
  [IdentityState(), ScaleState()]),)

In [79]:
ys

(DeviceArray([[0.6931472, 0.6782596],
              [0.6931472, 0.6782596],
              [0.6931472, 0.6782596],
              [0.6931472, 0.6782596]], dtype=float32),
 {'acc': DeviceArray([[1., 1.],
               [1., 1.],
               [0., 1.],
               [0., 1.]], dtype=float32)},
 DeviceArray([[[[ 0.   ,  0.   ]],
 
               [[ 0.015, -0.015]]],
 
 
              [[[ 0.   ,  0.   ]],
 
               [[ 0.015, -0.015]]],
 
 
              [[[ 0.   ,  0.   ]],
 
               [[-0.015,  0.015]]],
 
 
              [[[ 0.   ,  0.   ]],
 
               [[-0.015,  0.015]]]], dtype=float32))

In [72]:
tree_shape(ys)

((4, 2), {'acc': (4, 2)}, (4, 2, 1, 2))

In [16]:
jax.tree_flatten(random.normal(PRNGKey(1), (2, 2)))[0][0].shape

(2, 2)

In [13]:
tree_shape(random.normal(PRNGKey(1), (2, 2)))

(2, 2)

In [95]:
trainer  = ContinualLearningTrainer(params, state, meta_learner)
trainer.inner_loop(
   random.normal(PRNGKey(0), (2, 2, 2)), random.randint(PRNGKey(0), (2, 2), 0, 2), params, state, None, 3e-2,
)

({'counter': DeviceArray([2., 2.], dtype=float32),
  'fast_params': FlatMapping({
    'simple_model/fast/batch_norm': FlatMapping({
                                      'offset': DeviceArray([[[0., 0., 0., 0., 0., 0., 0., 0.]],
                                                
                                                             [[0., 0., 0., 0., 0., 0., 0., 0.]]], dtype=float32),
                                      'scale': DeviceArray([[[1., 1., 1., 1., 1., 1., 1., 1.]],
                                               
                                                            [[1., 1., 1., 1., 1., 1., 1., 1.]]], dtype=float32),
                                    }),
    'simple_model/fast/linear': FlatMapping({
                                  'b': DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0.],
                                                    [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32),
                                  'w': DeviceArray([[[ 0.56481254,  0.408743