#Scope


*   Understanding Split Learning
*   Setting up the Training Pipeline
*   Distributed Forward and Backward Pass
*   Scaling to multiple clients




#What is Split Learning?
Split Learning is a technique for distributed training and prediction without sharing raw data. The simplest idea here is to split a neural network architecture into two parts such that the first part can be executed by a trusted client and the output of the first part is fed to the second part which is typically executed by the sever.
This approach enables multi-client, multi-insitution collaboration for deep learning algorithms.

Set runtime as GPU

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.optim as optim
from torch.autograd import Variable

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [4]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

#CIFAR 10 is a dataset of natural images consisting of 50k training images and 10k test
#Every image is labelled with one of the following class
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

trainset = torchvision.datasets.CIFAR10(root='./data',train = True, download = True,transform = transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=128,shuffle= True,num_workers = 2)

testset = torchvision.datasets.CIFAR10(root = './data', train = False,download = True,transform = transform)
testloader = torch.utils.data.DataLoader(testset,batch_size = 128,shuffle = True,num_workers = 2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:11<00:00, 14782756.86it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Pytorch provides nn.Module as a base class for all neural network modules. This allows easily creating a parameterized model.

torchvision package consists of popular datasets, model architectures and common image transformations for computer vision. Here we use torchvision's model API to use ResNet18 architecture.

In [5]:
# Explain nn.Module and explain the forward and backward pass
class ResNet18Client(nn.Module):
    """ docstring for ResNet """
    #Explain initialize (listing the neural network architecture and other related parameters)
    def __init__(self,config):
        super(ResNet18Client,self).__init__()
        #Declare where we are going to split the NN, if cut layer close to output most computation will happen in the client side
        self.cut_layer = config["cut_layer"]

        #Load the resnet model
        self.model = models.resnet18(pretrained = False)
        self.model = nn.ModuleList(self.model.children())
        self.model = nn.Sequential(*self.model)

    def forward(self,x):
        for i,l in enumerate(self.model):
            if i> self.cut_layer:
                break
            x = l(x)
        return x

In [8]:
class ResNet18Server(nn.Module):
    """ docstring for ResNet """
    def __init__(self,config):
        super(ResNet18Server,self).__init__()
        self.logits = config["logits"]
        self.cut_layer = config["cut_layer"]

        self.model = models.resnet18(pretrained = False)
        num_ftrs = self.model.fc.in_features
        #
        self.model.fc = nn.Sequential(nn.Flatten(),nn.Linear(num_ftrs,self.logits))
        self.model = nn.ModuleList(self.model.children())
        self.model = nn.Sequential(*self.model)

    def forward(self,x):
        for i,l in enumerate(self.model):
            #continue until you are in the cut layer, skip earlier layers
            if i<=self.cut_layer:
                continue
            x = l(x)
        return nn.functional.softmax(x,dim=1)

#Initialize the Models

In [10]:
config = {'cut_layer':3,"logits":10} #logits = 10 cause we are using ciphar10 that has 10 classes
client_model = ResNet18Client(config).to(device)
server_model = ResNet18Server(config).to(device)

Set up the optimizer

In [12]:
criterion = nn.CrossEntropyLoss()
client_optimizer = optim.SGD(client_model.parameters(),lr=0.01,momentum=0.9)
server_optimizer = optim.SGD(server_model.parameters(),lr=0.01,momentum = 0.9)


Perform training

In [14]:
num_epochs = 50
for epoch in range(num_epochs):
    running_loss = 0.0
    for i,data in enumerate(trainloader,0):
        inputs, labels = data[0].to(device), data[1].to(device)

        client_optimizer.zero_grad()
        server_optimizer.zero_grad()

        #Client Part
        activation = client_model(inputs)
        server_inputs = activation.detach().clone() # get the outputs of the client and input them to the server

        #Simulation of sever part is happening in this portion
        #Server part
        server_inputs = Variable(server_inputs,requires_grad = True)
        outputs = server_model(server_inputs)
        loss = criterion(outputs,labels)
        loss.backward()

        #Server optimization
        server_optimizer.step()

        #Simulation of Client Happening in this portion
        #Client optimization
        activation.backward(server_inputs.grad)
        client_optimizer.step()

        running_loss += loss.item()

        if i % 200 == 199:
            print("[{},{}] loss: {}".format(epoch+1,i+1,running_loss/200))

[1,200] loss: 2.113515778183937
[2,200] loss: 1.9625120824575424
[3,200] loss: 1.8984527599811554
[4,200] loss: 1.8545391988754272
[5,200] loss: 1.8295815044641495
[6,200] loss: 1.8033285665512084
[7,200] loss: 1.782121154665947
[8,200] loss: 1.7604520726203918
[9,200] loss: 1.751526979804039
[10,200] loss: 1.7327482372522354
[11,200] loss: 1.7225162214040757
[12,200] loss: 1.7111092078685761
[13,200] loss: 1.6973230242729187
[14,200] loss: 1.678701884150505
[15,200] loss: 1.6723183631896972
[16,200] loss: 1.6656447041034699
[17,200] loss: 1.6526123923063278
[18,200] loss: 1.6437605232000352
[19,200] loss: 1.6353329992294312
[20,200] loss: 1.6280428624153138
[21,200] loss: 1.6256264501810074
[22,200] loss: 1.6161740392446518
[23,200] loss: 1.6182483959197997
[24,200] loss: 1.607930174469948
[25,200] loss: 1.5965617144107818
[26,200] loss: 1.5969991654157638
[27,200] loss: 1.588278787136078
[28,200] loss: 1.5817531907558442
[29,200] loss: 1.5811073738336563
[30,200] loss: 1.575644938945

In [26]:
# save the model
!mkdir = 'saved_models'

In [27]:
client_model_path = './saved_models/trained_client_model.pt'
server_model_path = './saved_models/trained_server_model.pt'
torch.save(client_model.state_dict(),client_model_path)
torch.save(server_model.state_dict(),server_model_path)


In [28]:
# Evaluation
total, correct = 0, 0
with torch.no_grad():
    for data in testloader:
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = server_model(client_model(inputs))
        _, predicted = torch.max(outputs.data,1)
        total += labels.size(0)
        correct += (predicted==labels).sum().item()
    print(correct/total)

0.7283


# Scaling to multiple clients
We will implement a basic round robin protocol for multiple clients to perform distributed training.
Here each client takes part in training and then send its weight to the subsequent client who continues with the  training.
Last 50mins of https://www.youtube.com/watch?v=VPL6ELWdJbg

In [29]:
total_client_num = 10
#Initialize multiple clients
client_model_list = [ResNet18Client(config).to(device) for client_num in range(total_client_num)]
#Initialize multiple optimizers
client_optimizer_list = [optim.SGD(client_model_list[client_num].parameters(),lr=0.01,momentum =0.9) for client_num in range(total_client_num)]
server = ResNet18Server(config).to(device)
server_optimizer = optim.SGD(server.parameters(),lr=0.01,momentum = 0.9)



#Training pipeline for multiple clients

In [34]:
num_epochs = 50 // total_client_num
for epoch in range(num_epochs):
    # Iterate over multiple clients
    for client_num in range(total_client_num):
        print("Current active client is {}".format(client_num))
        client = client_model_list[client_num]
        clinet_optimizer = client_optimizer_list[client_num]

        #Logic to load the weights from the previous client
        if client_num ==0:
            if epoch != 0:
                prev_client = total_client_num - 1
                prev_client_weights = client_model_list[prev_client].state_dict()
                client.load_state_dict(prev_client_weights)
                print("Loaded client {}'s weight successfully".format(prev_client))
        else:
            prev_client = client_num -1
            prev_client_weights = client_model_list[prev_client].state_dict()
            client.load_state_dict(prev_client_weights)
            print("Loaded client {}'s weight successfully".format(prev_client))

        client.train()
        running_loss = 0.0
        total_samples = 0
        # All clients have their own data source and we are using a single trainloader here for illustration fo te simulated setup
        for i, data in enumerate(trainloader,0):
            inputs, labels = data[0].to(device), data[1].to(device)

            client_optimizer.zero_grad()
            server_optimizer.zero_grad()

            # Client part
            activations = client(inputs)
            server_inputs = activations.detach().clone()

            # Server part
            server_inputs  = Variable(server_inputs,requires_grad = True)
            outputs = server(server_inputs)
            loss = criterion(outputs,labels)
            loss.backward()
            server_optimizer.step()

            running_loss += loss.item()
            total_samples += labels.shape[0]

            # Client part
            activations.backward(server_inputs.grad)
            client_optimizer.step()

            if i % 50 ==1:
                print( "Cleint: {}, Epoch: {}, Iteration: {}, Loss: {:.4f}".format(client_num,epoch,i,running_loss/(i+1)))



Current active client is 0
Cleint: 0, Epoch: 0, Iteration: 1, Loss: 2.3072
Cleint: 0, Epoch: 0, Iteration: 51, Loss: 2.2307
Cleint: 0, Epoch: 0, Iteration: 101, Loss: 2.1692
Cleint: 0, Epoch: 0, Iteration: 151, Loss: 2.1347
Cleint: 0, Epoch: 0, Iteration: 201, Loss: 2.1089
Cleint: 0, Epoch: 0, Iteration: 251, Loss: 2.0867
Cleint: 0, Epoch: 0, Iteration: 301, Loss: 2.0710
Cleint: 0, Epoch: 0, Iteration: 351, Loss: 2.0589
Current active client is 1
Loaded client 0's weight successfully
Cleint: 1, Epoch: 0, Iteration: 1, Loss: 1.9112
Cleint: 1, Epoch: 0, Iteration: 51, Loss: 1.9247
Cleint: 1, Epoch: 0, Iteration: 101, Loss: 1.9362
Cleint: 1, Epoch: 0, Iteration: 151, Loss: 1.9368
Cleint: 1, Epoch: 0, Iteration: 201, Loss: 1.9348
Cleint: 1, Epoch: 0, Iteration: 251, Loss: 1.9342
Cleint: 1, Epoch: 0, Iteration: 301, Loss: 1.9323
Cleint: 1, Epoch: 0, Iteration: 351, Loss: 1.9318
Current active client is 2
Loaded client 1's weight successfully
Cleint: 2, Epoch: 0, Iteration: 1, Loss: 1.8543
C

# Future Directions
1. How to make the round robin protocol asynchronous?
2. Different data distribution across clients?
3. Different topologies of client models, multiple servers, etc.