In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from arguments import Arguments
from cnn import CNN
import os
import syft as sy
import torch
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from train import fl_train_with_fl as train
from train import test

In [None]:
# Setups
args = Arguments()
USE_CUDA = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
device = torch.device("cuda" if USE_CUDA else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if USE_CUDA else {}
kwargs = {}

In [None]:
ckpt_path = '../ckpts'
init_path = os.path.join(ckpt_path, 'mnist_cnn_fl.init')
best_path = os.path.join(ckpt_path, 'mnist_cnn_fl.best')
stop_path = os.path.join(ckpt_path, 'mnist_cnn_fl.stop')


In [None]:
# Setup hook to support FL
hook = sy.TorchHook(torch)

# Define workers
workers = list()
for id_ in range(args.num_workers):
    workers.append(sy.VirtualWorker(hook, id=str(id_)))

train_loader = sy.FederatedDataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
    .federate(workers),
    batch_size=args.batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)


In [None]:
best = 0

# Fire the engines
model = CNN().to(device)
if args.load_init:
    model.load_state_dict(torch.load(init_path))
    print('Load init: {}'.format(init_path))
elif args.save_init:
    torch.save(model.state_dict(), init_path)
    print('Save init: {}'.format(init_path))
    
optimizer = optim.SGD(model.parameters(), lr=args.lr)

for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    best = test(args, model, device, test_loader, best)

    if args.save_model:
        torch.save(model.state_dict(), best_path)
        print('Model best: {}\n'.format(best_path))
    
if (args.save_model):
    torch.save(model.state_dict(), stop_path)
    print('Model stop: {}'.format(stop_path))
