##Installs and Imports

In [2]:
!nvidia-smi

NVIDIA-SMI has failed because you are not:
	a) running as an administrator or
	b) there is not at least one TCC device in the system



In [None]:
import jax
import jax.numpy as jnp
import haiku as hk 
import numpy as np
import fedjax
from tqdm import tqdm
import random
from jax.tree_util import tree_multimap
from typing import NamedTuple , Callable, Any, Tuple, Sequence, Mapping, Any, Optional
from copy import deepcopy
import wandb

##Installs and Imports

In [None]:
# Uncomment these to install fedjax.
!pip install --quiet wandb
!pip install --quiet dm-haiku
!pip install --quiet fedjax

In [None]:
import jax
import jax.numpy as jnp
import haiku as hk 
import numpy as np
import torch 
import torchvision
import fedjax
from tqdm import tqdm
import random
from jax.tree_util import tree_multimap
from typing import NamedTuple , Callable, Any, Tuple, Sequence, Mapping, Any, Optional
from copy import deepcopy

##Datasets

In [None]:
def load_dataset(hyper: NamedTuple ):
    """ 
        Utilities to return federated datases synthesized from common datasets
        using a Dirichlet distirbutino acocrding to. 

    """
    pass


def load_emnist(hyper: NamedTuple):
    """
        Changing the built-in federated-emnist dataset to an inMemoryDataset, 
        and using the hyperparameter named to tuple to corrupt some of the
        data if needed. 
    """
    train, test = fedjax.datasets.emnist.load_data(only_digits=False)
    id_list = list(train.client_ids())
    train_inMemory = {}
    test_inMemory = {}
    for idx, i in enumerate(id_list):
        client_train = train.get_client(i).all_examples()
        client_test  = test.get_client(i).all_examples()
        train_inMemory[idx] = {'x':client_train['x'] , 'y':client_train['y']} 
        test_inMemory[idx]  =  {'x':client_test['x'] , 'y':client_test['y']}
    id_list = list(train_inMemory.keys())
    if hyper.adversary == True: 
        random.shuffle(id_list)
        adversary_list = id_list[:round(hyper.corrupt_ratio*len(id_list))]
    else:
        adversary_list = []
    for i in adversary_list:
        train_inMemory[i]['y'] = np.random.randint(0,62,size=train_inMemory[i]['y'].shape)
    train = fedjax.InMemoryFederatedData(train_inMemory)
    test  = fedjax.InMemoryFederatedData(test_inMemory)
    return train, test, adversary_list

##Aggregators

In [None]:
def krum(tree_list: Sequence[fedjax.Params], m:int=4, f:int=2):
    @jax.jit
    def tree_l2_dist(x, y):
        return fedjax.tree_util.tree_l2_norm(tree_multimap(lambda a,b: a-b, x, y))

    n = len(tree_list)
    dist_array = jnp.zeros((n,n))
    for i in range(n):
        for j in range(i+1,n):
            dist_array = dist_array.at[i,j].set(tree_l2_dist(tree_list[i],tree_list[j]))
    dist_array = dist_array + dist_array.T
    for i in range(n):
        dist_array = dist_array.at[i,i].set(jnp.inf)
    dist_array = jnp.sort(dist_array,axis=1)
    dist_array = dist_array[:,:n-f]
    dist_array = jnp.sum(dist_array, axis=1)
    sorted_indices = jnp.argsort(dist_array)[:m]
    multi_krum_list = [(tree_list[i],1) for i in sorted_indices]
    return fedjax.tree_util.tree_mean(multi_krum_list)

def projectedKrum():
    """A projection based acceleration of krum"""
    return 

def coordinate_median():
    pass

##Model

In [None]:
_SAMPLE_MNIST_BATCH = {
    'x': np.zeros((1, 28, 28, 1), dtype=np.float32),
    'y': np.zeros(1, dtype=np.float32)
    }


_TRAIN_LOSS = lambda b, p: fedjax.metrics.unreduced_cross_entropy_loss(b['y'], p)

_EVAL_METRICS = {
    'loss': fedjax.metrics.CrossEntropyLoss(),
    'accuracy': fedjax.metrics.Accuracy()
}

