In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

sys.path.insert(0, "../src")

In [6]:
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

In [7]:
import os
import sys
import os.path as osp

import dill
import time
from tqdm.notebook import tqdm as tqdmn
from argparse import ArgumentParser
from easydict import EasyDict as edict

import numpy as onp

import jax
from jax.random import split
from jax import (
    jit,
    vmap,
    random,
    partial,
    tree_map,
    numpy as jnp,
    value_and_grad,
)

import optax as ox
import haiku as hk

from lib import (
    apply_and_loss_fn,
    
    flatten,
    outer_loop,
    setup_device,
    fsl_inner_loop,
    batched_outer_loop,
    parse_and_build_cfg,
    mean_xe_and_acc_dict,
    # outer_loop_reset_per_task,
)
from data import prepare_data
from experiment import Experiment, Logger
from data.sampling import fsl_sample_transfer_build, BatchSampler
from models.maml_conv import miniimagenet_cnn_argparse, prepare_model, make_params

from test_sup import test_sup_cosine
from test_utils import test_fsl_embeddings

TRAIN_SIZE = 500

In [8]:
cfg = edict()

In [9]:
cfg.seed = 0
cfg.gpus = 1
cfg.data_dir = "/home/samenabar/storage/data/FSL/mini-imagenet/"
cfg.dataset = "miniimagenet"

cfg.model = edict()
cfg.model.hidden_size = 32
cfg.model.activation = "relu"
cfg.model.no_track_bn_stats = False

In [10]:
cpu, device = setup_device(cfg.gpus, default_platform="cpu")
rng = random.PRNGKey(cfg.seed)  # Default seed is 0
print(f"Running on {device} with seed: {cfg.seed}")

Running on gpu:0 with seed: 0


In [267]:
train_images, train_labels, val_images, val_labels, preprocess_fn = prepare_data(
    cfg.dataset, cfg.data_dir, device,
)
fsl_train_images = train_images[:, :TRAIN_SIZE]
fsl_train_labels = train_labels[:, :TRAIN_SIZE]
# These are for supervised learning validation
sup_val_images = flatten(train_images[:, TRAIN_SIZE:], 1)
sup_val_labels = flatten(train_labels[:, TRAIN_SIZE:], 1)

transfer_spt_images = flatten(val_images[:, :TRAIN_SIZE], (0, 1))
transfer_spt_labels = flatten(val_labels[:, :TRAIN_SIZE], (0, 1))
transfer_qry_images = flatten(val_images[:, TRAIN_SIZE:], (0, 1))
transfer_qry_labels = flatten(val_labels[:, TRAIN_SIZE:], (0, 1))

print("Supervised train data:", fsl_train_images.shape, fsl_train_labels.shape)
print(
    "Supervised validation data:", sup_val_images.shape, sup_val_labels.shape,
)
print(
    "FSL and Transfer learning data:", val_images.shape, val_labels.shape,
)
print("Transfer", transfer_spt_images.shape, transfer_spt_labels.shape, transfer_qry_images.shape, transfer_qry_labels.shape)

Supervised train data: (64, 500, 84, 84, 3) (64, 500)
Supervised validation data: (6400, 84, 84, 3) (6400,)
FSL and Transfer learning data: (16, 600, 84, 84, 3) (16, 600)
Transfer (8000, 84, 84, 3) (8000,) (1600, 84, 84, 3) (1600,)


In [106]:
from data.sampling import shuffle_along_axis

In [110]:
num_tasks = 4
shuffle_along_axis(
    rng, jnp.arange(val_images.shape[0])[None, :].repeat(num_tasks, 0), 1
)[:, :val_images.shape[0]]

DeviceArray([[10,  6,  1, 13, 11, 15,  9,  8,  0,  4,  5, 12,  3,  2, 14,
               7],
             [ 8,  6,  2, 12, 13, 14, 11,  9, 15,  4,  5, 10,  7,  0,  3,
               1],
             [14, 12,  5,  6, 15,  8,  3,  9,  2,  0,  1, 11, 10, 13,  7,
               4],
             [ 8,  6,  9,  0, 15,  2,  7, 12,  3, 14, 13, 11,  5, 10,  4,
               1]], dtype=int32)

In [257]:
import time
from data.sampling import BatchSampler
from test_utils import forward_loader, lr_fit_eval
from lib import cl_inner_loop, outer_loop, batched_outer_loop, apply_and_loss_fn

