In [1]:
import typing
import torch
from torch import nn, Tensor
from dataclasses import dataclass
from conf import *
import torch.nn.functional as F
from avalanche.benchmarks.classic.cfashion_mnist import SplitFMNIST
import torchvision.transforms as transforms
from experiment.experiment import Experiment, BaseHyperParameters

%matplotlib auto
%load_ext autoreload
%autoreload 2
from network.vae import VAE, VAE_Loss


Using matplotlib backend: <object object at 0x7f414d78acf0>


In [2]:



@dataclass
class HyperParams(BaseHyperParameters):
    pass


class MyExperiment(Experiment):

    hp: HyperParams
    network: VAE

    classifier_weight: float = 50.0

    def __init__(self, hp: HyperParams) -> None:
        super().__init__(hp)

        self.after_eval_forward = self.after_forward

    def make_network(self) -> nn.Module:
        return VAE()

    def make_optimizer(self, parameters) -> torch.optim.Optimizer:
        optimizer = torch.optim.Adam(parameters, self.hp.lr)
        return optimizer

    def _reconstruction_loss(self, x: Tensor, x_hat: Tensor):
        loss = F.mse_loss(x, x_hat, reduction="none")
        loss = loss.sum(dim=[1, 2, 3]).mean(dim=[0])
        return loss

    def after_forward(self, strategy: 'BaseStrategy', **kwargs):
        """Disentangle reconstruction and classification"""
        self.forward_output: VAE.ForwardOutput = self.strategy.mb_output
        self.strategy.mb_output = self.forward_output.y_hat

    def make_criterion(self):
        self.criterion = VAE_Loss(1.0, 1.0, 1.0)
        def VAE_criterion(output, target):
            return self.criterion.loss(self.forward_output, target)
        return VAE_criterion


    def make_scenario(self):
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,))]
        )
        scenario = SplitFMNIST(
            n_experiences=5,
            fixed_class_order=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            dataset_root=DATASETS,
            train_transform=transform,
            eval_transform=transform)
        return scenario


In [5]:
experiment = MyExperiment(
    HyperParams(
        lr=0.005,
        train_mb_size=32,
        train_epochs=1,
        eval_mb_size=128,
        eval_every=-1,
        device="cuda"
    )
).train()




Network: <class 'network.vae.VAE'>
 > Has the <class 'network.trait.AutoEncoder'> trait
 > Has the <class 'network.trait.Classifier'> trait
 > Has the <class 'network.trait.Samplable'> trait
Start of experience: 0
Current Classes:     [0, 1]
Experience size:     12000
Start of experience: 1
Current Classes:     [2, 3]
Experience size:     12000
Start of experience: 2
Current Classes:     [4, 5]
Experience size:     12000
Start of experience: 3
Current Classes:     [6, 7]
Experience size:     12000
Start of experience: 4
Current Classes:     [8, 9]
Experience size:     12000
