In [1]:
import sys
sys.argv = ['']

In [41]:
import os
import models
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from config import cfg
from data import fetch_dataset, make_data_loader, split_dataset, SplitDataset
from utils import save, to_device, process_control, process_dataset, make_optimizer, make_scheduler, resume, collate
from logger import Logger
from metrics import Metric
import time
import datetime
import shutil
import copy

from masking_functions import SNIP

from sklearn.manifold import TSNE
from sklearn.datasets import load_digits

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import train_classifier

In [37]:
process_control()
cfg['data_name'] = 'CIFAR10'
cfg['model_name'] = 'conv'
cfg["data_split_mode"] ='non-iid-2'
cfg["num_users"] = 100
cfg["batch_size"]["train"] = 128
cfg["batch_size"]["test"] = 128
cfg["num_epochs"] = 50
cfg["prune_rate"] = 0.05

seeds = list(range(cfg['init_seed'], cfg['init_seed'] + cfg['num_experiments']))
model_tag_list = [str(seeds[0]), cfg['data_name'], cfg['subset'], cfg['model_name'], cfg['control_name']]
cfg['model_tag'] = '_'.join([x for x in model_tag_list if x])

In [4]:
print(cfg['device'])
print(cfg["data_split_mode"])
print(cfg["num_users"])
print(cfg["batch_size"]["train"])

cuda
non-iid-2
100
128


In [5]:
dataset = fetch_dataset(cfg['data_name'], cfg['subset'])
process_dataset(dataset)

fetching data CIFAR10...
data ready


In [6]:
dataset['train']

Dataset CIFAR10
Size: 50000
Root: ./data/CIFAR10
Split: train
Subset: label
Transforms: Compose(
    RandomCrop(size=(32, 32), padding=4)
    RandomHorizontalFlip(p=0.5)
    ToTensor()
    Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.201))
)

In [7]:
model = eval('models.{}(model_rate=cfg["global_model_rate"], track=True).to(cfg["device"]).to(cfg["device"])'
                 .format(cfg['model_name']))
optimizer = make_optimizer(model, cfg['lr'])
scheduler = make_scheduler(optimizer)

In [8]:
print(model)

Conv(
  (blocks): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Scaler()
    (2): BatchNorm2d(64, eps=1e-05, momentum=None, affine=True, track_running_stats=True)
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): Scaler()
    (7): BatchNorm2d(128, eps=1e-05, momentum=None, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): Scaler()
    (12): BatchNorm2d(256, eps=1e-05, momentum=None, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (15): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)

In [9]:
data_split = None

In [10]:
if data_split is None:
    data_split, label_split = split_dataset(dataset, cfg['num_users'], cfg['data_split_mode'])

In [11]:
print(len(data_split['train'][0]))
print(data_split['train'][0][0])
print(label_split[0])

500
14924
[3, 7]


In [12]:
data_loader = make_data_loader({'train': SplitDataset(dataset['train'], data_split['train'][0])})['train']
test_loader = make_data_loader({'test': SplitDataset(dataset['test'], data_split['test'][0])})['test']

In [13]:
batch = collate(next(iter(data_loader)))
print(len(batch['label']))

128


In [14]:
keep_mask = SNIP(model, cfg["prune_rate"], data_loader, cfg['device'])

tensor(77756, device='cuda:0')


In [15]:
print((keep_mask[0].shape))

torch.Size([64, 3, 3, 3])


In [27]:
def flatten_mask(mask):
    flat_mask = []
    for m in mask:
        flat_mask.extend(m.flatten().to(torch.long).tolist())
    return flat_mask

def mask_similarity(mask1, mask2):
    flat_mask1 = []
    flat_mask2 = []
    for m1, m2 in zip(mask1, mask2):
        flat_mask1.extend(m1.flatten().to(torch.long).tolist())
        flat_mask2.extend(m2.flatten().to(torch.long).tolist())
    paramnum = len(flat_mask1)
    masknum = int(sum(flat_mask1))
    overlap_list = [p1 == p2 == 1 for p1, p2 in zip(flat_mask1, flat_mask2)]
    overlap = int(sum(overlap_list))
    return overlap / masknum

def layer_similarity(mask1, mask2):
    masknums = []
    overlaps = []
    for m1, m2 in zip(mask1, mask2):
        fm1 = m1.flatten().to(torch.long).tolist()
        fm2 = m2.flatten().to(torch.long).tolist()
        masknums.append(sum(fm1))
        overlaps.append(sum([p1 == p2 == 1 for p1, p2 in zip(fm1, fm2)]))
        
    return [o/p for o, p in zip(overlaps, masknums)]

def print_similarity(a, b):
    print(f"the overlap between {client_label[a]}  and {client_label[b]} is {mask_similarity(client_masks[a], client_masks[b])}")

In [17]:
client_data = []
client_label = label_split
for m in range(cfg["num_users"]):
    client_data.append(make_data_loader({'train': SplitDataset(dataset['train'], data_split['train'][m])})['train'])

