In [1]:
%load_ext autoreload
%autoreload 2

In [97]:
from trainers.meta_trainer import MetaTrainer

In [98]:
import os
import sys

sys.path.insert(0, "../src")

In [162]:
from tqdm.notebook import tqdm as tqdmn
from easydict import EasyDict as edict

import jax
from jax.random import split
from jax import random, numpy as jnp, jit, value_and_grad, partial, vmap, tree_multimap

import optax as ox

from data.sampling import fsl_sample_and_build
from models.maml_conv import prepare_model, make_params
from data import prepare_data
from lib import setup_device, make_fsl_inner_outer_loop, mean_xe_and_acc_dict, make_batched_outer_loop

In [100]:
cfg = edict()

In [125]:
cfg.hidden_size = 32
cfg.activation = "relu"
cfg.way = 5
cfg.shot = 5
cfg.qry_shot = 10
cfg.batch_size = 5
cfg.val_num_tasks = 1000
cfg.inner_lr = 1e-2
cfg.outer_lr = 1e-3
cfg.num_outer_steps = 10000
cfg.num_inner_steps = 5
cfg.disjoint_tasks = False
cfg.disable_jit = False
cfg.dataset = "miniimagenet"
# cfg.data_dir = "/home/samenabar/storage/data/FSL/mini-imagenet/"
cfg.progress_bar_refresh_rate = 50
cfg.train_method = "fsl"
cfg.track_bn_stats = False
cfg.prefetch_data_gpu = False
cfg.meta_batch_size = 5
cfg.data_dir = "/workspace1/samenabar/data/mini-imagenet/"

cfg.gpus = 1

jit_enabled = not cfg.disable_jit

In [102]:
cpu, device = setup_device(cfg.gpus, default_platform="cpu")
rng = random.PRNGKey(0)
device

GpuDevice(id=0)

In [103]:
train_images, train_labels, val_images, val_labels, preprocess_fn = prepare_data(
    cfg.dataset, cfg.data_dir, device,
)

In [9]:
print("Train data:", train_images.shape, train_labels.shape)
print("Val data:", val_images.shape, val_labels.shape)

Train data: (64, 600, 84, 84, 3) (64, 600)
Val data: (16, 600, 84, 84, 3) (16, 600)


In [104]:
loss_fn = mean_xe_and_acc_dict

In [126]:
trainer = MetaTrainer(rng, cfg, device=device)

In [127]:
body, head = prepare_model(cfg.dataset, cfg.way, cfg.hidden_size, cfg.activation)

In [128]:
(
    slow_params,
    fast_params,
    slow_state,
    fast_state,
) = make_params(
    rng,
    cfg.dataset,
    body.init,
    body.apply,
    head.init,
    device,
)

In [129]:
x, y = fsl_sample_and_build(
    rng_sample,
    train_images,
    train_labels,
    cfg.batch_size,
    cfg.way,
    cfg.shot,
    cfg.qry_shot,
    cfg.disjoint_tasks,
)
x = preprocess_fn(jax.device_put(x, device))
y = jax.device_put(y, device)

image_shape = x.shape[-3:]
x_spt, x_qry = jnp.split(x, (cfg.shot,), 2)
x_spt = x_spt.reshape(cfg.batch_size, cfg.way * cfg.shot, *image_shape)
x_qry = x_qry.reshape(cfg.batch_size, cfg.way * cfg.qry_shot, *image_shape)
y_spt, y_qry = jnp.split(y, (cfg.shot,), 2)
y_spt = y_spt.reshape(cfg.batch_size, cfg.way * cfg.shot)
y_qry = y_qry.reshape(cfg.batch_size, cfg.way * cfg.qry_shot)

In [118]:
x_spt.shape

(5, 25, 84, 84, 3)

