In [1]:
# Check if in correct environment
import sys, os, time, datetime
is_conda = os.path.exists(os.path.join(sys.prefix, 'conda-meta'))
if not is_conda:
    sys.exit('Not in conda, check environment.')

In [2]:
import numpy as np
import torch
import shelve
import math

In [3]:
# Default args (convenience class)
class Arguments():
    def __init__(self):
        self.network = 'lenet5'
        self.optim = 'sgd'
        self.lr_scheduler = 'lambda'
        self.description = ''
        
        self.seed = 1
        self.log_frequency = 2
        self.use_cuda = True
        self.save_model = False
        
        self.batch_size = 10
        self.test_batch_size = 64
        self.epochs = 'auto'        
        # For fed learn this means each worker has own shuffler, does 'NOT' mean data federated randomly to workers
        self.shuffle = True
        # Mutually exclusive with shuffle
        self.use_weighted_random_sampler = False        
        self.weight_decay = [0, 1e-2, 1e-4, 1e-7][3]
        self.lr = [1e0, 1e-1, 5e-2, 1e-2, 5e-3, 1e-3, 1e-4, 1e-5][3]
        
        ### lr scheduler plateau ###
        # Number of epochs with no improvement after which learning rate will be reduced
        # ="will tolerate 2 bad epochs, if next epoch is bad, step, if not, reset bad epoch num to zero"
        self.lr_scheduler_patience = 2
        # lr*gamma to decrease
        self.lr_scheduler_gamma = 0.1 
        ### lr scheduler lambda ###
        self.lr_lambda = lambda epoch: 1.0 / (1+epoch) # decrease fast in the begining
        #self.lr_lambda = lambda epoch: 0.9 ** epoch # decrease slow, but may become too large in later epochs
        #self.lr_lambda = lambda epoch: 1.0 # dummy lr
        
        self.federate = True
        self.rounds = 2
        self.epoch_per_round = 1
        self.global_model = True
        self.fedadam = False
        self.fedprox = False
        # NOTE: v1 does not include buffers eg. means/variances (b/c using named_parameters())
        #       v2 use state_dict()
        #       v3 mt-fedavg
        self.fedavgver = 2
        self.fed_worker_num = 10
        self.class_per_worker = 5
        self.dist_scheme = 'permuted'
        self.custom_mapping = {0: [0, 1, 3, 4, 5], 1: [2, 6, 7, 8, 9], 2: [1, 4, 5, 6, 8], 3: [0, 2, 3, 7, 9],
                               4: [1, 2, 4, 5, 9], 5: [0, 3, 6, 7, 8], 6: [1, 2, 3, 5, 6], 7: [0, 4, 7, 8, 9],
                               8: [2, 3, 4, 8, 9], 9: [0, 1, 5, 6, 7]}
