<a href="https://colab.research.google.com/github/wuwewij/bachelor_thesis/blob/main/%E5%88%86%E5%B8%83%E8%AE%AD%E7%BB%83.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset


class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(8, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


def initialize_model():
    model = AutoEncoder()
    return model


def train_model(model, train_loader, criterion, optimizer, num_epochs=25):
    model.train()
    for epoch in range(num_epochs):
        for data in train_loader:
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')


def cold_start(coordinator_model, worker_models, train_loader, num_workers=10, num_epochs=25):
    global_weights = coordinator_model.state_dict()
    for worker_model in worker_models:
        worker_model.load_state_dict(global_weights)


    for worker_model in worker_models:
        optimizer = optim.Adam(worker_model.parameters(), lr=0.001)
        criterion = nn.MSELoss()
        train_model(worker_model, train_loader, criterion, optimizer, num_epochs)


    new_weights = [worker_model.state_dict() for worker_model in worker_models]
    global_weights = aggregate_weights(new_weights, num_workers)
    coordinator_model.load_state_dict(global_weights)

def aggregate_weights(weights, num_workers):
    avg_weights = weights[0]
    for key in avg_weights.keys():
        for i in range(1, num_workers):
            avg_weights[key] += weights[i][key]
        avg_weights[key] = torch.div(avg_weights[key], num_workers)
    return avg_weights

def train_autoencoder(coordinator_model, worker_models, train_loader, num_workers=10, num_epochs=25):
    cold_start(coordinator_model, worker_models, train_loader, num_workers, num_epochs)
    dataset_S1 = create_dataset_from_weights(worker_models, train_loader)
    autoencoder = initialize_model()
    optimizer = optim.Adam(autoencoder.parameters(), lr=0.001)
    criterion = nn.MSELoss()
    train_model(autoencoder, DataLoader(dataset_S1, batch_size=4, shuffle=True), criterion, optimizer, num_epochs)
    encoder = autoencoder.encoder
    decoder = autoencoder.decoder
    global_weights = coordinator_model.state_dict()
    for worker_model in worker_models:
        worker_model.load_state_dict(global_weights)


    for worker_model in worker_models:
        optimizer = optim.Adam(worker_model.parameters(), lr=0.001)
        train_model(worker_model, train_loader, criterion, optimizer, num_epochs)
        compressed_weights = compress_weights(worker_model, encoder)
        send_compressed_weights_to_coordinator(compressed_weights, decoder)

def create_dataset_from_weights(worker_models, train_loader):

    weights_list = []
    for worker_model in worker_models:
        weights = worker_model.state_dict()
        for key, value in weights.items():
            weights_list.append(value.view(-1).detach().numpy())
    dataset_S1 = TensorDataset(torch.tensor(weights_list))
    return dataset_S1

def compress_weights(model, encoder):

    weights = model.state_dict()
    compressed_weights = {}
    for key, value in weights.items():
        compressed_weights[key] = encoder(value.view(1, 1, value.size(0), value.size(1)))
    return compressed_weights

def send_compressed_weights_to_coordinator(compressed_weights, decoder):

    decompressed_weights = {}
    for key, value in compressed_weights.items():
        decompressed_weights[key] = decoder(value).view(-1)