In [3]:
import warnings
from datetime import datetime
import pickle
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import accuracy_score, mean_squared_error
from xgboost import XGBClassifier, XGBRegressor
from imblearn.over_sampling import SMOTE
import flwr as fl
from flwr.common import FitIns, EvaluateIns, Parameters, NDArrays
import logging
from typing import Dict, Optional, Tuple, List

# Suppress warnings
warnings.filterwarnings('ignore')

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

2025-04-15 12:59:00.148576: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744721940.359098      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744721940.416659      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
# Load csv
df = pd.read_csv(r'./transactions_data_extended.csv')
df.head()

Unnamed: 0,date,amount,category,description,clean_description,is_fraud
0,2025-02-02 15:31:58,0.44,Food,Issue law fear fine economic smile.,Allow always this.,0
1,2023-08-17 11:26:05,66.76,Food,Trade type training hand effect.,Finally focus own.,0
2,2024-07-21 13:51:04,99.78,Shopping,Glass strategy woman whether bank weight.,Mind mission.,0
3,2024-09-08 07:12:18,77.1,Food,Will structure sport growth.,Many analysis begin.,0
4,2020-04-19 12:21:34,70.28,Transport,Opportunity ago wish partner behind.,Some beautiful read.,0


In [5]:
# Create synthetic data
def create_synthetic_data(num_samples: int = 1000) -> pd.DataFrame:
    try:
        np.random.seed(42)
        dates = pd.date_range(start='2023-01-01', periods=num_samples)
        amounts = np.random.exponential(scale=100, size=num_samples)
        categories = np.random.choice(
            ['grocery', 'utilities', 'entertainment', 'travel', 'healthcare', 'other'],
            size=num_samples
        )
        descriptions = [f"Transaction {i}" for i in range(num_samples)]
        clean_descriptions = [f"Clean {desc}" for desc in descriptions]
        is_fraud = np.random.choice([0, 1], size=num_samples, p=[0.95, 0.05])
        df = pd.DataFrame({
            'date': dates,
            'amount': amounts,
            'category': categories,
            'description': descriptions,
            'clean_description': clean_descriptions,
            'is_fraud': is_fraud
        })
        logger.info("Synthetic data created successfully")
        return df
    except Exception as e:
        logger.error(f"Synthetic data creation failed: {e}")
        raise

In [6]:
# Preprocess the dataset
def preprocess_data(df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, LabelEncoder]]:
    try:
        def parse_date(date_str):
            try:
                for fmt in ["%Y-%m-%d %H:%M:%S", "%d-%m-%Y", "%Y-%m-%d"]:
                    try:
                        return datetime.strptime(str(date_str).split()[0], fmt)
                    except ValueError:
                        continue
                return pd.NaT
            except Exception as e:
                logger.error(f"Date parsing error: {e}")
                return pd.NaT

        df['date'] = df['date'].apply(parse_date)
        df = df.dropna(subset=['date'])

        df['day'] = df['date'].dt.day
        df['month'] = df['date'].dt.month
        df['year'] = df['date'].dt.year
        df['day_of_week'] = df['date'].dt.dayofweek
        df['quarter'] = df['date'].dt.quarter
        df['week_of_year'] = df['date'].dt.isocalendar().week
        df['is_weekend'] = df['day_of_week'].apply(lambda x: 1 if x >= 5 else 0)
        df['is_month_start'] = df['day'].apply(lambda x: 1 if x <= 5 else 0)
        df['is_month_end'] = df['day'].apply(lambda x: 1 if x >= 25 else 0)

        if 'amount' in df.columns:
            df['amount_log'] = np.log1p(df['amount'].abs())
            df['is_large_amount'] = df['amount'].apply(lambda x: 1 if x > df['amount'].quantile(0.75) else 0)
            df['is_very_large_amount'] = df['amount'].apply(lambda x: 1 if x > df['amount'].quantile(0.95) else 0)

        label_encoders = {}
        for col in ['description', 'category', 'clean_description']:
            if col in df.columns:
                le = LabelEncoder()
                df[col] = le.fit_transform(df[col].astype(str))
                label_encoders[col] = le

        numerical_features = ['amount', 'amount_log', 'day', 'month', 'year', 'day_of_week', 'quarter', 'week_of_year']
        numerical_features = [f for f in numerical_features if f in df.columns]
        if numerical_features:
            scaler = StandardScaler()
            df[numerical_features] = scaler.fit_transform(df[numerical_features])

        logger.info(f"Dataset shape: {df.shape}")
        if 'category' in df.columns:
            logger.info(f"Number of categories: {df['category'].nunique()}")

        return df, label_encoders
    except Exception as e:
        logger.error(f"Preprocessing failed: {e}")
        raise