#         self.custom_mapping = {0: [9, 1], 1: [4, 7], 2: [2, 5], 3: [3, 6], 4: [0, 8],
#                                5: [8, 6], 6: [0, 2], 7: [3, 7], 8: [1, 5], 9: [9, 4],
#                                10:[0, 2],11: [1, 6],12: [3, 5],13: [9, 7],14: [8, 4],
#                                15:[9, 6],16: [4, 5],17: [0, 1],18: [8, 3],19: [2, 7]}
        self.uniqueness_threshold = 3 # for scheme 'choose-unique': max number of same class for all workers
        
        self.use_pysyft = True
        self.pysyft_worker_verbose = True
        self.pysyft_worker_def = {
            'charlie': { 'type': 'websocket', 'port': 8779 },
            'alice': { 'type': 'websocket', 'port': 8777 },
            'bob': { 'type': 'websocket', 'port': 8778 },
            'dave': { 'type': 'virtual' },
            'eve': { 'type': 'virtual' },
            'fred': { 'type': 'virtual' },
            'george': { 'type': 'virtual' },
            'harry': { 'type': 'virtual' },
            'ian': { 'type': 'virtual' },
            'jack': { 'type': 'virtual' },
        } # BUG: Pysyft expects worker id as string; if to use dist_scheme=custom must change key
        # TODO: Cannot mix GPU and CPU, only one kind of device throughout
        
        self.dataset_root = 'speech_commands_dataset_v2'
        self.use_cache = True
        self.cache = 'speech_commands_dataset_v2.cache'        
        self.datalist_root = 'datalists'
        self.datalist = ['numbers_noise', 'leftright', 'numbers'][2]
        self.train_dataset = 'training_list.txt'
        self.valid_dataset = 'validation_list.txt'
        self.test_dataset = 'testing_list.txt'
        self.classlabels = '__classes__.txt'
        self.dataload_workers_num = 0
        self.drop_last_batch = True
        
        # Checks
        if self.use_cuda:
            if torch.cuda.is_available():
                self.default_device = 'cuda'
                self.cuda_args = {'num_workers': 1, 'pin_memory': True}
            else:
                print('CUDA not available, using CPU instead.')
                self.use_cuda = False
                self.default_device = 'cpu'
        else:
            self.default_device = 'cpu'
        
        if self.use_pysyft:
            if (self.fed_worker_num > len(self.pysyft_worker_def)):
                print('Not enought workers defined in pysyft_worker_def!')
                sys.exit('Aborted.')
            for idx in self.pysyft_worker_def:
                if 'device' not in self.pysyft_worker_def[idx]:
                    self.pysyft_worker_def[idx]['device'] = self.default_device
        
        if self.fedadam and self.optim != 'adam':
            print('FedAdam specified but optimizer is not Adam.')
            self.fedadam = False
            
        if self.fedavgver == 3 and not self.global_model:
            print('federated_average version 3 requires using a global model. Setting global_model to True.')
            self.global_model = True
            
        if self.epochs == 'auto':
            self.epochs = self.rounds * self.epoch_per_round if self.federate else self.rounds
            
        if self.fedprox:
            self.weight_decay = 0
        
args = Arguments()

In [4]:
# Globals
torch.manual_seed(args.seed)

START_TIMESTAMP = int(time.time()*1000)
START_EPOCH = 0

BEST_ACC = 0
LEAST_LOSS = 1e100

# Data processing

## Define transformations

In [5]:
from torchvision.transforms import Compose
from transformation import *

audio_feature_transform = Compose([
    FixAudioLength(),
    ToMelSpectrogram(),
    ToTensorFromSpect()
])
tensor_transform = Compose([
    ToTensorFromSpect()
])

## Define Dataloader (iterator), instantiate dataset

In [6]:
from dataset import SpeechCommandsDataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler

transform = tensor_transform if args.use_cache else audio_feature_transform
db = shelve.open(args.cache, 'r') if args.use_cache else None

paths = [{
    'classes': os.path.join(args.datalist_root, args.datalist, args.classlabels),
    'filelist': os.path.join(args.datalist_root, args.datalist, filelist),
    'dataset_dir': args.dataset_root
} for filelist in [args.train_dataset, args.valid_dataset, args.test_dataset]]
train_dataset = SpeechCommandsDataset(paths[0], transform=transform, db=db)
valid_dataset = SpeechCommandsDataset(paths[1], transform=transform, db=db)
test_dataset = SpeechCommandsDataset(paths[2], transform=transform, db=db)

# Get number of classes (same for train/valid/test)
nclass = train_dataset.get_num_of_classes()

sampler = None
if args.use_weighted_random_sampler:
    # Adopted from https://discuss.pytorch.org/t/balanced-sampling-between-classes-with-torchvision-dataloader/2703/3
    weights_for_sampling = train_dataset.make_weights_for_balanced_classes()
    sampler = WeightedRandomSampler(weights_for_sampling, len(weights_for_sampling))

# Note: train dataloader may be overriden if args.federate is True
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=args.shuffle,
                              pin_memory=args.use_cuda, num_workers=args.dataload_workers_num,
                              sampler=sampler)
valid_dataloader = DataLoader(valid_dataset, batch_size=args.test_batch_size, shuffle=args.shuffle,
                              pin_memory=args.use_cuda, num_workers=args.dataload_workers_num)