In [269]:
class MultiTester:
    def __init__(
        self,
        train_train_images,  # (num_classes, samples_per_class, *image_shape)
        train_train_labels,  # (num_classes, samples_per_class)
        train_val_images,
        train_val_labels,
        val_images,  # (num_classes, samples_per_class, *image_shape)
        val_labels,  # (num_classes, samples_per_class)
        transfer_spt_size=500,
    ):
        self.train_train_images = train_train_images
        self.train_train_labels = train_train_labels
        self.train_val_images = train_val_images
        self.train_val_labels = train_val_labels
        self.val_images = val_images
        self.val_labels = val_labels
        self.transfer_spt_size = transfer_spt_size

    @staticmethod
    def lr_test(spt_sampler, qry_sampler, embeddings_fn, preprocess_fn, device):
        spt_embeddings, spt_targets = forward_loader(embeddings_fn, spt_sampler, device, is_norm=True, preprocess_fn=preprocess_fn)
        qry_embeddings, qry_targets = forward_loader(embeddings_fn, qry_sampler, device, is_norm=True, preprocess_fn=preprocess_fn)
        now = time.time()
        qry_preds = lr_fit_eval(spt_embeddings, spt_targets, qry_embeddings, n_jobs=4)
        return (qry_preds == qry_targets).astype(onp.float).mean(), time.time() - now
        
    @staticmethod
    def batch_and_lr_test(rng, x_spt, y_spt, x_qry, y_qry, slow_params, slow_state, slow_apply, preprocess_fn, device, batch_size=512):
        rng, rng_spt, rng_qry = split(rng, 3)
        spt_sampler = BatchSampler(rng_spt, x_spt, y_spt, batch_size, shuffle=True, keep_last=True)
        qry_sampler = BatchSampler(rng_qry, x_qry, y_qry, batch_size, shuffle=True, keep_last=True)
        
        def extract_embeddings(inputs):
            return slow_apply(slow_params, slow_state, rng, inputs, False)[0][0]
        
        return MultiTester.lr_test(spt_sampler, qry_sampler, extract_embeddings, preprocess_fn, device)
        
    def run_sup_lr_test(self, rng, slow_params, slow_state, slow_apply, preprocess_fn, device=None, batch_size=512):
        return self.batch_and_lr_test(rng, self.train_train_images, train_train_labels, train_val_images, train_val_labels, slow_params, slow_state, slow_apply, preprocess_fn, device, batch_size)
        
        
    def run_cl_test(
        self,
        rng,
        slow_params,
        fast_params,
        slow_state,
        fast_state,
        slow_apply,
        fast_apply,
        loss_fn,
        opt,
        preprocess_fn,
        device=None,
        batch_size=1,
        way=None,
        spt_shot=10,
        qry_shot=-1,
    ):
        def forward_fn(fast_params, fast_state, inputs, targets):
            return apply_and_loss_fn(slow_params, fast_params, slow_state, fast_state, rng_outer, inputs, targets, False, slow_apply, fast_apply, loss_fn)
        
        infos = []
        results = []
        for _ in range(batch_size):
            rng, rng_spt_qry, rng_order, rng_inner, rng_outer = split(rng, 5)
            if way is None:
                way = self.val_images.shape[0]  # Use all classes
            assert way <= self.val_images.shape[0]

            classes_order = shuffle_along_axis(
                rng_order,
                jnp.arange(self.val_images.shape[0])[None, :].repeat(1, 0),
                1,
            )[:, :way]
            classes_order = classes_order.reshape(1 * way, 1)
            # classes_order = classes_order.repeat(spt_shot, 1)

            idxs = shuffle_along_axis(
                rng_spt_qry, jnp.arange(self.val_images.shape[1])[None].repeat(1 * way, 0), 1
            )
            spt_idxs = idxs[:, :spt_shot]
            if qry_shot == -1:
                qry_shot = self.val_images.shape[1] - spt_shot
                qry_idxs = idxs[:, spt_shot:]
            elif spt_shot + qry_shot <= self.val_images.shape[1]:
                qry_idxs = idxs[:, spt_shot:spt_shot + qry_shot]

            spt_order = classes_order.repeat(spt_shot, 1)
            qry_order = classes_order.repeat(qry_shot, 1)
            spt_images, spt_labels = self.val_images[spt_order, spt_idxs], self.val_labels[spt_order, spt_idxs]
            qry_images, qry_labels = self.val_images[qry_order, qry_idxs], self.val_labels[qry_order, qry_idxs]
            image_shape = self.val_images.shape[-3:]
            spt_images = spt_images.reshape(way * spt_shot, *image_shape)
            spt_labels = spt_labels.reshape(way * spt_shot)
            qry_images = qry_images.reshape(way * qry_shot, *image_shape)
            qry_labels = qry_labels.reshape(way * qry_shot)

            spt_images = preprocess_fn(jax.device_put(spt_images, device))
            spt_labels = jax.device_put(spt_labels, device)
            qry_images = preprocess_fn(jax.device_put(qry_images, device))
            qry_labels = jax.device_put(qry_labels, device)

            new_fast_params, new_slow_state, new_fast_states, new_opt_stats, info = cl_inner_loop(slow_params, fast_params, slow_state, fast_state, opt.init(fast_params), rng_inner, spt_images, spt_labels, is_training=False, slow_apply=slow_apply, fast_apply=fast_apply, loss_fn=loss_fn, opt_update_fn=opt.update)

            infos.append(info)
            results.append(forward_fn(new_fast_params, new_fast_states, qry_images, qry_labels))
        
        return jax.tree_multimap(lambda x, *xs: jnp.stack(xs), results[0], *results), jax.tree_multimap(lambda x, *xs: jnp.stack(xs), infos[0], *infos)

