Skip to content

A clean Tensorflow implementation of the Continual Learning library - Continuum.

License

Notifications You must be signed in to change notification settings

shikhar-srivastava/continuum_tensorflow

Repository files navigation

Continuum-Tensorflow

A clean Tensorflow implementation of the Continual Learning library - Continuum.

Effort has been made to retain the bare-bones TF/python calls without building too many abstractions on top of it. Abstractions make integrating native TF Dataset api's difficult without an investigation of the library's design. That has been avoided here.

Example:

Clone repo:

git clone https://github.com/aishikhar/continuum_tensorflow.git

Example:

from continuum_tensorflow.data import continual_dataset
import tensorflow as tf
import numpy as np

train, test = continual_dataset(dataset = 'splitmnist', n_tasks = 5)

for task_no in range(n_tasks):

    task_label, data, labels = train[task_no]
    this_task = tf.data.Dataset.from_tensor_slices((data, labels)).batch(batch_size = 8)
    
    # Do your stuff
    learn_on(this_task)

Tasks added:

CIFAR 100:

Task 1 Task 2 Task 3 Task 4 Task 5

Split MNIST:

Task 1 Task 2 Task 3 Task 4 Task 5

Permuted MNIST:

Task 1 Task 2 Task 3 Task 4 Task 5

(Images: Continuum)