<a href="https://colab.research.google.com/github/vasusundaraj/Academic_repo/blob/master/Copy_of_FL_working.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install --quiet --upgrade tensorflow_federated
!pip install tensorflow-privacy
!pip install nest_asyncio
import nest_asyncio
nest_asyncio.apply()

%load_ext tensorboard

import sys

if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [None]:
import tensorflow as tf
import tensorflow_federated as tff
import tensorflow_privacy
import collections
import attr
import os
import time
import json

from absl import app
from absl import flags
import numpy as np
from scipy import io
from tensorflow_federated.python.tensorflow_libs import tensor_utils
from tensorflow_federated.python.core.api import computations
import matplotlib.pyplot as plt

In [None]:
number_clients = 3

In [None]:
@attr.s(eq=False, frozen=True)
class ClientOutput(object):
  """Structure for outputs returned from clients during federated optimization.
  Fields:
  -   `weights_delta`: A dictionary of updates to the model's trainable
      variables.
  -   `weights_delta_weight`: Weight to be used in a weighted mean when
      aggregating `weights_delta`.
  -   `model_output`: A structure matching
      `tff.learning.Model.report_local_outputs`, reflecting the results of
      training on the input dataset.
  -   `optimizer_output`: Additional metrics or other outputs defined by the
      optimizer.
  """
  weights_delta = attr.ib()
  weights_delta_weight = attr.ib()
  model_output = attr.ib()
  response_time = attr.ib()
  optimizer_output = attr.ib()


@attr.s(eq=False, frozen=True)
class ServerState(object):
  """Structure for state on the server.
  Fields:
  -   `model`: A dictionary of model's trainable variables.
  -   `optimizer_state`: Variables of optimizer.
  """
  model = attr.ib()
  optimizer_state = attr.ib()
  delta_aggregate_state = attr.ib()


def _create_optimizer_vars(model, optimizer):
  model_weights = _get_weights(model)
  delta = tf.nest.map_structure(tf.zeros_like, model_weights.trainable)
  grads_and_vars = tf.nest.map_structure(
      lambda x, v: (-1.0 * x, v), tf.nest.flatten(delta),
      tf.nest.flatten(model_weights.trainable))
  optimizer.apply_gradients(grads_and_vars, name='server_update')
  return optimizer.variables()


def _get_weights(model):
  return tff.learning.framework.ModelWeights.from_model(model)


def _get_norm(weights):
  """Compute the norm of a weight matrix.
  Args:
    weights: a OrderedDict specifying weight matrices at different layers.
  Returns:
    The norm of all layer weight matrices.
  """
  return tf.linalg.global_norm(tf.nest.flatten(weights))


@tf.function
def server_update(model, server_optimizer, server_optimizer_vars, server_state,
                  weights_delta, new_delta_aggregate_state):
  """Updates `server_state` based on `weights_delta`.
  Args:
    model: A `tff.learning.Model`.
    server_optimizer: A `tf.keras.optimizers.Optimizer`.
    server_optimizer_vars: A list of previous variables of server_optimzer.
    server_state: A `ServerState`, the state to be updated.
    weights_delta: An update to the trainable variables of the model.
    new_delta_aggregate_state: An update to the server state.
  Returns:
    An updated `ServerState`.
  """
  model_weights = _get_weights(model)
  tf.nest.map_structure(lambda a, b: a.assign(b),
                        (model_weights, server_optimizer_vars),
                        (server_state.model, server_state.optimizer_state))

  grads_and_vars = tf.nest.map_structure(
      lambda x, v: (-1.0 * x, v), tf.nest.flatten(weights_delta),
      tf.nest.flatten(model_weights.trainable))
  server_optimizer.apply_gradients(grads_and_vars, name='server_update')

  return tff.structure.update_struct(
      server_state,
      model=model_weights,
      optimizer_state=server_optimizer_vars,
      delta_aggregate_state=new_delta_aggregate_state)



