In [1]:
import typing
import torch
from torch import nn, Tensor
from dataclasses import dataclass
from conf import *
import torch.nn.functional as F
from avalanche.benchmarks.classic.cfashion_mnist import SplitFMNIST
import torchvision.transforms as transforms
from experiment.experiment import Experiment, BaseHyperParameters
from network.coders import PackNetDenseEncoder, PackNetDenseDecoder, PackNetDenseHead

%matplotlib auto
%load_ext autoreload
%autoreload 2
from network.deep_generative import DVAE, DVAE_Loss
from network.trait import PackNetModule
from experiment.plugins import PackNetPlugin


Using matplotlib backend: <object object at 0x7fe3bf28bc00>


In [2]:

class PN_VAE(DVAE, PackNetModule):
    """PackNet variational auto-encoder"""

    def _pn_apply(self, func: typing.Callable[['PackNetModule'], None]):
        """Apply only to child PackNetModule"""
        for module in [self.encoder, self.decoder, self.class_head, self.mu_head, self.var_head]:
            func(module)

    def __init__(self, latent_dims=64):
        super().__init__(
            latent_dims,
            encoder=PackNetDenseEncoder((1, 28, 28), 64, [512, 256, 128]),
            decoder=PackNetDenseDecoder((1, 28, 28), 64, [128, 256, 512]),
            class_head=PackNetDenseHead(latent_dims, 10),
            mu_head=PackNetDenseHead(latent_dims, latent_dims),
            var_head=PackNetDenseHead(latent_dims, latent_dims)
        )

@dataclass
class HyperParams(BaseHyperParameters):
    prune_proportion: float
    post_prune_epochs: int


class MyExperiment(Experiment):

    hp: HyperParams
    network: DVAE

    classifier_weight: float = 50.0

    def __init__(self, hp: HyperParams) -> None:
        super().__init__(hp)

        self.after_eval_forward = self.after_forward

    def make_network(self) -> nn.Module:

        return PN_VAE(64)

    def make_optimizer(self, parameters) -> torch.optim.Optimizer:
        optimizer = torch.optim.SGD(parameters, self.hp.lr)
        return optimizer

    def make_criterion(self):
        self.criterion = DVAE_Loss(recon_weight=1.0, classifier_weight=1.0, kld_weight=1.0)
        def VAE_criterion(output, target):
            return self.criterion.loss(self.last_mb_output, target)
        return VAE_criterion

    def add_plugins(self):
        return [PackNetPlugin(self.network, self.hp.prune_proportion, self.hp.post_prune_epochs)]


    def make_scenario(self):
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,))]
        )
        scenario = SplitFMNIST(
            n_experiences=5,
            fixed_class_order=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            dataset_root=DATASETS,
            return_task_id=True,
            train_transform=transform,
            eval_transform=transform)
        return scenario


In [3]:
experiment = MyExperiment(
    HyperParams(
        lr=0.001,
        train_mb_size=32,
        train_epochs=10,
        prune_proportion=0.50,
        post_prune_epochs=20,
        eval_mb_size=128,
        eval_every=-1,
        device="cuda"
    )
).train()


Network: <class '__main__.PN_VAE'>
 > Has the <class 'network.trait.PackNetModule'> trait
 > Has the <class 'network.trait.AutoEncoder'> trait
 > Has the <class 'network.trait.Classifier'> trait
 > Has the <class 'network.trait.Samplable'> trait
Start of experience: 0
Current Classes:     [0, 1]
Experience size:     12000


INFO:experiment.plugins:Pruning Network, 50.0% remaining
INFO:experiment.plugins:Pushing
INFO:experiment.plugins:Using task label 0
INFO:experiment.plugins:Using task label 1
INFO:experiment.plugins:Using task label 2
INFO:experiment.plugins:Using task label 3
INFO:experiment.plugins:Using task label 4


Start of experience: 1
Current Classes:     [2, 3]
Experience size:     12000


INFO:experiment.plugins:Pruning Network, 25.0% remaining
INFO:experiment.plugins:Pushing
INFO:experiment.plugins:Using task label 0
INFO:experiment.plugins:Using task label 1
INFO:experiment.plugins:Using task label 2
INFO:experiment.plugins:Using task label 3
INFO:experiment.plugins:Using task label 4


Start of experience: 2
Current Classes:     [4, 5]
Experience size:     12000


INFO:experiment.plugins:Pruning Network, 12.5% remaining
INFO:experiment.plugins:Pushing
INFO:experiment.plugins:Using task label 0
INFO:experiment.plugins:Using task label 1
INFO:experiment.plugins:Using task label 2
INFO:experiment.plugins:Using task label 3
INFO:experiment.plugins:Using task label 4


Start of experience: 3
Current Classes:     [6, 7]
Experience size:     12000


INFO:experiment.plugins:Pruning Network, 6.25% remaining
INFO:experiment.plugins:Pushing
INFO:experiment.plugins:Using task label 0
INFO:experiment.plugins:Using task label 1
INFO:experiment.plugins:Using task label 2
INFO:experiment.plugins:Using task label 3
INFO:experiment.plugins:Using task label 4


Start of experience: 4
Current Classes:     [8, 9]
Experience size:     12000


INFO:experiment.plugins:Pruning Network, 3.125% remaining
INFO:experiment.plugins:Pushing
INFO:experiment.plugins:Using task label 0
INFO:experiment.plugins:Using task label 1
INFO:experiment.plugins:Using task label 2
INFO:experiment.plugins:Using task label 3
INFO:experiment.plugins:Using task label 4
