In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%load_ext autoreload
%autoreload 2

import os
import sys

sys.path.insert(0, '../src')

In [10]:
from easydict import EasyDict as edict
import numpy as onp

import jax
from jax import random
from jax.random import split
from jax import numpy as jnp
from jax.tree_util import Partial as partial

from jax.experimental import optix, optimizers

import haiku as hk

from lib import setup_device
from train_fsl import loss_fn, validate, prepare_model, make_params, prepare_data, fsl_sample_transfer_and_build
from models.maml_conv import MiniImagenetCNNBody
from lib import make_fsl_inner_outer_loop, make_batched_outer_loop

In [4]:
args = edict()

In [5]:
args.dataset="omniglot"
args.way=5
args.shot=1
args.hidden_size=32
args.activation="relu"
args.gpus=1
args.inner_lr=5e-1
args.outer_lr=1e-2
args.num_inner_steps=1
args.prefetch_data_gpu=False
args.qry_shot=15

args.val_num_tasks = 100
args.val_batch_size = 25

default_plaform = "gpu"
cpu, device = setup_device(args.gpus, default_plaform)

rng = random.PRNGKey(0)

In [6]:
if args.dataset == "miniimagenet":
    args.data_dir = "/workspace1/samenabar/data/mini-imagenet/"
elif args.dataset == "omniglot":
    args.data_dir = "/workspace1/samenabar/data/omniglot/"

train_images, train_labels, val_images, val_labels, preprocess_fn = prepare_data(
    args.dataset, args.data_dir, cpu, device, args.prefetch_data_gpu,
)

In [7]:
(
    MiniImagenetCNNBody,
    MiniImagenetCNNHead,
    slow_apply,
    fast_apply_and_loss_fn,
) = prepare_model(args.dataset, args.way, args.hidden_size, args.activation)

slow_params, fast_params, slow_state, fast_state = make_params(
    rng, args.dataset, MiniImagenetCNNBody.init, slow_apply, MiniImagenetCNNHead.init, device,
)

In [8]:
inner_opt = optix.chain(optix.sgd(args.inner_lr))
inner_loop, outer_loop = make_fsl_inner_outer_loop(
    slow_apply,
    fast_apply_and_loss_fn,
    inner_opt.update,
    args.num_inner_steps,
    update_state=False,
)
batched_outer_loop = make_batched_outer_loop(outer_loop)

In [11]:
outer_opt_init, outer_opt_update, outer_get_params = optimizers.adam(
    step_size=args.outer_lr,
)
outer_opt_state = outer_opt_init((slow_params, fast_params))

In [15]:
import dill

In [32]:
outer_opt_state[2]

(PyTreeDef(tuple, [*,*,*]),
 PyTreeDef(tuple, [*,*,*]),
 PyTreeDef(tuple, [*,*,*]),
 PyTreeDef(tuple, [*,*,*]),
 PyTreeDef(tuple, [*,*,*]),
 PyTreeDef(tuple, [*,*,*]),
 PyTreeDef(tuple, [*,*,*]),
 PyTreeDef(tuple, [*,*,*]),
 PyTreeDef(tuple, [*,*,*]),
 PyTreeDef(tuple, [*,*,*]),
 PyTreeDef(tuple, [*,*,*]),
 PyTreeDef(tuple, [*,*,*]),
 PyTreeDef(tuple, [*,*,*]),
 PyTreeDef(tuple, [*,*,*]),
 PyTreeDef(tuple, [*,*,*]),
 PyTreeDef(tuple, [*,*,*]),
 PyTreeDef(tuple, [*,*,*]),
 PyTreeDef(tuple, [*,*,*]))

In [33]:
dump_obj = {
    "optimizer_state": outer_opt_state,
    "slow_state": slow_state,
    "fast_state": fast_state,
    "rng": rng,
    "i": 1,
}
# with open("ckpt.ckpt", "wb") as f:
#     dill.dump(dump_obj, f, protocol=3)

In [34]:
import flax

In [None]:
Op

In [46]:
outer_opt_state.

