In [None]:
import typing
import torch
from torch import nn, Tensor
from dataclasses import dataclass
from conf import *
import torch.nn.functional as F

from experiment.experiment import Experiment, BaseHyperParameters
%load_ext autoreload
%autoreload 2

In [None]:
from avalanche.benchmarks.classic.cfashion_mnist import SplitFMNIST
from network.ae_group import D_AE

@dataclass
class HyperParams(BaseHyperParameters):
    pass

class MyExperiment(Experiment):

    hp: HyperParams
    network: nn.Module

    def __init__(self, hp: HyperParams) -> None:
        super().__init__(hp)

        self.after_eval_forward = self.after_forward

    def make_network(self) -> nn.Module:
        return D_AE(64, 10)

    def make_optimizer(self, parameters) -> torch.optim.Optimizer:
        optimizer = torch.optim.Adam(parameters, self.hp.lr)
        return optimizer

    def _reconstruction_loss(self,x: Tensor , x_hat: Tensor):
        loss = F.mse_loss(x, x_hat, reduction="none")
        loss = loss.sum(dim=[1, 2, 3]).mean(dim=[0])
        return loss

    def after_forward(self, strategy: 'BaseStrategy', **kwargs):
        """Disentangle reconstruction and classification"""
        x_hat, y_hat = self.strategy.mb_output
        self.x_hat = x_hat
        self.strategy.mb_output = y_hat

    def make_criterion(self):
        self.cross_entropy = nn.CrossEntropyLoss()

        def AE_criterion(output, target):
            x_hat = self.x_hat
            y_hat = self.strategy.mb_output
            x = self.strategy.mb_x
            y = self.strategy.mb_y

            reconstruction_loss = self._reconstruction_loss(x_hat, x)
            classifier_loss = self.cross_entropy(y_hat, y)

            # Log types of loss
            self.log_scalar("ReconstructionLoss", float(reconstruction_loss))
            self.log_scalar("ClassifierLoss", float(classifier_loss))
            return reconstruction_loss + classifier_loss

        return AE_criterion

    def make_scenario(self):
        scenario = SplitFMNIST(5, fixed_class_order=[0,1,2,3,4,5,6,7,8,9], dataset_root=DATASETS)
        return scenario

In [None]:
_ = MyExperiment(
    HyperParams(
            lr=0.01,
            train_mb_size=64,
            train_epochs=20,
            eval_mb_size=128,
            eval_every=-1,
            device="cuda"
    )
).train()
print("DONE!")

In [None]:
from avalanche.benchmarks.scenarios.new_classes.nc_scenario import NCExperience
from torchvision import transforms
from metrics.reconstructions import GenerateReconstruction
import numpy.random as random
scenario = SplitFMNIST(5, fixed_class_order=[0,1,2,3,4,5,6,7,8,9], dataset_root=DATASETS)


recon = GenerateReconstruction(scenario, 1, 0)
# strategy = D_AE(64, 10)
recon.after_eval_exp(experiment)

