In [1]:
import sys
sys.argv = ['']

In [2]:
import os
import models
import torch
import torch.backends.cudnn as cudnn
from config import cfg
from data import fetch_dataset, make_data_loader, split_dataset, SplitDataset
from utils import save, to_device, process_control, process_dataset, make_optimizer, make_scheduler, resume, collate

from masking_functions import SNIP

import train_classifier_fed

In [3]:
process_control()
cfg['data_name'] = 'CIFAR10'
cfg['model_name'] = 'conv'
cfg["data_split_mode"] ='non-iid-2'
cfg["num_users"] = 100

seeds = list(range(cfg['init_seed'], cfg['init_seed'] + cfg['num_experiments']))
model_tag_list = [str(seeds[0]), cfg['data_name'], cfg['subset'], cfg['model_name'], cfg['control_name']]
cfg['model_tag'] = '_'.join([x for x in model_tag_list if x])

In [4]:
print(cfg['device'])
print(cfg["data_split_mode"])
print(cfg["num_users"])

cuda
non-iid-2
100


In [5]:
dataset = fetch_dataset(cfg['data_name'], cfg['subset'])
process_dataset(dataset)

fetching data CIFAR10...
data ready


In [6]:
dataset['train']

Dataset CIFAR10
Size: 50000
Root: ./data/CIFAR10
Split: train
Subset: label
Transforms: Compose(
    RandomCrop(size=(32, 32), padding=4)
    RandomHorizontalFlip(p=0.5)
    ToTensor()
    Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.201))
)

In [7]:
model = eval('models.{}(model_rate=cfg["global_model_rate"], track=True).to(cfg["device"]).to(cfg["device"])'
                 .format(cfg['model_name']))
optimizer = make_optimizer(model, cfg['lr'])
scheduler = make_scheduler(optimizer)

In [8]:
print(model)

Conv(
  (blocks): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Scaler()
    (2): BatchNorm2d(64, eps=1e-05, momentum=None, affine=True, track_running_stats=True)
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): Scaler()
    (7): BatchNorm2d(128, eps=1e-05, momentum=None, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): Scaler()
    (12): BatchNorm2d(256, eps=1e-05, momentum=None, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (15): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)

In [9]:
data_split = None

In [10]:
if data_split is None:
    data_split, label_split = split_dataset(dataset, cfg['num_users'], cfg['data_split_mode'])

In [11]:
print(len(data_split['train'][0]))
print(data_split['train'][0][0])
print(label_split[0])

500
29626
[7, 8]


In [12]:
data_loader = make_data_loader({'train': SplitDataset(dataset['train'], data_split['train'][0])})['train']

In [13]:
keep_mask = SNIP(model, 0.05, data_loader, cfg['device'])

tensor(77756, device='cuda:0')


In [24]:
print(type(keep_mask[0]))

<class 'torch.Tensor'>


In [26]:
print(keep_mask[0].to('cpu'))

tensor([[[[1., 1., 1.],
          [1., 0., 1.],
          [1., 1., 0.]],

         [[1., 0., 1.],
          [0., 0., 0.],
          [1., 1., 1.]],

         [[0., 1., 0.],
          [1., 1., 1.],
          [0., 1., 1.]]],


        [[[0., 1., 1.],
          [0., 1., 0.],
          [1., 0., 1.]],

         [[0., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [0., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [0., 0., 0.],
          [0., 1., 1.]],

         [[1., 1., 1.],
          [0., 1., 1.],
          [1., 0., 0.]],

         [[0., 0., 1.],
          [1., 1., 1.],
          [1., 1., 0.]]],


        ...,


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 0., 0.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 