In [262]:
output_size = 16
print(output_size)
body, head = prepare_model(
    cfg.dataset,
    output_size,
    cfg.model.hidden_size,
    cfg.model.activation,
    # track_stats=not cfg.model.no_track_bn_stats,
)

16


In [270]:
multi_tester = MultiTester(flatten(fsl_train_images, 1), flatten(fsl_train_labels, 1), sup_val_images, sup_val_labels, val_images, val_labels)
multi_tester.batch_and_lr_test(rng,  transfer_spt_images, transfer_spt_labels, transfer_qry_images, transfer_qry_labels, slow_params, slow_state, body.apply, preprocess_fn, device)

(0.604375, 31.886035442352295)

In [None]:
outs = multi_tester.run_cl_test(
    random.PRNGKey(0),
    slow_params,
    jax.tree_map(lambda x: jnp.zeros((800, output_size)), fast_params),
    slow_state,
    fast_state,
    slow_apply=body.apply,
    fast_apply=head.apply,
    loss_fn=mean_xe_and_acc_dict,
    opt=ox.sgd(1e-6),
    preprocess_fn=preprocess_fn,
    device=device,
    batch_size=2,
    way=None,
    spt_shot=5,
    qry_shot=100,
)

In [167]:
with open("/home/samenabar/storage/code/continual_learning/meta-learning-representations-jax/experiments/Sep02-2020/MAML-std/bsz-8/checkpoints/best.ckpt", "rb") as f:
    ckpt = dill.load(f)
    slow_params = jax.tree_map(lambda x: jax.device_put(x, device), ckpt["slow_params"])
    fast_params = jax.tree_map(lambda x: jax.device_put(x, device), ckpt["fast_params"])
    slow_state = jax.tree_map(lambda x: jax.device_put(x, device), ckpt["slow_state"])
    fast_state = jax.tree_map(lambda x: jax.device_put(x, device), ckpt["fast_state"])

In [166]:
output_size = fsl_train_images.shape[0]
print(output_size)
body, head = prepare_model(
    cfg.dataset,
    output_size,
    cfg.model.hidden_size,
    cfg.model.activation,
    # track_stats=not cfg.model.no_track_bn_stats,
)

def pred_fn(slow_params, fast_params, slow_state, fast_state, rng, inputs, is_training, slow_apply, fast_apply):
    rng_slow, rng_fast = split(rng)
    slow_outputs, _ = slow_apply(slow_params, slow_state, rng_slow, inputs, is_training)
    fast_outputs, _ = fast_apply(fast_params, fast_state, rng_fast, *slow_outputs, is_training)
    return fast_outputs

def embeddings_fn(slow_params, slow_state, rng, inputs, is_training, slow_apply):
    return slow_apply(slow_params, slow_state, rng, inputs, is_training)[0][0]

no_stats_body, no_stats_head = prepare_model(
    cfg.dataset,
    output_size,
    cfg.model.hidden_size,
    cfg.model.activation,
    track_stats=False,
)
rng, rng_params = split(rng)
(untrained_slow_params, untrained_fast_params, untrained_slow_state, untrained_fast_state,) = make_params(
    rng, cfg.dataset, no_stats_body.init, no_stats_body.apply, no_stats_head.init, device,
)

test_pred_fn = jit(partial(pred_fn, is_training=False, slow_apply=body.apply, fast_apply=head.apply))
test_untrained_pred_fn = jit(partial(pred_fn, is_training=False, slow_apply=no_stats_body.apply, fast_apply=no_stats_head.apply))
test_no_stats_pred_fn = jit(partial(pred_fn, is_training=False, slow_apply=no_stats_body.apply, fast_apply=no_stats_head.apply))

