In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

sys.path.insert(0, "../src")

In [3]:
from tqdm.notebook import tqdm as tqdmn
from easydict import EasyDict as edict

import numpy as onp

import jax
from jax.random import split
from jax import random, numpy as jnp, jit, value_and_grad, partial, vmap, tree_multimap, ops

import haiku as hk
import optax as ox

# from test_sup import test_su
from data.sampling import BatchSampler

from data.sampling import fsl_sample_transfer_build, continual_learning_sample_build_transfer, continual_learning_sample
from models.maml_conv import prepare_model, make_params
from data import prepare_data
from lib import setup_device, make_fsl_inner_outer_loop, mean_xe_and_acc_dict, make_batched_outer_loop, fast_apply_and_loss_fn, flatten

In [4]:
cfg = edict()

In [5]:
cfg.hidden_size = 32
cfg.activation = "relu"
cfg.way = 5
cfg.shot = 5
cfg.qry_shot = 10
cfg.batch_size = 5
cfg.inner_lr = 1e-2
cfg.outer_lr = 1e-3
cfg.num_outer_steps = 10000
# cfg.num_inner_steps = 5
cfg.disjoint_tasks = False
cfg.disable_jit = False
cfg.dataset = "miniimagenet"
# cfg.data_dir = "/home/samenabar/storage/data/FSL/mini-imagenet/"
cfg.progress_bar_refresh_rate = 50
cfg.train_method = "continual-learning"
cfg.track_bn_stats = False
cfg.prefetch_data_gpu = False
cfg.meta_batch_size = 5
cfg.data_dir = "/home/samenabar/storage/data/FSL/mini-imagenet/"

# test
# cfg.num_inner_steps_test = 10
cfg.meta_batch_size_test = 25
cfg.num_tasks_test = 1000


cfg.gpus = 1

jit_enabled = not cfg.disable_jit

In [6]:
cpu, device = setup_device(cfg.gpus, default_platform="cpu")
rng = random.PRNGKey(0)
device

GpuDevice(id=0)

In [None]:
train_images, train_labels, val_images, val_labels, preprocess_fn = prepare_data(
    cfg.dataset, cfg.data_dir, device,
)

In [None]:
TRAIN_SIZE = 500

sup_train_images = train_images[:, :TRAIN_SIZE]
sup_train_labels = train_labels[:, :TRAIN_SIZE]
# These are for supervised learning validation
sup_val_images = train_images[:, TRAIN_SIZE:]
sup_val_labels = train_labels[:, TRAIN_SIZE:]

In [None]:
test_sup_spt_sampler = BatchSampler(rng, flatten(sup_train_images, 1), flatten(sup_train_labels), 512, shuffle=False, keep_last=True)
test_sup_qry_sampler = BatchSampler(rng, flatten(sup_val_images, 1), flatten(sup_val_labels), 512, shuffle=False, keep_last=True)

In [None]:
print("Train data:", sup_train_images.shape, sup_train_labels.shape)
print("Sup val data:", sup_val_images.shape, sup_val_labels.shape)
print("Val data:", val_images.shape, val_labels.shape)

