In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import Module
import numpy as np
import copy
from torch.nn import functional as F
import numpy as np
from typing import Dict
import pytorch_lightning as pl
from utils import summarize_prune
from pytorch_lightning.loggers import TensorBoardLogger
from provided_code.datasource import DataLoaders
from pytorch_lightning.metrics import functional as FM
from torchvision.datasets import MNIST,CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader, BatchSampler
from model.cifar10 import  cnn,mlp
from utils import summarize_prune,copy_model,get_prune_params,prune_fixed_amount
import torch.nn.utils.prune as prune


### DATA PREPARATION

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

dataset1 = CIFAR10('./data', train=True,
                 transform=transform)
dataset2 = CIFAR10('./data', train=False,
                 transform=transform)
train_loader = DataLoader(dataset1, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset2, batch_size=32, shuffle=False, num_workers=4)

### MODEL

In [9]:
class Model(pl.LightningModule):
    def __init__(
        self,
        model = None,
        num_classes=10,
        batch_size=32,
        lr=1e-3,
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)

        self.model = model
        self.lr = lr
        self.batch_size = batch_size

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        acc = FM.accuracy(y_hat, y)
        metrics = {
            'loss': loss,
            'acc': acc}
        self.log_dict(metrics)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            params=self.model.parameters(),
            lr=self.lr)
        return optimizer


    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        acc = FM.accuracy(y_hat, y)
        metrics = {
            'loss': loss,
            'acc': acc}
        self.log_dict(metrics)
        return metrics

In [82]:
class Client():
    def __init__(self,args,train_loader,test_loader,idx):
        self.args = args
        self.model = Model(model = cnn.CNN())
        self.test_loader = test_loader
        self.train_loader = train_loader
        self.idx = idx
        self.elapsed_comm_rounds = 0
        self.accuracies = np.zeros((args.comm_rounds,self.args.epochs))
        self.losses = np.zeros((args.comm_rounds,self.args.epochs))
        self.prune_rates = np.zeros(args.comm_rounds)
        self.cur_prune_rate = 0.00
        self.eita = self.args.eita_hat
        self.trainer = pl.Trainer(
            gpus = 1,
            progress_bar_refresh_rate=60,
            max_epochs=self.args.epochs,
            fast_dev_run= args.fast_dev_run
        )
        self.globalModel = copy.deepcopy(self.model)
        Client.Prune(self.globalModel.model,prune_rate = 0.0)
        self.global_init_model = copy.deepcopy(self.model)
        Client.Prune(self.global_init_model.model,prune_rate = 0.0)
  

    def update(self):
        metrics = self.trainer.validate(
            model = self.model,
            val_dataloaders=self.test_loader,
            verbose= True
            )
        
        num_pruned , num_params = summarize_prune(
            self.globalModel,name = 'weight')
        cur_prune_rate = num_pruned / num_params
        if self.cur_prune_rate < self.args.prune_percent:
            self.cur_prune_rate = min(self.cur_prune_rate+self.args.prune_step,
                                          self.args.prune_percent)
            if metrics[0]["acc"] > self.eita:
                Client.Prune(
                    self.globalModel.model,
                    prune_rate = self.cur_prune_rate)
                self.model = copy_model(
                    model = self.global_init_model,
                    dataset = self.args.dataset,
                    arch = self.args.arch,
                    source_buff = dict(self.globalModel.model.named_buffers())
                )
                self.eita = self.args.eita_hat
            else:
                self.eita = self.eita*self.args.alpha
                self.model = self.globalModel
        else:
            Client.Prune(
                model = self.globalModel.model,
                prune_rate = self.args.prune_percent
                )
            self.model = self.globalModel

        self.trainer.fit(
            model= self.model,
            train_dataloader=self.train_loader,
            )
        
        metrics = self.trainer.validate(
            model = self.model,
            val_dataloaders =self.test_loader)
        
        # TODO: LOG MODEL
        print(metrics)

    @staticmethod
    def Prune(model,prune_rate):
        params,_,_ = get_prune_params(model)
        for param,name in params:
            prune.l1_unstructured(
                param,
                name = name,
                amount=prune_rate
                )
    
    def upload(self):
        """
            Upload self.model
        """
        return {
            "model": copy_model(self.model,
                                self.args.dataset,
                                self.args.arch),
            "acc": self.eval_score["Accuracy"]
        }
    
    def download(self, globalModel, global_initModel):
        """
            Download global model from server
        """
        self.globalModel = globalModel
        self.global_init_model = global_initModel        


