Avalanche Reference: https://avalanche.continualai.org/from-zero-to-hero-tutorial/07_putting-all-together

Dataset Preparation:https://colab.research.google.com/github/sicara/easy-few-shot-learning/blob/master/notebooks/my_first_few_shot_classifier.ipynb#scrollTo=OrUCQ7AslpFO

In [1]:
!pip install git+https://github.com/ContinualAI/avalanche.git

Collecting git+https://github.com/ContinualAI/avalanche.git
  Cloning https://github.com/ContinualAI/avalanche.git to /tmp/pip-req-build-jfes_d9f
  Running command git clone -q https://github.com/ContinualAI/avalanche.git /tmp/pip-req-build-jfes_d9f
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting pytorchcv
  Downloading pytorchcv-0.0.67-py2.py3-none-any.whl (532 kB)
[K     |████████████████████████████████| 532 kB 5.1 MB/s 
Collecting quadprog
  Downloading quadprog-0.1.10.tar.gz (121 kB)
[K     |████████████████████████████████| 121 kB 35.0 MB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting gputil
  Downloading GPUtil-1.4.0.tar.gz (5.5 kB)
Collecting wandb
  Downloading wandb-0.12.6-py2.py3-none-any.whl (1.7 MB)
[K     |███████

# Imports

In [48]:
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
import torchvision
from torchvision import transforms

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from avalanche.evaluation.metrics import forgetting_metrics, \
accuracy_metrics, loss_metrics, timing_metrics, cpu_usage_metrics, \
confusion_matrix_metrics, disk_usage_metrics
# from torchvision.datasets import Omniglot
from avalanche.models import SimpleMLP
from avalanche.training.strategies import Naive
from avalanche.logging import InteractiveLogger, TextLogger, TensorboardLogger
from avalanche.training.plugins import EvaluationPlugin
import torchvision
from avalanche.benchmarks.datasets import Omniglot
from avalanche.benchmarks.classic import SplitOmniglot
from avalanche.models import SimpleMLP
from avalanche.training.strategies import Naive
from avalanche.benchmarks.utils import ImageFolder, DatasetFolder, FilelistDataset, AvalancheDataset
from avalanche.benchmarks.scenarios.new_classes.nc_scenario import NCScenario

In [60]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Benchmarking

In [52]:
height = 28

scenario = SplitOmniglot(n_experiences=4, 
                          seed=1, 
                          train_transform=transforms.Compose([
                              transforms.Resize(height),
                              transforms.Grayscale(),
                              transforms.ToTensor(),
                              transforms.Normalize((0.9221,), (0.2681,))
                          ]), 
                          eval_transform=transforms.Compose([
                              transforms.Resize(height),
                              transforms.Grayscale(),
                              transforms.ToTensor(),
                              transforms.Normalize((0.9221,), (0.2681,))
                          ]))

train_stream = scenario.train_stream

for experience in train_stream:
    t = experience.task_label
    exp_id = experience.current_experience
    training_dataset = experience.dataset
    print('Task {} batch {} -> train'.format(t, exp_id))
    print('This batch contains', len(training_dataset), 'patterns')

Task 0 batch 0 -> train
This batch contains 1928 patterns
Task 0 batch 1 -> train
This batch contains 1928 patterns
Task 0 batch 2 -> train
This batch contains 1928 patterns
Task 0 batch 3 -> train
This batch contains 1928 patterns
Task 0 batch 4 -> train
This batch contains 1928 patterns
Task 0 batch 5 -> train
This batch contains 1928 patterns
Task 0 batch 6 -> train
This batch contains 1928 patterns
Task 0 batch 7 -> train
This batch contains 1928 patterns
Task 0 batch 8 -> train
This batch contains 1928 patterns
Task 0 batch 9 -> train
This batch contains 1928 patterns


# Evaluation and Logging

In [53]:
# log to Tensorboard
tb_logger = TensorboardLogger()
# log to text file
text_logger = TextLogger(open('log.txt', 'a'))
# print to stdout
interactive_logger = InteractiveLogger()
eval_plugin = EvaluationPlugin(
    accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    timing_metrics(epoch=True, epoch_running=True),
    forgetting_metrics(experience=True, stream=True),
    cpu_usage_metrics(experience=True),
    confusion_matrix_metrics(num_classes=scenario.n_classes, save_image=False,
                             stream=True),
    disk_usage_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    loggers=[interactive_logger, text_logger, tb_logger]
)

  "No benchmark provided to the evaluation plugin. "


#Model

In [58]:
model = SimpleMLP(num_classes=n_classes, input_size=28*28)

# Strategy Instance

In [None]:
optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)

cl_strategy = Naive(
    model, optimizer, CrossEntropyLoss(), train_mb_size=500, train_epochs=4, eval_mb_size=100,
    evaluator=eval_plugin, eval_mb_size=batch_size, device=device)

# Training Loop

In [65]:
print('Starting experiment...')
results = []
for experience in scenario.train_stream:
    print("Start of experience: ", experience.current_experience)
    print("Current Classes: ", experience.classes_in_this_experience)

    cl_strategy.train(experience)
    print('Training completed')

    print('Computing accuracy on the whole test set')
    results.append(cl_strategy.eval(scenario.test_stream))

Files already downloaded and verified
Files already downloaded and verified
Starting experiment...
Start of experience:  0
Current Classes:  [2, 8, 520, 10, 529, 531, 534, 28, 541, 542, 33, 35, 548, 550, 40, 554, 556, 47, 559, 561, 564, 567, 568, 58, 571, 573, 62, 574, 64, 580, 70, 72, 74, 589, 78, 590, 592, 593, 594, 595, 598, 87, 600, 90, 93, 607, 97, 100, 102, 614, 617, 107, 623, 626, 115, 628, 122, 123, 636, 126, 638, 129, 647, 139, 651, 653, 143, 144, 145, 146, 655, 148, 656, 657, 659, 152, 664, 668, 157, 161, 163, 676, 166, 679, 168, 681, 171, 685, 174, 686, 693, 696, 698, 187, 699, 191, 705, 706, 195, 196, 707, 708, 199, 200, 710, 202, 716, 721, 211, 212, 213, 728, 729, 218, 733, 223, 738, 739, 228, 229, 233, 241, 756, 761, 762, 763, 252, 253, 764, 255, 766, 770, 774, 775, 264, 778, 267, 269, 271, 784, 786, 789, 791, 283, 284, 796, 287, 289, 801, 804, 805, 295, 809, 812, 813, 308, 309, 310, 821, 312, 827, 829, 318, 833, 322, 835, 325, 837, 327, 840, 841, 330, 331, 847, 337, 849,