def build_server_init_fn(model_fn, server_optimizer_fn,
                         aggregation_process_init):
  """Builds a `tff.Computation` that returns initial `ServerState`.
  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`.
    aggregation_process_init: A `tff.Computation` that initializes the
      aggregator state.
  Returns:
    A `tff.tf_computation` that returns initial `ServerState`.
  """

  @tff.tf_computation
  def server_init_tf():
    model = model_fn()
    server_optimizer = server_optimizer_fn()
    # Create optimizer variables so we have a place to assign the optimizer's
    # state.
    server_optimizer_vars = _create_optimizer_vars(model, server_optimizer)
    return _get_weights(model), server_optimizer_vars

  @tff.federated_computation
  def server_init():
    initial_model, server_optimizer_state = tff.federated_eval(
        server_init_tf, tff.SERVER)
    return tff.federated_zip(
        ServerState(
            model=initial_model,
            optimizer_state=server_optimizer_state,
            delta_aggregate_state=aggregation_process_init()))

  return server_init


def build_server_update_fn(model_fn, server_optimizer_fn, server_state_type,
                           model_weights_type):
  """Builds a `tff.tf_computation` that updates `ServerState`.
  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`.
    server_state_type: type_signature of server state.
    model_weights_type: type_signature of model weights.
  Returns:
    A `tff.tf_computation` that updates `ServerState`.
  """

  @tff.tf_computation(server_state_type, model_weights_type.trainable,
                      server_state_type.delta_aggregate_state)
  def server_update_tf(server_state, model_delta, new_delta_aggregate_state):
    """Updates the `server_state`.
    Args:
      server_state: The `ServerState`.
      model_delta: The model difference from clients.
      new_delta_aggregate_state: An update to the server state.
    Returns:
      The updated `ServerState`.
    """
    model = model_fn()
    server_optimizer = server_optimizer_fn()
    # Create optimizer variables so we have a place to assign the optimizer's
    # state.
    server_optimizer_vars = _create_optimizer_vars(model, server_optimizer)

    return server_update(model, server_optimizer, server_optimizer_vars,
                         server_state, model_delta, new_delta_aggregate_state)

  return server_update_tf


def build_client_update_fn(model_fn, optimizer_fn, client_update_tf,
                           tf_dataset_type, model_weights_type):
  """Builds a `tff.tf_computation` in the presense of malicious clients.
  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`.
    client_update_tf: A 'tf.function' that computes the ClientOutput
    tf_dataset_type: type_signature of dataset.
    model_weights_type: type_signature of model weights.
  Returns:
    A `tff.tf_computation` for local model optimization with type signature:
    '@tff.tf_computation(tf_dataset_type, tf_dataset_type,
                      tf.bool, model_weights_type)'
  """

  @tff.tf_computation(tf_dataset_type, tf_dataset_type, tf.bool,
                      model_weights_type)
  def client_delta_tf(benign_dataset, malicious_dataset, client_type,
                      initial_model_weights):
    """Performs client local model optimization.
    Args:
      benign_dataset: A 'tf.data.Dataset' consisting of benign dataset
      malicious_dataset: A 'tf.data.Dataset' consisting of malicious dataset
      client_type: A 'tf.bool' indicating whether the client is malicious
      initial_model_weights: A `tff.learning.Model.weights` from server.
    Returns:
      A 'ClientOutput`.
    """
    # Create variables here in the graph context, before calling the tf.function
    # below.
    model = model_fn()
    optimizer = optimizer_fn()
    return client_update_tf(model, optimizer, benign_dataset, malicious_dataset,
                            client_type, initial_model_weights)

  return client_delta_tf

