In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from arguments import Arguments
from cnn import CNN
from distributor import get_distributed_data, get_distributed_data_using_loader
from distributor import get_fog_graph
import os
import syft as sy
import torch
from torchvision import datasets, transforms
from train import fl_train as train, fog_train
from train import test

Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was '/usr/local/lib/python3.6/dist-packages/tf_encrypted/operations/secure_random/secure_random_module_tf_1.15.3.so'





In [3]:
# Setups
args = Arguments()
USE_CUDA = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
device = torch.device("cuda" if USE_CUDA else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if USE_CUDA else {}
kwargs = {}

In [4]:
ckpt_path = '../ckpts'
dataset = 'mnist'
clf_type = 'cnn'
paradigm = 'fl'
model_name = '{}_{}_{}'.format(dataset, clf_type, paradigm)
init_path = os.path.join(ckpt_path, 'mnist_cnn_fl.init')
best_path = os.path.join(ckpt_path, model_name + '.best')
stop_path = os.path.join(ckpt_path, model_name + '.stop')


In [5]:
# Setup hook to support FL
hook = sy.TorchHook(torch)
# Define workers
workers = list()
for id_ in range(args.num_workers):
    workers.append(sy.VirtualWorker(hook, id=str(id_)))

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.num_train, shuffle=False, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.num_test, shuffle=True, **kwargs)

for data, target in train_loader:
    X_train = data
    y_train = target

for data, target in test_loader:
    X_test = data
    y_test = target

print('X_train: {}'.format(X_train.shape))
print('y_train: {}'.format(y_train.shape))

print('X_test: {}'.format(X_test.shape))
print('y_test: {}'.format(y_test.shape))


X_train: torch.Size([60000, 1, 28, 28])
y_train: torch.Size([60000])
X_test: torch.Size([10000, 1, 28, 28])
y_test: torch.Size([10000])


In [6]:
# prepare graph and data
fog_graph, workers = get_fog_graph(hook, args.num_workers, args.num_clusters,
                                   args.shuffle_workers, args.uniform_clusters, fog=False)
X_trains, y_trains = get_distributed_data(X_train, y_train, args.num_workers)

print(fog_graph)

{'L1_W0': ['L0_W0', 'L0_W1', 'L0_W2', 'L0_W3', 'L0_W4', 'L0_W5', 'L0_W6', 'L0_W7', 'L0_W8', 'L0_W9', 'L0_W10', 'L0_W11', 'L0_W12', 'L0_W13', 'L0_W14', 'L0_W15', 'L0_W16', 'L0_W17', 'L0_W18', 'L0_W19', 'L0_W20', 'L0_W21', 'L0_W22', 'L0_W23', 'L0_W24', 'L0_W25', 'L0_W26', 'L0_W27', 'L0_W28', 'L0_W29', 'L0_W30', 'L0_W31', 'L0_W32', 'L0_W33', 'L0_W34', 'L0_W35', 'L0_W36', 'L0_W37', 'L0_W38', 'L0_W39', 'L0_W40', 'L0_W41', 'L0_W42', 'L0_W43', 'L0_W44', 'L0_W45', 'L0_W46', 'L0_W47', 'L0_W48', 'L0_W49']}


In [9]:
best = 0
# Fire the engines
model = CNN().to(device)
if args.load_init:
    model.load_state_dict(torch.load(init_path))
    print('Load init: {}'.format(init_path))
elif args.save_init:
    torch.save(model.state_dict(), init_path)
    print('Save init: {}'.format(init_path))


Load init: ../ckpts/mnist_cnn_fl.init


In [10]:
best = 0

for epoch in range(1, args.epochs + 1):
    train(args, model, fog_graph, workers, X_trains, y_trains,
          device, epoch)
    acc = test(args, model, device, test_loader, best, epoch)

    if args.save_model and acc > best:
        best = acc
        torch.save(model.state_dict(), best_path)
        print('Model best  @ {}, acc {}: {}\n'.format(epoch, acc, best_path))
    
if (args.save_model):
    torch.save(model.state_dict(), stop_path)
    print('Model stop: {}'.format(stop_path))


