In [1]:
# import ipywidgets as widgets
# from IPython.display import display

# # Create text input and button
# text_input = widgets.Text(placeholder="Enter some text", description="Input:")
# button = widgets.Button(description="Submit")
# output = widgets.Output()

# # Define what happens when the button is clicked
# def on_button_click(b):
#     with output:
#         output.clear_output()  # Clear previous output
#         print(f"You entered: {text_input.value}")

# # Link the button click event to the function
# button.on_click(on_button_click)

# # Display the widgets
# display(text_input, button, output)

## Federated Learning-based threat detection: DEMO 

In [4]:
import ipywidgets as widgets
from IPython.display import display, HTML
import networkx as nx
from pyvis.network import Network
from ipywidgets import Layout
import copy

# Initialize graph and data structures
G = nx.Graph()
connections = []
history = []  # To store graph states for undo functionality
malicious_nodes = set()  # To keep track of malicious nodes

# Define widgets
node_count_input = widgets.BoundedIntText(value=5, min=2, max=100, step=1, description="Node Count:", layout=Layout(width='50%'))
topology_select = widgets.Dropdown(
    # options=['Star', 'Ring', 'Mesh', 'Tree'],
    options=['Star', 'Ring', 'Mesh', 'Tree'],
    value='Star',
    description='Topology:',
    layout=Layout(width='50%')
)
apply_topology_button = widgets.Button(description="Apply Topology")

# For adding nodes
add_node_count_input = widgets.BoundedIntText(value=1, min=1, max=100, step=1, description="Add Nodes:", layout=Layout(width='50%'))
add_node_button = widgets.Button(description="Add Nodes")

# For removing nodes
delete_node_start_input = widgets.BoundedIntText(value=0, min=0, max=100, step=1, description="Start Node ID:", layout=Layout(width='50%'))
delete_node_end_input = widgets.BoundedIntText(value=1, min=0, max=100, step=1, description="End Node ID:", layout=Layout(width='50%'))
delete_node_button = widgets.Button(description="Delete Nodes")

# For adding connections
node_select = widgets.SelectMultiple(options=list(G.nodes), description="Select Nodes", layout=Layout(width='50%'))
add_connection_button = widgets.Button(description="Add Connection")

# For marking nodes as malicious
malicious_node_select = widgets.SelectMultiple(options=list(G.nodes), description="Malicious Nodes", layout=Layout(width='50%'))
set_malicious_button = widgets.Button(description="Set Malicious Nodes")

# Undo button
undo_button = widgets.Button(description="Undo Last Action", layout=Layout(width='30%'))

graph_output = widgets.Output()

# Save current graph state in history
def save_graph_state():
    global history
    # Save a deep copy of the current state of the graph (nodes and edges)
    history.append((copy.deepcopy(G), copy.deepcopy(connections), copy.deepcopy(malicious_nodes)))

# Function to initialize the graph based on the selected topology
def initialize_graph(topology, node_count):
    G.clear()  # Clear the graph for new topology
    connections.clear()  # Clear previous connections
    malicious_nodes.clear()  # Clear malicious nodes
    if topology == 'Star':
        G.add_node("Node0")  # Central node
        for i in range(1, node_count):
            node_name = f"Node{i}"
            G.add_node(node_name)
            G.add_edge("Node0", node_name)  # Connect each node to the central node
            connections.append(("Node0", node_name))
    elif topology == 'Ring':
        for i in range(node_count):
            node_name = f"Node{i}"
            G.add_node(node_name)
        for i in range(node_count):
            G.add_edge(f"Node{i}", f"Node{(i + 1) % node_count}")  # Connect nodes in a ring
            connections.append((f"Node{i}", f"Node{(i + 1) % node_count}"))
    elif topology == 'Mesh':
        for i in range(node_count):
            node_name = f"Node{i}"
            G.add_node(node_name)
        for i in range(node_count):
            for j in range(i + 1, node_count):
                G.add_edge(f"Node{i}", f"Node{j}")  # Connect each node to every other node
                connections.append((f"Node{i}", f"Node{j}"))
    elif topology == 'Tree':
        G.add_node("Node0")  # Root node
        for i in range(1, node_count):
            node_name = f"Node{i}"
            G.add_node(node_name)
            parent_node = f"Node{(i - 1) // 2}"  # Connect each node to its parent node
            G.add_edge(parent_node, node_name)
            connections.append((parent_node, node_name))
    
    node_select.options = list(G.nodes)  # Update node selection options
    malicious_node_select.options = list(G.nodes)  # Update malicious node selection options
    save_graph_state()  # Save the initial state

# Callback to apply selected topology
def apply_topology_callback(b):
    topology = topology_select.value
    node_count = node_count_input.value
    initialize_graph(topology, node_count)
    with graph_output:
        graph_output.clear_output(wait=True)
        draw_graph()

# Callback to add nodes to the graph
def add_node_callback(b):
    node_count = add_node_count_input.value
    existing_nodes = len(G.nodes)  # Count the existing nodes
    for i in range(1, node_count + 1):
        new_node_name = f"Node{existing_nodes + i - 1}"
        G.add_node(new_node_name)
        G.add_edge("Node0", new_node_name)  # Default to connect new nodes to the central node
        connections.append(("Node0", new_node_name))
    node_select.options = list(G.nodes)  # Update node selection options
    malicious_node_select.options = list(G.nodes)  # Update malicious node selection options
    save_graph_state()  # Save the current state
    with graph_output:
        graph_output.clear_output(wait=True)
        draw_graph()

# Callback to delete nodes within a range of IDs
def delete_node_callback(b):
    start_node = delete_node_start_input.value
    end_node = delete_node_end_input.value
    
    # Check and remove nodes within the specified range if they exist
    for node_id in range(start_node, end_node + 1):
        node_name = f"Node{node_id}"
        if node_name in G.nodes:
            G.remove_node(node_name)
            if node_name in malicious_nodes:
                malicious_nodes.remove(node_name)
    
    node_select.options = list(G.nodes)  # Update node selection options
    malicious_node_select.options = list(G.nodes)  # Update malicious node selection options
    save_graph_state()  # Save the current state
    with graph_output:
        graph_output.clear_output(wait=True)
        draw_graph()

# Callback to add connections between selected nodes
def add_connection_callback(b):
    selected = list(node_select.value)
    if len(selected) == 2:  # Only add connection if two nodes are selected
        G.add_edge(selected[0], selected[1])
        connections.append((selected[0], selected[1]))
        save_graph_state()  # Save the current state
        with graph_output:
            graph_output.clear_output(wait=True)
            draw_graph()

# Callback to set malicious nodes
def set_malicious_callback(b):
    global malicious_nodes
    malicious_nodes = set(malicious_node_select.value)
    save_graph_state()  # Save the current state
    with graph_output:
        graph_output.clear_output(wait=True)
        draw_graph()

# Undo callback to rollback to previous state
def undo_callback(b):
    global G, connections, history, malicious_nodes
    if len(history) > 1:
        history.pop()  # Remove the most recent state
        G, connections, malicious_nodes = copy.deepcopy(history[-1])  # Restore the last saved state
        node_select.options = list(G.nodes)  # Update node selection options
        malicious_node_select.options = list(G.nodes)  # Update malicious node selection options
        with graph_output:
            graph_output.clear_output(wait=True)
            draw_graph()

# Function to get node mapping
def get_node_mapping():
    node_mapping = {node_name: int(node_name.replace('Node', '')) for node_name in G.nodes()}
    return node_mapping

# Function to get malicious node IDs
def get_malicious_node_ids():
    node_mapping = get_node_mapping()
    malicious_ids = [node_mapping[node_name] for node_name in malicious_nodes]
    return malicious_ids