test_dataloader  = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=args.shuffle,
                              pin_memory=args.use_cuda, num_workers=args.dataload_workers_num)

{'zero': 0, 'one': 1, 'two': 2, 'three': 3, 'four': 4, 'five': 5, 'six': 6, 'seven': 7, 'eight': 8, 'nine': 9}
{'zero': 0, 'one': 1, 'two': 2, 'three': 3, 'four': 4, 'five': 5, 'six': 6, 'seven': 7, 'eight': 8, 'nine': 9}
{'zero': 0, 'one': 1, 'two': 2, 'three': 3, 'four': 4, 'five': 5, 'six': 6, 'seven': 7, 'eight': 8, 'nine': 9}


# Construct model

In [7]:
# Imports
import torch.nn as nn
import torch.optim as optim

import resnet as resnet
import densenet as densenet
import mobilenet as mn
import googlenet as googlenet
import lenet as lenet
import squeezenet as squeezenet
import shufflenetv2 as shufflenetv2
import nasnet as nasnet
from efficientnet_pytorch import EfficientNet

from fedavg import *

hyperparam = {"num_classes": nclass, "in_channels": 1}
criterion = nn.CrossEntropyLoss()

In [8]:
def create_model(network='mobilenetv2', send_to_device=None, rootmodel=None):
    if network == 'mobilenetv2':
        model = mn.mobilenet_v2(**hyperparam)
    elif network == 'mobilenetv2_quantize':
        model = mn.mobilenet_v2q(**hyperparam)
    elif network == 'mobilenetv1':
        model = mn.mobilenet_v1(**hyperparam)
    elif network == 'mobilenetv1_quantize':
        model = mn.mobilenet_v1q(**hyperparam)
    elif network == 'googlenet':
        model = googlenet.googlenet(**hyperparam)
    elif network == 'squeezenet1_1':
        model = squeezenet.squeezenet1_1(**hyperparam)
    elif network == 'efficientnetb0':
        model = EfficientNet.from_name('efficientnet-b0', image_size=32, **hyperparam)
    elif network == 'shufflenetv2_x0.5':
        model = shufflenetv2.shufflenet_v2_x0_5(**hyperparam)
    elif network == 'shufflenetv2_x1':
        model = shufflenetv2.shufflenet_v2_x1_0(**hyperparam)
    elif network == 'resnet18':
        model = resnet.resnet18(**hyperparam)
    elif network == 'resnet50':
        model = resnet.resnet50(**hyperparam)
    elif network == 'densenet121':
        model = densenet.densenet121(**hyperparam)
    elif network == 'nasnet-a-mobile':
        model = nasnet.nasnet_a_mobile(**hyperparam)
    elif network == 'lenet5':
        model = lenet.LeNet5() # Already accepting 1x32x32, output 10
    else:
        raise ValueError('Bad network name "%s" supplied for create_model.' % network)

    # For multi-GPU
    #if args.use_cuda:
        #model = torch.nn.DataParallel(model).to(DEVICE)
        
    if rootmodel:
        #copy_model(model, rootmodel)
        model.load_state_dict(rootmodel.state_dict())
    
    if send_to_device:
        model = model.to(torch.device(send_to_device))
        
    return model
        
def create_optimizer(model, algorithm='sgd'):
    # Define optimizer
    if algorithm == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif algorithm == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif algorithm == 'rmsprop':    
        optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    else:
        raise ValueError('Bad algorithm named "%s" supplied for create_optimizer.' % algorithm)
    return optimizer

def create_lr_scheduler(optimizer, algorithm='lambda'):
    # Define learning rate modifier
    if algorithm == 'lambda':
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=args.lr_lambda, last_epoch=-1)
    elif algorithm == 'plateau':
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=args.lr_scheduler_patience, 
                                                                  factor=args.lr_scheduler_gamma)
    else:
        raise ValueError('Bad algorithm named "%s" supplied for create_lr_scheduler.' % algorithm)
    return lr_scheduler