Model best  @ 1, acc 0.1793: ../ckpts/mnist_cnn_fl.best



Model best  @ 2, acc 0.3031: ../ckpts/mnist_cnn_fl.best



Model best  @ 3, acc 0.3945: ../ckpts/mnist_cnn_fl.best



Model best  @ 4, acc 0.4855: ../ckpts/mnist_cnn_fl.best



Model best  @ 5, acc 0.5699: ../ckpts/mnist_cnn_fl.best



Model best  @ 6, acc 0.6262: ../ckpts/mnist_cnn_fl.best



Model best  @ 7, acc 0.6646: ../ckpts/mnist_cnn_fl.best



Model best  @ 8, acc 0.6912: ../ckpts/mnist_cnn_fl.best



Model best  @ 9, acc 0.7114: ../ckpts/mnist_cnn_fl.best



Train Epoch: 10 	Loss: 2.041748 +- 0.003227



Test set: Average loss: 1.9826, Accuracy: 7249/10000 (72.49%) ==> 71.14%
Model best  @ 10, acc 0.7249: ../ckpts/mnist_cnn_fl.best



Model best  @ 11, acc 0.7348: ../ckpts/mnist_cnn_fl.best



Model best  @ 12, acc 0.739: ../ckpts/mnist_cnn_fl.best



Model best  @ 13, acc 0.7469: ../ckpts/mnist_cnn_fl.best



Model best  @ 14, acc 0.7579: ../ckpts/mnist_cnn_fl.best



Model best  @ 15, acc 0.7722: ../ckpts/mnist_cnn_fl.best



Model best  @ 16, acc 0.7878: ../ckpts/mnist_cnn_fl.best



Model best  @ 17, acc 0.8024: ../ckpts/mnist_cnn_fl.best



Model best  @ 18, acc 0.8122: ../ckpts/mnist_cnn_fl.best



Model best  @ 19, acc 0.8232: ../ckpts/mnist_cnn_fl.best



Train Epoch: 20 	Loss: 1.048532 +- 0.013975



Test set: Average loss: 0.9168, Accuracy: 8343/10000 (83.43%) ==> 82.32%
Model best  @ 20, acc 0.8343: ../ckpts/mnist_cnn_fl.best



Model best  @ 21, acc 0.8403: ../ckpts/mnist_cnn_fl.best



Train Epoch: 30 	Loss: 1.138466 +- 0.013150



Test set: Average loss: 0.9988, Accuracy: 7740/10000 (77.40%) ==> 84.03%


Model best  @ 33, acc 0.841: ../ckpts/mnist_cnn_fl.best



Model best  @ 34, acc 0.8619: ../ckpts/mnist_cnn_fl.best



Model best  @ 35, acc 0.8646: ../ckpts/mnist_cnn_fl.best



Model best  @ 36, acc 0.8695: ../ckpts/mnist_cnn_fl.best



Model best  @ 37, acc 0.8699: ../ckpts/mnist_cnn_fl.best



Model best  @ 38, acc 0.876: ../ckpts/mnist_cnn_fl.best



Train Epoch: 40 	Loss: 0.549937 +- 0.020649



Test set: Average loss: 0.8692, Accuracy: 6696/10000 (66.96%) ==> 87.60%


Model best  @ 47, acc 0.884: ../ckpts/mnist_cnn_fl.best



Model best  @ 48, acc 0.8924: ../ckpts/mnist_cnn_fl.best



Model best  @ 49, acc 0.8947: ../ckpts/mnist_cnn_fl.best



Train Epoch: 50 	Loss: 0.442206 +- 0.020063



Test set: Average loss: 0.4012, Accuracy: 8972/10000 (89.72%) ==> 89.47%
Model best  @ 50, acc 0.8972: ../ckpts/mnist_cnn_fl.best



Model best  @ 51, acc 0.8978: ../ckpts/mnist_cnn_fl.best



Model best  @ 52, acc 0.8993: ../ckpts/mnist_cnn_fl.best



Model best  @ 53, acc 0.9005: ../ckpts/mnist_cnn_fl.best



