In [4]:
import sys
import os

# For Jupyter notebooks
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Import commonly used modules
from data_utils import *
from model_utils import *
from config_utils import load_config

# Load default config
CONFIG = load_config('../config/fl_template_config.yaml')

# Export commonly used items
__all__ = ['CONFIG']

## FedAvg

### Federated Learning Averaging Pseudocode

### Server Initialization:
Initialize global model weights W₀

### Main Federated Learning Loop:
For each round t = 1 to T:
    
    1. Select a subset of clients Sₜ (or use all available clients)
    
    2. Broadcast the current global model weights Wₜ to all clients in Sₜ

    3. For each client k in Sₜ (executed in parallel):
         - Perform a local update:
           Wₜᵏ = ClientUpdate(Wₜ, local_dataₖ)
         - Return updated local model weights Wₜᵏ along with the number of samples nₖ

    4. Aggregate the updated weights using weighted averaging (FedAvg):
         - Compute total samples: N = Σₖ₍∈Sₜ₎ nₖ
         - Update global model weights:
           Wₜ₊₁ = Σₖ₍∈Sₜ₎ (nₖ / N) * Wₜᵏ

Return final global model weights W_T

### ClientUpdate Function:
Function ClientUpdate(W, local_data):
    
    - Set W_local = W
    
    - For each local epoch e = 1 to E:
         - For each batch b in local_data:
              - Compute gradient: grad = ∇(loss(W_local, b))
              - Update local weights: W_local = W_local - learning_rate * grad
              
    Return W_local

#### Central Model Initialization

In [3]:
history_dic = {}

In [7]:
model = create_model()
trainx,trainy = load_training_data(f'../experiments/{CONFIG['experiment_name']}/processed_data/init.npy')
history =  train_model(model,trainx,trainy)
history_dic['init'] = history
save_model(model,f'../experiments/{CONFIG['experiment_name']}/models/central_model.keras')
clear_model_from_memory(model)

Epoch 1/10
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 424ms/step - accuracy: 0.6235 - auc_2: 0.5452 - loss: 1.1168 - precision_2: 0.6877 - recall_2: 0.5327
Epoch 2/10
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.6074 - auc_2: 0.8512 - loss: 0.5920 - precision_2: 0.5811 - recall_2: 0.9803
Epoch 3/10
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 0.9090 - auc_2: 0.9576 - loss: 0.3398 - precision_2: 0.8977 - recall_2: 0.9321
Epoch 4/10
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.9460 - auc_2: 0.9708 - loss: 0.1787 - precision_2: 0.9927 - recall_2: 0.8993
Epoch 5/10
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.9633 - auc_2: 0.9847 - loss: 0.1416 - precision_2: 1.0000 - recall_2: 0.9231
Epoch 6/10
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.8807 - auc_2: 0.9584 - loss: 0.3

#### Client Models Initialization

In [8]:
central_model = load_model_from_disk(f'../experiments/{CONFIG['experiment_name']}/models/central_model.keras')
for i in range(CONFIG.get('num_clients',5)):
    model = create_model()
    model.set_weights(central_model.get_weights())
    model._name = f'client_model_{i+1}'
    save_model(model,f'../experiments/{CONFIG['experiment_name']}/models/{model._name}.keras')
    clear_model_from_memory(model)

InternalError: Failed copying input tensor from /job:localhost/replica:0/task:0/device:CPU:0 to /job:localhost/replica:0/task:0/device:GPU:0 in order to run Cast: CUDA error: Error recording CUDA event: CUDA_ERROR_CONTEXT_IS_DESTROYED: context is destroyed [Op:Cast] name: 

## FedAvg Conditioned

### Modified FedAvg with Performance-based Weighting

### Server Initialization:
Initialize global model weights W₀

### Main Federated Learning Loop:

For each round t = 1 to T:

    1. Select a subset of clients Sₜ (or use all available clients)
    2. Broadcast the current global model weights Wₜ to all clients in Sₜ

    3. For each client k in Sₜ (executed in parallel):
         - Perform a local update:
           Wₜᵏ = ClientUpdate(Wₜ, local_dataₖ)
         - Evaluate the updated model on a common validation set:
           aₖ = Evaluate(Wₜᵏ, validation_set)  # e.g., accuracy
         - Return updated model Wₜᵏ, number of samples nₖ, and accuracy aₖ

    4. Aggregate the updated weights:
         - Compute the performance-weighted sum of samples:
           Total_weight = Σₖ₍∈Sₜ₎ (nₖ × aₖ)
         - Update global model weights:
           Wₜ₊₁ = Σₖ₍∈Sₜ₎ [(nₖ × aₖ) / Total_weight] × Wₜᵏ

