In [36]:
import torchvision.datasets as datasets
# mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)

In [37]:
import torch as th
import syft as sy

hook = sy.TorchHook(th)



In [38]:
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
trusted_agg = sy.VirtualWorker(hook, id = "trusted_agg")

bob.clear_objects()
alice.clear_objects()
trusted_agg.clear_objects()

<VirtualWorker id:trusted_agg #objects:0>

In [39]:
device = th.device("cuda")

In [40]:
import torchvision.transforms as transforms

federated_train_loader = sy.FederatedDataLoader(
    datasets.MNIST(root='./data', train=True, download=True, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])).federate((bob, alice)), 
                                                batch_size=64, shuffle=True)

test_loader = th.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=64, shuffle=True,)

In [41]:
# bob_train_dataset = sy.BaseDataset(train_data[:train_idx], train_labels[:train_idx]).send(bob)
# bob_test_dataset = sy.BaseDataset(test_data[:test_idx], test_labels[:test_idx]).send(bob)

# alice_train_dataset = sy.BaseDataset(train_data[train_idx:], train_labels[train_idx:]).send(alice)
# alice_test_dataset = sy.BaseDataset(test_data[test_idx:], test_labels[test_idx:]).send(alice)


In [42]:
# federated_train_dataset = sy.FederatedDataset([bob_train_dataset, alice_train_dataset])
# federated_test_dataset = sy.FederatedDataset([bob_test_dataset, alice_test_dataset])

# federated_train_dataloader = sy.FederatedDataLoader(federated_train_dataset, shuffle=True, batch_size = 64)
# federated_test_dataloader = sy.FederatedDataLoader(federated_test_dataset, shuffle=True, batch_size = 64)

In [43]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim = 1)

In [44]:
model = Net().to(device)

In [45]:
def train(model, federate_train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(federate_train_loader):
        model.send(data.location)
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        model.get()
        if batch_idx % 10 == 0:
            loss = loss.get()
            print(f'Train Epoch: {epoch} [{batch_idx*64}/{len(federate_train_loader)*64} ({100. * batch_idx / len(federated_train_loader)}%)] \t Loss: {loss.item()}')

In [48]:
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with th.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction= 'sum').item() #sum up batch loss
            pred = output.argmax(1, keepdim = True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [None]:
optimizer = optim.SGD(model.parameters(), lr=0.01) 
epoch = 20
for epoch in range(1, epoch + 1):
    train(model, federated_train_loader, optimizer, epoch)
    test(model, test_loader)

th.save(model.state_dict(), "mnist_cnn.pt")


Test set: Average loss: 0.1776, Accuracy: 9488/10000 (95%)




Test set: Average loss: 0.1351, Accuracy: 9594/10000 (96%)




Test set: Average loss: 0.1143, Accuracy: 9637/10000 (96%)




Test set: Average loss: 0.0985, Accuracy: 9695/10000 (97%)




Test set: Average loss: 0.0906, Accuracy: 9700/10000 (97%)




Test set: Average loss: 0.0841, Accuracy: 9728/10000 (97%)




Test set: Average loss: 0.0786, Accuracy: 9741/10000 (97%)




Test set: Average loss: 0.0726, Accuracy: 9759/10000 (98%)




Test set: Average loss: 0.0681, Accuracy: 9774/10000 (98%)


Test set: Average loss: 0.0650, Accuracy: 9789/10000 (98%)




Test set: Average loss: 0.0650, Accuracy: 9800/10000 (98%)




Test set: Average loss: 0.0591, Accuracy: 9812/10000 (98%)




Test set: Average loss: 0.0574, Accuracy: 9827/10000 (98%)




Test set: Average loss: 0.0562, Accuracy: 9819/10000 (98%)





In [None]:
import matplotlib.pyplot as plt

image = mnist_trainset.data[52009]
img = image.view(1, 784).float()
with th.no_grad():
    logps = model(img)

ps = th.exp(logps)
probab = list(ps.numpy())
print("Predicted Digit =", probab.index(max(probab)))
plt.imshow(img.view(28, 28))

In [None]:
sy.BaseDataset()