Model best  @ 54, acc 0.9022: ../ckpts/mnist_cnn_fl.best



Model best  @ 55, acc 0.9033: ../ckpts/mnist_cnn_fl.best



Model best  @ 56, acc 0.9045: ../ckpts/mnist_cnn_fl.best



Model best  @ 57, acc 0.9053: ../ckpts/mnist_cnn_fl.best



Model best  @ 58, acc 0.9079: ../ckpts/mnist_cnn_fl.best



Train Epoch: 60 	Loss: 0.346854 +- 0.022767



Test set: Average loss: 0.3237, Accuracy: 9097/10000 (90.97%) ==> 90.79%
Model best  @ 60, acc 0.9097: ../ckpts/mnist_cnn_fl.best



Model best  @ 67, acc 0.9144: ../ckpts/mnist_cnn_fl.best



Model best  @ 68, acc 0.9169: ../ckpts/mnist_cnn_fl.best



Model best  @ 69, acc 0.9194: ../ckpts/mnist_cnn_fl.best



Train Epoch: 70 	Loss: 0.319124 +- 0.021195



Test set: Average loss: 0.2919, Accuracy: 9208/10000 (92.08%) ==> 91.94%
Model best  @ 70, acc 0.9208: ../ckpts/mnist_cnn_fl.best



Model best  @ 71, acc 0.922: ../ckpts/mnist_cnn_fl.best



Model best  @ 72, acc 0.9228: ../ckpts/mnist_cnn_fl.best



Model best  @ 74, acc 0.9239: ../ckpts/mnist_cnn_fl.best



Model best  @ 75, acc 0.925: ../ckpts/mnist_cnn_fl.best



Model best  @ 76, acc 0.9251: ../ckpts/mnist_cnn_fl.best



Model best  @ 77, acc 0.9266: ../ckpts/mnist_cnn_fl.best



Model best  @ 79, acc 0.9277: ../ckpts/mnist_cnn_fl.best



Train Epoch: 80 	Loss: 0.276213 +- 0.022083



Test set: Average loss: 0.2542, Accuracy: 9281/10000 (92.81%) ==> 92.77%
Model best  @ 80, acc 0.9281: ../ckpts/mnist_cnn_fl.best



Model best  @ 81, acc 0.9293: ../ckpts/mnist_cnn_fl.best



Model best  @ 83, acc 0.9306: ../ckpts/mnist_cnn_fl.best



Model best  @ 85, acc 0.9322: ../ckpts/mnist_cnn_fl.best



Model best  @ 87, acc 0.9328: ../ckpts/mnist_cnn_fl.best



Train Epoch: 90 	Loss: 0.252137 +- 0.022114



Test set: Average loss: 0.2330, Accuracy: 9321/10000 (93.21%) ==> 93.28%


Model best  @ 91, acc 0.9329: ../ckpts/mnist_cnn_fl.best



Train Epoch: 100 	Loss: 0.513000 +- 0.029950



Test set: Average loss: 0.6507, Accuracy: 8212/10000 (82.12%) ==> 93.29%


Model best  @ 103, acc 0.9391: ../ckpts/mnist_cnn_fl.best



Model best  @ 105, acc 0.9408: ../ckpts/mnist_cnn_fl.best



Model best  @ 109, acc 0.9413: ../ckpts/mnist_cnn_fl.best



Train Epoch: 110 	Loss: 0.221271 +- 0.019817



Test set: Average loss: 0.2024, Accuracy: 9416/10000 (94.16%) ==> 94.13%
Model best  @ 110, acc 0.9416: ../ckpts/mnist_cnn_fl.best



Model best  @ 111, acc 0.9422: ../ckpts/mnist_cnn_fl.best



Model best  @ 112, acc 0.9428: ../ckpts/mnist_cnn_fl.best



Model best  @ 114, acc 0.9436: ../ckpts/mnist_cnn_fl.best



Model best  @ 115, acc 0.9439: ../ckpts/mnist_cnn_fl.best



Model best  @ 116, acc 0.9441: ../ckpts/mnist_cnn_fl.best



