In [None]:
import typing
import torch
from torch import nn, Tensor
from dataclasses import dataclass
from conf import *
import torch.nn.functional as F

from experiment.experiment import Experiment, BaseHyperParameters
%load_ext autoreload
%autoreload 2

%matplotlib auto


In [None]:
from avalanche.benchmarks.classic.cfashion_mnist import SplitFMNIST
from network.ae_group import D_AE, AEGroup, MLP_AE_Head
from network.trait import Generative
from network.coders import CNN_Encoder, CNN_Decoder, MLP_Encoder, MLP_Decoder

import torchvision.transforms as transforms


@dataclass
class HyperParams(BaseHyperParameters):
    classifier_weight: float = 1.0
    latent_dims: int = 64
    base_channel_size: int = 32


class MyExperiment(Experiment):

    hp: HyperParams
    network: D_AE

    def __init__(self, hp: HyperParams) -> None:
        super().__init__(hp)

        self.after_eval_forward = self.after_forward

    def make_network(self) -> nn.Module:

        def _make_subnet():
            channels = 1
            base_channel_size = self.hp.base_channel_size
            latent_dims = self.hp.latent_dims
            classifier_weight = self.hp.classifier_weight

            encoder = MLP_Encoder(latent_dims)
            decoder = MLP_Decoder(latent_dims)
            head = MLP_AE_Head(latent_dims, 10)
            
            return D_AE(latent_dims, classifier_weight, encoder, decoder, head)

        return AEGroup(_make_subnet, 5, copy_network=True)

    def make_optimizer(self, parameters) -> torch.optim.Optimizer:
        optimizer = torch.optim.Adam(parameters, self.hp.lr)
        return optimizer

    def after_forward(self, strategy: 'BaseStrategy', **kwargs):
        assert isinstance(self.strategy.mb_output, D_AE.ForwardOutput), "Unexpected output"
        self.forward_output = self.strategy.mb_output
        self.strategy.mb_output = self.strategy.mb_output.y_hat

    def make_criterion(self):
        def _loss_function(
            output: Tensor, 
            target: Tensor) -> Tensor:
            return self.network.loss_function(
                self.strategy.mb_x, 
                self.strategy.mb_y, 
                self.forward_output.x_hat,
                self.forward_output.y_hat)
        return _loss_function

    def make_scenario(self):
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,)), transforms.Resize(28)]
        )
        scenario = SplitFMNIST(
            n_experiences=5,
            fixed_class_order=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            dataset_root=DATASETS,
            train_transform=transform,
            eval_transform=transform)
        return scenario


In [None]:
experiment = MyExperiment(
    HyperParams(
            lr=0.002,
            train_mb_size=64,
            train_epochs=256,
            eval_mb_size=128,
            eval_every=-1,
            device="cuda"
)
    )

experiment.train()
print("DONE!")