def get_lr(opt):
    """Reads current learning rate from a torch.optim"""
    
    return opt.param_groups[0]['lr']

In [9]:
if not args.federate:
    # Centralized mode
    model = create_model(send_to_device=args.default_device, network=args.network)
    optimizer = create_optimizer(model, algorithm=args.optim)
    lr_scheduler = create_lr_scheduler(optimizer, algorithm=args.lr_scheduler)
else:
    # Federated learning mode    
    # (1) With Pysyft
    if args.use_pysyft:
        import syft as sy
        hook = sy.TorchHook(torch)  # decorate torch lib
        
        # Prepare workers
        workers = []
        for idx, (worker_id, worker_def) in enumerate(args.pysyft_worker_def.items()):
            if idx > args.fed_worker_num-1: 
                break
            if worker_def['type'] == 'websocket':
                remote_worker = sy.WebsocketClientWorker(hook, host='localhost', port=worker_def['port'], id=worker_id, timeout=600)
                if (remote_worker.objects_count_remote() > 0):
                    remote_worker.clear_objects_remote()
                workers.append(remote_worker)
            else:
                workers.append(sy.VirtualWorker(hook, id=worker_id))
        
        # Distribute dataset
        worker_ids = [w.id for w in workers]    
        fed_dataset, meta_data = federate_dataset(train_dataset, workers, nclass, args.dist_scheme,
                                                  uniqueness_threshold=args.uniqueness_threshold,
                                                  class_per_worker=args.class_per_worker,
                                                  custom_mapping=args.custom_mapping,
                                                  use_pysyft=args.use_pysyft)
            
        # Overwrite train dataloader to be federated
        # FederatedDataLoader does not support many options
        train_dataloader = sy.FederatedDataLoader(fed_dataset, batch_size=args.batch_size, shuffle=args.shuffle,
                                                  drop_last=args.drop_last_batch)
    
    # (2) Simulate local epochs (faster than Pysyft...)
    else:
        workers = list(range(args.fed_worker_num)) # workers is a list of ids
        fed_dataset, meta_data = federate_dataset(train_dataset, workers, nclass, args.dist_scheme,
                                                  uniqueness_threshold=args.uniqueness_threshold,
                                                  class_per_worker=args.class_per_worker,
                                                  custom_mapping=args.custom_mapping,
                                                  use_pysyft=args.use_pysyft)
            
        # train_dataloader will now become a Dict of DataLoaders
        from torch.utils.data import TensorDataset
        worker_ids = workers
        train_dataloader = { w: DataLoader(
                                  TensorDataset(fed_dataset[w][0], fed_dataset[w][1]), 
                                  batch_size=args.batch_size, shuffle=args.shuffle,
                                  pin_memory=args.use_cuda, num_workers=args.dataload_workers_num,
                                  sampler=sampler, drop_last=args.drop_last_batch)
                            for w in workers
                          }
    
    # Tally local dataset size
    worker_n = {}
    for w in workers:
        dim0 = meta_data[w]['xshape'][0]
        worker_n[w.id if args.use_pysyft else w] = dim0
        print('Worker %s, number of input: %d, classes: %s' % \
              (str(w), dim0, str(meta_data[w]['yset'])) )
    
    # models/optimizers/lr_schedulers to be federated
    fed_rootmodel = create_model(network=args.network) # same random initialization
    fed_models = { id: create_model(
        send_to_device=args.default_device, network=args.network, rootmodel=fed_rootmodel) for id in worker_ids }
    global_model = create_model(
        send_to_device=args.default_device, network=args.network, rootmodel=fed_rootmodel) if args.global_model else None
    fed_optimizers = { id: create_optimizer(fed_models[id], algorithm=args.optim) for id in worker_ids }
    fed_lr_schedulers = { id: create_lr_scheduler(fed_optimizers[id], algorithm=args.lr_scheduler) for id in worker_ids }

