In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import os.path as osp
import sys

if "../src" not in sys.path:
    sys.path.insert(0, "../src")

In [3]:
import numpy as onp

import jax
from jax import numpy as jnp, random
from jax.random import split, PRNGKey

import haiku as hk
from haiku.data_structures import merge
import optax as ox

from utils.losses import mean_xe_and_acc_dict
from datasets import get_dataset
from utils import shuffle_along_axis, use_self_as_default, split_rng_or_none, call_self_as_default, expand
from utils.data.datasets import SimpleDataset, ImageDataset, MetaDataset, MetaDatasetArray

In [4]:
did = ImageDataset(onp.arange(6).repeat(20).astype(jnp.float32), onp.arange(6).repeat(20), 1, 0.1)
md = MetaDatasetArray(
    did,
    2, 3, 4, 5, 25, False, False,
)

In [5]:
md.sample(PRNGKey(1), 1)

<function MetaDataset.sample at 0x7eff052e9b90>
('self', 'rng', 'batch_size', 'way', 'shot', 'qry_shot', 'cl_qry_shot', 'disjoint', 'disjoint_cl_qry', 'rng_classes', 'rng_samples', 'rng_cl', 'sampled_classes', 'spt_indexes', 'qry_indexes', 'cl_qry_indexes', 'spt_dataset_indexes', 'qry_dataset_indexes', 'spt_inputs', 'spt_targets', 'qry_inputs', 'qry_targets', 'cl_qry_inputs', 'cl_qry_targets')
2
batch_size 2
way 3
getting way
shot 4
getting shot
qry_shot 5
getting qry_shot
cl_qry_shot 6
getting cl_qry_shot
disjoint 7
getting disjoint
disjoint_cl_qry 8
getting disjoint_cl_qry
{'way': 3, 'shot': 4, 'qry_shot': 5, 'cl_qry_shot': 25, 'disjoint': False, 'disjoint_cl_qry': False}


(array([[[2., 2., 2., 2.],
         [4., 4., 4., 4.],
         [0., 0., 0., 0.]]], dtype=float32),
 array([[[2, 2, 2, 2],
         [4, 4, 4, 4],
         [0, 0, 0, 0]]]),
 array([[[2., 2., 2., 2., 2.],
         [4., 4., 4., 4., 4.],
         [0., 0., 0., 0., 0.]]], dtype=float32),
 array([[[2, 2, 2, 2, 2],
         [4, 4, 4, 4, 4],
         [0, 0, 0, 0, 0]]]),
 array([[5., 4., 2., 0., 0., 5., 5., 4., 4., 1., 4., 2., 2., 1., 3., 5.,
         3., 4., 1., 3., 4., 4., 3., 5., 3.]], dtype=float32),
 array([[5, 4, 2, 0, 0, 5, 5, 4, 4, 1, 4, 2, 2, 1, 3, 5, 3, 4, 1, 3, 4, 4,
         3, 5, 3]]))

In [6]:
sampler = md.get_sampler(PRNGKey(1))
next(sampler)

<function MetaDataset.sample at 0x7eff052e9b90>
('self', 'rng', 'batch_size', 'way', 'shot', 'qry_shot', 'cl_qry_shot', 'disjoint', 'disjoint_cl_qry', 'rng_classes', 'rng_samples', 'rng_cl', 'sampled_classes', 'spt_indexes', 'qry_indexes', 'cl_qry_indexes', 'spt_dataset_indexes', 'qry_dataset_indexes', 'spt_inputs', 'spt_targets', 'qry_inputs', 'qry_targets', 'cl_qry_inputs', 'cl_qry_targets')
1
batch_size 2
getting batch_size
way 3
getting way
shot 4
getting shot
qry_shot 5
getting qry_shot
cl_qry_shot 6
getting cl_qry_shot
disjoint 7
getting disjoint
disjoint_cl_qry 8
getting disjoint_cl_qry
{'batch_size': 2, 'way': 3, 'shot': 4, 'qry_shot': 5, 'cl_qry_shot': 25, 'disjoint': False, 'disjoint_cl_qry': False}