class ClientExplicitBoosting:
  """Client tensorflow logic for explicit boosting."""

  def __init__(self, boost_factor):
    """Specify the boosting parameter.
    Args:
      boost_factor: A 'tf.float32' specifying how malicious update is boosted.
    """
    self.boost_factor = boost_factor

  @tf.function
  def __call__(self, model, optimizer, benign_dataset, malicious_dataset,
               client_type, initial_weights):
    """Updates client model with client potentially being malicious.
    Args:
      model: A `tff.learning.Model`.
      optimizer: A 'tf.keras.optimizers.Optimizer'.
      benign_dataset: A 'tf.data.Dataset' consisting of benign dataset.
      malicious_dataset: A 'tf.data.Dataset' consisting of malicious dataset.
      client_type: A 'tf.bool' indicating whether the client is malicious; iff
        `True` the client will construct its update using `malicious_dataset`,
        otherwise will construct the update using `benign_dataset`.
      initial_weights: A `tff.learning.Model.weights` from server.
    Returns:
      A 'ClientOutput`.
    """
    model_weights = _get_weights(model)

    @tf.function
    def reduce_fn(num_examples_sum, batch):
      """Runs `tff.learning.Model.train_on_batch` on local client batch."""
      with tf.GradientTape() as tape:
        output = model.forward_pass(batch)
      gradients = tape.gradient(output.loss, model.trainable_variables)
      optimizer.apply_gradients(zip(gradients, model.trainable_variables))
      return num_examples_sum + tf.shape(output.predictions)[0]

    @tf.function
    def compute_benign_update():
      """compute benign update sent back to the server."""
      tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                            initial_weights)

      num_examples_sum = benign_dataset.reduce(
          initial_state=tf.constant(0), reduce_func=reduce_fn)

      weights_delta_benign = tf.nest.map_structure(lambda a, b: a - b,
                                                   model_weights.trainable,
                                                   initial_weights.trainable)

      aggregated_outputs = model.report_local_outputs()

      return weights_delta_benign, aggregated_outputs, num_examples_sum

    @tf.function
    def compute_malicious_update():
      """compute malicious update sent back to the server."""
      result = compute_benign_update()
      weights_delta_benign, aggregated_outputs, num_examples_sum = result

      tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                            initial_weights)

      malicious_dataset.reduce(
          initial_state=tf.constant(0), reduce_func=reduce_fn)

      weights_delta_malicious = tf.nest.map_structure(lambda a, b: a - b,
                                                      model_weights.trainable,
                                                      initial_weights.trainable)

      weights_delta = tf.nest.map_structure(
          tf.add, weights_delta_benign,
          tf.nest.map_structure(lambda delta: delta * self.boost_factor,
                                weights_delta_malicious))

      return weights_delta, aggregated_outputs, num_examples_sum
    
    result = tf.cond(
        tf.equal(client_type, True), compute_malicious_update,
        compute_benign_update)
    weights_delta, aggregated_outputs, num_examples_sum = result

    weights_delta_weight = tf.cast(num_examples_sum, tf.float32)
    response_time = 0
    weight_norm = _get_norm(weights_delta)

    return ClientOutput(
        weights_delta, weights_delta_weight, aggregated_outputs, response_time,
        collections.OrderedDict({
            'num_examples': num_examples_sum,
            'weight_norm': weight_norm,
        }))