PyTreeDef(tuple, [PyTreeDef(<class 'haiku._src.data_structures.FlatMapping'>[PyTreeDef(dict[['mini_imagenet_cnn_body/conv_base/conv_block/batch_norm', 'mini_imagenet_cnn_body/conv_base/conv_block/conv2_d', 'mini_imagenet_cnn_body/conv_base/conv_block_1/batch_norm', 'mini_imagenet_cnn_body/conv_base/conv_block_1/conv2_d', 'mini_imagenet_cnn_body/conv_base/conv_block_2/batch_norm', 'mini_imagenet_cnn_body/conv_base/conv_block_2/conv2_d', 'mini_imagenet_cnn_body/conv_base/conv_block_3/batch_norm', 'mini_imagenet_cnn_body/conv_base/conv_block_3/conv2_d']], [PyTreeDef(<class 'haiku._src.data_structures.FlatMapping'>[PyTreeDef(dict[['offset', 'scale']], [*,*])], [*,*]),PyTreeDef(<class 'haiku._src.data_structures.FlatMapping'>[PyTreeDef(dict[['b', 'w']], [*,*])], [*,*]),PyTreeDef(<class 'haiku._src.data_structures.FlatMapping'>[PyTreeDef(dict[['offset', 'scale']], [*,*])], [*,*]),PyTreeDef(<class 'haiku._src.data_structures.FlatMapping'>[PyTreeDef(dict[['b', 'w']], [*,*])], [*,*]),PyTreeDef(

In [54]:
outer_opt_state.packed_state

([DeviceArray([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                  0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                  0., 0., 0., 0.]]]], dtype=float32),
  DeviceArray([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                  0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                  0., 0., 0., 0.]]]], dtype=float32),
  DeviceArray([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                  0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                  0., 0., 0., 0.]]]], dtype=float32)],
 [DeviceArray([[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                  1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                  1., 1., 1., 1.]]]], dtype=float32),
  DeviceArray([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                  0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                  0., 0., 0., 0.]]]], dtype

In [53]:
flax.serialization.msgpack_serialize(outer_opt_state.tree_def)

TypeError: can not serialize 'jaxlib.pytree.PyTreeDef' object

In [56]:
import optax

In [57]:
optimizer = optax.chain(
    optax.clip(10),
    optax.scale_by_adam(),
    optax.scale_by_schedule(optax.cosine_decay_schedule(1e-2, 20000, 0.01))
)

In [59]:
opt_state = optimizer.init((slow_params, fast_state))

In [68]:
optax.cosine_decay_schedule(1e-1, 20000, 0.01)(opt_state[-1].count).item()

0.10000000149011612

In [61]:
opt_state

