In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import os.path as osp
import sys

if "../src" not in sys.path:
    sys.path.insert(0, "../src")

In [3]:
import numpy as onp

import jax
from jax import numpy as jnp, random
from jax.random import split, PRNGKey

import haiku as hk
from haiku.data_structures import merge
import optax as ox

from acme.jax.utils import prefetch

from helpers import SimpleModel
from utils.losses import mean_xe_and_acc_dict
from datasets import get_dataset
from utils import (
    shuffle_along_axis,
    use_self_as_default,
    split_rng_or_none,
    call_self_as_default,
    expand,
    tree_shape,
    first_leaf_shape,
    tree_flatten_array,
    mean_of_f,
    pmap_init,
)
from utils.data.datasets import (
    SimpleDataset,
    ImageDataset,
    MetaDataset,
    MetaDatasetArray,
)

from meta.trainers import MetaTrainerB
from meta.base import MetaBase
from meta.wrappers import MetaLearnerBaseB, ContinualLearnerB

from models.oml import OMLConvnet

from tqdm.autonotebook import tqdm
from easydict import EasyDict as edict



In [4]:
jax.local_devices()

[GpuDevice(id=0), GpuDevice(id=1)]

In [5]:
args = edict(
    image_size=84,
    cross_replica_axis="i" if jax.local_device_count() else None,
    batch_size=8,
    way=10,
    shot=5,
    qry_shot=5,
    cl_qry_shot=16,
    include_spt=False,
)

trainset = get_dataset("omniglot", "train", all=True, train=True, image_size=args.image_size)
meta_trainset = MetaDatasetArray(trainset, args.batch_size, args.way, args.shot,args.qry_shot, args.cl_qry_shot, disjoint=True)
meta_trainset_sampler = meta_trainset.get_sampler(PRNGKey(0))

In [6]:
model = hk.transform_with_state(
    lambda x, phase, training: OMLConvnet(
        cross_replica_axis=args.cross_replica_axis,
    )(x, phase, training)
)

dummy_input = tree_flatten_array(meta_trainset.get_sampler(PRNGKey(0)).__next__()[-2] / 255)

if args.cross_replica_axis is not None:
    _dummy_input = ((dummy_input))
    params, state = pmap_init(model, (PRNGKey(0),), dict(phase="all", training=True), jax.device_put_replicated(_dummy_input, jax.local_devices()))
else:
    _dummy_input = (dummy_input)
    params, state = model.init(PRNGKey(0), _dummy_input, "all", True)

In [7]:
def reset_params(w_make_fn, w_get, rng, params, spt_classes=None):
    ws = w_get(params)
    leaves, layout = jax.tree_flatten(ws)
    rng = split(rng, len(leaves))
    if spt_classes is None:
        leaves = [w_make_fn(dtype=w.dtype)(rng, (w.shape[0], spt_classes.shape[0])) for _rng, w in zip(rng, leaves)]
    else:
        leaves = [jax.ops.index_update(
                w,
                jax.ops.index[:, spt_classes],
                w_make_fn(dtype=w.dtype)(rng, (w.shape[0], len(spt_classes))),
            ) for _rng, w in zip(rng, leaves)
        ]
        
    return merge(params, jax.tree_unflatten(layout, leaves))

w_get = lambda params: hk.data_structures.filter(lambda module_name, name, value: (("CLS" in module_name) and (name == "w")), params)
zero_init = lambda dtype: lambda rng, shape: jax.nn.initializers.zeros(rng, dtype=dtype, shape=shape)

reset_fn = jax.partial(reset_params, zero_init, w_get)

In [8]:
continual_learner = ContinualLearnerB(
    model.apply,
    params=None,
    state=None,
    slow_phase="encoder",
    fast_phase="adaptation",
    get_slow_params=OMLConvnet.get_train_slow_params,
    get_fast_params=OMLConvnet.get_train_fast_params,
    get_slow_state=OMLConvnet.get_state,
    get_fast_state=OMLConvnet.get_state,
)

meta_trainer = MetaTrainerB(
    continual_learner,
    model.apply,
    params,
    state,
    "encoder",
    "adaptation",
    OMLConvnet.get_train_slow_params,
    OMLConvnet.get_train_fast_params,
    OMLConvnet.get_state,
    OMLConvnet.get_state,
    inner_lr=1e-2,
    train_lr=False,
    optimizer=ox.adam(1e-4),
    cross_replica_axis=args.cross_replica_axis,
    reset_fast_params=reset_fn,
    reset_before_outer_loop=False,
)