def build_run_one_round_fn_attacked(server_update_fn, client_update_fn,
                                    aggregation_process,
                                    dummy_model_for_metadata,
                                    federated_server_state_type,
                                    federated_dataset_type,benign_clientoutput_list,counter):
  """Builds a `tff.federated_computation` for a round of training.
  Args:
    server_update_fn: A function for updates in the server.
    client_update_fn: A function for updates in the clients.
    aggregation_process: A 'tff.templates.AggregationProcess' that takes in
      model deltas placed@CLIENTS to an aggregated model delta placed@SERVER.
    dummy_model_for_metadata: A dummy `tff.learning.Model`.
    federated_server_state_type: type_signature of federated server state.
    federated_dataset_type: type_signature of federated dataset.
  Returns:
    A `tff.federated_computation` for a round of training.
  """
  @tff.tf_computation
  def cast_to_float(x):
    return tf.cast(x, tf.float32)
 
  federated_bool_type = tff.type_at_clients(tf.bool)

  @tff.federated_computation(federated_server_state_type,
                             federated_dataset_type, federated_dataset_type,
                             federated_bool_type)
  def run_one_round(server_state, federated_dataset, malicious_dataset,
                    malicious_clients):
    """Orchestration logic for one round of computation.
    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.
      malicious_dataset: A federated `tf.Dataset` with placement `tff.CLIENTS`.
        consisting of malicious datasets.
      malicious_clients: A federated `tf.bool` with placement `tff.CLIENTS`.
    Returns:
      A tuple of updated `ServerState` and the result of
      `tff.learning.Model.federated_output_computation`.
    """
    
       

    client_model = tff.federated_broadcast(server_state.model)#server weights(1st round its a initial weight)
    
    client_outputs = tff.federated_map(
        client_update_fn,
        (federated_dataset, malicious_dataset, malicious_clients, client_model))
    # client_type = tff.federated_map(tff.tf_computation(lambda x: True if a > 0.4 else False,client_outputs.response_time)))
    #threshold = tff.federated_mean(client_outputs.response_time)
    #virtual_outputs = tff.federated_map(
        #client_update_fn,
        #(federated_dataset, malicious_dataset, virtual_clients, client_model))
    
    #new_client_outputs = tff.federated_map(tff.tf_computation(tf.nest.map_structure(lambda a, b: a.assign(b), client_outputs,virtual_outputs)))
    
    
    weight_denom = client_outputs.weights_delta_weight
    
    if aggregation_process.is_weighted:
      aggregate_output = aggregation_process.next(
          server_state.delta_aggregate_state,
          client_outputs.weights_delta,
          weight=weight_denom)
    else:
      aggregate_output = aggregation_process.next(
          server_state.delta_aggregate_state, client_outputs.weights_delta)
    new_delta_aggregate_state = aggregate_output.state
    round_model_delta = aggregate_output.result

    server_state = tff.federated_map(
        server_update_fn,
        (server_state, round_model_delta, new_delta_aggregate_state))

    aggregated_outputs = dummy_model_for_metadata.federated_output_computation(
        client_outputs.model_output)
    if isinstance(aggregated_outputs.type_signature, tff.StructType):
      aggregated_outputs = tff.federated_zip(aggregated_outputs)

    return server_state, aggregated_outputs,client_outputs.weights_delta
    
  return run_one_round


def build_federated_averaging_process_attacked(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    model_update_aggregation_factory=None,
    client_update_tf=ClientExplicitBoosting(boost_factor=1.0)):
  """Builds the TFF computations for optimization using federated averaging with potentially malicious clients.
  Args:
    model_fn: A no-arg function that returns a `tff.learning.Model`.
    client_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`, use during local client training.
    server_optimizer_fn: A no-arg function that returns a
      `tf.keras.optimizers.Optimizer`, use to apply updates to the global model.
    model_update_aggregation_factory: An optional
      `tff.aggregators.AggregationFactory` that contstructs
      `tff.templates.AggregationProcess` for aggregating the client model
      updates on the server. If `None`, uses a default constructed
      `tff.aggregators.MeanFactory`, creating a stateless mean aggregation.
    client_update_tf: a 'tf.function' computes the ClientOutput.
  Returns:
    A `tff.templates.IterativeProcess`.
  """
  with tf.Graph().as_default():
    dummy_model_for_metadata = model_fn()
    weights_type = tff.learning.framework.weights_type_from_model(
        dummy_model_for_metadata)#returns -created TFF type
  
  if model_update_aggregation_factory is None:
    model_update_aggregation_factory = tff.aggregators.MeanFactory()

  if isinstance(model_update_aggregation_factory,
                tff.aggregators.WeightedAggregationFactory):
    aggregation_process = model_update_aggregation_factory.create(
        weights_type.trainable, tff.TensorType(tf.float32))
  else:
    aggregation_process = model_update_aggregation_factory.create(
        weights_type.trainable)

  server_init = build_server_init_fn(model_fn, server_optimizer_fn,
                                     aggregation_process.initialize)
  server_state_type = server_init.type_signature.result.member
  server_update_fn = build_server_update_fn(model_fn, server_optimizer_fn,
                                            server_state_type,
                                            server_state_type.model)
  tf_dataset_type = tff.SequenceType(dummy_model_for_metadata.input_spec)

  client_update_fn = build_client_update_fn(model_fn, client_optimizer_fn,
                                            client_update_tf, tf_dataset_type,
                                            server_state_type.model)

  federated_server_state_type = tff.type_at_server(server_state_type)

  federated_dataset_type = tff.type_at_clients(tf_dataset_type)
  benign_clientoutput_list = []
  counter = 0
  
  run_one_round_tff = build_run_one_round_fn_attacked(
      server_update_fn, client_update_fn, aggregation_process,
      dummy_model_for_metadata, federated_server_state_type,
      federated_dataset_type, benign_clientoutput_list,counter)
  
  
  return tff.templates.IterativeProcess(
      initialize_fn=server_init, next_fn=run_one_round_tff)


