In [None]:
from __future__ import print_function
import numpy as np
import argparse
import torch
import torch.nn.functional as F
from optimizer import PruneAdam
from model import LeNet, AlexNet
from utils import regularized_nll_loss, admm_loss, \
    initialize_Z_and_U, update_X, update_Z, update_Z_l1, update_U, \
    print_convergence, print_prune, apply_prune, apply_l1_prune
from torchvision import datasets, transforms, models
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms, datasets, models
from tqdm import tqdm
from Fed import FedAvg
import copy
import os
from PIL import Image

# from torch.utils.tensorboard import SummaryWriter
# %load_ext tensorboard

In [None]:
class CustomDataSet(Dataset):
    def __init__(self, main_dir, x, y, transform):
        self.main_dir = main_dir
        self.transform = transform
        self.y = y
        self.all_imgs = x
#         self.total_imgs = natsort.natsorted(all_imgs)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.all_imgs[idx])
        image = Image.open(img_loc).convert("RGB")
        tensor_image = self.transform(image)
        return tensor_image,self.y[idx]

In [None]:
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--dataset', type=str, default="cifar10", choices=["mnist", "cifar10"],
                    metavar='D', help='training dataset (mnist or cifar10)')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                    help='input batch size for testing (default: 1000)')
parser.add_argument('--percent', type=list, default=[0.8, 0.92, 0.991, 0.93],
                    metavar='P', help='pruning percentage (default: 0.8)')
parser.add_argument('--alpha', type=float, default=5e-4, metavar='L',
                    help='l2 norm weight (default: 5e-4)')
parser.add_argument('--rho', type=float, default=1e-2, metavar='R',
                    help='cardinality weight (default: 1e-2)')
parser.add_argument('--l1', default=False, action='store_true',
                    help='prune weights with l1 regularization instead of cardinality')
parser.add_argument('--l2', default=False, action='store_true',
                    help='apply l2 regularization')
parser.add_argument('--num_pre_epochs', type=int, default=3, metavar='P',
                    help='number of epochs to pretrain (default: 3)')
parser.add_argument('--num_epochs', type=int, default=10, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--num_re_epochs', type=int, default=3, metavar='R',
                    help='number of epochs to retrain (default: 3)')
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
                    help='learning rate (default: 1e-2)')
parser.add_argument('--adam_epsilon', type=float, default=1e-8, metavar='E',
                    help='adam epsilon (default: 1e-8)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--save-model', action='store_true', default=False,
                    help='For Saving the current Model')
parser.add_argument('--num_users', action='store_true', default=1,
                    help='Number of users in network')
parser.add_argument('--model', action='store_true', default='resnet',
                    help='Model to train')
args = parser.parse_args(args=[])

In [None]:
classDict = {'not_smiling':0, 'smiling':1}

# Define a function to separate CIFAR classes by class index

def get_class_i(x, y, i):
    """
    x: trainset.train_data or testset.test_data
    y: trainset.train_labels or testset.test_labels
    i: class label, a number between 0 to 9
    return: x_i
    """
    # Convert to a numpy array
    y = np.array(y)
    # Locate position of labels that equal to i
    pos_i = np.argwhere(y == i)
    # Convert the result into a 1-D list
    pos_i = list(pos_i[:,0])
    # Collect all data that match the desired label
#     x_i = [x[j] for j in pos_i]
    
    return pos_i