In [83]:
class Args():
    arch = "cnn"
    dataset = "cifar10"
    epochs = 1
    eita_hat = 0.5
    alpha = 0.5
    prune_percent = 0.8
    prune_step = 0.2
    comm_rounds = 10
    fast_dev_run = False

client = Client(
    Args(),
    idx = 1,
    train_loader = train_loader,
    test_loader = test_loader
    )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [84]:
client.update()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params
-------------------------------
0 | model | CNN  | 62.0 K
-------------------------------
62.0 K    Trainable params
0         Non-trainable params
62.0 K    Total params
0.248     Total estimated model params size (MB)
--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'acc': 0.10029999911785126, 'loss': 2.3026797771453857}
--------------------------------------------------------------------------------
Epoch 0:  60%|█████▉    | 420/704 [00:04<00:03, 91.64it/s, loss=2.11, v_num=20]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/313 [00:00<?, ?it/s][A
Epoch 0:  68%|██████▊   | 480/704 [00:05<00:02, 95.79it/s, loss=2.11, v_num=20]
Epoch 0:  77%|███████▋  | 540/704 [00:05<00:01, 100.77it/s, loss=2.11, v_num=20]
Epoch 0:  85%|████████▌ | 600/704 [00:05<00:00, 105.96it/s, loss=2.11, v_num=20]
Epoch 0: 

In [85]:
client.update()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/313 [00:00<?, ?it/s][A
Validating:  19%|█▉        | 60/313 [00:00<00:01, 155.47it/s][A
Validating:  38%|███▊      | 120/313 [00:00<00:01, 181.14it/s][A
Validating:  58%|█████▊    | 180/313 [00:01<00:00, 178.43it/s][A
Validating:  77%|███████▋  | 240/313 [00:01<00:00, 187.17it/s][A
Validating:  96%|█████████▌| 300/313 [00:01<00:00, 189.25it/s][A
                                                              [A--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'acc': 0.36890000104904175, 'loss': 2.0893971920013428}
--------------------------------------------------------------------------------


KeyError: 'conv1.bias'

In [52]:
list(model.named_buffers())

[('model.conv1.weight_mask',
  tensor([[[[1., 1., 1., 0., 0.],
            [1., 0., 1., 1., 1.],
            [1., 0., 1., 0., 0.],
            [0., 0., 0., 0., 1.],
            [1., 0., 0., 0., 0.]],
  
           [[0., 1., 0., 1., 1.],
            [1., 0., 0., 0., 1.],
            [1., 0., 1., 0., 0.],
            [1., 0., 0., 0., 1.],
            [0., 1., 1., 1., 0.]],
  
           [[0., 1., 1., 1., 1.],
            [0., 0., 0., 1., 1.],
            [0., 1., 1., 0., 0.],
            [1., 0., 0., 1., 0.],
            [0., 0., 0., 0., 1.]]],
  
  
          [[[1., 0., 1., 1., 1.],
            [1., 0., 0., 0., 0.],
            [0., 0., 1., 1., 0.],
            [0., 1., 1., 1., 1.],
            [1., 1., 1., 1., 1.]],
  
           [[0., 0., 1., 0., 1.],
            [0., 0., 1., 0., 0.],
            [1., 1., 0., 0., 0.],
            [0., 1., 1., 0., 0.],
            [0., 1., 1., 1., 0.]],
  
           [[0., 0., 1., 1., 0.],
            [0., 1., 0., 1., 0.],
            [0., 1., 1., 0., 