In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import os.path as osp


sys.path.insert(0, "../src")

import dill
import time
from tqdm import tqdm
from argparse import ArgumentParser
from easydict import EasyDict as edict

import numpy as onp

import jax
from jax.random import split
from jax import (
    jit,
    pmap,
    vmap,
    random,
    partial,
    tree_map,
    numpy as jnp,
    value_and_grad,
)

import optax as ox
import haiku as hk

from config import rsetattr
from lib import (
    flatten,
    outer_loop,
    setup_device,
    fsl_inner_loop,
    batched_outer_loop,
    parse_and_build_cfg,
    mean_xe_and_acc_dict,
    meta_step,
    reset_by_idxs,
    reset_all,
    delayed_cosine_decay_schedule,
    # outer_loop_reset_per_task,
)
from data import prepare_data, augment as augment_fn, preprocess_images
from experiment import Experiment, Logger
from data.sampling import fsl_sample, fsl_build, BatchSampler

# from models.maml_conv import miniimagenet_cnn_argparse, prepare_model, make_params
from models import make_params, prepare_model
from test_utils import test_fsl_maml, test_fsl_embeddings
from test_sup import test_sup_cosine

In [3]:
from mrcl_experiment import MetaLearner, replicate_array

In [4]:
import chex

In [5]:
cfg = edict()

cfg.random_seed = 0
cfg.dataset = "miniimagenet"

cfg.model = model_cfg = edict()
cfg.train = train_cfg = edict()

model_cfg.model_name = "convnet4"
model_cfg.output_size = 64
model_cfg.avg_pool = True
model_cfg.activation = "relu"
model_cfg.initializer = "kaiming_normal"
model_cfg.hidden_size = 4

In [6]:
cfg.train = train_cfg = edict()

train_cfg.prefetch = 0
train_cfg.batch_size = 2
train_cfg.way = 5
train_cfg.shot = 5
train_cfg.qry_shot = 15
train_cfg.data_root = "/Users/sebamenabar/Downloads/miniImageNet"
train_cfg.method = "maml"
train_cfg.sub_batch_size = None

train_cfg.outer_lr = 1e-3
train_cfg.inner_lr = 1e-2
train_cfg.learn_inner_lr = False

train_cfg.scheduler = "cosine"
train_cfg.cosine_alpha = 0.05
train_cfg.cosine_decay_steps = 10000
train_cfg.cosine_transition_begin = 10000

train_cfg.augment = "all"
train_cfg.num_inner_steps = 5

# train_cfg.num_inner_steps = 2
train_cfg.reset_head = "all-kaiming"

In [75]:
import chex

In [10]:
rng = jax.random.PRNGKey(0)

In [14]:
# train_cfg.batch_size = 2
train_cfg.way = 5
train_cfg.shot = 5
train_cfg.qry_shot = 15

model_cfg.track_stats = "none"
model_cfg.normalize = "bn"
train_cfg.batch_size = 4
train_cfg.sub_batch_size = 2

model_cfg.model_name = "convnet4"
model_cfg.output_size = 5
model_cfg.avg_pool = True
model_cfg.activation = "relu"
model_cfg.initializer = "glorot_uniform"
model_cfg.hidden_size = 64


meta_learner = MetaLearner(
    cfg.random_seed, cfg.dataset, train_cfg.data_root, model_cfg, train_cfg, None,
)
out = meta_learner.step(global_step=0, rng=rng)

False
Applying gradient every 2 sub steps with sub batch size 2
Reset classifier to kaiming
Reset all classifier
Initializing parameters rather than restoring from checkpoint.
start of preprocess images

Augmenting support and query
end of preprocess images
Before grad
in batch outer loop
in outer loop
Not resetting bias
Resetting all classifier
before slow outputs
after slow outputs
inner loop
after slow outputs
begin inner loop 0
end inner loop 0
begin inner loop 1
end inner loop 1
begin inner loop 2
end inner loop 2
begin inner loop 3
end inner loop 3
begin inner loop 4
end inner loop 4
end of inner loop
after inner loop
outer loop end
batch outer loop end
After grad
start of preprocess images

