In [None]:
from __future__ import print_function
import numpy as np
import argparse
import torch
import torch.nn.functional as F
from optimizer import PruneAdam
from model import LeNet, AlexNet
from utils import regularized_nll_loss, admm_loss, \
    initialize_Z_and_U, update_X, update_Z, update_Z_l1, update_U, \
    print_convergence, print_prune, apply_prune, apply_l1_prune
from torchvision import datasets, transforms, models
from tqdm import tqdm
from Fed import FedAvg
import copy

# from torch.utils.tensorboard import SummaryWriter
# %load_ext tensorboard

In [2]:
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--dataset', type=str, default="cifar10", choices=["mnist", "cifar10"],
                    metavar='D', help='training dataset (mnist or cifar10)')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--percent', type=list, default=[0.8, 0.92, 0.991, 0.93],
                    metavar='P', help='pruning percentage (default: 0.8)')
parser.add_argument('--alpha', type=float, default=5e-4, metavar='L',
                    help='l2 norm weight (default: 5e-4)')
parser.add_argument('--rho', type=float, default=1e-2, metavar='R',
                    help='cardinality weight (default: 1e-2)')
parser.add_argument('--l1', default=False, action='store_true',
                    help='prune weights with l1 regularization instead of cardinality')
parser.add_argument('--l2', default=False, action='store_true',
                    help='apply l2 regularization')
parser.add_argument('--num_pre_epochs', type=int, default=3, metavar='P',
                    help='number of epochs to pretrain (default: 3)')
parser.add_argument('--num_epochs', type=int, default=10, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--num_re_epochs', type=int, default=3, metavar='R',
                    help='number of epochs to retrain (default: 3)')
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
                    help='learning rate (default: 1e-2)')
parser.add_argument('--adam_epsilon', type=float, default=1e-8, metavar='E',
                    help='adam epsilon (default: 1e-8)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--save-model', action='store_true', default=False,
                    help='For Saving the current Model')
parser.add_argument('--num_users', action='store_true', default=1,
                    help='Number of users in network')
parser.add_argument('--model', action='store_true', default='resnet',
                    help='Model to train')
args = parser.parse_args(args=[])

In [3]:
args

Namespace(adam_epsilon=1e-08, alpha=0.0005, batch_size=32, dataset='cifar10', l1=False, l2=False, lr=0.001, model='resnet', no_cuda=False, num_epochs=10, num_pre_epochs=3, num_re_epochs=3, num_users=1, percent=[0.8, 0.92, 0.991, 0.93], rho=0.01, save_model=False, seed=1, test_batch_size=1000)

In [4]:
classDict = {'plane':0, 'car':1, 'bird':2, 'cat':3, 'deer':4, 'dog':5, 'frog':6, 'horse':7, 'ship':8, 'truck':9}

# Define a function to separate CIFAR classes by class index

def get_class_i(x, y, i):
    """
    x: trainset.train_data or testset.test_data
    y: trainset.train_labels or testset.test_labels
    i: class label, a number between 0 to 9
    return: x_i
    """
    # Convert to a numpy array
    y = np.array(y)
    # Locate position of labels that equal to i
    pos_i = np.argwhere(y == i)
    # Convert the result into a 1-D list
    pos_i = list(pos_i[:,0])
    # Collect all data that match the desired label
#     x_i = [x[j] for j in pos_i]
    
    return pos_i

In [5]:
def create_dict(dataset_train):
    
    frac = int(len(dataset_train.data) * 0.05)
    x_reserve = dataset_train.data[:frac]
    y_reserve = dataset_train.targets[:frac]
    x_train = dataset_train.data[frac:]
    y_train = dataset_train.targets[frac:]
    
    reserved = get_class_i(x_reserve, y_reserve, classDict['plane']) \
               +get_class_i(x_reserve, y_reserve, classDict['car']) \
               +get_class_i(x_reserve, y_reserve, classDict['bird']) \
               +get_class_i(x_reserve, y_reserve, classDict['cat']) \
               +get_class_i(x_reserve, y_reserve, classDict['deer']) \
               +get_class_i(x_reserve, y_reserve, classDict['dog']) \
               +get_class_i(x_reserve, y_reserve, classDict['frog']) \
               +get_class_i(x_reserve, y_reserve, classDict['horse']) \
               +get_class_i(x_reserve, y_reserve, classDict['ship']) \
               +get_class_i(x_reserve, y_reserve, classDict['truck'])

    train1 = get_class_i(x_train, y_train, classDict['plane']) \
             +get_class_i(x_train, y_train, classDict['car']) \
             +get_class_i(x_train, y_train, classDict['bird'])

    train2 = get_class_i(x_train, y_train, classDict['cat']) \
             +get_class_i(x_train, y_train, classDict['deer']) \
             +get_class_i(x_train, y_train, classDict['dog'])
    train3 = get_class_i(x_train, y_train, classDict['frog']) \
             +get_class_i(x_train, y_train, classDict['horse']) \
             +get_class_i(x_train, y_train, classDict['ship']) \
             +get_class_i(x_train, y_train, classDict['truck'])
    
    dict_users = {0: set(reserved+train3), 1:set(train1), 2:set(train2)}
    return dict_users