Model best  @ 117, acc 0.9448: ../ckpts/mnist_cnn_fl.best



Model best  @ 118, acc 0.945: ../ckpts/mnist_cnn_fl.best



Model best  @ 119, acc 0.9452: ../ckpts/mnist_cnn_fl.best



Train Epoch: 120 	Loss: 0.203128 +- 0.019864



Test set: Average loss: 0.1858, Accuracy: 9458/10000 (94.58%) ==> 94.52%
Model best  @ 120, acc 0.9458: ../ckpts/mnist_cnn_fl.best



Model best  @ 122, acc 0.9461: ../ckpts/mnist_cnn_fl.best



Model best  @ 123, acc 0.9464: ../ckpts/mnist_cnn_fl.best



Model best  @ 124, acc 0.9468: ../ckpts/mnist_cnn_fl.best



Model best  @ 125, acc 0.9474: ../ckpts/mnist_cnn_fl.best



Model best  @ 126, acc 0.9475: ../ckpts/mnist_cnn_fl.best



Model best  @ 127, acc 0.9477: ../ckpts/mnist_cnn_fl.best



Model best  @ 128, acc 0.9478: ../ckpts/mnist_cnn_fl.best



Model best  @ 129, acc 0.9483: ../ckpts/mnist_cnn_fl.best



Train Epoch: 130 	Loss: 0.189371 +- 0.019473



Test set: Average loss: 0.1729, Accuracy: 9490/10000 (94.90%) ==> 94.83%
Model best  @ 130, acc 0.949: ../ckpts/mnist_cnn_fl.best



Model best  @ 131, acc 0.9492: ../ckpts/mnist_cnn_fl.best



Model best  @ 132, acc 0.9496: ../ckpts/mnist_cnn_fl.best



Model best  @ 133, acc 0.9497: ../ckpts/mnist_cnn_fl.best



Model best  @ 134, acc 0.9503: ../ckpts/mnist_cnn_fl.best



Model best  @ 135, acc 0.9504: ../ckpts/mnist_cnn_fl.best



Model best  @ 136, acc 0.9507: ../ckpts/mnist_cnn_fl.best



Model best  @ 137, acc 0.9509: ../ckpts/mnist_cnn_fl.best



Model best  @ 138, acc 0.951: ../ckpts/mnist_cnn_fl.best



Model best  @ 139, acc 0.9514: ../ckpts/mnist_cnn_fl.best



Train Epoch: 140 	Loss: 0.177810 +- 0.018952



Test set: Average loss: 0.1620, Accuracy: 9518/10000 (95.18%) ==> 95.14%
Model best  @ 140, acc 0.9518: ../ckpts/mnist_cnn_fl.best



Model best  @ 141, acc 0.9525: ../ckpts/mnist_cnn_fl.best



Model best  @ 142, acc 0.9528: ../ckpts/mnist_cnn_fl.best



Model best  @ 144, acc 0.9529: ../ckpts/mnist_cnn_fl.best



Model best  @ 145, acc 0.9531: ../ckpts/mnist_cnn_fl.best



Model best  @ 146, acc 0.9534: ../ckpts/mnist_cnn_fl.best



Model best  @ 147, acc 0.9538: ../ckpts/mnist_cnn_fl.best



Model best  @ 148, acc 0.9539: ../ckpts/mnist_cnn_fl.best



Model best  @ 149, acc 0.9541: ../ckpts/mnist_cnn_fl.best



Train Epoch: 150 	Loss: 0.167775 +- 0.018418



Test set: Average loss: 0.1526, Accuracy: 9544/10000 (95.44%) ==> 95.41%
Model best  @ 150, acc 0.9544: ../ckpts/mnist_cnn_fl.best



Model best  @ 151, acc 0.9548: ../ckpts/mnist_cnn_fl.best



Model best  @ 152, acc 0.9553: ../ckpts/mnist_cnn_fl.best



Model best  @ 153, acc 0.9556: ../ckpts/mnist_cnn_fl.best



Model best  @ 154, acc 0.956: ../ckpts/mnist_cnn_fl.best