class ClientProjectBoost:
  """Client tensorflow logic for norm bounded attack."""

  def __init__(self, boost_factor, norm_bound, round_num):
    """Specify the attacking parameter.
    Args:
      boost_factor: A 'tf.float32' specifying how malicious update is boosted.
      norm_bound: A 'tf.float32' specifying the norm bound before boosting.
      round_num: A 'tf.int32' specifying the number of iterative rounds.
    """
    self.boost_factor = boost_factor
    self.norm_bound = norm_bound
    self.round_num = round_num

  @tf.function
  def __call__(self, model, optimizer, benign_dataset, malicious_dataset,
               client_is_malicious, initial_weights):
    """Updates client model with client potentially being malicious.
    Args:
      model: A `tff.learning.Model`.
      optimizer: A 'tf.keras.optimizers.Optimizer'.
      benign_dataset: A 'tf.data.Dataset' consisting of benign dataset.
      malicious_dataset: A 'tf.data.Dataset' consisting of malicious dataset.
      client_is_malicious: A 'tf.bool' showing whether the client is malicious.
      initial_weights: A `tff.learning.Model.weights` from server.
    Returns:
      A 'ClientOutput`.
    """
    model_weights = _get_weights(model)

    @tf.function
    def clip_by_norm(gradient, norm):
      """Clip the gradient by its l2 norm."""
      norm = tf.cast(norm, tf.float32)
      delta_norm = _get_norm(gradient)

      if delta_norm < norm:
        return gradient
      else:
        delta_mul_factor = tf.math.divide_no_nan(norm, delta_norm)# how much times it larger 
        return tf.nest.map_structure(lambda g: g * delta_mul_factor, gradient)

    @tf.function
    def project_weights(weights, initial_weights, norm):
      """Project the weight onto l2 ball around initial_weights with radius norm."""
      weights_delta = tf.nest.map_structure(lambda a, b: a - b, weights,
                                            initial_weights)

      return tf.nest.map_structure(tf.add, clip_by_norm(weights_delta, norm),
                                   initial_weights)

    @tf.function
    def reduce_fn(num_examples_sum, batch):
      """Runs `tff.learning.Model.train_on_batch` on local client batch."""
      with tf.GradientTape() as tape:
        output = model.forward_pass(batch)
      gradients = tape.gradient(output.loss, model.trainable_variables)
      optimizer.apply_gradients(zip(gradients, model.trainable_variables))
      return num_examples_sum + tf.shape(output.predictions)[0]

    @tf.function
    def compute_benign_update():
      """compute benign update sent back to the server."""
      tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                            initial_weights)

      num_examples_sum = benign_dataset.reduce(
          initial_state=tf.constant(0), reduce_func=reduce_fn)
      #A dataset element corresponding to the final state of the transformation.

      weights_delta_benign = tf.nest.map_structure(lambda a, b: a - b,
                                                   model_weights.trainable,
                                                   initial_weights.trainable)
      #weight_delta eg: model_weight =74, intial weight =70, ans is 4
      #np.savetxt()
      aggregated_outputs = model.report_local_outputs()#aggregation across clients

      return weights_delta_benign, aggregated_outputs, num_examples_sum
   
    @tf.function
    def compute_malicious_update():
      """compute malicious update sent back to the server."""

      _, aggregated_outputs, num_examples_sum = compute_benign_update()

      tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                            initial_weights)

      for _ in range(self.round_num):
        benign_dataset.reduce(
            initial_state=tf.constant(0), reduce_func=reduce_fn)
        malicious_dataset.reduce(
            initial_state=tf.constant(0), reduce_func=reduce_fn)

        tf.nest.map_structure(
            lambda a, b: a.assign(b), model_weights.trainable,
            project_weights(model_weights.trainable, initial_weights.trainable,
                            tf.cast(self.norm_bound, tf.float32)))

      weights_delta_malicious = tf.nest.map_structure(lambda a, b: a - b,
                                                      model_weights.trainable,
                                                      initial_weights.trainable)
      
      weights_delta = tf.nest.map_structure(
          lambda update: self.boost_factor * update, weights_delta_malicious)
     
      return weights_delta, aggregated_outputs, num_examples_sum
    
    @tf.function
    def compute_virtual_update():
      tf.nest.map_structure(lambda a, b: a.assign(b), model_weights,
                            initial_weights)

      num_examples_sum = benign_dataset.reduce(
          initial_state=tf.constant(0), reduce_func=reduce_fn)

      weights_delta_benign = tf.nest.map_structure(lambda a, b: a - b,
                                                   model_weights.trainable,
                                                   initial_weights.trainable)
      
      aggregated_outputs = model.report_local_outputs()
    
      return weights_delta_benign, aggregated_outputs, num_examples_sum


    if client_is_malicious:
      malicious_start = time.perf_counter()
      result = compute_malicious_update()
      malicious_stop = time.perf_counter()
      elapsed_time = malicious_stop-malicious_start
      #tf.print("malicious_elapsed-time",elapsed_time)
      #tf.print("weights_ delta of malicious  ",weights_delta)
      #tf.print("aggregated_outputs of malicious  ",aggregated_outputs)
      #tf.print("num_examples_sum of malicious  ",num_examples_sum)
      
    
    else:
      benign_start = time.perf_counter()
      result = compute_benign_update()
      benign_stop = time.perf_counter()
      elapsed_time = benign_stop-benign_start
      #tf.print("benign_elapsed-time",elapsed_time)
      #tf.print("weights_ delta of benign ",weights_delta)
      #tf.print("aggregated_outputs of benign ",aggregated_outputs)
      
      #tf.print("benign",result)

    #if elapsed_time > 0.4:
      #result = compute_virtual_update()
    

    
    weights_delta, aggregated_outputs, num_examples_sum = result
    response_time = elapsed_time
    weights_delta_weight = tf.cast(num_examples_sum, tf.float32)
    #tf.print(weights_delta)
    weight_norm = _get_norm(weights_delta)
    #tf.print("num_examples_sum of benign",num_examples_sum)
    return  ClientOutput(
        weights_delta, weights_delta_weight, aggregated_outputs, response_time,
        collections.OrderedDict({
            'num_examples': num_examples_sum,
            'weight_norm': weight_norm,
        }))