(array([[[0., 0., 0., 0.],
         [4., 4., 4., 4.],
         [3., 3., 3., 3.]],
 
        [[3., 3., 3., 3.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.]]], dtype=float32),
 array([[[0, 0, 0, 0],
         [4, 4, 4, 4],
         [3, 3, 3, 3]],
 
        [[3, 3, 3, 3],
         [1, 1, 1, 1],
         [0, 0, 0, 0]]]),
 array([[[0., 0., 0., 0., 0.],
         [4., 4., 4., 4., 4.],
         [3., 3., 3., 3., 3.]],
 
        [[3., 3., 3., 3., 3.],
         [1., 1., 1., 1., 1.],
         [0., 0., 0., 0., 0.]]], dtype=float32),
 array([[[0, 0, 0, 0, 0],
         [4, 4, 4, 4, 4],
         [3, 3, 3, 3, 3]],
 
        [[3, 3, 3, 3, 3],
         [1, 1, 1, 1, 1],
         [0, 0, 0, 0, 0]]]),
 array([[2., 0., 2., 1., 1., 0., 2., 1., 2., 2., 5., 3., 2., 0., 4., 3.,
         1., 4., 2., 5., 2., 5., 1., 0., 3.],
        [4., 2., 2., 1., 4., 1., 2., 3., 1., 0., 4., 4., 3., 3., 5., 0.,
         1., 0., 2., 0., 1., 5., 0., 1., 2.]], dtype=float32),
 array([[2, 0, 2, 1, 1, 0, 2, 1, 2, 2, 5, 3, 2, 0

In [7]:
trainset = get_dataset("omniglot", "train", all=True, train=True, image_size=28)

In [8]:
meta_trainset = MetaDatasetArray(trainset, 2, 5, 5, 15, 64)
meta_trainset_sampler = meta_trainset.get_sampler(PRNGKey(0))

In [9]:
def filter_structure_by_str(instr, structure):
    return hk.data_structures.filter(lambda module_name, name, value: instr in module_name, structure)

class SimpleModel(hk.Module):
    train_slow_phase = "train_slow_phase"
    train_fast_phase = "train_fast_phase"
    test_slow_phase = "test_slow_phase"
    test_fast_phase = "test_fast_phase"
    
    get_train_slow_params = jax.partial(filter_structure_by_str, "slow")
    get_train_fast_params = jax.partial(filter_structure_by_str, "fast")
    get_train_slow_state = jax.partial(filter_structure_by_str, "slow")
    get_train_fast_state = jax.partial(filter_structure_by_str, "fast")
    
    get_test_slow_params = jax.partial(filter_structure_by_str, "slow")
    get_test_fast_params = jax.partial(filter_structure_by_str, "fast")
    get_test_slow_state = jax.partial(filter_structure_by_str, "slow")
    get_test_fast_state = jax.partial(filter_structure_by_str, "fast")
    
    def __call__(self, inputs, phase, training):
        if phase in [self.train_slow_phase, self.test_slow_phase, "all"]:
            with hk.experimental.name_scope("slow1"):
                x = inputs
                x = hk.Linear(8)(x)
                x = hk.BatchNorm(True, True, 0.99)(x, training)
                out = (x,)
        else:
            x, = inputs
            
        if phase in [self.train_fast_phase, self.test_fast_phase, "all"]:
            
            with hk.experimental.name_scope("fast"):
                x = hk.Linear(8)(x)
                x = hk.BatchNorm(True, True, 0.99)(x, training)

            with hk.experimental.name_scope("slow2"):
                x = hk.Linear(8)(x)
                x = hk.BatchNorm(True, True, 0.99)(x, training)
            
            with hk.experimental.name_scope("fast"):
                x = hk.Linear(4)(x)
                x = hk.BatchNorm(True, True, 0.99)(x, training)
                x = hk.Linear(2)(x)
            
            out = x
            
        return out

In [10]:
model = hk.transform_with_state(lambda x, phase, training: SimpleModel()(x, phase, training))
params, state = model.init(PRNGKey(0), random.normal(PRNGKey(1), (2, 2)), "all", True)

In [11]:
jax.tree_map(jnp.shape, params)

FlatMapping({
  'simple_model/fast/batch_norm': FlatMapping({'offset': (1, 8), 'scale': (1, 8)}),
  'simple_model/fast/linear': FlatMapping({'b': (8,), 'w': (8, 8)}),
  'simple_model/fast_1/batch_norm': FlatMapping({'offset': (1, 4), 'scale': (1, 4)}),
  'simple_model/fast_1/linear': FlatMapping({'b': (4,), 'w': (8, 4)}),
  'simple_model/fast_1/linear_1': FlatMapping({'b': (2,), 'w': (4, 2)}),
  'simple_model/slow1/batch_norm': FlatMapping({'offset': (1, 8), 'scale': (1, 8)}),
  'simple_model/slow1/linear': FlatMapping({'b': (8,), 'w': (2, 8)}),
  'simple_model/slow2/batch_norm': FlatMapping({'offset': (1, 8), 'scale': (1, 8)}),
  'simple_model/slow2/linear': FlatMapping({'b': (8,), 'w': (8, 8)}),
})

In [12]:
jax.tree_map(jnp.shape, state)

FlatMapping({
  'simple_model/fast/batch_norm/~/mean_ema': FlatMapping({'average': (1, 8), 'counter': (), 'hidden': (1, 8)}),
  'simple_model/fast/batch_norm/~/var_ema': FlatMapping({'average': (1, 8), 'counter': (), 'hidden': (1, 8)}),
  'simple_model/fast_1/batch_norm/~/mean_ema': FlatMapping({'average': (1, 4), 'counter': (), 'hidden': (1, 4)}),
  'simple_model/fast_1/batch_norm/~/var_ema': FlatMapping({'average': (1, 4), 'counter': (), 'hidden': (1, 4)}),
  'simple_model/slow1/batch_norm/~/mean_ema': FlatMapping({'average': (1, 8), 'counter': (), 'hidden': (1, 8)}),
  'simple_model/slow1/batch_norm/~/var_ema': FlatMapping({'average': (1, 8), 'counter': (), 'hidden': (1, 8)}),
  'simple_model/slow2/batch_norm/~/mean_ema': FlatMapping({'average': (1, 8), 'counter': (), 'hidden': (1, 8)}),
  'simple_model/slow2/batch_norm/~/var_ema': FlatMapping({'average': (1, 8), 'counter': (), 'hidden': (1, 8)}),
})

In [44]:
def make_simple_opt_update(opt):
    def f(lr, updates, state, params):
        return opt(lr).update(updates, state, params)
    return f
    

sgd_opt_update = make_simple_opt_update(ox.sgd)
adam_opt_update = make_simple_opt_update(ox.adam)


def make_simple_init_opt(opt):
    return opt(0).init


sgd_init_opt = make_simple_init_opt(ox.sgd)
adam_init_opt = make_simple_init_opt(ox.adam)


class BaseMetaLearner:
    def __init__(
        self,
        apply,
        model_class,
        params=None,
        state=None,
        lr=None,
        training=True,
        train_init_opt=sgd_init_opt,
        test_init_opt=None,
        train_opt_update=sgd_opt_update,
        test_opt_update=None,
        loss_fn=mean_xe_and_acc_dict,
        bmap=True,  # Use hk.BatchApply instead of jax.vmap
    ):
        if test_init_opt is None:
            test_init_opt = train_init_opt
        if test_opt_update is None:
            test_opt_update = train_opt_update

        # super().__init__(*args, **kwargs)
        self.apply = apply
        self.model_class = model_class
        self.params = params
        self.state = state
        self.lr = lr
        self.training = training
        self.loss_fn = loss_fn
        self.train_init_opt = train_init_opt
        self.test_init_opt = test_init_opt
        self.train_opt_update = train_opt_update
        self.test_opt_update = test_opt_update
        self.bmap = bmap

        self.train_slow_phase = self.model_class.train_slow_phase
        self.train_fast_phase = self.model_class.train_fast_phase
        self.test_slow_phase = self.model_class.test_slow_phase
        self.test_fast_phase = self.model_class.test_fast_phase

        self.get_train_sp = self.model_class.get_train_slow_params
        self.get_train_fp = self.model_class.get_train_fast_params
        self.get_test_sp = self.model_class.get_test_slow_params
        self.get_test_fp = self.model_class.get_test_slow_params

        self.get_train_ss = self.model_class.get_train_slow_state
        self.get_train_fs = self.model_class.get_train_fast_state
        self.get_test_ss = self.model_class.get_test_slow_state
        self.get_test_fs = self.model_class.get_test_fast_state

    # @use_self_as_default("training")
    def init_opt_state(self, training):
        if training:
            return self.train_init_opt
        return self.test_init_opt

    # @use_self_as_default("training")
    def opt_update(self, training):
        if training:
            return self.train_opt_update
        return self.test_opt_update

    # @use_self_as_default("training")
    def get_phases(self, training=None):
        if training:
            return self.train_slow_phase, self.train_fast_phase
        return self.test_slow_phase, self.test_fast_phase

    # @use_self_as_default("params", "training")
    def get_sp(self, params=None, training=None):
        if training:
            return self.get_train_sp(params)
        return self.get_test_sp(params)

    # @use_self_as_default("params", "training")
    def get_fp(self, params=None, training=None):
        if training:
            return self.get_train_fp(params)
        return self.get_test_fp(params)

    # @use_self_as_default("state", "training")
    def get_ss(self, state=None, training=None):
        if training:
            return self.get_train_ss(state)
        return self.get_test_ss(state)

    # @use_self_as_default("state", "training")
    def get_fs(self, state=None, training=None):
        if training:
            return self.get_train_fs(state)
        return self.get_test_fs(state)

    # @use_self_as_default("training"):
    def slow_phase(self, training):
        if training:
            return self.train_slow_phase
        return self.test_slow_phase

    # @use_self_as_default("training"):
    def fast_phase(self, training):
        if training:
            return self.train_fast_phase
        return self.test_fast_phase

    @use_self_as_default(
        "params", "state", "training", "loss_fn", fast_phase=["training"]
    )
    def fast_apply_and_loss(
        self,
        x,
        y,
        params=None,
        state=None,
        rng=None,
        fast_params=None,
        fast_state=None,
        training=None,
        loss_fn=None,
        fast_phase=None,
    ):
        if fast_params is not None:
            params = merge(params, fast_params)
        if fast_state is not None:
            state = merge(state, fast_state)
        outputs, new_state = self.apply(
            params,
            state,
            rng,
            x,
            phase=fast_phase,
            training=training,
        )
        loss, loss_aux = loss_fn(outputs, y)
        return loss, (self.get_fs(new_state), loss_aux, outputs)

    @use_self_as_default(
        "params", "state", "training", slow_phase=["training"], fast_phase=["training"]
    )
    def apply_and_loss(
        self,
        x,
        y,
        params=None,
        state=None,
        rng=None,
        training=None,
        loss_fn=None,
        slow_phase=None,
        fast_phase=None,
    ):
        if rng is None:
            rng_slow = rng_fast = None
        else:
            rng_slow, rng_fast = split(rng)

        outputs, slow_state = self.apply(
            params,
            state,
            rng_slow,
            x,
            phase=slow_phase,
            training=training,
        )
        loss, (fast_state, *others) = self.fast_apply_and_loss(
            outputs,
            y,
            params,
            state,
            rng_fast,
            training,
            loss_fn,
            fast_phase,
        )

        return loss, (merge(slow_state, fast_state), *others)

    @use_self_as_default(
        "training",
        "lr",
        init_opt_state=["training"],
        opt_update=["training"],
        fast_phase=["training"],
    )
    def single_fast_step(
        self,
        x,
        y,
        params=None,
        state=None,
        rng=None,
        fast_params=None,
        fast_state=None,
        lr=None,
        training=None,
        loss_fn=None,
        init_opt_state=None,
        opt_update=None,
        opt_state=None,
        fast_phase=None,
    ):
        if fast_params is None:
            fast_params = self.get_fp(params, training)
        if opt_state is None:
            opt_state = init_opt_state(fast_params)

        (loss, (fast_state, loss_aux, outputs)), grads = jax.value_and_grad(
            self.fast_apply_and_loss,
            argnums=5,
            has_aux=True,
        )(
            x,
            y,
            params,
            state,
            rng,
            fast_params,
            fast_state,
            training,
            loss_fn,
            fast_phase,
        )
        updates, opt_state = opt_update(
            lr,
            grads,
            opt_state,
            fast_params,
        )
        fast_params = ox.apply_updates(fast_params, updates)
        return fast_params, fast_state, opt_state, (loss, loss_aux, outputs)

    @use_self_as_default(
        "params",
        "state",
        "training",
        slow_phase=["training"],
    )
    def slow_apply(
        self,
        x,
        params=None,
        state=None,
        rng=None,
        training=None,
        slow_phase=None,
    ):
        if self.bmap:
            # TODO move this f to class methods
            # f = lambda _x: self.apply(
            #     params, state, rng_slow, _x, phase=slow_phase, training=training,
            # )
            s1, s2 = x.shape[:2]
            slow_outputs, slow_state = self.apply(
                params,
                state,
                rng,
                jax.tree_map(lambda _x: _x.reshape((s1 * s2, *_x.shape[2:])), x),
                phase=slow_phase,
                training=training,
            )
            slow_outputs = jax.tree_map(
                lambda _x: _x.reshape(s1, s2, *_x.shape[1:]), slow_outputs
            )
        else:
            if rng is not None:
                args = (split(rng, len(x)),)
            else:
                args = tuple()
            # TODO move the following f and vmap definition to class methods
            f = lambda _x, _state, _rng=None: self.apply(
                params, _state, _rng, _x, phase=slow_phase, training=training
            )
            slow_outputs, slow_state = jax.vmap(f)(x, state, *args)

        return slow_outputs, slow_state
    
    def expand(self, struct, size):
        if self.bmap:
            return struct
        return expand(struct, size)

    @use_self_as_default(
    "params",
    "state",
    "training",
    # slow_phase=["training"],
    # fast_phase=["training"],
    init_opt_state=["training"],
    # opt_update=["training"],
    )
    def inner_loop(
        self,
        x_spt,  # Batched x (outer batch, inner_batch, num_times, *x_shape)
        y_spt,
        params=None,
        state=None,
        rng=None,
        fast_params=None,
        fast_state=None,
        training=None,
        opt_state=None,
        lr=None,
        loss_fn=None,
        slow_phase=None,
        fast_phase=None,
        init_opt_state=None,
        opt_update=None,
        counter=None,
    ):
        if rng is None:
            rng_slow = rng_fast = None
        else:
            rng_slow, rng_fast = split(rng)

        slow_outputs, slow_state = self.slow_apply(
            x_spt[:, :, 0], params, state, rng_slow, training, slow_phase
        )
        slow_outputs = expand(slow_outputs, x_spt.shape[2], axis=2)

        bsz = x_spt.shape[0]
        if fast_params is None:
            fast_params = self.expand(self.get_fp(params, training), bsz)
        if fast_state is None:
            fast_state = self.expand(self.get_fs(merge(state, slow_state), training), bsz)
        if opt_state is None:
            opt_state = self.expand(init_opt_state(fast_params), bsz)
        if counter is None:
            counter = self.expand(jnp.zeros([]), bsz)
        
            
        carries = {
            "counter": counter,
            "fast_params": fast_params,
            "fast_state": fast_state,
            "opt_state": opt_state,
        }

        if rng is not None:
            rng = (split(rng_fast, len(x_spt)),)
        else:
            rng = tuple()

        def scan_fun(carry, xs):
            _fast_params = carry.get("fast_params", None)
            _fast_state = carry.get("fast_state", None)
            _opt_state = carry.get("opt_state", None)
            if len(xs) == 2:
                _x, _y = xs
                _rng = None
            else:
                _x, _y, _rng = xs
            (
                fast_params,
                fast_state,
                opt_state,
                (loss, loss_aux, outputs),
            ) = self.single_fast_step(
                _x,
                _y,
                params,
                slow_state,
                _rng,
                _fast_params,
                _fast_state,
                lr,
                training,
                loss_fn,
                init_opt_state,
                opt_update,
                _opt_state,
                fast_phase,
            )

            return {
                "fast_params": fast_params,
                "fast_state": fast_state,
                "opt_state": opt_state,
                "counter": carry["counter"] + 1,
            }, (loss, loss_aux, outputs)

        xss = (
            slow_outputs,
            y_spt,
            *rng,
        )

        if self.bmap:
            last_carry, (loss, loss_aux, outputs) = jax.vmap(lambda xs: jax.lax.scan(scan_fun, carries, xs))(xss)
        else:
            (loss, loss_aux, outputs) = jax.vmap(lambda carry, xs: jax.lax.scan(scan_fun, carry, xs))(
                carries, xss
            )

        return (loss, loss_aux, outputs)

In [45]:
meta_learner = BaseMetaLearner(model.apply, SimpleModel, bmap=False)
algo = meta_learner.inner_loop(
    random.normal(PRNGKey(0), (2, 2, 1, 2)),
    random.randint(PRNGKey(0), (2, 2, 1), 0, 2),
    params,
    # state,
    expand(state, 2),
    None,
    lr=-3e-2,
    # training=False,
)

ValueError: Incompatible shapes for broadcasting: ((2, 1, 8), (1, 1, 2))

In [38]:
algo

(DeviceArray([[0.6931472, 0.6611467],
              [0.6931472, 0.6611467]], dtype=float32),
 {'acc': DeviceArray([[1., 1.],
               [0., 1.]], dtype=float32)},
 DeviceArray([[[[ 0.        ,  0.        ]],
 
               [[-0.03244543,  0.03261348]]],
 
 
              [[[ 0.        ,  0.        ]],
 
               [[ 0.03244543, -0.03261348]]]], dtype=float32))

In [85]:
slow_state

FlatMapping({
  'simple_model/fast/batch_norm/~/mean_ema': FlatMapping({
                                               'average': DeviceArray([[[ 8.9406961e-07,  6.8545341e-07, -5.6624413e-07,
                                                                         -3.8743016e-07, -6.5565109e-07, -6.2584877e-07,
                                                                          2.3841856e-07, -9.8347664e-07]],
                                                          
                                                                       [[ 8.9406961e-07,  6.8545341e-07, -5.6624413e-07,
                                                                         -3.8743016e-07, -6.5565109e-07, -6.2584877e-07,
                                                                          2.3841856e-07, -9.8347664e-07]]], dtype=float32),
                                               'counter': DeviceArray([0, 0], dtype=int32),
                                               'hidden': Devi