# Function to draw the graph using pyvis with CDN resources
def draw_graph():
    net = Network(notebook=True, height='600px', width='100%', bgcolor='#ffffff',
                  font_color='black', cdn_resources='in_line')

    node_mapping = get_node_mapping()

    # Specify the desired node size (adjust as needed)
    node_size = 15  # Decrease this value to make nodes smaller

    # Add nodes with color, label, and size based on whether they are malicious
    for node in G.nodes():
        color = 'red' if node in malicious_nodes else 'green'
        node_id = node_mapping[node]
        net.add_node(node_id, label=f"{node} (ID: {node_id})", color=color, size=node_size)

    # Add edges using node IDs
    for edge in G.edges():
        source_id = node_mapping[edge[0]]
        target_id = node_mapping[edge[1]]
        net.add_edge(source_id, target_id)

    # Save the network and display it in the notebook
    net.show('graph.html')
    display(HTML('graph.html'))


# Display the UI
display(
    node_count_input,
    topology_select,
    apply_topology_button,
    add_node_count_input,
    add_node_button,
    delete_node_start_input,
    delete_node_end_input,
    delete_node_button,
    node_select,
    add_connection_button,
    malicious_node_select,
    set_malicious_button,
    undo_button,
    graph_output
)

# Button event handlers
apply_topology_button.on_click(apply_topology_callback)
add_node_button.on_click(add_node_callback)
delete_node_button.on_click(delete_node_callback)
add_connection_button.on_click(add_connection_callback)
set_malicious_button.on_click(set_malicious_callback)
undo_button.on_click(undo_callback)

# Initialize graph plot
with graph_output:
    apply_topology_callback(None)

# Access selected nodes and connections
def get_selected_nodes():
    return list(node_select.value)

def get_connections():
    return connections

def get_total_number_of_nodes():
    return len(G.nodes)


BoundedIntText(value=5, description='Node Count:', layout=Layout(width='50%'), min=2)

Dropdown(description='Topology:', layout=Layout(width='50%'), options=('Star', 'Ring', 'Mesh', 'Tree'), value=…

Button(description='Apply Topology', style=ButtonStyle())

BoundedIntText(value=1, description='Add Nodes:', layout=Layout(width='50%'), min=1)

Button(description='Add Nodes', style=ButtonStyle())

BoundedIntText(value=0, description='Start Node ID:', layout=Layout(width='50%'))

BoundedIntText(value=1, description='End Node ID:', layout=Layout(width='50%'))

Button(description='Delete Nodes', style=ButtonStyle())

SelectMultiple(description='Select Nodes', layout=Layout(width='50%'), options=(), value=())

Button(description='Add Connection', style=ButtonStyle())

SelectMultiple(description='Malicious Nodes', layout=Layout(width='50%'), options=(), value=())

Button(description='Set Malicious Nodes', style=ButtonStyle())

Button(description='Undo Last Action', layout=Layout(width='30%'), style=ButtonStyle())

Output()

In [5]:
import ipywidgets as widgets
from IPython.display import display

# Create an Output widget to display the connections dictionary
connections_output = widgets.Output()

# Function to convert connections to a dictionary
def get_connections_as_dict():
    connections_dict = {}
    for node1, node2 in connections:
        if node1 not in connections_dict:
            connections_dict[node1] = []
        if node2 not in connections_dict:
            connections_dict[node2] = []
        connections_dict[node1].append(node2)
        connections_dict[node2].append(node1)
    return connections_dict

def map_and_filter_nodes(topology):
    # Step 1: Filter out nodes with only 1 connection
    filtered_topology = {node: clients for node, clients in topology.items() if len(clients) > 1}
    
    # Step 2: Create a unique mapping for all remaining nodes (keys and values)
    unique_nodes = list(set(filtered_topology.keys()).union(*filtered_topology.values()))
    # Sort the nodes based on the numeric part of their names
    unique_nodes.sort(key=lambda x: int(''.join(filter(str.isdigit, x))))
    node_mapping = {node: idx for idx, node in enumerate(unique_nodes)}
    
    # Step 3: Replace the original node names with their unique numeric IDs
    mapped_topology = {}
    for node, clients in filtered_topology.items():
        mapped_main_node = node_mapping[node]
        mapped_clients = [node_mapping[client] for client in clients]
        mapped_topology[mapped_main_node] = mapped_clients

    # Step 4: Sort the topology by the number of connections (length of the value lists)
    sorted_mapped_topology = dict(sorted(mapped_topology.items(), key=lambda item: len(item[1]), reverse=True))

    # Step 5: Eliminate indices that have already appeared as keys and remove empty lists
    final_topology = {}
    appeared_keys = set()

    for key, values in sorted_mapped_topology.items():
        # Remove values that have already appeared as keys
        filtered_values = [v for v in values if v not in appeared_keys]
        if filtered_values:  # Only add the key if it has remaining values
            final_topology[key] = filtered_values
        appeared_keys.add(key)  # Mark this key as appeared

    return final_topology, node_mapping

# Create a button to run the process
run_button = widgets.Button(description="Map and Filter Nodes")

# Create output widgets
output = widgets.Output()
loading_label = widgets.Label("")

filtered_topology = None

# Function to run on button click
def on_button_click(b):
    loading_label.value = "Processing... please wait!"
    output.clear_output()
    
    with output:
        # Execute the function
        filtered_topology_, node_mapping = map_and_filter_nodes(get_connections_as_dict())
        filtered_topology = filtered_topology_
        # Display the filtered topology and node mapping
        print("Filtered Topology:", filtered_topology)
        print("Node Mapping:", node_mapping)
    
    # Update the loading label after completion
    loading_label.value = "Completed!"

# Link the button to the function
# run_button.on_click(on_button_click)

# Display the widgets
# Use VBox to display the button and output together
# vbox_layout = widgets.VBox([run_button, output, loading_label])

# Display the widgets
# display(vbox_layout)

# Button to trigger the display of connections dictionary

# Function to display the connections dictionary in the output widget
def display_connections():
    with connections_output:
        connections_output.clear_output()  # Clear previous output
        connections_dict = get_connections_as_dict()
        filtered_topology_, node_mapping = map_and_filter_nodes(get_connections_as_dict())
        filtered_topology = filtered_topology_
        # Display the filtered topology and node mapping
        print("Filtered Topology:", filtered_topology)
        print("Node Mapping:", node_mapping)
        print(connections_dict)  # Print the connections dictionary

show_connections_button = widgets.Button(description="Show Connections")

# Set up button callback
show_connections_button.on_click(lambda b: display_connections())

# Display the button and the output area
display(show_connections_button, connections_output)


Button(description='Show Connections', style=ButtonStyle())

Output()

### Running FL Algorithm

In [6]:
import os
from contextlib import redirect_stdout
import io

if "run_once" not in get_ipython().user_ns:
    # Suppress the output
    with io.StringIO() as buf, redirect_stdout(buf):
        %cd ..
        %cd ..
    get_ipython().user_ns["run_once"] = True


In [7]:
import flwr as fl
import torch

%matplotlib inline
import importlib

# import sys
# sys.path.append('../src')
from src.FLProcess.CustomFedAvg import CustomFedAvg
from src.FLProcess.FLUtil import weighted_average
from src.FLProcess.FlowerClient import FlowerClient
from src.NN.NNConfig import get_nn
from src.dataset import dataLoaderFactory
from src.dataset.datasetStrategy import poison_strategy_with_non_iid_split, poison_strategy_for_multi_label_split
from src.poisonDetection.clientAnalysis.strategyFnGeneralAlg import client_analysis_fn_general_alg
from src.poisonDetection.clientAnalysis.strategyFnDebugging import client_analysis_strategy_fn_debugging
from src.poisonDetection.clientAnalysis.strategyFnRandomPoison import client_analysis_strategy_fn_random_poison
# from util.constants import NUM_CLIENTS
from util import constants
from src.NN import NNUtil
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import time
# importlib.reload(dataLoaderFactory) # for library code modifications
# importlib.reload(constants) 
# importlib.reload(NNUtil) 

from src.FLProcess.FLUtil import get_mdl_from_weights, get_pred_from_models, get_mdl_of_client_at_round
from src.dataset.datasetHandler import get_testloader
import shap
from src.poisonDetection.clusteringHDBSCAN import run_hdbscan_clustering_algorithm
from src.poisonDetection.tsneVisualisation import get_tsne_data_from_input_features, visualise_tsne_clusters_with_idx, \
    visualise_clusters_with_tsne
from src.poisonDetection.clientAnalysis.strategyFnGeneralAlg import general_algorithm_main_calc
import numpy as np
import matplotlib.pyplot as plt
from skimage.color import gray2rgb, rgb2gray, label2rgb # since the code wants color images
from lime import lime_image
from lime.wrappers.scikit_image import SegmentationAlgorithm

2025-01-16 15:48:52.753383: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-01-16 15:48:52.836558: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [11]:
%load_ext autoreload
%autoreload 2

In [12]:
def set_constants(num_clients=10, selected_dataset='MNIST'):
    constants.NUM_CLIENTS = num_clients
    constants.SELECTED_DATASET = selected_dataset

In [13]:
# total_nodes = get_total_number_of_nodes()
# # total_nodes = 20
# set_constants(num_clients=total_nodes, selected_dataset='MNIST')

In [14]:
# num_poison = 5
# Example usage: Get malicious node IDs

In [25]:
# Create a button
run_button = widgets.Button(description="Create clients")

# Create an output widget
output = widgets.Output()

# Create a loading message
loading_label = widgets.Label("")

# Global variables to store loaders
trainloaders = None
valloaders = None
testloaders = None

# Function to execute when button is clicked
def on_button_click(b):
    total_nodes = get_total_number_of_nodes()
    set_constants(num_clients=total_nodes, selected_dataset='MNIST')
    
    malicious_ids = get_malicious_node_ids()
    print("Malicious Node IDs:", malicious_ids)
    kwargs_train = {'poison_type': 'random_poison', 'poison_ratio': 1, 'target_label': 9, 'target_clients': malicious_ids}
    # kwargs_train = {'poison_type': 'random_poison', 'poison_ratio': 1, 'target_clients': [1,2,3]}
    kwargs_val = {'poison_type': 'random_poison', 'poison_ratio': 0, 'target_clients': []}
    # trainloaders, valloaders, testloaders = dataLoaderFactory.generate_data_loaders(kwargs_train, kwargs_val,
    #                           strategy=poison_strategy_with_non_iid_split,
    #                           len_train_data=10000, len_test_data=1000,
    #                           random_ratio=1, is_visualize=False,
    #                           visualize_idx=0)
    
    global trainloaders, valloaders, testloaders  # Make the variables global so they can be used outside the function
    loading_label.value = "Running... please wait!"
    output.clear_output()
    
    with output:
        # Simulate long running task
        # time.sleep(3)  # Replace this with the actual function call
        trainloaders, valloaders, testloaders = dataLoaderFactory.generate_data_loaders(
            kwargs_train, kwargs_val,
            strategy=poison_strategy_with_non_iid_split,
            len_train_data=10000, len_test_data=1000,
            random_ratio=1, is_visualize=False,
            visualize_idx=0
        )
        
        # Display the output data (can be expanded or customized)
        print("Data Loaders Created Successfully!")
    
    # Clear loading message when done
    loading_label.value = "Completed!"

# Link button to function
run_button.on_click(on_button_click)

# Display the button, loading label, and output in the UI
display(run_button, loading_label, output)

Button(description='Create clients', style=ButtonStyle())

Label(value='')

Output()

In [16]:
from src.NN.MdlTraining import train, test
from src.NN.NNUtil import get_parameters, set_parameters
from flwr.common import ndarrays_to_parameters

class FlowerClientDecen(fl.client.NumPyClient):
    def __init__(self, net, train_loader, val_loader, local_eps = 1):
        self.net = net
        self.trainloader = train_loader
        self.valloader = val_loader
        self.local_eps = local_eps

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        set_parameters(self.net, parameters)
        # train_multi_label(self.net, self.trainloader, epochs=1)
        train(self.net, self.trainloader, epochs=self.local_eps)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        # loss, accuracy = test_multi_label(self.net, self.valloader)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

In [17]:
def client_fn(cid) -> FlowerClient:
    """Create a Flower client representing a single organization."""
    # Load model
    net = get_nn()
    net.to(constants.DEVICE)
    # trainloaders, valloaders, _ = get_train_val_test_loaders()
    # Note: each client gets a different trainloader/valloader, so each client
    # will train and evaluate on their own unique data
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]

    # Create a  single Flower client representing a single organization
    return FlowerClient(net, trainloader, valloader)