In [33]:
def cl_inner_loop(
    slow_params,
    fast_params,
    slow_state,
    fast_state,
    opt_state,
    rng,
    inputs,
    targets,
    is_training,
    slow_apply,
    fast_apply,
    loss_fn,
    opt_update_fn,
    num_steps=None,
    update_state=False,
    return_history=True,
):
    _fast_apply_and_loss_fn = partial(
        fast_apply_and_loss_fn, fast_apply=fast_apply, loss_fn=loss_fn
    )
    rng_slow, rng_fast, *rngs = split(rng, inputs.shape[0] + 2)
    slow_outputs, slow_state = slow_apply(
        slow_params, slow_state, rng_slow, inputs, is_training,
    )
    
    initial_loss, (_, *initial_aux) = _fast_apply_and_loss_fn(
        fast_params, fast_state, rng_fast, slow_outputs, is_training, targets,
    )
    
    losses = []
    auxs = []

    for i in range(inputs.shape[0]):
        (loss, (new_fast_state, *aux)), grads = value_and_grad(
            _fast_apply_and_loss_fn, has_aux=True
        )(fast_params, fast_state, rngs[i], [so[[i]] for so in slow_outputs], is_training, targets[[i]])
        if update_state:
            fast_state = new_fast_state
        if return_history:
            losses.append(loss)
            auxs.append(aux)
            
        updates, opt_state = opt_update_fn(grads, opt_state, fast_params)
        fast_params = ox.apply_updates(fast_params, updates)

    final_loss, (final_fast_state, *final_aux) = _fast_apply_and_loss_fn(
        fast_params, fast_state, rng_fast, slow_outputs, is_training, targets,
    )
    if return_history:
        # losses.append(final_loss)
        # auxs.append(final_aux)
        info = {
            "losses": jnp.stack(losses),
            "auxs": tree_multimap(lambda x, *xs: jnp.stack(xs), auxs[0], *auxs),
            "initial": {"loss": initial_loss, "aux": initial_aux},
            "final": {"loss": final_loss, "aux": final_aux},
        }
    else:
        info = {
            "losses": {"initial": initial_loss, "final": final_loss},
            "auxs": {"initial": initial_aux, "final": final_aux},
        },

    return (
        fast_params,
        slow_state,
        fast_state,
        opt_state,
        info,
    )

In [None]:
def fsl_inner_loop(
    slow_params,
    fast_params,
    slow_state,
    fast_state,
    opt_state,
    rng,
    inputs,
    targets,
    is_training,
    num_steps,
    slow_apply,
    fast_apply,
    loss_fn,
    opt_update_fn,
    update_state=False,
    return_history=True,
):
    _fast_apply_and_loss_fn = partial(
        fast_apply_and_loss_fn, fast_apply=fast_apply, loss_fn=loss_fn
    )
    rng_slow, *rngs = split(rng, num_steps + 2)
    slow_outputs, slow_state = slow_apply(
        slow_params, slow_state, rng_slow, inputs, is_training,
    )
    
    losses = []
    auxs = []

    for i in range(num_steps):
        (loss, (new_fast_state, *aux)), grads = value_and_grad(
            _fast_apply_and_loss_fn, has_aux=True
        )(fast_params, fast_state, rngs[i], slow_outputs, is_training, targets)
        if update_state:
            fast_state = new_fast_state
        if i == 0:
            initial_loss = loss
            initial_aux = aux
            
        if return_history:
            losses.append(loss)
            auxs.append(aux)
            
        updates, opt_state = opt_update_fn(grads, opt_state, fast_params)
        fast_params = ox.apply_updates(fast_params, updates)

    final_loss, (final_fast_state, *final_aux) = _fast_apply_and_loss_fn(
        fast_params, fast_state, rngs[i + 1], slow_outputs, is_training, targets,
    )
    if return_history:
        losses.append(final_loss)
        auxs.append(final_aux)
        info = {"losses": jnp.stack(losses), "auxs": tree_multimap(lambda x, *xs: jnp.stack(xs), auxs[0], *auxs)}
    else:
        info = {
            "losses": {"initial": initial_loss, "final": final_loss},
            "auxs": {"initial": initial_aux, "final": final_aux},
        },

    return (
        fast_params,
        slow_state,
        fast_state,
        opt_state,
        info,
    )

