# Setup

In [3]:
import time

import os
import numpy as np
import pandas as pd

import tensorflow as tf 
from tensorflow import keras
#tf.config.set_per_process_memory_growth(True)

import pickle

import collections
from collections import defaultdict
import matplotlib.pyplot as plt

from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model, load_model
from keras.layers import Input, Dense, Conv2D, LeakyReLU, Dropout, Flatten, MaxPooling2D, GlobalAveragePooling2D
from keras.layers import BatchNormalization, Embedding, Reshape, Activation
from keras.layers import Concatenate, Conv2DTranspose, multiply, UpSampling2D
from keras.initializers import RandomNormal
from keras.optimizers import Adam
from keras.utils import Progbar
from keras.metrics import *
from keras import backend as K

import cv2

# Load Weights Only

## Define and initialize Generator

In [5]:
def generator(latent_dim = 100, n_classes = 3):
    # Initialize RandomNormal with mean = 0.0 and stddev = 0.02
    # init = RandomNormal(mean = 0.0, stddev = 0.02)
  
    ### Input 1: class label input ###
    
    # Generator take integer class label as input
    label_input = Input(shape = (1,))
    # print(label_input.shape)
    
    # Embedding layer: to convert class label integer to a vector of size 100
    y = Embedding(n_classes, 100)(label_input)
    # print('Embedding Layer: ', y.shape)
    
    # Dense layer with 7 x 7 units: to convert the vector to a 7 x 7 x 1 tensor
    n_nodes = 7 * 7
    y = Dense(n_nodes, kernel_initializer = RandomNormal(mean = 0.0, stddev = 0.02))(y)
    # print('Dense 1: ', y.shape)
    y = Reshape((7, 7 ,1))(y)
    print('reshape(final y shape): ', y.shape)

    ### Input 2: generator noise input ###
    
    # A latent_dim-dimensional vector is sampled from a normal distribution
    # with mean = 0.0 and stddev = 0.02 
    generator_input = Input(shape = (latent_dim,))
    
    # Noise vector is passed through a dense layer with 1024 * 7 * 7 units 
    # to produce a 7 x 7 x 1024 tensor
    n_nodes = 1024 * 7 * 7
    gen = Dense(n_nodes, 
                kernel_initializer = RandomNormal(mean = 0.0, stddev = 0.02))(generator_input)
    gen = Activation('relu')(gen)
    gen = Reshape((7, 7, 1024))(gen)
    print('Generator noise input: ', gen.shape)
    
    ### Concatenate both the inputs ###
    # The output tensors are then concatenated to produce a 7 × 7 × 1025 tensor. 
    merge = Concatenate()([gen, y])
    print('Concatenate(generator noise input and y: ', merge.shape)

    ### Upsampling ###
    # four successive transposed convolutions 
    # to produce tensors with dimensions 14 × 14 × 512, 28 × 28 × 256, 56 × 56 × 128 and 128 × 128 × 3, respectively.
    
    # (None, 7, 7, 1024) --> (None, 14, 14, 512)
    gen = Conv2DTranspose(512, kernel_size = (5, 5), strides = (2, 2), padding = "same", kernel_initializer = RandomNormal(mean = 0.0, stddev = 0.02))(merge)
    gen = BatchNormalization(momentum = 0.9)(gen)
    gen = Activation("relu")(gen)
    print("(None, 7, 7, 1024) -> (None, 14, 14, 512): ", gen.shape)

    # (None, 14, 14, 512)  --> (None, 28, 28, 256)
    gen = Conv2DTranspose(256, kernel_size = (5, 5), strides = (2, 2), padding = "same", kernel_initializer = RandomNormal(mean = 0.0, stddev = 0.02))(gen)
    gen = BatchNormalization(momentum = 0.9)(gen)
    gen = Activation("relu")(gen)
    print('(None, 14, 14, 512) -> (None, 28, 28, 256): ', gen.shape)

    # (None, 28, 28, 256) --> (None, 56, 56, 128)
    gen = Conv2DTranspose(128, kernel_size = (5, 5), strides = (2, 2), padding = "same", kernel_initializer = RandomNormal(mean = 0.0, stddev = 0.02))(gen)
    gen = BatchNormalization(momentum = 0.9)(gen)
    gen = Activation("relu")(gen)
    print('(None, 28, 28, 256) -> (None, 56, 56, 128): ', gen.shape)

    # (None, 56, 56, 128) --> (None, 112, 112, 3)
    gen = Conv2DTranspose(3, kernel_size = (5, 5), strides = (2, 2), padding = "same", kernel_initializer = RandomNormal(mean = 0.0, stddev = 0.02))(gen)
    out_layer = Activation("tanh")(gen)
    print("(None, 56, 56, 128) -> (None, 112, 112, 3): ", out_layer.shape)
    
    # The final output from the generator is an fake image X of dimension 112 × 112 × 3
    model = Model(inputs = [generator_input, label_input], outputs = out_layer)
    return model