Augmenting support and query
end of preprocess images
Before grad
in batch outer loop
in outer loop
Not resetting bias
Resetting all classifier
before slow outputs
after slow outputs
inner loop
after slow outputs
begin inner loop 0
end inner loop 0
begin inner loop 1
end inner loop 1
begin in

In [88]:
meta_learner.update_pmap

<function mrcl_experiment.MetaLearner._update_fn(global_step, rng, inputs, spt_classes, normalize_fn, augment, augment_fn, num_inner_steps, slow_apply, fast_apply, opt_update_fn, reset_fast_params_fn, learn_inner_lr, optimizer)>

In [16]:
out = meta_learner.step(global_step=0, rng=out[0])

In [93]:
out

(array([1845241823, 3895899934], dtype=uint32),
 {'inner': {'auxs': [{'acc': ShardedDeviceArray([[[0., 0., 0.],
                          [0., 0., 0.]]], dtype=float32)}],
   'losses': ShardedDeviceArray([[[4.2050285, 4.1799536, 4.1549315],
                        [4.3527384, 4.3320217, 4.3113647]]], dtype=float32)},
  'outer': {'final': {'aux': [{'acc': ShardedDeviceArray([[0., 0.]], dtype=float32)}],
    'loss': ShardedDeviceArray([[4.0473223, 5.0362635]], dtype=float32)},
   'initial': {'aux': [{'acc': ShardedDeviceArray([[0., 0.]], dtype=float32)}],
    'loss': ShardedDeviceArray([[4.0826473, 5.06845  ]], dtype=float32)}}})

In [102]:
out = meta_learner.step(global_step=0, rng=rng)


Augmenting support and query
Not resetting bias
Resetting all classifier


In [104]:
inner_scalars = jax.tree_map(lambda x: jnp.mean(x, (0, 1)), out[1]["inner"]) 

In [105]:
inner_scalars

{'auxs': [{'acc': DeviceArray([0., 0., 0.], dtype=float32)}],
 'losses': DeviceArray([4.6677713, 4.6443605, 4.6209974], dtype=float32)}

In [108]:
meta_learner._apply_every

2

In [107]:
meta_learner._learner_state.opt_state

