In [1]:
import os
import pandas as pd
import numpy as np
import multiprocessing
import tensorflow as tf

class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'
    

RND_SEED = 0
GPU_ID = 0
USE_GPU = True

if tf.config.list_physical_devices('GPU') and USE_GPU:
    gpus = tf.config.list_physical_devices('GPU')
    gpu_name = gpus[0].name  # You may specify an index if using multiple GPUs
    print(f"Using GPU - {gpu_name}")
    device = '/GPU:0'
else:
    device = '/CPU:0'

with tf.device(device):
    pass

print(f"Using device {device}")

N_CPU_CORES = multiprocessing.cpu_count()

BASE_FOLDER = os.path.join(
    "..", "input", "tlvmc-parkinsons-freezing-gait-prediction"
)

print(f"Number of CPU cores available: {N_CPU_CORES}")

pd.set_option('display.max_columns', 30)
pd.set_option('display.max_rows', 200)

from sklearn.metrics import f1_score
import os
import numpy as np
from tensorflow.keras.utils import Sequence, to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
from typing import List, Tuple
from models.cnn_models import *
import random
img_shape = (64, 64)

Using GPU - /physical_device:GPU:0
Using device /GPU:0
Number of CPU cores available: 12


In [2]:
class CustomMultiInputDataGenerator(Sequence):
    def __init__(self, directories, batch_size=32, image_size=img_shape, shuffle=True, augment=False, num_classes=2, **kwargs):
        """
        directories: List of directories, one for each input branch.
        """
        super().__init__(**kwargs)
        self.directories = directories
        self.batch_size = batch_size
        self.image_size = image_size
        self.shuffle = shuffle
        self.augment = augment
        self.num_classes = num_classes
        self.image_paths = self._load_image_paths()
        self.samples = len(self.image_paths[0])
        
        self.datagen = ImageDataGenerator(
            rescale=1/255.0,
            width_shift_range=0.1 if self.augment else 0,
            height_shift_range=0.1 if self.augment else 0
        )
        
        self.on_epoch_end()

    def _load_image_paths(self):
        # Load image paths for each branch directory
        image_paths = []
        for directory in self.directories:
            branch_image_paths = [os.path.join(directory, fname) for fname in os.listdir(directory) if fname.endswith('.jpg')]
            image_paths.append(branch_image_paths)
        return image_paths

    def _get_class_from_filename(self, filename):
        # Extract label from filename assuming a naming convention
        class_label = int(filename.split('_')[-1].split('.')[0])
        return class_label

    def __len__(self):
        return int(np.ceil(self.samples / self.batch_size))

    def __getitem__(self, index):
        # Load a batch of images for each branch
        batch_image_paths = [paths[index * self.batch_size:(index + 1) * self.batch_size] for paths in self.image_paths]
        
        # Load and preprocess images for each branch
        images_per_branch = []
        for branch_paths in batch_image_paths:
            images = np.array([img_to_array(load_img(path, target_size=self.image_size)) for path in branch_paths])
            if self.augment:
                images = np.array([self.datagen.random_transform(image) for image in images])
            else:
                images = self.datagen.standardize(images)
            images_per_branch.append(images)

        # Load labels (assuming the same labels for each branch)
        labels = np.array([self._get_class_from_filename(os.path.basename(path)) for path in batch_image_paths[0]])
        labels = to_categorical(labels, num_classes=self.num_classes)
        
        return images_per_branch, labels

    def on_epoch_end(self):
        if self.shuffle:
            for branch_paths in self.image_paths:
                np.random.shuffle(branch_paths)

In [None]:
DEBUG = True

class FederatedClient:
    def __init__(self, client_id: int, model_fn, train_generator, valid_generator=None):
        self.client_id = client_id
        self.model = model_fn()
        self.train_generator = train_generator
        self.valid_generator = valid_generator
    
    def train(self, global_weights, local_epochs=5):
        # Update local model with global weights
        self.model.set_weights(global_weights)
        if DEBUG:
            print("train_generator:")
            print(len(self.train_generator[0][1]), self.train_generator.samples)
            
        # Train the model locally using the generator
        history = self.model.fit(
            self.train_generator,
            epochs=local_epochs,
            validation_data=self.valid_generator,
            verbose=1
        )
        
        return self.model.get_weights(), history.history

class FederatedServer:
    def __init__(self, model_fn):
        self.global_model = model_fn()
        self.clients = []
        
    def add_client(self, client: FederatedClient):
        self.clients.append(client)
    
    def aggregate_weights(self, client_weights: List[np.ndarray]) -> List[np.ndarray]:
        """FedAvg aggregation with sample weighting"""
        # Get number of samples for each client
        client_samples = [client.train_generator.samples for client in self.clients]
        total_samples = sum(client_samples)
        
        # Calculate weighted average based on number of samples
        weighted_weights = []
        for idx, weights in enumerate(client_weights):
            weight = client_samples[idx] / total_samples
            weighted_weights.append([w * weight for w in weights])
        
        # Sum up the weighted weights
        averaged_weights = [
            sum(weights_list) 
            for weights_list in zip(*weighted_weights)
        ]
        
        return averaged_weights

    def evaluate(self, test_generator):
        """Evaluate the global model on test data"""
        return self.global_model.evaluate(test_generator)