In [169]:
def fsl_inner_loop(
    slow_params,
    fast_params,
    slow_state,
    fast_state,
    opt_state,
    rng,
    inputs,
    targets,
    is_training,
    num_steps,
    slow_apply,
    fast_apply,
    loss_fn,
    opt_update_fn,
    update_state=False,
    return_history=True,
):
    _fast_apply_and_loss_fn = partial(
        fast_apply_and_loss_fn, fast_apply=fast_apply, loss_fn=loss_fn
    )
    rng_slow, *rngs = split(rng, num_steps + 2)
    slow_outputs, slow_state = slow_apply(
        slow_params, slow_state, rng_slow, inputs, is_training,
    )
    
    losses = []
    auxs = []

    for i in range(num_steps):
        (loss, (new_fast_state, *aux)), grads = value_and_grad(
            _fast_apply_and_loss_fn, has_aux=True
        )(fast_params, fast_state, rngs[i], slow_outputs, is_training, targets)
        if update_state:
            fast_state = new_fast_state
        if i == 0:
            initial_loss = loss
            initial_aux = aux
            
        if return_history:
            losses.append(loss)
            auxs.append(aux)
            
        updates, opt_state = opt_update_fn(grads, opt_state, fast_params)
        fast_params = ox.apply_updates(fast_params, updates)

    final_loss, (final_fast_state, *final_aux) = _fast_apply_and_loss_fn(
        fast_params, fast_state, rngs[i + 1], slow_outputs, is_training, targets,
    )
    if return_history:
        losses.append(final_loss)
        auxs.append(final_aux)
        info = {"losses": jnp.stack(losses), "auxs": tree_multimap(lambda x, *xs: jnp.stack(xs), auxs[0], *auxs)}
    else:
        info = {
            "losses": {"initial": initial_loss, "final": final_loss},
            "auxs": {"initial": initial_aux, "final": final_aux},
        },

    return (
        fast_params,
        slow_state,
        fast_state,
        opt_state,
        info,
    )

In [170]:
inner_loop_ins = partial(fsl_inner_loop, is_training=True, num_steps=10, slow_apply=body.apply, fast_apply=head.apply, loss_fn=loss_fn, opt_update_fn=trainer.inner_opt.update)

inner_loop_ins = jit(inner_loop_ins)

In [176]:
inner_loop_ins(slow_params, fast_params, slow_state, fast_state, trainer.inner_opt.init(fast_params), rng, x_spt[0], y_spt[0])

