# Federated GAN tutorial with PyTorch Lightning

In [None]:
!pip install "pytorch-lightning>=1.3" "torch==1.9.1" "torchvision==0.10.1" "torchmetrics>=0.3" "scikit-image" "matplotlib"

# Connect to the Federation

In [None]:
from openfl.interface.interactive_api.federation import Federation

client_id = "frontend"
director_node_fqdn = "localhost"
director_port = 50050

#Run with TLS disabled (trusted environment)
federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port,
    tls=False,
)

In [None]:
shard_registry = federation.get_shard_registry()
shard_registry

In [None]:
federation.target_shape

## Creating a FL experiment using Interactive API

In [None]:
from openfl.interface.interactive_api.experiment import (
    DataInterface,
    FLExperiment,
    ModelInterface,
    TaskInterface,
)

### Register dataset

In [None]:
import copy
import os
import shutil
import PIL
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.datasets import MNIST

In [None]:
mnist_transform = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)


class MnistShardDataset(Dataset):
    def __init__(self, x, y, transform=None):
        self.x, self.y = x, y
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.x[index], self.y[index]
        x = self.transform(x).numpy()
        y = y.numpy()
        return x, y

    def __len__(self):
        return len(self.x)


class MnistFedDataset(DataInterface):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @property
    def shard_descriptor(self):
        return self._shard_descriptor

    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        self.train_set = MnistShardDataset(
            self._shard_descriptor.get_dataset("train")[:][0],
            self._shard_descriptor.get_dataset("train")[:][1],
            transform=mnist_transform,
        )
        self.valid_set = MnistShardDataset(
            self._shard_descriptor.get_dataset("val")[:][0],
            self._shard_descriptor.get_dataset("val")[:][1],
            transform=mnist_transform,
        )

    def __getitem__(self, index):
        return self.shard_descriptor[index]

    def __len__(self):
        return len(self.shard_descriptor)

    def get_train_loader(self):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        if self.kwargs["train_bs"]:
            batch_size = self.kwargs["train_bs"]
        else:
            batch_size = 256
        return DataLoader(self.train_set, batch_size=batch_size, num_workers=4)

    def get_valid_loader(self):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        if self.kwargs["valid_bs"]:
            batch_size = self.kwargs["valid_bs"]
        else:
            batch_size = 64
        return DataLoader(self.valid_set, batch_size=batch_size)

    def get_train_data_size(self):
        """
        train data size
        """

        return len(self.train_set)

    def get_valid_data_size(self):
        """
        val data size
        """
        return len(self.valid_set)


In [None]:
fed_dataset = MnistFedDataset(train_bs=256, valid_bs=64)

### Describe a model and optimizer

In [None]:
"""
Generator and discriminator model definition
"""


class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super().__init__()
        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh(),
        )

    def forward(self, z):
        z = z.float()
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        img_flat = img_flat.float()
        validity = self.model(img_flat)

        return validity

In [None]:
class GAN(LightningModule):
    def __init__(
        self,
        channels,
        width,
        height,
        train_disc_only,
        train_gen_only,
        latent_dim: int = 100,
        lr: float = 0.0002,
        b1: float = 0.5,
        b2: float = 0.999,
        batch_size: int = 256,
        **kwargs
    ):
        super().__init__()
        self.save_hyperparameters()

        data_shape = (channels, width, height)
        self.generator = Generator(
            latent_dim=self.hparams.latent_dim, img_shape=data_shape
        )
        self.discriminator = Discriminator(img_shape=data_shape)

        self.validation_z = torch.randn(8, self.hparams.latent_dim)
        self.example_input_array = torch.zeros(2, self.hparams.latent_dim)
        self.train_disc_only = train_disc_only
        self.train_gen_only = train_gen_only

    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, _ = batch

        # sample noise
        z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
        z = z.type_as(imgs)

        if optimizer_idx == 0 and self.train_disc_only == 0:
            return self.train_generator(imgs, z, display_images=0)

        elif optimizer_idx == 1 and self.train_gen_only == 0:
            return self.train_discriminator(imgs, z)

    def train_generator(self, imgs, z, display_images=0):
        self.generated_imgs = self(z)
        sample_imgs = self.generated_imgs[:10]
        sample_imgs = np.reshape(sample_imgs.detach().cpu().numpy(), (10, 28, 28, 1))
            
        if display_images:
            from skimage import data, io
            from matplotlib import pyplot as plt
            for img in sample_imgs:
                io.imshow(img.reshape((28, 28)), cmap='gray_r')
                plt.axis('off')
                plt.show()

        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs).float()

        g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
        tqdm_dict = {"g_loss": g_loss}
        output = OrderedDict(
            {"loss": g_loss, "progress_bar": tqdm_dict, "log": tqdm_dict}
        )
        self.log(name="Generator training loss", value=g_loss, on_epoch=True)
        return output

    def train_discriminator(self, imgs, z):
        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs).float()

        real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

        fake = torch.zeros(imgs.size(0), 1)
        fake = fake.type_as(imgs).float()

        fake_loss = self.adversarial_loss(self.discriminator(self(z).detach()), fake)

        d_loss = (real_loss + fake_loss) / 2
        tqdm_dict = {"d_loss": d_loss}
        output = OrderedDict(
            {"loss": d_loss, "progress_bar": tqdm_dict, "log": tqdm_dict}
        )
        self.log(name="Discriminator training loss", value=d_loss, on_epoch=True)
        return output

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))

        return [opt_g, opt_d]

    def validation_step(self, batch, batch_idx, optimizer_idx=1):
        imgs, _ = batch

        valid = torch.ones(imgs.size(0), 1)
        valid = valid.type_as(imgs).float()

        val_real_loss = self.adversarial_loss(self.discriminator(imgs), valid)
        self.log(name="Discriminator val loss", value=val_real_loss, on_epoch=True)
        return {"val_loss": val_real_loss}

    def on_epoch_end(self):
        z = self.validation_z.type_as(self.generator.model[0].weight)

        sample_imgs = self(z)
        grid = torchvision.utils.make_grid(sample_imgs)
        self.logger.experiment.add_image("generated_images", grid, self.current_epoch)

