In [1]:
import typing
import torch
from torch import nn, Tensor
from dataclasses import dataclass
from conf import *
import avalanche as av
import torch.nn.functional as F

from experiment.experiment import Experiment, BaseHyperParameters
%load_ext autoreload
%autoreload 2

%matplotlib auto


Using matplotlib backend: <object object at 0x7f6382a16bd0>


In [2]:
from avalanche.benchmarks.classic.cfashion_mnist import SplitFMNIST
from network.ae_group import D_AE, CNN_Encoder, CNN_Decoder, MLP_AE_Head
from network.trait import Generative, PackNetModule
from network.module.packnet_linear import PackNetDenseEncoder, PackNetDenseDecoder

import torchvision.transforms as transforms


class PackNetClassifyingAutoEncoder(D_AE, PackNetModule):
    encoder: PackNetModule
    decoder: PackNetModule

    def prune(self, to_prune_proportion: float) -> None:
        self.encoder.prune(to_prune_proportion)
        self.decoder.prune(to_prune_proportion)

    def push_pruned(self) -> None:
        self.encoder.push_pruned()
        self.decoder.push_pruned()

    def use_task_subset(self, task_id):
        self.encoder.use_task_subset(task_id)
        self.decoder.use_task_subset(task_id)

    def use_top_subset(self):
        self.encoder.use_top_subset()
        self.decoder.use_top_subset()


@dataclass
class HyperParams(BaseHyperParameters):
    prune: bool
    prune_proportion: float
    post_prune_epochs: int
    classifier_weight: float = 1.0
    latent_dims: int = 64

class MyExperiment(Experiment):

    hp: HyperParams
    network: D_AE

    def __init__(self, hp: HyperParams) -> None:
        super().__init__(hp)

        self.after_eval_forward = self.after_forward

    def make_network(self) -> nn.Module:
        latent_dims = self.hp.latent_dims

        encoder = PackNetDenseEncoder((1, 28, 28), 64, [512, 256, 128])
        decoder = PackNetDenseDecoder((1, 28, 28), 64, [128, 256, 512])

        # channels = 1
        # base_channel_size = self.hp.base_channel_size
        # encoder = CNN_Encoder(channels, base_channel_size, latent_dims)
        # decoder = CNN_Decoder(channels, base_channel_size, latent_dims)
        
        head = MLP_AE_Head(latent_dims, 10)
        return PackNetClassifyingAutoEncoder(latent_dims, encoder, decoder, head)

    def make_optimizer(self, parameters) -> torch.optim.Optimizer:
        optimizer = torch.optim.SGD(parameters, self.hp.lr)
        return optimizer

    def after_forward(self, strategy: 'BaseStrategy', **kwargs):
        assert isinstance(self.strategy.mb_output, D_AE.ForwardOutput), "Unexpected output"
        self.forward_output = self.strategy.mb_output
        self.strategy.mb_output = self.strategy.mb_output.y_hat

    def make_criterion(self):
        def _loss_function(
            output: Tensor, 
            target: Tensor) -> Tensor:
            mb_x, mb_y = self.strategy.mb_x, self.strategy.mb_y
            x_hat = self.forward_output.x_hat
            y_hat = self.forward_output.y_hat

            recon_loss = F.mse_loss(mb_x, x_hat, reduction="none")
            recon_loss = recon_loss.sum(dim=[1, 2, 3]).mean(dim=[0])

            class_loss = F.cross_entropy(y_hat, mb_y)

            if self.strategy.is_training:
                self.log_scalar("train_recon_loss", recon_loss)
                self.log_scalar("train_class_loss", class_loss)
            else:
                self.log_scalar("test_recon_loss", recon_loss)
                self.log_scalar("test_class_loss", class_loss)

            return recon_loss + class_loss*self.hp.classifier_weight
        return _loss_function

    capacity: float = 1.0
    """How much of the network is still trainable"""
    def after_training_exp(self, strategy):
        """Perform pruning"""
        if not self.hp.prune:
            return
        
        self.capacity *= self.hp.prune_proportion
        print("Performing Prune")
        print(f"     Pruning     {self.hp.prune_proportion}")
        print(f"     New Capcity {self.capacity}")
        self.network.prune(self.hp.prune_proportion)

        for _ in range(self.hp.post_prune_epochs):
            self.strategy.training_epoch()

        print("Push Pruned")
        self.network.push_pruned()


    def before_eval_exp(self, strategy, *args, **kwargs):
        """Use task id to select the right part of each layer for eval"""
        experience: av.benchmarks.Experience = self.strategy.experience
        task_id = experience.task_label

        if self.hp.prune:
            print(f"task_id={task_id}, experience={self.clock.train_exp_counter}")
            # if task_id >= self.clock.train_exp_counter:
            #     self.network.use_top_subset()
            # else:
            self.network.use_task_subset(task_id)
    
    def after_eval(self, strategy, *args, **kwargs):
        """Reset for new experience"""
        self.network.use_top_subset()

    def make_scenario(self):
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,))]
        )
        scenario = SplitFMNIST(
            n_experiences=5,
            fixed_class_order=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            dataset_root=DATASETS,
            return_task_id=True,
            train_transform=transform,
            eval_transform=transform)
        return scenario


In [6]:
experiment = MyExperiment(
    HyperParams(
            lr=0.002,
            train_mb_size=32,
            train_epochs=10,
            eval_mb_size=128,
            eval_every=-1,
            prune=True,
            prune_proportion=0.90,
            post_prune_epochs=50,
            classifier_weight=0.0,
            device="cuda"
    )
)

experiment.train()
print("DONE!")



Network: <class '__main__.PackNetClassifyingAutoEncoder'>
 * Has the <class 'network.trait.Generative'> trait
Start of experience: 0
Current Classes:     [0, 1]
Experience size:     12000
Performing Prune
     Pruning     0.9
     New Capcity 0.9
Push Pruned
task_id=0, experience=1
USING SUBSET
USING SUBSET
USING SUBSET
USING SUBSET
USING SUBSET
task_id=1, experience=1
USING SUBSET
USING SUBSET
USING SUBSET
USING SUBSET
USING SUBSET
task_id=2, experience=1
USING SUBSET
USING SUBSET
USING SUBSET
USING SUBSET
USING SUBSET
task_id=3, experience=1
USING SUBSET
USING SUBSET
USING SUBSET
USING SUBSET
USING SUBSET
task_id=4, experience=1
USING SUBSET
USING SUBSET
USING SUBSET
USING SUBSET
USING SUBSET
Start of experience: 1
Current Classes:     [2, 3]
Experience size:     12000
Performing Prune
     Pruning     0.9
     New Capcity 0.81
Push Pruned
task_id=0, experience=2
USING SUBSET
USING SUBSET
USING SUBSET
USING SUBSET
USING SUBSET
task_id=1, experience=2
USING SUBSET
USING SUBSET
USING S