## TOC:
* [Environment Setup](#setup)
* [Results](#results)
    * [S-FSVI](#res1)
    * [S-FSVI (larger networks)](#res2)
    * [S-FSVI (no coreset)](#res3)
    * [S-FSVI (minimal coreset)](#res4)
    * [VCL (random-choice coreset)](#res5)

# Environment Setup

## Run as Colab notebook

**Important: Before connecting to a kernel, select a GPU runtime. To do so, open the `Runtime` tab above, click `Change runtime type`, and select `GPU`. Run the setup cell below only after you've done this.**

In [None]:
# pull S-FSVI repository
!git clone https://github.com/timrudner/S-FSVI.git
# patch required packages
!pip install -r ./S-FSVI/colab_requirements.txt

**After successfully running the cell above, you need to restart the runtime. To do so, open the “Runtime” tab above and and click “Restart runtime”. Once the runtime was restarted, run the cell below. There is no need to re-run the installation in the cell above.**

In [None]:
# add the repo to path
import os
import sys
root = os.path.abspath(os.path.join(os.getcwd(), "S-FSVI"))
if root not in sys.path:
    sys.path.insert(0, root)

## Run as Jupyter notebook (-->skip ahead to “Results” if you are running this as a Colab notebook<--)

Install conda environment `fsvi`

In [None]:
!conda env update -f ../environment.yml

Troubleshooting:

 - In case there is an error when installing sklearn: run `pip install Cython==0.29.23` manually and then run the above command again.
 - In case you have access to a GPU, see instructions [here](https://github.com/google/jax#pip-installation-gpu-cuda) for installing the GPU version of `jaxlib`. This will make the experiment run significantly faster.

Run the command below to install the conda environment as a kernel of the jupyter notebook. Then switch to this kernel using the Jupyter Notebook menu bar by selecting `Kernel`, `Change kernel`, and then selecting `fsvi`.

In [None]:
!python -m ipykernel install --user --name=fsvi

Troubleshooting: For further details, see [here](https://medium.com/@nrk25693/how-to-add-your-conda-environment-to-your-jupyter-notebook-in-just-4-steps-abeab8b8d084)

In [None]:
import os
import sys
# assuming os.getcwd() returns the directory containing this jupyter notebook
root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if root not in sys.path:
    sys.path.insert(0, root)

# Results <a name="results"></a>

To read a model checkpoint instead of training the model from scratch, pass load_chkpt=True to the function read_config_and_run .


In [1]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False
import os
import sys
root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if root not in sys.path:
    sys.path.insert(0, root)
    
    
from notebooks.nb_utils.common import read_config_and_run, show_final_average_accuracy
import sfsvi.exps.utils.load_utils as lutils

task_sequence = "pmnist_sh"

 The versions of TensorFlow you are currently using is 2.8.0 and is not supported. 
Some things might work, some things might not.
If you were to encounter a bug, do not file an issue.
If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
You can find the compatibility matrix in TensorFlow Addon's readme:
https://github.com/tensorflow/addons
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  dtype=np.int):


Jax is running on gpu


## S-FSVI (ours) <a name="res1"></a>

In [2]:
logdir = read_config_and_run("fsvi_match.pkl", task_sequence)
exp = lutils.read_exp(logdir)
show_final_average_accuracy(exp)

loading experiments: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 1127.48it/s]

Loading from cache:
Running on clpc158.cs.ox.ac.uk
Jax is running on gpu


Input arguments:
 {
    "command":"cl",
    "data_training":"continual_learning_pmnist",
    "data_ood":[
        "not_specified"
    ],
    "model_type":"fsvi_mlp",
    "optimizer":"adam",
    "optimizer_var":"not_specified",
    "momentum":0.0,
    "momentum_var":0.0,
    "schedule":"not_specified",
    "architecture":[
        100,
        100
    ],
    "activation":"relu",
    "prior_mean":"0.0",
    "prior_cov":"0.001",
    "prior_covs":[
        0.0
    ],
    "prior_type":"bnn_induced",
    "epochs":10,
    "start_var_opt":0,
    "batch_size":128,
    "learning_rate":0.0005,
    "learning_rate_var":0.001,
    "dropout_rate":0.0,
    "regularization":0.0,
    "inducing_points":0,
    "n_marginals":1,
    "n_condition":128,
    "inducing_input_type":"uniform_rand",
    "inducing_input_ood_data":[
        "not_specified"
    ],
    "inducing_input_ood_data_size":50000,
    "kl_scale":"equal",
    "feature_m




## S-FSVI (larger networks) <a name="res2"></a>

In [3]:
logdir = read_config_and_run("fsvi_optimized.pkl", task_sequence)
exp = lutils.read_exp(logdir)
show_final_average_accuracy(exp)

loading experiments: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 1157.20it/s]

Loading from cache:
Running on oat1.cs.ox.ac.uk
Jax is running on gpu


Input arguments:
 {
    "command":"cl",
    "data_training":"continual_learning_pmnist",
    "data_ood":[
        "not_specified"
    ],
    "model_type":"fsvi_mlp",
    "optimizer":"adam",
    "optimizer_var":"not_specified",
    "momentum":0.0,
    "momentum_var":0.0,
    "schedule":"not_specified",
    "architecture":[
        500,
        500
    ],
    "activation":"relu",
    "prior_mean":"0.0",
    "prior_cov":"0.001",
    "prior_covs":[
        0.0
    ],
    "prior_type":"bnn_induced",
    "epochs":20,
    "start_var_opt":0,
    "batch_size":128,
    "learning_rate":0.0001,
    "learning_rate_var":0.001,
    "dropout_rate":0.0,
    "regularization":0.0,
    "inducing_points":0,
    "n_marginals":1,
    "n_condition":128,
    "inducing_input_type":"uniform_rand",
    "inducing_input_ood_data":[
        "not_specified"
    ],
    "inducing_input_ood_data_size":50000,
    "kl_scale":"equal",
    "feature_map_




## S-FSVI (no coreset) <a name="res3"></a>

In [4]:
logdir = read_config_and_run("fsvi_no_coreset.pkl", task_sequence)
exp = lutils.read_exp(logdir)
show_final_average_accuracy(exp)

loading experiments: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 1156.06it/s]

Loading from cache:
Running on clpc158.cs.ox.ac.uk
Jax is running on gpu


Input arguments:
 {
    "command":"cl",
    "data_training":"continual_learning_pmnist",
    "data_ood":[
        "not_specified"
    ],
    "model_type":"fsvi_mlp",
    "optimizer":"adam",
    "optimizer_var":"not_specified",
    "momentum":0.0,
    "momentum_var":0.0,
    "schedule":"not_specified",
    "architecture":[
        100,
        100
    ],
    "activation":"relu",
    "prior_mean":"0.0",
    "prior_cov":"0.001",
    "prior_covs":[
        0.0
    ],
    "prior_type":"bnn_induced",
    "epochs":5,
    "start_var_opt":0,
    "batch_size":128,
    "learning_rate":0.0001,
    "learning_rate_var":0.001,
    "dropout_rate":0.0,
    "regularization":0.0,
    "inducing_points":0,
    "n_marginals":1,
    "n_condition":128,
    "inducing_input_type":"uniform_rand",
    "inducing_input_ood_data":[
        "not_specified"
    ],
    "inducing_input_ood_data_size":50000,
    "kl_scale":"equal",
    "feature_ma




## S-FSVI (minimal coreset) <a name="res4"></a>

In [5]:
logdir = read_config_and_run("fsvi_minimal_coreset.pkl", task_sequence)
exp = lutils.read_exp(logdir)
show_final_average_accuracy(exp)

loading experiments: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 1177.61it/s]

Loading from cache:
Running on oat2.cs.ox.ac.uk
Jax is running on gpu


Input arguments:
 {
    "command":"cl",
    "data_training":"continual_learning_pmnist",
    "data_ood":[
        "not_specified"
    ],
    "model_type":"fsvi_mlp",
    "optimizer":"adam",
    "optimizer_var":"not_specified",
    "momentum":0.0,
    "momentum_var":0.0,
    "schedule":"not_specified",
    "architecture":[
        300,
        300
    ],
    "activation":"relu",
    "prior_mean":"0.0",
    "prior_cov":"10.0",
    "prior_covs":[
        0.0
    ],
    "prior_type":"bnn_induced",
    "epochs":10,
    "start_var_opt":0,
    "batch_size":128,
    "learning_rate":0.0005,
    "learning_rate_var":0.001,
    "dropout_rate":0.0,
    "regularization":0.0,
    "inducing_points":0,
    "n_marginals":1,
    "n_condition":128,
    "inducing_input_type":"train_pixel_rand_0.0",
    "inducing_input_ood_data":[
        "not_specified"
    ],
    "inducing_input_ood_data_size":50000,
    "kl_scale":"equal",
    "featu




## VCL (random-choice coreset) <a name="res5"></a>

In [6]:
logdir = read_config_and_run("vcl_random_coreset.pkl", task_sequence, "vcl")
exp = lutils.read_exp(logdir)
show_final_average_accuracy(exp, runner="vcl")

loading experiments: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 5607.36it/s]

Loading from cache:
Running on oat18.cs.ox.ac.uk
Running with: python /auto/users/timner/qixuan/function-space-variational-inference/fsvi_cl/baselines/vcl/run_vcl.py --dataset pmnist --n_epochs 100 --batch_size 256 --hidden_size 100 --n_layers 2 --seed 7 --select_method random_choice --n_permuted_tasks 10 --logroot ablation --subdir reproduce_main_results_3 --n_coreset_inputs_per_task 200
----------------------------------------------------------------------------------------------------
('Epoch:', '0001', 'cost=', '0.494620047')
('Epoch:', '0006', 'cost=', '0.068687391')
('Epoch:', '0011', 'cost=', '0.032833647')
('Epoch:', '0016', 'cost=', '0.015344139')
('Epoch:', '0021', 'cost=', '0.006667245')
('Epoch:', '0026', 'cost=', '0.004755347')
('Epoch:', '0031', 'cost=', '0.001242538')
('Epoch:', '0036', 'cost=', '0.000305422')
('Epoch:', '0041', 'cost=', '0.000169750')
('Epoch:', '0046', 'cost=', '0.001465945')
('Epoch:', '0051', 'cost=', '0.000185845')
('Epoch:', '0056', 'cost=', '0.000


