# Tutorial 7: Image generation

Synthcity supports generating synthetic images. In this tutorial, we will train a generator  based on the [MedNIST dataset](https://medmnist.com/). The Tutorial is adapted from a [MONAI example](https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/mednist_tutorial.ipynb).

The main components are:
 - Creating an `ImageDataloader` on top of the MedNIST dataset.
 - Training a Conditional GAN on the resulted dataloader.
 - Benchmarking the quality of the synthetic images.
 
__Disclaimer__: The models used for the Generators and the Discriminators are not state of the art. For adding better architectures, please update the `suggest_image_generator_discriminator_arch` options from the `convnet.py` module.

In [None]:
!pip install synthcity
!pip uninstall -y torchaudio torchdata

## Load MedNIST

The dataset is downloaded using [MONAI](https://monai.io/).

In [None]:
# Download MedNIST
# stdlib
import os
from pathlib import Path

# third party
import PIL
from monai.apps import download_and_extract

workspace = Path("workspace")
workspace.mkdir(parents=True, exist_ok=True)

resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz"
md5 = "0bc7306e7427e00ad1c5526a6677552d"

compressed_file = workspace / "MedNIST.tar.gz"
data_dir = workspace / "MedNIST"

if not data_dir.exists():
    download_and_extract(resource, compressed_file, workspace, md5)

In [None]:
LIMIT = 1000  # samples per class

class_names = sorted(
    x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x))
)
num_class = len(class_names)
image_files = [
    [
        os.path.join(data_dir, class_names[i], x)
        for x in os.listdir(os.path.join(data_dir, class_names[i]))
    ]
    for i in range(num_class)
]
num_each = [len(image_files[i]) for i in range(num_class)]
image_files_list = []
image_class = []
for i in range(num_class):
    image_files_list.extend(image_files[i][:LIMIT])
    image_class.extend([i] * min(num_each[i], LIMIT))
num_total = len(image_class)
image_width, image_height = PIL.Image.open(image_files_list[0]).size

print(f"Total image count: {num_total}")
print(f"Image dimensions: {image_width} x {image_height}")
print(f"Label names: {class_names}")
print(f"Label counts: {num_each}")

## Visualize random samples

In [None]:
# third party
import matplotlib.pyplot as plt
import numpy as np

plt.subplots(3, 3, figsize=(8, 8))
for i, k in enumerate(np.random.randint(num_total, size=9)):
    im = PIL.Image.open(image_files_list[k])
    arr = np.array(im)
    plt.subplot(3, 3, i + 1)
    plt.xlabel(class_names[image_class[k]])
    plt.imshow(arr, cmap="gray", vmin=0, vmax=255)
plt.tight_layout()
plt.show()

## Create the ImageDataLoader

The ImageDataLoader prepares the image dataset for the `synthcity` generators.

Internally, the dataloader will resize the data to the `(height, width)` parameters.

In [None]:
# third party
import torch

# synthcity absolute
from synthcity.plugins.core.dataloader import ImageDataLoader

IMG_SIZE = 64


class MedNISTDataset(torch.utils.data.Dataset):
    def __init__(self, image_files, labels):
        self.image_files = image_files
        self.image_cache = {}
        self.labels = labels

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, index):
        if index in self.image_cache:
            img = self.image_cache[index]
        else:
            img = PIL.Image.open(self.image_files[index])
            img = np.asarray(img)
            self.image_cache[index] = img

        return img, self.labels[index]


dataset = MedNISTDataset(image_files_list, labels=image_class)

dataloader = ImageDataLoader(
    dataset,
    height=IMG_SIZE,
)

## Load a generator - Conditional GAN

For this experiment, we are using the `image_cgan` plugin.

In [None]:
# synthcity absolute
from synthcity.plugins import Plugins

generator = Plugins().get("image_cgan", batch_size=100, plot_progress=True)

## Train the generator

For the training, we are using the `ImageDataLoader` object previously created.

At the same time, we are using a conditional(`cond`) with the labels of the images. This way, at inference time, we can request from the generator only samples from a specific class.

In [None]:
generator.fit(dataloader, cond=image_class)

## Generate new samples

In [None]:
# third party
import torch

# synthcity absolute
from synthcity.plugins.core.models.image_gan import display_imgs

syn_samples, syn_labels = generator.generate(count=5).unpack().tensors()

display_imgs(syn_samples)

## Generate new samples using a conditional

We can also generate instances from a specific class, using the conditional we used at training time(`cond`).

__Disclaimer__ : Other architectures for the Generator and the Discriminator could improve the results. These architectures can be tweaked in the `suggest_image_generator_discriminator_arch` function.

In [None]:
for cls_idx, cls in enumerate(class_names):
    print("Class", cls)
    syn_samples, syn_labels = (
        generator.generate(count=5, cond=np.ones(5) * cls_idx).unpack().tensors()
    )

    display_imgs(syn_samples)

## Benchmarks

`synthcity` allows us to compare multiple generators on the same dataset, with a wide range of metrics.

In [None]:
# synthcity absolute
from synthcity.benchmark import Benchmarks

score = Benchmarks.evaluate(
    [
        (f"test_{model}", model, {}) for model in ["image_cgan", "image_adsgan"]
    ],  # (testname, plugin, plugin_args) REPLACE {"n_iter" : 50} with {} for better performance
    dataloader,
    repeats=3,
    metrics={
        "detection": ["detection_mlp"],
        "performance": ["mlp"],
        "stats": ["fid"],
    },
    task_type="classification",
)

In [None]:
Benchmarks.print(score)

## Congratulations!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement towards Machine learning and AI for medicine, you can do so in the following ways!

### Star [Synthcity](https://github.com/vanderschaarlab/synthcity) on GitHub

- The easiest way to help our community is just by starring the Repos! This helps raise awareness of the tools we're building.


### Checkout other projects from vanderschaarlab
- [HyperImpute](https://github.com/vanderschaarlab/hyperimpute)
- [AutoPrognosis](https://github.com/vanderschaarlab/autoprognosis)