(FlatMapping({
   'mini_imagenet_cnn_head/linear': FlatMapping({
                                      'b': DeviceArray([ 0.0013707 , -0.0032852 , -0.00303768,  0.00307852,
                                                         0.00187366], dtype=float32),
                                      'w': DeviceArray([[-0.02074367,  0.0250273 , -0.04737075,  0.07665098,
                                                         -0.05147767],
                                                        [ 0.01456404,  0.01268667,  0.08201264,  0.07178478,
                                                         -0.03569538],
                                                        [-0.00321351, -0.03089465,  0.00315202, -0.04079141,
                                                         -0.03144538],
                                                        ...,
                                                        [ 0.08402209,  0.06583858,  0.07225818, -0.07058419,
                               

In [179]:
def outer_loop(
    slow_params,
    fast_params,
    slow_state,
    fast_state,
    inner_opt_state,
    rng,
    x_spt,
    y_spt,
    x_qry,
    y_qry,
    is_training,
    inner_loop, # instantiated inner_loop
    slow_apply,
    fast_apply,
    loss_fn,
):
    _fast_apply_and_loss_fn = partial(
        fast_apply_and_loss_fn, fast_apply=fast_apply, loss_fn=loss_fn
    )
    rng_outer_slow, rng_outer_fast, rng_inner = split(rng, 3)
    slow_outputs, initial_slow_state = slow_apply(
        slow_params, slow_state, rng_outer_slow, x_qry, is_training,
    )
    initial_loss, (_, *initial_aux) = _fast_apply_and_loss_fn(
        fast_params, fast_state, rng_outer_fast, slow_outputs, False, y_qry,
    )
    
    fast_params, inner_slow_state, fast_state, inner_opt_state, inner_auxs = inner_loop(
        slow_params,
        fast_params,
        slow_state,
        fast_state,
        inner_opt_state,
        rng_inner,
        x_spt,
        y_spt,
    )
    final_loss, (final_fast_state, *final_aux) = _fast_apply_and_loss_fn(
        fast_params, fast_state, rng_outer_fast, slow_outputs, is_training, y_qry, 
    )
    return (
        final_loss,
        (
            initial_slow_state,
            fast_state,
            {
                "inner": inner_auxs,
                "outer": {
                    "initial": {"aux": initial_aux, "loss": initial_loss},
                    "final": {"aux": final_aux, "loss": final_loss},
                },
            },
        ),
    )

In [184]:
outer_loop_ins = partial(outer_loop, is_training=True, inner_loop=inner_loop_ins, slow_apply=body.apply, fast_apply=head.apply, loss_fn=loss_fn)

outer_loop_ins = jit(outer_loop_ins)

In [200]:
outer_loop_ins(
    slow_params, fast_params,
    slow_state, fast_state,
    trainer.inner_opt.init(fast_params),
    rng,
    x_spt[0], y_spt[0],
    x_qry[0], y_qry[0],
)

(DeviceArray(1.5807741, dtype=float32),
 (FlatMapping({}),
  FlatMapping({}),
  {'inner': {'auxs': [{'acc': DeviceArray([0.19999999, 0.16      , 0.32      , 0.32      , 0.39999998,
                   0.48      , 0.59999996, 0.68      , 0.68      , 0.68      ,
                   0.71999997], dtype=float32)}],
    'losses': DeviceArray([3.1672387 , 2.0957406 , 1.6780807 , 1.5082569 , 1.3795255 ,
                 1.263914  , 1.159855  , 1.0664941 , 0.982954  , 0.90832037,
                 0.8416744 ], dtype=float32)},
   'outer': {'final': {'aux': [{'acc': DeviceArray(0.22, dtype=float32)}],
     'loss': DeviceArray(1.5807741, dtype=float32)},
    'initial': {'aux': [{'acc': DeviceArray(0.14, dtype=float32)}],
     'loss': DeviceArray(3.3284416, dtype=float32)}}}))

In [204]:
def batched_outer_loop(slow_params, fast_params, slow_state, fast_state, inner_opt_state, brng, bx_spt, by_spt, bx_qry, by_qry, outer_loop):
    losses, aux = vmap(partial(outer_loop, slow_params, fast_params, slow_state, fast_state, inner_opt_state))(brng, bx_spt, by_spt, bx_qry, by_qry)
    return losses.mean(), aux

In [208]:
batched_outer_loop_ins = partial(batched_outer_loop, outer_loop=outer_loop_ins)
batched_outer_loop_ins = jit(batched_outer_loop_ins)

In [220]:
loss, (_, _, info) = batched_outer_loop_ins(slow_params, fast_params, slow_state, fast_state, trainer.inner_opt.init(fast_params), split(rng, x_spt.shape[0]), x_spt, y_spt, x_qry, y_qry)

In [241]:
def step(
    rng,
    step_num,
    outer_opt_state,
    slow_params,
    fast_params,
    slow_state,
    fast_state,
    x_spt,
    y_spt,
    x_qry,
    y_qry,
    inner_opt_init,
    outer_opt_update,
    batched_outer_loop_ins,
):
    inner_opt_state = inner_opt_init(fast_params)

    (outer_loss, (slow_state, fast_state, info)), grads = value_and_grad(
        batched_outer_loop_ins, (0, 1), has_aux=True
    )(
        slow_params,
        fast_params,
        slow_state,
        fast_state,
        inner_opt_state,
        split(rng, x_spt.shape[0]),
        x_spt,
        y_spt,
        x_qry,
        y_qry,
    )
    updates, outer_opt_state = outer_opt_update(
        grads, outer_opt_state, (slow_params, fast_params)
    )
    slow_params, fast_params = ox.apply_updates(
        (slow_params, fast_params), updates
    )

    return outer_opt_state, slow_params, fast_params, slow_state, fast_state, info



In [242]:
slow_params, fast_params, slow_state, fast_state = make_params(
    rng,
    cfg.dataset,
    body.init,
    body.apply,
    head.init,
    device,
)
slow_params, fast_params, slow_state, fast_state = map(lambda x: jax.device_put(x, device), (slow_params, fast_params, slow_state, fast_state))

inner_opt = ox.sgd(cfg.inner_lr)
lr_schedule = ox.cosine_decay_schedule(-cfg.outer_lr, cfg.num_outer_steps, 0.1)
outer_opt = ox.chain(
    ox.clip(10),
    ox.scale_by_adam(),
    ox.scale_by_schedule(lr_schedule),
)
outer_opt_state = outer_opt.init((slow_params, fast_params))

inner_loop_ins = partial(
    fsl_inner_loop,
    is_training=True,
    num_steps=10,
    slow_apply=body.apply,
    fast_apply=head.apply,
    loss_fn=loss_fn,
    opt_update_fn=inner_opt.update,
)
outer_loop_ins = partial(
    outer_loop,
    is_training=True,
    inner_loop=inner_loop_ins,
    slow_apply=body.apply,
    fast_apply=head.apply,
    loss_fn=loss_fn,
)
batched_outer_loop_ins = partial(batched_outer_loop, outer_loop=outer_loop_ins)

step_ins = jit(
    partial(
        step,
        inner_opt_init=inner_opt.init,
        outer_opt_update=trainer.outer_opt.update,
        batched_outer_loop_ins=batched_outer_loop_ins,
    ),
    # static_argnums=(11, 12, 13),
)

In [260]:
pbar = tqdmn(
    range(10000),
)

for i in pbar:
        rng, rng_sample, rng_step = split(rng, 3)
        x, y = fsl_sample_and_build(
            rng_sample,
            train_images,
            train_labels,
            cfg.batch_size,
            cfg.way,
            cfg.shot,
            cfg.qry_shot,
            cfg.disjoint_tasks,
        )
        x = preprocess_fn(jax.device_put(x, device))
        y = jax.device_put(y, device)
        
        image_shape = x.shape[-3:]
        x_spt, x_qry = jnp.split(x, (cfg.shot,), 2)
        x_spt = x_spt.reshape(cfg.batch_size, cfg.way * cfg.shot, *image_shape)
        x_qry = x_qry.reshape(cfg.batch_size, cfg.way * cfg.qry_shot, *image_shape)
        y_spt, y_qry = jnp.split(y, (cfg.shot,), 2)
        y_spt = y_spt.reshape(cfg.batch_size, cfg.way * cfg.shot)
        y_qry = y_qry.reshape(cfg.batch_size, cfg.way * cfg.qry_shot)

        outer_opt_state, slow_params, fast_params, slow_state, fast_state, info = step_ins(
            rng_step,
            i,
            outer_opt_state,
            slow_params,
            fast_params,
            slow_state,
            fast_state,
            x_spt,
            y_spt,
            x_qry,
            y_qry,
            # trainer.inner_opt.init,
            # trainer.outer_opt.update
        )

        if (((i + 1) % cfg.progress_bar_refresh_rate) == 0) or (i == 0):
            current_lr = lr_schedule(outer_opt_state[-1].count)
            pbar.set_postfix(
                lr=f"{current_lr:.4f}",
                # vfol=f"{vfol:.3f}",
                # vioa=f"{vioa:.3f}",
                # vfoa=f"{vfoa:.3f}",
                loss=f"{info['outer']['final']['loss'].mean():.3f}",
                foa=f"{info['outer']['final']['aux'][0]['acc'].mean():.3f}",
                # bfoa=f"{best_val_acc:.3f}",
                fia=f"{info['inner']['auxs'][0]['acc'][:, -1].mean():.3f}",
                iil=f"{info['inner']['losses'][:, 0].mean():.3f}",
                fil=f"{info['inner']['losses'][:, -1].mean():.3f}",
            )


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))




In [265]:
cfg.num_inner_steps_test = 10
cfg.meta_batch_size_test = 25

In [267]:
from data.sampling import fsl_sample_transfer_build

In [268]:
test_sample_fn_kwargs = {
    "images": val_images,
    "labels": val_labels,
    "batch_size": cfg.meta_batch_size_test,
    "way": cfg.way,
    "shot": cfg.shot,
    "qry_shot": 15,
    "preprocess_fn": preprocess_fn,
    "device": device,
}
test_sample_fn = partial(
    fsl_sample_transfer_build, **test_sample_fn_kwargs,
)

In [272]:
test_inner_loop_ins = partial(
    fsl_inner_loop,
    is_training=False,
    num_steps=cfg.num_inner_steps_test,
    slow_apply=body.apply,
    fast_apply=head.apply,
    loss_fn=mean_xe_and_acc_dict,
    opt_update_fn=inner_opt.update,
)
test_outer_loop_ins = partial(
    outer_loop,
    is_training=False,
    inner_loop=test_inner_loop_ins,
    slow_apply=body.apply,
    fast_apply=head.apply,
    loss_fn=mean_xe_and_acc_dict,
)
test_batched_outer_loop_ins = partial(
    batched_outer_loop, outer_loop=test_outer_loop_ins
)
test_batched_outer_loop_ins = jit(test_batched_outer_loop_ins)

In [270]:
x_spt_test, y_spt_test, x_qry_test, y_qry_test = test_sample_fn(rng)

In [275]:
mean_loss, (_, _, test_info) = test_batched_outer_loop_ins(
    slow_params, fast_params, slow_state, fast_state, inner_opt.init(fast_params), split(rng, x_spt_test.shape[0]), x_spt_test, y_spt_test, x_qry_test, y_qry_test,
)

In [280]:
test_info["outer"]["final"]["aux"][0]["acc"].mean()

DeviceArray(0.61973345, dtype=float32)