def split_data_for_clients(train_generator, num_clients: int) -> List[List[List[str]]]:
    """
    Split image paths into non-overlapping subsets for each client
    
    Args:
        image_paths_list: List of lists of image paths (one list per branch)
        num_clients: Number of clients to split data for
    
    Returns:
        List of client datasets, where each client dataset contains lists of paths for each branch
    """
    image_paths_list = train_generator.image_paths
    # Ensure all branches have the same number of images
    total_samples = len(image_paths_list[0])
    assert all(len(paths) == total_samples for paths in image_paths_list)
    
    # Create indices for splitting
    indices = list(range(total_samples))
    
    if train_generator.shuffle:
        random.shuffle(indices)
    
    # Calculate samples per client
    samples_per_client = total_samples // num_clients
    
    # Split indices for each client
    client_indices = [
        indices[i * samples_per_client:(i + 1) * samples_per_client]
        for i in range(num_clients)
    ]
    
    # Add remaining samples to the last client
    if total_samples % num_clients != 0:
        remaining = indices[num_clients * samples_per_client:]
        client_indices[-1].extend(remaining)
    
    # Create client datasets
    client_datasets = []
    for client_idx in client_indices:
        client_data = []
        for branch_paths in image_paths_list:
            client_branch_paths = [branch_paths[i] for i in client_idx]
            client_data.append(client_branch_paths)
        client_datasets.append(client_data)
    
    return client_datasets

def create_client_generator(base_generator, client_image_paths):
    """
    Create a new generator for a client with specific image paths
    """
    client_generator = CustomMultiInputDataGenerator(
        directories=base_generator.directories,
        batch_size=base_generator.batch_size,
        image_size=base_generator.image_size,
        shuffle=base_generator.shuffle,
        augment=base_generator.augment,
        num_classes=base_generator.num_classes
    )
    
    # Set the client-specific image paths
    client_generator.image_paths = client_image_paths
    client_generator.samples = len(client_image_paths[0])
    
    return client_generator


def create_federated_learning_system(
    num_clients: int,
    train_generator,
    valid_generator=None,
    input_shape1=(64, 64, 3),
    input_shape2=(64, 64, 3),
    input_shape3=(64, 64, 3),
    num_classes=2
):
    """
    Creates a federated learning system using data generators
    
    Args:
        num_clients: Number of clients to simulate
        train_generator: CustomMultiInputDataGenerator for training
        valid_generator: Optional CustomMultiInputDataGenerator for validation
        input_shape1/2/3: Input shapes for the three branches
        num_classes: Number of output classes
    """
    
    def model_fn():
        return create_multi_input_cnn(
            input_shape1=input_shape1,
            input_shape2=input_shape2,
            input_shape3=input_shape3
        )
    
    # Initialize server
    server = FederatedServer(model_fn)
    
    # Split data among clients
    client_datasets = split_data_for_clients(train_generator, num_clients)
    
    # Create clients with their specific data
    for i, client_image_paths in enumerate(client_datasets):
        # Create client-specific generator
        client_train_generator = create_client_generator(train_generator, client_image_paths)
        
        # Create client
        client = FederatedClient(
            client_id=i,
            model_fn=model_fn,
            train_generator=client_train_generator,
            valid_generator=valid_generator
        )
        server.add_client(client)
        
        print(f"Client {i} created with {client_train_generator.samples} samples")
    
    return server

def train_federated(
    server: FederatedServer,
    num_rounds: int,
    local_epochs: int,
    test_generator=None
) -> dict:
    """
    Train the model using federated learning
    
    Args:
        server: FederatedServer instance
        num_rounds: Number of federated learning rounds
        local_epochs: Number of local epochs per client
        test_generator: Optional generator for testing global model
    """
    metrics_history = []
    
    for round_num in range(num_rounds):
        print(f"\nFederated Learning Round {round_num + 1}/{num_rounds}")
        
        # Get current global weights
        global_weights = server.global_model.get_weights()
        
        # Train each client locally
        client_weights = []
        client_metrics = []
        
        for client in server.clients:
            print(f"\nTraining Client {client.client_id + 1}/{len(server.clients)}")
            weights, metrics = client.train(
                global_weights,
                local_epochs=local_epochs
            )
            client_weights.append(weights)
            client_metrics.append(metrics)
        
        # Aggregate weights using FedAvg
        new_global_weights = server.aggregate_weights(client_weights)
        
        # Update global model
        server.global_model.set_weights(new_global_weights)
        
        # Evaluate global model if test generator is provided
        if test_generator is not None:
            print("\nEvaluating global model:")
            test_metrics = server.evaluate(test_generator)
            test_results = dict(zip(server.global_model.metrics_names, test_metrics))
            print("Test metrics:", test_results)
        
        # Aggregate training metrics
        round_metrics = {
            metric: np.mean([client_metric[metric][-1] 
                           for client_metric in client_metrics])
            for metric in client_metrics[0].keys()
        }
        metrics_history.append(round_metrics)
        
        # Print round metrics
        print("\nRound Training Metrics:")
        for metric, value in round_metrics.items():
            print(f"{metric}: {value:.4f}")
    
    return metrics_history

