In [None]:
import typing
import torch
from torch import nn, Tensor
from dataclasses import dataclass
from conf import *
import torch.nn.functional as F

from experiment.experiment import Experiment, BaseHyperParameters
%load_ext autoreload
%autoreload 2

%matplotlib auto


In [None]:
from avalanche.benchmarks.classic.cfashion_mnist import SplitFMNIST
from network.ae_group import D_AE, CNN_Encoder, CNN_Decoder, MLP_AE_Head
import torchvision.transforms as transforms


@dataclass
class HyperParams(BaseHyperParameters):
    pass


class MyExperiment(Experiment):

    hp: HyperParams
    network: nn.Module

    classifier_weight: float = 20.0

    def __init__(self, hp: HyperParams) -> None:
        super().__init__(hp)

        self.after_eval_forward = self.after_forward

    def make_network(self) -> nn.Module:
        channels = 1
        base_channel_size = 32
        latent_dims = 64

        encoder = CNN_Encoder(channels, base_channel_size, latent_dims)
        decoder = CNN_Decoder(channels, base_channel_size, latent_dims)
        head = MLP_AE_Head(latent_dims, 10)

        return D_AE(encoder, decoder, head)

    def make_optimizer(self, parameters) -> torch.optim.Optimizer:
        optimizer = torch.optim.Adam(parameters, self.hp.lr)
        return optimizer

    def _reconstruction_loss(self, x: Tensor, x_hat: Tensor):
        loss = F.mse_loss(x, x_hat, reduction="none")
        loss = loss.sum(dim=[1, 2, 3]).mean(dim=[0])
        return loss

    def after_forward(self, strategy: 'BaseStrategy', **kwargs):
        """Disentangle reconstruction and classification"""
        x_hat, y_hat = self.strategy.mb_output
        self.x_hat = x_hat
        self.strategy.mb_output = y_hat

    def make_criterion(self):
        self.cross_entropy = nn.CrossEntropyLoss()

        def AE_criterion(output, target):
            x_hat = self.x_hat
            y_hat = self.strategy.mb_output
            x = self.strategy.mb_x
            y = self.strategy.mb_y

            reconstruction_loss = self._reconstruction_loss(x_hat, x)
            classifier_loss = self.classifier_weight * \
                self.cross_entropy(y_hat, y)

            # Log types of loss
            if self.strategy.is_training:
                self.log_scalar("ReconstructionLoss",
                                float(reconstruction_loss))
                self.log_scalar("ClassifierLoss", float(classifier_loss))
            return reconstruction_loss + classifier_loss

        return AE_criterion

    def make_scenario(self):
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,)), transforms.Resize(32)]
        )
        scenario = SplitFMNIST(
            n_experiences=5,
            fixed_class_order=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            dataset_root=DATASETS,
            train_transform=transform,
            eval_transform=transform)
        return scenario


In [8]:
experiment = MyExperiment(
    HyperParams(
            lr=0.002,
            train_mb_size=64,
            train_epochs=20,
            eval_mb_size=128,
            eval_every=-1,
            device="cuda"
    )
).train()
print("DONE!")

Start of experience:0
Current Classes:    [0, 1]
Start of experience:1
Current Classes:    [2, 3]
Start of experience:2
Current Classes:    [4, 5]
Start of experience:3
Current Classes:    [6, 7]
Start of experience:4
Current Classes:    [8, 9]
DONE!