In [18]:
def initialise_random_mdl():
    net = get_nn()
    net.to(constants.DEVICE)
    return net

In [19]:
def get_model_parameters(model):
    mdl_np = [val.cpu().detach().numpy() for val in model.parameters()]
    mdl_0_param = ndarrays_to_parameters(mdl_np)
    return mdl_0_param

In [20]:
# master_model.parameters()

In [21]:
# mdl_0 = get_mdl_of_client_at_round(target_client_id=1, round_no=0,
#                                            client_updates_list=client_updates_list)
# mdl_0_param_np = get_model_parameters(mdl_0)
# mdl_0_param = ndarrays_to_parameters(mdl_0_param_np)

In [22]:
# for i in range(1):
#     client_updates_list_round = []
#     for j in 

In [23]:
def run_fl_round(master_params, client_ids):
    client_updates_list_nw = []
    aggregated_updates_list = []
    results = []
    weight_results = []
    eliminated_client_list = []
    eliminated_client_ids = []
    debug_info = []
    
    # should update these value based on the dataset: total_labels_per_client, target_label
    kwargs_poison = {'client_ids': ['0', '1', '2', '3', '4'],
                            'explainer_type': 'grad_exp', 'total_rounds': 1, 'sample_count_for_plot': 20,
                            'target_label': 1, 'is_pca': False, 'num_pca_features': 80,
                            'round_idx': 0, 'start_feature_idx': 0, 'total_labels_per_client': 2,
                            'min_cluster_size': 2, 'perplexity': 10,
                            'show_poison_detection_graphs': True, 'malicious_start_idx': 10, 'malicious_end_idx': 40, 'epsilon':0.0,
                            'debug_info': debug_info,'is_eliminating_clients':False}
    
    # kwargs_debug = {'dummy_poison_ids':[str(i) for i in range(num_poison)], 'debug_info': debug_info,
                    # 'is_eliminating_clients':True}
    
    strategy = CustomFedAvg(
        fraction_fit=1.0,
        evaluate_metrics_aggregation_fn=weighted_average,  # <-- pass the metric aggregation function
        client_updates_list=client_updates_list_nw,
        aggregated_updates_list=aggregated_updates_list,
        results_all=results,
         min_fit_clients=1,       # Minimum number of clients to train on each round
        min_available_clients=1, # Minimum number of clients that need to be available
        min_evaluate_clients=1,      # Minimum number of clients to evaluate on each round
        initial_parameters=master_params,
        # client_analysis_strategy_fn=client_analysis_fn_general_alg,
        # client_analysis_strategy_fn=client_analysis_strategy_fn_debugging,
        # strategy_kwargs=kwargs_poison,
        # strategy_kwargs = kwargs_debug,
        eliminated_client_list=eliminated_client_list,
        eliminated_client_ids=eliminated_client_ids,
        weight_results=weight_results
    )
    
    client_resources = None
    if constants.DEVICE == "cuda":
        # 10 client can run concurrently on a single GPU, but only if you have 10 CPU threads. 
        client_resources = {"num_cpus": 1, "num_gpus": 0.2}
    
    # for key, values in filtered_topology.items():
    #     server_id = key
    #     cli_ids = values
    #     print(f"server_id: {server_id}, cli_ids: {cli_ids}")
    
        # Start simulation
    sim = fl.simulation.start_simulation(
        client_fn=client_fn,
        # num_clients=constants.NUM_CLIENTS,
        # num_clients=1,
        clients_ids = client_ids,
        config=fl.server.ServerConfig(num_rounds=1),
        strategy=strategy,
        client_resources=client_resources,
        ray_init_args={"log_to_driver": False, "num_cpus": 10, "num_gpus": 1}
    )
    ray.shutdown()
    return client_updates_list_nw, aggregated_updates_list

