In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

In [None]:
# Simulated data parameters
n_samples = 1000
n_features_client1 = 10  # e.g., demographic + vitals
n_features_client2 = 15  #oe.g., lab tests

# Simulate input data and labels
x1 = torch.randn(n_samples, n_features_client1)  # Client 1 input
x2 = torch.randn(n_samples, n_features_client2)  # Client 2 input
y = torch.randint(0, 2, (n_samples, 1)).float()   # Binary labels

# Dataset and DataLoader
batch_size = 64
dataset = TensorDataset(x1, x2, y)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
# Client 1 MLP
class Client1Model(nn.Module):
    def __init__(self):
        super(Client1Model, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(n_features_client1, 32),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(32, 16),
            nn.ReLU()
        )

    def forward(self, x):
        return self.model(x)

# Client 2 MLP
class Client2Model(nn.Module):
    def __init__(self):
        super(Client2Model, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(n_features_client2, 32),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(32, 16),
            nn.ReLU()
        )

    def forward(self, x):
        return self.model(x)


In [None]:
# Server MLP
class ServerModel(nn.Module):
    def __init__(self):
        super(ServerModel, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(32, 32),  # 16 + 16 from both clients
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [None]:
# Initialize models
client1 = Client1Model()
client2 = Client2Model()
server = ServerModel()

# Optimizer and loss
params = list(client1.parameters()) + list(client2.parameters()) + list(server.parameters())
optimizer = optim.Adam(params, lr=0.001)
criterion = nn.BCELoss()


In [None]:
# Training loop
epochs = 20
for epoch in range(epochs):
    running_loss = 0.0
    for batch_x1, batch_x2, batch_y in dataloader:
        # Forward pass
        z1 = client1(batch_x1)
        z2 = client2(batch_x2)
        z = torch.cat((z1, z2), dim=1)
        y_pred = server(z)

        # Loss computation
        loss = criterion(y_pred, batch_y)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * batch_x1.size(0)

    epoch_loss = running_loss / len(dataloader.dataset)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}")