In [1]:
import typing
import torch
from torch import nn, Tensor
from dataclasses import dataclass
from conf import *
import avalanche as av
import torch.nn.functional as F

from experiment.experiment import Experiment, BaseHyperParameters
%load_ext autoreload
%autoreload 2

%matplotlib auto

from functional import best_reduce


Using matplotlib backend: <object object at 0x7ff02f8e3ce0>


In [2]:
from avalanche.benchmarks.classic.cfashion_mnist import SplitFMNIST
from network.ae_group import D_AE, CNN_Encoder, CNN_Decoder, MLP_AE_Head
from network.trait import PackNetModule, Classifier
from network.module.packnet_linear import PackNetDenseEncoder, PackNetDenseDecoder, PackNetDenseHead

import torchvision.transforms as transforms


class PackNetClassifyingAutoEncoder(D_AE, PackNetModule):
    encoder: PackNetModule
    decoder: PackNetModule
    head: PackNetModule

    subnet_count: int = 0

    def _pn_apply(self, func: typing.Callable[['PackNetModule'], None]):
        """Apply only to child PackNetModule"""
        for module in [self.encoder, self.decoder, self.head]:
            func(module)

    def forward(self, x: Tensor) -> D_AE.ForwardOutput:
        if self.training:
            return super().forward(x)
        return self._eval_forward(x)

    def push_pruned(self):
        super().push_pruned()
        self.subnet_count += 1

    @torch.no_grad()
    def _eval_forward(self, x: Tensor) -> D_AE.ForwardOutput:

        # Generate `ForwardOutput` for each `PackNet` subnetwork/stack layer
        losses: typing.Sequence[float] = []
        x_hats, y_hats, z_codes = [], [], []
        for i in range(self.subnet_count):
            self.use_task_subset(i)

            output = super().forward(x)
            # Calculate MSE to determine the how familiar the stack layer is with
            # the instance
            loss = F.mse_loss(output.x_hat, x, reduction="none").sum(dim=[1,2,3])

            # print(f"task: {i} loss: {loss}")

            x_hats.append(output.x_hat)
            y_hats.append(output.y_hat)
            z_codes.append(output.z_code)
            losses.append(loss)

        # Select the best subnetwork's output
        losses = torch.stack(losses)
        x_hat = best_reduce(losses, torch.stack(x_hats))
        y_hat = best_reduce(losses, torch.stack(y_hats))
        z_code = best_reduce(losses, torch.stack(z_codes))

        self.use_top_subset()

        return self.ForwardOutput(y_hat, x_hat, z_code)

@dataclass
class HyperParams(BaseHyperParameters):
    prune: bool
    prune_proportion: float
    post_prune_epochs: int
    classifier_weight: float = 1.0
    latent_dims: int = 64

