# Continual Learning on Split CIFAR-10

## Prepare the Dataset

In [1]:
from pathlib import Path
from avalanche.benchmarks.classic import SplitCIFAR10


example_dir_path = Path().resolve()
data_dir_path = Path.joinpath(example_dir_path, "data")

split_cifar10 = SplitCIFAR10(
    n_experiences=5,
    dataset_root=Path.joinpath(data_dir_path, "cifar10"),
    shuffle=False,
    return_task_id=True,
    class_ids_from_zero_in_each_exp=True,
)

Files already downloaded and verified
Files already downloaded and verified


## Define the Lightning Module

In [2]:
import pytorch_lightning as pl
import timm
import torch
import torch.nn as nn

# Must import `hat.networks` to register the models
# noinspection PyUnresolvedReferences
import hat.timm_models
from hat import HATConfig, HATPayload
from hat.utils import get_hat_reg_term, get_hat_mask_scale


class ContinualClassifier(pl.LightningModule):
    def __init__(
        self,
        init_strat,
        scaling_strat,
        num_classes_per_exp,
        max_mask_scale=100.0,
    ):
        super().__init__()
        self.init_strat = init_strat
        self.scaling_strat = scaling_strat
        self.num_classes_per_exp = num_classes_per_exp
        self.max_mask_scale = max_mask_scale
        _hat_config = HATConfig(
            init_strat=self.init_strat,
            num_tasks=len(num_classes_per_exp),
        )
        self.backbone = timm.create_model(
            "hat_resnet18s",
            num_classes=0,
            hat_config=_hat_config,
        )
        self.heads = nn.ModuleList(
            [nn.Linear(512, __c) for __c in num_classes_per_exp]
        )
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, images, task_id, mask_scale=None):
        _pld = HATPayload(images, task_id=task_id, mask_scale=mask_scale)
        return self.heads[_pld.task_id](self.backbone(_pld).data)

    def training_step(self, batch, batch_idx):
        _images, _targets, _task_id = batch
        _progress = batch_idx / (self.trainer.num_training_batches - 1)
        _mask_scale = get_hat_mask_scale(
            strat=self.scaling_strat,
            max_trn_mask_scale=self.max_mask_scale,
            progress=_progress,
        )
        _logits = self.forward(_images, _task_id, _mask_scale)
        _loss = self.criterion(_logits, _targets)
        _reg = get_hat_reg_term(
            module=self.backbone,
            strat="uniform",
            task_id=_task_id,
            mask_scale=_mask_scale,
            # forgive_quota=False,
        )
        return _loss + _reg

    def test_step(self, batch, batch_idx, dataloader_idx):
        _images, _targets, _task_id = batch
        # Class-incremental learning
        # Iterate through all the tasks and compute the logits
        _logits = []
        for __task_id in range(len(self.heads)):
            _logits.append(
                self.forward(_images, __task_id, self.max_mask_scale)
            )
        # Class-incremental testing
        _cil_logits = torch.cat(_logits, dim=1)
        _cil_targets = _targets + sum(self.num_classes_per_exp[:_task_id])
        _cil_acc = _cil_logits.argmax(dim=1) == _cil_targets
        # Task-incremental testing
        _til_logits = _logits[_task_id]
        _til_acc = _til_logits.argmax(dim=1) == _targets
        self.log_dict(
            {
                "cil_acc": _cil_acc.float().mean(),
                "til_acc": _til_acc.float().mean(),
            },
            batch_size=_images.shape[0],
        )

    def configure_optimizers(self):
        # Bigger learning rate or more epochs may be needed if
        # the model is using dense initialization for HAT maskers.
        # return torch.optim.SGD(self.parameters(), lr=1e-3)
        return torch.optim.Adam(self.parameters(), lr=1e-2)

## Train the Model for Each Task

In [3]:
import logging
import warnings

from torch.utils.data import DataLoader


logger = logging.getLogger("pytorch_lightning")
logger.propagate = False
logger.setLevel(logging.ERROR)

