In [None]:
from torchvision import datasets,transforms
import torch
import random
import torch.nn.functional as f
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary

In [None]:
def get_mnist_dataset(batch_size=32):
    train_dataset = datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ]))

    test_dataset = datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ]))

    train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size, shuffle=False)
    return train_loader,test_loader

In [None]:
def train_single_model(model,optimiser,device,train_loader,epoch):
    model.train()
    criterion = nn.CrossEntropyLoss()
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):        
        data, target = data.to(device), target.to(device)
        optimiser.zero_grad()
        output = model(data)
        loss = criterion(output,target)
        loss.backward()
        optimiser.step()

        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        ### printing logs
        if batch_idx % 300 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    
    print('\nTrain set for client: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))

In [None]:
def test_single_model(model,device,test_loader):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
        
        test_loss /= len(test_loader.dataset)
        history = print('\nTest set for client: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))

In [None]:
import torch.nn as nn

class Flatten(nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)

lenet = lambda input_channels,output_channels : nn.Sequential(
            nn.Conv2d(input_channels, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d((2,2)),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d((2,2)),
            Flatten(),
            nn.Linear(256,120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, output_channels),
            nn.Softmax())

lenet_client = lambda input_channels : nn.Sequential(
            nn.Conv2d(input_channels, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d((2,2)),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d((2,2)),
            Flatten())

lenet_server = lambda input_channels,output_channels : nn.Sequential(
            nn.Linear(input_channels,120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, output_channels),
            nn.Softmax())

In [None]:
if __name__ == "__main__":
    random.seed(7)
    torch.manual_seed(7)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if torch.cuda.is_available():  
        dev = "cuda:0" 
    else:  
        dev = "cpu"
    device = torch.device(dev)

    train_loader,test_loader = get_mnist_dataset(batch_size=64)

    model = lenet(input_channels=1,output_channels=10).to(device)
    client_optim = optim.SGD(model.parameters(),lr=1e-2,weight_decay=5e-3)

    for i in range(1,50):
        train_single_model(model,client_optim,device,train_loader,i)
        test_single_model(model,device,test_loader)
    
    torch.save(model.state_dict(),"single-model.pt")

  input = module(input)



Train set for client: Accuracy: 7508/60000 (13%)


Test set for client: Average loss: 0.0361, Accuracy: 1609/10000 (16%)


Train set for client: Accuracy: 12553/60000 (21%)


Test set for client: Average loss: 0.0361, Accuracy: 2174/10000 (22%)


Train set for client: Accuracy: 13038/60000 (22%)


Test set for client: Average loss: 0.0361, Accuracy: 2307/10000 (23%)


Train set for client: Accuracy: 18450/60000 (31%)


Test set for client: Average loss: 0.0359, Accuracy: 3288/10000 (33%)


Train set for client: Accuracy: 20024/60000 (33%)


Test set for client: Average loss: 0.0307, Accuracy: 5981/10000 (60%)


Train set for client: Accuracy: 42789/60000 (71%)


Test set for client: Average loss: 0.0266, Accuracy: 7842/10000 (78%)


Train set for client: Accuracy: 47203/60000 (79%)


Test set for client: Average loss: 0.0262, Accuracy: 8034/10000 (80%)


Train set for client: Accuracy: 48461/60000 (81%)


Test set for client: Average loss: 0.0260, Accuracy: 8222/10000 (82%)


Train se

In [None]:
def train_model(client,client_optimiser,server,server_optimiser,device,train_loader,epoch):
    client.train()
    server.train()
    criterion = nn.CrossEntropyLoss()
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):        
        data, target = data.to(device), target.to(device)
        client_optimiser.zero_grad()
        server_optimiser.zero_grad()

        ### execute client - feed forward network
        intermediate = client(data)
        remote = intermediate.detach().requires_grad_()
        ### execute server - feed forward network
        output = server(remote)
        loss = criterion(output,target)
        loss.backward()
        ### execute client back propagation
        grad = remote.grad.clone()
        intermediate.backward(grad)
        ### optimiser step
        server_optimiser.step()
        client_optimiser.step()

        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        ### printing logs
        if batch_idx % 300 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    
    print('\nTrain set for client: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))
    
def test_model(client,server,device,test_loader):
  client.eval()
  server.eval()
  criterion = nn.CrossEntropyLoss()
  with torch.no_grad():
      test_loss = 0
      correct = 0
      for data, target in test_loader:
          data, target = data.to(device), target.to(device)
          intermediate = client(data)
          output = server(intermediate)
          test_loss += criterion(output, target).item()  # sum up batch loss
          pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
          correct += pred.eq(target.view_as(pred)).sum().item()
      
      test_loss /= len(test_loader.dataset)
      print('\nTest set for client: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
          test_loss, correct, len(test_loader.dataset),
          100. * correct / len(test_loader.dataset)))

In [None]:
if __name__ == "__main__":
    random.seed(7)
    torch.manual_seed(7)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if torch.cuda.is_available():  
        dev = "cuda:0" 
    else:  
        dev = "cpu"
    device = torch.device(dev)

    train_loader,test_loader = get_mnist_dataset(batch_size=64)

    client = lenet_client(input_channels=1).to(device)
    client_optim = optim.SGD(client.parameters(),lr=1e-2,weight_decay=5e-3)

    server = lenet_server(input_channels=256,output_channels=10).to(device)
    server_optim = optim.SGD(server.parameters(),lr=1e-2,weight_decay=5e-3)

    print("\n\nModel Summary")
    print("\nClient:")
    summary(client,(1,28,28))
    print("\n\nServer:")
    summary(server,(1,256))

    for i in range(1,50):
        train_model(client,client_optim,server,server_optim,device,train_loader,i)
        history1 = test_model(client,server,device,test_loader)
    
    torch.save({'server':server.state_dict(),
            'client':client.state_dict()},"split-model-single-agent.pt")




Model Summary

Client:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 24, 24]             156
              ReLU-2            [-1, 6, 24, 24]               0
         MaxPool2d-3            [-1, 6, 12, 12]               0
            Conv2d-4             [-1, 16, 8, 8]           2,416
              ReLU-5             [-1, 16, 8, 8]               0
         MaxPool2d-6             [-1, 16, 4, 4]               0
           Flatten-7                  [-1, 256]               0
Total params: 2,572
Trainable params: 2,572
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.08
Params size (MB): 0.01
Estimated Total Size (MB): 0.09
----------------------------------------------------------------


Server:
----------------------------------------------------------------
        La

  input = module(input)



Train set for client: Accuracy: 7458/60000 (12%)


Test set for client: Average loss: 0.0361, Accuracy: 1615/10000 (16%)


Train set for client: Accuracy: 12449/60000 (21%)


Test set for client: Average loss: 0.0361, Accuracy: 2178/10000 (22%)


Train set for client: Accuracy: 13209/60000 (22%)


Test set for client: Average loss: 0.0361, Accuracy: 2295/10000 (23%)


Train set for client: Accuracy: 18270/60000 (30%)


Test set for client: Average loss: 0.0359, Accuracy: 3288/10000 (33%)


Train set for client: Accuracy: 19952/60000 (33%)


Test set for client: Average loss: 0.0307, Accuracy: 6114/10000 (61%)


Train set for client: Accuracy: 42925/60000 (72%)


Test set for client: Average loss: 0.0266, Accuracy: 7874/10000 (79%)


Train set for client: Accuracy: 50480/60000 (84%)


Test set for client: Average loss: 0.0250, Accuracy: 8856/10000 (89%)


Train set for client: Accuracy: 53595/60000 (89%)


Test set for client: Average loss: 0.0244, Accuracy: 9172/10000 (92%)


Train se

In [None]:
def train_model(clients,client_optimisers,server,server_optimiser,device,train_loader,epoch):
    for client in clients:
        client.train()
    server.train()
    criterion = nn.CrossEntropyLoss()
    correct = 0
    previous_client = None
    for batch_idx, (data, target) in enumerate(train_loader):        
        data, target = data.to(device), target.to(device)

        client = clients[batch_idx%5]
        client_optimiser = client_optimisers[batch_idx%5]
        server_optimiser.zero_grad()
        client_optimiser.zero_grad()

        if previous_client:
            client.load_state_dict(previous_client.state_dict())

        ### execute client - feed forward network
        intermediate = client(data)
        remote = intermediate.detach().requires_grad_()
        ### execute server - feed forward network
        output = server(remote)
        loss = criterion(output,target)
        loss.backward()
        ### execute client back propagation
        grad = remote.grad.clone()
        intermediate.backward(grad)
        ### optimiser step
        server_optimiser.step()
        client_optimiser.step()
        previous_client = client

        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        ### printing logs
        if batch_idx % 300 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    
    print('\nTrain set for client: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))

def test_model(clients,server,device,test_loader):
    for client in clients:
        client.eval()
    server.eval()
    criterion = nn.CrossEntropyLoss()
    for client in clients:
        with torch.no_grad():
            test_loss = 0
            correct = 0
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                intermediate = client(data)
                output = server(intermediate)
                test_loss += criterion(output, target).item()  # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()
            
            test_loss /= len(test_loader.dataset)
            print('\nTest set for client: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
                test_loss, correct, len(test_loader.dataset),
                100. * correct / len(test_loader.dataset)))

if __name__ == "__main__":
    random.seed(7)
    torch.manual_seed(7)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if torch.cuda.is_available():  
        dev = "cuda:0" 
    else:  
        dev = "cpu"
    device = torch.device(dev)

    train_loader,test_loader = get_mnist_dataset(batch_size=64)

    clients=[]
    client_optimisers = []
    for i in range(5):
        client = lenet_client(input_channels=1).to(device)
        client_optim = optim.SGD(client.parameters(),lr=1e-2,weight_decay=5e-3)
        clients.append(client)
        client_optimisers.append(client_optim)

    server = lenet_server(input_channels=256,output_channels=10).to(device)
    server_optim = optim.SGD(server.parameters(),lr=1e-2,weight_decay=5e-3)

    print("\n\nModel Summary")
    print("\nClient:")
    summary(clients[0],(1,28,28))
    print("\n\nServer:")
    summary(server,(1,256))

    for i in range(1,50):
        train_model(clients,client_optimisers,server,server_optim,device,train_loader,i)
        test_model(clients,server,device,test_loader)
    
    torch.save({'server':server.state_dict(),
            'client1':clients[0].state_dict(),
            'client2':clients[1].state_dict(),
            'client3':clients[2].state_dict(),
            'client4':clients[3].state_dict(),
            'client5':clients[4].state_dict()},"split-model-multi-agent.pt")


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw
Processing...


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


Done!


Model Summary

Client:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 6, 24, 24]             156
              ReLU-2            [-1, 6, 24, 24]               0
         MaxPool2d-3            [-1, 6, 12, 12]               0
            Conv2d-4             [-1, 16, 8, 8]           2,416
              ReLU-5             [-1, 16, 8, 8]               0
         MaxPool2d-6             [-1, 16, 4, 4]               0
           Flatten-7                  [-1, 256]               0
Total params: 2,572
Trainable params: 2,572
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.08
Params size (MB): 0.01
Estimated Total Size (MB): 0.09
----------------------------------------------------------------


Server:


  input = module(input)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1               [-1, 1, 120]          30,840
              ReLU-2               [-1, 1, 120]               0
            Linear-3                [-1, 1, 84]          10,164
              ReLU-4                [-1, 1, 84]               0
            Linear-5                [-1, 1, 10]             850
           Softmax-6                [-1, 1, 10]               0
Total params: 41,854
Trainable params: 41,854
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.16
Estimated Total Size (MB): 0.16
----------------------------------------------------------------

Train set for client: Accuracy: 7088/60000 (12%)


Test set for client: Average loss: 0.0361, Accuracy: 1510/10000 (15%)

Test set for client: Average loss: 0.0361, Accura