In [2]:
from torchvision import datasets,transforms
import torch
import random
import torch.nn.functional as f
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchsummary import summary

In [1]:
'''LeNet in PyTorch.'''
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self,input_channel=3,output_channel=10):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(input_channel, 6, 5)
        self.maxpool1 = nn.MaxPool2d((2,2))
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.maxpool2 = nn.MaxPool2d((2,2))
        self.fc1 = nn.Linear(256,120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, output_channel)
        self.output_channel = output_channel

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.maxpool1(out)
        out = F.relu(self.conv2(out))
        out = self.maxpool2(out)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = F.softmax(self.fc3(out))
        return out


class LeNetServer(nn.Module):
    def __init__(self,input_channel=256,output_channel=10):
        super(LeNetServer, self).__init__()
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.maxpool2 = nn.MaxPool2d((2,2))
        self.fc1 = nn.Linear(input_channel,120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, output_channel)
        self.output_channel = output_channel

    def forward(self, x):
        out = F.relu(self.conv2(x))
        out = self.maxpool2(out)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = F.softmax(self.fc3(out))
        return out

class LeNetClient(nn.Module):
    def __init__(self,input_channel=3):
        super(LeNetClient, self).__init__()
        self.conv1 = nn.Conv2d(input_channel, 6, 5)
        self.maxpool1 = nn.MaxPool2d((2,2))

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = self.maxpool1(out)
        return out


In [3]:
'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn
import torch.nn.functional as F


cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512],
    'VGG16_client' : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M'],
    'VGG16_server' : [512, 512, 512, 'M', 512, 512, 512],
    'VGG11_client' : [64, 'M'],
    'VGG11_server' : [128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
}


class VGG(nn.Module):
    def __init__(self, vgg_name, in_channels = 3,dense_output = 10):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name],in_channels)
        self.classifier = nn.Sequential(nn.Linear(512, 256),
                            nn.ReLU(inplace=True),nn.Dropout(p=0.25),nn.Linear(256, 128),
                            nn.ReLU(inplace=True),nn.Dropout(p=0.25),nn.Linear(128,dense_output))

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        out = F.softmax(out)
        return out

    def _make_layers(self, cfg, in_channels):
        layers = []
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        return nn.Sequential(*layers)

class VGGClient(nn.Module):
    def __init__(self, vgg_name,in_channels=3):
        super(VGGClient, self).__init__()
        self.features = self._make_layers(cfg[vgg_name],in_channels)

    def forward(self, x):
        out = self.features(x)
        return out

    def _make_layers(self, cfg,in_channels):
        layers = []
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        return nn.Sequential(*layers)


In [5]:
def get_emnist_dataset():
    train_dataset = datasets.EMNIST('./data', split='balanced', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1751,), (0.3267,))
                       ]))

    test_dataset = datasets.EMNIST('./data', split='balanced', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1751,), (0.3267,))
                       ]))

    train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=64, shuffle=False)

    return train_loader,test_loader


In [5]:
def train(server,server_opt,client,client_opt,device,train_loader,epoch):
    server.train()
    criterion = nn.CrossEntropyLoss()
    cli_length = len(client)
    for i in range(cli_length):
        client[i].train()
    client_parameters = None
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        ### picking the client to process this batch of data
        clt = client[batch_idx%cli_length]
        clt_opt = client_opt[batch_idx%cli_length]
        
        if client_parameters:
            clt.load_state_dict(client_parameters.state_dict())

        data, target = data.to(device), target.to(device)
        server_opt.zero_grad()
        clt_opt.zero_grad()
        
        ### passing the result of the client to the server classifier
        intermediate = clt(data)
        intermediate = intermediate.detach().requires_grad_()
        intermediate.to(device)        
        ### calculating the loss at the server
        output = server(intermediate)
        loss = criterion(output,target)
        loss.backward()
        server_opt.step()
        
        ### passing the gradient calculated at the server to the client
        grad = intermediate.grad.clone()
        intermediate.backward(grad)
        clt_opt.step()
        client_parameters = clt

        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        ### printing logs
        if batch_idx % 400 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    
    print("training accuracy : {:.2f}%".format(100.0*correct/len(train_loader.dataset)))