In [None]:
def outer_loop(
    slow_params,
    fast_params,
    slow_state,
    fast_state,
    inner_opt_state,
    rng,
    x_spt,
    y_spt,
    x_qry,
    y_qry,
    spt_classes,
    is_training,
    inner_loop,  # instantiated inner_loop
    slow_apply,
    fast_apply,
    loss_fn,
    train_method=None,
):
    if train_method == "fsl-reset-per-task":
        print("reseting fast_params")
        # uniques = jnp.unique(y_spt)
        fast_params = hk.data_structures.merge(
            {
                "mini_imagenet_cnn_head/linear": {
                    "w": ops.index_update(
                        fast_params["mini_imagenet_cnn_head/linear"]["w"],
                        ops.index[:, spt_classes],
                        jnp.zeros(
                            (
                                fast_params["mini_imagenet_cnn_head/linear"]["w"].shape[
                                    0
                                ],
                                spt_classes.shape[0],
                            )
                        ),
                    )
                }
            }
        )
    _fast_apply_and_loss_fn = partial(
        fast_apply_and_loss_fn, fast_apply=fast_apply, loss_fn=loss_fn
    )
    rng_outer_slow, rng_outer_fast, rng_inner = split(rng, 3)
    slow_outputs, initial_slow_state = slow_apply(
        slow_params, slow_state, rng_outer_slow, x_qry, is_training,
    )
    initial_loss, (_, *initial_aux) = _fast_apply_and_loss_fn(
        fast_params, fast_state, rng_outer_fast, slow_outputs, False, y_qry,
    )

    fast_params, inner_slow_state, fast_state, inner_opt_state, inner_auxs = inner_loop(
        slow_params,
        fast_params,
        slow_state,
        fast_state,
        inner_opt_state,
        rng_inner,
        x_spt,
        y_spt,
    )
    final_loss, (final_fast_state, *final_aux) = _fast_apply_and_loss_fn(
        fast_params, fast_state, rng_outer_fast, slow_outputs, is_training, y_qry,
    )
    return (
        final_loss,
        (
            initial_slow_state,
            fast_state,
            {
                "inner": inner_auxs,
                "outer": {
                    "initial": {"aux": initial_aux, "loss": initial_loss},
                    "final": {"aux": final_aux, "loss": final_loss},
                },
            },
        ),
    )

In [None]:
def batched_outer_loop(slow_params, fast_params, slow_state, fast_state, inner_opt_state, brng, bx_spt, by_spt, bx_qry, by_qry, spt_classes, outer_loop):
    losses, aux = vmap(partial(outer_loop, slow_params, fast_params, slow_state, fast_state, inner_opt_state))(brng, bx_spt, by_spt, bx_qry, by_qry, spt_classes)
    return losses.mean(), aux

In [None]:
def step(
    rng,
    step_num,
    outer_opt_state,
    slow_params,
    fast_params,
    slow_state,
    fast_state,
    x_spt,
    y_spt,
    x_qry,
    y_qry,
    spt_classes,
    inner_opt_init,
    outer_opt_update,
    batched_outer_loop_ins,
    train_method=None,
):
    inner_opt_state = inner_opt_init(fast_params)

    (outer_loss, (slow_state, fast_state, info)), grads = value_and_grad(
        batched_outer_loop_ins, (0, 1), has_aux=True
    )(
        slow_params,
        fast_params,
        slow_state,
        fast_state,
        inner_opt_state,
        split(rng, x_spt.shape[0]),
        x_spt,
        y_spt,
        x_qry,
        y_qry,
        spt_classes,
    )
    updates, outer_opt_state = outer_opt_update(
        grads, outer_opt_state, (slow_params, fast_params)
    )
    slow_params, fast_params = ox.apply_updates(
        (slow_params, fast_params), updates
    )

    return outer_opt_state, slow_params, fast_params, slow_state, fast_state, info

In [40]:
from data.sampling import fsl_sample

In [41]:
def test_continual_learning_sample(rng, images, labels, num_tasks, way, spt_shot, qry_shot, disjoint=False, shuffle_labels=False):
    x, y = fsl_sample(rng, images, labels, num_tasks, way, spt_shot, qry_shot, disjoint, shuffle_labels)
    return x, y

def test_continual_learning_sample_build_transfer(rng, images, labels, num_tasks, way, spt_shot, qry_shot, preprocess_fn, device=None, disjoint=False, shuffle_labels=False):
    x, y = test_continual_learning_sample(rng, images, labels, num_tasks, way, spt_shot, qry_shot, disjoint=False, shuffle_labels=False)
    x, y = jax.device_put(x, device), jax.device_put(y, device)
    x = preprocess_fn(x)
    x_spt, y_spt = x[:, :, :spt_shot], y[:, :, :spt_shot]
    x_qry, y_qry = x[:, :, spt_shot:], y[:, :, spt_shot:]
    return tuple(flatten(t, (1, 2)) for t in (x_spt, y_spt, x_qry, y_qry))

In [47]:
x_spt, y_spt, x_qry, y_qry = test_continual_learning_sample_build_transfer(rng, val_images, val_labels, 2, 16, 10, 15, preprocess_fn, device)