class mnist_conv(hk.Module):
    def __init__(self, num_classes:int, dropout_rate=0.25):
        super().__init__()
        self._num_classes = num_classes
        self._rate = dropout_rate
    
    def __call__(self, x:jnp.ndarray, is_train: bool):
        x = hk.Conv2D(output_channels=16, kernel_shape=(5, 5), padding='VALID')(x)
        x = (
            hk.MaxPool(
                window_shape=(1, 2, 2, 1), strides=(1, 2, 2, 1),
                padding='VALID')(x))
        x = jax.nn.relu(x)
        x = (
            hk.MaxPool(
                window_shape=(1, 2, 2, 1), strides=(1, 2, 2, 1),
                padding='VALID')(x))
        x = hk.Conv2D(output_channels=32, kernel_shape=(5, 5), padding='VALID')(x)
        x = jax.nn.relu(x)
        x = (
            hk.MaxPool(
                window_shape=(1, 2, 2, 1), strides=(1, 2, 2, 1),
                padding='VALID')(x))
        
        # # if is_train:
        #     x = hk.dropout(rng=hk.next_rng_key(), rate=self._rate, x=x)
        x = hk.Flatten()(x)
        x = hk.Linear(512)(x)
        x = jax.nn.relu(x)
        x = hk.Linear(512)(x)
        x = jax.nn.relu(x)
        # if is_train:
        #     x = hk.dropout(rng=hk.next_rng_key(), rate=self._rate, x=x)
        x = hk.Linear(self._num_classes)(x)
        return x

def create_mnist_cnn(num_classes, dropout_rate=0.25):
        
        def forward_pass(batch, is_train=True):
            return mnist_conv(num_classes, dropout_rate)(batch['x'], is_train)

        transformed_forward_pass = hk.transform(forward_pass)
        return fedjax.create_model_from_haiku(
            transformed_forward_pass=transformed_forward_pass,
            sample_batch=_SAMPLE_MNIST_BATCH,
            train_loss=_TRAIN_LOSS, 
            eval_metrics = _EVAL_METRICS,
            train_kwargs={'is_train': True},
            eval_kwargs={'is_train': False})

##Implementation of Ditto

In [None]:
def DittoClient(client_model, client_optimizer, lmbd=20):
    """ 
        client_model : functional form of client model to be used in training
        client_optimizer: function form of client_optimizer to be used in trainig
        RETURN a named tuple with client_init, client_step, client_final
        this is supposed to execute the client_steps, and client_final will 
        return the update signal to the server.  
        As a design principle the loss will be calculate withint here,
    """
    class ToReturn(NamedTuple):
        init: Callable
        step: Callable
        final: Callable
    
    def server_loss_fn(params,batch, rng):
       preds =  client_model.apply_for_train(params, batch, rng)
       example_loss = client_model.train_loss(batch, preds)
       return jnp.mean(example_loss)
    
    def client_loss_fn(client_params, server_params,batch, rng):
       preds =  client_model.apply_for_train(client_params, batch, rng)
       example_loss = client_model.train_loss(batch, preds)
       distance = jax.tree_util.tree_multimap(lambda a, b: a - b,
                                                    client_params,
                                                    server_params)
       distance_l2 = fedjax.tree_util.tree_l2_norm(distance)
       loss = jnp.mean(example_loss)+ (lmbd/2)*(distance_l2**2)
       return loss
    
    server_grad_fn = jax.jit(jax.grad(server_loss_fn))
    client_grad_fn = jax.jit(jax.grad(client_loss_fn))
    
    def client_init(server_params, client_input):
        client_rng = client_input['key']
        server_opt_state = client_optimizer.init(server_params.server_params)
        client_opt_state = client_optimizer.init(client_input['client_params'])
        client_step_state = {
            'server_params': server_params.server_params,
            'client_params': client_input['client_params'],
            'server_opt_state': server_opt_state,
            'client_opt_state': client_opt_state, 
            'rng': client_rng,
            'iter_numb':0
        }
        return client_step_state

    def client_step(client_step_state, batch):
        rng, use_rng = jax.random.split(client_step_state['rng'])
        server_grads = server_grad_fn(client_step_state['server_params'], batch, use_rng)
        client_grads = client_grad_fn(client_step_state['client_params'],
                                      client_step_state['server_params'],
                                      batch, use_rng)
        
        server_opt_state, server_params = client_optimizer.apply(server_grads,
                                                    client_step_state['server_opt_state'],
                                                    client_step_state['server_params'])
        client_opt_state, client_params = client_optimizer.apply(client_grads,
                                                    client_step_state['client_opt_state'],
                                                    client_step_state['client_params'])
        next_client_step_state = {
            'server_params': server_params,
            'client_params': client_params,
            'server_opt_state': server_opt_state,
            'client_opt_state': client_opt_state, 
            'rng': rng,
            'iter_numb':client_step_state['iter_numb']+1
        }
        return next_client_step_state
    
    def client_final(server_params, client_step_state):
        delta_params = jax.tree_util.tree_multimap(lambda a, b: a - b,
                                                    server_params.server_params,
                                                    client_step_state['server_params'])
        
        return delta_params, client_step_state['iter_numb'], client_step_state['client_params']
    
    return ToReturn(client_init, client_step, client_final)