In [7]:
# Prepare datasets for each model
def prepare_model_datasets(df: pd.DataFrame) -> Dict[str, Tuple[pd.DataFrame, pd.Series]]:
    try:
        base_features = ['day', 'month', 'year', 'day_of_week', 'is_weekend', 'is_month_start', 
                         'is_month_end', 'quarter', 'week_of_year']
        amount_features = ['amount', 'amount_log', 'is_large_amount']

        fraud_df = df.copy()
        fraud_df['is_fraud'] = fraud_df.get('is_fraud', fraud_df['is_very_large_amount'])
        fraud_features = amount_features + ['is_very_large_amount'] + base_features
        if 'category' in df.columns:
            fraud_features.append('category')
        if 'clean_description' in df.columns:
            fraud_features.append('clean_description')
        fraud_features = [f for f in fraud_features if f in df.columns]
        X_fraud = fraud_df[fraud_features]
        y_fraud = fraud_df['is_fraud'].astype(int)

        budget_df = df[df['is_fraud'] == 0] if 'is_fraud' in df.columns else df.copy()
        budget_features = amount_features + base_features
        budget_features = [f for f in budget_features if f in budget_df.columns]
        X_budget = budget_df[budget_features]
        y_budget = budget_df['category'].astype(int)

        expense_df = budget_df.copy()
        expense_features = base_features
        if 'category' in expense_df.columns:
            expense_features.append('category')
        expense_features = [f for f in expense_features if f in df.columns]
        X_expense = expense_df[expense_features]
        y_expense = expense_df['amount_log'] if 'amount_log' in expense_df.columns else expense_df['amount']

        forecast_df = df.sort_values('date').copy()
        if 'amount' in forecast_df.columns:
            for window in [3, 7, 14]:
                forecast_df[f'rolling_avg_{window}'] = forecast_df['amount'].rolling(window=window).mean().fillna(0)
                forecast_df[f'rolling_std_{window}'] = forecast_df['amount'].rolling(window=window).std().fillna(0)
            forecast_df['amount_diff_1'] = forecast_df['amount'].diff(1).fillna(0)
        forecast_features = base_features + [col for col in forecast_df.columns if 'rolling_' in col or 'amount_diff' in col]
        forecast_features = [f for f in forecast_features if f in df.columns]
        X_forecast = forecast_df[forecast_features]
        y_forecast = forecast_df['amount_log'] if 'amount_log' in forecast_df.columns else forecast_df['amount']

        return {
            'fraud': (X_fraud, y_fraud),
            'budget': (X_budget, y_budget),
            'expense': (X_expense, y_expense),
            'forecast': (X_forecast, y_forecast)
        }
    except Exception as e:
        logger.error(f"Dataset preparation failed: {e}")
        raise

