In [None]:
import flwr as fl
from typing import Dict, Optional, Tuple, List
from flwr.common import Metrics, Parameters, Scalar, FitRes, EvaluateRes
import logging
from datetime import datetime
import time

In [None]:
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("Server")

def log_status(status: str, details: str = ""):
    status_line = f"\n{'='*20} {status} {'='*20}"
    logger.info(status_line)
    if details:
        logger.info(details)
    logger.info("="*len(status_line))
    print(status_line)
    if details:
        print(details)
    print("="*len(status_line))

class LoggingStrategy(fl.server.strategy.FedAvg):
    def __init__(
        self,
        *,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        timeout: Optional[float] = None,
    ) -> None:
        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
        )
        self.round_metrics = []
        self.expected_clients = min_available_clients
        self.timeout = timeout
        self.current_round = 0
        self.connected_clients = set()

    def configure_fit(
        self, server_round: int, parameters: Parameters, client_manager
    ) -> List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitIns]]:
        self.current_round = server_round
        connected = client_manager.num_available()
        
        # Get newly connected clients
        current_clients = set(client_manager.all().keys())
        new_clients = current_clients - self.connected_clients
        self.connected_clients.update(new_clients)
        
        if new_clients:
            log_status("NEW CLIENT CONNECTED", 
                      f"Round {server_round}\n"
                      f"Total connected: {connected}/{self.expected_clients}\n"
                      f"New client(s): {len(new_clients)}\n"
                      f"Waiting for {self.expected_clients - connected} more")
        
        if connected < self.expected_clients:
            log_status("WAITING", 
                      f"Round {server_round}: Waiting for more clients\n"
                      f"Currently connected: {connected}/{self.expected_clients}\n"
                      f"You can start {self.expected_clients - connected} more client(s)")
            # Sleep to prevent too frequent logging
            time.sleep(2)
        else:
            log_status("STARTING ROUND", 
                      f"Round {server_round}: All {self.expected_clients} clients connected\n"
                      f"Beginning training round")
        
        return super().configure_fit(server_round, parameters, client_manager)

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[fl.server.client_proxy.ClientProxy, FitRes]],
        failures: List[BaseException],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        log_status("ROUND COMPLETE", 
                  f"Round {server_round} completed\n"
                  f"Successful clients: {len(results)}\n"
                  f"Failed clients: {len(failures)}")
        
        # Store metrics
        round_data = {
            "round": server_round,
            "timestamp": datetime.now().isoformat(),
            "num_clients": len(results),
            "num_failures": len(failures)
        }
        self.round_metrics.append(round_data)
        
        return super().aggregate_fit(server_round, results, failures)

    def configure_evaluate(
        self, server_round: int, parameters: Parameters, client_manager
    ) -> List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.EvaluateIns]]:
        log_status("EVALUATION", f"Round {server_round}: Starting evaluation")
        return super().configure_evaluate(server_round, parameters, client_manager)

    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[fl.server.client_proxy.ClientProxy, EvaluateRes]],
        failures: List[BaseException],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        log_status("EVALUATION COMPLETE", 
                  f"Round {server_round}\n"
                  f"Successful evaluations: {len(results)}\n"
                  f"Failed evaluations: {len(failures)}")
        return super().aggregate_evaluate(server_round, results, failures)

# Define a larger message size (1GB)
GRPC_MAX_MESSAGE_LENGTH = 1024 * 1024 * 1024

# Initialize strategy with longer timeout
strategy = LoggingStrategy(
    min_fit_clients=2,
    min_available_clients=2,
    min_evaluate_clients=2,
    timeout=None  # No timeout, wait indefinitely for clients
)

log_status("SERVER STARTING", 
          "Initializing Flower server\n"
          f"Address: 127.0.0.1:8081\n"
          f"Expected clients: 2\n"
          f"Message size: {GRPC_MAX_MESSAGE_LENGTH}\n"
          f"You can now start connecting clients one by one")

# Start server
try:
    fl.server.start_server(
        server_address="127.0.0.1:8081",
        config=fl.server.ServerConfig(
            num_rounds=1,
            round_timeout=None  # No timeout for rounds
        ),
        grpc_max_message_length=GRPC_MAX_MESSAGE_LENGTH,
        strategy=strategy
    )
    
    log_status("SERVER COMPLETED", 
              f"Training completed successfully\n"
              f"Total rounds: {len(strategy.round_metrics)}")
    
    # Print final metrics
    print("\nTraining Round Metrics:")
    for round_data in strategy.round_metrics:
        print(f"\nRound {round_data['round']}:")
        print(f"Timestamp: {round_data['timestamp']}")
        print(f"Active clients: {round_data['num_clients']}")
        print(f"Failed clients: {round_data['num_failures']}")
    
except Exception as e:
    log_status("SERVER ERROR", str(e))
finally:
    log_status("SERVER SHUTDOWN", "Flower server has shut down")