In [26]:
import ipywidgets as widgets
from IPython.display import display
import time

# Create a widget to input the number of rounds
round_slider = widgets.IntSlider(
    value=1,
    min=1,
    max=10,
    step=1,
    description='Rounds:',
    style={'description_width': 'initial'}
)

# Create a button to start the process
start_button = widgets.Button(description="Start Training")

# Create output widgets
output = widgets.Output()
loading_label = widgets.Label("")

# Declare global variable
client_updates_list_all_weights = []
num_rounds_fin = 0
# Function to execute on button click
def on_start_click(b):
    global client_updates_list_all_weights  # Make the variable global
    loading_label.value = "Running federated learning... please wait!"
    output.clear_output()
    num_rounds = round_slider.value
    num_rounds_fin = num_rounds
    cli_updates = []
    
    total_nodes = get_total_number_of_nodes()
    client_updates_list_all = [initialise_random_mdl() for i in range(total_nodes)]
    client_updates_list_all_weights = [[] for i in range(total_nodes)]
    filtered_topology, node_mapping = map_and_filter_nodes(get_connections_as_dict())
    
    with output:
        for round_num in range(num_rounds):
            print(f"Round {round_num + 1}/{num_rounds}")
            
            for master_id in filtered_topology:
                client_ids = filtered_topology[master_id]
                
                # Obtain the master model from client_updates_list_all using the master_id
                master_model = client_updates_list_all[master_id]
                
                # Get the parameters of the master model
                master_params = get_model_parameters(master_model)
                
                # Run the federated learning round
                client_updates_dict, aggregated_updates_list = run_fl_round(master_params, client_ids)
                cli_updates.append(client_updates_dict)
                
                # Update the models of the clients with the new updates from client_updates_dict
                for client_id in client_ids:
                    client_updates_list_all[client_id] = get_mdl_from_weights(client_updates_dict[0].get(client_id)[0])
                    client_updates_list_all_weights[client_id] = client_updates_dict[0].get(client_id)[0]
                    
                client_updates_list_all[master_id] = get_mdl_from_weights(aggregated_updates_list[0])
                client_updates_list_all_weights[master_id] = aggregated_updates_list[0]
                
                # Simulate a delay for demonstration
                time.sleep(1)  # You can remove this in actual runs
        
        loading_label.value = "Federated learning completed!"

# Link the button to the function
start_button.on_click(on_start_click)

# Display the widgets and output
display(round_slider, start_button, loading_label, output)


IntSlider(value=1, description='Rounds:', max=10, min=1, style=SliderStyle(description_width='initial'))

Button(description='Start Training', style=ButtonStyle())

Label(value='')

Output()

In [66]:
# print(client_updates_list_all_weights)

In [32]:
# num_rounds = 1
# cli_updates = []
# for round in range(num_rounds): 
#     for master_id in filtered_topology:
#         client_ids = filtered_topology[master_id]
        
#         # Obtain the master model from client_updates_list_all using the master_id
#         master_model = client_updates_list_all[master_id]
        
#         # Get the parameters of the master model
#         master_params = get_model_parameters(master_model)
        
#         # Run the federated learning round
#         client_updates_dict, aggregated_updates_list = run_fl_round(master_params, client_ids)
#         cli_updates.append(client_updates_dict)
#         # Update the models of the clients with the new updates from client_updates_dict
#         for client_id in client_ids:
#             client_updates_list_all[client_id] = get_mdl_from_weights(client_updates_dict[0].get(client_id)[0])
#             client_updates_list_all_weights[client_id] = client_updates_dict[0].get(client_id)[0]
            
#         client_updates_list_all[master_id] = get_mdl_from_weights(aggregated_updates_list[0])
#         client_updates_list_all_weights[master_id] = aggregated_updates_list[0]

In [33]:
# def get_mdl_weights(mdl):
#     return [val.cpu().detach().numpy() for val in mdl.parameters()]
    
# client_updates_list_all_weights_dict = [{i:(client_updates_list_all_weights[i],16) for i in range(len(client_updates_list_all_weights))}]

In [34]:
# cli_updates_ori = cli_updates[0]

In [35]:
# client_updates_list_all_weights_dict[0]

In [36]:
# cli_updates_ori[0]

In [37]:
# client_updates_list_all[i].parameters()

In [38]:
# get_mdl_of_client_at_round(target_client_id=target_client_id, round_no=0, client_updates_list=client_updates_list_all_dict)

In [39]:
# client_updates_list_all_dict.keys()

In [40]:
# # Update the models of the clients with the new updates from client_updates_dict
# for client_id in client_ids:
#     client_updates_list_all[client_id] = get_mdl_from_weights(client_updates_dict[0].get(client_id)[0])
# client_updates_list_all[master_id] = aggregated_updates_list[0]get_mdl_from_weights(aggregated_updates_list[0])

In [41]:
# # Update the models of the clients with the new updates from client_updates_dict
# for client_id in client_ids:
#     client_updates_list_all[client_id] = client_updates_dict[client_id]

# # Update the master model with the aggregated updates
# # Since aggregated_updates_list contains only one element, we use index 0
# client_updates_list_all[master_id] = aggregated_updates_list[0]

In [39]:
from src.FLProcess.FLUtil import get_mdl_from_weights, get_pred_from_models, get_mdl_of_client_at_round
from src.dataset.datasetHandler import get_testloader
import shap
from src.poisonDetection.clusteringHDBSCAN import run_hdbscan_clustering_algorithm
from src.poisonDetection.tsneVisualisation import get_tsne_data_from_input_features, visualise_tsne_clusters_with_idx, \
    visualise_clusters_with_tsne
from src.poisonDetection.clientAnalysis.strategyFnGeneralAlg import general_algorithm_main_calc