Sending data to worker dave ...
Sending data to worker alice ...
Sending data to worker charlie ...
Sending data to worker bob ...
Sending data to worker eve ...
Sending data to worker fred ...
Sending data to worker george ...
Sending data to worker harry ...
Sending data to worker ian ...
Sending data to worker jack ...
Worker <VirtualWorker id:dave #objects:2>, number of input: 3136, classes: [0, 1, 2, 4, 5]
Worker <WebsocketClientWorker id:alice #tensors local:0 #tensors remote: 2>, number of input: 3086, classes: [3, 6, 7, 8, 9]
Worker <WebsocketClientWorker id:charlie #tensors local:0 #tensors remote: 2>, number of input: 3121, classes: [0, 1, 4, 6, 9]
Worker <WebsocketClientWorker id:bob #tensors local:0 #tensors remote: 2>, number of input: 3113, classes: [2, 3, 5, 7, 8]
Worker <VirtualWorker id:eve #objects:2>, number of input: 3120, classes: [1, 2, 3, 7, 9]
Worker <VirtualWorker id:fred #objects:2>, number of input: 3114, classes: [0, 4, 5, 6, 8]
Worker <VirtualWorker id:geor

# Training routines

In [10]:
def train(epoch, model, optimizer):    
    print("epoch %3d with lr=%.02e, wall clock %s" % (epoch, get_lr(optimizer), str(datetime.datetime.now().time())))
    log_interval = len(train_dataloader) // args.log_frequency
    device = torch.device(args.default_device)
    model.train()  # Set model to training mode

    for batch_idx, (inputs, targets) in enumerate(train_dataloader):
        inputs = torch.unsqueeze(inputs, 1)
        #print(targets)
        
        # copy tensors to cpu/gpu
        inputs = inputs.to(device)
        targets = targets.to(device)

        # forward/backward
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        #loss = criterion(outputs.logits, targets) # for googlenet
        loss.backward()
        optimizer.step()

        # statistics
        if batch_idx == 0 or batch_idx % log_interval == log_interval-1:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, len(train_dataset),
                100. * batch_idx / len(train_dataloader), loss.item()))

In [11]:
def train_federated_pysyft(epoch, models, optimizers, global_model=None):
    """Return True if is FedAvg round, else return False."""
    
    log_interval = len(train_dataloader) // (args.log_frequency * args.fed_worker_num)
    is_fedavg_round = (epoch % args.epoch_per_round == args.epoch_per_round-1)
    is_start_round = (epoch % args.epoch_per_round == 0)
    device = torch.device(args.default_device)
    last_worker = ''
        
    for worker_id in models:
        models[worker_id].train()  # Set model to training mode

    for batch_idx, (inputs, targets) in enumerate(train_dataloader):
        inputs = torch.unsqueeze(inputs, 1)
        worker = inputs.location # worker here is Pysyft Worker

        model = models[worker.id]
        optimizer = optimizers[worker.id]
        
        if last_worker != worker.id:
            print("epoch %3d on worker %s with lr=%.02e, wall clock %s" % \
                  (epoch, worker.id, get_lr(optimizer), str(datetime.datetime.now().time())))
            last_worker = worker.id
            
            #if args.pysyft_worker_def[last_worker]['device'] == 'cuda':
            #    torch.set_default_tensor_type('torch.cuda.FloatTensor')
            #    last_device = torch.device('cuda')
            #elif args.pysyft_worker_def[last_worker]['device'] == 'cpu':
            #    torch.set_default_tensor_type('torch.FloatTensor')
            #    last_device = torch.device('cpu')
            
            if is_start_round:
                #model.to(last_device)
                # send model to worker
                model.send(worker)
        
        # copy tensors to cpu/gpu
        #inputs = inputs.to(last_device)
        #targets = targets.to(last_device)
        inputs = inputs.to(device)
        targets = targets.to(device)

        # forward/backward
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        #loss = criterion(outputs.logits, targets) # for googlenet
        loss.backward()
        optimizer.step()
        batch_loss = loss.get()

        # statistics
        if batch_idx % log_interval == log_interval-1:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * args.batch_size, len(train_dataset),
                100. * batch_idx / len(train_dataloader), batch_loss.item()))
        #print("Batch %d done" % batch_idx) # Debug
    
    if is_fedavg_round:
        for worker_id, model in models.items():
            model.get() # Pysyft get model back
            #model.to(default_device)
        fedavg(args.fedavgver, models, worker_n, global_model)
        torch.set_default_tensor_type('torch.FloatTensor')
        return True
    else:
        return False

