# Continuous-Time Meta-Learning with Forward Mode Differentiation

This notebook contains an example of how to meta-train COMLN. We will use the preprocessed version of miniImageNet as an example, introduced in ([Rusu et al., 2018](https://arxiv.org/abs/1807.05960)), with 5-way 5-shot classification tasks.

In [None]:
import jax.numpy as jnp
import haiku as hk
import optax
import math

from tqdm.auto import tqdm
from scipy.stats import t
from jax import random
from jax_meta.datasets import LEOMiniImagenet

from comln import COMLN, COMLNMetaParameters

In [None]:
meta_train_dataset = LEOMiniImagenet(
    root='data/',
    batch_size=16,
    shots=5,
    ways=5,
    test_shots=15,
    size=2000,
    split='train',
    seed=0,
    download=True
)

Since the data has already been preprocessed, there is no need for a feature extraction network. Therefore the model here is simply the identity function.

In [None]:
@hk.without_apply_rng
@hk.transform_with_state
def model(inputs):
    return inputs

In [None]:
key = random.PRNGKey(0)

optimizer = optax.multi_transform({
    'model': optax.sgd(1e-1, momentum=0.9, nesterov=True),
    'classifier': optax.sgd(1e-1, momentum=0.9, nesterov=True),
    't_final': optax.sgd(1e-1, momentum=0.9, nesterov=True)
}, COMLNMetaParameters(model='model', classifier='classifier', t_final='t_final'))

metalearner = COMLN(
    model,
    num_ways=meta_train_dataset.ways,
    t_final=1.,
    odeint_kwargs='{"atol":1e-5,"rtol":1e-5}'
)

params, state = metalearner.init(key, optimizer, meta_train_dataset.dummy_input)

with tqdm(meta_train_dataset, desc='Meta-train') as pbar:
    for batch in pbar:
        params, state, logs = metalearner.step(params, state, batch['train'], batch['test'])
        pbar.set_postfix(
            T=f'{jnp.exp(params.t_final):.2f}',
            accuracy=f'{100 * logs["outer/accuracy"].mean():.2f}',
        )

In [None]:
meta_test_dataset = LEOMiniImagenet(
    root='data/',
    batch_size=10,
    shots=5,
    ways=5,
    test_shots=15,
    size=100,
    split='test',
    seed=0
)

results = metalearner.evaluate(params, state, meta_test_dataset)

In [None]:
accuracy = results['outer/accuracy']
print(f'Accuracy: {results["outer/accuracy"] * 100:.2f}%')