In [40]:
'''main algorithm for new general poisoning detection'''
def general_algorithm_main_calc2(client_updates_list, total_labels_per_client, hdbscan_labels, malicious_ids, poisoner_ct=0):
    target_clients = list(client_updates_list[0].keys())
    # print(target_clients)
    # get all cluster ids, ignore -1 cluster id
    all_clusters_ids = np.unique(hdbscan_labels)
    # if np.any(all_clusters_ids == 0):
    #     all_clusters_ids = all_clusters_ids[all_clusters_ids != 0]

    # create an empty dict of arrays to get the positions for each features to be calculated for suspicious counts
    feature_positions = {}
    for i in all_clusters_ids:
        feature_positions[i] = []

    for i in range(len(hdbscan_labels)):
        if hdbscan_labels[i] in feature_positions.keys():
            feature_positions[hdbscan_labels[i]].append(i)
            
    # ## patch to remove outliers       
    # for key, value_list in list(feature_positions.items()):
    #     if len(value_list) > 200:
    #         # Remove the key-value pair if the length is greater than 200
    #         del feature_positions[key]
        
    print(feature_positions)
    diff_idxes_all = []

    # Main algorithm to detect poisoners: compare feature repetitions within the same cluster.
    # If different features are present, possible poisoning alert
    for i in all_clusters_ids:
        # print(i)
        # List of numbers
        numbers = feature_positions[i]
        # print(len(numbers))
        # if len(numbers)>200:
        #     pass
        # Find the remainder when each number is divided by total output features/labels per client
        remainders = [num % total_labels_per_client for num in numbers]

        # Check if all the remainders are the same
        if all(remainder == remainders[0] for remainder in remainders):
            print("All cluster features are the same:", i)
        else:
            # print("Not all features are the same. Possible poisoning")
            # Find and isolate the numbers with different remainders
            different_idxes = [num for num, remainder in zip(numbers, remainders) if remainder != remainders[0]]
            # print("Cluster with different features:", different_idxes)
            # print(numbers)
            diff_idxes_all.extend(numbers)

    sus_ct = {}

    for i in list(target_clients):
        sus_ct[i] = 0
    

    # add a suspicious score for each client
    for i in diff_idxes_all:
        sus_client_position = i // (total_labels_per_client)
        sus_client = target_clients[sus_client_position]
        sus_ct[sus_client] += 1

    print(sus_ct)
    print(sus_ct.values())
    total_sus_ct = sum(sus_ct.values())
    # total_sus_ct = sum(list(sus_ct.values))
    # total_sus_cli = sum(1 for value in sus_ct.values() if value > 0)
    total_cli = len(list(client_updates_list[0].keys()))
    
    print(sus_ct)  # this is what we want!!

    # detecting poison clients
    poison_clients = []
    
    for key, value in sus_ct.items():
        # if value >= total_labels_per_client / 2:
        if value >= int(total_sus_ct/total_cli) and value >= total_labels_per_client / 2:
            poison_clients.append(key)  # CONVERTING TO AN INTEGER CAN BE A POTENTIAL BUG - yes it is, so eliminated!!!

    poison_idxes = []
    idxes_to_remove = list(client_updates_list[0].keys())
    for i in poison_clients:
        poison_idxes.append(idxes_to_remove.index(i))

    print('detected: ', poison_idxes)
    debugging_enabled = False
    poison_idx_viewing = True
    if poison_idx_viewing:
        # debugging operation (should update)
        my_list = list(client_updates_list[0].keys())
        # values_to_find = [str(i) for i in range(poisoner_ct)]
        # values_to_find = ['1', '2', '3','4','5','6','7','8','9','10']
        # values_to_find = ['1', '2', '3', '4', '5']
        values_to_find = malicious_ids

        # Find the indexes of the values in the list
        indexes = [i for i, value in enumerate(my_list) if value in values_to_find]
        print('original: ', indexes)
        if debugging_enabled:
            return indexes
    return poison_idxes

In [41]:
# client_updates_list_all_weights_dict

In [42]:
# client_updates_list_all_weights

In [43]:
# sample_count_for_test = 100
# sample_count_for_plot = 3
# total_rounds = 1
# total_labels_per_client = 10

# client_updates_list_all_weights_dict = [{i:(client_updates_list_all_weights[i],16) for i in range(len(client_updates_list_all_weights))}]

# client_updates_list = client_updates_list_all_weights_dict
# target_clients = list(client_updates_list[0].keys())

# testloader = get_testloader(len_test=sample_count_for_test+sample_count_for_plot*2, batch_size=sample_count_for_test+sample_count_for_plot*2, 
#                             shuffle=True)
# batch = next(iter(testloader))
# images, actual_out = batch
# background = images[:sample_count_for_test]
# test_images = images[sample_count_for_test:sample_count_for_test + sample_count_for_plot]

# total_detected_poison_idxes = []

# for round in range(total_rounds):
#     shap_feature_ori_list_all = []
#     round_no = round
#     shap_feature_ori_list_all = []
    
#     for cli in target_clients:
#         target_client_id = cli
#         mdl_0 = get_mdl_of_client_at_round(target_client_id=target_client_id, round_no=round_no,
#                                            client_updates_list=client_updates_list)
#         e = shap.DeepExplainer(mdl_0, background)
#         shap_values = e.shap_values(test_images)
#         shap_vals_all_flattened = []
#         for pred in shap_values:
#             for j in pred:
#                 shap_vals_all_flattened.append(j.flatten())
#         shap_feature_ori_list_all.append(shap_vals_all_flattened)
#     shap_feature_list = shap_feature_ori_list_all
#     shap_feature_per_client = []
#     for i in range(len(target_clients)):
#         for j in range(sample_count_for_plot*total_labels_per_client):
#             shap_feature_per_client.append(shap_feature_list[i][j])
#     # print(len(shap_feature_per_client))
#     min_cluster_size = 2
#     # epsilon = 0.011
#     epsilon = 0.0
#     hdbscan_labels, hdbscan_clusterer, colors = run_hdbscan_clustering_algorithm(
#             input_feature_list=shap_feature_per_client, min_cluster_size=min_cluster_size, epsilon=epsilon)
#     # print(hdbscan_labels)
#     perplexity = 10
#     visualise_clusters_with_tsne(input_feature_list=shap_feature_per_client, label_list=hdbscan_labels,
#                                      label_colors=colors, perplexity=perplexity,
#                                      show_malicious_items=False, malicious_start_idx=None,
#                                      malicious_end_idx=None, show_labels=True)
#     real_malicious_ids = get_malicious_node_ids()
#     poison_idxes = general_algorithm_main_calc2(client_updates_list, total_labels_per_client*sample_count_for_plot, hdbscan_labels, real_malicious_ids)
#     total_detected_poison_idxes.append(poison_idxes)

In [44]:
import ipywidgets as widgets
from IPython.display import display
import shap
import matplotlib.pyplot as plt

# Create a button to run the process
run_button = widgets.Button(description="Run SHAP Detection")

# Create output widgets
output = widgets.Output()
loading_label = widgets.Label("")

# Global variable to store results
total_detected_poison_idxes = []

# Function to run on button click
def on_button_click(b):
    global total_detected_poison_idxes
    total_detected_poison_idxes = []  # Reset results for each run
    
    loading_label.value = "Running... please wait!"
    output.clear_output()

    sample_count_for_test = 100
    sample_count_for_plot = 1
    total_rounds = 1
    total_labels_per_client = 10

    client_updates_list_all_weights_dict = [{i:(client_updates_list_all_weights[i],16) for i in range(len(client_updates_list_all_weights))}]

    client_updates_list = client_updates_list_all_weights_dict
    target_clients = list(client_updates_list[0].keys())

    testloader = get_testloader(len_test=sample_count_for_test + sample_count_for_plot*2, 
                                batch_size=sample_count_for_test + sample_count_for_plot*2, 
                                shuffle=True)
    batch = next(iter(testloader))
    images, actual_out = batch
    background = images[:sample_count_for_test]
    test_images = images[sample_count_for_test:sample_count_for_test + sample_count_for_plot]

    with output:
        for round in range(total_rounds):
            shap_feature_ori_list_all = []
            round_no = round
            
            for cli in target_clients:
                target_client_id = cli
                mdl_0 = get_mdl_of_client_at_round(target_client_id=target_client_id, round_no=round_no,
                                                   client_updates_list=client_updates_list)
                e = shap.DeepExplainer(mdl_0, background)
                shap_values = e.shap_values(test_images)
                shap_vals_all_flattened = []
                for pred in shap_values:
                    for j in pred:
                        shap_vals_all_flattened.append(j.flatten())
                shap_feature_ori_list_all.append(shap_vals_all_flattened)

            shap_feature_list = shap_feature_ori_list_all
            shap_feature_per_client = []
            for i in range(len(target_clients)):
                for j in range(sample_count_for_plot * total_labels_per_client):
                    shap_feature_per_client.append(shap_feature_list[i][j])

            min_cluster_size = 2
            epsilon = 0.0
            hdbscan_labels, hdbscan_clusterer, colors = run_hdbscan_clustering_algorithm(
                input_feature_list=shap_feature_per_client, min_cluster_size=min_cluster_size, epsilon=epsilon)

            perplexity = 10
            visualise_clusters_with_tsne(input_feature_list=shap_feature_per_client, label_list=hdbscan_labels,
                                         label_colors=colors, perplexity=perplexity,
                                         show_malicious_items=False, malicious_start_idx=None,
                                         malicious_end_idx=None, show_labels=True)

            real_malicious_ids = get_malicious_node_ids()
            poison_idxes = general_algorithm_main_calc2(client_updates_list, total_labels_per_client*sample_count_for_plot, hdbscan_labels, real_malicious_ids)
            total_detected_poison_idxes.append(poison_idxes)

        # Display the plots
        plt.show()  # This ensures the plots from visualise_clusters_with_tsne are rendered
        
        # Print the detected malicious indices
        print("Total Detected Poison Indices:", total_detected_poison_idxes)

    loading_label.value = "Completed!"