@fedjax.dataclass
class ServerState:
    server_params : fedjax.Params
    client_params: Mapping[Any, fedjax.Params]
    opt_state: fedjax.OptState
    adversary_list : list

def Ditto(ForClient, 
           server_optimizer, 
           client_batch_hparams, 
           aggregator=None):
    def init(params: fedjax.Params, adversary_list, client_ids):        
        opt_state = server_optimizer.init(params)
        client_params = {ids: deepcopy(params) for ids in client_ids}
        return ServerState(params, client_params, opt_state, adversary_list) 
    
    def apply(server_state: ServerState,
        clients: Sequence[Tuple[Any, fedjax.ClientDataset, fedjax.PRNGKey]], 
        server_rng: Optional[Any] =None
        ) -> Tuple[ServerState, Mapping[Any, Any]]:

        client_models = server_state.client_params
        client_diagnostics = {}
        client_delta_params_weights = []       
        for_each_client_update = fedjax.for_each_client(ForClient.init,
                                                        ForClient.step,
                                                        ForClient.final)

        ##need to inject code here
        batched_clients_data = [
            (cid, cds.shuffle_repeat_batch(client_batch_hparams), 
             {'key':crng,'client_params':server_state.client_params[cid]}) for cid, cds, crng in clients]

        data_length = {cid:len(cds) for cid, cds, _ in clients}

        for client_id, (delta_params,iter_numb, new_client_params) in for_each_client_update(server_state,
                                                            batched_clients_data):
            
            client_delta_params_weights.append((delta_params, len(data_length)))
            client_models[client_id] = new_client_params
            client_diagnostics[client_id] = {
                'delta_l2_norm': fedjax.tree_util.tree_l2_norm(delta_params),
                'iter_numb':iter_numb}

        ##Updating the server
        if aggregator==None:
            mean_delta_params = fedjax.tree_util.tree_mean(client_delta_params_weights)
        else:
            mean_delta_params = aggregator([tree_param for tree_param,_ in client_delta_params_weights])
        server_state = server_update(server_state, mean_delta_params,client_models)
        return server_state, client_diagnostics

    def server_update(server_state, mean_delta_params,client_models):
        opt_state, params = server_optimizer.apply(mean_delta_params,
                                                server_state.opt_state,
                                                server_state.server_params)
        return ServerState(params, client_models,opt_state,server_state.adversary_list)

    return fedjax.FederatedAlgorithm(init, apply)

##Main Function