class DatasetSplit(torch.utils.data.Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label

In [6]:
gossip = []

def pretrain(args, model, device, train_loader, test_loader, optimizer):
    for epoch in range(args.num_pre_epochs):
        print('Pre epoch: {}'.format(epoch + 1))
        model.train()
        for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = regularized_nll_loss(args, model, output, target)
            loss.backward()
            optimizer.step()
        test(args, model, device, test_loader)


def train(args, model, device, train_loader, test_loader, optimizer, Z, U, report=False):
    model.train()
    print('Epoch: {}'.format(epoch + 1))
    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = admm_loss(args, device, model, Z, U, output, target)
        loss.backward()
        optimizer.step()
    X = update_X(model)
    Z = update_Z_l1(X, U, args) if args.l1 else update_Z(X, U, args)
    U = update_U(U, X, Z)
    print_convergence(model, X, Z)
    test(args, model, device, test_loader, report)


iter = 0
def test(args, model, device, test_loader, report=False):
    model.eval()
    test_loss = 0
    correct = 0
    global iter
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    gossip.append(correct / len(test_loader.dataset))
    
    if report:
#         writer.add_scalar('train/loss_gossip_admm4', test_loss, iter)
#         writer.add_scalar('valid/accuracy_gossip_admm4', correct / len(test_loader.dataset), iter)
        iter+=1


def retrain(args, model, mask, device, train_loader, test_loader, optimizer):
    for epoch in range(args.num_re_epochs):
        print('Re epoch: {}'.format(epoch + 1))
        model.train()
        for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.prune_step(mask)

        test(args, model, device, test_loader)

In [7]:
### MAIN

# writer = SummaryWriter('../../runs/') 
what_gpu = 1
torch.manual_seed(args.seed)
device = torch.device('cuda:{}'.format(what_gpu) if torch.cuda.is_available() else 'cpu')
kwargs = {'num_workers': 5, 'pin_memory': True}

args.percent = [0.8, 0.92, 0.93, 0.94, 0.95, 0.99, 0.99, 0.93]
args.num_pre_epochs = 3
args.num_epochs = 60
args.num_re_epochs = 5
args.num_users = 3
args.dataset = 'cifar10'
args.model = 'resnet'
args.l1 = True
args.l2 = False

trainset = datasets.CIFAR10('../../../data/cifar10', train=True, download=True,
                         transform=transforms.Compose([
                             transforms.ToTensor(),
                             transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
                                                  (0.24703233, 0.24348505, 0.26158768))
                         ]))
dict_users = create_dict(trainset)

test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../../../data/cifar10', train=False, download=True,
                     transform=transforms.Compose([
                         transforms.ToTensor(),
                         transforms.Normalize((0.49139968, 0.48215827, 0.44653124),
                                              (0.24703233, 0.24348505, 0.26158768))
                     ])), shuffle=True, batch_size=args.test_batch_size, **kwargs)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
mdlz = dict()
for usr in range(args.num_users):
    train_loader = torch.utils.data.DataLoader(DatasetSplit(trainset, dict_users[usr]), batch_size=args.batch_size, shuffle=True, **kwargs)
    if args.model == 'resnet':
        model = models.resnet50(pretrained=True).to(device)
    else:
        model = LeNet().to(device) if args.dataset == "mnist" else AlexNet().to(device)
    
    optimizer = PruneAdam(model.named_parameters(), lr=args.lr, eps=args.adam_epsilon)
    pretrain(args, model, device, train_loader, test_loader, optimizer)
    Z, U = initialize_Z_and_U(model)
    mdlz[usr] = (model, optimizer, Z, U)
    
for epoch in range(args.num_epochs):
    w_locals=[]
    idxs_users = np.random.choice(range(args.num_users), 2, replace=False)
    for usr in idxs_users:
        report = True if usr == 0 else False
        train_loader = torch.utils.data.DataLoader(DatasetSplit(trainset, dict_users[usr]), batch_size=args.batch_size, shuffle=True, **kwargs)
        train(args, mdlz[usr][0], device, train_loader, test_loader, mdlz[usr][1], mdlz[usr][2], mdlz[usr][3], report=report)
        w = mdlz[usr][0].state_dict()
        w_locals.append(copy.deepcopy(w))

    # update global weights
    w_glob = FedAvg(w_locals)

    for idx in idxs_users:
        # copy weight to net_glob
        mdlz[idx][0].load_state_dict(w_glob)

  0%|          | 0/640 [00:00<?, ?it/s]

Pre epoch: 1


100%|██████████| 640/640 [00:25<00:00, 25.47it/s]
  0%|          | 0/640 [00:00<?, ?it/s]


Test set: Average loss: -7351.0350, Accuracy: 1000/10000 (10%)

Pre epoch: 2


100%|██████████| 640/640 [00:27<00:00, 23.12it/s]
  0%|          | 0/640 [00:00<?, ?it/s]


Test set: Average loss: -32210.7048, Accuracy: 1025/10000 (10%)

Pre epoch: 3


100%|██████████| 640/640 [00:27<00:00, 23.29it/s]



Test set: Average loss: -57130.8472, Accuracy: 1059/10000 (11%)



  0%|          | 0/446 [00:00<?, ?it/s]

Pre epoch: 1


100%|██████████| 446/446 [00:20<00:00, 22.22it/s]
  0%|          | 0/446 [00:00<?, ?it/s]


Test set: Average loss: -4119.2290, Accuracy: 1386/10000 (14%)

Pre epoch: 2


100%|██████████| 446/446 [00:20<00:00, 21.85it/s]
  0%|          | 0/446 [00:00<?, ?it/s]


Test set: Average loss: -15969.0273, Accuracy: 957/10000 (10%)

Pre epoch: 3


100%|██████████| 446/446 [00:20<00:00, 21.77it/s]