# Link the button to the function
run_button.on_click(on_button_click)

# Display the button, loading label, and output
display(run_button, loading_label, output)


Button(description='Run SHAP and Clustering', style=ButtonStyle())

Label(value='')

Output()

In [None]:
# sample_count_for_test = 100
# sample_count_for_plot = 1
# total_rounds = 1
# total_labels_per_client = 10
# # client_updates_list = client_updates_list_nw
# client_updates_list = client_updates_list_all
# # target_clients = list(client_updates_list[0].keys())
# target_clients = [i for i in range(total_nodes)]

# testloader = get_testloader(len_test=sample_count_for_test+sample_count_for_plot*2, batch_size=sample_count_for_test+sample_count_for_plot*2, 
#                             shuffle=True)
# batch = next(iter(testloader))
# images, actual_out = batch
# background = images[:sample_count_for_test]
# test_images = images[sample_count_for_test:sample_count_for_test + sample_count_for_plot]

# total_detected_poison_idxes = []

# for round in range(total_rounds):
#     shap_feature_ori_list_all = []
#     round_no = round
#     shap_feature_ori_list_all = []
    
#     for cli in target_clients:
#         target_client_id = cli
#         print(cli)
#         # mdl_0 = get_mdl_of_client_at_round(target_client_id=target_client_id, round_no=round_no,
#                                            # client_updates_list=client_updates_list)
#         mdl_0 = client_updates_list[target_client_id]
#         e = shap.DeepExplainer(mdl_0, background)
#         shap_values = e.shap_values(test_images)
#         shap_vals_all_flattened = []
#         for pred in shap_values:
#             for j in pred:
#                 shap_vals_all_flattened.append(j.flatten())
#         shap_feature_ori_list_all.append(shap_vals_all_flattened)
#     shap_feature_list = shap_feature_ori_list_all
#     shap_feature_per_client = []
#     for i in range(len(target_clients)):
#         for j in range(sample_count_for_plot*total_labels_per_client):
#             shap_feature_per_client.append(shap_feature_list[i][j])
#     # print(len(shap_feature_per_client))
#     min_cluster_size = 2
#     # epsilon = 0.011
#     epsilon = 0.011
#     hdbscan_labels, hdbscan_clusterer, colors = run_hdbscan_clustering_algorithm(
#             input_feature_list=shap_feature_per_client, min_cluster_size=min_cluster_size, epsilon=epsilon)
#     # print(hdbscan_labels)
#     perplexity = 10
#     visualise_clusters_with_tsne(input_feature_list=shap_feature_per_client, label_list=hdbscan_labels,
#                                      label_colors=colors, perplexity=perplexity,
#                                      show_malicious_items=False, malicious_start_idx=None,
#                                      malicious_end_idx=None, show_labels=True)

In [None]:
# '''main algorithm for new general poisoning detection'''
# def general_algorithm_main_calc2(client_updates_list, total_labels_per_client, hdbscan_labels, cli_ids, poisoner_ct=0):
#     # target_clients = list(client_updates_list[0].keys())
#     # target_clients = [i for i in range(total_nodes)]
#     target_clients = client_updates_list
#     # print(target_clients)
#     # get all cluster ids, ignore -1 cluster id
#     all_clusters_ids = np.unique(hdbscan_labels)
#     # if np.any(all_clusters_ids == 0):
#     #     all_clusters_ids = all_clusters_ids[all_clusters_ids != 0]

#     # create an empty dict of arrays to get the positions for each features to be calculated for suspicious counts
#     feature_positions = {}
#     for i in all_clusters_ids:
#         feature_positions[i] = []

#     for i in range(len(hdbscan_labels)):
#         if hdbscan_labels[i] in feature_positions.keys():
#             feature_positions[hdbscan_labels[i]].append(i)
            
#     # ## patch to remove outliers       
#     # for key, value_list in list(feature_positions.items()):
#     #     if len(value_list) > 200:
#     #         # Remove the key-value pair if the length is greater than 200
#     #         del feature_positions[key]
        
#     # print(feature_positions)
#     diff_idxes_all = []

#     # Main algorithm to detect poisoners: compare feature repetitions within the same cluster.
#     # If different features are present, possible poisoning alert
#     for i in all_clusters_ids:
#         # print(i)
#         # List of numbers
#         numbers = feature_positions[i]
#         # print(len(numbers))
#         # if len(numbers)>200:
#         #     pass
#         # Find the remainder when each number is divided by total output features/labels per client
#         remainders = [num % total_labels_per_client for num in numbers]

#         # Check if all the remainders are the same
#         if all(remainder == remainders[0] for remainder in remainders):
#             print("All cluster features are the same:", i)
#         else:
#             # print("Not all features are the same. Possible poisoning")
#             # Find and isolate the numbers with different remainders
#             different_idxes = [num for num, remainder in zip(numbers, remainders) if remainder != remainders[0]]
#             # print("Cluster with different features:", different_idxes)
#             # print(numbers)
#             diff_idxes_all.extend(numbers)

#     sus_ct = {}

#     for i in list(target_clients):
#         sus_ct[i] = 0
    

#     # add a suspicious score for each client
#     for i in diff_idxes_all:
#         sus_client_position = i // (total_labels_per_client)
#         sus_client = target_clients[sus_client_position]
#         sus_ct[sus_client] += 1

#     print(sus_ct)
#     print(sus_ct.values())
#     total_sus_ct = sum(sus_ct.values())
#     # total_sus_ct = sum(list(sus_ct.values))
#     # total_sus_cli = sum(1 for value in sus_ct.values() if value > 0)
#     # total_cli = len(list(client_updates_list[0].keys()))
#     total_cli = len(cli_ids)
    
#     print(sus_ct)  # this is what we want!!

#     # detecting poison clients
#     poison_clients = []
    
#     for key, value in sus_ct.items():
#         # if value >= total_labels_per_client / 2:
#         if value >= int(total_sus_ct/total_cli) and value >= total_labels_per_client / 2:
#             poison_clients.append(key)  # CONVERTING TO AN INTEGER CAN BE A POTENTIAL BUG - yes it is, so eliminated!!!

#     # poison_idxes = []
#     # idxes_to_remove = list(client_updates_list[0].keys())
#     # for i in poison_clients:
#     #     poison_idxes.append(idxes_to_remove.index(i))
#     poison_idxes = poison_clients

#     print('detected: ', poison_idxes)
#     debugging_enabled = False
#     poison_idx_viewing = True
#     if poison_idx_viewing:
#         # debugging operation (should update)
#         # my_list = list(client_updates_list[0].keys())
#         my_list = cli_ids
#         # values_to_find = [str(i) for i in range(poisoner_ct)]
#         # values_to_find = ['1', '2', '3','4','5','6','7','8','9','10']
#         # values_to_find = ['1', '2', '3', '4', '5']
#         values_to_find = malicious_ids

#         # Find the indexes of the values in the list
#         # indexes = [i for i, value in enumerate(my_list) if value in values_to_find]
#         print('original: ', values_to_find)
#         if debugging_enabled:
#             return indexes
#     return poison_idxes

In [14]:
from abc import ABC, abstractmethod

class Topology(ABC):
    def __init__(self, clients, configuration):
        self.clients = clients  # List of clients (Node0, Node1, etc.)
        self.configuration = configuration  # Dict of connections for each client

    @abstractmethod
    def select_clients_for_round(self, round_num):
        pass
    
    @abstractmethod
    def exchange_models(self, models, round_num):
        pass