meta_trainer.replicate_state().initialize_opt_state()

<meta.trainers.MetaTrainerB at 0x7f8da6ce0750>

In [9]:
def _resize_batch_dim(array, num_devices=jax.local_device_count()):
    bsz = array.shape[0]
    assert (
        bsz % num_devices
    ) == 0, f"Batch size must be divisible but number of available devices, received batch size: {bsz} and num devices: {num_devices}"
    return array.reshape(num_devices, bsz // num_devices, *array.shape[1:])


def resize_batch_dim(struct, num_devices=jax.local_device_count()):
    return jax.tree_map(_resize_batch_dim, struct)

def flatten_dims(struct, dims=(1, 3)):
    return jax.tree_map(lambda t: t.reshape(*t.shape[:dims[0]], onp.prod(t.shape[dims[0]:dims[1]]), *t.shape[dims[1]:]), struct)

In [None]:
rng = PRNGKey(0)
prefetch_loader = prefetch(meta_trainset.get_sampler(PRNGKey(0)))

pbar = tqdm(range(10000))
for i in pbar:
    rng, rng_step = split(rng)
    x_spt, y_spt, x_qry, y_qry, x_qry_cl, y_qry_cl = next(prefetch_loader)
    x_spt, y_spt, x_qry, y_qry = resize_batch_dim((flatten_dims((x_spt, y_spt, x_qry, y_qry))))
    x_qry_cl, y_qry_cl = resize_batch_dim((x_qry_cl, y_qry_cl))

    out = meta_trainer.step(
        rng_step, i, x_spt, y_spt, x_qry, y_qry, x_qry_cl, y_qry_cl, 
    )
    
    if ((i % 50) == 0):
        pbar.set_postfix(
            loss=out[0].mean().item(),
            lr=meta_trainer.inner_lr.tolist(),
            ioa=out[1]["initial_outer_out"]["loss_aux"]["acc"].mean().item(),
            foa=out[1]["outer_out"]["loss_aux"]["acc"].mean().item(),
        )
    
    # print(i)

HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))

FlatMapping({
  'oml_convnet/CLS/linear': FlatMapping({'b': (4, 1000), 'w': (4, 2304, 1000)}),
})
Traced<ShapedArray(int32[4,10])>with<DynamicJaxprTrace(level=0/1)>
False
functools.partial(<function reset_params at 0x7f8d847954d0>, <function <lambda> at 0x7f8d847bb680>, <function <lambda> at 0x7f8d847bb710>)
Resetting params in outer loop
FlatMapping({
  'oml_convnet/CLS/linear': FlatMapping({'b': (4, 1000), 'w': (4, 2304, 1000)}),
})
Traced<ShapedArray(int32[4,10])>with<DynamicJaxprTrace(level=0/1)>
False
functools.partial(<function reset_params at 0x7f8d847954d0>, <function <lambda> at 0x7f8d847bb680>, <function <lambda> at 0x7f8d847bb710>)
Resetting params in outer loop


In [25]:
tree_shape(out)

((2,),
 {'initial_outer_out': {'fast_state': FlatMapping({}),
   'loss': (2, 4),
   'loss_aux': {'acc': (2, 4, 66), 'loss': (2, 4, 66, 1)},
   'outputs': (2, 4, 66, 1000),
   'slow_state': FlatMapping({})},
  'inner_out': {'fast_params': FlatMapping({
     'oml_convnet/CLS/linear': FlatMapping({'b': (2, 4, 1000), 'w': (2, 4, 2304, 1000)}),
   }),
   'fast_state': FlatMapping({}),
   'loss': (2, 4, 50),
   'loss_aux': {'acc': (2, 4, 50, 1), 'loss': (2, 4, 50, 1, 1)},
   'opt_state': [IdentityState(), ScaleState()],
   'outputs': (2, 4, 50, 1000),
   'slow_outputs': (2, 4, 50, 2304),
   'slow_state': FlatMapping({})},
  'outer_out': {'fast_state': FlatMapping({}),
   'loss': (2, 4),
   'loss_aux': {'acc': (2, 4, 264), 'loss': (2, 4, 264, 1)},
   'outputs': (2, 4, 264, 1000)}},
 ())

In [28]:
out[1]["inner_out"]["loss_aux"]