In [12]:
def train_federated_sim(epoch, models, optimizers, global_model=None):
    """Return True if is FedAvg round, else return False."""

    is_fedavg_round = (epoch % args.epoch_per_round == args.epoch_per_round-1)
    #is_start_round = (epoch % args.epoch_per_round == 0)
    device = torch.device(args.default_device)
    
    # NOTE: iterating over worker_id
    for worker_id in worker_ids:
        model = models[worker_id]
        dataloader = train_dataloader[worker_id]
        optimizer = optimizers[worker_id]
        log_interval = len(dataloader) // args.log_frequency
        print("Local epoch %2d on worker %2d with lr=%.02e, wall clock %s" % \
              (epoch, worker_id, get_lr(optimizer), str(datetime.datetime.now().time())))
        
        """# DEBUG
        worst_loss = -1
        prev_batch = None
        # END DEBUG"""
        
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            if inputs.size()[0] == 1:
                print('[WARNING] Batch size of 1, dropping...')
                continue
            inputs = torch.unsqueeze(inputs, 1)
        
            # copy tensors to cpu/gpu
            inputs = inputs.to(device)
            targets = targets.to(device)

            # forward/backward
            optimizer.zero_grad()
            outputs = model(inputs)
            #print(outputs)
            loss = criterion(outputs, targets)
            #loss = criterion(outputs.logits, targets) # for googlenet
            loss.backward()
            optimizer.step()
            
            """# DEBUG
            previous_batch = [inputs, outputs, targets, loss]
            if worst_loss == -1:
                worst_loss = loss.item()
            if torch.isnan(loss).any():
                torch.set_printoptions(profile="full")
                with open('bad.dump', 'a') as f:
                    print('Anomaly at batch %d, loss %f -> %f' % (batch_idx, worst_loss, loss.item()))
                    print('epoch %s, worker %s, batch %d, anomaly detected:\n' % (epoch, worker_id, batch_idx), file=f)
                    print(inputs, file=f)
                    print(targets, file=f)
                    print(outputs, file=f)
                    print('Previously:', file=f)
                    print(previous_batch, file=f)
                    print('=========================================================================\n', file=f)
                torch.set_printoptions(profile="default") # reset
                worst_loss = loss.item()
            # END DEBUG"""
            
            # statistics
            if batch_idx % log_interval == log_interval-1:
                print('{:3d}/{:3d} batches bs={}\t({:.0f}%)\tLoss: {:.6f}'.format(
                    batch_idx, len(dataloader), args.batch_size,
                    100. * batch_idx / len(dataloader), loss.item()))
    
    if is_fedavg_round:
        fedavg(args.fedavgver, models, worker_n, global_model)
        if args.fedadam:
            fedadam(optimizers, worker_n)
        return True
    else:
        return False

# Validation and test routines