test_embeddings_fn = jit(partial(embeddings_fn, is_training=False, slow_apply=body.apply))
test_untrained_embeddings_fn = jit(partial(embeddings_fn, is_training=False, slow_apply=no_stats_body.apply))
test_no_stats_embeddings_fn = jit(partial(embeddings_fn, is_training=False, slow_apply=no_stats_body.apply))

64


In [10]:
from lib import xe_and_acc, cl_inner_loop
from test_utils import SupervisedStandardTester, FSLLRTester, SupervisedCosineTester, lr_fit_eval

In [86]:
with open("/home/samenabar/storage/code/continual_learning/meta-learning-representations-jax/experiments/Sep02-2020/MAML-std/bsz-8/checkpoints/best.ckpt", "rb") as f:
    ckpt = dill.load(f)
    slow_params = jax.tree_map(lambda x: jax.device_put(x, device), ckpt["slow_params"])
    fast_params = jax.tree_map(lambda x: jax.device_put(x, device), ckpt["fast_params"])
    slow_state = jax.tree_map(lambda x: jax.device_put(x, device), ckpt["slow_state"])
    fast_state = jax.tree_map(lambda x: jax.device_put(x, device), ckpt["fast_state"])

In [8]:
class ContinualLearnerTester:
    def __init__(
        self,
        spt_images,
        spt_labels,
        qry_images,
        qry_labels,
        device,
        preprocess_fn,
        n_jobs=4,
    ):
        self.spt_images = spt_images
        self.spt_labels = spt_labels
        self.qry_images = qry_images
        self.qry_labels = qry_labels

        self.device = device
        self.preprocess_fn = preprocess_fn
        self.n_jobs = n_jobs

    def __call__(
        self,
        length,
        slow_params,
        fast_params,
        slow_state,
        fast_state,
        rng,
        slow_apply,
        fast_apply,
        loss_fn,
        opt,
    ):
        def apply_fn(slow_params, fast_params, slow_state, fast_state, inputs):
            slow_outputs, _ = slow_apply(slow_params, slow_state, rng, inputs, False)
            return fast_apply(fast_params, fast_state, rng, *slow_outputs, False)[0]

        spt_images = flatten(
            preprocess_fn(jax.device_put(self.spt_images[:length], device)), 1
        )
        spt_labels = flatten((jax.device_put(self.spt_labels[:length], device)), 1)
        qry_images = flatten(
            preprocess_fn(jax.device_put(self.qry_images[:length], device)), 1
        )
        qry_labels = flatten((jax.device_put(self.qry_labels[:length], device)), 1)

        adapted_fast_params, _, _, _, info = cl_inner_loop(
            slow_params,
            fast_params,
            slow_state,
            fast_state,
            opt.init(fast_params),
            rng,
            spt_images,
            spt_labels,
            False,
            slow_apply,
            fast_apply,
            loss_fn,
            opt.update,
        )

        spt_slow_outputs, _ = slow_apply(
            slow_params, slow_state, rng, spt_images, False
        )
        qry_slow_outputs, _ = slow_apply(
            slow_params, slow_state, rng, qry_images, False
        )

        spt_outputs = fast_apply(
            adapted_fast_params, fast_state, rng, *spt_slow_outputs, False
        )[0]
        qry_outputs = fast_apply(
            adapted_fast_params, fast_state, rng, *qry_slow_outputs, False
        )[0]

        lr_spt_preds, lr_qry_preds = lr_fit_eval(
            onp.array(flatten(spt_slow_outputs[0], (1, 3))),
            onp.array(spt_labels),
            onp.array(flatten(qry_slow_outputs[0], (1, 3))),
            n_jobs=self.n_jobs,
            predict_train=True,
        )
        lr_spt_acc = (lr_spt_preds == spt_labels).astype(jnp.float32).mean()
        lr_qry_acc = (lr_qry_preds == qry_labels).astype(jnp.float32).mean()

        spt_acc = (spt_outputs.argmax(-1) == spt_labels).astype(jnp.float32).mean()
        qry_acc = (qry_outputs.argmax(-1) == qry_labels).astype(jnp.float32).mean()

        return spt_acc, qry_acc, lr_spt_acc, lr_qry_acc

# Supervised Learning

In [21]:
supervised_std_tester = SupervisedStandardTester(
    rng,
    sup_val_images,
    sup_val_labels,
    512,
    # pred_fn_test,
    preprocess_fn,
    device,
)