In [10]:
train_directories = ["../data/federated_learning_data/AccAP/train", "../data/federated_learning_data/AccML/train", "../data/federated_learning_data/AccV/train"]

val_directories = ["../data/federated_learning_data/AccAP/valid", "../data/federated_learning_data/AccML/valid", "../data/federated_learning_data/AccV/valid"]

train_generator = CustomMultiInputDataGenerator(
    directories=train_directories,
    batch_size=32,
    image_size=img_shape,
    augment=False,
    num_classes=2
)

valid_generator = CustomMultiInputDataGenerator(
    directories=val_directories,
    batch_size=32,
    image_size=img_shape,
    augment=False,
    num_classes=2
)

In [None]:
server = create_federated_learning_system(
    num_clients=5,
    train_generator=train_generator,
    valid_generator=valid_generator,
    input_shape1=(64, 64, 3),
    input_shape2=(64, 64, 3),
    input_shape3=(64, 64, 3),
    num_classes=2
)

# Train using federated learning
metrics_history = train_federated(
    server,
    num_rounds=10,
    local_epochs=5,
    test_generator=valid_generator
)

In [14]:
test_directories = ["../data/federated_learning_data/AccAP/test", "../data/federated_learning_data/AccML/test", "../data/federated_learning_data/AccV/test"]

test_generator = CustomMultiInputDataGenerator(
    directories=test_directories,
    batch_size=32,
    image_size=img_shape,
    augment=False,
    num_classes=2,
    shuffle=False
)

In [15]:
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score, precision_score, recall_score
import numpy as np

def evaluate_model(model, generator):
    y_true = []
    y_pred = []
    
    for i in range(len(generator)):
        images, batch_y_true = generator[i]
        
        batch_predictions = model.predict(images, verbose=0)
        
        batch_y_pred = np.argmax(batch_predictions, axis=1)
        batch_y_true = np.argmax(batch_y_true, axis=1)
        
        y_true.extend(batch_y_true.flatten())
        y_pred.extend(batch_y_pred.flatten())
    
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    
    total_samples = len(y_true)
    
    tn_percent = (tn / total_samples) * 100
    fp_percent = (fp / total_samples) * 100
    fn_percent = (fn / total_samples) * 100
    tp_percent = (tp / total_samples) * 100
    
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    
    precision_per_class = precision_score(y_true, y_pred, average=None, zero_division=0)
    
    recall_per_class = recall_score(y_true, y_pred, average=None, zero_division=0)
    
    metrics = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'specificity': specificity,
        'sensitivity': sensitivity,
        
        'f1_micro': f1_score(y_true, y_pred, average='micro'),
        'f1_macro': f1_score(y_true, y_pred, average='macro'),
        'f1_weighted': f1_score(y_true, y_pred, average='weighted'),
        'f1_per_class': f1_score(y_true, y_pred, average=None),
        
        'precision_per_class': precision_per_class,
        'recall_per_class': recall_per_class,
        
        'true_positives_percentage': tp_percent,
        'true_negatives_percentage': tn_percent,
        'false_positives_percentage': fp_percent,
        'false_negatives_percentage': fn_percent,
        
        'true_positives': tp,
        'true_negatives': tn,
        'false_positives': fp,
        'false_negatives': fn
    }
    
    return metrics

results = evaluate_model(server.global_model, test_generator)

print(results)

{'accuracy': 0.8698641087130295, 'precision': 0.8683679503357472, 'recall': 0.8698641087130295, 'specificity': 0.767056530214425, 'sensitivity': 0.9200571020699501, 'f1_micro': 0.8698641087130295, 'f1_macro': 0.849660627151198, 'f1_weighted': 0.8686129554719567, 'f1_per_class': array([0.79454821, 0.90477305]), 'precision_per_class': array([0.82408377, 0.88998849]), 'recall_per_class': array([0.76705653, 0.9200571 ]), 'true_positives_percentage': 61.82254196642686, 'true_negatives_percentage': 25.163868904876097, 'false_positives_percentage': 7.641886490807353, 'false_negatives_percentage': 5.371702637889688, 'true_positives': 3867, 'true_negatives': 1574, 'false_positives': 478, 'false_negatives': 336}