In [48]:
x_spt.shape, y_spt.shape, x_qry.shape, y_qry.shape

((2, 160, 84, 84, 3), (2, 160), (2, 240, 84, 84, 3), (2, 240))

In [49]:
from lib import batched_outer_loop

In [53]:
test_inner_opt = ox.sgd(1e-2)
loss_fn = mean_xe_and_acc_dict
test_cl_inner_loop_ins = partial(
    cl_inner_loop,
    is_training=False,
    # num_steps=5,
    slow_apply=body.apply,
    fast_apply=head.apply,
    loss_fn=loss_fn,
    opt_update_fn=test_inner_opt.update,
)

outer_loop_ins = partial(
    outer_loop,
    is_training=False,
    inner_loop=test_cl_inner_loop_ins,
    slow_apply=body.apply,
    fast_apply=head.apply,
    loss_fn=loss_fn,
    # train_method="fsl-reset-per-task",
)
batched_outer_loop_ins = partial(batched_outer_loop, outer_loop=outer_loop_ins)


In [54]:
batched_outer_loop_ins(slow_params, fast_params, slow_state, fast_state, test_inner_opt.init(fast_params), split(rng, x_spt.shape[0]), x_spt, y_spt, x_qry, y_qry, jnp.zeros(x_spt.shape[0]))