Model best  @ 155, acc 0.9564: ../ckpts/mnist_cnn_fl.best



Model best  @ 156, acc 0.9566: ../ckpts/mnist_cnn_fl.best



Model best  @ 157, acc 0.957: ../ckpts/mnist_cnn_fl.best



Model best  @ 158, acc 0.9575: ../ckpts/mnist_cnn_fl.best



Train Epoch: 160 	Loss: 0.158931 +- 0.017911



Test set: Average loss: 0.1443, Accuracy: 9577/10000 (95.77%) ==> 95.75%
Model best  @ 160, acc 0.9577: ../ckpts/mnist_cnn_fl.best



Model best  @ 161, acc 0.9579: ../ckpts/mnist_cnn_fl.best



Model best  @ 162, acc 0.958: ../ckpts/mnist_cnn_fl.best



Model best  @ 163, acc 0.9583: ../ckpts/mnist_cnn_fl.best



Model best  @ 164, acc 0.9584: ../ckpts/mnist_cnn_fl.best



Model best  @ 166, acc 0.9585: ../ckpts/mnist_cnn_fl.best



Model best  @ 167, acc 0.9587: ../ckpts/mnist_cnn_fl.best



Model best  @ 168, acc 0.9589: ../ckpts/mnist_cnn_fl.best



Train Epoch: 170 	Loss: 0.151056 +- 0.017432



Test set: Average loss: 0.1369, Accuracy: 9590/10000 (95.90%) ==> 95.89%
Model best  @ 170, acc 0.959: ../ckpts/mnist_cnn_fl.best



Model best  @ 171, acc 0.9591: ../ckpts/mnist_cnn_fl.best



Model best  @ 172, acc 0.9597: ../ckpts/mnist_cnn_fl.best



Model best  @ 173, acc 0.9601: ../ckpts/mnist_cnn_fl.best



Model best  @ 174, acc 0.9602: ../ckpts/mnist_cnn_fl.best



Model best  @ 175, acc 0.9603: ../ckpts/mnist_cnn_fl.best



Model best  @ 177, acc 0.9604: ../ckpts/mnist_cnn_fl.best



Model best  @ 179, acc 0.9607: ../ckpts/mnist_cnn_fl.best



Train Epoch: 180 	Loss: 0.144006 +- 0.016987



Test set: Average loss: 0.1304, Accuracy: 9610/10000 (96.10%) ==> 96.07%
Model best  @ 180, acc 0.961: ../ckpts/mnist_cnn_fl.best



Model best  @ 181, acc 0.9615: ../ckpts/mnist_cnn_fl.best



Model best  @ 182, acc 0.9616: ../ckpts/mnist_cnn_fl.best



Model best  @ 183, acc 0.9619: ../ckpts/mnist_cnn_fl.best



Model best  @ 185, acc 0.9624: ../ckpts/mnist_cnn_fl.best



Model best  @ 187, acc 0.9625: ../ckpts/mnist_cnn_fl.best



Model best  @ 188, acc 0.9628: ../ckpts/mnist_cnn_fl.best



Model best  @ 189, acc 0.9631: ../ckpts/mnist_cnn_fl.best



Train Epoch: 190 	Loss: 0.137672 +- 0.016559



Test set: Average loss: 0.1245, Accuracy: 9628/10000 (96.28%) ==> 96.31%


Model best  @ 191, acc 0.9635: ../ckpts/mnist_cnn_fl.best



Model best  @ 193, acc 0.9639: ../ckpts/mnist_cnn_fl.best



Model best  @ 194, acc 0.9641: ../ckpts/mnist_cnn_fl.best



Model best  @ 195, acc 0.9643: ../ckpts/mnist_cnn_fl.best



Model best  @ 197, acc 0.9652: ../ckpts/mnist_cnn_fl.best



Model best  @ 199, acc 0.9655: ../ckpts/mnist_cnn_fl.best



Train Epoch: 200 	Loss: 0.133071 +- 0.016108



Test set: Average loss: 0.1211, Accuracy: 9634/10000 (96.34%) ==> 96.55%
Model stop: ../ckpts/mnist_cnn_fl.stop
