# Step 1

Step 2: Data Preparation
For your reproduction, you'll need:

Public datasets: Start with MNIST and SVHN as the paper did

Data partitioning strategy: You'll split data into:

Public dataset (10% of training data)

Client private datasets (90% of training data)

Testing dataset (20% of original data)

The non-IID scenario is particularly interesting-each client only gets data from two classes. Implement a function that handles both IID and non-IID distributions so you can compare results.

Don't underestimate the importance of proper data splitting! A lot of reproducibility issues stem from differences in how data is partitioned across clients.

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split, Subset, DataLoader
import numpy as np
import random
import os

In [2]:
# Use the desktop directory for data storage
DATA_DIR = os.path.join(os.path.expanduser("~"), "Desktop", "data")

# Ensure the directory exists
if not os.path.exists(DATA_DIR):
    os.makedirs(DATA_DIR)

print(f"Data directory is set to: {DATA_DIR}")

Data directory is set to: C:\Users\rafha\Desktop\data


In [3]:
# Create data directory if it does not exist
os.makedirs(DATA_DIR, exist_ok=True)

# Set random seed for reproducibility
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

In [4]:
# Configurations
NUM_CLIENTS = 10  # Adjust as needed
NON_IID_CLASSES_PER_CLIENT = 2
BATCH_SIZE = 32


In [5]:
# Transformations
transform_mnist = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # Real MNIST mean and std
])

transform_svhn = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))  # Real SVHN mean and std
])

In [6]:
# Load Datasets
def load_datasets():
    mnist_train = torchvision.datasets.MNIST(root=DATA_DIR, train=True, download=True, transform=transform_mnist)
    mnist_test = torchvision.datasets.MNIST(root=DATA_DIR, train=False, download=True, transform=transform_mnist)

    svhn_train = torchvision.datasets.SVHN(root=DATA_DIR, split='train', download=True, transform=transform_svhn)
    svhn_test = torchvision.datasets.SVHN(root=DATA_DIR, split='test', download=True, transform=transform_svhn)

    return mnist_train, mnist_test, svhn_train, svhn_test

In [7]:
# IID Split
def iid_split(dataset, num_clients):
    data_per_client = len(dataset) // num_clients
    indices = torch.randperm(len(dataset))
    return {i: Subset(dataset, indices[i * data_per_client:(i + 1) * data_per_client]) for i in range(num_clients)}


In [8]:
# Non-IID Split

def get_targets(dataset):
    """Recursively find the targets or labels attribute inside Subsets."""
    while isinstance(dataset, Subset):
        dataset = dataset.dataset
    if hasattr(dataset, 'targets'):
        return np.array(dataset.targets)
    elif hasattr(dataset, 'labels'):
        return np.array(dataset.labels)
    else:
        raise AttributeError("Dataset does not have 'targets' or 'labels' attribute.")