In [None]:
def main(hyper, wandb_log=False):
    random.seed(hyper.seed)
    np.random.seed(hyper.seed)
    rng = jax.random.PRNGKey(hyper.seed)

    num_classes = 62
    model = create_mnist_cnn(num_classes=num_classes, dropout_rate=hyper.dropout_rate)
    init_params = model.init(rng)

    train, test, adversary_list = load_emnist(hyper)
    client_optimizer = fedjax.optimizers.sgd(hyper.client_lr)
    server_optimizer = fedjax.optimizers.sgd(hyper.server_lr)
    if hyper.step_per_client == None:
        client_batch_hparams = fedjax.ShuffleRepeatBatchHParams(
                                        batch_size=hyper.samples_per_client)
    else:
        client_batch_hparams = fedjax.ShuffleRepeatBatchHParams(
                                        batch_size=hyper.samples_per_client,
                                        num_steps =hyper.step_per_client)

    ForClient = DittoClient(model, client_optimizer)
    client_ids = list(train.client_ids())
    if hyper.use_aggregator == True:
        FedAlgorithm = Ditto(ForClient,server_optimizer, client_batch_hparams,krum)
    else:
        FedAlgorithm = Ditto(ForClient,server_optimizer, client_batch_hparams)

    rng, subkey = jax.random.split(rng)
    server_state = FedAlgorithm.init(init_params,adversary_list,client_ids)

    train_client_sampler = fedjax.client_samplers.UniformGetClientSampler(
            fd=train, num_clients = hyper.clients_per_round, seed=hyper.seed)

    batched_train_data = {cid:cds.shuffle_repeat_batch(batch_size = 20) for cid, cds in train.get_clients(client_ids)}   
    batched_test_data = {cid:cds.shuffle_repeat_batch(batch_size = 20) for cid, cds in test.get_clients(client_ids)}
            
    train_eval_datasets = [cds for _, cds in train.get_clients(client_ids)]   
    test_eval_datasets = [cds for _, cds in test.get_clients(client_ids)]
    


    
    for round_num in tqdm(range(1, hyper.round_num)):
        clients = train_client_sampler.sample()
        server_state, client_diagnostics = FedAlgorithm.apply(server_state, clients)
        if round_num%10 == 0:
            #Only test the accuracy of the non-byzantine workers. 
            client_ids = [cid for cid, _, _ in clients if cid not in 
                            server_state.adversary_list]
            
            train_eval_batches = fedjax.padded_batch_client_datasets(
                train_eval_datasets, batch_size=256)
            test_eval_batches = fedjax.padded_batch_client_datasets(
                test_eval_datasets, batch_size=256)

            train_metrics = fedjax.evaluate_model(model, server_state.server_params,
                                                    train_eval_batches)
            test_metrics = fedjax.evaluate_model(model, server_state.server_params,
                                                test_eval_batches)
            
            ditto_train_accuracy = []
            ditto_test_accuracy = []
            train_acc_tally, test_acc_tally = [], []
            for indx, id in enumerate(client_ids):
                client_train_metrics = fedjax.evaluate_model(model, server_state.client_params[id],
                                                        batched_train_data[id])
                client_test_metrics =  fedjax.evaluate_model(model, server_state.client_params[id],
                                                    batched_test_data[id])
                train_acc_tally.append(client_train_metrics['accuracy'])
                test_acc_tally.append(client_test_metrics['accuracy'])
            
            log_dict = {
                'ditto_train_accuracy': float(sum(train_acc_tally)/len(train_acc_tally)),
                'ditto_test_accuracy': float(sum(test_acc_tally)/len(test_acc_tally)),
                'train_accuracy': float(train_metrics['accuracy']), 
                'train_loss' : float(train_metrics['loss']), 
                'test_accuracy': float(test_metrics['accuracy']), 
                'test_loss' : float(test_metrics['loss'])
            }
            if wandb_log == True:
                wandb.log(log_dict)
            print(f'[round {round_num}] metrics={log_dict}')

##Running code:

In [None]:
class hyperparameters(NamedTuple):
    round_num : int = 1000
    clients_per_round : int = 10
    samples_per_client : int = 20
    client_lr :float = 0.01
    server_lr :float = 1
    seed :int = 42 
    dropout_rate: float = 0.25
    adversary: bool = False  
    corrupt_ratio: float = 0.20
    aggregator: str = 'krum'
    use_aggregator: bool = False
    step_per_client = None

In [None]:
wandb.login()
wandb.init(project='Graduate Project', name='Ditto Run', config=hyperparameters()._asdict())
main(hyperparameters(), wandb_log=True)