In [8]:
# Custom server strategy
class BestModelStrategy(fl.server.strategy.Strategy):
    def __init__(self, min_fit_clients: int, min_evaluate_clients: int, min_available_clients: int):
        super().__init__()
        self.min_fit_clients = min_fit_clients
        self.min_evaluate_clients = min_evaluate_clients
        self.min_available_clients = min_available_clients
        self.current_parameters = None
        self.best_score = float("-inf")
        self.best_parameters = None

    def initialize_parameters(self, client_manager) -> Optional[Parameters]:
        logger.info("Initializing parameters")
        return self.current_parameters

    def configure_fit(self, server_round: int, parameters: Parameters, client_manager) -> List[Tuple[fl.server.client_proxy.ClientProxy, FitIns]]:
        try:
            sample_size = max(self.min_fit_clients, int(client_manager.num_available() * 1.0))
            clients = client_manager.sample(num_clients=sample_size, min_num_clients=self.min_fit_clients)
            config = {}
            fit_ins = FitIns(parameters or [], config)
            logger.info(f"Round {server_round}: Configured fit for {len(clients)} clients")
            return [(client, fit_ins) for client in clients]
        except Exception as e:
            logger.error(f"Configure fit failed: {e}")
            return []

    def aggregate_fit(self, server_round: int, results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.FitRes]], failures: List) -> Tuple[Optional[Parameters], Dict]:
        try:
            if not results or failures:
                logger.warning(f"Round {server_round}: No results or failures detected: {failures}")
                return None, {}
            classifier_results = [(r, r.num_examples) for _, r in results if r.metrics.get("model_type") in ["fraud", "budget"] and r.num_examples > 0]
            regressor_results = [(r, r.num_examples) for _, r in results if r.metrics.get("model_type") in ["expense", "forecast"] and r.num_examples > 0]
            best_round_score = float("-inf")
            best_round_parameters = None
            for _, fit_res in classifier_results:
                if "train_score" in fit_res.metrics:
                    score = fit_res.metrics["train_score"]
                    if score > best_round_score:
                        best_round_score = score
                        best_round_parameters = fit_res.parameters
                        logger.info(f"Round {server_round}: Classifier score={score}, model_type={fit_res.metrics.get('model_type')}")
            for _, fit_res in regressor_results:
                if "train_score" in fit_res.metrics:
                    mse = -fit_res.metrics["train_score"]
                    score = 1.0 / (1.0 + mse) if mse >= 0 else 0.0
                    if score > best_round_score:
                        best_round_score = score
                        best_round_parameters = fit_res.parameters
                        logger.info(f"Round {server_round}: Regressor score={score}, raw_mse={mse}, model_type={fit_res.metrics.get('model_type')}")
            if best_round_score > self.best_score:
                self.best_score = best_round_score
                self.best_parameters = best_round_parameters
            self.current_parameters = best_round_parameters or self.current_parameters
            metrics = {"best_score": best_round_score if best_round_score != float("-inf") else 0.0}
            logger.info(f"Round {server_round}: Aggregated fit, best score={best_round_score}")
            return self.current_parameters, metrics
        except Exception as e:
            logger.error(f"Round {server_round}: Aggregate fit failed: {e}")
            return None, {}

    def configure_evaluate(self, server_round: int, parameters: Parameters, client_manager) -> List[Tuple[fl.server.client_proxy.ClientProxy, EvaluateIns]]:
        try:
            sample_size = max(self.min_evaluate_clients, int(client_manager.num_available() * 1.0))
            clients = client_manager.sample(num_clients=sample_size, min_num_clients=self.min_evaluate_clients)
            config = {}
            evaluate_ins = EvaluateIns(parameters or [], config)
            logger.info(f"Round {server_round}: Configured evaluate for {len(clients)} clients")
            return [(client, evaluate_ins) for client in clients]
        except Exception as e:
            logger.error(f"Configure evaluate failed: {e}")
            return []

    def aggregate_evaluate(self, server_round: int, results: List[Tuple[fl.server.client_proxy.ClientProxy, fl.common.EvaluateRes]], failures: List) -> Tuple[Optional[float], Dict]:
        try:
            if not results or failures:
                logger.warning(f"Round {server_round}: No evaluation results or failures detected: {failures}")
                return None, {}
            num_examples_total = sum([evaluate_res.num_examples for _, evaluate_res in results])
            if num_examples_total == 0:
                logger.warning(f"Round {server_round}: No evaluation examples available")
                return None, {}
            classifier_results = [(r, r.num_examples) for _, r in results if r.metrics.get("model_type") in ["fraud", "budget"] and r.num_examples > 0]
            regressor_results = [(r, r.num_examples) for _, r in results if r.metrics.get("model_type") in ["expense", "forecast"] and r.num_examples > 0]
            metrics_aggregated = {}
            loss_aggregated = 0.0
            total_examples = 0
            for _, res in results:
                logger.info(f"Round {server_round}: Client result - model_type={res.metrics.get('model_type')}, loss={res.loss}, metrics={res.metrics}, examples={res.num_examples}")
            if classifier_results:
                classifier_examples = sum([n for _, n in classifier_results])
                if classifier_examples > 0:
                    classifier_loss = sum([r.loss * n for r, n in classifier_results]) / classifier_examples
                    classifier_accuracy = sum([r.metrics.get("accuracy", 0) * n for r, n in classifier_results]) / classifier_examples
                    metrics_aggregated["accuracy"] = classifier_accuracy
                    loss_aggregated += classifier_loss * classifier_examples
                    total_examples += classifier_examples
                    logger.info(f"Round {server_round}: Classifier loss={classifier_loss}, accuracy={classifier_accuracy}, examples={classifier_examples}")
                else:
                    logger.warning(f"Round {server_round}: No valid classifier examples")
            if regressor_results:
                regressor_examples = sum([n for _, n in regressor_results])
                if regressor_examples > 0:
                    regressor_loss = sum([r.loss * n for r, n in regressor_results]) / regressor_examples
                    regressor_mse = sum([r.metrics.get("mse", 0) * n for r, n in regressor_results]) / regressor_examples
                    metrics_aggregated["mse"] = regressor_mse
                    loss_aggregated += regressor_loss * regressor_examples
                    total_examples += regressor_examples
                    logger.info(f"Round {server_round}: Regressor loss={regressor_loss}, mse={regressor_mse}, examples={regressor_examples}")
                else:
                    logger.warning(f"Round {server_round}: No valid regressor examples")
            if total_examples > 0:
                loss_aggregated /= total_examples
            else:
                logger.warning(f"Round {server_round}: No valid examples for loss aggregation")
                return None, metrics_aggregated
            logger.info(f"Round {server_round}: Aggregated evaluation, loss={loss_aggregated}, metrics={metrics_aggregated}")
            return loss_aggregated, metrics_aggregated
        except Exception as e:
            logger.error(f"Round {server_round}: Aggregate evaluate failed: {e}")
            return None, {}

    def evaluate(self, server_round: int, parameters: Parameters) -> Optional[Tuple[float, Dict]]:
        try:
            logger.info(f"Round {server_round}: Performing server-side evaluation")
            return None
        except Exception as e:
            logger.error(f"Round {server_round}: Server-side evaluation failed: {e}")
            return None

