In [1]:
import torch
from torchvision.datasets import MNIST
from torch.utils.data import random_split, DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

import pickle
import base64
import sys

import hashlib
import json
from time import time
from time import sleep
import requests

%matplotlib inline
plt.rcParams['figure.figsize'] = [5, 5]

In [2]:
train_dataset = MNIST('/kaggle/working', train=True, download=True, transform=transforms.ToTensor())
test_dataset = MNIST('/kaggle/working', train=False, download=True, transform=transforms.ToTensor())

train_dataset, dev_dataset = random_split(train_dataset, [int(len(train_dataset) * 0.83), int(len(train_dataset) * 0.17)])

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /kaggle/working/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [04:41<00:00, 35222.05it/s]


Extracting /kaggle/working/MNIST/raw/train-images-idx3-ubyte.gz to /kaggle/working/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /kaggle/working/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 937024.33it/s]


Extracting /kaggle/working/MNIST/raw/train-labels-idx1-ubyte.gz to /kaggle/working/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /kaggle/working/MNIST/raw/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /kaggle/working/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 9788296.32it/s]


Extracting /kaggle/working/MNIST/raw/t10k-images-idx3-ubyte.gz to /kaggle/working/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /kaggle/working/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 10214760.73it/s]


Extracting /kaggle/working/MNIST/raw/t10k-labels-idx1-ubyte.gz to /kaggle/working/MNIST/raw



In [3]:
total_train_size = len(train_dataset)
total_test_size = len(test_dataset)
total_dev_size = len(dev_dataset)

classes = 10
input_dim = 784

batch_size = 128
epochs_per_client = 3
learning_rate = 2e-2

In [16]:
# CONFIGURE BEFORE RUNNIG
# MAKE SURE SAME FOR BOTH SERVER AND ALL CLIENTS' CODE
num_clients = 2
rounds = 1

In [5]:
# CONFIGURE CLIENT ID TO CONTAIN ALL POSSIBLE CLIENT IDs
# Ex. If there are 3 clients, the only possible client ids are 0, 1, 2
# The client_id variable will therefore be equal to [0,1,2]
client_id = [0,1]

In [6]:
client_code = [f'L{client_id[i]+1}' for i in client_id]

In [8]:
CHAIN_URL = 'https://2347b804-83c2-481a-9697-c5f6de62b5ab-00-2qo920vvfssrt.spock.replit.dev/chain'
MINE_URL = 'https://2347b804-83c2-481a-9697-c5f6de62b5ab-00-2qo920vvfssrt.spock.replit.dev/mine'
SEND_TRNS_URL = 'https://2347b804-83c2-481a-9697-c5f6de62b5ab-00-2qo920vvfssrt.spock.replit.dev/transactions/new'
GET_TRNS_URL = 'https://2347b804-83c2-481a-9697-c5f6de62b5ab-00-2qo920vvfssrt.spock.replit.dev/current'
PREV_BLK_URL = 'https://2347b804-83c2-481a-9697-c5f6de62b5ab-00-2qo920vvfssrt.spock.replit.dev/last'

In [9]:
total_train_size, total_dev_size, total_test_size

(49800, 10200, 10000)

In [10]:
def get_device():
    return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader(DataLoader):
        def __init__(self, dl, device):
            self.dl = dl
            self.device = device

        def __iter__(self):
            for batch in self.dl:
                yield to_device(batch, self.device)

        def __len__(self):
            return len(self.dl)

device = get_device()