logger = logging.getLogger("lightning_fabric")
logger.propagate = False
logger.setLevel(logging.ERROR)

warnings.filterwarnings("ignore", category=UserWarning)


def collate_fn(batch):
    images, targets, task_ids = zip(*batch)
    return (torch.stack(images), torch.tensor(targets), int(task_ids[0]))


clf = ContinualClassifier(
    init_strat="dense",
    scaling_strat="cosine",
    num_classes_per_exp=split_cifar10.n_classes_per_exp,
)
accelerator = "cuda" if torch.cuda.is_available() else "cpu"
strategy = (
    "ddp_notebook_find_unused_parameters_true"
    if accelerator == "cuda"
    else None
)


for __task_id, __trn_exp in enumerate(split_cifar10.train_stream):
    print(f"Training on task/experience {__task_id}")
    _trainer = pl.Trainer(
        max_epochs=5,
        accelerator=accelerator,
        # strategy=strategy,
        devices=1,
    )
    _dataloader = DataLoader(
        __trn_exp.dataset,
        batch_size=128,
        shuffle=True,
        num_workers=8 if accelerator == "cuda" else 0,
        pin_memory=True if accelerator == "cuda" else False,
        collate_fn=collate_fn,
    )
    _trainer.fit(clf, _dataloader)

Training on task/experience 0


Training: 0it [00:00, ?it/s]

Training on task/experience 1


Training: 0it [00:00, ?it/s]

Training on task/experience 2


Training: 0it [00:00, ?it/s]

Training on task/experience 3


Training: 0it [00:00, ?it/s]

Training on task/experience 4


Training: 0it [00:00, ?it/s]

## Test the Model

In [4]:
clf.freeze()
for __m in clf.modules():
    if isinstance(__m, nn.BatchNorm2d):
        __m.track_running_stats = False

trainer = pl.Trainer(
    accelerator=accelerator,
    devices=1,
    enable_model_summary=False,
)
tst_dataloaders = [
    DataLoader(
        __exp.dataset,
        batch_size=128,
        num_workers=8 if accelerator == "cuda" else 0,
        pin_memory=True if accelerator == "cuda" else False,
        collate_fn=collate_fn,
    )
    for __exp in split_cifar10.test_stream
]
tst_results = trainer.test(clf, tst_dataloaders, verbose=False)

Testing: 0it [00:00, ?it/s]

In [5]:
# Reformat results for better readability
import pandas as pd

reformatted_tst_results = {}
til_acc, cil_acc = [], []
for __tst_results in tst_results:
    for __label, __acc in __tst_results.items():
        __il, __dl = __label.split("/")
        __il = __il.split("_")[0].upper()
        __dl = f"Task {__dl.split('_')[-1]}"
        if __dl not in reformatted_tst_results:
            reformatted_tst_results[__dl] = {}
        reformatted_tst_results[__dl][__il] = __acc
        if __il == "TIL":
            til_acc.append(__acc)
        else:
            cil_acc.append(__acc)
reformatted_tst_results["Avg"] = {
    "TIL": sum(til_acc) / len(til_acc),
    "CIL": sum(cil_acc) / len(cil_acc),
}
reformatted_tst_results = pd.DataFrame(reformatted_tst_results)
print(
    reformatted_tst_results.to_markdown(
        floatfmt="2.2%",
        tablefmt="fancy_outline",
        stralign="center",
        numalign="center",
    )
)

╒═════╤══════════╤══════════╤══════════╤══════════╤══════════╤════════╕
│     │  Task 0  │  Task 1  │  Task 2  │  Task 3  │  Task 4  │  Avg   │
╞═════╪══════════╪══════════╪══════════╪══════════╪══════════╪════════╡
│ CIL │  30.80%  │  26.85%  │  24.60%  │  69.65%  │  72.70%  │ 44.92% │
│ TIL │  92.10%  │  77.05%  │  91.90%  │  96.80%  │  95.45%  │ 90.66% │
╘═════╧══════════╧══════════╧══════════╧══════════╧══════════╧════════╛
