In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.model_selection import StratifiedShuffleSplit
import torch

In [None]:
torch.manual_seed(42424242)
np.random.seed(42424242)

--------------------------------------------------------------------------------------------------------------------------------
###                                      Generation of Data      -         Synthetic Data
--------------------------------------------------------------------------------------------------------------------------------

In [None]:
# parameters of the (latent) gaussians
mus = (np.array([-2., -2.]), np.array([2., 2.]))
covs = (
    np.array([[0.5, 0.],[0., 0.5]]),
    np.array([[0.3, 0.1], [0.1, 0.3]]),
)

In [None]:
n_features = 2
n_classes = len(mus)
n_samples_per_class = 500

features = []
labels = []

for i in range(n_classes):
    features.append(
        np.random.multivariate_normal(mean=mus[i], 
                                      cov=covs[i],
                                      size=(n_samples_per_class,))
    )
    labels.append([i]*n_samples_per_class)

In [None]:
features_orig = np.array(features).reshape(n_classes*n_samples_per_class, n_features)
labels_orig = np.array(labels).reshape(n_classes*n_samples_per_class, 1).squeeze()

## View latent dataset

In [None]:
plt.scatter(
    features_orig[:,0], features_orig[:,1], c=labels_orig
)

## Warp data

In [None]:
use_simple_warp_functions = False

In [None]:
# Difficult warp functions
def warp1(a):
    x = a[:,0]
    y = a[:,1]
    out1 = x*y
    out2 = x + np.exp(-y/10.)
    out3 = (np.log(x*x*y*y) + 10*x*x*y - x)/100.
    return np.stack((out1, out2, out3), axis=-1)

def warp2(a):
    x = a[:,0]
    y = a[:,1]
    out1 = np.log(x*x*y*y) + 10*x*y*np.sin(x)
    out2 = y - 10*x*y
    out3 = x*y*np.tan(y)
    out4 = np.sin(x)-np.cos(y)
    return np.stack((out1, out2, out3, out4), axis=-1)

In [None]:
# Simpler warp functions
if use_simple_warp_functions:
    def warp1(a):
        x = a[:,0]
        y = a[:,1]
        out1 = x + y
        out2 = 10*x
        out3 = y - x
        return np.stack((out1, out2, out3), axis=-1)

    def warp2(a):
        x = a[:,0]
        y = a[:,1]
        out1 = y - 5*x
        out2 = 10*y
        out3 = y + x*x
        out4 = x + y
        return np.stack((out1, out2, out3, out4), axis=-1)

In [None]:
warping_fns = [
    warp1,
    warp2
]

## Split data among clients

In [None]:
import torch
from torch.utils.data import TensorDataset, DataLoader

In [None]:
n_clients = 2
n_samples_per_client = n_samples_per_class*n_classes//n_clients
from sklearn.utils import shuffle
features_orig, labels_orig = shuffle(features_orig, labels_orig)
client_datasets = []
for i_client in range(n_clients-1):
    client_datasets.append(
        TensorDataset(torch.tensor(
            warping_fns[i_client](features_orig[i_client*n_samples_per_client:(i_client+1)*n_samples_per_client,:]),
            dtype=torch.float32),
         torch.tensor(labels_orig[i_client*n_samples_per_client:(i_client+1)*n_samples_per_client])
        )
    )
i_client += 1
client_datasets.append(
    TensorDataset(torch.tensor(warping_fns[i_client](features_orig[i_client*n_samples_per_client:,:]),
                              dtype=torch.float32),
     torch.tensor(labels_orig[i_client*n_samples_per_client:])
    )
)

## View dataset for client 1

In [None]:
client_datasets[0][0]

In [None]:
fig, ax = plt.subplots(1,3,figsize=(10,3))
features_np = client_datasets[0].tensors[0].numpy()
labels_np = client_datasets[0].tensors[1].numpy()
ax[0].scatter(
    features_np[:,0], features_np[:,1], c=labels_np
)
ax[1].scatter(
    features_np[:,0], features_np[:,2], c=labels_np
)
ax[2].scatter(
    features_np[:,1], features_np[:,2], c=labels_np
)

## View dataset for client 2

In [None]:
fig, ax = plt.subplots(2,3,figsize=(10,6))
features_np = client_datasets[1].tensors[0].numpy()
labels_np = client_datasets[1].tensors[1].numpy()
ax[0][0].scatter(
    features_np[:,0], features_np[:,1], c=labels_np
)
ax[0][1].scatter(
    features_np[:,0], features_np[:,2], c=labels_np
)
ax[0][2].scatter(
    features_np[:,0], features_np[:,3], c=labels_np
)
ax[1][0].scatter(
    features_np[:,1], features_np[:,2], c=labels_np
)
ax[1][1].scatter(
    features_np[:,1], features_np[:,3], c=labels_np
)
ax[1][2].scatter(
    features_np[:,2], features_np[:,3], c=labels_np
)

## Definition of clients

In [None]:
from torch import nn
import torch.nn.functional as F