In [22]:
class StarTopology(Topology):
    def __init__(self, clients, configuration):
        super().__init__(clients, configuration)  # Call the base class constructor

    def select_clients_for_round(self, round_num):
        return self.clients
    
    def exchange_models(self, models, round_num):
        central_server_model = self.aggregate_models(models)
        return {client: central_server_model for client in self.clients}
    
    def aggregate_models(self, models):
        return sum(models.values()) / len(models)

class RingTopology(Topology):
    def __init__(self, clients, configuration):
        super().__init__(clients, configuration)

    def select_clients_for_round(self, round_num):
        return self.clients

    def exchange_models(self, models, round_num):
        next_round_models = {}
        for client, peers in self.configuration.items():
            next_client = peers[0]
            next_round_models[next_client] = models[client]
        return next_round_models

class MeshTopology(Topology):
    def __init__(self, clients, configuration):
        super().__init__(clients, configuration)

    def select_clients_for_round(self, round_num):
        return self.clients
    
    def exchange_models(self, models, round_num):
        aggregated_models = {}
        for client, peers in self.configuration.items():
            peer_models = [models[peer] for peer in peers if peer in models]
            aggregated_model = self.aggregate_peer_models(peer_models)
            aggregated_models[client] = aggregated_model
        return aggregated_models

    def aggregate_peer_models(self, models):
        return sum(models) / len(models)


In [9]:
# class RingTopology(Topology):
#     def select_clients_for_round(self, round_num):
#         return self.clients

#     def exchange_models(self, models, round_num):
#         next_round_models = {}
#         for client, peers in self.configuration.items():
#             # Send model to the first peer in the list (ring-like behavior)
#             next_client = peers[0]
#             next_round_models[next_client] = models[client]
#         return next_round_models


In [15]:
# class MeshTopology(Topology):
#     def select_clients_for_round(self, round_num):
#         return self.clients
    
#     def exchange_models(self, models, round_num):
#         aggregated_models = {}
#         for client, peers in self.configuration.items():
#             # Aggregate models from all connected peers
#             peer_models = [models[peer] for peer in peers if peer in models]
#             aggregated_model = self.aggregate_peer_models(peer_models)
#             aggregated_models[client] = aggregated_model
#         return aggregated_models

#     def aggregate_peer_models(self, models):
#         return sum(models) / len(models)


In [17]:
# Example configuration dict
client_config = {
    'Node0': ['Node1', 'Node2', 'Node3', 'Node4'], 
    'Node1': ['Node0'], 
    'Node2': ['Node0', 'Node3'], 
    'Node3': ['Node0', 'Node2'], 
    'Node4': ['Node0']
}

clients = list(client_config.keys())  # ['Node0', 'Node1', 'Node2', 'Node3', 'Node4']

# Star Topology
star_topology = StarTopology(clients, client_config)
# train_federated(clients, star_topology, 10)

# Ring Topology
ring_topology = RingTopology(clients, client_config)
# train_federated(clients, ring_topology, 10)

# Mesh Topology
mesh_topology = MeshTopology(clients, client_config)
# train_federated(clients, mesh_topology, 10)


In [8]:
# ring_topology.configuration

In [2]:
# import ipywidgets as widgets
# from IPython.display import display, HTML
# import networkx as nx
# from pyvis.network import Network
# from ipywidgets import Layout, GridspecLayout
# import copy

# # Initialize graph and data structures
# G = nx.Graph()
# connections = []
# history = []  # To store graph states for undo functionality

# # Function to create a grid layout
# def create_grid():
#     grid = GridspecLayout(6, 4, height='800px', width='100%')  # Increased height to 800px

#     # Topology and node inputs
#     grid[0, 0] = widgets.Label("Graph Topology:")
#     grid[0, 1] = topology_select
#     grid[0, 2] = widgets.Label("Node Count:")
#     grid[0, 3] = node_count_input

#     # Apply topology button
#     grid[1, 0:2] = apply_topology_button

#     # Add/Delete Nodes and Undo buttons
#     grid[2, 0] = add_node_button
#     grid[2, 1] = add_node_count_input
#     grid[2, 2] = delete_node_button
#     grid[2, 3] = undo_button

#     # Add connection and node selection (using HBox to group widgets)
#     grid[3, 0] = widgets.Label("Select Nodes for Connection:")
#     grid[3, 1:] = widgets.HBox([node_select, add_connection_button], layout=Layout(width='100%', justify_content='space-between'))

#     # Graph output display
#     grid[4:, :] = graph_output

#     return grid

# # Custom CSS Styling for Voila
# custom_css = """
# <style>
#     .widget-button {
#         background-color: #4CAF50;
#         color: white;
#         font-weight: bold;
#         border-radius: 5px;
#         padding: 10px 20px;
#         margin: 5px;
#         border: none;
#     }
#     .widget-button:hover {
#         background-color: #45a049;
#     }
#     .widget-label {
#         font-size: 16px;
#         font-weight: bold;
#     }
#     .widget-select-multiple {
#         height: 150px;  /* Increase height for the select multiple widget */
#     }
#     .output {
#         border: 1px solid #ddd;
#         padding: 10px;
#         border-radius: 5px;
#         background-color: #f9f9f9;
#     }
#     .graph-output {
#         width: 100%;
#         height: 800px;  /* Increase graph output height */
#     }
# </style>
# """

# # Define widgets
# node_count_input = widgets.BoundedIntText(value=5, min=2, max=100, step=1, layout=Layout(width='90%'))
# topology_select = widgets.Dropdown(
#     options=['Star', 'Ring', 'Mesh', 'Tree'],
#     value='Star',
#     layout=Layout(width='90%')
# )
# apply_topology_button = widgets.Button(description="Apply Topology", layout=Layout(width='90%'))

# # For adding nodes
# add_node_count_input = widgets.BoundedIntText(value=1, min=1, max=100, step=1, layout=Layout(width='90%'))
# add_node_button = widgets.Button(description="Add Nodes", layout=Layout(width='90%'))

# # For removing nodes
# delete_node_start_input = widgets.BoundedIntText(value=0, min=0, max=100, step=1, layout=Layout(width='90%'))
# delete_node_end_input = widgets.BoundedIntText(value=1, min=0, max=100, step=1, layout=Layout(width='90%'))
# delete_node_button = widgets.Button(description="Delete Nodes", layout=Layout(width='90%'))

# # For adding connections
# node_select = widgets.SelectMultiple(options=list(G.nodes), layout=Layout(width='70%', height='150px')) # Increase height of node selection
# add_connection_button = widgets.Button(description="Add Connection", layout=Layout(width='100%', height='150px'))

# # Undo button
# undo_button = widgets.Button(description="Undo Last Action", layout=Layout(width='90%'))

# graph_output = widgets.Output(layout=Layout(width='100%', height='800px'))  # Increase graph output height to 800px

# # Save current graph state in history
# def save_graph_state():
#     global history
#     # Save a deep copy of the current state of the graph (nodes and edges)
#     history.append((copy.deepcopy(G), copy.deepcopy(connections)))

# # Function to initialize the graph based on the selected topology
# def initialize_graph(topology, node_count):
#     G.clear()  # Clear the graph for new topology
#     connections.clear()  # Clear previous connections
#     if topology == 'Star':
#         G.add_node("Node0")  # Central node
#         for i in range(1, node_count):
#             node_name = f"Node{i}"
#             G.add_node(node_name)
#             G.add_edge("Node0", node_name)  # Connect each node to the central node
#             connections.append(("Node0", node_name))
#     elif topology == 'Ring':
#         for i in range(node_count):
#             node_name = f"Node{i}"
#             G.add_node(node_name)
#         for i in range(node_count):
#             G.add_edge(f"Node{i}", f"Node{(i + 1) % node_count}")  # Connect nodes in a ring
#             connections.append((f"Node{i}", f"Node{(i + 1) % node_count}"))
#     elif topology == 'Mesh':
#         for i in range(node_count):
#             node_name = f"Node{i}"
#             G.add_node(node_name)
#         for i in range(node_count):
#             for j in range(i + 1, node_count):
#                 G.add_edge(f"Node{i}", f"Node{j}")  # Connect each node to every other node
#                 connections.append((f"Node{i}", f"Node{j}"))
#     elif topology == 'Tree':
#         G.add_node("Node0")  # Root node
#         for i in range(1, node_count):
#             node_name = f"Node{i}"
#             G.add_node(node_name)
#             parent_node = f"Node{(i - 1) // 2}"  # Connect each node to its parent node (integer division for tree structure)
#             G.add_edge(parent_node, node_name)  # Parent-child relationship
#             connections.append((parent_node, node_name))
    