In [None]:
# Flower client class
class FinancialModelClient(fl.client.NumPyClient):
    def __init__(self, client_id: int, model_type: str = 'fraud'):
        self.client_id = client_id
        self.model_type = model_type
        try:
            df = create_synthetic_data(num_samples=2000)  
            df, _ = preprocess_data(df)
            logger.info(f"Client {client_id}: Data loaded")
        except Exception as e:
            logger.error(f"Client {client_id}: Data loading failed: {e}")
            raise
        datasets = prepare_model_datasets(df)
        if model_type not in datasets:
            logger.error(f"Client {client_id}: Invalid model type {model_type}, defaulting to 'fraud'")
            self.model_type = 'fraud'
            model_type = 'fraud'
        X, y = datasets[model_type]
        self.X_train, self.X_val, self.y_train, self.y_val = train_test_split(
            X, y, test_size=0.2, random_state=42 + client_id
        )
        if model_type == 'fraud':
            unique_labels = len(np.unique(self.y_train))
            pos_samples = sum(self.y_train == 1)
            if unique_labels > 1 and pos_samples >= 3:
                try:
                    smote = SMOTE(random_state=42 + client_id, k_neighbors=min(3, pos_samples - 1))
                    self.X_train, self.y_train = smote.fit_resample(self.X_train, self.y_train)
                    logger.info(f"Client {client_id}: SMOTE applied, {pos_samples} positive samples")
                except ValueError as e:
                    logger.warning(f"Client {client_id}: SMOTE failed, using original data: {e}")
            else:
                logger.warning(f"Client {client_id}: Insufficient fraud samples ({pos_samples}), skipping SMOTE")
        if model_type in ['fraud', 'budget']:
            unique_labels = len(np.unique(self.y_train))
            if unique_labels < 2:
                logger.warning(f"Client {client_id}: Single-class training data, adjusting labels")
                self.y_train = np.ones(len(self.y_train))  
            neg_count = sum(self.y_train == 0)
            pos_count = sum(self.y_train != 0)
            scale_pos_weight = neg_count / pos_count if pos_count > 0 else 1
            self.model = XGBClassifier(
                use_label_encoder=False,
                eval_metric='logloss',
                random_state=42 + client_id,
                max_depth=4,
                learning_rate=0.1,  
                n_estimators=200,   
                scale_pos_weight=scale_pos_weight
            )
            self.is_classifier = True
        else:
            self.model = XGBRegressor(
                random_state=42 + client_id,
                max_depth=4,
                learning_rate=0.1,
                n_estimators=200
            )
            self.is_classifier = False
        self.train_local_model()

    def train_local_model(self):
        try:
            self.model.fit(self.X_train, self.y_train)
            logger.info(f"Client {self.client_id}: Local model trained")
        except Exception as e:
            logger.error(f"Client {self.client_id}: Training failed: {e}")
            raise

    def get_parameters(self, config) -> NDArrays:
        try:
            model_bytes = pickle.dumps(self.model)
            return [np.frombuffer(model_bytes, dtype=np.uint8)]
        except Exception as e:
            logger.error(f"Client {self.client_id}: Parameter serialization failed: {e}")
            return []

    def set_parameters(self, parameters: NDArrays):
        if not parameters or not parameters[0].size:
            logger.warning(f"Client {self.client_id}: No valid parameters provided")
            return
        try:
            model_bytes = parameters[0].tobytes()
            self.model = pickle.loads(model_bytes)
            logger.info(f"Client {self.client_id}: Parameters updated")
        except Exception as e:
            logger.warning(f"Client {self.client_id}: Parameter deserialization failed: {e}")

    def fit(self, parameters: NDArrays, config: Dict) -> Tuple[NDArrays, int, Dict]:
        try:
            self.set_parameters(parameters)
            self.train_local_model()
            updated_parameters = self.get_parameters(config)
            if self.is_classifier:
                y_pred = self.model.predict(self.X_train)
                train_score = float(accuracy_score(self.y_train, y_pred))
            else:
                y_pred = self.model.predict(self.X_train)
                train_score = -float(mean_squared_error(self.y_train, y_pred))
            logger.info(f"Client {self.client_id}: Fit completed with score {train_score}")
            return updated_parameters, len(self.X_train), {"train_score": train_score, "model_type": self.model_type}
        except Exception as e:
            logger.error(f"Client {self.client_id}: Fit failed: {e}")
            return [], 0, {"model_type": self.model_type}

    def evaluate(self, parameters: NDArrays, config: Dict) -> Tuple[float, int, Dict]:
        try:
            self.set_parameters(parameters)
            if self.is_classifier:
                if len(np.unique(self.y_val)) < 2:
                    logger.warning(f"Client {self.client_id} ({self.model_type}): Single-class validation data")
                    return 1.0, len(self.X_val), {"accuracy": 0.0, "model_type": self.model_type}
                y_pred = self.model.predict(self.X_val)
                loss = 1.0 - accuracy_score(self.y_val, y_pred)
                accuracy = float(accuracy_score(self.y_val, y_pred))
                metrics = {"accuracy": accuracy, "model_type": self.model_type}
            else:
                y_pred = self.model.predict(self.X_val)
                loss = float(mean_squared_error(self.y_val, y_pred))
                metrics = {"mse": loss, "model_type": self.model_type}
            logger.info(f"Client {self.client_id} ({self.model_type}): Evaluation - loss={loss}, metrics={metrics}")
            return loss, len(self.X_val), metrics
        except Exception as e:
            logger.error(f"Client {self.client_id} ({self.model_type}): Evaluation failed: {e}")
            return float("inf"), 0, {"model_type": self.model_type}

