In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import math
import syft as sy  # <-- NEW: import the Pysyft library
hook = sy.TorchHook(torch)  # <-- NEW: hook PyTorch ie add extra functionalities to support Federated Learning



In [3]:
epochs = 200
num_clients = 100

In [4]:
B = 10          # Total Bandwidth, kHz
N_0 = -100      # Noise Spectrum Power Density
h = -117        # Channel Gain
SNR = 20        # Assumed SNR
m = 1.2         # Model Size, Mb
epsilon = 1     # Effective Capacitance Parameter of CPU

In [5]:
# gamma = 1       # Training time constraint constant
# rho = 1         # Staleness Decay
tau = 1         # Time Length of Epoch
lbd = 0.7       # Reputation Update Weight, < 1
psi = 0.3       # Uncertain Reputation Ratio, < 1
w1 = 1
w2 = 1
w3 = 1          # Rank Weights
w4 = 1
w5 = 1
phi = 1.03      # Expected Training Time Constant, > 1

In [6]:
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 1000
        self.epochs = epochs
        self.lr = 0.01
        self.momentum = 0.5
        self.no_cuda = False
        self.seed = 1
        self.log_interval = 30
        self.save_model = False

args = Arguments()

use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}


In [7]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


model = Net().to(device)

optimizer = optim.SGD(model.parameters(), lr=args.lr) # TODO momentum is not supported at the moment

In [8]:
m = 0
for i in list(model.state_dict().values()):
    m += i.nelement()

m = m * 4 /1000 / 1000  # in Mb

In [9]:
class client():
    def __init__(self, index, p = 0.2, fq = 0.7):
        self.p = p                  # parameter, transmission energy, 
        self.bw = 0                 # variable, allocated bandwidth, kHz
        self.ts = -1                # parameter, time stamp, -1 -> not in a training
        self.fq = fq                # parameter, local CPU frequency, GHz
        self.c_i = 2 * m * 20 * 600 # number of CPU cycles to finish the training taskg
        self.fn =  0                # parameter, fairness
        self.rp = 0.8               # parameter, reputation
        self.AoU = -1               # parameter, age of update
        self.bf = 0                 # parameter, belief reputation
        self.dbf = 0                # parameter, disbelief reputation
        self.unc = 0                # parameter, uncertain reputation
        self.R = 0                  # parameter, rank score
        self.ds = 600               # parameter, data size
        self.index = index          # index
        self.T_comp = 0             # parameter, expected training time
        self.c_n = 20               # number of iterations done in the local training
        self.a = 0                  # parameter, coefficient over b in the formula of energy
        self.worker = sy.VirtualWorker(hook, id= "c" + str(index))
        self.local_model = model    # model returned after the local training
        self.stale_model = model    # model received before the local training

In [10]:
clients = []
for i in range(num_clients):
    clients.append(client(i))

In [11]:
federated_train_loader = sy.FederatedDataLoader( # <-- this is now a FederatedDataLoader 
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
    .federate( tuple(ct.worker for ct in clients) ), # <-- NEW: we distribute the dataset across all the workers, it's now a FederatedDataset
    batch_size=args.batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False,
                    transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)



In [12]:
import time

In [13]:
def train(args, model, device, federated_train_loader, optimizer, epoch, clients, index):
    for batch_idx, (data, target) in enumerate(federated_train_loader): # <-- now it is a distributed dataset

        if data.location != clients[index].worker: continue

        

        clients[index].stale_model = clients[index].local_model
        clients[index].local_model.train()

        clients[index].local_model.send(data.location) # <-- NEW: send the model to the right location
        data, target = data.to(device), target.to(device)

        time_start = time.time()
        for _ in range(clients[index].c_n):
            optimizer.zero_grad()
            output = clients[index].local_model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
            
        time_end = time.time()

        T_comp = 4*(time_end - time_start)
        clients[index].T_comp = T_comp
        
        clients[index].local_model.get() # <-- NEW: get the model back

        # with torch.no_grad():
        #     for key in model.state_dict().keys():
        #         model.state_dict()[key] = model.state_dict()[key] + \
        #             clients[index].local_model.state_dict()[key] - \ 
        #             stale_model.state_dict()[key]

        # if batch_idx % args.log_interval == 0:
        # loss = loss.get() # <-- NEW: get the loss back
        print('Train Epoch: {} Client Index: {} Time: {}'.format(
            epoch, index, T_comp))

        return math.ceil(4 *(time_end - time_start)/epoch)

