Simple Development of MAB-FL in Random Manner

In [None]:
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import random


In [None]:
#Intiial Setting for number of clients and number of rounds /iterations to aggregate the global aggregations
num_devices = 10
num_rounds = 100
model_weights = []  # Sotring model
global_models = []  # Storing models after each round

# Simulate local model training on devices
def train_local_models(num_devices):
    models = []
    for _ in range(num_devices):
        # Create a simple linear regression model with random coefficients
        model = LinearRegression()
        model.coef_ = np.random.rand(1)
        model.intercept_ = np.random.rand(1)
        models.append(model)
    return models

# Simulate model aggregation to get global model
def aggregate_models(models):
    # Average the coefficients of the models
    avg_coef = np.mean([model.coef_ for model in models])
    avg_intercept = np.mean([model.intercept_ for model in models])

    global_model = LinearRegression()
    global_model.coef_ = avg_coef
    global_model.intercept_ = avg_intercept
    return global_model

# Multi-Armed Bandit (MAB) for model selection
def mab_select_model(models, epsilon=0.1):
    if random.random() < epsilon:  # Explore
        return random.choice(models)
    else:  # Exploit
        losses = [mean_squared_error([0], model.predict([[0]])) for model in models]  # Example loss calculation
        return models[np.argmin(losses)]


In [None]:

# Main simulation loop
for _ in range(num_rounds):
    local_models = train_local_models(num_devices)
    model_weights.extend(local_models)
    selected_model = mab_select_model(model_weights)
    global_model = aggregate_models(model_weights)
    global_models.append(global_model)

# Evaluate the final global model
X_test = np.array([[0]])  # Example test data
y_test = np.array([0])   # Example test label
final_global_model = global_models[-1]
y_pred = final_global_model.predict(X_test)
loss = mean_squared_error(y_test, y_pred)

print("Final Global Model Loss:", loss)