In [13]:
def valid(epoch, model):
    global BEST_ACC, LEAST_LOSS

    model.eval()  # Set model to evaluate mode
    device = torch.device(args.default_device)

    running_loss = 0.0
    it = 0
    correct = 0
    total = 0

    for idx, (inputs, targets) in enumerate(valid_dataloader):
        inputs = torch.unsqueeze(inputs, 1)

        with torch.no_grad():
            inputs = inputs.to(device)
            targets = targets.to(device)

            # forward
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        # statistics
        it += 1
        running_loss += loss.item()
        pred = outputs.argmax(1, keepdim=True) # get the index of the max log-probability 
        correct += pred.eq(targets.view_as(pred)).sum().item()
        #pred = outputs.data.max(1, keepdim=True)[1]
        #correct += pred.eq(targets.data.view_as(pred)).sum()
        total += targets.size(0)

    accuracy = correct / total
    epoch_loss = running_loss / it
    print('[ROUND] Accuracy: %.4f, Epoch loss: %.4f' % (accuracy, epoch_loss))
    
    if args.save_model:
        #checkpoint = {
        #    'epoch': epoch,
        #    'state_dict': model.state_dict(),
        #    'loss': epoch_loss,
        #    'accuracy': accuracy,
        #    'optimizer' : optimizer.state_dict(),
        #}

        # a name used to save checkpoints etc.
        full_name = '%s_%s_%s_bs%d_lr%.1e_wd%.1e' % \
                    (args.network, args.optim, args.lr_scheduler, args.batch_size, args.lr, args.weight_decay)

        if accuracy > BEST_ACC:
            #torch.save(checkpoint, 'checkpoints/best-loss-speech-commands-checkpoint-%s.pth' % full_name)
            torch.save(model, '%d-%s-best-acc-%s.pth' % (START_TIMESTAMP, full_name, args.description))
        if epoch_loss < LEAST_LOSS:
            #torch.save(checkpoint, 'checkpoints/best-acc-speech-commands-checkpoint-%s.pth' % full_name)
            torch.save(model, '%d-%s-least-loss-%s.pth' % (START_TIMESTAMP, full_name, args.description))

        #torch.save(checkpoint, 'checkpoints/last-speech-commands-checkpoint.pth')
        #del checkpoint  # reduce memory
    
    BEST_ACC = max(accuracy, BEST_ACC)
    LEAST_LOSS = min(epoch_loss, LEAST_LOSS)

    return epoch_loss

In [14]:
def test(model):
    model.eval()  # Set model to evaluate mode
    device = torch.device(args.default_device)

    running_loss = 0.0
    it = 0
    correct = 0
    total = 0
    
    for idx, (inputs, targets) in enumerate(test_dataloader):
        inputs = torch.unsqueeze(inputs, 1)

        with torch.no_grad():
            inputs = inputs.to(device)
            targets = targets.to(device)

            # forward
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        
        # statistics
        it += 1
        running_loss += loss.item()
        pred = outputs.argmax(1, keepdim=True) # get the index of the max log-probability 
        correct += pred.eq(targets.view_as(pred)).sum().item()
        total += targets.size(0)

    accuracy = correct / total
    epoch_loss = running_loss / it
    print('Accuracy: %.4f, Loss: %.4f' % (accuracy, epoch_loss))

# Main epoch loop

In [15]:
# args.lr=5e-4
# args.epoch_per_round=3
# args.epochs = 300
# fed_rootmodel = create_model(network=args.network, send_to_device=False) # same random initialization
# fed_models = { id: create_model(network=args.network, rootmodel=fed_rootmodel) for id in worker_ids }
# fed_optimizers = { id: create_optimizer(fed_models[id], algorithm=args.optim) for id in worker_ids }
# fed_lr_schedulers = { id: create_lr_scheduler(fed_optimizers[id], args.lr_scheduler) for id in worker_ids }