In [None]:
from pytorch_lightning.callbacks import Callback


class MetricsCallback(Callback):
    """PyTorch Lightning metric callback."""

    def __init__(self):
        super().__init__()
        self.metrics = []

    def on_epoch_end(self, trainer, pl_module):
        met = copy.deepcopy(trainer.callback_metrics)
        self.metrics.append(met)

    def __call__(self):
        return self.get_callbacks()

In [None]:
model = GAN(channels=1, width=28, height=28, train_disc_only=0, train_gen_only=0)

In [None]:
optimizer = model.configure_optimizers()

#### Register model

In [None]:
from copy import deepcopy

# Need this plugin only if multiple optimizers are used. Not required for PyTorch Lightning with a single optimizer.
framework_adapter = (
    "plugin_for_multiple_optimizers.FrameworkAdapterPluginforMultipleOpt"
)
MI = ModelInterface(
    model=model, optimizer=optimizer, framework_plugin=framework_adapter
)

initial_model = deepcopy(model)

### Define and register FL tasks

In [None]:
TI = TaskInterface()

import tqdm

@TI.register_fl_task(
    model="model", data_loader="train_loader", device="device", optimizer="optimizer"
)
def train(model, train_loader, optimizer, device, some_parameter=None):

    print(f"\n\n TASK TRAIN GOT DEVICE {device}\n\n")

    AVAIL_GPUS = 1 if "cuda" in device else 0

    trainer = Trainer(gpus=AVAIL_GPUS, max_epochs=1, callbacks=[MetricsCallback()])
    trainer.fit(model=model, train_dataloaders=train_loader)
    print("training logged metrics", trainer.logged_metrics)

    if "Discriminator training loss_epoch" in trainer.logged_metrics:
        train_loss = trainer.logged_metrics["Discriminator training loss_epoch"]
    else:
        train_loss = trainer.logged_metrics["Generator training loss_epoch"]
    return {"train_loss": train_loss}


@TI.register_fl_task(model="model", data_loader="val_loader", device="device")
def validate(model, val_loader, device):

    print(f"\n\n TASK VALIDATE GOT DEVICE {device}\n\n")

    model.eval()
    model.to(device)

    AVAIL_GPUS = 1 if "cuda" in device else 0

    trainer = Trainer(gpus=AVAIL_GPUS, max_epochs=1, callbacks=[MetricsCallback()])

    trainer.validate(model=model, dataloaders=val_loader)
    print("validation logged metrics", trainer.logged_metrics)

    val_loss = trainer.logged_metrics["Discriminator val loss"]

    return {"val_loss": val_loss}

## Time to start a federated learning experiment

In [None]:
# create an experimnet in federation
experiment_name = "PL_MNIST_test_experiment"
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
fl_experiment.start(
    model_provider=MI,
    task_keeper=TI,
    data_loader=fed_dataset,
    rounds_to_train=10,
    opt_treatment="CONTINUE_GLOBAL",
    device_assignment_policy="CUDA_PREFERRED",
)

In [None]:
fl_experiment.stream_metrics()

## Check the images generated by the model

In [None]:
!pip install -r ../envoy/sd_requirements.txt
import sys

sys.path.insert(1, "../envoy")
from mnist_shard_descriptor import MnistShardDescriptor

In [None]:
fed_dataset = MnistFedDataset(train_bs=256, valid_bs=64)
fed_dataset.shard_descriptor = MnistShardDescriptor(rank_worldsize="1,1")

last_model = fl_experiment.get_last_model()

In [None]:
val_imgs, _ = next(iter(fed_dataset.get_valid_loader()))

z = torch.randn(val_imgs.shape[0], 100)

In [None]:
last_model.train_generator(val_imgs, z, display_images=1)
pass