In [91]:
loss, acc = supervised_std_tester(partial(test_no_stats_pred_fn, slow_params, fast_params, slow_state, fast_state, rng))
print("Trained loss acc:")
print(loss, acc)

AssertionError: 'mini_imagenet_cnn_head/linear/w' with shape (800, 5) does not match shape=[800, 64] dtype=dtype('float32')

In [23]:
loss, acc = supervised_std_tester(partial(test_untrained_pred_fn, untrained_slow_params, untrained_fast_params, untrained_slow_state, untrained_fast_state, rng))
print("Untrained loss acc:")
print(loss, acc)

Untrained loss acc:
5.271284 0.01109375


In [25]:
fsl_train_images.shape

(64, 500, 84, 84, 3)

In [27]:
supervised_cosine_tester = SupervisedCosineTester(
    rng,
    flatten(fsl_train_images, 1),
    flatten(fsl_train_labels, 1),
    sup_val_images,
    sup_val_labels,
    512,
    preprocess_fn,
    device,
)

In [93]:
supervised_cosine_tester(partial(test_no_stats_embeddings_fn, slow_params, slow_state, rng))

0.3815625

In [29]:
supervised_cosine_tester(partial(test_untrained_embeddings_fn, untrained_slow_params, untrained_slow_state, rng))

0.17953125

# Few shot learning

In [30]:
fsllr_tester = FSLLRTester(
    val_images,
    val_labels,
    25,
    1000,
    5,
    5,
    15,
    preprocess_fn,
    device,
)

In [94]:
fsllr_tester(partial(test_no_stats_embeddings_fn, slow_params, slow_state, rng), rng)

0.63268

In [32]:
fsllr_tester(partial(test_untrained_embeddings_fn, untrained_slow_params, untrained_slow_state, rng), rng)

0.40872

# Continual Learning

In [33]:
cl_tester = ContinualLearnerTester(val_images[:, :10], val_labels[:, :10], val_images[:, 10:210], val_labels[:, 10:210], device, preprocess_fn)

In [96]:
inner_opt = ox.sgd(1e-2)
cl_tester(16, slow_params, fast_params, slow_state, fast_state, rng, no_stats_body.apply, no_stats_head.apply, mean_xe_and_acc_dict, inner_opt)

AssertionError: 'mini_imagenet_cnn_head/linear/w' with shape (800, 5) does not match shape=[800, 64] dtype=dtype('float32')

In [104]:
inner_opt = ox.sgd(1e-5)
cl_tester(16, slow_params, jax.tree_map(lambda x: jnp.zeros((800, 64)), fast_params), slow_state, fast_state, rng, no_stats_body.apply, no_stats_head.apply, mean_xe_and_acc_dict, inner_opt)

(DeviceArray(0.70625, dtype=float32),
 DeviceArray(0.37031248, dtype=float32),
 DeviceArray(1., dtype=float32),
 DeviceArray(0.386875, dtype=float32))

In [54]:
inner_opt = ox.sgd(1e-3)
cl_tester(16, untrained_slow_params, untrained_fast_params, untrained_slow_state, untrained_fast_state, rng, no_stats_body.apply, no_stats_head.apply, mean_xe_and_acc_dict, inner_opt)

(DeviceArray(0.0875, dtype=float32),
 DeviceArray(0.07593749, dtype=float32),
 DeviceArray(1., dtype=float32),
 DeviceArray(0.2265625, dtype=float32))

In [63]:
inner_opt = ox.sgd(1e-7)
cl_tester(16, untrained_slow_params, jax.tree_map(lambda x: jnp.zeros(x.shape), fast_params), untrained_slow_state, untrained_fast_state, rng, no_stats_body.apply, no_stats_head.apply, mean_xe_and_acc_dict, inner_opt)

(DeviceArray(0.39375, dtype=float32),
 DeviceArray(0.1190625, dtype=float32),
 DeviceArray(1., dtype=float32),
 DeviceArray(0.2265625, dtype=float32))

# Transfer learning

In [65]:
transfer_cosine_tester = SupervisedCosineTester(
    rng,
    flatten(transfer_spt_images, 1),
    flatten(transfer_spt_labels, 1),
    flatten(transfer_qry_images, 1),
    flatten(transfer_qry_labels, 1),
    512,
    preprocess_fn,
    device,
)

In [105]:
transfer_cosine_tester(partial(test_no_stats_embeddings_fn, slow_params, slow_state, rng))

0.47625

In [67]:
transfer_cosine_tester(partial(test_untrained_embeddings_fn, untrained_slow_params, untrained_slow_state, rng))

0.319375