In [14]:
def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    
    return test_loss, correct/len(test_loader.dataset)

In [15]:
def UpdateRank(RankList, clients, total_trained):
    for rc in RankList:
        max_T_comp = 2 * m * 20 * 600 / 0.5
        min_T_comp = 2 * m * 20 * 600 / 1.0
        T_comp = ( clients[rc[0]].c_i / clients[rc[0]].fq - min_T_comp - (min_T_comp + max_T_comp)/2 ) / (max_T_comp - min_T_comp)
        clients[rc[0]].R = w1 * clients[rc[0]].rp + w2 * ( 1 / (1 + np.exp(-(clients[rc[0]].AoU - 7))) - 1)\
                        + w3  * ( 1 / (1 + np.exp(-0.2 * (total_trained/num_clients - clients[rc[0]].fn))) - 1)\
                        + w4 * T_comp + w5 * (clients[rc[0]].ds - 600) / 600
        rc[1] = clients[rc[0]].R

    RankList = sorted(RankList, key=lambda x:-x[1])

    return RankList, clients

In [16]:
def determine(WaitingCache, TrainingClients, clients, epoch):
    Total_staleness = 0
    for record in TrainingClients:
        Total_staleness += max(clients[record[0]].T_comp - (epoch - clients[record[0]].ts), 0)
    
    Total_staleness += len(WaitingCache)

    return pow(phi, Total_staleness)

In [17]:
def downlink(clients, d, RankList, TrainingClients, epoch, model):

    Avai_BW = B / d
    for i in RankList:
        clients[i[0]].a = m * clients[i[0]].p / math.log(1 + SNR)

    num_downlink = 0
    
    has_solution = 1
    sum_sqrt_a_i = 0
    for i in range(len(RankList)):
        sum_sqrt_a_i += math.sqrt(clients[RankList[i][0]].a)
        for j in range(i):
            if sum_sqrt_a_i / math.sqrt(clients[RankList[j][0]].a) > tau * math.log(1 + SNR) * Avai_BW / m:
                has_solution = 0
                break

        if has_solution == 0:
            num_downlink = i
            sum_sqrt_a_i -= math.sqrt(clients[RankList[i][0]].a)
            break
            
        num_downlink = i + 1

    Downlink_Bandwidth = []

    print("num_DL:", num_downlink)

    count = 0
    while count < num_downlink:
        clients[RankList[0][0]].bw = Avai_BW * math.sqrt(clients[RankList[0][0]].a) / sum_sqrt_a_i
        clients[RankList[0][0]].stale_model = model
        Downlink_Bandwidth.append([RankList[0][0], clients[RankList[0][0]].bw])
        clients[RankList[0][0]].ts = epoch
        clients[RankList[0][0]].AoU = 0
        clients[RankList[0][0]].fn += 1

        remain_time = train(args, model, device, federated_train_loader, optimizer, epoch, clients, RankList[0][0])

        #     for key in model.state_dict().keys():
        #         model.state_dict()[key] = model.state_dict()[key] + \
        #             clients[index].local_model.state_dict()[key] - \ 
        #             stale_model.state_dict()[key]

        TrainingClients.append([RankList[0][0], epoch, remain_time + 1])
        RankList.pop(0)
        count+=1

    # print("DL: RankList:", RankList)

    return clients, RankList, TrainingClients, Downlink_Bandwidth

In [18]:
def uplink(clients, d, WaitingCache, epoch, model):
    Avai_BW = B * (1 - 1 / d)
    for i in WaitingCache:
        clients[i].a = m * clients[i].p / math.log(1 + SNR)

    num_uplink = 0
    
    has_solution = 1
    sum_sqrt_a_i = 0
    for i in range(len(WaitingCache)):
        sum_sqrt_a_i += math.sqrt(clients[WaitingCache[i]].a)
        for j in range(i):
            if sum_sqrt_a_i / math.sqrt(clients[WaitingCache[j]].a) > tau * math.log(1 + SNR) * Avai_BW / m:
                has_solution = 0
                break
        if has_solution == 0:
            num_uplink = i
            sum_sqrt_a_i -= math.sqrt(clients[WaitingCache[i]].a)
            break
        num_uplink = i + 1

    Uplink_Bandwidth = []

    NewRankList = []

    count = 0
    while count < num_uplink:
        clients[WaitingCache[0]].bw = Avai_BW * math.sqrt(clients[WaitingCache[0]].a) / sum_sqrt_a_i
        Uplink_Bandwidth.append([WaitingCache[0], clients[WaitingCache[0]].bw])
        clients[WaitingCache[0]].AoU = epoch - clients[WaitingCache[0]].ts
        clients[WaitingCache[0]].ts = -1
        NewRankList.append([WaitingCache[0], clients[WaitingCache[0]].R])
        for key in model.state_dict().keys():
            model.state_dict()[key] = (model.state_dict()[key] \
                                        + clients[WaitingCache[0]].local_model.state_dict()[key] \
                                        - clients[WaitingCache[0]].stale_model.state_dict()[key])\
                                        * pow(0.95, clients[WaitingCache[0]].AoU)
        WaitingCache.pop(0)
        count += 1

    return clients, WaitingCache, NewRankList, Uplink_Bandwidth, model