class MyExperiment(Experiment):

    hp: HyperParams
    network: D_AE

    def __init__(self, hp: HyperParams) -> None:
        super().__init__(hp)

        self.after_eval_forward = self.after_forward

    def make_network(self) -> nn.Module:
        latent_dims = self.hp.latent_dims

        encoder = PackNetDenseEncoder((1, 28, 28), 64, [512, 256, 128])
        decoder = PackNetDenseDecoder((1, 28, 28), 64, [128, 256, 512])

        # channels = 1
        # base_channel_size = self.hp.base_channel_size
        # encoder = CNN_Encoder(channels, base_channel_size, latent_dims)
        # decoder = CNN_Decoder(channels, base_channel_size, latent_dims)
        
        head = PackNetDenseHead(latent_dims, 10)
        return PackNetClassifyingAutoEncoder(latent_dims, encoder, decoder, head)

    def make_optimizer(self, parameters) -> torch.optim.Optimizer:
        optimizer = torch.optim.SGD(parameters, self.hp.lr)
        return optimizer

    def after_forward(self, strategy: 'BaseStrategy', **kwargs):
        assert isinstance(self.strategy.mb_output, D_AE.ForwardOutput), "Unexpected output"
        self.forward_output = self.strategy.mb_output
        self.strategy.mb_output = self.strategy.mb_output.y_hat

    def make_criterion(self):
        def _loss_function(
            output: Tensor, 
            target: Tensor) -> Tensor:
            mb_x, mb_y = self.strategy.mb_x, self.strategy.mb_y
            x_hat = self.forward_output.x_hat
            y_hat = self.forward_output.y_hat

            recon_loss = F.mse_loss(mb_x, x_hat, reduction="none")
            recon_loss = recon_loss.sum(dim=[1, 2, 3]).mean(dim=[0])

            class_loss = F.cross_entropy(y_hat, mb_y)

            if self.strategy.is_training:
                self.log_scalar("train_recon_loss", recon_loss)
                self.log_scalar("train_class_loss", class_loss)
            else:
                self.log_scalar("test_recon_loss", recon_loss)
                self.log_scalar("test_class_loss", class_loss)

            return recon_loss + class_loss*self.hp.classifier_weight
        return _loss_function

    capacity: float = 1.0
    """How much of the network is still trainable"""
    def after_training_exp(self, strategy, **kwargs):
        """Perform pruning"""
        if not self.hp.prune:
            return
        
        self.capacity *= self.hp.prune_proportion
        print("Performing Prune")
        print(f"     Pruning     {self.hp.prune_proportion}")
        print(f"     New Capcity {self.capacity}")
        self.network.prune(self.hp.prune_proportion)

        for _ in range(self.hp.post_prune_epochs):
            self.strategy._before_training_epoch(**kwargs)
            self.strategy.training_epoch()
            self.strategy._after_training_epoch(**kwargs)


        print("Push Pruned")
        self.network.push_pruned()


    def before_eval_exp(self, strategy, *args, **kwargs):
        """Use task id to select the right part of each layer for eval"""
        experience: av.benchmarks.Experience = self.strategy.experience
        task_id = experience.task_label

        if self.hp.prune:
            print(f"task_id={task_id}, experience={self.clock.train_exp_counter}")
            # if task_id >= self.clock.train_exp_counter:
            #     self.network.use_top_subset()
            # else:
            self.network.use_task_subset(task_id)
    
    def after_eval(self, strategy, *args, **kwargs):
        """Reset for new experience"""
        self.network.use_top_subset()

    def make_scenario(self):
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,))]
        )
        scenario = SplitFMNIST(
            n_experiences=5,
            fixed_class_order=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
            dataset_root=DATASETS,
            return_task_id=True,
            train_transform=transform,
            eval_transform=transform)
        return scenario


In [3]:
experiment = MyExperiment(
    HyperParams(
            lr=0.002,
            train_mb_size=32,
            train_epochs=1,
            eval_mb_size=128,
            eval_every=-1,
            prune=True,
            prune_proportion=0.90,
            post_prune_epochs=1,
            classifier_weight=10.0,
            device="cuda"
    )
)

experiment.train()
print("DONE!")



Network: <class '__main__.PackNetClassifyingAutoEncoder'>
 > Has the <class 'network.trait.PackNetModule'> trait
 > Has the <class 'network.trait.AutoEncoder'> trait
 > Has the <class 'network.trait.Classifier'> trait
Start of experience: 0
Current Classes:     [0, 1]
Experience size:     12000
Performing Prune
     Pruning     0.9
     New Capcity 0.9
Push Pruned
task_id=0, experience=1
task: 0 loss: tensor([127.8640,  90.1622, 271.2549,  85.1819, 103.0417,  64.3163, 189.0309,
         78.9069, 164.4716, 124.6226,  67.4024, 114.0982,  89.8494,  79.4890,
        103.5007,  88.8573, 117.3710,  97.8966,  61.5265,  63.9777,  74.8878,
        149.6462, 387.7433, 200.3184,  85.8910, 130.1789, 142.9917, 119.6378,
         95.6842,  98.0440, 512.5458, 109.4387, 102.1615, 148.9093,  67.6694,
         97.4781, 396.0354,  79.2655, 188.9079, 115.4695, 297.1868, 149.2720,
         50.5747,  79.2706, 122.6745, 180.0286,  63.3540,  93.5301,  99.6337,
         71.1528,  88.8048, 151.1505, 159.2433, 1

KeyboardInterrupt: 