<a href="https://colab.research.google.com/github/xuyangm/BC-FL/blob/master/FL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [35]:
from torchvision import transforms, datasets
import numpy as np
from torch.utils.data import sampler
from itertools import combinations
import random
import math
import copy

class ClientInfo(object):
    """
    Record some information related to a client
    """

    def __init__(self, client_id):
        self.client_id = client_id
        self.last_involved_round = 0
        self.contribution = 0.0
        self.loss = 0.0
        self.times = 0

train_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

test_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

data_train = datasets.CIFAR10('./data/cifar/', train=True, download=True, transform=train_transform)
data_test = datasets.CIFAR10('../data/cifar/', train=False, download=True, transform=test_transform)


def dirichlet_partition(data, num_clients, alpha):
        """Partition by Dirichlet distribution (non-iid)"""
        labels = torch.as_tensor(data.targets)
        num_classes = labels.max() + 1
        label_distribution = np.random.dirichlet([alpha] * num_clients, num_classes)
        class_idx = [np.argwhere(labels == y).flatten() for y in range(num_classes)]
        client_idx = [[] for _ in range(num_clients)]

        for c, fracs in zip(class_idx, label_distribution):
            for i, idx in enumerate(np.split(c, ((np.cumsum(fracs)[:-1]) * len(c)).astype(int))):
                client_idx[i] += [idx]

        partitions = [np.concatenate(idx) for idx in client_idx]
        return partitions

def select_participants(explored, unexplored, sample_sz, explore_ratio, rd, clients_info):
  participants = []
  explore_num = min(len(unexplored), round(explore_ratio * sample_sz))
  exploit_num = sample_sz - explore_num

  if len(explored) < exploit_num:
    participants = random.sample(unexplored, sample_sz)
  else:
    utility = {}
    for cid in explored:
      L = clients_info[cid].last_involved_round
      utility[cid] = clients_info[cid].contribution + math.sqrt(0.1 * math.log(rd, 10) / L)

    unexplored_participants = random.sample(unexplored, explore_num)
    explored_participants = sorted(utility, key=utility.get, reverse=True)[:exploit_num]
    participants = unexplored_participants + explored_participants
  return participants

def cal_shapley_values(loader, local_updates, accuracy, batch_size):
    value_dict = {}
    accuracy_dict = {}
    client_ids = sorted(list(local_updates.keys()))

    for i in range(len(local_updates)):
        cmbs = list(combinations(client_ids, i+1))
        for cmb in cmbs:
            model = average_model(local_updates, cmb)
            accuracy_dict[cmb], _ = test(model, loader)

    for k in local_updates:
        non_k = [x for x in client_ids if x != k]
        counter = 1
        value = accuracy_dict[(k,)] - accuracy

        for j in range(1, len(local_updates)):
            partner_list = combinations(non_k, j)
            for partners in partner_list:
                full_tup = tuple(sorted(partners + (k,)))
                value = value + accuracy_dict[full_tup] - accuracy_dict[partners]
                counter += 1

        value_dict[k] = float(value / counter)

    return value_dict

partitions = dirichlet_partition(data_train, 100, 0.5)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
import torchvision
from torch.utils.data import DataLoader
import torch
from torch.utils.tensorboard import SummaryWriter
from collections import OrderedDict

def train(model, loader, lr, momentum, weight_decay, epoch):
  model = model.cuda()

  optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
  criterion = torch.nn.CrossEntropyLoss().cuda()

  step = 0

  while step < epoch:
    model.train()
    step += 1
    for (X, y) in loader:
      X = X.cuda()
      y = y.cuda()
      output = model(X)
      loss = criterion(output, y)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

  return model.state_dict()

def test(model, loader):
  model = model.cuda()
  criterion = torch.nn.CrossEntropyLoss().cuda()
  model.eval()
  total_accuracy = total_loss = 0
  for (X, y) in loader:
    X = X.cuda()
    y = y.cuda()
    output = model(X)
    total_accuracy += (output.argmax(1) == y).sum()
    loss = criterion(output, y)
    total_loss += loss.item()

  # writer = SummaryWriter("logs/cifar10")
  # writer.add_scalar("loss", total_loss/len(data_test), rd)
  # writer.add_scalar("accuracy(%)", total_accuracy/len(data_test)*100, rd)
  # writer.close()
  return total_accuracy/len(data_test), total_loss/len(data_test)
                 
