In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import os.path as osp
import sys

if "../src" not in sys.path:
    sys.path.insert(0, "../src")

In [3]:
import numpy as onp

import jax
from jax import numpy as jnp, random
from jax.random import split, PRNGKey

import haiku as hk
from haiku.data_structures import merge
import optax as ox

from acme.jax.utils import prefetch

from helpers import SimpleModel
from utils.losses import mean_xe_and_acc_dict
from datasets import get_dataset
from utils import (
    shuffle_along_axis,
    use_self_as_default,
    split_rng_or_none,
    call_self_as_default,
    expand,
    tree_shape,
    first_leaf_shape,
    tree_flatten_array,
    mean_of_f,
    pmap_init,
)
from utils.data.datasets import (
    SimpleDataset,
    ImageDataset,
    MetaDataset,
    MetaDatasetArray,
)

from meta.trainers import MetaTrainerB
from meta.base import MetaBase
from meta.wrappers import MetaLearnerBaseB, ContinualLearnerB

from models.oml import OMLConvnet

from tqdm.autonotebook import tqdm
from easydict import EasyDict as edict



In [4]:
jax.local_devices()

[GpuDevice(id=0), GpuDevice(id=1)]

In [5]:
args = edict(
    image_size=84,
    cross_replica_axis="i" if jax.local_device_count() else None,
    batch_size=8,
    way=10,
    shot=5,
    qry_shot=5,
    cl_qry_shot=16,
    include_spt=False,
)

trainset = get_dataset("omniglot", "train", all=True, train=True, image_size=args.image_size)
meta_trainset = MetaDatasetArray(trainset, args.batch_size, args.way, args.shot,args.qry_shot, args.cl_qry_shot, disjoint=True)
meta_trainset_sampler = meta_trainset.get_sampler(PRNGKey(0))

In [6]:
model = hk.transform_with_state(
    lambda x, phase, training: OMLConvnet(
        cross_replica_axis=args.cross_replica_axis,
    )(x, phase, training)
)

dummy_input = tree_flatten_array(meta_trainset.get_sampler(PRNGKey(0)).__next__()[-2] / 255)

if args.cross_replica_axis is not None:
    _dummy_input = ((dummy_input))
    params, state = pmap_init(model, (PRNGKey(0),), dict(phase="all", training=True), jax.device_put_replicated(_dummy_input, jax.local_devices()))
else:
    _dummy_input = (dummy_input)
    params, state = model.init(PRNGKey(0), _dummy_input, "all", True)

In [17]:
continual_learner = ContinualLearnerB(
    model.apply,
    params=None,
    state=None,
    slow_phase="encoder",
    fast_phase="adaptation",
    get_slow_params=OMLConvnet.get_train_slow_params,
    get_fast_params=OMLConvnet.get_train_fast_params,
    get_slow_state=OMLConvnet.get_state,
    get_fast_state=OMLConvnet.get_state,
)

meta_trainer = MetaTrainerB(
    continual_learner,
    model.apply,
    params,
    state,
    "encoder",
    "adaptation",
    OMLConvnet.get_train_slow_params,
    OMLConvnet.get_train_fast_params,
    OMLConvnet.get_state,
    OMLConvnet.get_state,
    inner_lr=1e-2,
    train_lr=False,
    optimizer=ox.adam(1e-4),
    cross_replica_axis=args.cross_replica_axis,
)

meta_trainer.replicate_state().initialize_opt_state()

<meta.trainers.MetaTrainerB at 0x7fd68c212810>

In [8]:
def _resize_batch_dim(array, num_devices=jax.local_device_count()):
    bsz = array.shape[0]
    assert (
        bsz % num_devices
    ) == 0, f"Batch size must be divisible but number of available devices, received batch size: {bsz} and num devices: {num_devices}"
    return array.reshape(num_devices, bsz // num_devices, *array.shape[1:])


def resize_batch_dim(struct, num_devices=jax.local_device_count()):
    return jax.tree_map(_resize_batch_dim, struct)

def flatten_dims(struct, dims=(1, 3)):
    return jax.tree_map(lambda t: t.reshape(*t.shape[:dims[0]], onp.prod(t.shape[dims[0]:dims[1]]), *t.shape[dims[1]:]), struct)

In [19]:
rng = PRNGKey(0)
prefetch_loader = prefetch(meta_trainset.get_sampler(PRNGKey(0)))

pbar = tqdm(range(10000))
for i in pbar:
    rng, rng_step = split(rng)
    x_spt, y_spt, x_qry, y_qry, x_qry_cl, y_qry_cl = next(prefetch_loader)
    x_spt, y_spt, x_qry, y_qry = resize_batch_dim((flatten_dims((x_spt, y_spt, x_qry, y_qry))))
    x_qry_cl, y_qry_cl = resize_batch_dim((x_qry_cl, y_qry_cl))

    out = meta_trainer.step(
        rng_step, i, x_spt, y_spt, x_qry, y_qry, x_qry_cl, y_qry_cl, 
    )
    
    if ((i % 50) == 0):
        pbar.set_postfix(
            loss=out[0].mean().item(),
            lr=meta_trainer.inner_lr.tolist(),
            ioa=out[1]["initial_outer_out"]["loss_aux"]["acc"].mean().item(),
            foa=out[1]["outer_out"]["loss_aux"]["acc"].mean().item(),
        )
    
    # print(i)

HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))