def test(client,server,device,test_loader):
    server.eval()
    criterion = nn.CrossEntropyLoss()
    for i in range(len(client)):
        client[i].eval()

    with torch.no_grad():
        for i in range(len(client)):
            test_loss = 0
            correct = 0
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                output = client[i](data)
                output = server(output)
                test_loss += criterion(output, target).item()  # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()
                
            test_loss /= len(test_loader.dataset)
            print('\nTest set for client-{}: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(i,
                test_loss, correct, len(test_loader.dataset),
                100. * correct / len(test_loader.dataset)))

def adjust_learning_rate(optimizer, init_lr, epoch):
    lr = init_lr * (0.5 ** (epoch // 50))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [None]:
if __name__ == "__main__":
    random.seed(7)
    torch.manual_seed(7)
    if torch.cuda.is_available():  
        dev = "cuda:0" 
    else:  
        dev = "cpu"
    device = torch.device(dev)
    train_loader,test_loader = get_mnist_dataset()

    server = VGG(vgg_name='VGG16_server',in_channels=256,dense_output=47).to(device)
    server_optim = optim.SGD(server.parameters(),lr=1e-3,weight_decay=5e-4)

    clients=[]
    client_optim=[]
    for i in range(5):
        cli = VGGClient(vgg_name='VGG16_client',in_channels=1).to(device)
        clients.append(cli)
        client_optim.append(optim.SGD(cli.parameters(),lr=1e-3,weight_decay=5e-4))
    
    print("\n\nModel Summary")
    print("\nClient:")
    summary(clients[0],(1,28,28))
    print("\n\nServer:")
    summary(server,(256,3,3))

    init_lr = 0.1
    for i in range(1,200):
        adjust_learning_rate(server_optim,init_lr,i)
        for op in client_optim:
            adjust_learning_rate(op,init_lr,i)
        train(server,server_optim,clients,client_optim,device,train_loader,i)
        test(clients,server,device,test_loader)

    torch.save({'server':server.state_dict(),
            'client1':clients[0].state_dict(),
            'client2':clients[1].state_dict(),
            'client3':clients[2].state_dict(),
            'client4':clients[2].state_dict(),
            'client5':clients[3].state_dict()},"split-model.pt")




Model Summary

Client:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 28, 28]             640
       BatchNorm2d-2           [-1, 64, 28, 28]             128
              ReLU-3           [-1, 64, 28, 28]               0
            Conv2d-4           [-1, 64, 28, 28]          36,928
       BatchNorm2d-5           [-1, 64, 28, 28]             128
              ReLU-6           [-1, 64, 28, 28]               0
         MaxPool2d-7           [-1, 64, 14, 14]               0
            Conv2d-8          [-1, 128, 14, 14]          73,856
       BatchNorm2d-9          [-1, 128, 14, 14]             256
             ReLU-10          [-1, 128, 14, 14]               0
           Conv2d-11          [-1, 128, 14, 14]         147,584
      BatchNorm2d-12          [-1, 128, 14, 14]             256
             ReLU-13          [-1, 128, 14, 14]               0
        MaxPoo



training accuracy : 10.97%

Test set for client-0: Average loss: 0.0580, Accuracy: 3183/18800 (17%)


Test set for client-1: Average loss: 0.0580, Accuracy: 3184/18800 (17%)


Test set for client-2: Average loss: 0.0580, Accuracy: 3185/18800 (17%)


Test set for client-3: Average loss: 0.0580, Accuracy: 3189/18800 (17%)


Test set for client-4: Average loss: 0.0580, Accuracy: 3182/18800 (17%)

training accuracy : 23.76%

Test set for client-0: Average loss: 0.0563, Accuracy: 5222/18800 (28%)


Test set for client-1: Average loss: 0.0563, Accuracy: 5217/18800 (28%)


Test set for client-2: Average loss: 0.0563, Accuracy: 5218/18800 (28%)


Test set for client-3: Average loss: 0.0563, Accuracy: 5224/18800 (28%)


Test set for client-4: Average loss: 0.0563, Accuracy: 5223/18800 (28%)

training accuracy : 34.59%

Test set for client-0: Average loss: 0.0544, Accuracy: 7542/18800 (40%)


Test set for client-1: Average loss: 0.0544, Accuracy: 7540/18800 (40%)


Test set for client-2: Average

In [None]:
from torchvision import datasets,transforms
import torch

import random
import torch.nn.functional as f
import torch.nn as nn
import torch.optim as optim

def get_emnist_dataset():
    train_dataset = datasets.EMNIST('./data',split='balanced', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1751,), (0.3267,))
                       ]))

    test_dataset = datasets.EMNIST('./data', split='balanced',train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1751,), (0.3267,))
                       ]))

    train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=64, shuffle=False)

    return train_loader,test_loader