In [16]:
since = time.time()
for epoch in range(START_EPOCH, args.epochs):
#for epoch in range(72, 120):
    # Workaround for pysyft error: https://github.com/OpenMined/PySyft/issues/2518
    if not args.federate:
        train(epoch, model, optimizer)
        epoch_loss = valid(epoch, model)
        if args.lr_scheduler == 'plateau':
            lr_scheduler.step(metrics=epoch_loss)
        else:
            lr_scheduler.step()
    else:
        if args.use_pysyft:
            round_complete = train_federated_pysyft(epoch, fed_models, fed_optimizers, global_model=global_model)
        else:
            round_complete = train_federated_sim(epoch, fed_models, fed_optimizers, global_model=global_model)
            
        if round_complete:
            epoch_loss = valid(epoch, global_model if args.global_model else fed_models[worker_ids[0]])
            for worker_id, sched in fed_lr_schedulers.items():
                if args.lr_scheduler == 'plateau':
                    sched.step(metrics=epoch_loss)
                else:
                    sched.step()

    time_elapsed = time.time() - since
    time_str = '[EPOCH] Total time elapsed: {:.0f}h {:.0f}m {:.0f}s '.format(time_elapsed // 3600, time_elapsed % 3600 // 60, time_elapsed % 60)
    print("%s, best accuracy: %.02f%%, best loss %f" % (time_str, 100*BEST_ACC, LEAST_LOSS))
    
    #time.sleep(60)
    #if epoch % 30 == 29:
    #    time.sleep(120) # Pause to let computer cool...
    
print("finished")

epoch   0 on worker dave with lr=1.00e-02, wall clock 13:35:59.097268


ValueError: not enough values to unpack (expected 4, got 1)

# Test

In [None]:
# eval_model = torch.load('1597955978929-mobilenetv1_quantize_sgd_lambda_bs16_lr1.0e-02_wd1.0e-07-best-acc-fed(w10e3cw2).pth')
# test(eval_model)

In [None]:
#federated_avg_v2(fed_models, worker_n)
if args.federate:
    test(global_model if args.global_model else fed_models[worker_ids[0]])
else:
    test(model)

In [None]:
# full_name = '%s_%s_%s_bs%d_lr%.1e_wd%.1e' % \
#                     (args.network, args.optim, args.lr_scheduler, args.batch_size, args.lr, args.weight_decay)
# torch.save(fed_models[worker_ids[0]], '%d-%s-%s.pth' % (START_TIMESTAMP, full_name, 'fed(w10e3cw5)-43ROUND'))
# torch.save(model, '%d-%s-%s.pth' % (START_TIMESTAMP, full_name, 'centralized'))

In [None]:
# if db:
#     db.close()

# Debug zone

In [None]:
# next(fed_models['bob'].parameters()).device

In [None]:
# next(fed_models['alice'].parameters()).data

In [None]:
# import pickle
# from fedavg import *
# pickle.dump(fed_models, open('fedmodels.p', 'wb'))
# federated_avg(fed_models, worker_n)

In [None]:
# print(fed_models[worker_ids[1]].state_dict())
# print(dict(fed_models[worker_ids[1]].named_parameters()).keys())

In [None]:
# fed_models2 = pickle.load(open('fedmodels.p', 'rb'))
# federated_avg_v2(fed_models2, worker_n)

In [None]:
# print(fed_models2[worker_ids[1]].state_dict())

In [None]:
# fed_lr_schedulers = { id: create_lr_scheduler(fed_optimizers[id], args.lr_scheduler) for id in worker_ids }

In [None]:
# for w in worker_ids:
#     fed_lr_schedulers[w].step()
# get_lr(fed_optimizers[worker_ids[0]])

In [None]:
#with open('badmodel.dump', 'w') as f:
#    torch.set_printoptions(profile="full")
#    d = dict(fed_models[worker_ids[9]].named_parameters())
#    print(d, file=f)
#    torch.set_printoptions(profile="default")

In [None]:
# m = torch.load('1597715469718-mobilenetv2_sgd_lambda_bs32_lr1.0e-02_wd1.0e-07-fed(w10e3cw5)-sgd.pth')
# fed_models = { id: create_model(network=args.network, rootmodel=m) for id in worker_ids }

In [None]:
# eval_model = torch.load('1597949089071-mobilenetv1_quantize_sgd_lambda_bs16_lr1.0e-02_wd1.0e-07-least-loss-fedv1(w10e3cw5).pth')
# valid(1,eval_model)

In [None]:
# for i in worker_ids:
#     print(valid(1,fed_models[i]))

In [None]:
# fed_optimizers[worker_ids[0]].state_dict()['state'].keys()

In [None]:
# fed_optimizers[worker_ids[0]].state_dict()['state'][2120198521064]['exp_avg'].shape

In [None]:
# fed_optimizers[worker_ids[1]].state_dict()['state'].keys()

In [None]:
# fed_optimizers[worker_ids[1]].state_dict()['state'][2120229084600]['exp_avg'].shape