In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

sys.path.insert(0, "../src")

In [27]:
from tqdm.notebook import tqdm as tqdmn
from easydict import EasyDict as edict

import jax
from jax.random import split
from jax import random, numpy as jnp, jit, value_and_grad

import optax as ox

from data.sampling import fsl_sample_and_build
from models.maml_conv import prepare_model, make_params
from data import prepare_data
from lib import setup_device, make_fsl_inner_outer_loop, mean_xe_and_acc_dict, make_batched_outer_loop

In [4]:
cfg = edict()

In [60]:
cfg.hidden_size = 32
cfg.activation = "relu"
cfg.way = 5
cfg.shot = 5
cfg.qry_shot = 10
cfg.batch_size = 5
cfg.val_num_tasks = 1000
cfg.inner_lr = 1e-1
cfg.outer_lr = 1e-2
cfg.num_outer_steps = 10000
cfg.num_inner_steps = 5
cfg.disjoint_tasks = False
cfg.disable_jit = False
cfg.dataset = "miniimagenet"
cfg.data_dir = "/home/samenabar/storage/data/FSL/mini-imagenet/"
cfg.progress_bar_refresh_rate = 50

cfg.gpus = 1

jit_enabled = not cfg.disable_jit

In [6]:
cpu, device = setup_device(cfg.gpus, default_platform="cpu")
rng = random.PRNGKey(0)
device

GpuDevice(id=0)

In [8]:
train_images, train_labels, val_images, val_labels, preprocess_fn = prepare_data(
    cfg.dataset, cfg.data_dir, device,
)

In [9]:
print("Train data:", train_images.shape, train_labels.shape)
print("Val data:", val_images.shape, val_labels.shape)

Train data: (64, 600, 84, 84, 3) (64, 600)
Val data: (16, 600, 84, 84, 3) (16, 600)


In [17]:
body, head = prepare_model(cfg.dataset, cfg.way, cfg.hidden_size, cfg.activation)

In [32]:
loss_fn = mean_xe_and_acc_dict

def slow_apply(rng, slow_params, slow_state, is_training, inputs):
    return body.apply(slow_params, slow_state, rng, inputs, is_training)

def fast_apply_and_loss_fn(
    rng, fast_params, fast_state, is_training, inputs, targets
):
    # params = hk.data_structures.merge(slow_params, fast_params)
    logits, state = head.apply(
        fast_params, fast_state, rng, inputs, is_training
    )
    loss, *aux = loss_fn(logits, targets)
    return loss, (state, *aux)

    return logits, state

In [33]:
def step(
    step_num,
    outer_opt_state,
    slow_params,
    fast_params,
    slow_state,
    fast_state,
    x_spt,
    y_spt,
    x_qry,
    y_qry,
    inner_opt_init,
    outer_opt_update,
):
    inner_opt_state = inner_opt_init(fast_params)

    (outer_loss, (slow_state, fast_state, info)), grads = value_and_grad(
        batched_outer_loop, (1, 2), has_aux=True
    )(
        None,  # rng
        slow_params,
        fast_params,
        slow_state,
        fast_state,
        True,  # is_training
        inner_opt_state,
        x_spt,
        y_spt,
        x_qry,
        y_qry,
        cfg.num_inner_steps,
    )
    updates, outer_opt_state = outer_opt_update(
        grads, outer_opt_state, (slow_params, fast_params)
    )
    slow_params, fast_params = ox.apply_updates(
        (slow_params, fast_params), updates
    )

    return outer_opt_state, slow_params, fast_params, slow_state, fast_state, info

In [61]:
inner_opt = ox.sgd(cfg.inner_lr)
inner_loop, outer_loop = make_fsl_inner_outer_loop(
    slow_apply,
    fast_apply_and_loss_fn,
    inner_opt.update,
    cfg.num_inner_steps,
    update_state=False,
)
batched_outer_loop = make_batched_outer_loop(outer_loop)

In [62]:
slow_params, fast_params, slow_state, fast_state = make_params(
    rng,
    cfg.dataset,
    body.init,
    body.apply,
    head.init,
    device,
)
slow_params, fast_params, slow_state, fast_state = map(lambda x: jax.device_put(x, device), (slow_params, fast_params, slow_state, fast_state))

In [63]:
lr_schedule = ox.cosine_decay_schedule(-cfg.outer_lr, cfg.num_outer_steps, 0.1)
outer_opt = ox.chain(
    ox.clip(10),
    ox.scale_by_adam(),
    ox.scale_by_schedule(lr_schedule),
)
outer_opt_state = outer_opt.init((slow_params, fast_params))

In [64]:
if jit_enabled:
    step = jit(step, static_argnums=(10, 11))
    # validate = jit(validate, static_argnums=(2, 3, 4))
    # validation_loss_acc_fn = jit(validation_loss_acc_fn)

In [None]:
pbar = tqdmn(
    range(cfg.num_outer_steps),
    # file=sys.stdout,
    # miniters=25,
    # mininterval=10,
    # maxinterval=30,
)

for i in pbar:
        rng, rng_sample = split(rng)
        x, y = fsl_sample_and_build(
            rng_sample,
            train_images,
            train_labels,
            cfg.batch_size,
            cfg.way,
            cfg.shot,
            cfg.qry_shot,
            cfg.disjoint_tasks,
        )
        x = preprocess_fn(jax.device_put(x, device))
        y = jax.device_put(y, device)
        
        image_shape = x.shape[-3:]
        x_spt, x_qry = jnp.split(x, (cfg.shot,), 2)
        x_spt = x_spt.reshape(cfg.batch_size, cfg.way * cfg.shot, *image_shape)
        x_qry = x_qry.reshape(cfg.batch_size, cfg.way * cfg.qry_shot, *image_shape)
        y_spt, y_qry = jnp.split(y, (cfg.shot,), 2)
        y_spt = y_spt.reshape(cfg.batch_size, cfg.way * cfg.shot)
        y_qry = y_qry.reshape(cfg.batch_size, cfg.way * cfg.qry_shot)

        outer_opt_state, slow_params, fast_params, slow_state, fast_state, info = step(
            i,
            outer_opt_state,
            slow_params,
            fast_params,
            slow_state,
            fast_state,
            x_spt,
            y_spt,
            x_qry,
            y_qry,
            inner_opt.init,
            outer_opt.update
        )

        if (((i + 1) % cfg.progress_bar_refresh_rate) == 0) or (i == 0):
            current_lr = lr_schedule(outer_opt_state[-1].count)
            pbar.set_postfix(
                lr=f"{current_lr:.4f}",
                # vfol=f"{vfol:.3f}",
                # vioa=f"{vioa:.3f}",
                # vfoa=f"{vfoa:.3f}",
                loss=f"{info['outer']['final_loss'].mean():.3f}",
                foa=f"{info['outer']['final_aux'][0]['acc'].mean():.3f}",
                # bfoa=f"{best_val_acc:.3f}",
                fia=f"{info['inner']['final_aux'][0]['acc'].mean():.3f}",
                iil=f"{info['inner']['initial_loss'].mean():.3f}",
                fil=f"{info['inner']['final_loss'].mean():.3f}",
            )


HBox(children=(FloatProgress(value=0.0, max=10000.0), HTML(value='')))