In [None]:
class Client(nn.Module):
    def __init__(self, encoder_model):
        super(Client, self).__init__()
        
        self.encoder_model = encoder_model
    
        self.latent_model = nn.Sequential(
                                nn.Linear(10,5),
                                nn.Linear(5,2)
                                 )
    
        self.decoder_model = nn.Sequential(nn.Linear(2,4), 
                                           nn.Linear(4,2))
        
        self.optimizer = torch.optim.SGD(self.parameters(), lr=0.001)

    def forward(self, inputs):
        outputs_encoder = self.encoder_model(inputs)
        outputs_latent = self.latent_model(outputs_encoder)
        outputs_class = self.decoder_model(outputs_latent)
        #print(f"Outputs: {outputs_class}")
        return outputs_class
    
    def get_latent_space(self, inputs):
        self.eval()
        outputs_encoder = self.encoder_model(inputs)
        return self.latent_model(outputs_encoder)

In [None]:
encoder_models = [
    nn.Sequential(
        nn.Linear(3,20),
        nn.Linear(20,10)
    ),
    nn.Sequential(
        nn.Linear(4,30),
        nn.Linear(30,10)
    )
]

In [None]:
clients = [Client(encoder_model=e) for e in encoder_models]

## DataLoaders

In [None]:
batch_size = 5
shuffle = True
loaders = [DataLoader(dataset=d, batch_size=batch_size, shuffle=shuffle) for d in client_datasets]

## Training loop

In [None]:
tot_num_samples = np.sum([len(d) for d in client_datasets])

In [None]:
loss_fn = nn.BCEWithLogitsLoss()

In [None]:
loss_fn(
    torch.tensor([[1., 0.]]),
    torch.tensor([[1., 0.]])
)

In [None]:
# Do one "dummy" Federated Averaging to ensure latent models are the same
new_state_dict = dict()
weights = [len(d)/tot_num_samples for d in client_datasets]
for param_name, params in clients[0].latent_model.state_dict().items():
    avg_params = weights[0]*params.detach()
    for client_, weight_ in zip(clients[1:], weights[1:]):
        avg_params = avg_params + weight_*client_.latent_model.state_dict()[param_name].detach()
    new_state_dict[param_name] = avg_params
for client_ in clients:
    client_.latent_model.load_state_dict(new_state_dict)

print(f"After FedAVG {[client_.latent_model.state_dict()['0.weight'][0][:1] for client_ in clients]}")

In [None]:
do_fedavg = True

In [None]:
num_epochs_per_round = 2
num_rounds = 100
losses = [[] for c in clients]
for round_ in range(1,num_rounds+1):
    #print(f"Before Training {[client_.latent_model.state_dict()['0.weight'][0][:1] for client_ in clients]}")
    
    for i_client, (client_, loader_) in enumerate(zip(clients, loaders)):
        client_.train()
        for epoch_ in range(1,num_epochs_per_round+1):
            avg_loss = 0.
            loader_iter = iter(loader_)
            for data, label in loader_iter:
                client_.optimizer.zero_grad()
                #print(f":: data {data} :: label {label}")
                prediction = client_.forward(data)
                #print(f":: pred {prediction}")
                ## convert label to logits-like tensor
                target = torch.zeros_like(prediction, dtype=torch.float32)
                for i, label_idx in enumerate(label):
                    target[i,label_idx] = 1.
                loss = loss_fn(prediction, target)
                #print(f":: target {target}")
                #print(f":: loss {loss}")
                loss.backward()
                client_.optimizer.step()
                avg_loss += loss.detach()
            avg_loss /= len(loader_)
            #print(f"Client {i_client} Epoch {epoch_} Average Loss {avg_loss}")
            losses[i_client].append(avg_loss)
                
    #print(f"Before FedAVG {[client_.latent_model.state_dict()['0.weight'][0][:1] for client_ in clients]}")
            
    # Federated Averaging
    if do_fedavg:
        new_state_dict = dict()
        weights = [len(d)/tot_num_samples for d in client_datasets]
        for param_name, params in clients[0].latent_model.state_dict().items():
            avg_params = weights[0]*params.detach()
            for client_, weight_ in zip(clients[1:], weights[1:]):
                avg_params = avg_params + weight_*client_.latent_model.state_dict()[param_name].detach()
            new_state_dict[param_name] = avg_params
        for client_ in clients:
            client_.latent_model.load_state_dict(new_state_dict)
        
    #print(f"After FedAVG {[client_.latent_model.state_dict()['0.weight'][0][:1] for client_ in clients]}")

In [None]:
plt.plot(losses[0], label='client 1')
plt.plot(losses[1], label='client 2')
plt.legend()

## Plot latent space

In [None]:
fig, axs = plt.subplots(1, n_clients, figsize=(14,6))
for i, ax in enumerate(axs):
    client_ = clients[i]
    client_.eval()
    dataset = client_datasets[i]
    latent_space = []
    labels = []
    for idata in range(len(dataset)):
        data, label = dataset[idata]
        latent_space.append(client_.get_latent_space(data).detach().numpy())
        labels.append(label)
    latent_space = np.array(latent_space)
    ax.scatter(latent_space[:,0], latent_space[:,1], c=labels)