In [None]:
use_nchw_format = False
data_format = 'channels_first' if use_nchw_format else 'channels_last'
data_shape = [1, 28, 28] if use_nchw_format else [28, 28, 1]
def preprocess(dataset):
  """Preprocess dataset."""

  def element_fn(element):
    return collections.OrderedDict([
        ('x', tf.reshape(element['pixels'], data_shape)),
        ('y', tf.reshape(element['label'], [-1])),
    ])

  return dataset.repeat(5).map(element_fn).batch(20)


def load_malicious_dataset(num_tasks):
  """Load malicious dataset consisting of malicious target samples."""
  url_malicious_dataset = 'https://storage.googleapis.com/tff-experiments-public/targeted_attack/emnist_malicious/emnist_target.mat'
  filename = 'emnist_target.mat'
  path = tf.keras.utils.get_file(filename, url_malicious_dataset)
  emnist_target_data = io.loadmat(path)
  emnist_target_x = emnist_target_data['target_train_x'][0]
  emnist_target_y = emnist_target_data['target_train_y'][0]
  target_x = np.concatenate(emnist_target_x[-num_tasks:], axis=0)
  target_y = np.concatenate(emnist_target_y[-num_tasks:], axis=0)
  #print(len(target_x))
  dict_malicious = collections.OrderedDict([('x', target_x), ('y', target_y)])
  dataset_malicious = tf.data.Dataset.from_tensors(dict_malicious)
  return dataset_malicious, target_x, target_y