In [20]:
import numpy as np
import os
import matplotlib.pyplot as plt

In [None]:
%%time

RankList = [[i, clients[i].R] for i in range(num_clients)]
WaitingCache = []
TrainingClients = []
test_loss = []
accuracy = []
num_uplink = []
num_downlink = []
total_trained = 0

for epoch in range(1, args.epochs + 1):

    for ct in clients:
        if ct.AoU != -1: ct.AoU += 1
    Uplink = []
    Downlink = []
    NewRankList = []

    i = 0
    while i < len(TrainingClients):
        TrainingClients[i][2] -= 1
        if TrainingClients[i][2] <= 0:
            WaitingCache.append(TrainingClients[i][0])
            TrainingClients.pop(i)
            continue
        i+=1

    d = determine(WaitingCache, TrainingClients, clients, epoch)

    while 1:
        if len(TrainingClients) > 0 and TrainingClients[0][1] >= epoch:
            WaitingCache.append(TrainingClients[0][0])
            TrainingClients.pop(0)
        else:
            break

    
    RankList, clients = UpdateRank(RankList, clients, total_trained)

    clients, WaitingCache, NewRankList, Uplink_Bandwidth, model = uplink(clients, d, WaitingCache, epoch, model)

    clients, RankList, TrainingClients, Downlink_Bandwidth = downlink(clients, d, RankList, TrainingClients, epoch, model)

    testloss, acc = test(args, model, device, test_loader)

    total_trained = len(Downlink_Bandwidth)

    test_loss.append(testloss)
    accuracy.append(acc)

    with open("Log.txt", "a") as LOG:
        LOG.write("Epoch = " + str(epoch))
        LOG.write("\nd = " + str(d))
        LOG.write("\nRankList = " + str(RankList))
        LOG.write("\nUplink Available Bandwidth = " + str(B * (1 - 1 / d)))
        LOG.write("\nDownlink Available Bandwidth = " + str(B/d))
        LOG.write("\nUplink = [" + str([i[0] for i in Uplink_Bandwidth]) + "]")
        LOG.write("\nDownlink = [" + str([i[0] for i in Downlink_Bandwidth]) + "]\n\n")

    print("Uplink: ", [i[0] for i in Uplink_Bandwidth])
    print("Downlink: ", [i[0] for i in Downlink_Bandwidth])
    num_uplink.append(len(Uplink_Bandwidth))
    num_downlink.append(len(Downlink_Bandwidth))
    

    with open("Record.txt", "w") as record:
        for ct in clients:
            rc = "Index:" + str(ct.index) + " Bandwidth:" + str(ct.bw) + \
                " Rankscore:" + str(ct.R) + " Time Coefficient:" + str(ct.a) + "\n"
            record.write(rc)
    
    for record in Uplink_Bandwidth:
        clients[record[0]].bw = 0
    for record in Downlink_Bandwidth: clients[record[0]].bw = 0

    
    RankList += NewRankList

    RankList, clients = UpdateRank(RankList, clients, total_trained)

    plt.figure(1)
    plt.clf()
    plt.title('Evolution of Test Loss')
    plt.xlabel('Time Slot')
    plt.ylabel('Test Loss')
    plt.plot(np.array(test_loss))
    plt.savefig("test_loss.png")

    plt.figure(2)
    plt.clf()
    plt.title('Evolution of Accuracy')
    plt.xlabel('Time Slot')
    plt.ylabel('Accuracy')
    plt.plot(np.array(accuracy))
    plt.savefig("accuracy.png")


    if (args.save_model):
        torch.save(model.state_dict(), "mnist_cnn.pt")