# Semantic Image Synthesis with SPADE in TensorFlow
The code in this notebook is to do Semantic Image Synthesis using a [TensorFlow](https://www.tensorflow.org/) implementation of the NVIDIA's SPADE [paper](https://arxiv.org/abs/1903.07291). Credits for TF porting go to [taki0112](https://github.com/taki0112).  
Tests are performed using a pre-trained model on the [CelebAMask-HQ](https://github.com/switchablenorms/CelebAMask-HQ) dataset.  
Please switch to a GPU runtime before executing the code in this notebook. 

### Settings

Clone the GitHub repository from taki0112.

In [None]:
!git clone https://github.com/taki0112/SPADE-Tensorflow.git
%cd ./SPADE-Tensorflow

Switch to TensorFlow 1.x, as the implementation of this code is in TensorFlow 1.

In [None]:
%tensorflow_version 1.x

Downgrade scipy to release 1.2.0.

In [None]:
!pip install scipy==1.2.0

Download the model checkpoint (pre-trained on the CelebAMask-HQ dataset).  
The checkpoint size is 1.76 GB.

In [None]:
!gdown --id 1UIj7eRJeNWrDS-3odyaoLhcqk0tNcEez
!unzip ./checkpoint.zip
!rm -f ./checkpoint.zip

Create the segmap label for the test dataset as explainded in https://github.com/taki0112/SPADE-Tensorflow/issues/15

In [None]:
segmap_label_content = '{(0, 0, 0): 0, (0, 0, 255): 1, (255, 0, 0): 2, (150, 30, 150): 3, (255, 65, 255): 4, (150, 80, 0): 5, (170, 120, 65): 6, (125, 125, 125): 7, (255, 255, 0): 8, (0, 255, 255): 9, (255, 150, 0): 10, (255, 225, 120): 11, (255, 125, 125): 12, (200, 100, 100): 13, (0, 255, 0): 14, (0, 150, 80): 15, (215, 175, 125): 16, (220, 180, 210): 17, (125, 125, 255): 18}'
with open('./dataset/spade_celebA/segmap_label.txt', 'w') as f:
    f.write(segmap_label_content)

### Semantic Image Synthesis Tests

Random test using the pre-trained model.

In [None]:
!python main.py --checkpoint_dir . --dataset spade_celebA --segmap_ch 3 --phase random

Display the results.

In [None]:
!rm -f ./results/SPADE_spade_celebA_hinge_2multi_4dis_1_1_10_10_0.05_sn_TTUR_more/index.html

In [None]:
import cv2
import os
import matplotlib.pyplot as plt

image_output_dir = './results/SPADE_spade_celebA_hinge_2multi_4dis_1_1_10_10_0.05_sn_TTUR_more'
items = os.listdir(image_output_dir)
items.sort()   

generated_image_count = len(items)
image_index = 0
rows = 5
cols = 3
fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(15,15))

for i in range(rows):
    for j in range(cols):  
      if image_index < generated_image_count:
        full_path = image_output_dir + '/' + items[image_index]
        image = cv2.imread(full_path)
        image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)      
        axes[i, j].imshow(image)
        axes[i, j].set_title(items[image_index])
        image_index+=1