def load_test_data():
  """Load test data for faster evaluation."""
  url_test_data = 'https://storage.googleapis.com/tff-experiments-public/targeted_attack/emnist_test_data/emnist_test_data.mat'
  filename = 'emnist_test_data.mat'
  path = tf.keras.utils.get_file(filename, url_test_data)
  emnist_test_data = io.loadmat(path)
  test_image = emnist_test_data['test_x']
  test_label = emnist_test_data['test_y']
  return test_image, test_label


def make_federated_data_with_malicious(client_data,
                                       dataset_malicious,
                                       client_ids,
                                       with_attack=1,attack_no=2):
  """Make federated dataset with potential attackers."""
  benign_dataset = [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]
  malicious_dataset = [dataset_malicious for x in client_ids]
  if with_attack:
    client_type_list = [tf.cast(0, tf.bool)] * (len(client_ids) - attack_no) + attack_no * [
        tf.cast(1, tf.bool)
    
    ]
  else:
    client_type_list = [tf.cast(0, tf.bool)] * len(client_ids)
 
  return benign_dataset, malicious_dataset, client_type_list


def sample_clients_with_malicious(client_data,
                                  client_ids,
                                  dataset_malicious,
                                  num_clients=3,
                                  with_attack=1,attack_no=2):
  """Sample client and make federated dataset."""
  sampled_clients = np.random.choice(client_ids, num_clients)
  federated_train_data, federated_malicious_data, client_type_list = make_federated_data_with_malicious(
      client_data, dataset_malicious, sampled_clients, with_attack,attack_no)
  return federated_train_data, federated_malicious_data, client_type_list


def create_keras_model():
  """Build compiled keras model."""
  num_classes = 10 if True else 62
  model = tf.keras.models.Sequential([
      tf.keras.layers.Conv2D(
          32,
          kernel_size=(3, 3),
          activation='relu',
          input_shape=data_shape,
          data_format=data_format),
      tf.keras.layers.Conv2D(
          64, kernel_size=(3, 3), activation='relu', data_format=data_format),
      tf.keras.layers.MaxPool2D(pool_size=(2, 2), data_format=data_format),
      tf.keras.layers.Dropout(0.25),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dropout(0.5),
      tf.keras.layers.Dense(num_classes, activation='softmax')
  ])
  return model

emnist_train, _ = tff.simulation.datasets.emnist.load_data(
      only_digits=True)

In [None]:
test_image, test_label = load_test_data()

In [None]:
dataset_malicious, target_x, target_y = load_malicious_dataset(30)
len(dataset_malicious)


  # prepare model_fn.
example_dataset = preprocess(emnist_train.create_tf_dataset_for_client(emnist_train.client_ids[0]))
input_spec = example_dataset.element_spec

def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(keras_model,
                                       input_spec=input_spec,
                                       loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                                       metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])


def server_optimizer_fn():
    return tf.keras.optimizers.SGD(
        learning_rate=1.,
        momentum=0.,
        nesterov=True)


