In [1]:
from torchvision import datasets,transforms
import torch
import random
import torch.nn.functional as f
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary

In [2]:
def get_cifar10_dataset(batch_size=64):
    transformers = transforms.Compose([
                           transforms.RandomHorizontalFlip(),
                           transforms.RandomCrop(32, 4),
                           transforms.ToTensor(),
                           transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                       ])
    train_dataset = datasets.CIFAR10('./data', train=True, download=True,
                       transform=transformers)

    test_dataset = datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
                       ]))

    train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size, shuffle=False)
    return train_loader,test_loader

In [None]:
def train_single_model(model,optimiser,device,train_loader,epoch):
    model.train()
    criterion = nn.CrossEntropyLoss()
    correct = 0
    for batch_idx, (data, target) in enumerate(train_loader):        
        data, target = data.to(device), target.to(device)
        optimiser.zero_grad()
        output = model(data)
        loss = criterion(output,target)
        loss.backward()
        optimiser.step()

        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        ### printing logs
        if batch_idx % 300 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    
    print('\nTrain set for client: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))

In [None]:
def test_single_model(model,device,test_loader):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
        
        test_loss /= len(test_loader.dataset)
        print('\nTest set for client: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))

def adjust_learning_rate(optimizer, init_lr, epoch):
    lr = init_lr * (0.5 ** (epoch // 40))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [4]:
import torch.nn as nn

class Flatten(nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)

cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG16_client' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M'],
    'VGG16_server' : [512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG11_client' : [64, 'M'],
    'VGG11_server' : [128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
}

def make_layers(cfg, in_channels):
    layers = []
    for x in cfg:
        if x == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                        nn.BatchNorm2d(x),
                        nn.ReLU()]
            in_channels = x
    return nn.Sequential(*layers)

vgg = lambda input_channels,output_channels : nn.Sequential(
            make_layers(cfg['VGG16'],input_channels),
            Flatten(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(p=0.25),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(p=0.25),
            nn.Linear(128,output_channels),
            nn.Softmax(dim=1))

vgg_client = lambda input_channels,output_channels : make_layers(cfg['VGG16_client'],input_channels)

vgg_server = lambda input_channels,output_channels : nn.Sequential(
            make_layers(cfg['VGG16_server'],input_channels),
            Flatten(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(p=0.25),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(p=0.25),
            nn.Linear(128,output_channels),
            nn.Softmax(dim=1))

In [None]:

if __name__ == "__main__":
    random.seed(7)
    torch.manual_seed(7)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if torch.cuda.is_available():  
        dev = "cuda:0" 
    else:  
        dev = "cpu"
    device = torch.device(dev)

    train_loader,test_loader = get_cifar10_dataset(batch_size=128)

    model = vgg(input_channels=3,output_channels=10).to(device)
    optimiser = optim.SGD(model.parameters(),lr=1e-3,weight_decay=5e-4)

    print("Model Summary")
    summary(model,(3,32,32))

    init_lr = 0.1
    for i in range(1,200):
        adjust_learning_rate(optimiser,init_lr,i)
        train_single_model(model,optimiser,device,train_loader,i)
        history = test_single_model(model,device,test_loader)
    
    torch.save(model.state_dict(),"single-model.pt")


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Model Summary
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,792
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
            Conv2d-4           [-1, 64, 32, 32]          36,928
       BatchNorm2d-5           [-1, 64, 32, 32]             128
              ReLU-6           [-1, 64, 32, 32]               0
         MaxPool2d-7           [-1, 64, 16, 16]               0
            Conv2d-8          [-1, 128, 16, 16]          73,856
       BatchNorm2d-9          [-1, 128, 16, 16]             256
             ReLU-10          [-1, 128, 16, 16]               0
           Conv2d-11          [-1, 128, 16, 16]         147,584
      BatchNorm2d-12          [-1, 128, 16, 16]             256
             ReLU-13          [-1, 128

In [None]:
def train_model(clients,client_optimisers,server,server_optimiser,device,train_loader,epoch):
    for client in clients:
        client.train()
    server.train()
    criterion = nn.CrossEntropyLoss()
    correct = 0
    client = clients[0]
    client_optimiser = client_optimisers[0]
    for batch_idx, (data, target) in enumerate(train_loader):        
        data, target = data.to(device), target.to(device)
        server_optimiser.zero_grad()
        client_optimiser.zero_grad()

        ### execute client - feed forward network
        intermediate = client(data)
        remote = intermediate.detach().requires_grad_()
        ### execute server - feed forward network
        output = server(remote)
        loss = criterion(output,target)
        loss.backward()
        ### execute client back propagation
        grad = remote.grad.clone()
        intermediate.backward(grad)
        ### optimiser step
        server_optimiser.step()
        client_optimiser.step()

        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        ### printing logs
        if batch_idx % 300 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    
    print('\nTrain set for client: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))

def adjust_learning_rate(optimizer, init_lr, epoch):
    lr = init_lr * (0.5 ** (epoch // 40))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def test_model(clients,server,device,test_loader):
    for client in clients:
        client.eval()
    server.eval()
    criterion = nn.CrossEntropyLoss()
    for client in clients:
        with torch.no_grad():
            test_loss = 0
            correct = 0
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                intermediate = client(data)
                output = server(intermediate)
                test_loss += criterion(output, target).item()  # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()
            
            test_loss /= len(test_loader.dataset)
            print('\nTest set for client: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
                test_loss, correct, len(test_loader.dataset),
                100. * correct / len(test_loader.dataset)))

if __name__ == "__main__":
    random.seed(7)
    torch.manual_seed(7)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if torch.cuda.is_available():  
        dev = "cuda:0" 
    else:  
        dev = "cpu"
    device = torch.device(dev)

    train_loader,test_loader = get_cifar10_dataset(batch_size=128)

    clients=[]
    client_optimisers = []
    for i in range(1):
        client = vgg_client(input_channels=3,output_channels=10).to(device)
        client_optim = optim.SGD(client.parameters(),lr=1e-3,weight_decay=5e-4)
        clients.append(client)
        client_optimisers.append(client_optim)

    server = vgg_server(input_channels=256,output_channels=10).to(device)
    server_optim = optim.SGD(server.parameters(),lr=1e-3,weight_decay=5e-4)

    print("\n\nModel Summary")
    print("\nClient:")
    summary(clients[0],(3,32,32))
    print("\n\nServer:")
    summary(server,(256,4,4))

    init_lr = 0.1
    for i in range(1,200):
        for optimiser in client_optimisers:
            adjust_learning_rate(optimiser,init_lr,i)
        adjust_learning_rate(server_optim,init_lr,i)
        train_model(clients,client_optimisers,server,server_optim,device,train_loader,i)
        history1 = test_model(clients,server,device,test_loader)
    
    torch.save({'server':server.state_dict(),
            'client':clients[0].state_dict()},"split-model-single-agent.pt")


Files already downloaded and verified


Model Summary

Client:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,792
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
            Conv2d-4           [-1, 64, 32, 32]          36,928
       BatchNorm2d-5           [-1, 64, 32, 32]             128
              ReLU-6           [-1, 64, 32, 32]               0
         MaxPool2d-7           [-1, 64, 16, 16]               0
            Conv2d-8          [-1, 128, 16, 16]          73,856
       BatchNorm2d-9          [-1, 128, 16, 16]             256
             ReLU-10          [-1, 128, 16, 16]               0
           Conv2d-11          [-1, 128, 16, 16]         147,584
      BatchNorm2d-12          [-1, 128, 16, 16]             256
             ReLU-13          [-1, 128, 

In [5]:
def train_model(clients,client_optimisers,server,server_optimiser,device,train_loader,epoch):
    for client in clients:
        client.train()
    server.train()
    criterion = nn.CrossEntropyLoss()
    correct = 0
    previous_client = None
    for batch_idx, (data, target) in enumerate(train_loader):        
        data, target = data.to(device), target.to(device)

        client = clients[batch_idx%5]
        client_optimiser = client_optimisers[batch_idx%5]
        server_optimiser.zero_grad()
        client_optimiser.zero_grad()

        if previous_client:
            client.load_state_dict(previous_client.state_dict())

        ### execute client - feed forward network
        intermediate = client(data)
        remote = intermediate.detach().requires_grad_()
        ### execute server - feed forward network
        output = server(remote)
        loss = criterion(output,target)
        loss.backward()
        ### execute client back propagation
        grad = remote.grad.clone()
        intermediate.backward(grad)
        ### optimiser step
        server_optimiser.step()
        client_optimiser.step()
        previous_client = client

        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()
        ### printing logs
        if batch_idx % 300 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    
    print('\nTrain set for client: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))

def adjust_learning_rate(optimizer, init_lr, epoch):
    lr = init_lr * (0.5 ** (epoch // 40))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def test_model(clients,server,device,test_loader):
    for client in clients:
        client.eval()
    server.eval()
    criterion = nn.CrossEntropyLoss()
    for client in clients:
        with torch.no_grad():
            test_loss = 0
            correct = 0
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                intermediate = client(data)
                output = server(intermediate)
                test_loss += criterion(output, target).item()  # sum up batch loss
                pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
                correct += pred.eq(target.view_as(pred)).sum().item()
            
            test_loss /= len(test_loader.dataset)
            print('\nTest set for client: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
                test_loss, correct, len(test_loader.dataset),
                100. * correct / len(test_loader.dataset)))

if __name__ == "__main__":
    random.seed(7)
    torch.manual_seed(7)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if torch.cuda.is_available():  
        dev = "cuda:0" 
    else:  
        dev = "cpu"
    device = torch.device(dev)

    train_loader,test_loader = get_cifar10_dataset(batch_size=128)

    clients=[]
    client_optimisers = []
    for i in range(5):
        client = vgg_client(input_channels=3,output_channels=10).to(device)
        client_optim = optim.SGD(client.parameters(),lr=1e-3,weight_decay=5e-4)
        clients.append(client)
        client_optimisers.append(client_optim)

    server = vgg_server(input_channels=256,output_channels=10).to(device)
    server_optim = optim.SGD(server.parameters(),lr=1e-3,weight_decay=5e-4)

    print("\n\nModel Summary")
    print("\nClient:")
    summary(clients[0],(3,32,32))
    print("\n\nServer:")
    summary(server,(256,4,4))

    init_lr = 0.1
    for i in range(1,50):
        for optimiser in client_optimisers:
            adjust_learning_rate(optimiser,init_lr,i)
        adjust_learning_rate(server_optim,init_lr,i)
        train_model(clients,client_optimisers,server,server_optim,device,train_loader,i)
        test_model(clients,server,device,test_loader)
    
    torch.save({'server':server.state_dict(),
            'client1':clients[0].state_dict(),
            'client2':clients[1].state_dict(),
            'client3':clients[2].state_dict(),
            'client4':clients[3].state_dict(),
            'client5':clients[4].state_dict()},"split-model-multi-agent.pt")

Files already downloaded and verified


Model Summary

Client:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,792
       BatchNorm2d-2           [-1, 64, 32, 32]             128
              ReLU-3           [-1, 64, 32, 32]               0
            Conv2d-4           [-1, 64, 32, 32]          36,928
       BatchNorm2d-5           [-1, 64, 32, 32]             128
              ReLU-6           [-1, 64, 32, 32]               0
         MaxPool2d-7           [-1, 64, 16, 16]               0
            Conv2d-8          [-1, 128, 16, 16]          73,856
       BatchNorm2d-9          [-1, 128, 16, 16]             256
             ReLU-10          [-1, 128, 16, 16]               0
           Conv2d-11          [-1, 128, 16, 16]         147,584
      BatchNorm2d-12          [-1, 128, 16, 16]             256
             ReLU-13          [-1, 128, 