(DeviceArray(8.702764, dtype=float32),
 (FlatMapping({}),
  FlatMapping({}),
  {'inner': {'auxs': [{'acc': DeviceArray([[0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1.,
                    1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
                    0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1.,
                    1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                    0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1.,
                    1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
                    0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1.,
                    1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
                    0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1.,
                    1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
                    0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
                   [0., 1., 1., 1., 1.

In [271]:
test_cl_inner_loop_ins(slow_params, zero_fast_params, slow_state, fast_state, test_inner_opt.init(zero_fast_params), rng, x_spt[0], y_spt[0])

(FlatMapping({
   'mini_imagenet_cnn_head/linear': FlatMapping({
                                      'w': DeviceArray([[ 4.3427167e-07,  1.3737676e-06, -5.0718722e-07, ...,
                                                         -5.0628262e-07, -5.0628262e-07, -5.0628262e-07],
                                                        [ 5.1206190e-08,  1.1213683e-06,  4.6514092e-06, ...,
                                                         -1.2044611e-06, -1.2044611e-06, -1.2044611e-06],
                                                        [ 2.2674105e-06,  4.6662276e-06,  4.4428316e-06, ...,
                                                         -1.3054758e-06, -1.3054758e-06, -1.3054758e-06],
                                                        ...,
                                                        [ 1.0681804e-06,  5.4417719e-06,  2.8324150e-06, ...,
                                                         -1.7285606e-06, -1.7285606e-06, -1.7285606e-06],
          

In [230]:
train_sample_fn_kwargs = {
    "images": sup_train_images,
    "labels": sup_train_labels,
    "batch_size": cfg.meta_batch_size,
    "way": 3,
    "shot": 5,
    "qry_shot": 128,
    "preprocess_fn": preprocess_fn,
    "device": device,
    "shuffled_labels": False,
}
# train_sample_fn = partial(
#     fsl_sample_transfer_build, **train_sample_fn_kwargs,
# )

train_sample_fn = partial(continual_learning_sample_build_transfer, **train_sample_fn_kwargs)

In [132]:
cfg.inner_lr = 0.01
cfg.outer_lr = 0.01
cfg.num_outer_steps = 20000

slow_params, fast_params, slow_state, fast_state = make_params(
    rng,
    cfg.dataset,
    body.init,
    body.apply,
    head.init,
    device,
)
slow_params, fast_params, slow_state, fast_state = map(lambda x: jax.device_put(x, device), (slow_params, fast_params, slow_state, fast_state))

inner_opt = ox.chain(ox.clip(10), ox.sgd(cfg.inner_lr))
lr_schedule = ox.cosine_decay_schedule(-cfg.outer_lr, cfg.num_outer_steps, 0.1)
outer_opt = ox.chain(
    ox.clip(10),
    ox.scale_by_adam(),
    ox.scale_by_schedule(lr_schedule),
)
outer_opt_state = outer_opt.init((slow_params, fast_params))

loss_fn = mean_xe_and_acc_dict
inner_loop_ins = partial(
    cl_inner_loop,
    is_training=True,
    # num_steps=5,
    slow_apply=body.apply,
    fast_apply=head.apply,
    loss_fn=loss_fn,
    opt_update_fn=inner_opt.update,
)
outer_loop_ins = partial(
    outer_loop,
    is_training=True,
    inner_loop=inner_loop_ins,
    slow_apply=body.apply,
    fast_apply=head.apply,
    loss_fn=loss_fn,
    train_method="fsl-reset-per-task",
)
batched_outer_loop_ins = partial(batched_outer_loop, outer_loop=outer_loop_ins)

step_ins = jit(
    partial(
        step,
        inner_opt_init=inner_opt.init,
        outer_opt_update=outer_opt.update,
        batched_outer_loop_ins=batched_outer_loop_ins,
    ),
    # static_argnums=(11, 12, 13),
)

In [133]:
pbar = tqdmn(
    range(cfg.num_outer_steps),
)

for i in pbar:
    rng, rng_sample, rng_step = split(rng, 3)
    x_spt, y_spt, x_qry, y_qry = train_sample_fn(rng_sample)
    spt_classes = jax.device_put(onp.unique(y_spt, axis=1), device)

    outer_opt_state, slow_params, fast_params, slow_state, fast_state, info = step_ins(
        rng_step,
        i,
        outer_opt_state,
        slow_params,
        fast_params,
        slow_state,
        fast_state,
        x_spt,
        y_spt,
        x_qry,
        y_qry,
        spt_classes,
        # trainer.inner_opt.init,
        # trainer.outer_opt.update
    )

    if (((i + 1) % 25) == 0) or (i == 0):
        current_lr = lr_schedule(outer_opt_state[-1].count)
        pbar.set_postfix(
            lr=f"{current_lr:.4f}",
            # vfol=f"{vfol:.3f}",
            # vioa=f"{vioa:.3f}",
            # vfoa=f"{vfoa:.3f}",
            loss=f"{info['outer']['final']['loss'].mean():.3f}",
            foa=f"{info['outer']['final']['aux'][0]['acc'].mean():.3f}",
            # bfoa=f"{best_val_acc:.3f}",
            fia=f"{info['inner']['final']['aux'][0]['acc'].mean():.3f}",
            iia=f"{info['inner']['initial']['aux'][0]['acc'].mean():.3f}",
            iil=f"{info['inner']['initial']['loss'].mean():.3f}",
            fil=f"{info['inner']['final']['loss'].mean():.3f}",
        )


HBox(children=(FloatProgress(value=0.0, max=20000.0), HTML(value='')))

reseting fast_params
reseting fast_params



In [165]:
from test_sup import test_sup_cosine, test_sup
from test_utils import test_fsl_maml, test_fsl_embeddings

In [154]:
cfg.sup_batch_size_test = 256

In [157]:
transfer_spt_images = val_images[:, :TRAIN_SIZE]
transfer_spt_labels = val_labels[:, :TRAIN_SIZE]
transfer_qry_images = val_images[:, TRAIN_SIZE:]
transfer_qry_labels = val_labels[:, TRAIN_SIZE:]

In [216]:
test_fsl_sample_fn_kwargs = {
    "images": val_images,
    "labels": val_labels,
    "batch_size": cfg.meta_batch_size_test,
    "way": cfg.way,
    "shot": cfg.shot,
    "qry_shot": 15,  # Standard
    "preprocess_fn": preprocess_fn,
    "device": device,
}
test_fsl_sample_fn = partial(
    fsl_sample_transfer_build, **test_fsl_sample_fn_kwargs,
)
rng, *rngs_samplers = split(rng, 5)
test_sup_spt_sampler = BatchSampler(
    rngs_samplers[0],
    flatten(sup_train_images, 1),
    flatten(sup_train_labels),
    cfg.sup_batch_size_test,
    shuffle=True,
    keep_last=True,
)
test_sup_qry_sampler = BatchSampler(
    rngs_samplers[1],
    flatten(sup_val_images, 1),
    flatten(sup_val_labels),
    cfg.sup_batch_size_test,
    shuffle=True,
    keep_last=True,
)
test_transfer_spt_sampler = BatchSampler(
    rngs_samplers[2],
    flatten(transfer_spt_images, 1),
    flatten(transfer_spt_labels),
    cfg.sup_batch_size_test,
    shuffle=True,
    keep_last=True,
)
test_transfer_qry_sampler = BatchSampler(
    rngs_samplers[3],
    flatten(transfer_qry_images, 1),
    flatten(transfer_qry_labels),
    cfg.sup_batch_size_test,
    shuffle=True,
    keep_last=True,
)

embeddings_fn = lambda slow_params, slow_state, inputs: body.apply(
    slow_params, slow_state, None, inputs, False
)[0][0]
embeddings_fn = jit(embeddings_fn)

In [None]:
def test_continual_learning()

In [173]:
def sup_test_pred_fn(slow_params, fast_params, slow_state, fast_state, inputs):
    slow_outputs, _ = body.apply(slow_params, slow_state, None, inputs, False)
    return head.apply(fast_params, fast_state, None, *slow_outputs, False)[0]

In [218]:
sup_preds, sup_targets = test_sup(
    partial(sup_test_pred_fn, slow_params, fast_params, slow_state, fast_state), test_sup_qry_sampler, device, preprocess_fn,
)
sup_acc = (sup_preds == sup_targets).astype(onp.float).mean()
sup_acc

0.329375

In [163]:
sup_preds_cosine, sup_targets_cosine = test_sup_cosine(
    partial(embeddings_fn, slow_params, slow_state),
    test_sup_spt_sampler,
    test_sup_qry_sampler,
    device,
    preprocess_fn,
)
sup_acc_cosine = (sup_preds_cosine == sup_targets_cosine).astype(onp.float).mean()
sup_acc_cosine

0.21984375

In [184]:
zero_fast_params = hk.data_structures.merge(
            {
                "mini_imagenet_cnn_head/linear": {
                    "w": jax.device_put(jnp.zeros((800, 64)), device),
                }
            }
        )

In [208]:
test_inner_opt = ox.sgd(0.01)

test_inner_loop_ins = partial(
    fsl_inner_loop,
    is_training=False,
    num_steps=cfg.num_inner_steps_test,
    slow_apply=body.apply,
    fast_apply=head.apply,
    loss_fn=mean_xe_and_acc_dict,
    opt_update_fn=test_inner_opt.update,
)
test_outer_loop_ins = partial(
    outer_loop,
    is_training=False,
    inner_loop=test_inner_loop_ins,
    slow_apply=body.apply,
    fast_apply=head.apply,
    loss_fn=mean_xe_and_acc_dict,
)
test_batched_outer_loop_ins = partial(
    batched_outer_loop, outer_loop=test_outer_loop_ins, spt_classes=None,
)
test_batched_outer_loop_ins = jit(test_batched_outer_loop_ins)

fsl_maml_results = test_fsl_maml(
    rng,
    slow_params,
    zero_fast_params,
    slow_state,
    fast_state,
    200 // cfg.meta_batch_size_test,
    test_inner_opt.init,
    test_fsl_sample_fn,
    test_batched_outer_loop_ins,
)

fsl_maml_loss = fsl_maml_results[0].mean()
fsl_maml_acc = fsl_maml_results[1]["outer"]["final"]["aux"][0]["acc"].mean()
fsl_maml_loss, fsl_maml_acc

(DeviceArray(1.8134462, dtype=float32), DeviceArray(0.5618667, dtype=float32))

In [209]:
fsl_embeddings_preds, fsl_embeddings_targets = test_fsl_embeddings(
    rng,
    partial(embeddings_fn, slow_params, slow_state),
    test_fsl_sample_fn,
    200 // cfg.meta_batch_size_test,
    device=device,
    pool=0,
)
fsl_embeddings_acc = (
    (fsl_embeddings_preds == fsl_embeddings_targets)
    .astype(onp.float)
    .mean()
)
fsl_embeddings_acc

0.5582666666666667