<a href="https://colab.research.google.com/github/soumik12345/tf2_gans/blob/gaugan/notebooks/gaugan_facades_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Setup

In [None]:
# Clone the repository as it contains utility functions. 
!git clone https://github.com/soumik12345/tf2_gans
!cd tf2_gans && pip install -qr requirements.docker

# Get the Facades dataset. We'll use the validation masks from here.
!cd tf2_gans && gdown https://drive.google.com/uc?id=1q4FEjQg1YSb4mPx2VdxL7LXKYu3voTMj
!cd tf2_gans && unzip -q facades_data.zip

# Get the pre-trained checkpoints.
!wget https://github.com/soumik12345/tf2_gans/releases/download/v0.2/checkpoints.zip
!unzip -q checkpoints.zip

In [None]:
# Reference: https://stackoverflow.com/questions/64862818/cannot-import-name-png-from-matplotlib
!pip uninstall -q matplotlib
!pip install -q matplotlib==3.1.3

## Imports

In [None]:
import sys
sys.path.append("tf2_gans")

from tensorflow import keras
import tensorflow as tf

import matplotlib.pyplot as plt

from gaugan.dataloader import FacadesDataLoader
from gaugan.models import GauGAN
from configs import facades

In [None]:
print(f"TensorFlow version: {tf.__version__}.")
print(f"Keras version: {keras.__version__}.")

## Initialize the GauGAN model and populate weights

In [None]:
configurations = facades.get_config()
gaugan_model = GauGAN(
    image_size=configurations.image_height,
    num_classes=configurations.num_classes,
    batch_size=configurations.batch_size,
    hyperparameters=configurations.hyperparameters,
)
print("GauGAN model initialized.")

In [None]:
from glob import glob

disc_path = glob("/content/tf2_gans/checkpoints/models/*/discriminator")[0]
generator_path = glob("/content/tf2_gans/checkpoints/models/*/generator")[0]

gaugan_model.discriminator = keras.models.load_model(disc_path)
gaugan_model.generator = keras.models.load_model(generator_path)

print("Weights populated.")

In [None]:
if configurations.dataset_dir == "facades_data":
    configurations.dataset_dir = f"/content/tf2_gans/{configurations.dataset_dir}"

print(f"Dataset path: {configurations.dataset_dir}.")

## Initialize the validation dataset

In [None]:
data_loader = FacadesDataLoader(
    target_image_height=configurations.image_height,
    target_image_width=configurations.image_width,
    num_classes=configurations.num_classes,
    data_dir=configurations.dataset_dir,
)
_, val_dataset = data_loader.get_datasets(
    batch_size=configurations.batch_size,
    split_fraction=configurations.split_fraction,
)
print("Validation dataset prepared.")

## Perform inference

In [None]:
val_iterator = iter(val_dataset)

for _ in range(5):
    val_images = next(val_iterator)
    # Sample latent from a normal distribution.
    latent_vector = tf.random.normal(
        shape=(gaugan_model.batch_size, gaugan_model.latent_dim), mean=0.0, stddev=2.0
    )
    # Generate fake images.
    fake_images = gaugan_model.predict([latent_vector, val_images[2]])

    real_images = val_images
    grid_row = min(fake_images.shape[0], 3)
    grid_col = 3
    f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col * 6, grid_row * 6))
    for row in range(grid_row):
        ax = axarr if grid_row == 1 else axarr[row]
        ax[0].imshow((real_images[0][row] + 1) / 2)
        ax[0].axis("off")
        ax[0].set_title("Mask", fontsize=20)
        ax[1].imshow((real_images[1][row] + 1) / 2)
        ax[1].axis("off")
        ax[1].set_title("Ground Truth", fontsize=20)
        ax[2].imshow((fake_images[row] + 1) / 2)
        ax[2].axis("off")
        ax[2].set_title("Generated", fontsize=20)
    plt.show()

For more details please refer to our blog post: [GauGAN for conditional image generation](https://keras.io/examples/generative/gaugan/).