In [5]:
# Standard transformations for improving celebA. 
# Transformations A
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop((218,178), padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Transformations B
RC   = transforms.RandomCrop((218,178), padding=4)
RHF  = transforms.RandomHorizontalFlip()
RVF  = transforms.RandomVerticalFlip()
NRM  = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
TT   = transforms.ToTensor()
TPIL = transforms.ToPILImage()

# Transforms object for trainset with augmentation
transform_with_aug = transforms.Compose([RC, RHF, TT, NRM])
# Transforms object for testset with NO augmentation
transform_no_aug   = transforms.Compose([TT, NRM])

# Downloading/Louding CELEBA data
trainset = torchvision.datasets.CelebA(root='../../../data/celebA', split = 'train',
                                        download=True, transform=transform_with_aug)

testset = torchvision.datasets.CelebA(root='../../../data/celebA', split='test',
                                       download=True, transform=transform_no_aug)

classDict = {'not_smiling':0, 'smiling':1}

# Separating trainset/testset data/label
x_train  = trainset
x_test   = testset
y_train  = trainset.attr[:,31] ## 31 is smile
y_test   = testset.attr[:,31] 

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified


In [6]:
def create_dict(x_train=x_train, x_test=x_test, y_train=y_train, y_test=y_test):
    # If we are saving a fraction of random data to be used in training
    frac1 = int(len(x_test) * 0.316666666667)
    frac2 = int(len(x_test) * 0.633333333333)

    x_train1 = list(range(0,frac1))
    x_train2 = list(range(frac1,frac2))
    x_train3 = list(range(frac2,len(x_test)))
    
#     trainset1 = CustomDataSet('../../../data/celebA/celeba/img_align_celeba',x=x_train1,y=y_train1, transform=transform_with_aug)
#     trainset2 = CustomDataSet('../../../data/celebA/celeba/img_align_celeba',x=x_train2,y=y_train2, transform=transform_with_aug)
#     trainset3 = CustomDataSet('../../../data/celebA/celeba/img_align_celeba',x=x_train3,y=y_train3, transform=transform_with_aug)
#     testset = CustomDataSet('../../../data/celebA/celeba/img_align_celeba',x=x_test,y=y_test, transform=transform_no_aug)

    
    dict_users = {0: x_train3, 1:x_train1, 2:x_train2}
    return dict_users

class DatasetSplit(torch.utils.data.Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label

In [7]:
gossip = []

def pretrain(args, model, device, train_loader, test_loader, optimizer):
    for epoch in range(args.num_pre_epochs):
        print('Pre epoch: {}'.format(epoch + 1))
        model.train()
        for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = regularized_nll_loss(args, model, output, target)
            loss.backward()
            optimizer.step()
        test(args, model, device, test_loader)


def train(args, model, device, train_loader, test_loader, optimizer, Z, U, report=False):
    model.train()
    print('Epoch: {}'.format(epoch + 1))
    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = admm_loss(args, device, model, Z, U, output, target)
        loss.backward()
        optimizer.step()
    X = update_X(model)
    Z = update_Z_l1(X, U, args) if args.l1 else update_Z(X, U, args)
    U = update_U(U, X, Z)
    print_convergence(model, X, Z)
    test(args, model, device, test_loader, report)


iter = 0
def test(args, model, device, test_loader, report=False):
    model.eval()
    test_loss = 0
    correct = 0
    global iter
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    gossip.append(correct / len(test_loader.dataset))
    
    if report:
#         writer.add_scalar('train/loss_gossip_admm4', test_loss, iter)
#         writer.add_scalar('valid/accuracy_gossip_admm4', correct / len(test_loader.dataset), iter)
        iter+=1


def retrain(args, model, mask, device, train_loader, test_loader, optimizer):
    for epoch in range(args.num_re_epochs):
        print('Re epoch: {}'.format(epoch + 1))
        model.train()
        for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.prune_step(mask)

        test(args, model, device, test_loader)

In [8]:
### MAIN

what_gpu = 2
torch.manual_seed(args.seed)
device = torch.device('cuda:{}'.format(what_gpu) if torch.cuda.is_available() else 'cpu')
kwargs = {'num_workers': 5, 'pin_memory': True}

args.percent = [0.6, 0.7, 0.8, 0.94, 0.95, 0.99, 0.99, 0.93]
args.num_pre_epochs = 1
args.num_epochs = 60
args.num_re_epochs = 5
args.num_users = 3
args.dataset = 'cifar10'
args.model = 'resnet'
args.l1 = True
args.l2 = False

trainset=CustomDataSet('../../../data/celebA/celeba/img_align_celeba',x=x_train.filename,y=y_train, transform=transform_with_aug)
testset=CustomDataSet('../../../data/celebA/celeba/img_align_celeba',x=x_test.filename,y=y_test, transform=transform_no_aug)
test_loader = torch.utils.data.DataLoader(testset , batch_size=args.batch_size, shuffle=True, **kwargs)
dict_users = create_dict()

In [9]:
mdlz = dict()
for usr in range(args.num_users):
    train_loader = torch.utils.data.DataLoader(DatasetSplit(trainset, dict_users[usr]), batch_size=args.batch_size, shuffle=True, **kwargs)
    if args.model == 'resnet':
        model = models.resnet50(pretrained=True)
        num_ftrs = model.fc.in_features
        model.fc = torch.nn.Linear(num_ftrs, 2)
        model.to(device)
    else:
        model = LeNet().to(device) if args.dataset == "mnist" else AlexNet().to(device)
    
    optimizer = PruneAdam(model.named_parameters(), lr=args.lr, eps=args.adam_epsilon)
    pretrain(args, model, device, train_loader, test_loader, optimizer)
    Z, U = initialize_Z_and_U(model)
    mdlz[usr] = (model, optimizer, Z, U)
    
for epoch in range(args.num_epochs):
    w_locals=[]
    idxs_users = np.random.choice(range(args.num_users), 2, replace=False)
    for usr in idxs_users:
        report = True if usr == 0 else False
        train_loader = torch.utils.data.DataLoader(DatasetSplit(trainset, dict_users[usr]), batch_size=args.batch_size, shuffle=True, **kwargs)
        train(args, mdlz[usr][0], device, train_loader, test_loader, mdlz[usr][1], mdlz[usr][2], mdlz[usr][3], report=report)
        w = mdlz[usr][0].state_dict()
        w_locals.append(copy.deepcopy(w))

    # update global weights
    w_glob = FedAvg(w_locals)

    for idx in idxs_users:
        # copy weight to net_glob
        mdlz[idx][0].load_state_dict(w_glob)

  0%|          | 0/115 [00:00<?, ?it/s]

Pre epoch: 1


100%|██████████| 115/115 [00:25<00:00,  4.53it/s]



Test set: Average loss: -365.3359, Accuracy: 17946/19962 (90%)



  0%|          | 0/99 [00:00<?, ?it/s]

Pre epoch: 1


100%|██████████| 99/99 [00:43<00:00,  2.27it/s]



Test set: Average loss: -281.9503, Accuracy: 14033/19962 (70%)



  0%|          | 0/99 [00:00<?, ?it/s]

Pre epoch: 1


100%|██████████| 99/99 [00:44<00:00,  2.20it/s]



Test set: Average loss: -302.1027, Accuracy: 17813/19962 (89%)



  0%|          | 0/99 [00:00<?, ?it/s]

Epoch: 1


100%|██████████| 99/99 [00:58<00:00,  1.68it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.3044
(bn1.weight): 0.1847
(layer1.0.conv1.weight): 0.4445
(layer1.0.bn1.weight): 0.2322
(layer1.0.conv2.weight): 0.7406
(layer1.0.bn2.weight): 0.2892
(layer1.0.conv3.weight): 0.6960
(layer1.0.bn3.weight): 0.2676
(layer1.0.downsample.0.weight): 0.5076
(layer1.0.downsample.1.weight): 0.1921
(layer1.1.conv1.weight): 0.7803
(layer1.1.bn1.weight): 0.2457
(layer1.1.conv2.weight): 0.8045
(layer1.1.bn2.weight): 0.2702
(layer1.1.conv3.weight): 0.7395
(layer1.1.bn3.weight): 0.4125
(layer1.2.conv1.weight): 0.8133
(layer1.2.bn1.weight): 0.2749
(layer1.2.conv2.weight): 0.8166
(layer1.2.bn2.weight): 0.2232
(layer1.2.conv3.weight): 0.7408
(layer1.2.bn3.weight): 0.3204
(layer2.0.conv1.weight): 0.7481
(layer2.0.bn1.weight): 0.2420
(layer2.0.conv2.weight): 0.9055
(layer2.0.bn2.weight): 0.2454
(layer2.0.conv3.weight): 0.7835
(layer2.0.bn3.weight): 0.3009
(layer2.0.downsample.0.weight): 0.7950
(layer2.0.downsample.1.weight): 0.2909
(layer2.1.conv1

  0%|          | 0/115 [00:00<?, ?it/s]


Test set: Average loss: -1197.1880, Accuracy: 9988/19962 (50%)

Epoch: 1


100%|██████████| 115/115 [01:05<00:00,  1.76it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.3086
(bn1.weight): 0.1838
(layer1.0.conv1.weight): 0.4489
(layer1.0.bn1.weight): 0.2291
(layer1.0.conv2.weight): 0.7461
(layer1.0.bn2.weight): 0.2905
(layer1.0.conv3.weight): 0.6929
(layer1.0.bn3.weight): 0.2698
(layer1.0.downsample.0.weight): 0.5116
(layer1.0.downsample.1.weight): 0.1894
(layer1.1.conv1.weight): 0.7754
(layer1.1.bn1.weight): 0.2467
(layer1.1.conv2.weight): 0.7999
(layer1.1.bn2.weight): 0.2695
(layer1.1.conv3.weight): 0.7457
(layer1.1.bn3.weight): 0.4236
(layer1.2.conv1.weight): 0.8145
(layer1.2.bn1.weight): 0.2709
(layer1.2.conv2.weight): 0.8130
(layer1.2.bn2.weight): 0.2227
(layer1.2.conv3.weight): 0.7460
(layer1.2.bn3.weight): 0.3300
(layer2.0.conv1.weight): 0.7482
(layer2.0.bn1.weight): 0.2388
(layer2.0.conv2.weight): 0.8969
(layer2.0.bn2.weight): 0.2456
(layer2.0.conv3.weight): 0.7817
(layer2.0.bn3.weight): 0.3054
(layer2.0.downsample.0.weight): 0.8059
(layer2.0.downsample.1.weight): 0.2901
(layer2.1.conv1

  0%|          | 0/115 [00:00<?, ?it/s]


Test set: Average loss: -1255.6199, Accuracy: 10769/19962 (54%)

Epoch: 2


100%|██████████| 115/115 [00:43<00:00,  2.64it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.3115
(bn1.weight): 0.1821
(layer1.0.conv1.weight): 0.4528
(layer1.0.bn1.weight): 0.2258
(layer1.0.conv2.weight): 0.7448
(layer1.0.bn2.weight): 0.2848
(layer1.0.conv3.weight): 0.6909
(layer1.0.bn3.weight): 0.2690
(layer1.0.downsample.0.weight): 0.5166
(layer1.0.downsample.1.weight): 0.1885
(layer1.1.conv1.weight): 0.7722
(layer1.1.bn1.weight): 0.2462
(layer1.1.conv2.weight): 0.7981
(layer1.1.bn2.weight): 0.2706
(layer1.1.conv3.weight): 0.7396
(layer1.1.bn3.weight): 0.4279
(layer1.2.conv1.weight): 0.8029
(layer1.2.bn1.weight): 0.2715
(layer1.2.conv2.weight): 0.8025
(layer1.2.bn2.weight): 0.2225
(layer1.2.conv3.weight): 0.7458
(layer1.2.bn3.weight): 0.3309
(layer2.0.conv1.weight): 0.7488
(layer2.0.bn1.weight): 0.2382
(layer2.0.conv2.weight): 0.8960
(layer2.0.bn2.weight): 0.2456
(layer2.0.conv3.weight): 0.7821
(layer2.0.bn3.weight): 0.3038
(layer2.0.downsample.0.weight): 0.8059
(layer2.0.downsample.1.weight): 0.2899
(layer2.1.conv1

  0%|          | 0/99 [00:00<?, ?it/s]


Test set: Average loss: -2659.3877, Accuracy: 10257/19962 (51%)

Epoch: 2


100%|██████████| 99/99 [00:37<00:00,  2.66it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.3100
(bn1.weight): 0.1854
(layer1.0.conv1.weight): 0.4440
(layer1.0.bn1.weight): 0.2305
(layer1.0.conv2.weight): 0.7437
(layer1.0.bn2.weight): 0.2850
(layer1.0.conv3.weight): 0.6969
(layer1.0.bn3.weight): 0.2705
(layer1.0.downsample.0.weight): 0.5126
(layer1.0.downsample.1.weight): 0.1873
(layer1.1.conv1.weight): 0.7674
(layer1.1.bn1.weight): 0.2445
(layer1.1.conv2.weight): 0.7954
(layer1.1.bn2.weight): 0.2685
(layer1.1.conv3.weight): 0.7399
(layer1.1.bn3.weight): 0.4322
(layer1.2.conv1.weight): 0.8055
(layer1.2.bn1.weight): 0.2740
(layer1.2.conv2.weight): 0.8093
(layer1.2.bn2.weight): 0.2225
(layer1.2.conv3.weight): 0.7467
(layer1.2.bn3.weight): 0.3363
(layer2.0.conv1.weight): 0.7457
(layer2.0.bn1.weight): 0.2399
(layer2.0.conv2.weight): 0.8948
(layer2.0.bn2.weight): 0.2453
(layer2.0.conv3.weight): 0.7803
(layer2.0.bn3.weight): 0.3019
(layer2.0.downsample.0.weight): 0.8023
(layer2.0.downsample.1.weight): 0.2836
(layer2.1.conv1

  0%|          | 0/99 [00:00<?, ?it/s]


Test set: Average loss: -2422.1152, Accuracy: 10467/19962 (52%)

Epoch: 3


100%|██████████| 99/99 [00:36<00:00,  2.68it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.3039
(bn1.weight): 0.1835
(layer1.0.conv1.weight): 0.4474
(layer1.0.bn1.weight): 0.2304
(layer1.0.conv2.weight): 0.7411
(layer1.0.bn2.weight): 0.2910
(layer1.0.conv3.weight): 0.6935
(layer1.0.bn3.weight): 0.2662
(layer1.0.downsample.0.weight): 0.5084
(layer1.0.downsample.1.weight): 0.1927
(layer1.1.conv1.weight): 0.7805
(layer1.1.bn1.weight): 0.2507
(layer1.1.conv2.weight): 0.8067
(layer1.1.bn2.weight): 0.2680
(layer1.1.conv3.weight): 0.7374
(layer1.1.bn3.weight): 0.4018
(layer1.2.conv1.weight): 0.8160
(layer1.2.bn1.weight): 0.2733
(layer1.2.conv2.weight): 0.8164
(layer1.2.bn2.weight): 0.2241
(layer1.2.conv3.weight): 0.7357
(layer1.2.bn3.weight): 0.3270
(layer2.0.conv1.weight): 0.7465
(layer2.0.bn1.weight): 0.2405
(layer2.0.conv2.weight): 0.9052
(layer2.0.bn2.weight): 0.2462
(layer2.0.conv3.weight): 0.7822
(layer2.0.bn3.weight): 0.3049
(layer2.0.downsample.0.weight): 0.7992
(layer2.0.downsample.1.weight): 0.2938
(layer2.1.conv1

  0%|          | 0/115 [00:00<?, ?it/s]


Test set: Average loss: -1006.7922, Accuracy: 10487/19962 (53%)

Epoch: 3


100%|██████████| 115/115 [00:42<00:00,  2.68it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.3130
(bn1.weight): 0.1805
(layer1.0.conv1.weight): 0.4498
(layer1.0.bn1.weight): 0.2309
(layer1.0.conv2.weight): 0.7400
(layer1.0.bn2.weight): 0.2862
(layer1.0.conv3.weight): 0.6891
(layer1.0.bn3.weight): 0.2695
(layer1.0.downsample.0.weight): 0.5154
(layer1.0.downsample.1.weight): 0.1878
(layer1.1.conv1.weight): 0.7586
(layer1.1.bn1.weight): 0.2437
(layer1.1.conv2.weight): 0.7859
(layer1.1.bn2.weight): 0.2676
(layer1.1.conv3.weight): 0.7360
(layer1.1.bn3.weight): 0.4389
(layer1.2.conv1.weight): 0.7913
(layer1.2.bn1.weight): 0.2753
(layer1.2.conv2.weight): 0.7917
(layer1.2.bn2.weight): 0.2208
(layer1.2.conv3.weight): 0.7422
(layer1.2.bn3.weight): 0.3402
(layer2.0.conv1.weight): 0.7338
(layer2.0.bn1.weight): 0.2414
(layer2.0.conv2.weight): 0.8758
(layer2.0.bn2.weight): 0.2450
(layer2.0.conv3.weight): 0.7748
(layer2.0.bn3.weight): 0.3024
(layer2.0.downsample.0.weight): 0.8003
(layer2.0.downsample.1.weight): 0.2856
(layer2.1.conv1

  0%|          | 0/99 [00:00<?, ?it/s]


Test set: Average loss: -5024.0880, Accuracy: 9975/19962 (50%)

Epoch: 4


100%|██████████| 99/99 [00:36<00:00,  2.68it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.3120
(bn1.weight): 0.1836
(layer1.0.conv1.weight): 0.4486
(layer1.0.bn1.weight): 0.2289
(layer1.0.conv2.weight): 0.7417
(layer1.0.bn2.weight): 0.2794
(layer1.0.conv3.weight): 0.6912
(layer1.0.bn3.weight): 0.2753
(layer1.0.downsample.0.weight): 0.5149
(layer1.0.downsample.1.weight): 0.1841
(layer1.1.conv1.weight): 0.7643
(layer1.1.bn1.weight): 0.2462
(layer1.1.conv2.weight): 0.7934
(layer1.1.bn2.weight): 0.2690
(layer1.1.conv3.weight): 0.7375
(layer1.1.bn3.weight): 0.4231
(layer1.2.conv1.weight): 0.7851
(layer1.2.bn1.weight): 0.2748
(layer1.2.conv2.weight): 0.7948
(layer1.2.bn2.weight): 0.2225
(layer1.2.conv3.weight): 0.7447
(layer1.2.bn3.weight): 0.3382
(layer2.0.conv1.weight): 0.7355
(layer2.0.bn1.weight): 0.2409
(layer2.0.conv2.weight): 0.8828
(layer2.0.bn2.weight): 0.2468
(layer2.0.conv3.weight): 0.7762
(layer2.0.bn3.weight): 0.3023
(layer2.0.downsample.0.weight): 0.8005
(layer2.0.downsample.1.weight): 0.2782
(layer2.1.conv1

  0%|          | 0/99 [00:00<?, ?it/s]


Test set: Average loss: -4516.5071, Accuracy: 9826/19962 (49%)

Epoch: 4


100%|██████████| 99/99 [00:36<00:00,  2.69it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.3108
(bn1.weight): 0.1798
(layer1.0.conv1.weight): 0.4508
(layer1.0.bn1.weight): 0.2298
(layer1.0.conv2.weight): 0.7473
(layer1.0.bn2.weight): 0.2902
(layer1.0.conv3.weight): 0.6913
(layer1.0.bn3.weight): 0.2686
(layer1.0.downsample.0.weight): 0.5170
(layer1.0.downsample.1.weight): 0.1907
(layer1.1.conv1.weight): 0.7683
(layer1.1.bn1.weight): 0.2495
(layer1.1.conv2.weight): 0.7968
(layer1.1.bn2.weight): 0.2683
(layer1.1.conv3.weight): 0.7406
(layer1.1.bn3.weight): 0.4331
(layer1.2.conv1.weight): 0.8038
(layer1.2.bn1.weight): 0.2751
(layer1.2.conv2.weight): 0.8075
(layer1.2.bn2.weight): 0.2237
(layer1.2.conv3.weight): 0.7411
(layer1.2.bn3.weight): 0.3384
(layer2.0.conv1.weight): 0.7426
(layer2.0.bn1.weight): 0.2435
(layer2.0.conv2.weight): 0.8958
(layer2.0.bn2.weight): 0.2461
(layer2.0.conv3.weight): 0.7813
(layer2.0.bn3.weight): 0.3103
(layer2.0.downsample.0.weight): 0.8022
(layer2.0.downsample.1.weight): 0.2907
(layer2.1.conv1

  0%|          | 0/115 [00:00<?, ?it/s]


Test set: Average loss: -4576.9743, Accuracy: 9975/19962 (50%)

Epoch: 5


100%|██████████| 115/115 [00:43<00:00,  2.64it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.3083
(bn1.weight): 0.1817
(layer1.0.conv1.weight): 0.4471
(layer1.0.bn1.weight): 0.2318
(layer1.0.conv2.weight): 0.7420
(layer1.0.bn2.weight): 0.2898
(layer1.0.conv3.weight): 0.6935
(layer1.0.bn3.weight): 0.2699
(layer1.0.downsample.0.weight): 0.5114
(layer1.0.downsample.1.weight): 0.1904
(layer1.1.conv1.weight): 0.7719
(layer1.1.bn1.weight): 0.2483
(layer1.1.conv2.weight): 0.7993
(layer1.1.bn2.weight): 0.2680
(layer1.1.conv3.weight): 0.7365
(layer1.1.bn3.weight): 0.4271
(layer1.2.conv1.weight): 0.8128
(layer1.2.bn1.weight): 0.2754
(layer1.2.conv2.weight): 0.8153
(layer1.2.bn2.weight): 0.2230
(layer1.2.conv3.weight): 0.7386
(layer1.2.bn3.weight): 0.3296
(layer2.0.conv1.weight): 0.7459
(layer2.0.bn1.weight): 0.2417
(layer2.0.conv2.weight): 0.9034
(layer2.0.bn2.weight): 0.2458
(layer2.0.conv3.weight): 0.7797
(layer2.0.bn3.weight): 0.3003
(layer2.0.downsample.0.weight): 0.8002
(layer2.0.downsample.1.weight): 0.2884
(layer2.1.conv1

  0%|          | 0/99 [00:00<?, ?it/s]


Test set: Average loss: -4098.6146, Accuracy: 9987/19962 (50%)

Epoch: 5


100%|██████████| 99/99 [00:37<00:00,  2.67it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.3097
(bn1.weight): 0.1825
(layer1.0.conv1.weight): 0.4502
(layer1.0.bn1.weight): 0.2314
(layer1.0.conv2.weight): 0.7436
(layer1.0.bn2.weight): 0.2829
(layer1.0.conv3.weight): 0.6968
(layer1.0.bn3.weight): 0.2677
(layer1.0.downsample.0.weight): 0.5163
(layer1.0.downsample.1.weight): 0.1896
(layer1.1.conv1.weight): 0.7654
(layer1.1.bn1.weight): 0.2493
(layer1.1.conv2.weight): 0.7920
(layer1.1.bn2.weight): 0.2690
(layer1.1.conv3.weight): 0.7378
(layer1.1.bn3.weight): 0.4291
(layer1.2.conv1.weight): 0.7959
(layer1.2.bn1.weight): 0.2758
(layer1.2.conv2.weight): 0.8016
(layer1.2.bn2.weight): 0.2213
(layer1.2.conv3.weight): 0.7432
(layer1.2.bn3.weight): 0.3384
(layer2.0.conv1.weight): 0.7414
(layer2.0.bn1.weight): 0.2430
(layer2.0.conv2.weight): 0.8916
(layer2.0.bn2.weight): 0.2467
(layer2.0.conv3.weight): 0.7800
(layer2.0.bn3.weight): 0.3041
(layer2.0.downsample.0.weight): 0.8053
(layer2.0.downsample.1.weight): 0.2838
(layer2.1.conv1

  0%|          | 0/99 [00:00<?, ?it/s]


Test set: Average loss: -7020.0787, Accuracy: 9975/19962 (50%)

Epoch: 6


100%|██████████| 99/99 [00:37<00:00,  2.67it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.3127
(bn1.weight): 0.1782
(layer1.0.conv1.weight): 0.4497
(layer1.0.bn1.weight): 0.2268
(layer1.0.conv2.weight): 0.7456
(layer1.0.bn2.weight): 0.2825
(layer1.0.conv3.weight): 0.6865
(layer1.0.bn3.weight): 0.2673
(layer1.0.downsample.0.weight): 0.5169
(layer1.0.downsample.1.weight): 0.1881
(layer1.1.conv1.weight): 0.7563
(layer1.1.bn1.weight): 0.2451
(layer1.1.conv2.weight): 0.7825
(layer1.1.bn2.weight): 0.2662
(layer1.1.conv3.weight): 0.7327
(layer1.1.bn3.weight): 0.4162
(layer1.2.conv1.weight): 0.7732
(layer1.2.bn1.weight): 0.2729
(layer1.2.conv2.weight): 0.7882
(layer1.2.bn2.weight): 0.2230
(layer1.2.conv3.weight): 0.7375
(layer1.2.bn3.weight): 0.3427
(layer2.0.conv1.weight): 0.7249
(layer2.0.bn1.weight): 0.2402
(layer2.0.conv2.weight): 0.8736
(layer2.0.bn2.weight): 0.2457
(layer2.0.conv3.weight): 0.7729
(layer2.0.bn3.weight): 0.3141
(layer2.0.downsample.0.weight): 0.7983
(layer2.0.downsample.1.weight): 0.2841
(layer2.1.conv1

  0%|          | 0/99 [00:00<?, ?it/s]


Test set: Average loss: -7197.7097, Accuracy: 9975/19962 (50%)

Epoch: 6


100%|██████████| 99/99 [00:37<00:00,  2.66it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.3093
(bn1.weight): 0.1833
(layer1.0.conv1.weight): 0.4477
(layer1.0.bn1.weight): 0.2313
(layer1.0.conv2.weight): 0.7434
(layer1.0.bn2.weight): 0.2876
(layer1.0.conv3.weight): 0.6926
(layer1.0.bn3.weight): 0.2701
(layer1.0.downsample.0.weight): 0.5150
(layer1.0.downsample.1.weight): 0.1939
(layer1.1.conv1.weight): 0.7551
(layer1.1.bn1.weight): 0.2475
(layer1.1.conv2.weight): 0.7826
(layer1.1.bn2.weight): 0.2674
(layer1.1.conv3.weight): 0.7313
(layer1.1.bn3.weight): 0.4322
(layer1.2.conv1.weight): 0.7767
(layer1.2.bn1.weight): 0.2781
(layer1.2.conv2.weight): 0.7862
(layer1.2.bn2.weight): 0.2220
(layer1.2.conv3.weight): 0.7422
(layer1.2.bn3.weight): 0.3238
(layer2.0.conv1.weight): 0.7330
(layer2.0.bn1.weight): 0.2421
(layer2.0.conv2.weight): 0.8848
(layer2.0.bn2.weight): 0.2467
(layer2.0.conv3.weight): 0.7766
(layer2.0.bn3.weight): 0.3104
(layer2.0.downsample.0.weight): 0.7987
(layer2.0.downsample.1.weight): 0.2863
(layer2.1.conv1

  0%|          | 0/115 [00:00<?, ?it/s]


Test set: Average loss: -7930.7488, Accuracy: 9975/19962 (50%)

Epoch: 7


100%|██████████| 115/115 [00:42<00:00,  2.68it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.3095
(bn1.weight): 0.1797
(layer1.0.conv1.weight): 0.4565
(layer1.0.bn1.weight): 0.2281
(layer1.0.conv2.weight): 0.7412
(layer1.0.bn2.weight): 0.2856
(layer1.0.conv3.weight): 0.6847
(layer1.0.bn3.weight): 0.2653
(layer1.0.downsample.0.weight): 0.5181
(layer1.0.downsample.1.weight): 0.1899
(layer1.1.conv1.weight): 0.7560
(layer1.1.bn1.weight): 0.2481
(layer1.1.conv2.weight): 0.7858
(layer1.1.bn2.weight): 0.2667
(layer1.1.conv3.weight): 0.7367
(layer1.1.bn3.weight): 0.4203
(layer1.2.conv1.weight): 0.7919
(layer1.2.bn1.weight): 0.2746
(layer1.2.conv2.weight): 0.7987
(layer1.2.bn2.weight): 0.2215
(layer1.2.conv3.weight): 0.7431
(layer1.2.bn3.weight): 0.3493
(layer2.0.conv1.weight): 0.7336
(layer2.0.bn1.weight): 0.2407
(layer2.0.conv2.weight): 0.8851
(layer2.0.bn2.weight): 0.2455
(layer2.0.conv3.weight): 0.7761
(layer2.0.bn3.weight): 0.3014
(layer2.0.downsample.0.weight): 0.8033
(layer2.0.downsample.1.weight): 0.2823
(layer2.1.conv1

  0%|          | 0/99 [00:00<?, ?it/s]


Test set: Average loss: -8447.1964, Accuracy: 9975/19962 (50%)

Epoch: 7


100%|██████████| 99/99 [00:37<00:00,  2.66it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.3086
(bn1.weight): 0.1811
(layer1.0.conv1.weight): 0.4482
(layer1.0.bn1.weight): 0.2283
(layer1.0.conv2.weight): 0.7428
(layer1.0.bn2.weight): 0.2863
(layer1.0.conv3.weight): 0.6910
(layer1.0.bn3.weight): 0.2673
(layer1.0.downsample.0.weight): 0.5148
(layer1.0.downsample.1.weight): 0.1894
(layer1.1.conv1.weight): 0.7614
(layer1.1.bn1.weight): 0.2473
(layer1.1.conv2.weight): 0.7876
(layer1.1.bn2.weight): 0.2673
(layer1.1.conv3.weight): 0.7380
(layer1.1.bn3.weight): 0.4226
(layer1.2.conv1.weight): 0.7855
(layer1.2.bn1.weight): 0.2788
(layer1.2.conv2.weight): 0.7898
(layer1.2.bn2.weight): 0.2224
(layer1.2.conv3.weight): 0.7438
(layer1.2.bn3.weight): 0.3333
(layer2.0.conv1.weight): 0.7286
(layer2.0.bn1.weight): 0.2416
(layer2.0.conv2.weight): 0.8769
(layer2.0.bn2.weight): 0.2447
(layer2.0.conv3.weight): 0.7732
(layer2.0.bn3.weight): 0.3098
(layer2.0.downsample.0.weight): 0.7996
(layer2.0.downsample.1.weight): 0.2837
(layer2.1.conv1

  0%|          | 0/99 [00:00<?, ?it/s]


Test set: Average loss: -9655.8268, Accuracy: 9975/19962 (50%)

Epoch: 8


100%|██████████| 99/99 [00:37<00:00,  2.65it/s]


normalized norm of (weight - projection)
(conv1.weight): 0.3118
(bn1.weight): 0.1800
(layer1.0.conv1.weight): 0.4517
(layer1.0.bn1.weight): 0.2287
(layer1.0.conv2.weight): 0.7401
(layer1.0.bn2.weight): 0.2825
(layer1.0.conv3.weight): 0.6837
(layer1.0.bn3.weight): 0.2635
(layer1.0.downsample.0.weight): 0.5183
(layer1.0.downsample.1.weight): 0.1912
(layer1.1.conv1.weight): 0.7497
(layer1.1.bn1.weight): 0.2422
(layer1.1.conv2.weight): 0.7789
(layer1.1.bn2.weight): 0.2653
(layer1.1.conv3.weight): 0.7288
(layer1.1.bn3.weight): 0.4281
(layer1.2.conv1.weight): 0.7622
(layer1.2.bn1.weight): 0.2733
(layer1.2.conv2.weight): 0.7779
(layer1.2.bn2.weight): 0.2247
(layer1.2.conv3.weight): 0.7401
(layer1.2.bn3.weight): 0.3506
(layer2.0.conv1.weight): 0.7243
(layer2.0.bn1.weight): 0.2429
(layer2.0.conv2.weight): 0.8596
(layer2.0.bn2.weight): 0.2483
(layer2.0.conv3.weight): 0.7630
(layer2.0.bn3.weight): 0.3078
(layer2.0.downsample.0.weight): 0.7893
(layer2.0.downsample.1.weight): 0.2757
(layer2.1.conv1

  0%|          | 0/115 [00:00<?, ?it/s]


Test set: Average loss: -10465.6272, Accuracy: 9975/19962 (50%)

Epoch: 8


 53%|█████▎    | 61/115 [00:22<00:20,  2.66it/s]


KeyboardInterrupt: 

In [None]:
### Apply pruning
for usr in range(args.num_users):
    mask = apply_l1_prune(mdlz[usr][0], device, args) if args.l1 else apply_prune(mdlz[usr][0], device, args)
    print_prune(model)
    test(args, mdlz[usr][0], device, test_loader)
    retrain(args, model, mask, device, train_loader, test_loader, optimizer)

# Testing ground

In [None]:
from torch.utils.data import DataLoader
import torch.nn.functional as F

confusion_matrix = torch.zeros(10, 10)
mdlz[0][0].eval()
# testing
test_loss = 0
correct = 0

l = len(test_loader)
for idx, (data, target) in enumerate(test_loader):
    data, target = data.to(device), target.to(device)
    log_probs = mdlz[0][0](data)
    # sum up batch loss
    test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
    # get the index of the max log-probability
    y_pred = log_probs.data.max(1, keepdim=True)[1]
    correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

    for t, p in zip(target.data, y_pred.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1

In [None]:
print(confusion_matrix)

In [None]:
gossip