[ClipState(),
 ScaleByAdamState(count=DeviceArray(0, dtype=int32), mu=(FlatMapping({
   'mini_imagenet_cnn_body/conv_base/conv_block/batch_norm': FlatMapping({
                                                               'offset': DeviceArray([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                                                                         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                                                                         0., 0., 0., 0.]]]], dtype=float32),
                                                               'scale': DeviceArray([[[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                                                                        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                                                                                        0., 0., 0., 0.]]]], dtype=float32),
             

In [69]:
dump_obj = {
    "optimizer_state": opt_state,
    "slow_params": slow_params,
    "fast_params": fast_params,
    "slow_state": slow_state,
    "fast_state": fast_state,
    "rng": rng,
    "i": 1,
}
with open("ckpt.ckpt", "wb") as f:
    dill.dump(dump_obj, f)

Separacion

In [2]:
%load_ext autoreload
%autoreload 2

import os
import sys

sys.path.insert(0, '../src')

In [39]:
import jax
from jax import numpy as jnp, vmap, grad, value_and_grad, random, jit
from jax.random import split
from jax.tree_util import Partial as partial

import optax
import dill

from train_fsl import make_params, prepare_model
from lib import setup_device
from data import prepare_data
from models.maml_conv import MiniImagenetCNNMaker
from data.sampling import fsl_sample_transfer_and_build
from lib import make_fsl_inner_loop # make_batched_outer_loop, 

In [4]:
cpu, gpu = setup_device(1)
rng = random.PRNGKey(0)

In [5]:
train_images, train_labels, val_images, val_labels, preprocess_fn = prepare_data(
    "miniimagenet", "/workspace1/samenabar/data/mini-imagenet/", cpu, gpu, False,
)

In [6]:
with open("../experiments/Aug26-2020/mi-5-way-5-shot-big-bsz-sch/sch-bsz-20-ilr-0.1-olr-0.01/checkpoints/best.ckpt", "rb") as f:
    state = dill.load(f)

In [7]:
slow_params, fast_params = state["slow_params"], state["fast_params"]
slow_state, fast_state = state["slow_state"], state["fast_state"]

slow_params, fast_params, slow_state, fast_state = jax.tree_map(lambda x: jax.device_put(x, gpu), (slow_params, fast_params, slow_state, fast_state))

In [8]:
(
    MiniImagenetCNNBody,
    MiniImagenetCNNHead,
    slow_apply,
    fast_apply_and_loss_fn,
) = prepare_model("miniimagenet", 5, 32, "relu")

inner_opt = optax.sgd(learning_rate=1e-2)

In [9]:
## Outer loop

def outer_loop(
    inner_loop,
    rng,
    slow_params,
    fast_params,
    slow_state,
    fast_state,
    is_training,
    x_spt,
    y_spt,
    x_qry,
    y_qry,
    num_steps,
    slow_apply,
    fast_apply_and_loss_fn,
    inner_opt,
    inner_opt_state=None,
    update_state=False,
):
    if inner_opt_state is None:
        inner_opt_state = inner_opt.init(fast_params)
    slow_outputs, initial_slow_state = slow_apply(
        rng, slow_params, slow_state, is_training, x_qry,
    )
    initial_loss, (initial_fast_state, *initial_aux) = fast_apply_and_loss_fn(
        rng, fast_params, fast_state, is_training, *slow_outputs, y_qry,
    )
    fast_params, slow_state, fast_state, inner_info = inner_loop(
        rng=rng,
        slow_params=slow_params,
        fast_params=fast_params,
        slow_state=slow_state,
        fast_state=fast_state,
        is_training=is_training,
        opt_state=inner_opt_state,
        x_spt=x_spt,
        y_spt=y_spt,
        slow_apply=slow_apply,
        fast_apply_and_loss_fn=fast_apply_and_loss_fn,
        opt_update_fn=inner_opt.update,
        num_steps=num_steps,
        update_state=update_state,
    )
    final_loss, (final_fast_state, *final_aux) = fast_apply_and_loss_fn(
        rng, fast_params, fast_state, is_training, *slow_outputs, y_qry,
    )

    return (
        final_loss,
        (
            slow_state,
            fast_state,
            {
                "inner": inner_info,
                "outer": {
                    "initial_loss": initial_loss,
                    "final_loss": final_loss,
                    "initial_aux": initial_aux,
                    "final_aux": final_aux,
                },
            },
        ),
    )

In [10]:
## Inner loop

def fsl_inner_loop(
    rng,
    slow_params,
    fast_params,
    slow_state,
    fast_state,
    is_training,
    opt_state,
    x_spt,
    y_spt,
    slow_apply,
    fast_apply_and_loss_fn,
    opt_update_fn,
    num_steps,
    update_state=False,
):
    slow_outputs, slow_state = slow_apply(
        rng, slow_params, slow_state, is_training, x_spt,
    )
    for i in range(num_steps):
        (loss, (new_fast_state, *aux)), grads = value_and_grad(
            fast_apply_and_loss_fn, 1, has_aux=True
        )(rng, fast_params, fast_state, is_training, *slow_outputs, y_spt)
        if update_state:
            fast_state = new_fast_state
        if i == 0:
            initial_loss = loss
            initial_aux = aux
        updates, opt_state = opt_update_fn(grads, opt_state, fast_params)
        fast_params = optax.apply_updates(fast_params, updates)

    final_loss, (final_fast_state, *final_aux) = fast_apply_and_loss_fn(
        rng, fast_params, fast_state, False, *slow_outputs, y_spt
    )

    return (
        fast_params,
        slow_state,
        fast_state,
        {
            "initial_loss": initial_loss,
            "final_loss": final_loss,
            "initial_aux": initial_aux,
            "final_aux": final_aux,
        },
    )

In [11]:
def make_batched_outer_loop(outer_loop):
    def helper_fn(rng, x_spt, y_spt, x_qry, y_qry, **kwargs):
        return outer_loop(
            rng=rng, x_spt=x_spt, y_spt=y_spt, x_qry=x_qry, y_qry=y_qry, **kwargs
        )

    def batched_outer_loop(
        rng,  # Assume rng is already split
        slow_params,
        fast_params,
        slow_state,
        fast_state,
        bx_spt,
        by_spt,
        bx_qry,
        by_qry,
        **kwargs,
    ):

        losses, (slow_states, fast_states, infos) = vmap(
            partial(
                helper_fn,
                slow_params=slow_params,
                fast_params=fast_params,
                slow_state=slow_state,
                fast_state=fast_state,
                **kwargs,
            )
        )(rng, bx_spt, by_spt, bx_qry, by_qry)

        return losses.mean(), (slow_states, fast_states, infos)

    return batched_outer_loop

In [33]:
def test(rng, loss_acc_fn, sample_fn, num_batches):
    results = []
    for i in range(num_batches):
        rng, rng_sample = split(rng)
        x_spt, y_spt, x_qry, y_qry = sample_fn(rng)
        results.append(loss_acc_fn(None, x_spt=x_spt, y_spt=y_spt, x_qry=x_qry, y_qry=y_qry))

    results = jax.tree_util.tree_multimap(lambda x, *xs: jnp.stack(xs), results[0], *results)
    return results

def loss_acc_fn(
    rng,
    slow_params,
    fast_params,
    slow_state,
    fast_state,
    x_spt,
    y_spt,
    x_qry,
    y_qry,
    batched_outer_loop,
):
    outer_loss, (_, _, info) = batched_outer_loop(
        rng=rng,  # Assume rng is already split
        slow_params=slow_params,
        fast_params=fast_params,
        slow_state=slow_state,
        fast_state=fast_state,
        bx_spt=x_spt,
        by_spt=y_spt,
        bx_qry=x_qry,
        by_qry=y_qry,
    )
    return outer_loss, info

In [44]:
inner_opt = optax.sgd(learning_rate=1e-1)
test_outer_loop_const_kwargs = {
    "inner_loop": fsl_inner_loop,
    "is_training": False,
    "num_steps": 10,
    "slow_apply": slow_apply,
    "fast_apply_and_loss_fn": fast_apply_and_loss_fn,
    "inner_opt": inner_opt,
    "update_state": False,
}

test_outer_loop = partial(outer_loop, **test_outer_loop_const_kwargs)
test_batched_outer_loop = jit(make_batched_outer_loop(test_outer_loop))

sample_fn_const_kwargs = {
    "preprocess_fn": preprocess_fn,
    "images": val_images,
    "labels": val_labels,
    "num_tasks": 25,
    "way": 5,
    "spt_shot": 5,
    "qry_shot": 15,
    "device": gpu,
    "disjoint": False,
}
test_sample_fn = partial(fsl_sample_transfer_and_build, **sample_fn_const_kwargs)

loss_acc_fn_const_kwargs = {
    "slow_params": slow_params,
    "fast_params": fast_params,
    "slow_state": slow_state,
    "fast_state": fast_state,
    "batched_outer_loop": test_batched_outer_loop,
}
test_loss_acc_fn = jit(partial(loss_acc_fn, **loss_acc_fn_const_kwargs))

In [46]:
%%time
outer_loss, info = test(rng, test_loss_acc_fn, test_sample_fn, 2000 // 25)

CPU times: user 5.78 s, sys: 3.16 s, total: 8.93 s
Wall time: 6.71 s


In [49]:
info["outer"]["final_aux"][0]["acc"].mean(-1).mean(), info["outer"]["final_aux"][0]["acc"].std()

(DeviceArray(0.61568016, dtype=float32),
 DeviceArray(0.09007358, dtype=float32))

In [91]:
def make_test_loss_acc_fn(optimizer, batched_outer_loop, num_steps):
    def test_loss_acc_fn(
        rng,
        slow_params,
        fast_params,
        slow_state,
        fast_state,
        x_spt,
        y_spt,
        x_qry,
        y_qry,
    ):
        opt_state = optimizer.init(fast_params)
        outer_loss, (_, _, info) = batched_outer_loop(
            rng,
            slow_params,
            fast_params,
            slow_state,
            fast_state,
            False,  # is_training
            opt_state,
            x_spt,
            y_spt,
            x_qry,
            y_qry,
            num_steps,
        )
        return outer_loss, val_info


def make_test_sample_fn(
    preprocess_fn, images, labels, batch_size, way, shot, qry_shot, device, disjoint
):
    return partial(
        fsl_sample_transfer_and_build,
        preprocess_fn=preprocess_fn,
        images=images,
        labels=labels,
        num_tasks=batch_size,
        way=way,
        shot=shot,
        qry_shot=qry_shot,
        device=device,
        disjoint=disjoint,
    )