#     node_select.options = list(G.nodes)  # Update node selection options
#     save_graph_state()  # Save the initial state

# # Callback to apply selected topology
# def apply_topology_callback(b):
#     topology = topology_select.value
#     node_count = node_count_input.value
#     initialize_graph(topology, node_count)
#     with graph_output:
#         graph_output.clear_output(wait=True)
#         draw_graph()

# # Callback to add nodes to the graph
# def add_node_callback(b):
#     node_count = add_node_count_input.value
#     existing_nodes = len(G.nodes)  # Count the existing nodes
#     for i in range(1, node_count + 1):
#         new_node_name = f"Node{existing_nodes + i - 1}"
#         G.add_node(new_node_name)
#         G.add_edge("Node0", new_node_name)  # Default to connect new nodes to the central node (or adjust)
#         connections.append(("Node0", new_node_name))
#     node_select.options = list(G.nodes)  # Update node selection options
#     save_graph_state()  # Save the current state
#     with graph_output:
#         graph_output.clear_output(wait=True)
#         draw_graph()

# # Callback to delete nodes within a range of IDs
# def delete_node_callback(b):
#     start_node = delete_node_start_input.value
#     end_node = delete_node_end_input.value
    
#     # Check and remove nodes within the specified range if they exist
#     for node_id in range(start_node, end_node + 1):
#         node_name = f"Node{node_id}"
#         if node_name in G.nodes:
#             G.remove_node(node_name)
    
#     node_select.options = list(G.nodes)  # Update node selection options
#     save_graph_state()  # Save the current state
#     with graph_output:
#         graph_output.clear_output(wait=True)
#         draw_graph()

# # Callback to add connections between selected nodes
# def add_connection_callback(b):
#     selected = list(node_select.value)
#     if len(selected) == 2:  # Only add connection if two nodes are selected
#         G.add_edge(selected[0], selected[1])
#         connections.append((selected[0], selected[1]))
#         save_graph_state()  # Save the current state
#         with graph_output:
#             graph_output.clear_output(wait=True)
#             draw_graph()

# # Undo callback to rollback to previous state
# def undo_callback(b):
#     global G, connections, history
#     if len(history) > 1:
#         history.pop()  # Remove the most recent state
#         G, connections = copy.deepcopy(history[-1])  # Restore the last saved state
#         node_select.options = list(G.nodes)  # Update node selection options
#         with graph_output:
#             graph_output.clear_output(wait=True)
#             draw_graph()

# # Function to draw the graph using pyvis with CDN resources
# def draw_graph():
#     net = Network(notebook=True, height='800px', width='100%', bgcolor='#ffffff', font_color='black', cdn_resources='in_line')
    
#     # Add nodes and edges from the NetworkX graph G
#     net.from_nx(G)
    
#     # Save the network and display it in the notebook
#     net.show('graph.html')
#     display(HTML('graph.html'))

# # Apply the custom CSS
# display(HTML(custom_css))

# # Create and display the grid layout
# grid = create_grid()
# display(grid)

# # Button event handlers
# apply_topology_button.on_click(apply_topology_callback)
# add_node_button.on_click(add_node_callback)
# delete_node_button.on_click(delete_node_callback)
# add_connection_button.on_click(add_connection_callback)
# undo_button.on_click(undo_callback)

# # Initialize graph plot
# with graph_output:
#     apply_topology_callback(None)


In [32]:
# !pip install bokeh
# !pip install pyvis

In [25]:
# import ipywidgets as widgets
# import networkx as nx
# from bokeh.io import output_notebook, show
# from bokeh.plotting import figure, from_networkx
# from bokeh.models import Circle, MultiLine, TapTool, CustomJS, ColumnDataSource
# from bokeh.models.graphs import NodesAndLinkedEdges
# from bokeh.models.callbacks import CustomJS
# from IPython.display import display

# # Initialize Bokeh for notebook
# output_notebook()

# # Initialize the graph and data structures
# G = nx.Graph()

# # Add a centralized node and other nodes connected to it
# central_node = "Center"
# G.add_node(central_node)
# other_nodes = ["Node1", "Node2", "Node3", "Node4", "Node5"]
# for node in other_nodes:
#     G.add_node(node)
#     G.add_edge(central_node, node)

# # Create a Bokeh figure
# plot = figure(title="Interactive Graph", x_range=(-2, 2), y_range=(-2, 2),
#               tools="tap", tooltips="Node: @index")

# # Convert the graph into a Bokeh renderer
# graph_renderer = from_networkx(G, nx.spring_layout, scale=1, center=(0, 0))
# graph_renderer.node_renderer.glyph = Circle(radius=0.1, fill_color="skyblue")
# graph_renderer.edge_renderer.glyph = MultiLine(line_color="gray", line_width=2)
# graph_renderer.node_renderer.hover_glyph = Circle(radius=0.1, fill_color="green")
# plot.renderers.append(graph_renderer)

# # Add TapTool to enable node selection
# taptool = plot.select(type=TapTool)

# # Store the selected nodes for creating connections
# source = ColumnDataSource(data=dict(selected=[]))

# # JavaScript callback to handle node selection and create edges dynamically
# callback = CustomJS(args=dict(source=source, graph_renderer=graph_renderer), code="""
#     // Get selected indices from the graph
#     var selected_indices = cb_obj.selected.indices;

#     // Get the names of the selected nodes (their index values)
#     var selected_nodes = source.data.selected;

#     if (selected_indices.length > 0) {
#         var node_name = graph_renderer.node_renderer.data_source.data['index'][selected_indices[0]];
#         selected_nodes.push(node_name);
#         source.data['selected'] = selected_nodes;

#         // If two nodes are selected, create an edge between them
#         if (selected_nodes.length == 2) {
#             // Add the edge in the graph's data source
#             var edge_data = graph_renderer.edge_renderer.data_source.data;
#             edge_data['start'].push(selected_nodes[0]);
#             edge_data['end'].push(selected_nodes[1]);

#             // Reset the selected nodes
#             selected_nodes.length = 0;

#             // Trigger graph re-render
#             graph_renderer.edge_renderer.data_source.change.emit();
#         }
#     }
# """)

# # Attach the JavaScript callback to the TapTool's selection
# graph_renderer.node_renderer.data_source.selected.js_on_change('indices', callback)

# # Show the graph in the notebook
# show(plot)

# # Widgets for adding new nodes and edges dynamically
# node_input = widgets.Text(placeholder="Enter node name", description="Node:")
# add_node_button = widgets.Button(description="Add Node")
# connection_output = widgets.Output()

# # Callback function to add nodes
# def add_node_callback(b):
#     node_name = node_input.value
#     if node_name:  # Add node to the graph
#         G.add_node(node_name)
#         G.add_edge(central_node, node_name)  # Connect new node to the central node
#         node_input.value = ""  # Clear input
#         update_graph()

# # Function to update the graph in the widget
# def update_graph():
#     with connection_output:
#         connection_output.clear_output()
#         show(plot)

# # Display the widgets for adding new nodes
# display(node_input, add_node_button, connection_output)

# # Attach button callback
# add_node_button.on_click(add_node_callback)
