# MNIST training A Bigger Model

In [1]:
import torch
import torchvision

import numpy as np
import math

In [2]:
mnist_trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))
mnist_testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

In [3]:
## Parameters:
n_epochs = 3
batch_size_train = 10000
batch_size_test = 500
log_interval = 500

In [4]:
train_loader = torch.utils.data.DataLoader(mnist_trainset,batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_testset,batch_size=batch_size_test, shuffle=False)

In [5]:
import papayaclient

In [6]:
class TheModel(torch.nn.Module):

    def __init__(self):
        super(TheModel, self).__init__()

        self.linear1 = torch.nn.Linear(784, 49)
        self.linear2 = torch.nn.Linear(49, 49)
        self.linear3 = torch.nn.Linear(49, 10)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x1 = x.flatten(start_dim = 1)
        return self.linear3(self.relu(self.linear2(self.relu(self.linear1(x1)))))

In [7]:
clients = []
for batchno, (ex_data, ex_labels) in enumerate(train_loader):
    clients.append(papayaclient.PapayaClient(dat = ex_data,
                                            labs = ex_labels,
                                            batch_sz = 500,
                                            num_partners = 5,
                                            model_class = TheModel,
                                            loss_fn = torch.nn.CrossEntropyLoss))


In [8]:
## Train the Nodes
num_epochs_total = 100
num_epochs_per_swap = 5
num_times = (num_epochs_total // num_epochs_per_swap)
for i in range(0, num_times):
    for n in clients:
        for j in range(0, num_epochs_per_swap):
            n.model_train_epoch()
            print(n.logs['stringy'][n.epochs_trained - 1])
    if i > 1 and i < num_times - 1 :
        for n in clients:
            n.select_partners(3)
        for n in clients:
            for i in range(0, 4) :
                n.update_partner_weights()
            n.average_partners()

node3010epoch 0 loss 2.261340856552124
node3010epoch 1 loss 2.212891101837158
node3010epoch 2 loss 2.1522250175476074
node3010epoch 3 loss 2.0761609077453613
node3010epoch 4 loss 1.9814214706420898
node3297epoch 0 loss 2.2594196796417236
node3297epoch 1 loss 2.2021450996398926
node3297epoch 2 loss 2.1328916549682617
node3297epoch 3 loss 2.0473015308380127
node3297epoch 4 loss 1.9438883066177368
node3400epoch 0 loss 2.2872397899627686
node3400epoch 1 loss 2.246544361114502
node3400epoch 2 loss 2.1938352584838867
node3400epoch 3 loss 2.1254990100860596
node3400epoch 4 loss 2.038665533065796
node347epoch 0 loss 2.2778453826904297
node347epoch 1 loss 2.2379350662231445
node347epoch 2 loss 2.1820764541625977
node347epoch 3 loss 2.1071786880493164
node347epoch 4 loss 2.0125720500946045
node4922epoch 0 loss 2.2780885696411133
node4922epoch 1 loss 2.2403464317321777
node4922epoch 2 loss 2.1939237117767334
node4922epoch 3 loss 2.13490891456604
node4922epoch 4 loss 2.058116912841797
node2716epoc

In [9]:
for c in clients :
    print(c.logs['stringy'][99])

node3010epoch 99 loss 0.3654802143573761
node3297epoch 99 loss 0.3899177014827728
node3400epoch 99 loss 0.3436969220638275
node347epoch 99 loss 0.3100656569004059
node4922epoch 99 loss 0.381700336933136
node2716epoch 99 loss 0.364274263381958


In [10]:
accuracies = {}
with torch.no_grad():
    for i in clients :
        accuracies_node = []
        for batchno, (ex_data, ex_labels) in enumerate(test_loader) :
            accuracies_node.append(((i.model.forward(ex_data).argmax(dim = 1) == ex_labels).float().mean()).item())
        accuracies[i.node_id] = np.array(accuracies_node).mean()

In [11]:
accuracies

{3010: 0.9030000001192093,
 3297: 0.9026999980211258,
 3400: 0.9018999993801117,
 347: 0.9,
 4922: 0.900900000333786,
 2716: 0.9022999972105026}