In [38]:
client_masks = []
for m in range(cfg["num_users"]):
    client_masks.append(SNIP(model, cfg["prune_rate"], client_data[m], cfg['device']))

tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77757, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(77756, device='cuda:0')
tensor(7

In [35]:
for l in client_label:
    print(l)

[3, 7]
[3, 4]
[4, 6]
[0, 3]
[2, 6]
[6, 8]
[1, 7]
[6, 8]
[3, 6]
[7]
[0, 3]
[2, 7]
[5, 9]
[2, 9]
[3, 5]
[1]
[7, 9]
[5, 8]
[7, 9]
[6, 9]
[0, 7]
[0]
[1, 3]
[1, 9]
[5, 8]
[2, 4]
[0, 3]
[1, 4]
[0, 4]
[4, 7]
[0, 9]
[0, 3]
[3, 9]
[1, 6]
[6, 8]
[1, 8]
[4, 5]
[3, 4]
[1, 8]
[7, 8]
[3, 8]
[0, 6]
[5, 6]
[0, 6]
[4, 7]
[3, 6]
[5]
[8]
[0, 3]
[5, 9]
[0, 7]
[2, 5]
[4, 7]
[4, 5]
[1]
[2, 9]
[2, 8]
[7, 8]
[2, 3]
[2, 6]
[4, 5]
[2, 7]
[6, 8]
[5]
[4, 7]
[3, 4]
[1, 6]
[2, 9]
[2, 5]
[6, 7]
[5, 8]
[0, 5]
[7, 8]
[0, 1]
[9]
[1, 8]
[1, 3]
[0, 2]
[0, 6]
[2, 7]
[2, 6]
[2, 5]
[8, 9]
[1, 3]
[1, 8]
[2, 6]
[7, 9]
[1, 2]
[4]
[1]
[2, 9]
[4]
[0, 5]
[9]
[0, 3]
[8, 9]
[0, 4]
[2, 3]
[5, 9]
[4, 6]


In [39]:
tocompare = 3
for i in range(30):
    print_similarity(tocompare, i)

the overlap between [0, 3]  and [3, 7] is 0.32751170327691753
the overlap between [0, 3]  and [3, 4] is 0.320296825968414
the overlap between [0, 3]  and [4, 6] is 0.26966407737023507
the overlap between [0, 3]  and [0, 3] is 1.0
the overlap between [0, 3]  and [2, 6] is 0.35062246000308656
the overlap between [0, 3]  and [6, 8] is 0.3525001286074387
the overlap between [0, 3]  and [1, 7] is 0.33342764545501313
the overlap between [0, 3]  and [6, 8] is 0.35320746952003707
the overlap between [0, 3]  and [3, 6] is 0.34653274345388135
the overlap between [0, 3]  and [7] is 0.28793919440300425
the overlap between [0, 3]  and [0, 3] is 0.696370698081177
the overlap between [0, 3]  and [2, 7] is 0.2932764031071557
the overlap between [0, 3]  and [5, 9] is 0.34284170996450436
the overlap between [0, 3]  and [2, 9] is 0.3291450177478265
the overlap between [0, 3]  and [3, 5] is 0.35494366994186943
the overlap between [0, 3]  and [1] is 0.2984850043726529
the overlap between [0, 3]  and [7, 9]

In [42]:
# tsne print
n_components = 2
learning_rate = 500
model = TSNE(n_components=n_components, learning_rate=learning_rate)

In [43]:
flat_masks = []
for m in range(cfg['num_users']):
    flat_masks.append(flatten_mask(client_masks[m]))
embedded_masks = model.fit_transform(flat_masks)
print(embedded_masks.shape)
em = pd.DataFrame(embedded_masks, columns=["x", "y"])
em['labels'] = [str(l) for l in labels]



(100, 2)


NameError: name 'labels' is not defined

In [None]:
%matplotlib inline
plt.figure(figsize=(5, 5))
sns.scatterplot(
    x="x",
    y="y",
    data=em[:50],
    hue="labels",
    s=40.
)

In [17]:
def apply_prune_mask(net, keep_masks):

    # Before I can zip() layers and pruning masks I need to make sure they match
    # one-to-one by removing all the irrelevant modules:
    prunable_layers = filter(
        lambda layer: isinstance(layer, nn.Conv2d) or isinstance(
            layer, nn.Linear), net.modules())

    for layer, keep_mask in zip(prunable_layers, keep_masks):
        assert (layer.weight.shape == keep_mask.shape)

        def hook_factory(keep_mask):
            """
            The hook function can't be defined directly here because of Python's
            late binding which would result in all hooks getting the very last
            mask! Getting it through another function forces early binding.
            """

            def hook(grads):
                return grads * keep_mask

            return hook

        # mask[i] == 0 --> Prune parameter
        # mask[i] == 1 --> Keep parameter

        # Step 1: Set the masked weights to zero (NB the biases are ignored)
        # Step 2: Make sure their gradients remain zero
        layer.weight.data[keep_mask == 0.] = 0.
        layer.weight.register_hook(hook_factory(keep_mask))

In [18]:
def train(data_loader, model, optimizer, logger, epoch):
    metric = Metric()
    model.train(True)
    start_time = time.time()
    for i, input in enumerate(data_loader):
        input = collate(input)
        input_size = input['img'].size(0)
        input = to_device(input, cfg['device'])
        optimizer.zero_grad()
        output = model(input)
        output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss']
        output['loss'].backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()
        evaluation = metric.evaluate(cfg['metric_name']['train'], input, output)
        logger.append(evaluation, 'train', n=input_size)
        if i % int((len(data_loader) * cfg['log_interval']) + 1) == 0:
            batch_time = (time.time() - start_time) / (i + 1)
            lr = optimizer.param_groups[0]['lr']
            epoch_finished_time = datetime.timedelta(seconds=round(batch_time * (len(data_loader) - i - 1)))
            exp_finished_time = epoch_finished_time + datetime.timedelta(
                seconds=round((cfg['num_epochs'] - epoch) * batch_time * len(data_loader)))
            info = {'info': ['Model: {}'.format(cfg['model_tag']),
                             'Train Epoch: {}({:.0f}%)'.format(epoch, 100. * i / len(data_loader)),
                             'Learning rate: {}'.format(lr), 'Epoch Finished Time: {}'.format(epoch_finished_time),
                             'Experiment Finished Time: {}'.format(exp_finished_time)]}
            logger.append(info, 'train', mean=False)
            logger.write('train', cfg['metric_name']['train'])
    return
def stats(data_loader, model):
    with torch.no_grad():
        test_model = eval('models.{}(model_rate=cfg["global_model_rate"], track=True).to(cfg["device"])'
                          .format(cfg['model_name']))
        test_model.load_state_dict(model.state_dict(), strict=False)
        test_model.train(True)
        for i, input in enumerate(data_loader):
            input = collate(input)
            input = to_device(input, cfg['device'])
            test_model(input)
    return test_model

def test(data_loader, model, logger, epoch):
    with torch.no_grad():
        metric = Metric()
        model.train(False)
        for i, input in enumerate(data_loader):
            input = collate(input)
            input_size = input['img'].size(0)
            input = to_device(input, cfg['device'])
            output = model(input)
            output['loss'] = output['loss'].mean() if cfg['world_size'] > 1 else output['loss']
            evaluation = metric.evaluate(cfg['metric_name']['test'], input, output)
            logger.append(evaluation, 'test', input_size)
        info = {'info': ['Model: {}'.format(cfg['model_tag']), 'Test Epoch: {}({:.0f}%)'.format(epoch, 100.)]}
        logger.append(info, 'test', mean=False)
        logger.write('test', cfg['metric_name']['test'])
    return

In [22]:
model_copy = copy.deepcopy(model)
optimizer = make_optimizer(model_copy, cfg['lr'])
scheduler = make_scheduler(optimizer)
apply_prune_mask(model_copy, keep_mask)
last_epoch = 1
logger_path = os.path.join('output', 'runs', 'train_{}'.format(cfg['model_tag']))
logger = Logger(logger_path)
for epoch in range(last_epoch, cfg['num_epochs'] + 1):
        logger.safe(True)
        train(data_loader, model_copy, optimizer, logger, epoch)
        test_model = stats(data_loader, model_copy)
        test(test_loader, test_model, logger, epoch)

Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 1(0%)  Loss: 2.4564  Accuracy: 0.0000  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:14
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 1(50%)  Loss: 1.9602  Accuracy: 39.8438  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:12
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 1(100%)  Loss: 0.7750  Accuracy: 65.0000
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 2(0%)  Loss: 1.5119  Accuracy: 49.5223  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:10
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 2(50%)  Loss: 1.2631  Accuracy: 53.9593  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:13
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 2(100%)  Loss: 0.6599  Accuracy: 70.0000
Mod

Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 17(100%)  Loss: 1.2864  Accuracy: 69.6471
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 18(0%)  Loss: 0.4820  Accuracy: 79.5665  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:08
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 18(50%)  Loss: 0.4783  Accuracy: 79.6601  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:06
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 18(100%)  Loss: 1.2664  Accuracy: 70.0556
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 19(0%)  Loss: 0.4734  Accuracy: 79.7875  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:05
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 19(50%)  Loss: 0.4683  Accuracy: 80.0192  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00

Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 34(100%)  Loss: 1.2670  Accuracy: 72.6765
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 35(0%)  Loss: 0.3757  Accuracy: 84.1429  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:03
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 35(50%)  Loss: 0.3728  Accuracy: 84.2556  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:03
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Test Epoch: 35(100%)  Loss: 1.2498  Accuracy: 73.0286
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 36(0%)  Loss: 0.3709  Accuracy: 84.3317  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00:03
Model: 0_CIFAR10_label_conv_1_100_0.1_iid_fix_a1_bn_1_1  Train Epoch: 36(50%)  Loss: 0.3689  Accuracy: 84.3939  Learning rate: 0.1  Epoch Finished Time: 0:00:00  Experiment Finished Time: 0:00