def evaluate(state, x, y, target_x, target_y, batch_size=100):
  """Evaluate the model on both main task and target task."""
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
  state.model.assign_weights_to(keras_model)
  test_metrics = keras_model.evaluate(x, y, batch_size=batch_size)
  test_metrics_target = keras_model.evaluate(
      target_x, target_y, batch_size=batch_size)
  return test_metrics, test_metrics_target



client_update_function = ClientProjectBoost(
      boost_factor=float(5),
      norm_bound=3. is 3, # The maximum norm for malicious update before boosting
      round_num=5)
#client_update_function = ClientExplicitBoosting(boost_factor=float(5))

# attack ---> The number of attack tasks we want to insert
# Number of local rounds used to compute the malicious update

query = tensorflow_privacy.GaussianSumQuery(0.7,0.0)#--- L2_Norm threshold
query = tensorflow_privacy.NormalizedQuery(query, number_clients)
dp_agg_factory = tff.aggregators.DifferentiallyPrivateFactory(query)

 # Clips records to bound the L2 norm, then adds Gaussian noise to the sum
 #The clipped l2 norm , The multiplication factor to ensure privacy


iterative_process = build_federated_averaging_process_attacked(
    model_fn=model_fn,
    model_update_aggregation_factory=dp_agg_factory,
    client_update_tf=client_update_function,
    server_optimizer_fn=server_optimizer_fn)
state = iterative_process.initialize()

In [None]:
for cur_round in range(1):
    if cur_round % 1 == 1 // 2: #Attacking frequency of the attacker = 1
      with_attack = 1
      #write_print(file_handle, 'Attacker appears!')
    else:
      with_attack = 0
    #print(with_attack)
    # sample clients and make federated dataset
    if (cur_round == 1):
      federated_train_data, federated_malicious_data, client_type_list = sample_clients_with_malicious(
          emnist_train,
          client_ids=emnist_train.client_ids,
          dataset_malicious=dataset_malicious,
          num_clients=number_clients,
          with_attack=with_attack,attack_no=0)
    else:
      federated_train_data, federated_malicious_data, client_type_list = sample_clients_with_malicious(
          emnist_train,
          client_ids=emnist_train.client_ids,
          dataset_malicious=dataset_malicious,
          num_clients=number_clients,
          with_attack=with_attack,attack_no=0)

    # one round of attacked federated averaging
    #write_print(file_handle, 'Round starts!')
    state, train_metrics,a= iterative_process.next(state, federated_train_data,
                                                  federated_malicious_data,
                                                  client_type_list)
    data = np.array(a)
    #data = np.array(t,dtype ='object')
    
    #tf.summary.histogram("%s-grad" % g[1].name, g[0]) for g in grads
   
      #tf.summary.histogram("weights",a)
    
    #np.save('weights'+ str(cur_round)+'.npy',data)
    print(len(a))
    
    for j in range(len(a)):
      f = plt.figure(figsize=(12,7))
      plt.hist(a[j], density=True)
      plt.show()
    #print(t)

# New Section

# New Section

In [None]:
if cur_round % 1 == 0:
      test_metrics, test_metrics_target = evaluate(state, test_image,
                                                   test_label, target_x,
                                                   target_y)
      print(test_metrics)
      print(test_metrics_target)

In [None]:
for i in range(0,4):
  plt.subplot(2, 2, i+1)
  plt.imshow(target_x[i].reshape((28, 28)))
  print(target_y[i])


In [None]:
import os
import shutil
!rm -rf ./logs/ 
logdir = "/tmp/logs/scalars/training/"
if os.path.exists(logdir):
  shutil.rmtree(logdir)

# Your code to create a summary writer:
summary_writer = tf.summary.create_file_writer(logdir)

state = iterative_process.initialize()

In [None]:
with summary_writer.as_default():
  for round_num in range(1, 1):
    state, metrics = iterative_process.next(state, federated_train_data)
    for name, value in metrics['train'].items():
      tf.summary.scalar(name, value, step=round_num)

In [None]:
%tensorboard --logdir /tmp/logs/scalars/