Test set: Average loss: -35892.0668, Accuracy: 969/10000 (10%)



  0%|          | 0/446 [00:00<?, ?it/s]

Pre epoch: 1


100%|██████████| 446/446 [00:18<00:00, 24.27it/s]
  0%|          | 0/446 [00:00<?, ?it/s]


Test set: Average loss: -4274.3794, Accuracy: 1126/10000 (11%)

Pre epoch: 2


100%|██████████| 446/446 [00:18<00:00, 24.01it/s]
  0%|          | 0/446 [00:00<?, ?it/s]


Test set: Average loss: -14728.7476, Accuracy: 1049/10000 (10%)

Pre epoch: 3


100%|██████████| 446/446 [00:19<00:00, 22.76it/s]
  0%|          | 0/446 [00:00<?, ?it/s]


Test set: Average loss: -34465.3804, Accuracy: 1014/10000 (10%)

Epoch: 1


100%|██████████| 446/446 [00:49<00:00,  8.94it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.6419
(bn1.weight): 0.0500
(layer1.0.conv1.weight): 0.2464
(layer1.0.bn1.weight): 0.0500
(layer1.0.conv2.weight): 0.5138
(layer1.0.bn2.weight): 0.0500
(layer1.0.conv3.weight): 0.4134
(layer1.0.bn3.weight): 0.0502
(layer1.0.downsample.0.weight): 0.4191
(layer1.0.downsample.1.weight): 0.0493
(layer1.1.conv1.weight): 0.2475
(layer1.1.bn1.weight): 0.0500
(layer1.1.conv2.weight): 0.5249
(layer1.1.bn2.weight): 0.0495
(layer1.1.conv3.weight): 0.4200
(layer1.1.bn3.weight): 0.0502
(layer1.2.conv1.weight): 0.2503
(layer1.2.bn1.weight): 0.0496
(layer1.2.conv2.weight): 0.5231
(layer1.2.bn2.weight): 0.0501
(layer1.2.conv3.weight): 0.4181
(layer1.2.bn3.weight): 0.0505
(layer2.0.conv1.weight): 0.3274
(layer2.0.bn1.weight): 0.0495
(layer2.0.conv2.weight): 0.5889
(layer2.0.bn2.weight): 0.0498
(layer2.0.conv3.weight): 0.5005
(layer2.0.bn3.weight): 0.0506
(layer2.0.downsample.0.weight): 0.5051
(layer2.0.downsample.1.weight): 0.0507
(layer2.1.conv1

  0%|          | 0/640 [00:00<?, ?it/s]


Test set: Average loss: -65357.7284, Accuracy: 1495/10000 (15%)

Epoch: 1


100%|██████████| 640/640 [01:04<00:00,  9.95it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.5729
(bn1.weight): 0.0495
(layer1.0.conv1.weight): 0.2469
(layer1.0.bn1.weight): 0.0502
(layer1.0.conv2.weight): 0.4813
(layer1.0.bn2.weight): 0.0505
(layer1.0.conv3.weight): 0.3966
(layer1.0.bn3.weight): 0.0495
(layer1.0.downsample.0.weight): 0.3997
(layer1.0.downsample.1.weight): 0.0510
(layer1.1.conv1.weight): 0.2455
(layer1.1.bn1.weight): 0.0501
(layer1.1.conv2.weight): 0.4903
(layer1.1.bn2.weight): 0.0503
(layer1.1.conv3.weight): 0.3962
(layer1.1.bn3.weight): 0.0493
(layer1.2.conv1.weight): 0.2459
(layer1.2.bn1.weight): 0.0502
(layer1.2.conv2.weight): 0.4836
(layer1.2.bn2.weight): 0.0501
(layer1.2.conv3.weight): 0.3935
(layer1.2.bn3.weight): 0.0501
(layer2.0.conv1.weight): 0.3175
(layer2.0.bn1.weight): 0.0503
(layer2.0.conv2.weight): 0.5336
(layer2.0.bn2.weight): 0.0500
(layer2.0.conv3.weight): 0.4587
(layer2.0.bn3.weight): 0.0506
(layer2.0.downsample.0.weight): 0.4667
(layer2.0.downsample.1.weight): 0.0506
(layer2.1.conv1

  0%|          | 0/640 [00:00<?, ?it/s]


Test set: Average loss: -121219.0176, Accuracy: 905/10000 (9%)

Epoch: 2


100%|██████████| 640/640 [01:11<00:00,  9.01it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.6508
(bn1.weight): 0.0498
(layer1.0.conv1.weight): 0.3252
(layer1.0.bn1.weight): 0.0500
(layer1.0.conv2.weight): 0.5555
(layer1.0.bn2.weight): 0.0501
(layer1.0.conv3.weight): 0.4792
(layer1.0.bn3.weight): 0.0504
(layer1.0.downsample.0.weight): 0.4826
(layer1.0.downsample.1.weight): 0.0492
(layer1.1.conv1.weight): 0.3213
(layer1.1.bn1.weight): 0.0501
(layer1.1.conv2.weight): 0.5779
(layer1.1.bn2.weight): 0.0498
(layer1.1.conv3.weight): 0.4928
(layer1.1.bn3.weight): 0.0499
(layer1.2.conv1.weight): 0.3258
(layer1.2.bn1.weight): 0.0500
(layer1.2.conv2.weight): 0.5763
(layer1.2.bn2.weight): 0.0500
(layer1.2.conv3.weight): 0.4882
(layer1.2.bn3.weight): 0.0509
(layer2.0.conv1.weight): 0.4092
(layer2.0.bn1.weight): 0.0501
(layer2.0.conv2.weight): 0.6311
(layer2.0.bn2.weight): 0.0500
(layer2.0.conv3.weight): 0.5590
(layer2.0.bn3.weight): 0.0508
(layer2.0.downsample.0.weight): 0.5549
(layer2.0.downsample.1.weight): 0.0507
(layer2.1.conv1

  0%|          | 0/446 [00:00<?, ?it/s]


Test set: Average loss: -129910.6264, Accuracy: 1000/10000 (10%)

Epoch: 2


100%|██████████| 446/446 [00:48<00:00,  9.25it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.6365
(bn1.weight): 0.0499
(layer1.0.conv1.weight): 0.3180
(layer1.0.bn1.weight): 0.0504
(layer1.0.conv2.weight): 0.5461
(layer1.0.bn2.weight): 0.0504
(layer1.0.conv3.weight): 0.4725
(layer1.0.bn3.weight): 0.0495
(layer1.0.downsample.0.weight): 0.4684
(layer1.0.downsample.1.weight): 0.0509
(layer1.1.conv1.weight): 0.3169
(layer1.1.bn1.weight): 0.0501
(layer1.1.conv2.weight): 0.5584
(layer1.1.bn2.weight): 0.0502
(layer1.1.conv3.weight): 0.4719
(layer1.1.bn3.weight): 0.0492
(layer1.2.conv1.weight): 0.3209
(layer1.2.bn1.weight): 0.0501
(layer1.2.conv2.weight): 0.5469
(layer1.2.bn2.weight): 0.0504
(layer1.2.conv3.weight): 0.4692
(layer1.2.bn3.weight): 0.0496
(layer2.0.conv1.weight): 0.3931
(layer2.0.bn1.weight): 0.0501
(layer2.0.conv2.weight): 0.5829
(layer2.0.bn2.weight): 0.0504
(layer2.0.conv3.weight): 0.5338
(layer2.0.bn3.weight): 0.0502
(layer2.0.downsample.0.weight): 0.5360
(layer2.0.downsample.1.weight): 0.0508
(layer2.1.conv1

  0%|          | 0/446 [00:00<?, ?it/s]


Test set: Average loss: -124599.7408, Accuracy: 1140/10000 (11%)

Epoch: 3


100%|██████████| 446/446 [00:44<00:00, 10.12it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.6324
(bn1.weight): 0.0497
(layer1.0.conv1.weight): 0.3228
(layer1.0.bn1.weight): 0.0503
(layer1.0.conv2.weight): 0.5472
(layer1.0.bn2.weight): 0.0505
(layer1.0.conv3.weight): 0.4741
(layer1.0.bn3.weight): 0.0503
(layer1.0.downsample.0.weight): 0.4721
(layer1.0.downsample.1.weight): 0.0502
(layer1.1.conv1.weight): 0.3180
(layer1.1.bn1.weight): 0.0499
(layer1.1.conv2.weight): 0.5541
(layer1.1.bn2.weight): 0.0499
(layer1.1.conv3.weight): 0.4774
(layer1.1.bn3.weight): 0.0494
(layer1.2.conv1.weight): 0.3193
(layer1.2.bn1.weight): 0.0499
(layer1.2.conv2.weight): 0.5506
(layer1.2.bn2.weight): 0.0502
(layer1.2.conv3.weight): 0.4721
(layer1.2.bn3.weight): 0.0492
(layer2.0.conv1.weight): 0.3974
(layer2.0.bn1.weight): 0.0502
(layer2.0.conv2.weight): 0.5838
(layer2.0.bn2.weight): 0.0502
(layer2.0.conv3.weight): 0.5369
(layer2.0.bn3.weight): 0.0506
(layer2.0.downsample.0.weight): 0.5380
(layer2.0.downsample.1.weight): 0.0510
(layer2.1.conv1

  0%|          | 0/446 [00:00<?, ?it/s]


Test set: Average loss: -141684.7232, Accuracy: 1000/10000 (10%)

Epoch: 3


100%|██████████| 446/446 [00:50<00:00,  8.83it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.6525
(bn1.weight): 0.0501
(layer1.0.conv1.weight): 0.2506
(layer1.0.bn1.weight): 0.0505
(layer1.0.conv2.weight): 0.5290
(layer1.0.bn2.weight): 0.0503
(layer1.0.conv3.weight): 0.4271
(layer1.0.bn3.weight): 0.0496
(layer1.0.downsample.0.weight): 0.4248
(layer1.0.downsample.1.weight): 0.0492
(layer1.1.conv1.weight): 0.2533
(layer1.1.bn1.weight): 0.0503
(layer1.1.conv2.weight): 0.5374
(layer1.1.bn2.weight): 0.0501
(layer1.1.conv3.weight): 0.4216
(layer1.1.bn3.weight): 0.0504
(layer1.2.conv1.weight): 0.2500
(layer1.2.bn1.weight): 0.0503
(layer1.2.conv2.weight): 0.5336
(layer1.2.bn2.weight): 0.0498
(layer1.2.conv3.weight): 0.4189
(layer1.2.bn3.weight): 0.0511
(layer2.0.conv1.weight): 0.3324
(layer2.0.bn1.weight): 0.0500
(layer2.0.conv2.weight): 0.6125
(layer2.0.bn2.weight): 0.0498
(layer2.0.conv3.weight): 0.5115
(layer2.0.bn3.weight): 0.0494
(layer2.0.downsample.0.weight): 0.5153
(layer2.0.downsample.1.weight): 0.0497
(layer2.1.conv1

  0%|          | 0/446 [00:00<?, ?it/s]


Test set: Average loss: -65994.1268, Accuracy: 949/10000 (9%)

Epoch: 4


100%|██████████| 446/446 [00:49<00:00,  9.08it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.6572
(bn1.weight): 0.0499
(layer1.0.conv1.weight): 0.3563
(layer1.0.bn1.weight): 0.0504
(layer1.0.conv2.weight): 0.5882
(layer1.0.bn2.weight): 0.0506
(layer1.0.conv3.weight): 0.5175
(layer1.0.bn3.weight): 0.0497
(layer1.0.downsample.0.weight): 0.5087
(layer1.0.downsample.1.weight): 0.0502
(layer1.1.conv1.weight): 0.3545
(layer1.1.bn1.weight): 0.0494
(layer1.1.conv2.weight): 0.5789
(layer1.1.bn2.weight): 0.0501
(layer1.1.conv3.weight): 0.5157
(layer1.1.bn3.weight): 0.0489
(layer1.2.conv1.weight): 0.3569
(layer1.2.bn1.weight): 0.0502
(layer1.2.conv2.weight): 0.6006
(layer1.2.bn2.weight): 0.0499
(layer1.2.conv3.weight): 0.5237
(layer1.2.bn3.weight): 0.0506
(layer2.0.conv1.weight): 0.4480
(layer2.0.bn1.weight): 0.0502
(layer2.0.conv2.weight): 0.6452
(layer2.0.bn2.weight): 0.0499
(layer2.0.conv3.weight): 0.5869
(layer2.0.bn3.weight): 0.0498
(layer2.0.downsample.0.weight): 0.5817
(layer2.0.downsample.1.weight): 0.0498
(layer2.1.conv1

  0%|          | 0/446 [00:00<?, ?it/s]


Test set: Average loss: -160306.2144, Accuracy: 1155/10000 (12%)

Epoch: 4


100%|██████████| 446/446 [00:49<00:00,  9.07it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.7272
(bn1.weight): 0.0500
(layer1.0.conv1.weight): 0.3748
(layer1.0.bn1.weight): 0.0505
(layer1.0.conv2.weight): 0.6453
(layer1.0.bn2.weight): 0.0504
(layer1.0.conv3.weight): 0.5622
(layer1.0.bn3.weight): 0.0502
(layer1.0.downsample.0.weight): 0.5678
(layer1.0.downsample.1.weight): 0.0497
(layer1.1.conv1.weight): 0.3763
(layer1.1.bn1.weight): 0.0501
(layer1.1.conv2.weight): 0.6634
(layer1.1.bn2.weight): 0.0500
(layer1.1.conv3.weight): 0.5615
(layer1.1.bn3.weight): 0.0501
(layer1.2.conv1.weight): 0.3699
(layer1.2.bn1.weight): 0.0502
(layer1.2.conv2.weight): 0.6586
(layer1.2.bn2.weight): 0.0501
(layer1.2.conv3.weight): 0.5595
(layer1.2.bn3.weight): 0.0498
(layer2.0.conv1.weight): 0.4711
(layer2.0.bn1.weight): 0.0502
(layer2.0.conv2.weight): 0.7114
(layer2.0.bn2.weight): 0.0500
(layer2.0.conv3.weight): 0.6404
(layer2.0.bn3.weight): 0.0497
(layer2.0.downsample.0.weight): 0.6360
(layer2.0.downsample.1.weight): 0.0506
(layer2.1.conv1

  0%|          | 0/446 [00:00<?, ?it/s]


Test set: Average loss: -139208.9072, Accuracy: 938/10000 (9%)

Epoch: 5


100%|██████████| 446/446 [00:48<00:00,  9.16it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.6689
(bn1.weight): 0.0498
(layer1.0.conv1.weight): 0.3630
(layer1.0.bn1.weight): 0.0506
(layer1.0.conv2.weight): 0.6009
(layer1.0.bn2.weight): 0.0506
(layer1.0.conv3.weight): 0.5292
(layer1.0.bn3.weight): 0.0495
(layer1.0.downsample.0.weight): 0.5289
(layer1.0.downsample.1.weight): 0.0505
(layer1.1.conv1.weight): 0.3625
(layer1.1.bn1.weight): 0.0497
(layer1.1.conv2.weight): 0.6061
(layer1.1.bn2.weight): 0.0501
(layer1.1.conv3.weight): 0.5295
(layer1.1.bn3.weight): 0.0497
(layer1.2.conv1.weight): 0.3598
(layer1.2.bn1.weight): 0.0501
(layer1.2.conv2.weight): 0.6077
(layer1.2.bn2.weight): 0.0501
(layer1.2.conv3.weight): 0.5308
(layer1.2.bn3.weight): 0.0496
(layer2.0.conv1.weight): 0.4543
(layer2.0.bn1.weight): 0.0502
(layer2.0.conv2.weight): 0.6571
(layer2.0.bn2.weight): 0.0500
(layer2.0.conv3.weight): 0.5970
(layer2.0.bn3.weight): 0.0499
(layer2.0.downsample.0.weight): 0.5920
(layer2.0.downsample.1.weight): 0.0510
(layer2.1.conv1

  0%|          | 0/446 [00:00<?, ?it/s]


Test set: Average loss: -197544.7808, Accuracy: 948/10000 (9%)

Epoch: 5


100%|██████████| 446/446 [00:49<00:00,  8.95it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.6701
(bn1.weight): 0.0498
(layer1.0.conv1.weight): 0.3654
(layer1.0.bn1.weight): 0.0504
(layer1.0.conv2.weight): 0.6137
(layer1.0.bn2.weight): 0.0503
(layer1.0.conv3.weight): 0.5386
(layer1.0.bn3.weight): 0.0501
(layer1.0.downsample.0.weight): 0.5349
(layer1.0.downsample.1.weight): 0.0499
(layer1.1.conv1.weight): 0.3633
(layer1.1.bn1.weight): 0.0499
(layer1.1.conv2.weight): 0.6106
(layer1.1.bn2.weight): 0.0501
(layer1.1.conv3.weight): 0.5290
(layer1.1.bn3.weight): 0.0499
(layer1.2.conv1.weight): 0.3585
(layer1.2.bn1.weight): 0.0502
(layer1.2.conv2.weight): 0.5968
(layer1.2.bn2.weight): 0.0502
(layer1.2.conv3.weight): 0.5227
(layer1.2.bn3.weight): 0.0498
(layer2.0.conv1.weight): 0.4537
(layer2.0.bn1.weight): 0.0503
(layer2.0.conv2.weight): 0.6575
(layer2.0.bn2.weight): 0.0499
(layer2.0.conv3.weight): 0.5979
(layer2.0.bn3.weight): 0.0495
(layer2.0.downsample.0.weight): 0.5939
(layer2.0.downsample.1.weight): 0.0501
(layer2.1.conv1

  0%|          | 0/640 [00:00<?, ?it/s]


Test set: Average loss: -197802.5760, Accuracy: 1191/10000 (12%)

Epoch: 6


100%|██████████| 640/640 [01:11<00:00,  8.97it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.6255
(bn1.weight): 0.0499
(layer1.0.conv1.weight): 0.3187
(layer1.0.bn1.weight): 0.0503
(layer1.0.conv2.weight): 0.5262
(layer1.0.bn2.weight): 0.0503
(layer1.0.conv3.weight): 0.4590
(layer1.0.bn3.weight): 0.0501
(layer1.0.downsample.0.weight): 0.4659
(layer1.0.downsample.1.weight): 0.0507
(layer1.1.conv1.weight): 0.3109
(layer1.1.bn1.weight): 0.0501
(layer1.1.conv2.weight): 0.5340
(layer1.1.bn2.weight): 0.0502
(layer1.1.conv3.weight): 0.4671
(layer1.1.bn3.weight): 0.0485
(layer1.2.conv1.weight): 0.3181
(layer1.2.bn1.weight): 0.0501
(layer1.2.conv2.weight): 0.5383
(layer1.2.bn2.weight): 0.0503
(layer1.2.conv3.weight): 0.4643
(layer1.2.bn3.weight): 0.0492
(layer2.0.conv1.weight): 0.3900
(layer2.0.bn1.weight): 0.0501
(layer2.0.conv2.weight): 0.5912
(layer2.0.bn2.weight): 0.0505
(layer2.0.conv3.weight): 0.5379
(layer2.0.bn3.weight): 0.0500
(layer2.0.downsample.0.weight): 0.5335
(layer2.0.downsample.1.weight): 0.0518
(layer2.1.conv1

  0%|          | 0/446 [00:00<?, ?it/s]


Test set: Average loss: -156801.3296, Accuracy: 981/10000 (10%)

Epoch: 6


100%|██████████| 446/446 [00:50<00:00,  8.78it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.6603
(bn1.weight): 0.0499
(layer1.0.conv1.weight): 0.3650
(layer1.0.bn1.weight): 0.0504
(layer1.0.conv2.weight): 0.6033
(layer1.0.bn2.weight): 0.0504
(layer1.0.conv3.weight): 0.5307
(layer1.0.bn3.weight): 0.0499
(layer1.0.downsample.0.weight): 0.5299
(layer1.0.downsample.1.weight): 0.0499
(layer1.1.conv1.weight): 0.3616
(layer1.1.bn1.weight): 0.0500
(layer1.1.conv2.weight): 0.6070
(layer1.1.bn2.weight): 0.0501
(layer1.1.conv3.weight): 0.5282
(layer1.1.bn3.weight): 0.0501
(layer1.2.conv1.weight): 0.3591
(layer1.2.bn1.weight): 0.0502
(layer1.2.conv2.weight): 0.6084
(layer1.2.bn2.weight): 0.0501
(layer1.2.conv3.weight): 0.5314
(layer1.2.bn3.weight): 0.0498
(layer2.0.conv1.weight): 0.4560
(layer2.0.bn1.weight): 0.0503
(layer2.0.conv2.weight): 0.6646
(layer2.0.bn2.weight): 0.0499
(layer2.0.conv3.weight): 0.6066
(layer2.0.bn3.weight): 0.0496
(layer2.0.downsample.0.weight): 0.6022
(layer2.0.downsample.1.weight): 0.0504
(layer2.1.conv1

  0%|          | 0/446 [00:00<?, ?it/s]


Test set: Average loss: -240530.1264, Accuracy: 1035/10000 (10%)

Epoch: 7


100%|██████████| 446/446 [00:50<00:00,  8.89it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.6791
(bn1.weight): 0.0499
(layer1.0.conv1.weight): 0.3734
(layer1.0.bn1.weight): 0.0505
(layer1.0.conv2.weight): 0.5991
(layer1.0.bn2.weight): 0.0505
(layer1.0.conv3.weight): 0.5275
(layer1.0.bn3.weight): 0.0499
(layer1.0.downsample.0.weight): 0.5367
(layer1.0.downsample.1.weight): 0.0503
(layer1.1.conv1.weight): 0.3665
(layer1.1.bn1.weight): 0.0502
(layer1.1.conv2.weight): 0.6093
(layer1.1.bn2.weight): 0.0504
(layer1.1.conv3.weight): 0.5315
(layer1.1.bn3.weight): 0.0493
(layer1.2.conv1.weight): 0.3714
(layer1.2.bn1.weight): 0.0501
(layer1.2.conv2.weight): 0.6127
(layer1.2.bn2.weight): 0.0502
(layer1.2.conv3.weight): 0.5364
(layer1.2.bn3.weight): 0.0500
(layer2.0.conv1.weight): 0.4557
(layer2.0.bn1.weight): 0.0500
(layer2.0.conv2.weight): 0.6488
(layer2.0.bn2.weight): 0.0502
(layer2.0.conv3.weight): 0.6016
(layer2.0.bn3.weight): 0.0502
(layer2.0.downsample.0.weight): 0.6001
(layer2.0.downsample.1.weight): 0.0506
(layer2.1.conv1

  0%|          | 0/640 [00:00<?, ?it/s]


Test set: Average loss: -253381.2032, Accuracy: 906/10000 (9%)

Epoch: 7


100%|██████████| 640/640 [01:12<00:00,  8.85it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.6445
(bn1.weight): 0.0499
(layer1.0.conv1.weight): 0.3652
(layer1.0.bn1.weight): 0.0507
(layer1.0.conv2.weight): 0.5665
(layer1.0.bn2.weight): 0.0506
(layer1.0.conv3.weight): 0.5090
(layer1.0.bn3.weight): 0.0499
(layer1.0.downsample.0.weight): 0.5133
(layer1.0.downsample.1.weight): 0.0511
(layer1.1.conv1.weight): 0.3554
(layer1.1.bn1.weight): 0.0498
(layer1.1.conv2.weight): 0.5736
(layer1.1.bn2.weight): 0.0500
(layer1.1.conv3.weight): 0.5063
(layer1.1.bn3.weight): 0.0484
(layer1.2.conv1.weight): 0.3614
(layer1.2.bn1.weight): 0.0495
(layer1.2.conv2.weight): 0.5715
(layer1.2.bn2.weight): 0.0500
(layer1.2.conv3.weight): 0.5106
(layer1.2.bn3.weight): 0.0493
(layer2.0.conv1.weight): 0.4403
(layer2.0.bn1.weight): 0.0501
(layer2.0.conv2.weight): 0.6040
(layer2.0.bn2.weight): 0.0500
(layer2.0.conv3.weight): 0.5638
(layer2.0.bn3.weight): 0.0497
(layer2.0.downsample.0.weight): 0.5703
(layer2.0.downsample.1.weight): 0.0509
(layer2.1.conv1

  0%|          | 0/446 [00:00<?, ?it/s]


Test set: Average loss: -290795.0752, Accuracy: 950/10000 (10%)

Epoch: 8


100%|██████████| 446/446 [00:50<00:00,  8.80it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.6429
(bn1.weight): 0.0499
(layer1.0.conv1.weight): 0.3605
(layer1.0.bn1.weight): 0.0505
(layer1.0.conv2.weight): 0.5833
(layer1.0.bn2.weight): 0.0506
(layer1.0.conv3.weight): 0.5172
(layer1.0.bn3.weight): 0.0497
(layer1.0.downsample.0.weight): 0.5171
(layer1.0.downsample.1.weight): 0.0500
(layer1.1.conv1.weight): 0.3582
(layer1.1.bn1.weight): 0.0496
(layer1.1.conv2.weight): 0.5841
(layer1.1.bn2.weight): 0.0501
(layer1.1.conv3.weight): 0.5056
(layer1.1.bn3.weight): 0.0495
(layer1.2.conv1.weight): 0.3465
(layer1.2.bn1.weight): 0.0501
(layer1.2.conv2.weight): 0.5543
(layer1.2.bn2.weight): 0.0501
(layer1.2.conv3.weight): 0.4999
(layer1.2.bn3.weight): 0.0502
(layer2.0.conv1.weight): 0.4346
(layer2.0.bn1.weight): 0.0502
(layer2.0.conv2.weight): 0.5996
(layer2.0.bn2.weight): 0.0500
(layer2.0.conv3.weight): 0.5555
(layer2.0.bn3.weight): 0.0495
(layer2.0.downsample.0.weight): 0.5585
(layer2.0.downsample.1.weight): 0.0505
(layer2.1.conv1

  0%|          | 0/640 [00:00<?, ?it/s]


Test set: Average loss: -227269.7744, Accuracy: 885/10000 (9%)

Epoch: 8


100%|██████████| 640/640 [01:10<00:00,  9.06it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.6583
(bn1.weight): 0.0499
(layer1.0.conv1.weight): 0.3663
(layer1.0.bn1.weight): 0.0507
(layer1.0.conv2.weight): 0.5765
(layer1.0.bn2.weight): 0.0504
(layer1.0.conv3.weight): 0.5111
(layer1.0.bn3.weight): 0.0497
(layer1.0.downsample.0.weight): 0.5215
(layer1.0.downsample.1.weight): 0.0506
(layer1.1.conv1.weight): 0.3592
(layer1.1.bn1.weight): 0.0498
(layer1.1.conv2.weight): 0.5805
(layer1.1.bn2.weight): 0.0504
(layer1.1.conv3.weight): 0.5123
(layer1.1.bn3.weight): 0.0489
(layer1.2.conv1.weight): 0.3643
(layer1.2.bn1.weight): 0.0498
(layer1.2.conv2.weight): 0.5833
(layer1.2.bn2.weight): 0.0502
(layer1.2.conv3.weight): 0.5221
(layer1.2.bn3.weight): 0.0497
(layer2.0.conv1.weight): 0.4462
(layer2.0.bn1.weight): 0.0500
(layer2.0.conv2.weight): 0.6259
(layer2.0.bn2.weight): 0.0500
(layer2.0.conv3.weight): 0.5733
(layer2.0.bn3.weight): 0.0505
(layer2.0.downsample.0.weight): 0.5673
(layer2.0.downsample.1.weight): 0.0505
(layer2.1.conv1

  0%|          | 0/446 [00:00<?, ?it/s]


Test set: Average loss: -345958.8672, Accuracy: 1072/10000 (11%)

Epoch: 9


100%|██████████| 446/446 [00:51<00:00,  8.67it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.6372
(bn1.weight): 0.0499
(layer1.0.conv1.weight): 0.3663
(layer1.0.bn1.weight): 0.0506
(layer1.0.conv2.weight): 0.5676
(layer1.0.bn2.weight): 0.0504
(layer1.0.conv3.weight): 0.5054
(layer1.0.bn3.weight): 0.0502
(layer1.0.downsample.0.weight): 0.5083
(layer1.0.downsample.1.weight): 0.0505
(layer1.1.conv1.weight): 0.3563
(layer1.1.bn1.weight): 0.0501
(layer1.1.conv2.weight): 0.5709
(layer1.1.bn2.weight): 0.0503
(layer1.1.conv3.weight): 0.5073
(layer1.1.bn3.weight): 0.0486
(layer1.2.conv1.weight): 0.3597
(layer1.2.bn1.weight): 0.0496
(layer1.2.conv2.weight): 0.5672
(layer1.2.bn2.weight): 0.0502
(layer1.2.conv3.weight): 0.5045
(layer1.2.bn3.weight): 0.0498
(layer2.0.conv1.weight): 0.4396
(layer2.0.bn1.weight): 0.0502
(layer2.0.conv2.weight): 0.6047
(layer2.0.bn2.weight): 0.0502
(layer2.0.conv3.weight): 0.5642
(layer2.0.bn3.weight): 0.0498
(layer2.0.downsample.0.weight): 0.5610
(layer2.0.downsample.1.weight): 0.0507
(layer2.1.conv1

  0%|          | 0/640 [00:00<?, ?it/s]


Test set: Average loss: -326806.5696, Accuracy: 1064/10000 (11%)

Epoch: 9


100%|██████████| 640/640 [01:13<00:00,  8.74it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.6719
(bn1.weight): 0.0499
(layer1.0.conv1.weight): 0.3753
(layer1.0.bn1.weight): 0.0505
(layer1.0.conv2.weight): 0.6060
(layer1.0.bn2.weight): 0.0504
(layer1.0.conv3.weight): 0.5325
(layer1.0.bn3.weight): 0.0499
(layer1.0.downsample.0.weight): 0.5338
(layer1.0.downsample.1.weight): 0.0502
(layer1.1.conv1.weight): 0.3705
(layer1.1.bn1.weight): 0.0498
(layer1.1.conv2.weight): 0.5956
(layer1.1.bn2.weight): 0.0503
(layer1.1.conv3.weight): 0.5209
(layer1.1.bn3.weight): 0.0488
(layer1.2.conv1.weight): 0.3659
(layer1.2.bn1.weight): 0.0499
(layer1.2.conv2.weight): 0.5825
(layer1.2.bn2.weight): 0.0501
(layer1.2.conv3.weight): 0.5285
(layer1.2.bn3.weight): 0.0501
(layer2.0.conv1.weight): 0.4527
(layer2.0.bn1.weight): 0.0500
(layer2.0.conv2.weight): 0.6250
(layer2.0.bn2.weight): 0.0502
(layer2.0.conv3.weight): 0.5802
(layer2.0.bn3.weight): 0.0499
(layer2.0.downsample.0.weight): 0.5778
(layer2.0.downsample.1.weight): 0.0507
(layer2.1.conv1

  0%|          | 0/446 [00:00<?, ?it/s]


Test set: Average loss: -338518.6464, Accuracy: 1012/10000 (10%)

Epoch: 10


 94%|█████████▍| 421/446 [00:46<00:02,  9.09it/s]


KeyboardInterrupt: 

In [None]:
### Apply pruning
for usr in range(args.num_users):
    mask = apply_l1_prune(mdlz[usr][0], device, args) if args.l1 else apply_prune(mdlz[usr][0], device, args)
    print_prune(model)
    test(args, mdlz[usr][0], device, test_loader)
    retrain(args, model, mask, device, train_loader, test_loader, optimizer)

# Testing ground

In [None]:
from torch.utils.data import DataLoader
import torch.nn.functional as F

confusion_matrix = torch.zeros(10, 10)
mdlz[0][0].eval()
# testing
test_loss = 0
correct = 0

l = len(test_loader)
for idx, (data, target) in enumerate(test_loader):
    data, target = data.to(device), target.to(device)
    log_probs = mdlz[0][0](data)
    # sum up batch loss
    test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
    # get the index of the max log-probability
    y_pred = log_probs.data.max(1, keepdim=True)[1]
    correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

    for t, p in zip(target.data, y_pred.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1

In [None]:
print(confusion_matrix)

In [None]:
gossip