[ClipState(),
 ScaleState(),
 ApplyEvery(count=ShardedDeviceArray([0], dtype=int32), grad_acc=(frozendict({
   'mini_imagenet_cnn_body/conv_base/conv_block/batch_norm': frozendict({
                                                               'offset': ShardedDeviceArray([[[[[-0.02147533, -0.06235676, -0.03157339,
                                                                                                 -0.04902368]]]]], dtype=float32),
                                                               'scale': ShardedDeviceArray([[[[[ 0.086115  , -0.10844006,  0.07221012,
                                                                                                -0.04862137]]]]], dtype=float32),
                                                             }),
   'mini_imagenet_cnn_body/conv_base/conv_block/conv2_d': frozendict({
                                                            'b': ShardedDeviceArray([[-0.07355435, -0.0680431 , -0.08339171,  0.00129735]], dtype=floa

In [106]:
meta_learner._scheduler(0)

DeviceArray(-0.001, dtype=float32)

In [101]:
jax.tree_multimap(lambda x, *xs: jnp.concatenate(xs, 1), out[1][0], *out[1])

{'inner': {'auxs': [{'acc': DeviceArray([[[0., 0., 0.],
                  [0., 0., 0.],
                  [0., 0., 0.],
                  [0., 0., 0.]]], dtype=float32)}],
  'losses': DeviceArray([[[3.95506  , 3.937397 , 3.9197636],
                [4.1650476, 4.1399183, 4.1149   ],
                [4.3137956, 4.289889 , 4.2660275],
                [3.8309817, 3.8110113, 3.7911274]]], dtype=float32)},
 'outer': {'final': {'aux': [{'acc': DeviceArray([[0., 0., 0., 0.]], dtype=float32)}],
   'loss': DeviceArray([[3.7534647, 5.578721 , 3.5377736, 5.697146 ]], dtype=float32)},
  'initial': {'aux': [{'acc': DeviceArray([[0., 0., 0., 0.]], dtype=float32)}],
   'loss': DeviceArray([[3.787255 , 5.6071205, 3.5723367, 5.722905 ]], dtype=float32)}}}

In [44]:
# class MetaMiniImageNet:
#     def __init__(
#         self, rng, split, data_root, batch_size, way, shot, qry_shot, shuffled_labels=True,
#     ):
#         self._rng = rng
#         self._batch_size = batch_size
#         self._way = way
#         self._shot = shot
#         self._qry_shot = qry_shot
#         self._shuffled_labels = shuffled_labels
        
#         if split == "train":
#             self._fp = osp.join(data_root, "miniImageNet_category_split_train_phase_train_ordered.pickle")
#         self._images, self._labels, self._normalize = prepare_data("miniimagenet", self._fp)
        
#         if split == "val":
#             self._labels = self._labels - 64
            
#         self.fsl_sample = partial(fsl_sample, images=self._images, labels=self._labels, num_tasks=batch_size, way=way, spt_shot=shot, qry_shot=qry_shot, disjoint=False, shuffled_labels=shuffled_labels)
#         self.fsl_build = partial(fsl_build, batch_size=batch_size, way=way, shot=shot, qry_shot=qry_shot)
        
#     def __next__(self):
#         self._rng, rng = split(self._rng)
#         return self.fsl_build(*self.fsl_sample(rng))

In [48]:
# val_dataset = MetaMiniImageNet(
#     random.PRNGKey(0),
#     "val",
#     "/Users/sebamenabar/Downloads/miniImageNet/",
#     2, 5, 5, 15, shuffled_labels=True,
# )

AttributeError: 'MetaMiniImageNet' object has no attribute '_fp'

In [49]:
from mrcl_experiment import MetaMiniImageNet
from eval_experiment import MAMLTester, LRTester

In [50]:
val_dataset = MetaMiniImageNet(
    random.PRNGKey(0),
    "val",
    "/Users/sebamenabar/Downloads/miniImageNet/",
    2, 5, 5, 15, shuffled_labels=True,
)

In [51]:
lr_tester = LRTester(meta_learner._encoder.apply, 4, 2, val_dataset, 5, dataset._normalize)

In [52]:
_learner_state = meta_learner.get_first_state()

In [66]:
lr_acc, lr_std = lr_tester.eval(_learner_state.slow_params, _learner_state.slow_state)

Augmenting testing samples
Augmenting testing samples


In [67]:
lr_acc, lr_std

(0.29333333333333333, 0.05416025603090639)

In [54]:
maml_tester = MAMLTester(
    meta_learner._encoder.apply,
    meta_learner._classifier.apply,
    4,
    2,
    val_dataset,
    3,
    1,
    dataset._normalize,
)

In [72]:
maml_acc, maml_std = maml_tester.eval(_learner_state.slow_params,
    _learner_state.fast_params,
    jax.tree_map(partial(replicate_array, num_devices=2),_learner_state.slow_state),
    jax.tree_map(partial(replicate_array, num_devices=2),_learner_state.fast_state),
    _learner_state.inner_lr,)

Augmenting testing samples
Augmenting testing samples


In [73]:
maml_acc, maml_std

(DeviceArray(0., dtype=float32), DeviceArray(0., dtype=float32))

In [70]:
jax.tree_multimap(
    lambda x, *xs: jnp.concatenate(xs),
    test_results[0],
    *test_results,
)

{'inner': {'auxs': [{'acc': DeviceArray([[0., 0., 0., 0.],
                 [0., 0., 0., 0.],
                 [0., 0., 0., 0.],
                 [0., 0., 0., 0.]], dtype=float32)}],
  'losses': DeviceArray([[4.246692 , 4.240478 , 4.2342696, 4.228069 ],
               [4.2454734, 4.2392335, 4.233    , 4.2267723],
               [4.2725606, 4.266361 , 4.2601705, 4.253983 ],
               [4.193775 , 4.1875887, 4.181411 , 4.175238 ]], dtype=float32)},
 'outer': {'final': {'aux': [{'acc': DeviceArray([0.        , 0.01333333, 0.02666667, 0.01333333], dtype=float32)}],
   'loss': DeviceArray([4.195467 , 4.211032 , 4.229458 , 4.2169127], dtype=float32)},
  'initial': {'aux': [{'acc': DeviceArray([0.        , 0.01333333, 0.01333333, 0.01333333], dtype=float32)}],
   'loss': DeviceArray([4.2137237, 4.229183 , 4.2471437, 4.235965 ], dtype=float32)}}}

In [97]:
jax.tree_multimap(lambda x, *xs: jnp.concatenate(xs), test_results[0], *test_results)

{'inner': {'auxs': [{'acc': DeviceArray([[0.02, 0.04, 0.06, 0.06],
                 [0.  , 0.  , 0.  , 0.  ],
                 [0.  , 0.  , 0.  , 0.  ],
                 [0.  , 0.  , 0.  , 0.  ]], dtype=float32)}],
  'losses': DeviceArray([[4.150457 , 4.1422462, 4.1340485, 4.125866 ],
               [4.2803736, 4.273617 , 4.2668643, 4.2601204],
               [4.2764935, 4.2698545, 4.2632217, 4.256597 ],
               [4.301257 , 4.2938795, 4.286512 , 4.2791514]], dtype=float32)},
 'outer': {'final': {'aux': [{'acc': DeviceArray([0.06666667, 0.        , 0.        , 0.        ], dtype=float32)}],
   'loss': DeviceArray([4.159359, 4.256447, 4.287693, 4.30843 ], dtype=float32)},
  'initial': {'aux': [{'acc': DeviceArray([0.01333333, 0.        , 0.        , 0.        ], dtype=float32)}],
   'loss': DeviceArray([4.182653 , 4.27624  , 4.3077197, 4.3309455], dtype=float32)}}}

In [103]:
meta_learner.predict(x_spt[0] / 255)

(25, 64)

In [46]:
_learner_state = meta_learner.get_first_state()

tester.batch_adapt_jit(
    random.PRNGKey(0),
    _learner_state.slow_params,
    _learner_state.fast_params,
    _learner_state.slow_state,
    _learner_state.fast_state,
    _learner_state.inner_lr,
    x_spt, y_spt, x_qry, y_qry,
)

Augmenting testing samples


{'inner': {'auxs': [{'acc': DeviceArray([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
                 [0.  , 0.  , 0.02, 0.04, 0.08, 0.08]], dtype=float32)}],
  'losses': DeviceArray([[4.1769824, 4.1704135, 4.163852 , 4.157297 , 4.150749 ,
                4.1442113],
               [4.0993667, 4.093079 , 4.0867996, 4.080528 , 4.074263 ,
                4.0680065]], dtype=float32)},
 'outer': {'final': {'aux': [{'acc': DeviceArray([0.        , 0.06666667], dtype=float32)}],
   'loss': DeviceArray([4.1432977, 4.0433207], dtype=float32)},
  'initial': {'aux': [{'acc': DeviceArray([0., 0.], dtype=float32)}],
   'loss': DeviceArray([4.177068 , 4.0740485], dtype=float32)}}}

In [105]:
from test_utils import SupervisedStandardTester, 

In [110]:
from refactor.eval_experiment import LRTester

In [107]:
dataset._images.shape

(64, 600, 84, 84, 3)

In [109]:
std_tester = SupervisedStandardTester(random.PRNGKey(0), flatten(dataset._images, (0, 1)), flatten(dataset._labels, (0, 1)), 64, meta_learner._normalize_fn)

@jit
def fun(inputs):
    inputs = inputs / 255
    inputs = meta_learner._normalize_fn(inputs)
    return meta_learner.predict(inputs)

std_tester(fun)

(DeviceArray(4.6710153, dtype=float32), 0.020416666)

In [13]:
x_spt, y_spt, x_qry, y_qry = next(dataset)

In [18]:
x_spt.repeat(4, 1).shape

(2, 100, 84, 84, 3)

In [19]:
y_spt.shape

(2, 25)