# Continual Learning on Split CIFAR-10

## Prepare the Dataset

In [1]:
from pathlib import Path

from avalanche.benchmarks.classic import SplitCIFAR10


example_dir_path = Path().resolve()
data_dir_path = Path.joinpath(example_dir_path, "data")

split_cifar10 = SplitCIFAR10(
    n_experiences=5,
    dataset_root=Path.joinpath(data_dir_path, "cifar10"),
    shuffle=False,
    return_task_id=True,
    class_ids_from_zero_in_each_exp=True,
)

Files already downloaded and verified
Files already downloaded and verified


## Define the Lightning Module

In [2]:
import pytorch_lightning as pl
import timm
import torch
import torch.nn as nn

# Must import `hat.networks` to register the models
# noinspection PyUnresolvedReferences
import hat.networks
from hat import HATConfig, HATPayload
from hat.utils import get_hat_reg_term


class ContinualClassifier(pl.LightningModule):
    def __init__(self, num_classes_per_exp, max_mask_scale=100.0):
        super().__init__()
        self.num_classes_per_exp = num_classes_per_exp
        self.max_mask_scale = max_mask_scale
        _hat_config = HATConfig(
            num_tasks=len(num_classes_per_exp),
        )
        self.backbone = timm.create_model(
            "hat_resnet18s",
            num_classes=0,
            hat_config=_hat_config,
        )
        self.heads = nn.ModuleList(
            [nn.Linear(512, __c) for __c in num_classes_per_exp]
        )
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, images, task_id, mask_scale=None):
        pld = HATPayload(images, task_id=task_id, mask_scale=mask_scale)
        return self.heads[pld.task_id](self.backbone(pld).data)

    def training_step(self, batch, batch_idx):
        images, targets, task_id = batch
        # Progress is the percentage of the epoch completed
        # We use sliding mask scale to gradually increase the masking
        _progress = (batch_idx + 1) / self.trainer.num_training_batches
        _mask_scale = _progress * self.max_mask_scale
        logits = self.forward(images, task_id, _mask_scale)
        loss = self.criterion(logits, targets)
        reg = get_hat_reg_term(
            module=self.backbone,
            reg_strategy="uniform",
            task_id=task_id,
            mask_scale=_mask_scale,
        )
        return loss + reg

    def test_step(self, batch, batch_idx, dataloader_idx):
        images, targets, task_id = batch
        # Class-incremental learning
        # Iterate through all the tasks and compute the logits
        logits = []
        for __task_id in range(len(self.heads)):
            logits.append(self.forward(images, __task_id, self.max_mask_scale))
        # Class-incremental testing
        cil_logits = torch.cat(logits, dim=1)
        cil_targets = targets + sum(self.num_classes_per_exp[:task_id])
        cil_acc = cil_logits.argmax(dim=1) == cil_targets
        # Task-incremental testing
        til_logits = logits[task_id]
        til_acc = til_logits.argmax(dim=1) == targets
        self.log_dict(
            {
                "cil_acc": cil_acc.float().mean(),
                "til_acc": til_acc.float().mean(),
            },
            batch_size=images.shape[0],
        )

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

## Train the Model for Each Task

In [3]:
import logging
import warnings

from torch.utils.data import DataLoader


log = logging.getLogger("pytorch_lightning")
log.propagate = False
log.setLevel(logging.ERROR)

log = logging.getLogger("lightning_fabric")
log.propagate = False
log.setLevel(logging.ERROR)

warnings.filterwarnings("ignore", category=UserWarning)

clf = ContinualClassifier(split_cifar10.n_classes_per_exp)
device = "cuda"
strategy = "ddp_notebook_find_unused_parameters_true"


def collate_fn(batch):
    images, targets, task_ids = zip(*batch)
    return (torch.stack(images), torch.tensor(targets), int(task_ids[0]))


for __task_id, __trn_exp in enumerate(split_cifar10.train_stream):
    print(f"Training on task/experience {__task_id}")
    trainer = pl.Trainer(
        max_epochs=5,
        accelerator=device,
        strategy=strategy,
    )
    dataloader = DataLoader(
        __trn_exp.dataset,
        batch_size=128,
        shuffle=True,
        num_workers=8,
        pin_memory=True if device == "cuda" else False,
        collate_fn=collate_fn,
    )
    trainer.fit(clf, dataloader)

Training on task/experience 0
Training on task/experience 1
Training on task/experience 2
Training on task/experience 3
Training on task/experience 4


## Test the Model

In [4]:
clf.freeze()
for __m in clf.modules():
    if isinstance(__m, nn.BatchNorm2d):
        __m.track_running_stats = False

trainer = pl.Trainer(
    accelerator=device,
    devices=1,
    enable_model_summary=False,
)
tst_dataloaders = [
    DataLoader(
        __exp.dataset,
        batch_size=128,
        num_workers=8,
        pin_memory=True if device == "cuda" else False,
        collate_fn=collate_fn,
    )
    for __exp in split_cifar10.test_stream
]
tst_results = trainer.test(clf, tst_dataloaders, verbose=False)

Testing: 0it [00:00, ?it/s]

In [5]:
# Reformat results for better readability
import pandas as pd

reformatted_tst_results = {}
til_acc, cil_acc = [], []
for __tst_results in tst_results:
    for __label, __acc in __tst_results.items():
        __il, __dl = __label.split("/")
        __il = __il.split("_")[0].upper()
        __dl = f"Task {__dl.split('_')[-1]}"
        if __dl not in reformatted_tst_results:
            reformatted_tst_results[__dl] = {}
        reformatted_tst_results[__dl][__il] = __acc
        if __il == "TIL":
            til_acc.append(__acc)
        else:
            cil_acc.append(__acc)
reformatted_tst_results["Avg"] = {
    "TIL": sum(til_acc) / len(til_acc),
    "CIL": sum(cil_acc) / len(cil_acc),
}
reformatted_tst_results = pd.DataFrame(reformatted_tst_results)
print(
    reformatted_tst_results.to_markdown(
        floatfmt="2.2%",
        tablefmt="fancy_outline",
        stralign="center",
        numalign="center",
    )
)

╒═════╤══════════╤══════════╤══════════╤══════════╤══════════╤════════╕
│     │  Task 0  │  Task 1  │  Task 2  │  Task 3  │  Task 4  │  Avg   │
╞═════╪══════════╪══════════╪══════════╪══════════╪══════════╪════════╡
│ CIL │  56.20%  │  3.65%   │  22.15%  │  54.85%  │  26.20%  │ 32.61% │
│ TIL │  91.80%  │  80.05%  │  79.95%  │  80.75%  │  89.10%  │ 84.33% │
╘═════╧══════════╧══════════╧══════════╧══════════╧══════════╧════════╛