def noniid_split(dataset, num_clients, classes_per_client):
    full_targets = get_targets(dataset)
    if isinstance(dataset, Subset):
        subset_indices = dataset.indices
        targets = full_targets[subset_indices]
    else:
        targets = full_targets

    class_indices = {cls: np.where(targets == cls)[0] for cls in np.unique(targets)}

    client_indices = {i: [] for i in range(num_clients)}
    available_classes = list(class_indices.keys())
    random.shuffle(available_classes)

    for i in range(num_clients):
        selected_classes = random.sample(available_classes, classes_per_client)
        for cls in selected_classes:
            n_samples = max(1, len(class_indices[cls]) // num_clients)
            chosen = np.random.choice(class_indices[cls], size=n_samples, replace=True)
            client_indices[i].extend(chosen)

    return {i: Subset(dataset, client_indices[i]) for i in range(num_clients)}


In [9]:
# Dataset Preparation (old version)
def prepare_datasets(dataset, iid=True):
    total_len = len(dataset)
    test_len = int(0.2 * total_len)
    remaining_len = total_len - test_len

    test_set, remaining_set = random_split(dataset, [test_len, remaining_len], generator=torch.Generator().manual_seed(SEED))

    public_len = int(0.1 * remaining_len)
    private_len = remaining_len - public_len

    public_set, private_set = random_split(remaining_set, [public_len, private_len], generator=torch.Generator().manual_seed(SEED))

    if iid:
        clients = iid_split(private_set, NUM_CLIENTS)
    else:
        clients = noniid_split(private_set, NUM_CLIENTS, NON_IID_CLASSES_PER_CLIENT)

    return {
        "public": public_set,
        "test": test_set,
        "clients": clients
    }

In [10]:
def prepare_datasets(dataset, num_clients=NUM_CLIENTS, classes_per_client=NON_IID_CLASSES_PER_CLIENT, iid=True):
    """
    Splits a dataset into public, test, and client-specific datasets.
    
    Args:
        dataset: The full dataset to split.
        num_clients: Number of clients for private dataset distribution.
        classes_per_client: Number of classes per client for Non-IID splits.
        iid: If True, performs IID splitting. Otherwise, performs Non-IID splitting.
    
    Returns:
        A dictionary containing:
            - "public": Public dataset for clustering.
            - "test": Test dataset for evaluation.
            - "clients": Dictionary of client-specific datasets.
    """
    # Shuffle dataset indices for randomness
    total_indices = list(range(len(dataset)))
    random.shuffle(total_indices)

    # Calculate lengths for splits
    test_len = int(0.2 * len(dataset))
    public_len = int(0.1 * (len(dataset) - test_len))  # 10% of training data
    private_len = len(dataset) - test_len - public_len

    # Create splits (mutually exclusive)
    test_indices = total_indices[:test_len]
    public_indices = total_indices[test_len:test_len + public_len]
    private_indices = total_indices[test_len + public_len:]

    # Create Subsets
    test_set = Subset(dataset, test_indices)
    public_set = Subset(dataset, public_indices)
    private_set = Subset(dataset, private_indices)

    # Perform client-specific splits
    if iid:
        clients = iid_split(private_set, num_clients)
    else:
        clients = noniid_split(private_set, num_clients, classes_per_client)

    # Validation for mutual exclusivity (optional, should now pass)
    assert not set(test_indices).intersection(public_indices), "Public and Test datasets overlap!"
    assert not set(test_indices).intersection(private_indices), "Test and Private datasets overlap!"
    assert not set(public_indices).intersection(private_indices), "Public and Private datasets overlap!"

    return {
        "public": public_set,
        "test": test_set,
        "clients": clients
    }

In [11]:
# Example Usage
if __name__ == "__main__":
    mnist_train, mnist_test, svhn_train, svhn_test = load_datasets()

    # Prepare MNIST with non-IID split
    mnist_data = prepare_datasets(mnist_train, iid=False)

    # Prepare SVHN with IID split
    svhn_data = prepare_datasets(svhn_train, iid=True)

    # Example: Create DataLoader for a client
    client0_loader = DataLoader(mnist_data['clients'][0], batch_size=BATCH_SIZE, shuffle=True)
    public_loader = DataLoader(mnist_data['public'], batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(mnist_data['test'], batch_size=BATCH_SIZE, shuffle=False)

    print(f"Client 0 dataset size: {len(mnist_data['clients'][0])}")
    print(f"Public dataset size: {len(mnist_data['public'])}")
    print(f"Test dataset size: {len(mnist_data['test'])}")


Client 0 dataset size: 907
Public dataset size: 4800
Test dataset size: 12000


  return np.array(dataset.targets)


# Step 2: Model Architecture Implementation
You'll need to implement four different model architectures as described in the paper:

Create models with varying complexity (different numbers of convolutional layers, fully-connected layers, etc.)

Ensure each model can extract intermediate layer outputs (this is critical for the feature vector extraction)

Add hooks or functions that can extract these features when passing data through the models

This is where many students struggle-identifying which layer to extract features from. Experiment with different layers to see which ones capture structural information best.

In [12]:
import torch

# Check if GPU is available and being used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Example: Move a tensor to the selected device
x = torch.tensor([1.0, 2.0, 3.0]).to(device)
print(f"Tensor is on device: {x.device}")

Using device: cuda
Tensor is on device: cuda:0


In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
from typing import Dict, List, Tuple, Optional
import numpy as np

# For reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x16e2964d790>

In [14]:
class FeatureExtractor:
    def __init__(self):
        self.features = {}
        self.hooks = []

    def get_activation(self, name):
        def hook(model, input, output):
            print(f"Hook triggered for layer '{name}' with output shape: {output.shape}")  # Debugging
            self.features[name] = output.detach()
        return hook

    def attach_hooks(self, model: nn.Module, layer_names: List[str]):
        """Attach hooks to specified layers"""
        for name, layer in model.named_modules():
            if name in layer_names:
                print(f"Attaching hook to layer: {name}")  # Debugging
                hook = layer.register_forward_hook(self.get_activation(name))
                self.hooks.append(hook)

    def clear_hooks(self):
        """Remove all hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()
        self.features.clear()

In [15]:
class BaseModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.feature_layers = []  # Layers to extract features from

    def get_features(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        """Extract features from specified layers"""
        self.feature_extractor.attach_hooks(self, self.feature_layers)
        _ = self(x)  # Forward pass
        features = self.feature_extractor.features
        self.feature_extractor.clear_hooks()
        return features

    def get_flattened_features(self, x: torch.Tensor) -> torch.Tensor:
        """Get concatenated features from all specified layers"""
        # Attach hooks and perform forward pass
        self.feature_extractor.attach_hooks(self, self.feature_layers)
        _ = self(x)  # Forward pass
        features = self.feature_extractor.features

        # Flatten and concatenate the features
        flattened_features = []
        for name in self.feature_layers:
            if name in features:
                print(f"Flattening feature from layer '{name}' with shape: {features[name].shape}")  # Debugging
                flat = torch.flatten(features[name], start_dim=1)
                print(f"Flattened feature shape: {flat.shape}")  # Debugging
                flattened_features.append(flat)
            else:
                print(f"Warning: No feature captured for layer '{name}'")
        
        self.feature_extractor.clear_hooks()  # Clear hooks after processing features

        if not flattened_features:
            raise ValueError("No valid features captured for flattening and concatenation.")
        return torch.cat(flattened_features, dim=1)


In [16]:
class SimpleConvNet(BaseModel):
    def __init__(self, in_channels: int = 1, num_classes: int = 10):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.fc2 = nn.Linear(512, num_classes)
        
        # Layers for feature extraction
        self.feature_layers = ['conv2', 'fc1']

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        print("Input shape:", x.shape)  # Debugging
        x = self.pool(F.relu(self.conv1(x)))
        print("After conv1:", x.shape)  # Debugging
        x = self.pool(F.relu(self.conv2(x)))
        print("After conv2:", x.shape)  # Debugging
        x = torch.flatten(x, 1)
        print("After flatten:", x.shape)  # Debugging
        x = F.relu(self.fc1(x))
        print("After fc1:", x.shape)  # Debugging
        x = self.fc2(x)
        print("After fc2:", x.shape)  # Debugging
        return x

In [17]:
class ComplexConvNet(BaseModel):
    def __init__(self, in_channels: int = 1, num_classes: int = 10):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256 * 4 * 4, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, num_classes)
        self.dropout = nn.Dropout(0.5)

        # Layers for feature extraction
        self.feature_layers = ['conv3', 'fc2']

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        print("Input shape:", x.shape)  # Debugging
        x = self.pool(F.relu(self.conv1(x)))
        print("After conv1:", x.shape)  # Debugging
        x = self.pool(F.relu(self.conv2(x)))
        print("After conv2:", x.shape)  # Debugging
        x = self.pool(F.relu(self.conv3(x)))
        print("After conv3:", x.shape)  # Debugging
        x = torch.flatten(x, 1)
        print("After flatten:", x.shape)  # Debugging
        x = F.relu(self.fc1(x))
        print("After fc1:", x.shape)  # Debugging
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        print("After fc2:", x.shape)  # Debugging
        x = self.dropout(x)
        x = self.fc3(x)
        print("After fc3:", x.shape)  # Debugging
        return x

In [18]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(BaseModel):
    def __init__(self, in_channels: int = 1, num_classes: int = 10):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.layer1 = ResidualBlock(64, 64)
        self.layer2 = ResidualBlock(64, 128, stride=2)
        self.layer3 = ResidualBlock(128, 256, stride=2)
        
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, num_classes)
        
        # Specify layers to extract features from
        self.feature_layers = ['layer2', 'layer3']

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

class DenseBlock(nn.Module):
    def __init__(self, in_channels: int, growth_rate: int, num_layers: int):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(self._make_layer(in_channels + i * growth_rate, growth_rate))
    
    def _make_layer(self, in_channels: int, growth_rate: int) -> nn.Sequential:
        return nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, growth_rate, 3, padding=1)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        features = [x]
        for layer in self.layers:
            x = layer(torch.cat(features, 1))
            features.append(x)
        return torch.cat(features, 1)

class DenseNet(BaseModel):
    def __init__(self, in_channels: int = 1, num_classes: int = 10):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, 3, padding=1)
        self.dense1 = DenseBlock(64, growth_rate=32, num_layers=6)
        self.trans1 = nn.Sequential(
            nn.BatchNorm2d(64 + 6 * 32),
            nn.ReLU(inplace=True),
            nn.Conv2d(64 + 6 * 32, 128, 1),
            nn.AvgPool2d(2, 2)
        )
        self.dense2 = DenseBlock(128, growth_rate=32, num_layers=12)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128 + 12 * 32, num_classes)
        
        # Specify layers to extract features from
        self.feature_layers = ['dense1', 'dense2']

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.dense1(x)
        x = self.trans1(x)
        x = self.dense2(x)
        x = self.avg_pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [19]:
def test_model_and_features(model: BaseModel, input_shape: Tuple[int, ...]) -> None:
    """Test model forward pass and feature extraction"""
    print(f"\nTesting {model.__class__.__name__}")
    
    # Create dummy input
    x = torch.randn(input_shape)
    
    # Test forward pass
    try:
        out = model(x)
        print(f"Output shape: {out.shape}")
    except Exception as e:
        print(f"Error during forward pass: {e}")
        return

    # Test feature extraction
    try:
        features = model.get_features(x)
        print("\nFeature shapes:")
        for name, feat in features.items():
            print(f"{name}: {feat.shape}")
    except Exception as e:
        print(f"Error during feature extraction: {e}")
        return
    
    # Test flattened features
    try:
        flat_features = model.get_flattened_features(x)
        print(f"\nFlattened features shape: {flat_features.shape}")
    except Exception as e:
        print(f"Error during flattened feature extraction: {e}")

In [20]:
# class ModelFactory:
#     @staticmethod
#     def create_model(model_name: str, in_channels: int = 1, num_classes: int = 10) -> BaseModel:
#         """Create a model instance based on the model name"""
#         models = {
#             'simple': SimpleConvNet,
#             'complex': ComplexConvNet,
#             'resnet': ResNet,
#             'densenet': DenseNet
#         }
        
#         if model_name not in models:
#             raise ValueError(f"Model {model_name} not found. Available models: {list(models.keys())}")
        
#         return models[model_name](in_channels=in_channels, num_classes=num_classes)

import torch.nn as nn

class ModelWithProjection(nn.Module):
    def __init__(self, base_model, feature_dim, projection_dim=128):
        """
        Wrap a base model with a projection layer.

        Args:
            base_model (nn.Module): Original client model producing feature vectors.
            feature_dim (int): Dimensionality of the base model's feature vectors.
            projection_dim (int): Target dimensionality for the projection layer.
        """
        super(ModelWithProjection, self).__init__()
        self.base_model = base_model
        self.projection = nn.Linear(feature_dim, projection_dim)

    def forward(self, x):
        # Pass input through the base model
        features = self.base_model.get_flattened_features(x)
        # Project features to a lower dimension
        projected_features = self.projection(features)
        return projected_features

    def get_projected_features(self, x):
        """
        Extract and return the projected feature vectors.

        Args:
            x: Input tensor.
        """
        return self.forward(x)

class ModelFactory:
    @staticmethod
    def create_model(model_name: str, in_channels: int = 1, num_classes: int = 10, projection_dim: int = 128) -> nn.Module:
        """
        Create a model instance based on the model name and wrap it with a projection layer.

        Args:
            model_name (str): Name of the model ('simple', 'complex', 'resnet', 'densenet').
            in_channels (int): Number of input channels for the model.
            num_classes (int): Number of output classes for classification.
            projection_dim (int): Dimensionality of the projection layer.

        Returns:
            nn.Module: Model instance wrapped with a projection layer.
        """
        models = {
            'simple': SimpleConvNet,
            'complex': ComplexConvNet,
            'resnet': ResNet,
            'densenet': DenseNet
        }
        
        if model_name not in models:
            raise ValueError(f"Model {model_name} not found. Available models: {list(models.keys())}")
        
        # Create the base model
        base_model = models[model_name](in_channels=in_channels, num_classes=num_classes)
        
        # Add the expected_feature_dim attribute to the base model
        feature_dim_mapping = {
            'simple': 13056,    # Example flattened feature size for SimpleConvNet
            'complex': 20480,   # Example flattened feature size for ComplexConvNet
            'resnet': 37632,    # Example flattened feature size for ResNet
            'densenet': 50176   # Example flattened feature size for DenseNet
        }
        feature_dim = feature_dim_mapping[model_name]
        base_model.expected_feature_dim = feature_dim  # Add this attribute to the base model

        # Wrap the base model with a projection layer
        return ModelWithProjection(base_model, feature_dim=feature_dim, projection_dim=projection_dim)

In [21]:
if __name__ == "__main__":
    factory = ModelFactory()
    
    # Create models for both MNIST (1 channel) and SVHN (3 channels)
    mnist_model = factory.create_model('simple', in_channels=1)
    svhn_model = factory.create_model('complex', in_channels=3)
    
    # Test models
    test_model_and_features(mnist_model, (4, 1, 28, 28))  # MNIST dimensions
    test_model_and_features(svhn_model, (4, 3, 32, 32))  # SVHN dimensions


Testing ModelWithProjection
Attaching hook to layer: conv2
Attaching hook to layer: fc1
Input shape: torch.Size([4, 1, 28, 28])
After conv1: torch.Size([4, 32, 14, 14])
Hook triggered for layer 'conv2' with output shape: torch.Size([4, 64, 14, 14])
After conv2: torch.Size([4, 64, 7, 7])
After flatten: torch.Size([4, 3136])
Hook triggered for layer 'fc1' with output shape: torch.Size([4, 512])
After fc1: torch.Size([4, 512])
After fc2: torch.Size([4, 10])
Flattening feature from layer 'conv2' with shape: torch.Size([4, 64, 14, 14])
Flattened feature shape: torch.Size([4, 12544])
Flattening feature from layer 'fc1' with shape: torch.Size([4, 512])
Flattened feature shape: torch.Size([4, 512])
Output shape: torch.Size([4, 128])
Error during feature extraction: 'ModelWithProjection' object has no attribute 'get_features'

Testing ModelWithProjection
Attaching hook to layer: conv3
Attaching hook to layer: fc2
Input shape: torch.Size([4, 3, 32, 32])
After conv1: torch.Size([4, 64, 16, 16])


In [22]:
if __name__ == "__main__":
    # Test SimpleConvNet
    model = SimpleConvNet(in_channels=1, num_classes=10)
    dummy_input = torch.randn(4, 1, 28, 28)  # Batch size 4, MNIST input dimensions
    print("\nTesting SimpleConvNet forward pass:")
    model(dummy_input)

    # Test ComplexConvNet
    model = ComplexConvNet(in_channels=1, num_classes=10)
    dummy_input = torch.randn(4, 1, 32, 32)  # Batch size 4, SVHN input dimensions
    print("\nTesting ComplexConvNet forward pass:")
    model(dummy_input)


Testing SimpleConvNet forward pass:
Input shape: torch.Size([4, 1, 28, 28])
After conv1: torch.Size([4, 32, 14, 14])
After conv2: torch.Size([4, 64, 7, 7])
After flatten: torch.Size([4, 3136])
After fc1: torch.Size([4, 512])
After fc2: torch.Size([4, 10])

Testing ComplexConvNet forward pass:
Input shape: torch.Size([4, 1, 32, 32])
After conv1: torch.Size([4, 64, 16, 16])
After conv2: torch.Size([4, 128, 8, 8])
After conv3: torch.Size([4, 256, 4, 4])
After flatten: torch.Size([4, 4096])
After fc1: torch.Size([4, 1024])
After fc2: torch.Size([4, 512])
After fc3: torch.Size([4, 10])


In [23]:
# Example for SimpleConvNet
model = SimpleConvNet(in_channels=1)
print("SimpleConvNet Layers:")
for name, _ in model.named_modules():
    print(name)

# Example for ComplexConvNet
model = ComplexConvNet(in_channels=1)
print("\nComplexConvNet Layers:")
for name, _ in model.named_modules():
    print(name)

SimpleConvNet Layers:

conv1
conv2
pool
fc1
fc2

ComplexConvNet Layers:

conv1
conv2
conv3
pool
fc1
fc2
fc3
dropout


In [24]:
if __name__ == "__main__":
    # Test SimpleConvNet
    model = SimpleConvNet(in_channels=1, num_classes=10)
    dummy_input = torch.randn(4, 1, 28, 28)  # Batch size 4, MNIST input dimensions
    print("\nTesting SimpleConvNet forward pass:")
    model(dummy_input)

    # Test ComplexConvNet
    model = ComplexConvNet(in_channels=1, num_classes=10)
    dummy_input = torch.randn(4, 1, 32, 32)  # Batch size 4, SVHN input dimensions
    print("\nTesting ComplexConvNet forward pass:")
    model(dummy_input)


Testing SimpleConvNet forward pass:
Input shape: torch.Size([4, 1, 28, 28])
After conv1: torch.Size([4, 32, 14, 14])
After conv2: torch.Size([4, 64, 7, 7])
After flatten: torch.Size([4, 3136])
After fc1: torch.Size([4, 512])
After fc2: torch.Size([4, 10])

Testing ComplexConvNet forward pass:
Input shape: torch.Size([4, 1, 32, 32])
After conv1: torch.Size([4, 64, 16, 16])
After conv2: torch.Size([4, 128, 8, 8])
After conv3: torch.Size([4, 256, 4, 4])
After flatten: torch.Size([4, 4096])
After fc1: torch.Size([4, 1024])
After fc2: torch.Size([4, 512])
After fc3: torch.Size([4, 10])


In [25]:
if __name__ == "__main__":
    # Test SimpleConvNet
    model = SimpleConvNet(in_channels=1, num_classes=10)
    dummy_input = torch.randn(4, 1, 28, 28)  # Batch size 4, MNIST input dimensions
    print("\nTesting SimpleConvNet feature extraction:")
    features = model.get_features(dummy_input)
    for name, feature in features.items():
        print(f"{name}: {feature.shape}")

    # Test ComplexConvNet
    model = ComplexConvNet(in_channels=1, num_classes=10)
    dummy_input = torch.randn(4, 1, 32, 32)  # Batch size 4, SVHN input dimensions
    print("\nTesting ComplexConvNet feature extraction:")
    features = model.get_features(dummy_input)
    for name, feature in features.items():
        print(f"{name}: {feature.shape}")


Testing SimpleConvNet feature extraction:
Attaching hook to layer: conv2
Attaching hook to layer: fc1
Input shape: torch.Size([4, 1, 28, 28])
After conv1: torch.Size([4, 32, 14, 14])
Hook triggered for layer 'conv2' with output shape: torch.Size([4, 64, 14, 14])
After conv2: torch.Size([4, 64, 7, 7])
After flatten: torch.Size([4, 3136])
Hook triggered for layer 'fc1' with output shape: torch.Size([4, 512])
After fc1: torch.Size([4, 512])
After fc2: torch.Size([4, 10])

Testing ComplexConvNet feature extraction:
Attaching hook to layer: conv3
Attaching hook to layer: fc2
Input shape: torch.Size([4, 1, 32, 32])
After conv1: torch.Size([4, 64, 16, 16])
After conv2: torch.Size([4, 128, 8, 8])
Hook triggered for layer 'conv3' with output shape: torch.Size([4, 256, 8, 8])
After conv3: torch.Size([4, 256, 4, 4])
After flatten: torch.Size([4, 4096])
After fc1: torch.Size([4, 1024])
Hook triggered for layer 'fc2' with output shape: torch.Size([4, 512])
After fc2: torch.Size([4, 512])
After fc3

In [26]:
if __name__ == "__main__":
    # Test SimpleConvNet
    model = SimpleConvNet(in_channels=1, num_classes=10)
    dummy_input = torch.randn(4, 1, 28, 28)  # Batch size 4, MNIST input dimensions
    print("\nTesting SimpleConvNet flattened feature extraction:")
    flattened_features = model.get_flattened_features(dummy_input)
    print(f"Flattened features shape: {flattened_features.shape}")


Testing SimpleConvNet flattened feature extraction:
Attaching hook to layer: conv2
Attaching hook to layer: fc1
Input shape: torch.Size([4, 1, 28, 28])
After conv1: torch.Size([4, 32, 14, 14])
Hook triggered for layer 'conv2' with output shape: torch.Size([4, 64, 14, 14])
After conv2: torch.Size([4, 64, 7, 7])
After flatten: torch.Size([4, 3136])
Hook triggered for layer 'fc1' with output shape: torch.Size([4, 512])
After fc1: torch.Size([4, 512])
After fc2: torch.Size([4, 10])
Flattening feature from layer 'conv2' with shape: torch.Size([4, 64, 14, 14])
Flattened feature shape: torch.Size([4, 12544])
Flattening feature from layer 'fc1' with shape: torch.Size([4, 512])
Flattened feature shape: torch.Size([4, 512])
Flattened features shape: torch.Size([4, 13056])


In [27]:
if __name__ == "__main__":
    # Test SimpleConvNet
    model = SimpleConvNet(in_channels=1, num_classes=10)
    dummy_input = torch.randn(4, 1, 28, 28)  # Batch size 4, MNIST input dimensions
    print("\nTesting SimpleConvNet flattened feature extraction:")
    flattened_features = model.get_flattened_features(dummy_input)
    print(f"Flattened features shape: {flattened_features.shape}")

    # Test ComplexConvNet
    model = ComplexConvNet(in_channels=1, num_classes=10)
    dummy_input = torch.randn(4, 1, 32, 32)  # Batch size 4, SVHN input dimensions
    print("\nTesting ComplexConvNet flattened feature extraction:")
    flattened_features = model.get_flattened_features(dummy_input)
    print(f"Flattened features shape: {flattened_features.shape}")


Testing SimpleConvNet flattened feature extraction:
Attaching hook to layer: conv2
Attaching hook to layer: fc1
Input shape: torch.Size([4, 1, 28, 28])
After conv1: torch.Size([4, 32, 14, 14])
Hook triggered for layer 'conv2' with output shape: torch.Size([4, 64, 14, 14])
After conv2: torch.Size([4, 64, 7, 7])
After flatten: torch.Size([4, 3136])
Hook triggered for layer 'fc1' with output shape: torch.Size([4, 512])
After fc1: torch.Size([4, 512])
After fc2: torch.Size([4, 10])
Flattening feature from layer 'conv2' with shape: torch.Size([4, 64, 14, 14])
Flattened feature shape: torch.Size([4, 12544])
Flattening feature from layer 'fc1' with shape: torch.Size([4, 512])
Flattened feature shape: torch.Size([4, 512])
Flattened features shape: torch.Size([4, 13056])

Testing ComplexConvNet flattened feature extraction:
Attaching hook to layer: conv3
Attaching hook to layer: fc2
Input shape: torch.Size([4, 1, 32, 32])
After conv1: torch.Size([4, 64, 16, 16])
After conv2: torch.Size([4, 128

In [28]:
if __name__ == "__main__":
    # Test SimpleConvNet
    model = SimpleConvNet(in_channels=1, num_classes=10)
    dummy_input = torch.randn(4, 1, 28, 28)  # Batch size 4, MNIST input dimensions
    print("\nTesting SimpleConvNet flattened feature extraction:")
    flattened_features = model.get_flattened_features(dummy_input)
    print(f"Flattened features shape: {flattened_features.shape}")


Testing SimpleConvNet flattened feature extraction:
Attaching hook to layer: conv2
Attaching hook to layer: fc1
Input shape: torch.Size([4, 1, 28, 28])
After conv1: torch.Size([4, 32, 14, 14])
Hook triggered for layer 'conv2' with output shape: torch.Size([4, 64, 14, 14])
After conv2: torch.Size([4, 64, 7, 7])
After flatten: torch.Size([4, 3136])
Hook triggered for layer 'fc1' with output shape: torch.Size([4, 512])
After fc1: torch.Size([4, 512])
After fc2: torch.Size([4, 10])
Flattening feature from layer 'conv2' with shape: torch.Size([4, 64, 14, 14])
Flattened feature shape: torch.Size([4, 12544])
Flattening feature from layer 'fc1' with shape: torch.Size([4, 512])
Flattened feature shape: torch.Size([4, 512])
Flattened features shape: torch.Size([4, 13056])


In [29]:
# Assign models to clients
NUM_CLIENTS = 10  # Number of clients
factory = ModelFactory()

client_models = {
    f"client_{i}": factory.create_model(
        'simple' if i % 2 == 0 else 'resnet',  # Alternate between 'simple' and 'resnet'
        in_channels=1, num_classes=10
    )
    for i in range(NUM_CLIENTS)
}

# Example: Print assigned models
for client, model in client_models.items():
    print(f"{client} is assigned model: {model.__class__.__name__}")

client_0 is assigned model: ModelWithProjection
client_1 is assigned model: ModelWithProjection
client_2 is assigned model: ModelWithProjection
client_3 is assigned model: ModelWithProjection
client_4 is assigned model: ModelWithProjection
client_5 is assigned model: ModelWithProjection
client_6 is assigned model: ModelWithProjection
client_7 is assigned model: ModelWithProjection
client_8 is assigned model: ModelWithProjection
client_9 is assigned model: ModelWithProjection


In [30]:
# Example: Test client models with appropriate datasets
for client, model in client_models.items():
    input_shape = (4, 1, 28, 28)  # Example input for MNIST
    print(f"\nTesting {client}'s model: {model.__class__.__name__}")
    test_model_and_features(model, input_shape)


Testing client_0's model: ModelWithProjection

Testing ModelWithProjection
Attaching hook to layer: conv2
Attaching hook to layer: fc1
Input shape: torch.Size([4, 1, 28, 28])
After conv1: torch.Size([4, 32, 14, 14])
Hook triggered for layer 'conv2' with output shape: torch.Size([4, 64, 14, 14])
After conv2: torch.Size([4, 64, 7, 7])
After flatten: torch.Size([4, 3136])
Hook triggered for layer 'fc1' with output shape: torch.Size([4, 512])
After fc1: torch.Size([4, 512])
After fc2: torch.Size([4, 10])
Flattening feature from layer 'conv2' with shape: torch.Size([4, 64, 14, 14])
Flattened feature shape: torch.Size([4, 12544])
Flattening feature from layer 'fc1' with shape: torch.Size([4, 512])
Flattened feature shape: torch.Size([4, 512])
Output shape: torch.Size([4, 128])
Error during feature extraction: 'ModelWithProjection' object has no attribute 'get_features'

Testing client_1's model: ModelWithProjection

Testing ModelWithProjection
Attaching hook to layer: layer2
Attaching hook 

# Step 4: Core FedISMH Components
Now for the heart of the implementation:

Feature vector extraction:

Implement functions that run the public dataset through each client's model

Extract and flatten features from intermediate layers

This gives you a compressed representation of how each model "sees" data

Similarity matrix construction:

Calculate cosine similarity between all client feature vectors

Transform the similarity values to the range for easier interpretation

For large numbers of clients, optimize by only computing the upper triangle of the matrix

DBSCAN clustering:

Apply DBSCAN to the similarity matrix with parameters ε=0.15 and mmin=1

Group clients into clusters based on model similarity

Identify noise clients (outliers) that don't fit well into any cluster

Noise client management:

Calculate internal similarity within noise clusters

Reclassify noise clusters as valid if their internal similarity exceeds θnoise=0.65

Align remaining noise clients with their most similar valid clusters

This is the most novel part of the paper, so take time to understand the intuition behind these steps.



In [31]:
import flwr as fl

# Define Flower client
class FederatedClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader, test_loader):
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader

    def get_parameters(self):
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters):
        state_dict = dict(zip(self.model.state_dict().keys(), parameters))
        self.model.load_state_dict(state_dict)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        # Train the model here
        return self.get_parameters(), len(self.train_loader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        # Evaluate the model here
        return 0.0, len(self.test_loader.dataset), {}

In [32]:
import torch
from typing import Tuple

def test_model_features(model, input_shape: Tuple[int, ...], feature_layers: list):
    """
    Test feature extraction for a given model and feature layers.
    
    Args:
        model: The model to test.
        input_shape: Shape of the input tensor (batch_size, channels, height, width).
        feature_layers: List of layers to extract features from.
    """
    print(f"\nTesting {model.__class__.__name__}")

    # Set feature layers for extraction 
    model.feature_layers = feature_layers

    # Create dummy input
    x = torch.randn(input_shape)

    # Test feature extraction
    try:
        features = model.get_features(x)
        print(f"Extracted Features for {feature_layers}:")
        for layer, feature in features.items():
            print(f"  Layer {layer}: Shape {feature.shape}")
    except Exception as e:
        print(f"Error during feature extraction: {e}")

    # Test flattened features
    try:
        flattened_features = model.get_flattened_features(x)
        print(f"Flattened Features Shape: {flattened_features.shape}")
    except Exception as e:
        print(f"Error during flattened feature extraction: {e}")


if __name__ == "__main__":
    # Input shapes for MNIST and SVHN 
    mnist_input_shape = (4, 1, 28, 28)  # Batch size 4, single channel, 28x28 (MNIST)
    svhn_input_shape = (4, 3, 32, 32)   # Batch size 4, three channels, 32x32 (SVHN)

    # Models and corresponding layers
    models_to_test = [
        {"model": SimpleConvNet(in_channels=1), "input_shape": mnist_input_shape, "layers": ['conv2', 'fc1']},
        {"model": ComplexConvNet(in_channels=3), "input_shape": svhn_input_shape, "layers": ['conv3', 'fc2']}, 
        {"model": ResNet(in_channels=1), "input_shape": mnist_input_shape, "layers": ['layer2', 'layer3']},
        {"model": DenseNet(in_channels=1), "input_shape": mnist_input_shape, "layers": ['dense1', 'dense2']}
    ]

    # Test each model
    for item in models_to_test:
        test_model_features(item["model"], item["input_shape"], item["layers"])


Testing SimpleConvNet
Attaching hook to layer: conv2
Attaching hook to layer: fc1
Input shape: torch.Size([4, 1, 28, 28])
After conv1: torch.Size([4, 32, 14, 14])
Hook triggered for layer 'conv2' with output shape: torch.Size([4, 64, 14, 14])
After conv2: torch.Size([4, 64, 7, 7])
After flatten: torch.Size([4, 3136])
Hook triggered for layer 'fc1' with output shape: torch.Size([4, 512])
After fc1: torch.Size([4, 512])
After fc2: torch.Size([4, 10])
Extracted Features for ['conv2', 'fc1']:
Attaching hook to layer: conv2
Attaching hook to layer: fc1
Input shape: torch.Size([4, 1, 28, 28])
After conv1: torch.Size([4, 32, 14, 14])
Hook triggered for layer 'conv2' with output shape: torch.Size([4, 64, 14, 14])
After conv2: torch.Size([4, 64, 7, 7])
After flatten: torch.Size([4, 3136])
Hook triggered for layer 'fc1' with output shape: torch.Size([4, 512])
After fc1: torch.Size([4, 512])
After fc2: torch.Size([4, 10])
Flattening feature from layer 'conv2' with shape: torch.Size([4, 64, 14, 1

# What’s Done
Data Preparation:

Datasets are split into public, private (client-specific), and test datasets.
Both IID and Non-IID splits are supported.
The splits are mutually exclusive, and the sizes are verified.
Model Architectures:

Four architectures (SimpleConvNet, ComplexConvNet, ResNet, DenseNet) have been implemented.
Feature extraction from intermediate layers is integrated and tested.
Models are ready for use in a federated learning setup.
Client-Model Assignment:

Clients can now be dynamically assigned different models (heterogeneous architectures).
Compatibility with data dimensions is ensured.

## Next Steps
Now that datasets and models are ready, you can move on to the core components of FedISMH:

Step 1: Feature Vector Extraction
Goal: Extract and flatten feature vectors for all clients using the public dataset.
Why: These feature vectors represent each model’s "view" of the public dataset and are essential for clustering.
Here’s what to do:

Create a script to pass the public dataset through each client’s model.
Extract and flatten the features from the predefined intermediate layers.
Store the feature vectors for use in the next step.
Step 2: Similarity Matrix Construction
Goal: Compute a cosine similarity matrix between all client feature vectors.
Why: The similarity matrix quantifies how similar or different the models are in their representation of the public dataset.
Here’s what to do:

Use the flattened feature vectors to compute pairwise cosine similarity.
Optimize by calculating only the upper triangle of the similarity matrix (symmetric matrix).
Step 3: Clustering with DBSCAN
Goal: Group similar clients into clusters using DBSCAN.
Why: Clustering identifies groups of similar models and isolates outliers (noise clients).
Parameters:
ε (epsilon) = 0.15: Maximum distance between points in a cluster.
m_min = 1: Minimum number of points to form a cluster.
Step 4: Noise Client Management
Goal: Handle noise clients that don’t belong to any cluster.
Why: Noise clients can degrade the performance of the framework if not managed properly.
Steps:
Calculate internal similarity within noise clusters.
Reclassify noise clusters as valid if their internal similarity exceeds θ_noise = 0.65.
Align remaining noise clients with their most similar valid clusters.
Step 5: Federated Simulation
Goal: Integrate everything into a federated learning framework (e.g., Flower).
Why: This allows you to simulate the full FedISMH methodology, including model training, clustering, and aggregation.

## Proposed Order of Work

### Feature Vector Extraction:

Write a script to extract and store feature vectors for all clients.
### Similarity Matrix Construction:

Compute the similarity matrix for client feature vectors.
### Clustering and Noise Management:

Apply DBSCAN clustering and handle noise clients.
### Federated Simulation:

Use Flower to integrate the components into a federated learning simulation.



In [33]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split, Subset
import numpy as np
import random
import os

# Configurations
DATA_DIR = os.path.join(os.path.expanduser("~"), "Desktop", "data")
SEED = 42
NUM_CLIENTS = 10
NON_IID_CLASSES_PER_CLIENT = 2

# Set random seed for reproducibility
random.seed(SEED)
torch.manual_seed(SEED)
np.random.seed(SEED)

# Transformations
transform_mnist = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

transform_svhn = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))
])

# Load datasets
def load_datasets():
    mnist_train = torchvision.datasets.MNIST(root=DATA_DIR, train=True, download=True, transform=transform_mnist)
    mnist_test = torchvision.datasets.MNIST(root=DATA_DIR, train=False, download=True, transform=transform_mnist)

    svhn_train = torchvision.datasets.SVHN(root=DATA_DIR, split='train', download=True, transform=transform_svhn)
    svhn_test = torchvision.datasets.SVHN(root=DATA_DIR, split='test', download=True, transform=transform_svhn)

    return mnist_train, mnist_test, svhn_train, svhn_test

# IID Split
def iid_split(dataset, num_clients):
    data_per_client = len(dataset) // num_clients
    indices = torch.randperm(len(dataset))
    return {i: Subset(dataset, indices[i * data_per_client:(i + 1) * data_per_client]) for i in range(num_clients)}

# Non-IID Split
def get_targets(dataset):
    while isinstance(dataset, Subset):
        dataset = dataset.dataset
    if hasattr(dataset, 'targets'):
        return np.array(dataset.targets)
    elif hasattr(dataset, 'labels'):
        return np.array(dataset.labels)
    else:
        raise AttributeError("Dataset does not have 'targets' or 'labels' attribute.")

def noniid_split(dataset, num_clients, classes_per_client):
    full_targets = get_targets(dataset)
    if isinstance(dataset, Subset):
        subset_indices = dataset.indices
        targets = full_targets[subset_indices]
    else:
        targets = full_targets

    class_indices = {cls: np.where(targets == cls)[0] for cls in np.unique(targets)}
    client_indices = {i: [] for i in range(num_clients)}
    available_classes = list(class_indices.keys())
    random.shuffle(available_classes)

    for i in range(num_clients):
        selected_classes = random.sample(available_classes, classes_per_client)
        for cls in selected_classes:
            n_samples = max(1, len(class_indices[cls]) // num_clients)
            chosen = np.random.choice(class_indices[cls], size=n_samples, replace=True)
            client_indices[i].extend(chosen)

    return {i: Subset(dataset, client_indices[i]) for i in range(num_clients)}

# Dataset Preparation
def prepare_datasets(dataset, num_clients=NUM_CLIENTS, classes_per_client=NON_IID_CLASSES_PER_CLIENT, iid=True):
    total_indices = list(range(len(dataset)))
    random.shuffle(total_indices)

    test_len = int(0.2 * len(dataset))
    public_len = int(0.1 * (len(dataset) - test_len))
    private_len = len(dataset) - test_len - public_len

    test_indices = total_indices[:test_len]
    public_indices = total_indices[test_len:test_len + public_len]
    private_indices = total_indices[test_len + public_len:]

    test_set = Subset(dataset, test_indices)
    public_set = Subset(dataset, public_indices)
    private_set = Subset(dataset, private_indices)

    if iid:
        clients = iid_split(private_set, num_clients)
    else:
        clients = noniid_split(private_set, num_clients, classes_per_client)

    return {
        "public": public_set,
        "test": test_set,
        "clients": clients
    }

In [34]:
import os

FEATURE_DIR = "feature_vectors"

# Create the directory if it doesn't exist
if not os.path.exists(FEATURE_DIR):
    os.makedirs(FEATURE_DIR)
    print(f"Created directory: {FEATURE_DIR}")

# Now list files
files = os.listdir(FEATURE_DIR)
if files:
    print(f"Files in {FEATURE_DIR}: {files}")
else:
    print(f"Directory {FEATURE_DIR} is empty")

Files in feature_vectors: ['client_0_features.pt', 'client_1_features.pt', 'client_2_features.pt', 'client_3_features.pt', 'client_4_features.pt', 'client_5_features.pt', 'client_6_features.pt', 'client_7_features.pt', 'client_8_features.pt', 'client_9_features.pt']


In [35]:
import torch
from torch.utils.data import DataLoader
from typing import Dict
import os

# Configurations
BATCH_SIZE = 32
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FEATURE_DIR = "feature_vectors"  # Directory to save feature vectors

# Ensure FEATURE_DIR exists
if not os.path.exists(FEATURE_DIR):
    os.makedirs(FEATURE_DIR)

def extract_and_save_feature_vectors(
    client_models: Dict[str, torch.nn.Module],
    public_dataset: torch.utils.data.Dataset
):
    """
    Extract and save feature vectors for each client using the public dataset.

    Args:
        client_models (Dict[str, torch.nn.Module]): Dictionary of client names and models.
        public_dataset (torch.utils.data.Dataset): Public dataset to extract features.
    """
    # DataLoader for the public dataset
    public_loader = DataLoader(public_dataset, batch_size=BATCH_SIZE, shuffle=False)

    for client_name, model in client_models.items():
        print(f"Processing {client_name}...")

        # Move model to the appropriate device and set to evaluation mode
        model.to(DEVICE)
        model.eval()

        # List to store all projected feature vectors for the client
        all_features = []

        with torch.no_grad():
            for data, _ in public_loader:  # Ignore labels
                data = data.to(DEVICE)  # Move data to device
                try:
                    # **Extract projected feature vectors using the projection layer**
                    features = model.get_projected_features(data)
                    all_features.append(features.cpu())  # Move to CPU for saving
                except Exception as e:
                    print(f"Error extracting features for {client_name}: {e}")

        # Concatenate all features
        all_features = torch.cat(all_features, dim=0)
        print(f"Extracted {all_features.shape[0]} features for {client_name}. Final shape: {all_features.shape}")

        # Save features to a file
        feature_file = os.path.join(FEATURE_DIR, f"{client_name}_features.pt")
        torch.save(all_features, feature_file)
        print(f"Saved features for {client_name} to {feature_file}")

In [36]:
# Prepare the public dataset
mnist_train, _, _, _ = load_datasets()
datasets = prepare_datasets(mnist_train, iid=False)  # Non-IID split
public_dataset = datasets["public"]

# Create models for each client
factory = ModelFactory()
client_models = {
    f"client_{i}": factory.create_model(
        'simple' if i % 2 == 0 else 'resnet',  # Alternate between SimpleConvNet and ResNet
        in_channels=1,
        num_classes=10
    )
    for i in range(NUM_CLIENTS)
}

# Extract and save feature vectors
extract_and_save_feature_vectors(client_models, public_dataset)

  return np.array(dataset.targets)


Processing client_0...
Attaching hook to layer: conv2
Attaching hook to layer: fc1
Input shape: torch.Size([32, 1, 28, 28])
After conv1: torch.Size([32, 32, 14, 14])
Hook triggered for layer 'conv2' with output shape: torch.Size([32, 64, 14, 14])
After conv2: torch.Size([32, 64, 7, 7])
After flatten: torch.Size([32, 3136])
Hook triggered for layer 'fc1' with output shape: torch.Size([32, 512])
After fc1: torch.Size([32, 512])
After fc2: torch.Size([32, 10])
Flattening feature from layer 'conv2' with shape: torch.Size([32, 64, 14, 14])
Flattened feature shape: torch.Size([32, 12544])
Flattening feature from layer 'fc1' with shape: torch.Size([32, 512])
Flattened feature shape: torch.Size([32, 512])
Attaching hook to layer: conv2
Attaching hook to layer: fc1
Input shape: torch.Size([32, 1, 28, 28])
After conv1: torch.Size([32, 32, 14, 14])
Hook triggered for layer 'conv2' with output shape: torch.Size([32, 64, 14, 14])
After conv2: torch.Size([32, 64, 7, 7])
After flatten: torch.Size([32

In [37]:
import os
import torch

FEATURE_DIR = "feature_vectors"
for file in os.listdir(FEATURE_DIR):
    if file.endswith("_features.pt"):
        features = torch.load(os.path.join(FEATURE_DIR, file))
        print(f"{file}: {features.shape}")

client_0_features.pt: torch.Size([4800, 128])
client_1_features.pt: torch.Size([4800, 128])
client_2_features.pt: torch.Size([4800, 128])
client_3_features.pt: torch.Size([4800, 128])
client_4_features.pt: torch.Size([4800, 128])
client_5_features.pt: torch.Size([4800, 128])
client_6_features.pt: torch.Size([4800, 128])
client_7_features.pt: torch.Size([4800, 128])
client_8_features.pt: torch.Size([4800, 128])
client_9_features.pt: torch.Size([4800, 128])


In [38]:
import torch

sample_file = "feature_vectors/client_0_features.pt"
features = torch.load(sample_file)
print(f"Loaded features shape: {features.shape}")

Loaded features shape: torch.Size([4800, 128])


In [39]:
import torch
import numpy as np
import os

# Configurations
FEATURE_DIR = "feature_vectors"  # Directory where feature vectors are stored
OUTPUT_FILE = "similarity_matrix.npy"  # File to save the similarity matrix

def load_feature_vectors(feature_dir: str):
    """
    Load feature vectors for all clients from the specified directory.

    Args:
        feature_dir (str): Directory containing feature vector files.

    Returns:
        Dict[str, torch.Tensor]: Dictionary of client names and their feature vectors.
    """
    feature_vectors = {}
    for file in os.listdir(feature_dir):
        if file.endswith("_features.pt"):
            client_name = file.replace("_features.pt", "")
            try:
                features = torch.load(os.path.join(feature_dir, file))
                if features is None or features.numel() == 0:
                    print(f"Warning: Empty feature tensor for {client_name}")
                else:
                    print(f"Loaded features for {client_name}: {features.shape}")
                    feature_vectors[client_name] = features
            except Exception as e:
                print(f"Error loading features for {client_name}: {e}")
    return feature_vectors

def compute_similarity_matrix(feature_vectors: dict):
    """
    Compute pairwise cosine similarity matrix for all clients.

    Args:
        feature_vectors (dict): Dictionary of client names and their feature vectors.

    Returns:
        np.ndarray: Pairwise cosine similarity matrix.
    """
    # Collect all feature vectors into a single matrix
    client_names = list(feature_vectors.keys())
    if len(client_names) == 0:
        raise RuntimeError("No feature vectors loaded. Ensure the feature_vectors dictionary is not empty.")

    print("Client names:", client_names)
    
    all_features = []
    for client in client_names:
        features = feature_vectors[client]
        print(f"{client} features shape: {features.shape}")
        all_features.append(features)
    
    all_features = torch.stack(all_features)  # Shape: (num_clients, num_samples, feature_dim)

    # Average feature vectors across samples for each client
    avg_features = torch.mean(all_features, dim=1)  # Shape: (num_clients, feature_dim)

    # Normalize the feature vectors
    norms = torch.norm(avg_features, p=2, dim=1, keepdim=True)
    normalized_features = avg_features / norms  # Shape: (num_clients, feature_dim)

    # Compute cosine similarity matrix
    similarity_matrix = torch.mm(normalized_features, normalized_features.T).numpy()  # Shape: (num_clients, num_clients)

    return similarity_matrix, client_names

if __name__ == "__main__":
    # Step 1: Load feature vectors
    feature_vectors = load_feature_vectors(FEATURE_DIR)

    # Step 2: Compute similarity matrix
    similarity_matrix, client_names = compute_similarity_matrix(feature_vectors)

    # Step 3: Save similarity matrix
    np.save(OUTPUT_FILE, similarity_matrix)
    print(f"Similarity matrix saved to {OUTPUT_FILE}")

    # Print the matrix for verification
    print("\nSimilarity Matrix:")
    print(similarity_matrix)

    # Print client names for reference
    print("\nClient Names:")
    print(client_names)

Loaded features for client_0: torch.Size([4800, 128])
Loaded features for client_1: torch.Size([4800, 128])
Loaded features for client_2: torch.Size([4800, 128])
Loaded features for client_3: torch.Size([4800, 128])
Loaded features for client_4: torch.Size([4800, 128])
Loaded features for client_5: torch.Size([4800, 128])
Loaded features for client_6: torch.Size([4800, 128])
Loaded features for client_7: torch.Size([4800, 128])
Loaded features for client_8: torch.Size([4800, 128])
Loaded features for client_9: torch.Size([4800, 128])
Client names: ['client_0', 'client_1', 'client_2', 'client_3', 'client_4', 'client_5', 'client_6', 'client_7', 'client_8', 'client_9']
client_0 features shape: torch.Size([4800, 128])
client_1 features shape: torch.Size([4800, 128])
client_2 features shape: torch.Size([4800, 128])
client_3 features shape: torch.Size([4800, 128])
client_4 features shape: torch.Size([4800, 128])
client_5 features shape: torch.Size([4800, 128])
client_6 features shape: torch.

# What’s Left
Now that the feature vectors are ready, let’s move on to the next steps:

## Step 2: Similarity Matrix Construction

Compute a cosine similarity matrix between all clients using their feature vectors.
This matrix quantifies how similar or different the clients’ models are in their representation of the public dataset.
## Step 3: Clustering with DBSCAN

Use the similarity matrix to group clients into clusters using the DBSCAN algorithm.
Identify and handle noise clients (outliers).
## Step 4: Noise Client Management

Reclassify noise clusters as valid if their internal similarity exceeds a threshold.
Align remaining noise clients with the most similar valid clusters.
## Step 5: Federated Simulation

Simulate the entire FedISMH pipeline using the Flower framework.
Train, cluster, and aggregate models across clients in a federated learning setup.

### Code for Clustering Using DBSCAN

In [40]:
from sklearn.cluster import DBSCAN
import numpy as np

def cluster_clients(similarity_matrix: np.ndarray, client_names: list, eps=0.15, min_samples=1):
    """
    Cluster clients using DBSCAN based on the similarity matrix.

    Args:
        similarity_matrix (np.ndarray): Pairwise cosine similarity matrix.
        client_names (list): List of client names corresponding to the similarity matrix.
        eps (float): Maximum distance between points in a cluster.
        min_samples (int): Minimum number of points to form a cluster.

    Returns:
        dict: Dictionary with cluster assignments for each client.
    """
    # Clip similarity matrix values to the range [-1, 1]
    similarity_matrix = np.clip(similarity_matrix, -1, 1)

    # Convert similarity matrix to distance matrix (1 - similarity)
    distance_matrix = 1 - similarity_matrix

    # Ensure all values in the distance matrix are non-negative
    if np.min(distance_matrix) < 0:
        raise ValueError("Distance matrix contains negative values. Check similarity matrix normalization.")

    # Apply DBSCAN
    dbscan = DBSCAN(eps=eps, min_samples=min_samples, metric="precomputed")
    cluster_labels = dbscan.fit_predict(distance_matrix)

    # Map clients to their respective clusters
    clusters = {}
    for client, cluster in zip(client_names, cluster_labels):
        clusters[client] = cluster

    return clusters

if __name__ == "__main__":
    # Load the saved similarity matrix
    similarity_matrix = np.load("similarity_matrix.npy")

    # List of client names (ensure it's accessible in your script)
    client_names = [f"client_{i}" for i in range(len(similarity_matrix))]

    # Cluster clients
    eps = .87  # Maximum distance between points in a cluster
    min_samples = 1  # Minimum number of points to form a cluster
    try:
        clusters = cluster_clients(similarity_matrix, client_names, eps, min_samples)

        # Print cluster assignments
        print("\nCluster Assignments:")
        for client, cluster in clusters.items():
            print(f"{client}: Cluster {cluster}")
    except ValueError as e:
        print(f"Error: {e}")


Cluster Assignments:
client_0: Cluster 0
client_1: Cluster 0
client_2: Cluster 1
client_3: Cluster 0
client_4: Cluster 0
client_5: Cluster 0
client_6: Cluster 2
client_7: Cluster 3
client_8: Cluster 0
client_9: Cluster 4


# Step 4: Noise Client Management
In this step, we will:

1. Reclassify Noise Clients:

    * If any clients were marked as noise (Cluster -1), check their internal similarity.
    * If their internal similarity exceeds a threshold, reclassify them as valid clusters.

2. Align Noise Clients:

    * For remaining noise clients, align them with the most similar valid cluster based on their similarity values to other clusters.
# Step 4 Implementation
1. Reclassify Noise Clients
    * To reclassify noise clients, calculate the average pairwise similarity between these clients.
    * If the average similarity exceeds a predefined threshold (e.g., 0.5), treat them as a new valid cluster.
2. Align Remaining Noise Clients
    * For each remaining noise client:
        * Compare its similarity with the average feature vectors of all valid clusters.
        * Assign it to the cluster with the highest similarity.


In [41]:
import numpy as np

def manage_noise_clients(similarity_matrix: np.ndarray, cluster_labels: np.ndarray, threshold=0.5):
    """
    Handle noise clients by reclassifying or aligning them with valid clusters.

    Args:
        similarity_matrix (np.ndarray): Pairwise cosine similarity matrix.
        cluster_labels (np.ndarray): Cluster labels from DBSCAN (-1 for noise).
        threshold (float): Similarity threshold to reclassify noise clients.

    Returns:
        np.ndarray: Updated cluster labels after handling noise clients.
    """
    # Identify noise clients
    noise_clients = np.where(cluster_labels == -1)[0]
    valid_clusters = np.unique(cluster_labels[cluster_labels != -1])

    if len(noise_clients) == 0:
        print("No noise clients to manage.")
        return cluster_labels

    print(f"Found {len(noise_clients)} noise clients: {noise_clients}")

    # Step 1: Reclassify noise clients as a new cluster if internal similarity exceeds threshold
    reclassified_clients = []
    for i in noise_clients:
        for j in noise_clients:
            if i != j:
                avg_similarity = np.mean(similarity_matrix[noise_clients][:, noise_clients])
                if avg_similarity > threshold:
                    reclassified_clients.append(i)
                    break

    if reclassified_clients:
        new_cluster_label = cluster_labels.max() + 1
        for client in reclassified_clients:
            cluster_labels[client] = new_cluster_label
        print(f"Reclassified {len(reclassified_clients)} noise clients into new cluster {new_cluster_label}.")

    # Step 2: Align remaining noise clients with the most similar valid cluster
    remaining_noise_clients = np.where(cluster_labels == -1)[0]
    if remaining_noise_clients.size > 0:
        print(f"Aligning {len(remaining_noise_clients)} remaining noise clients...")
        for client in remaining_noise_clients:
            # Compute similarity with each valid cluster
            cluster_similarities = []
            for cluster in valid_clusters:
                cluster_members = np.where(cluster_labels == cluster)[0]
                avg_similarity = np.mean(similarity_matrix[client, cluster_members])
                cluster_similarities.append((cluster, avg_similarity))

            # Assign to the most similar cluster
            best_cluster = max(cluster_similarities, key=lambda x: x[1])[0]
            cluster_labels[client] = best_cluster
            print(f"Aligned client {client} to cluster {best_cluster}.")

    return cluster_labels

if __name__ == "__main__":
    # Load the similarity matrix and cluster labels
    similarity_matrix = np.load("similarity_matrix.npy")

    # Example cluster labels (replace this with your actual DBSCAN output)
    cluster_labels = np.array([0, 0, -1, 0, 0, 0, 2, 3, 0, 4])

    # Handle noise clients
    updated_labels = manage_noise_clients(similarity_matrix, cluster_labels, threshold=0.5)

    # Print updated cluster assignments
    print("\nUpdated Cluster Assignments:")
    for i, label in enumerate(updated_labels):
        print(f"client_{i}: Cluster {label}")

Found 1 noise clients: [2]
Aligning 1 remaining noise clients...
Aligned client 2 to cluster 2.

Updated Cluster Assignments:
client_0: Cluster 0
client_1: Cluster 0
client_2: Cluster 2
client_3: Cluster 0
client_4: Cluster 0
client_5: Cluster 0
client_6: Cluster 2
client_7: Cluster 3
client_8: Cluster 0
client_9: Cluster 4


# Step 5: Federated Simulation
In this step, we will:

* Simulate the FedISMH pipeline using the Flower framework.
* Train, cluster, and aggregate models across clients in a federated learning setup.

## Overview of the FedISMH Pipeline

1. Train Federated Models:

    * Each client trains its local model on its own data.
    * Models are transmitted to the server for aggregation.
2. Cluster Clients:
    * Use the similarity matrix and clustering results to group clients dynamically during simulation.

3. Aggregate Models:
    * Aggregate models within clusters to form cluster-specific models.
    * Align training with client clusters for future rounds.
4. Repeat:

    * Iterate through multiple rounds of training, clustering, and aggregation.

### Define the Server

In [None]:
import flwr as fl
from flwr.server import Server
from flwr.common import EventType

# Define a custom strategy for the Federated Server
class CustomStrategy(fl.server.strategy.FedAvg):
    def configure_fit(self, rnd, parameters, client_manager):
        print(f"Configuring round {rnd}...")
        return super().configure_fit(rnd, parameters, client_manager)

    def aggregate_fit(self, rnd, results, failures):
        print(f"Aggregating results for round {rnd}...")
        return super().aggregate_fit(rnd, results, failures)

def main() -> Server:
    # Create strategy and server
    strategy = CustomStrategy()
    
    # Create the server with the strategy
    server = Server(strategy=strategy)
    
    # Add event handlers
    def on_server_start(event):
        print("Server started!")
    
    def on_round_start(event):
        print(f"Round {event.round_number} started!")
    
    server.add_event_handler(EventType.SERVER_START, on_server_start)
    server.add_event_handler(EventType.ROUND_START, on_round_start)
    
    return server

if __name__ == "__main__":
    print("\nTo start the server, run this command in the terminal:")
    print("flower-superlink --insecure --fleet-api-address=\"127.0.0.1:9092\" --serverappio-api-address=\"127.0.0.1:9093\"")

### Define the Client

In [None]:
import flwr as fl
import torch
from torch.utils.data import DataLoader
import argparse

# Import your models and data preparation functions
from model import ModelFactory
from dataset import load_datasets, prepare_datasets

class FedISMHClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader, test_loader, device):
        self.model = model
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.optimizer = torch.optim.Adam(model.parameters())
        self.criterion = torch.nn.CrossEntropyLoss()

    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {k: torch.tensor(v) for k, v in params_dict}
        self.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        self.model.train()
        
        for _ in range(config.get("epochs", 1)):
            for data, target in self.train_loader:
                data, target = data.to(self.device), target.to(self.device)
                self.optimizer.zero_grad()
                output = self.model(data)
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()

        return self.get_parameters(config), len(self.train_loader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        correct = total = 0

        with torch.no_grad():
            for data, target in self.test_loader:
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.size(0)

        accuracy = correct / total
        return float(self.criterion(output, target)), total, {"accuracy": accuracy}

def main(client_id: int):
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load and prepare datasets
    mnist_train, mnist_test, _, _ = load_datasets()
    datasets = prepare_datasets(mnist_train, iid=False)
    
    # Create client's model
    factory = ModelFactory()
    model = factory.create_model(
        'simple' if client_id % 2 == 0 else 'resnet',
        in_channels=1,
        num_classes=10
    ).to(device)
    
    # Create data loaders for this client
    train_loader = DataLoader(datasets['clients'][client_id], batch_size=32, shuffle=True)
    test_loader = DataLoader(datasets['test'], batch_size=32)
    
    # Start client
    client = FedISMHClient(model, train_loader, test_loader, device)
    fl.client.start_numpy_client(
        server_address="127.0.0.1:9092",
        client=client
    )

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Start a Flower client")
    parser.add_argument("--client-id", type=int, required=True, help="Client ID (0-9)")
    args = parser.parse_args()
    
    main(args.client_id)

# NOTE: These commands were used to run the federated system

In [None]:
# For Server
flower-superlink --insecure --fleet-api-address="127.0.0.1:9092" --serverappio-api-address="127.0.0.1:9093" --executor-config="executor_config.py"

# Client 1
flower-supernode --insecure --superlink="127.0.0.1:9092" --node-config="partition-id=0 num-partitions=5" --clientappio-api-address 0.0.0.0:9094

# Client 2
flower-supernode --insecure --superlink="127.0.0.1:9092" --node-config="partition-id=1 num-partitions=5" --clientappio-api-address 0.0.0.0:9095

# Client 3
flower-supernode --insecure --superlink="127.0.0.1:9092" --node-config="partition-id=2 num-partitions=5" --clientappio-api-address 0.0.0.0:9096

# Client 4
flower-supernode --insecure --superlink="127.0.0.1:9092" --node-config="partition-id=3 num-partitions=5" --clientappio-api-address 0.0.0.0:9097

# Client 5
flower-supernode --insecure --superlink="127.0.0.1:9092" --node-config="partition-id=4 num-partitions=5" --clientappio-api-address 0.0.0.0:9098