{'acc': ShardedDeviceArray([[[[1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
                       [1.],
       

In [22]:
y_spt.shape

(2, 4, 50)

In [39]:
len(jax.tree_flatten(meta_trainer.params)[0])

14

In [101]:
def reset_params(w_make_fn, w_get, rng, params, spt_classes=None):
    params = hk.data_structures.to_mutable_dict(params)
    ws = w_get(params)
    leaves, layout = jax.tree_flatten(ws)
    rng = split(rng, len(leaves))
    if spt_classes is None:
        leaves = [w_make_fn(dtype=w.dtype)(rng, (w.shape[0], spt_classes.shape[0])) for _rng, w in zip(rng, leaves)]
    else:
        leaves = [jax.ops.index_update(
                w,
                jax.ops.index[:, spt_classes],
                w_make_fn(dtype=w.dtype)(rng, (w.shape[0], len(spt_classes))),
            ) for _rng, w in zip(rng, leaves)
        ]
        
    return merge(params, jax.tree_unflatten(layout, leaves))

w_get = lambda params: hk.data_structures.filter(lambda module_name, name, value: (("CLS" in module_name) and (name == "w")), params)
zero_init = lambda dtype: lambda rng, shape: jax.nn.initializers.zeros(rng, dtype=dtype, shape=shape)

jax.partial(reset_fast_params, zero_init, w_get)

In [95]:
w_get = lambda params: hk.data_structures.filter(lambda module_name, name, value: (("CLS" in module_name) and (name == "w")), params)
zero_init = lambda dtype: lambda rng, shape: jax.nn.initializers.zeros(rng, dtype=dtype, shape=shape)

In [99]:
first_leaf_shape(meta_trainer.params)

(2, 1000)

In [100]:
jax.vmap(jax.partial(reset_fast_params, zero_init, w_get))(split(PRNGKey(0), first_leaf_shape(meta_trainer.params)[0]), meta_trainer.params, jnp.array([[0, 1], [2, 3]]))

[Traced<ShapedArray(float32[2304,1000])>with<BatchTrace(level=1/0)>
  with val = DeviceArray([[[ 0.        ,  0.        , -0.02601205, ..., -0.02460624,
                            -0.03653534, -0.06808325],
                           [ 0.        ,  0.        ,  0.03998324, ...,  0.00159086,
                            -0.00304701,  0.04260138],
                           [ 0.        ,  0.        , -0.00193443, ..., -0.04270816,
                             0.0260589 , -0.0608446 ],
                           ...,
                           [ 0.        ,  0.        , -0.05608036, ...,  0.00980459,
                            -0.03396349, -0.08356816],
                           [ 0.        ,  0.        ,  0.00652571, ..., -0.00032381,
                            -0.01352585, -0.01571198],
                           [ 0.        ,  0.        ,  0.01196475, ...,  0.00529244,
                            -0.00293344, -0.03101354]],
             
                          [[-0.03238982, -0.0

FlatMapping({
  'oml_convnet/CLS/linear': FlatMapping({
                              'b': ShardedDeviceArray([[ 0.0053436 ,  0.01119717,  0.00158418, ...,
                                                        -0.01550275, -0.01161904, -0.01545853],
                                                       [ 0.0053436 ,  0.01119717,  0.00158418, ...,
                                                        -0.01550275, -0.01161904, -0.01545853]],                   dtype=float32),
                              'w': DeviceArray([[[ 0.        ,  0.        , -0.02601205, ..., -0.02460624,
                                                  -0.03653534, -0.06808325],
                                                 [ 0.        ,  0.        ,  0.03998324, ...,  0.00159086,
                                                  -0.00304701,  0.04260138],
                                                 [ 0.        ,  0.        , -0.00193443, ..., -0.04270816,
                                          

In [24]:
onp.unique(y_spt, axis=-1)

array([[[122, 129, 132, 146, 186, 211, 222, 341, 554, 628],
        [346, 644, 635, 218, 348, 502, 148, 531, 241, 630],
        [510, 215, 457, 574, 114, 161, 285, 116, 396,  83],
        [586,  89, 373, 374, 639, 254, 284, 656, 234, 558]],

       [[242, 492, 505, 153,  67, 417,   0,  64, 596, 156],
        [306, 313, 583, 608,  92, 526, 436, 149,  33, 624],
        [576,  44, 438, 372, 662, 333, 625, 401, 647, 447],
        [400,  91, 220,  90, 573, 316, 607,  26, 310,  60]]])