In [6]:
generator_instance = generator(latent_dim = 100, n_classes = 3)
generator_instance.summary()

reshape(final y shape):  (None, 7, 7, 1)
Generator noise input:  (None, 7, 7, 1024)
Concatenate(generator noise input and y:  (None, 7, 7, 1025)
(None, 7, 7, 1024) -> (None, 14, 14, 512):  (None, 14, 14, 512)
(None, 14, 14, 512) -> (None, 28, 28, 256):  (None, 28, 28, 256)
(None, 28, 28, 256) -> (None, 56, 56, 128):  (None, 56, 56, 128)
(None, 56, 56, 128) -> (None, 112, 112, 3):  (None, 112, 112, 3)
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 100)]        0           []                               
                                                                                                  
 input_1 (InputLayer)           [(None, 1)]          0           []                               
                                                                                      

## Load weights

In [1]:
# Change x with epoch number
epoch = 367
filename = f"./SaveDir/best_weights/params_generator_epoch_{epoch}.hdf5"

In [7]:
generator_instance.load_weights(filename)

# Tools functions

In [8]:
# Function to generate a batch of noise and label 

def generate_batch_noise_and_labels(batch_size, latent_dim, gen_class = "normal"):
    name_map = {'COVID-19':0, 'normal':1, 'pneumonia':2}
    # generate a new batch of noise
    noise = np.random.uniform(-1, 1, (batch_size, latent_dim))

    # defined labels
    labels = np.full(batch_size, name_map[gen_class], dtype=int)

    return noise, labels

# Generation

In [9]:
# Define number of images to generate
generate_class_n = {'COVID-19':8000, 'normal':8000, 'pneumonia':8000}
# generate_class_n = {'COVID-19':8000, 'normal':8000, 'pneumonia':8000}
# Batch size
batch_size = 256
# Path to save images
path_save = f'./SaveDir/generated_images/epoch_{epoch}'
# Class map
#class_map = {0:'COVID-19', 1:'normal', 2:'pneumonia'}

In [11]:
with open(f"{path_save}images_list.txt", "w+") as file:
    for key, value in generate_class_n.items():
        if value != 0:
            n_batch = (value // batch_size) + 1 # produce n_batch of images (surplus)
            counter_image = 0
            for i in range(n_batch):
                noise, labels = generate_batch_noise_and_labels(batch_size = batch_size, latent_dim = 100, gen_class = key)
                
                generated_images_batch = generator_instance.predict([noise, labels.reshape((-1, 1))], verbose=0)
                norm_image = cv2.normalize((generated_images_batch + 1) * 127.5, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
                norm_img = norm_image.astype(np.uint8)
                for j in range(norm_img.shape[0]):
                    cv2.imwrite(f"{path_save}/{key}/{key}_{counter_image}.png", norm_img[j,:,:])
                    file.write(f"{key}_{counter_image}.png {key}\n")
                    counter_image += 1
                    print(f"{key}: {counter_image}/{value}")
              

COVID-19: 1/8000
COVID-19: 2/8000
COVID-19: 3/8000
COVID-19: 4/8000
COVID-19: 5/8000
COVID-19: 6/8000
COVID-19: 7/8000
COVID-19: 8/8000
COVID-19: 9/8000
COVID-19: 10/8000
COVID-19: 11/8000
COVID-19: 12/8000
COVID-19: 13/8000
COVID-19: 14/8000
COVID-19: 15/8000
COVID-19: 16/8000
COVID-19: 17/8000
COVID-19: 18/8000
COVID-19: 19/8000
COVID-19: 20/8000
COVID-19: 21/8000
COVID-19: 22/8000
COVID-19: 23/8000
COVID-19: 24/8000
COVID-19: 25/8000
COVID-19: 26/8000
COVID-19: 27/8000
COVID-19: 28/8000
COVID-19: 29/8000
COVID-19: 30/8000
COVID-19: 31/8000
COVID-19: 32/8000
COVID-19: 33/8000
COVID-19: 34/8000
COVID-19: 35/8000
COVID-19: 36/8000
COVID-19: 37/8000
COVID-19: 38/8000
COVID-19: 39/8000
COVID-19: 40/8000
COVID-19: 41/8000
COVID-19: 42/8000
COVID-19: 43/8000
COVID-19: 44/8000
COVID-19: 45/8000
COVID-19: 46/8000
COVID-19: 47/8000
COVID-19: 48/8000
COVID-19: 49/8000
COVID-19: 50/8000
COVID-19: 51/8000
COVID-19: 52/8000
COVID-19: 53/8000
COVID-19: 54/8000
COVID-19: 55/8000
COVID-19: 56/8000
C