Return final global model weights W_T

### ClientUpdate Function:

Function ClientUpdate(W, local_data):
    
    Set W_local = W
    For each local epoch e = 1 to E:
         For each batch b in local_data:
              - Compute gradient: grad = ∇(loss(W_local, b))
              - Update local weights: W_local = W_local - learning_rate * grad
    Return W_local


## Asynchronous Weight Updating Federated Learning

### Federated Learning with Partial Weight Sharing (Deep Layers Updated Frequently)

### Server Initialization:
Initialize global shallow weights W_shallow₀
Initialize global deep weights W_deep₀
Set shallow_update_interval K  # e.g., update shallow layers every K rounds, update deep layers every round

### Main Federated Learning Loop:

For each round t = 1 to T:

    1. Determine if this is a shallow update round:
         If (t mod K == 0):
             shallow_update = True
         Else:
             shallow_update = False

    2. Client Selection & Broadcast:
         Select a subset of clients Sₜ
         For each client in Sₜ, send:
             - Current deep weights: W_deepₜ  (always sent)
             - If shallow_update is True, also send current shallow weights: W_shallowₜ
             - Otherwise, clients use their locally stored shallow weights

    3. Clients' Local Update (executed in parallel):
         For each client k in Sₜ:
             - If shallow_update is True:
                  (W_shallowₜ^k, W_deepₜ^k) = ClientUpdate(W_shallowₜ, W_deepₜ, local_dataₖ, update_shallow=True)
             - Else:
                  (W_shallow_local, W_deepₜ^k) = ClientUpdate(W_shallow_local, W_deepₜ, local_dataₖ, update_shallow=False)
             - Evaluate the full updated model on a common validation set:
                  aₖ = Evaluate(FullModel(W_shallow, W_deep), validation_set)  # e.g., accuracy
             - Return to server:
                  - For shallow layers: if update_shallow is True, return updated W_shallowₜ^k; otherwise, no update (or the previous version)
                  - Updated deep weights: W_deepₜ^k
                  - Local sample count nₖ and performance metric aₖ

    4. Server Aggregation:
         # Always aggregate deep layers:
         Compute Total_weight_deep = Σₖ₍∈Sₜ₎ (nₖ × aₖ)
         Update global deep weights:
             W_deepₜ₊₁ = Σₖ₍∈Sₜ₎ [ (nₖ × aₖ) / Total_weight_deep ] × W_deepₜ^k

         # Aggregate shallow layers only on shallow update rounds:
         If shallow_update is True:
             Compute Total_weight_shallow = Σₖ₍∈Sₜ₎ (nₖ × aₖ)
             Update global shallow weights:
                 W_shallowₜ₊₁ = Σₖ₍∈Sₜ₎ [ (nₖ × aₖ) / Total_weight_shallow ] × W_shallowₜ^k
         Else:
             W_shallowₜ₊₁ = W_shallowₜ  # Keep shallow layers unchanged

    Return final global model: {W_shallow_T, W_deep_T}


### ClientUpdate Function:

Function ClientUpdate(shallow_weights, deep_weights, local_data, update_shallow):

    If update_shallow is True:
         Set local_shallow = shallow_weights    # Received from server
    Else:
         Set local_shallow = local_shallow      # Use previously stored shallow weights locally

    Set local_deep = deep_weights              # Always use the latest deep weights from server

    For each local epoch e = 1 to E:
         For each batch b in local_data:
              If update_shallow is True:
                  - Compute gradients for both layers:
                        grad_shallow, grad_deep = ∇(loss(FullModel(local_shallow, local_deep), b))
                  - Update shallow layers:
                        local_shallow = local_shallow - learning_rate * grad_shallow
              Else:
                  - Compute gradient only for deep layers (shallow remains fixed):
                        grad_deep = ∇(loss(FullModel(local_shallow, local_deep), b))
              - Update deep layers:
                    local_deep = local_deep - learning_rate * grad_deep

    If update_shallow is True:
         Return (local_shallow, local_deep)
    Else:
         Return (local_shallow, local_deep)  # Note: shallow remains unchanged from before the round
