In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import os.path as osp
import sys

if "../src" not in sys.path:
    sys.path.insert(0, "../src")

In [3]:
import numpy as onp

import jax
from jax import numpy as jnp, random
from jax.random import split, PRNGKey

import haiku as hk
from haiku.data_structures import merge
import optax as ox

from utils.losses import mean_xe_and_acc_dict
from datasets import get_dataset
from utils import shuffle_along_axis, use_self_as_default, split_rng_or_none
from utils.data.datasets import SimpleDataset, ImageDataset, MetaDataset, MetaDatasetArray

In [4]:
did = ImageDataset(onp.arange(6).repeat(20).astype(jnp.float32), onp.arange(6).repeat(20), 1, 0.1)
md = MetaDatasetArray(
    did,
    2, 3, 4, 5, 25, False, False,
)
sampler = md.get_sampler(PRNGKey(1))
next(sampler)

(array([[[0., 0., 0., 0.],
         [4., 4., 4., 4.],
         [3., 3., 3., 3.]],
 
        [[3., 3., 3., 3.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.]]], dtype=float32),
 array([[[0, 0, 0, 0],
         [4, 4, 4, 4],
         [3, 3, 3, 3]],
 
        [[3, 3, 3, 3],
         [1, 1, 1, 1],
         [0, 0, 0, 0]]]),
 array([[[0., 0., 0., 0., 0.],
         [4., 4., 4., 4., 4.],
         [3., 3., 3., 3., 3.]],
 
        [[3., 3., 3., 3., 3.],
         [1., 1., 1., 1., 1.],
         [0., 0., 0., 0., 0.]]], dtype=float32),
 array([[[0, 0, 0, 0, 0],
         [4, 4, 4, 4, 4],
         [3, 3, 3, 3, 3]],
 
        [[3, 3, 3, 3, 3],
         [1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0]]]),
 array([[2., 0., 2., 1., 1., 0., 2., 1., 2., 2., 5., 3., 2., 0., 4., 3.,
         1., 4., 2., 5., 2., 5., 1., 0., 3.],
        [4., 2., 2., 1., 4., 1., 2., 3., 1., 0., 4., 4., 3., 3., 5., 0.,
         1., 0., 2., 0., 1., 5., 0., 1., 2.]], dtype=float32),
 array([[2, 0, 2, 1, 1, 0, 2, 1, 2, 2, 5, 3, 2, 0

In [5]:
trainset = get_dataset("omniglot", "train", all=True, train=True, image_size=28)

In [6]:
meta_trainset = MetaDatasetArray(trainset, 2, 5, 5, 15, 64)
meta_trainset_sampler = meta_trainset.get_sampler(PRNGKey(0))

In [None]:
class BaseContinualLearner:
    def __init__(
        apply,
        model_class,
        params=None,
        state=None,
        lr=None,
        training=True,
        init_train_inner_opt_state=ox.sgd(0).init,
        init_test_inner_opt_state=ox.adam(0).init,
        loss_fn=mean_xe_and_acc_dict,
        bmap=True,  # Use hk.BatchApply instead of jax.vmap
    ):
        # super().__init__(*args, **kwargs)
        self.apply = apply
        self.model_class = model_class
        self.params = params
        self.state = state
        self.lr = lr
        self.training = training
        self.init_train_inner_opt_state = init_train_inner_opt_state
        self.init_test_inner_opt_state = init_test_inner_opt_state
        self.bmap = bmap

        self.train_slow_phase = self.model_class.train_slow_phase
        self.train_fast_phase = self.model_class.train_fast_phase
        self.test_slow_phase = self.model_class.test_slow_phase
        self.test_fast_phase = self.model_class.test_fast_phase

        self.get_train_sp = self.model_class.get_train_slow_params
        self.get_train_fp = self.model_class.get_train_fast_params
        self.get_test_sp = self.model_class.get_test_slow_params
        self.get_test_sp = self.model_class.get_test_slow_params

        self.get_train_ss = self.model_class.get_train_slow_state
        self.get_train_fs = self.model_class.get_train_fast_state
        self.get_test_ss = self.model_class.get_test_slow_state
        self.get_test_ss = self.model_class.get_test_slow_state

    def init_inner_opt_state(self, training):
        if training:
            return self.init_train_inner_opt_state
        return self.init_test_inner_opt_state

    @use_self_as_default("training")
    def get_phases(self, training=None):
        if training:
            return self.train_slow_phase, self.train_fast_phase
        return self.test_slow_phase, self.test_fast_phase
    
    @use_self_as_default("params", "state", "rng", "training", "loss_fn")
    def fast_apply_and_loss(
        self,
        x,
        y,
        params=None,
        state=None,
        rng=None,
        training=None,
        loss_fn=None,
    ):
        phase = self.get_phases(training)[1]
        output, new_state = self.apply(
            params, state, rng, x, phase=phase, training=training,
        )
        loss, loss_aux = loss_fn(outputs, y)
        return loss, (new_state, loss_aux, output)
    
    @use_self_as_default("params", "init_inner_opt_state")
    def single_fast_step(
        self,
        x,
        y,
        params=None,
        state=None,
        rng=None,
        training=None,
        loss_fn=None,
        init_inner_opt_state=None,
        opt_state=None,
    ):
        if opt_state is None:
            opt_state = init_inner_opt_state(params)
            
        (loss, (new_state, loss_aux, output)) = jax.value_and_grad(
            self.slow_apply_and_loss, argnums=2
        )(x, y, params, state, rng, training, loss_fn)
        
        
    
    @use_self_as_default(
        "params", "state", "training", "lr", "init_inner_opt_state", "loss_fn"
    )
    def inner_loop(
        self,
        x_spt,
        y_spt,
        params=None,
        state=None,
        rng=None,
        # inp_kwargs=None,
        training=None,
        opt_state=None,
        lr=None,
        init_inner_opt_state=None,
        loss_fn=None,
    ):
        slow_params, fast_params = self.split_params(params, training)
        slow_state, fast_state = self.split_state(state, training)
        slow_phase, fast_phase = self.get_phases(training)
        rng, *rngs = split_rng_or_none(rng, 1 + len(x_spt))
        # if opt_state is None:
        #     opt_state = init_inner_opt_state(fast_params)
        # if inp_kwargs is None:
        #     inp_kwargs = {}

        def fast_apply_and_loss(fast_params, fast_state, rng, x):
            pred, state = self.apply(
                rng,
                merge(slow_params, fast_params),
                merge(slow_state, fast_state),
                rng,
                x,
                phase=fast_phase,
                training=training,
                # **inp_kwargs,
            )

        def scan_fun(carry, xs):
            _fast_params, _fast_state, _opt_state = carry
            _x, _y, _rng = xs
            (loss, (_fast_state, aux)), grads = jax.value_and_grad(lambda x, y: None)

    def split_params_train(self, params):
        return (
            self.model_class.get_train_slow_params(params),
            self.model_class.get_train_fast_params(params),
        )

    def split_params_test(self, params):
        return (
            self.model_class.get_test_slow_params(params),
            self.model_class.get_test_fast_params(params),
        )

    def split_state_train(self, state):
        return (
            self.model_class.get_train_slow_state(state),
            self.model_class.get_train_fast_state(state),
        )

    def split_state_test(self, state):
        return (
            self.model_class.get_test_slow_state(state),
            self.model_class.get_test_fast_state(state),
        )

    def split_params(self, params, training):
        if training:
            return self.split_params_train(params)
        else:
            return self.split_params_test(params)

    def split_state(self, state, training):
        if training:
            return self.split_state_train(state)
        else:
            return self.split_params_test(state)


#     def __call__(self):
#         raise NotImplementedError

#     def get_train_slow_params(self):
#         raise NotImplementedError

#     def get_train_fast_params(self):
#         raise NotImplementedError

#     def get_train_slow_state(self):
#         raise NotImplementedError

#     def get_train_fast_state(self):
#         raise NotImplementedError

#     def get_test_slow_params(self):
#         raise NotImplementedError

#     def get_test_fast_params(self):
#         raise NotImplementedError

#     def get_test_slow_state(self):
#         raise NotImplementedError

#     def get_test_fast_state(self):
#         raise NotImplementedError

array([  0,   0,   0, ..., 663, 663, 663])