def average_model(local_updates, cmb):
  model = torchvision.models.shufflenet_v2_x2_0(num_classes=10)
  sz = len(cmb)
  if sz == 1:
    model.load_state_dict(local_updates[cmb[0]])
    return model

  averaged_weights = OrderedDict()
  for it, idx in enumerate(cmb):
    local_weights = local_updates[idx]
    for key in local_weights.keys():
      if it == 0:
        averaged_weights[key] = torch.div(local_weights[key], sz)
      else:
        averaged_weights[key] += torch.div(local_weights[key], sz)

  model.load_state_dict(averaged_weights)
  return model

def random_client_sampler(clients, sample_sz):
  return random.sample(clients, sample_sz)

def random_FL(model, rd, lr, momentum, weight_decay, epoch):
  clients = [_ for _ in range(100)]
  for r in range(rd):
    sampled = random_client_sampler(clients, 5)
    loaders = []
    for i in sampled:
      loaders.append(DataLoader(data_train, batch_size=20, drop_last=True, sampler=sampler.SubsetRandomSampler(partitions[i])))
    local_updates = {}
    for i in range(5):
      local_updates[sampled[i]] = train(model, loaders[i], lr, momentum, weight_decay, epoch)
    m = average_model(local_updates)
    acc, loss = test(m, test_loader, r+1)
    print("Round {}, acc: {}%, loss: {}".format(r+1, acc, loss))

def bandit_FL(model, sample_sz, rd, lr, momentum, weight_decay, epoch, acc, batch_sz):
  clients = [_ for _ in range(100)]
  clients_info = {}
  for i in clients:
    clients_info[i] = ClientInfo(i)
  explored = []
  unexplored = clients
  explore_ratio = 0.9
  for cur_rd in range(1, rd+1):
    sampled = select_participants(explored, unexplored, sample_sz, explore_ratio, cur_rd, clients_info)
    print(sampled)
    loaders = []
    for i in sampled:
      loaders.append(DataLoader(data_train, batch_size=batch_sz, drop_last=True, sampler=sampler.SubsetRandomSampler(partitions[i])))
    local_updates = {}
    for i in range(sample_sz):
      local_updates[sampled[i]] = train(model, loaders[i], lr, momentum, weight_decay, epoch)
    model = average_model(local_updates, local_updates.keys())
    acc, loss = test(model, test_loader)
    print("Round {}, acc: {}%, loss: {}".format(cur_rd, acc*100, loss))
    print("Calculate shapley value")
    values_dict = cal_shapley_values(test_loader, local_updates, acc, batch_sz)
    contrib = {}
    for participant in sampled:
      clients_info[participant].last_involved_round = cur_rd
      clients_info[participant].contribution = (clients_info[participant].contribution * clients_info[participant].times + values_dict[participant]) / (clients_info[participant].times + 1)
      clients_info[participant].times += 1
      contrib[participant] = clients_info[participant].contribution

    for cid in contrib:
      if cid in unexplored:
        unexplored.remove(cid)
        explored.append(cid)

    iter_explored = copy.deepcopy(explored)
    for cid in iter_explored:
      if clients_info[cid].times >= 10:
        unexplored.append(cid)
        explored.remove(cid)
        clients_info[cid].times = 0
        clients_info[cid].last_involved_round = 0

    explore_ratio *= 0.98
    explore_ratio = max(explore_ratio, 0.2)

rd = 10
epoch = 10
batch_size = 20
lr = 1e-2
momentum = 0.9
weight_decay = 4e-5

test_loader = DataLoader(data_test, batch_size, shuffle=False, pin_memory=True, drop_last=False)
model = torchvision.models.shufflenet_v2_x2_0(num_classes=10)
acc, _ = test(model, test_loader)
print("initial acc: {}%".format(acc*100))
bandit_FL(model, 5, rd, lr, momentum, weight_decay, epoch, acc, 20)

initial acc: 9.999999046325684%
[48, 21, 63, 43, 51]
Round 1, acc: 26.219999313354492%, loss: 0.15444363471269607
Calculate shapley value
[99, 92, 14, 84, 48]
Round 2, acc: 29.309999465942383%, loss: 0.17149744911193848
Calculate shapley value