In [11]:
class FederatedNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 20, 7)
        self.conv2 = torch.nn.Conv2d(20, 40, 7)
        self.maxpool = torch.nn.MaxPool2d(2, 2)
        self.flatten = torch.nn.Flatten()
        self.linear = torch.nn.Linear(2560, 10)
        self.non_linearity = torch.nn.functional.relu
        self.track_layers = {'conv1': self.conv1, 'conv2': self.conv2, 'linear': self.linear}

    def forward(self, x_batch):
        out = self.conv1(x_batch)
        out = self.non_linearity(out)
        out = self.conv2(out)
        out = self.non_linearity(out)
        out = self.maxpool(out)
        out = self.flatten(out)
        out = self.linear(out)
        return out

    def get_track_layers(self):
        return self.track_layers

    def apply_parameters(self, parameters_dict):
        with torch.no_grad():
            for layer_name in parameters_dict:
                self.track_layers[layer_name].weight.data *= 0
                self.track_layers[layer_name].bias.data *= 0
                self.track_layers[layer_name].weight.data += parameters_dict[layer_name]['weight']
                self.track_layers[layer_name].bias.data += parameters_dict[layer_name]['bias']

    def get_parameters(self):
        parameters_dict = dict()
        for layer_name in self.track_layers:
            parameters_dict[layer_name] = {
                'weight': self.track_layers[layer_name].weight.data,
                'bias': self.track_layers[layer_name].bias.data
            }
        return parameters_dict

    def batch_accuracy(self, outputs, labels):
        with torch.no_grad():
            _, predictions = torch.max(outputs, dim=1)
            return torch.tensor(torch.sum(predictions == labels).item() / len(predictions))

    def _process_batch(self, batch):
        images, labels = batch
        outputs = self(images)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        accuracy = self.batch_accuracy(outputs, labels)
        return (loss, accuracy)

    def fit(self, dataset, epochs, lr, batch_size=128, opt=torch.optim.SGD):
        dataloader = DeviceDataLoader(DataLoader(dataset, batch_size, shuffle=True), device)
        optimizer = opt(self.parameters(), lr)
        history = []
        for epoch in range(epochs):
            losses = []
            accs = []
            for batch in dataloader:
                loss, acc = self._process_batch(batch)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                loss.detach()
                losses.append(loss)
                accs.append(acc)
            avg_loss = torch.stack(losses).mean().item()
            avg_acc = torch.stack(accs).mean().item()
            history.append((avg_loss, avg_acc))
        return history

    def evaluate(self, dataset, batch_size=128):
        dataloader = DeviceDataLoader(DataLoader(dataset, batch_size), device)
        losses = []
        accs = []
        with torch.no_grad():
            for batch in dataloader:
                loss, acc = self._process_batch(batch)
                losses.append(loss)
                accs.append(acc)
        avg_loss = torch.stack(losses).mean().item()
        avg_acc = torch.stack(accs).mean().item()
        return (avg_loss, avg_acc)

In [12]:
class Client:
    def __init__(self, client_id, dataset):
        self.client_id = client_id
        self.dataset = dataset

    def get_dataset_size(self):
        return len(self.dataset)

    def get_client_id(self):
        return self.client_id

    def train(self, parameters_dict):
        net = to_device(FederatedNet(), device)
        net.apply_parameters(parameters_dict)
        train_history = net.fit(self.dataset, epochs_per_client, learning_rate, batch_size)
        print('{}: Loss = {}, Accuracy = {}'.format(self.client_id, round(train_history[-1][0], 4), round(train_history[-1][1], 4)))
        return net.get_parameters()

In [13]:
def compress_params(params):
  compressed = pickle.dumps(params)
  params_bytes = base64.b64encode(compressed)
  params_bystr = params_bytes.decode('ascii')
  return params_bystr

def decompress_params(params):
  decompressed = params.encode("ascii")
  decompressed = base64.b64decode(decompressed)
  decompressed = pickle.loads(decompressed)
  return decompressed

In [14]:
examples_per_client = total_train_size // num_clients
client_datasets = random_split(train_dataset, [min(i + examples_per_client,
           total_train_size) - i for i in range(0, total_train_size, examples_per_client)])
clients = [Client('client_' + str(i), client_datasets[i]) for i in range(num_clients)]

In [17]:
prev_block_index = None
for i in range(rounds):
  response = requests.get(PREV_BLK_URL)
  block_index = response.json()['chain']['index']

  if i == 0:
    prev_block_index = block_index

  new_block = True if prev_block_index != block_index else False
  while response.json()['chain']['transactions'][-2]['type'] != 'global' and not new_block:
    sleep(45)
    response = requests.get(PREV_BLK_URL)

  prev_block_index = block_index

  curr_global_gradients = response.json()['chain']['transactions'][-2]['gradients']
  curr_global_gradients = decompress_params(curr_global_gradients)

  for j in client_id:
    client_parameters = clients[j].train(curr_global_gradients)
    client_data_size = clients[j].get_dataset_size()
    client_parameters_compressed = compress_params(client_parameters)
    new_transaction = {
        'type': 'local',
        'trainer': client_code[j],
        'gradient': client_parameters_compressed,
        'data-size': client_data_size
    }

    # Send the transaction to the server
    response = requests.post(SEND_TRNS_URL, json=new_transaction)
    print(response.json())

client_0: Loss = 0.2074, Accuracy = 0.9411
{'message': 'Transaction will be added to Block 3'}
client_1: Loss = 0.2047, Accuracy = 0.9395
{'message': 'Transaction will be added to Block 3'}