In [10]:
# Define client_fn
def client_fn(cid: str) -> fl.client.Client:
    try:
        client_id = int(cid)
        model_types = ['fraud', 'budget', 'expense', 'forecast']
        model_type = model_types[client_id % len(model_types)]
        logger.info(f"Creating client {client_id} with model type {model_type}")
        client = FinancialModelClient(client_id, model_type).to_client()
        return client
    except Exception as e:
        logger.error(f"Client creation failed for CID {cid}: {e}")
        raise

In [11]:
# Main simulation
def main_fl():
    try:
        logger.info("Starting federated learning simulation")
        strategy = BestModelStrategy(
            min_fit_clients=4,
            min_evaluate_clients=4,
            min_available_clients=4
        )
        fl.simulation.start_simulation(
            client_fn=client_fn,
            num_clients=4,
            config=fl.server.ServerConfig(num_rounds=5),
            strategy=strategy,
            client_resources={"num_cpus": 1, "num_gpus": 0},
            ray_init_args={"ignore_reinit_error": True}
        )
        logger.info("Federated learning simulation completed")
    except Exception as e:
        logger.error(f"Simulation failed: {e}")
        raise

In [12]:
# Run the simulation
main_fl()

	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=5, no round_timeout
2025-04-15 13:01:18,698	INFO worker.py:1852 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'node:172.19.2.2': 1.0, 'node:__internal_head__': 1.0, 'CPU': 4.0, 'object_store_memory': 9002540236.0, 'memory': 21005927220.0, 'GPU': 2.0, 'accelerator_type:T4': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 