def train_single_model(client,client_opt,device,train_loader,epoch):
    client.train()
    criterion = nn.CrossEntropyLoss()
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):        
        data, target = data.to(device), target.to(device)
        client_opt.zero_grad()
        output = client(data)
        loss = criterion(output,target)
        loss.backward()
        client_opt.step()

        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        ### printing logs
        if batch_idx % 500 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
        
    print("training accuracy : {:.2f}%".format(100.0*correct/len(train_loader.dataset)))

def test_single_model(client,device,test_loader):
    client.eval()
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = client(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
        
        test_loss /= len(test_loader.dataset)
        print('\nTest set for client: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))

def adjust_learning_rate(optimizer, init_lr, epoch):
    lr = init_lr * (0.5 ** (epoch // 25))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

if __name__ == "__main__":
    random.seed(7)
    torch.manual_seed(7)
    if torch.cuda.is_available():  
        dev = "cuda:0" 
    else:  
        dev = "cpu"
    device = torch.device(dev)

    train_loader,test_loader = get_emnist_dataset()
    
    client = VGG(vgg_name='VGG16',in_channels=1,dense_output=47).to(device)
    client_optim = optim.SGD(client.parameters(),lr=1e-3,weight_decay=5e-4)
    init_lr = 0.1

    for i in range(1,100):
        adjust_learning_rate(client_optim,init_lr,i)
        train_single_model(client,client_optim,device,train_loader,i)
        test_single_model(client,device,test_loader)
    
    torch.save(client.state_dict(),"single-model.pt")



training accuracy : 10.49%

Test set for client: Average loss: 0.0579, Accuracy: 3096/18800 (16%)

training accuracy : 22.02%

Test set for client: Average loss: 0.0558, Accuracy: 5758/18800 (31%)

training accuracy : 37.23%

Test set for client: Average loss: 0.0539, Accuracy: 8201/18800 (44%)

training accuracy : 49.17%

Test set for client: Average loss: 0.0524, Accuracy: 9958/18800 (53%)

training accuracy : 58.81%

Test set for client: Average loss: 0.0508, Accuracy: 11972/18800 (64%)

training accuracy : 74.98%

Test set for client: Average loss: 0.0484, Accuracy: 14948/18800 (80%)

training accuracy : 80.47%

Test set for client: Average loss: 0.0484, Accuracy: 14904/18800 (79%)

training accuracy : 80.85%

Test set for client: Average loss: 0.0481, Accuracy: 15260/18800 (81%)

training accuracy : 81.38%

Test set for client: Average loss: 0.0480, Accuracy: 15333/18800 (82%)

training accuracy : 82.95%

Test set for client: Average loss: 0.0481, Accuracy: 15248/18800 (81%)

trai