In [8]:
import tensorflow as tf
from tensorflow.keras import layers

# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Normalize the images to the range [0, 1]
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

# Convert labels to one-hot encoding if needed
y_train_one_hot = tf.keras.utils.to_categorical(y_train, 10)
y_test_one_hot = tf.keras.utils.to_categorical(y_test, 10)

# Print dataset shapes
print("Training data shape:", x_train.shape, y_train.shape)
print("Testing data shape:", x_test.shape, y_test.shape)


Training data shape: (50000, 32, 32, 3) (50000, 1)
Testing data shape: (10000, 32, 32, 3) (10000, 1)


In [65]:
import numpy as np
def exclue_class(x_data,y_data, excluded_cindex):
    x_filtered,y_filtered = [],[]
    for cur_x, cur_y in zip(x_data,y_data):
        if np.argmax(cur_y) in excluded_cindex:
            continue
        x_filtered.append(cur_x)
        y_filtered.append(cur_y)
    return np.array(x_filtered),np.array(y_filtered)

In [66]:
def check_cindex(cur_y):
    temp = set()
    for x in cur_y:
        temp.add(np.argmax(x))
    return temp

In [67]:
x_train_1, y_train_1  = exclue_class(x_train, y_train_one_hot, excluded_cindex=[1, 3, 7])
print("y_train_1 : ",check_cindex(y_train_1))


x_train_2, y_train_2 = exclue_class(x_train, y_train_one_hot, excluded_cindex=[2, 5, 8])
print("y_train_2 : ",check_cindex(y_train_2))

x_train_3, y_train_3 = exclue_class(x_train, y_train_one_hot, excluded_cindex=[4, 6, 9])
print("y_train_3 : ",check_cindex(y_train_3))




y_train_1 :  {0, 2, 4, 5, 6, 8, 9}
y_train_2 :  {0, 1, 3, 4, 6, 7, 9}
y_train_3 :  {0, 1, 2, 3, 5, 7, 8}


In [68]:
x_test_1, y_test_1  = exclue_class(x_test, y_test_one_hot, excluded_cindex=[1, 3, 7])
print("y_test_1 : ",check_cindex(y_test_1))

x_test_2, y_test_2  = exclue_class(x_test, y_test_one_hot, excluded_cindex=[2, 5, 8])
print("y_test_2 : ",check_cindex(y_test_2))

x_test_3, y_test_3  = exclue_class(x_test, y_test_one_hot, excluded_cindex=[4,6,9])
print("y_test_3 : ",check_cindex(y_test_3))


y_test_1 :  {0, 2, 4, 5, 6, 8, 9}
y_test_2 :  {0, 1, 3, 4, 6, 7, 9}
y_test_3 :  {0, 1, 2, 3, 5, 7, 8}


In [98]:

def load_model(learning_rate: float = 0.001):
    # Define a simple CNN for CIFAR-10 and set Adam optimizer
    model = tf.keras.Sequential(
        [
            layers.Input(shape=(32, 32, 3)),
            layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
            layers.MaxPooling2D(pool_size=(2, 2)),
            layers.Flatten(),
            layers.Dropout(0.5),
            layers.Dense(10, activation="softmax"),
        ]
    )
    model.compile(
        "adam",
        loss="categorical_crossentropy",
        metrics=["accuracy"],
    )
    return model

### Fed Learning

In [111]:
from flwr.client import NumPyClient, ClientApp
from flwr.common import ndarrays_to_parameters, Context
from flwr.client import Client, ClientApp, NumPyClient
from flwr.server.strategy import FedAvg
from flwr.server import ServerApp, ServerConfig
from flwr.server import ServerAppComponents

In [100]:
class FlowerClient(NumPyClient):
    def __init__(
        self, model, data, epochs, batch_size, verbose
    ):
        self.model = model
        self.x_train, self.y_train, self.x_test, self.y_test = data
        self.epochs = epochs
        self.batch_size = batch_size
        self.verbose = verbose

    def fit(self, parameters, config):
        self.model.set_weights(parameters)
        self.model.fit(
            self.x_train,
            self.y_train,
            epochs=self.epochs,
            batch_size=self.batch_size,
            verbose=self.verbose,
        )
        return self.model.get_weights(), len(self.x_train), {}

    def evaluate(self, parameters, config):
        self.model.set_weights(parameters)
        loss, accuracy = self.model.evaluate(self.x_test, self.y_test, verbose=0)
        return loss, len(self.x_test), {"accuracy": accuracy}


### mdoel Evaluation method

In [101]:
def evaluate_model(model,cur_x_test,cur_y_test):
    
    loss, accuracy = model.evaluate(cur_x_test,cur_y_test, verbose=0)
    return loss, accuracy

In [102]:
# Client function
def client_fn1(context: Context) -> Client:
    net = load_model()
    client_data = (x_train_1,y_train_1,x_test_1,y_test_1)
    epochs = 3,
    batch_size = 32,
    verbose =1
    return FlowerClient(net, client_data, epochs,batch_size,verbose).to_client()

In [None]:
client = ClientApp(client_fn1)

In [103]:
print("x_test_1" ,x_test_1.shape)
print("y_test_1 " , y_test_1.shape)

x_test_1 (7000, 32, 32, 3)
y_test_1  (7000, 10)


In [106]:
net = load_model()
_, accuracy137 = evaluate_model(net, x_test_1,y_test_1)
print("test accuracy on [1,3,7]: ", accuracy137)

test accuracy on [1,3,7]:  0.08285713940858841


### Define evaluate for model testing
- The evaluate method evaluates the performance of the neural network model using the provided parameters and the test dataset (testset).

In [107]:
def evaluate(server_round, parameters, config):
    net = load_model()
    net.set_weights(parameters)

    _, accuracy137 = evaluate_model(net, x_test_1,y_test_1)
    _, accuracy258 = evaluate_model(net, x_test_2,y_test_2)
    _, accuracy469 = evaluate_model(net, x_test_3,y_test_3)

    print("test accuracy on [1,3,7]: ", accuracy137)
    print("test accuracy on [2,5,8]: ", accuracy258)
    print("test accuracy on [4,6,9]: ", accuracy469)


In [112]:
net = load_model()
params = ndarrays_to_parameters(net.get_weights())

def server_fn(context: Context):
    strategy = FedAvg(
        fraction_fit=1.0,
        fraction_evaluate=0.0,
        initial_parameters=params,
        evaluate_fn=evaluate,
    )
    config=ServerConfig(num_rounds=3)
    return ServerAppComponents(
        strategy=strategy,
        config=config,
    )

In [None]:
server = ServerApp(server_fn=server_fn)

In [None]:
from flwr.simulation import run_simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=3,
    #backend_config=backend_setup,
)