In [1]:
import wandb
wandb.login()
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data.sampler import RandomSampler

from tqdm import tqdm
import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt
from os.path import isfile

from data_preprocess import *
from model import NN
from dshap import *
from utilities import *
from fl_methods import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

wandb_config = {}

random_seed = 0 # log with wandb
wandb_config['random_seed'] = random_seed
torch.manual_seed(random_seed)
np.random.seed(random_seed)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpringlesinghal[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Load Data

In [2]:
train_data_global, val_data_global, test_data_global = load_mnist()
input_dim = 784
output_dim = 10

files already downloaded


In [3]:
num_clients = 20 # log with wandb
alpha = 1e6 # log with wandb
wandb_config['num_clients'] = num_clients
wandb_config['alpha'] = alpha
client_indices = NIIDClientSplit(train_data=train_data_global, num_clients=num_clients, alpha=alpha)
client_indices_batched = DivideIntoBatches(client_indices, num_batches = 10)


## Compute shapley value for all minibatches


In [4]:
dshap = {i:None for i in range(num_clients)}
for i in range(num_clients):
  print(f"Starting {i}...........")
  dshap[i] = DShapMiniBatches(client_indices_batched[i], client_indices[i], train_data_global)

for i in range(num_clients):
  dshaplist = []
  for j in range(Nb):
    dshaplist.append(dshap[i][j])
  dshap[i] = dshaplist

Starting 0...........


model.NN

In [None]:
experiments = ['random','server-selection','client-sampling','client-transmission']
initModel = NN(input_dim = input_dim, output_dim = output_dim) # use the same model for all experiments
criterion = nn.CrossEntropyLoss()

In [None]:
# initialise test data
test_data = test_data_global.data.to(device = device)
test_targets = test_data_global.targets.to(device = device)

learning_rate = 0.001 # can try hyperparameter sweep with wandb
wandb_config['learning_rate'] = learning_rate
num_communication_rounds = 400 
wandb_config['num_communication_rounds'] = num_communication_rounds
server_select_fraction = 0.1
wandb_config['server_select_fraction'] = server_select_fraction
client_iterations = 1 # >1 for FedAvg, 1 for FedSGD
wandb_config['client_iterations'] = client_iterations

loss_lists = []
accuracy_lists = []
avg_active_clients_list = []

for experiment in experiments:
  config = deepcopy(wandb_config)
  config['experiment'] = experiment

  server_selection = "random" # or "shapley"
  client_sampling = "random" # or "shapley"
  client_transmission = "always" # or "shapley"

  if experiment == 'server-selection':
    # shapley based server selection is compared with random selection
    server_selection = "shapley"
  elif experiment == 'client-sampling':
    # shapley based importance sampling is compared with random sampling
    client_sampling = "shapley"
  elif experiment == 'client-transmission':
    # shapley based thresholding is compared with transmitting always
    shapley_threshold = 0.5
    config['shapley_threshold'] = shapley_threshold
    client_transmission = "shapley"

  config['server_selection'] = server_selection
  config['client_sampling'] = client_sampling
  config['client_transmission'] = client_transmission

  wandb.init(
    # Set the project where this run will be logged
    project="shapley-fl", 
    # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
    name=f"{experiment}-1",  
    # Track hyperparameters and run metadata
    config=config)

  # initialise server model and compute initial accuracy and loss
  serverModel = deepcopy(initModel).to(device)
  serverOptimiser = optim.Adam(serverModel.parameters(), lr = learning_rate)

  loss, accuracy = model_accuracy(serverModel, test_data, test_targets, criterion, device)

  loss_list = [loss]
  accuracy_list = [accuracy]  
  avg_active_clients = 0

  if server_selection == "shapley":
    client_values = [0 for i in range(num_clients)]


  for communication_round in tqdm(range(num_communication_rounds)):
    '''
    At server:
    - select clients to communicate with
    - two schemes: random selection, GTG-Shapley
    '''
    if communication_round < 3*(1/server_select_fraction) or server_selection == "random":
      # random client selection
      num_selected = int(server_select_fraction*num_clients)
      selected_clients = [(i < num_selected) for i in range(num_clients)]
      np.random.shuffle(selected_clients)
    elif server_selection == "shapley":
      # GTG-Shapley client selection
      num_selected = int(server_select_fraction*num_clients)
      all_clients = list(range(num_clients))
      selected_indices = np.argpartition(client_values, -1*num_selected)[-1*num_selected:]
      selected_clients = [False for i in range(num_clients)]
      for i in selected_indices:
        selected_clients[i] = True

    client_gradients = []
    sent_status = []

    for i in range(num_clients):
      if selected_clients[i]:
        # set client model weights to aggregator weights, and copy optimiser state
        clientModel = NN(input_dim = input_dim, output_dim = output_dim)
        clientModel = clientModel.to(device)
        clientModel.load_state_dict(serverModel.state_dict())
        clientOptimiser = optim.SGD(clientModel.parameters(), lr = learning_rate)
        clientOptimiser.zero_grad()
        dshap_chosen = 0
        for iteration in range(client_iterations):
          # select client minibatch
          if client_sampling == "random":
            # choose the minibatch for client i randomly
            # minibatch_idx
            vals = list(range(len(client_indices_batched[i]))) # list of minibatches (len = # minibatches)
            minibatch_idx = np.random.choice(vals, size = 1)[0] # draw a random minibatch k
          elif client_sampling == "shapley":
            # choose the minibatch for client i with importance sampling
            # minibatch_idx
            wts = np.array(dshap[i]) # weights for importance sampling
            wts = np.array([max(i, 1e-7) for i in wts])
            # negative dshap minibatches are never chosen, need to add some error handling
            vals = list(range(len(dshap[i]))) # list of minibatches (len = # minibatches)
            wtsum = np.sum(wts)
            wts = wts/wtsum # TODO: add error handling for case when all minibatches are bad and wtsum = 0
            minibatch_idx = np.random.choice(vals, size = 1, p = wts)[0] # draw a random minibatch k
            dshap_chosen += dshap[i][minibatch_idx]
          
            # update gradients on this chosen client minibatch
            data_raw_minibatch = [train_data_global.data[j] for j in client_indices_batched[i][minibatch_idx]]
            targets_raw_minibatch = [int(train_data_global.targets[j]) for j in client_indices_batched[i][minibatch_idx]]
            data_minibatch = torch.stack(data_raw_minibatch,0).to(device=device).to(torch.float32)
            targets_minibatch = torch.tensor(targets_raw_minibatch).to(device=device)

            scores = clientModel(data_minibatch)
            loss = criterion(scores, targets_minibatch)
            loss.backward()
            clientOptimiser.step()
        
        dshap_chosen /= client_iterations
        # compute client gradients
        if client_transmission == "always":
          sent_status.append(True)
        elif client_transmission == "shapley":
          sent_status.append(True if dshap_chosen > shapley_threshold else False)

        if sent_status[-1] == True: 
          client_gradients.append(clientModel.gradients())
        else:
          client_gradients.append(None)

      else:
        # if the client is not selected
        sent_status.append(False)
        client_gradients.append(None)
    # end gradient collection loop 
    avg_active_clients = ((communication_round)*avg_active_clients + np.sum(sent_status))/(communication_round + 1)

    print_results = False
    if communication_round % 10 == 0:
      print_results = True
    # combine gradients and take a step at aggregator and update aggregator weights
    if server_selection == "random":
      aggregator_update(client_gradients, sent_status, model=serverModel, optimiser=serverOptimiser)
    elif server_selection == "shapley":
      client_values = aggregator_update_shapley(client_values, client_gradients, sent_status, model=serverModel, optimiser=serverOptimiser, val_data = val_data_global, criterion = criterion, device = device)

    loss, accuracy = model_accuracy(serverModel, test_data, test_targets, criterion)
    loss_list.append(loss)
    accuracy_list.append(accuracy)
    wandb.log({'loss': loss, 'accuracy': accuracy, 'average_active_clients': avg_active_clients})
    wandb.log({f'client_values_{i}':client_values[i] for i in range(num_clients)})
    if print_results == True:
      print(f"loss = {loss:.4f}, accuracy = {accuracy*100:.3f}")

  print(f"avg_active_clients = {avg_active_clients}")

  loss_lists.append(loss_list)
  accuracy_lists.append(accuracy_list)
  avg_active_clients_list.append(avg_active_clients)

In [None]:
list(set([1,2,3]) - set([1,2]))

[3]