In [4]:
# example of a dcgan on cifar10
from numpy import ones
from numpy.random import randint, randn
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import sqrtm
from sklearn.metrics import precision_score, recall_score
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import LeakyReLU, Conv2D, Conv2DTranspose, Embedding, Concatenate
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.datasets.cifar10 import load_data
from tensorflow.keras import backend as K
import tensorflow
import tensorflow as tf
from tensorflow.keras.layers import Lambda, Permute

 #Flattened input
class VBN(tf.keras.layers.Layer):
    def __init__(self,
                 epsilon=1e-5,
                 beta_initializer='zeros',
                 gamma_initializer='ones',
                 **kwargs):
        
        self.epsilon = epsilon
        self.beta_initializer = beta_initializer
        self.gamma_initializer = gamma_initializer
        self.batch_size = 128
        
        self.ref_mean = None
        self.ref_mean_sq = None
            
        super(VBN, self).__init__(**kwargs)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        #Pas convaincu du -1 mdr
        self.gamma = self.add_weight(name='gamma', 
                                      shape=(input_shape[-1],),
                                      initializer=self.gamma_initializer,
                                      trainable=True)
        
        self.beta = self.add_weight(name='beta', 
                                      shape=(input_shape[-1],),
                                      initializer=self.beta_initializer,
                                      trainable=True)
    
        super(VBN, self).build(input_shape)  # Be sure to call this at the end
    
    def call(self, x, training=None):
        #Getting training value to know if we are in inference or not
        if training is None:
            training = K.learning_phase()
        if isinstance(training, int):
            training = bool(training)       

        #If True, we need to update stats (using ref batch)
        if self.ref_mean is None or self.ref_mean_sq is None:
            self.ref_mean = tf.reduce_mean(x, [0,1,2], keepdims=True, name='new_mean')
            self.ref_mean_sq = tf.reduce_mean(tf.square(x), [0,1,2], keepdims=True, name="new_mean_sq") # 0 1 2
            return self._normalize(x, self.ref_mean, self.ref_mean_sq)
        
        #Else, they are set, so we use them
        new_coeff = 1./(self.batch_size + 1)
        old_coeff = 1 - new_coeff
        new_mean = tf.reduce_mean(x, [1,2], keepdims=True, name='new_mean')
        new_mean_sq = tf.reduce_mean(tf.square(x), [1,2], keepdims=True, name="new_mean_sq")
        mean = new_coeff * new_mean + old_coeff * self.ref_mean
        mean_sq = new_coeff * new_mean_sq + old_coeff * self.ref_mean_sq
        
        out = self._normalize(x, mean, mean_sq)
        self.ref_mean = None
        self.ref_mean_sq = None
        return out
        
    def _normalize(self, x, mean, mean_sq):
        std = tf.sqrt(self.epsilon + mean_sq - tf.square(mean))
        beta = tf.reshape(self.beta,[1,1,1,-1])
        gamma = tf.reshape(self.gamma,[1,1,1,-1])
        
        out = (x - mean) / std
        out = out * gamma
        out = out + beta
        return out
    
    def compute_output_shape(self, input_shape):
        return input_shape[0]
      
    def get_config(self):
        config = super().get_config()
        config['epsilon'] =  self.epsilon
        config['beta_initializer'] = self.beta_initializer
        config["gamma_initializer"] = self.gamma_initializer
        return config
    
# define the standalone discriminator model
def define_discriminator(in_shape=(32,32,3), n_classes=10):
    # Label input
    label_input = Input(shape=(1,))
    li = Embedding(n_classes, 50)(label_input)
    n_nodes = in_shape[0] * in_shape[1]
    li = Dense(n_nodes)(li)
    li = Reshape((in_shape[0], in_shape[1], 1))(li)

    # Image input
    image_input = Input(shape=in_shape)

    # Merge inputs
    merge = Concatenate()([image_input, li])

    # Downsample
    fe = Conv2D(64, (3,3), strides=(2,2), padding='same')(merge)
    fe = LeakyReLU(alpha=0.2)(fe)

    # Downsample
    fe = Conv2D(128, (3,3), strides=(2,2), padding='same')(fe)
    fe = LeakyReLU(alpha=0.2)(fe)

    # Flatten and output
    fe = Flatten()(fe)
    fe = Dropout(0.7)(fe)
    out_layer = Dense(1, activation='sigmoid')(fe)

    # Define model
    model = Model([image_input, label_input], out_layer)
    return model

 
# define the standalone generator model
def define_generator(latent_dim, n_classes=10):
    # Label input
    label_input = Input(shape=(1,))
    li = Embedding(n_classes, 50)(label_input)
    n_nodes = 8 * 8
    li = Dense(n_nodes)(li)
    li = Reshape((8, 8, 1))(li)

    # Latent input
    latent_input = Input(shape=(latent_dim,))
    gen = Dense(256 * 8 * 8)(latent_input)
    gen = LeakyReLU(alpha=0.2)(gen)
    gen = Reshape((8, 8, 256))(gen)

    # Merge inputs
    merge = Concatenate()([gen, li])

    # Upsample to 16x16
    gen = Conv2DTranspose(256, (4,4), strides=(2,2), padding='same')(merge)
    gen = LeakyReLU(alpha=0.2)(gen)
    gen = VBN()(gen)

    # Upsample to 32x32
    gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(gen)
    gen = LeakyReLU(alpha=0.2)(gen)
    gen = VBN()(gen)

    # Output layer 32x32x3
    out_layer = Conv2D(3, (3,3), activation='tanh', padding='same')(gen)

    # Define model
    model = Model([latent_input, label_input], out_layer)
    return model
    
    
# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model):
    d_model.trainable = False
    gen_noise, gen_label = g_model.input
    gen_output = g_model.output
    gan_output = d_model([gen_output, gen_label])
    model = Model([gen_noise, gen_label], gan_output)
    opt = Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model
    
# load and prepare cifar10 training images
def load_real_samples():
    (trainX, trainy), (_, _) = load_data()
    X = trainX.astype('float32')
    X = (X - 127.5) / 127.5
    return [X, trainy]

# select real samples
def generate_real_samples(dataset, n_samples):
    # Verify that dataset is a list or tuple with two items
    if not isinstance(dataset, (list, tuple)) or len(dataset) != 2:
        raise ValueError("Expected dataset to be a list or tuple of [images, labels]")

    images, labels = dataset
    ix = randint(0, images.shape[0], n_samples)
    X, labels = images[ix], labels[ix]
    y = ones((n_samples, 1))
    return [X, labels], y

# use the generator to generate n fake examples, with class labels
def generate_fake_samples(g_model, latent_dim, n_samples, n_classes=10):
    x_input = randn(latent_dim * n_samples)
    z_input = x_input.reshape(n_samples, latent_dim)
    labels = randint(0, n_classes, n_samples)
    images = g_model.predict([z_input, labels])
    y = np.zeros((n_samples, 1))
    return [images, labels], y

# generate points in latent space as input for the generator
def generate_latent_points(latent_dim, n_samples):
    x_input = randn(latent_dim * n_samples)
    z_input = x_input.reshape(n_samples, latent_dim)
    labels = randint(0, 10, n_samples)
    return [z_input, labels]

# create and save a plot of generated images
def save_plot(examples, epoch, n=7):
    # scale from [-1,1] to [0,1]
    examples = (examples + 1) / 2.0
    # plot images
    for i in range(n * n):
        # define subplot
        plt.subplot(n, n, 1 + i)
        # turn off axis
        plt.axis('off')
        # plot raw pixel data
        plt.imshow(examples[i])
        # save plot to file
        filename = 'generated_plot_e%03d.png' % (epoch+1)
        plt.savefig(filename)
        plt.close()

def calculate_fid(dataset, generated_images):
    real_images = dataset[0]  # Extracting real images from dataset
    generated_images = np.array(generated_images[0]) 
    # Flatten images to 2D
    real_images_2d = real_images.reshape(real_images.shape[0], -1)
    generated_images_2d = generated_images.reshape(generated_images.shape[0], -1)

    # Calculate the mean and covariance of real and generated images
    mu1, sigma1 = real_images_2d.mean(axis=0), np.cov(real_images_2d, rowvar=False)
    mu2, sigma2 = generated_images_2d.mean(axis=0), np.cov(generated_images_2d, rowvar=False)

    # Calculate the sum of the squared difference of the means
    ssdiff = np.sum((mu1 - mu2)**2.0)

    # Calculate sqrt of product of covariances
    covmean = sqrtm(sigma1.dot(sigma2))
    
    # Check for imaginary numbers and convert to real
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    # Calculate FID
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid


def calculate_precision_recall(dataset, generated_images):
    real_images = dataset[0]  # Extracting real images from dataset

    # Ensure generated_images is a numpy array
    generated_images = np.array(generated_images[0])  # Extracting only images from generated samples

    # Ensure the same number of samples in both sets
    min_samples = min(real_images.shape[0], generated_images.shape[0])
    real_images = real_images[:min_samples]
    generated_images = generated_images[:min_samples]

    # Flatten images
    real_images_2d = real_images.reshape(min_samples, -1)
    generated_images_2d = generated_images.reshape(min_samples, -1)

    # Binarize images (assuming values are in 0-255 range)
    threshold = 128
    real_images_binarized = (real_images_2d > threshold).astype(int)
    generated_images_binarized = (generated_images_2d > threshold).astype(int)

    # Calculate precision and recall
    precision = precision_score(real_images_binarized, generated_images_binarized, average='micro', zero_division=0)
    recall = recall_score(real_images_binarized, generated_images_binarized, average='micro', zero_division=0)

    return precision, recall

# evaluate the discriminator, plot generated images, save generator model
def summarize_performance(g_model, d_model, dataset, latent_dim, n_samples=100):
    # Prepare real samples
    [X_real, labels_real], y_real = generate_real_samples(dataset, n_samples)
    # Evaluate discriminator on real examples
    _, acc_real = d_model.evaluate([X_real, labels_real], y_real, verbose=0)
    
    # Prepare fake examples
    [x_fake, labels_fake], y_fake = generate_fake_samples(g_model, latent_dim, n_samples)
    # Evaluate discriminator on fake examples
    _, acc_fake = d_model.evaluate([x_fake, labels_fake], y_fake, verbose=0)
    
    # Summarize discriminator performance
    print(f'>Accuracy | real: {acc_real*100:.0f}%, fake: {acc_fake*100:.0f}%')
    return acc_real, acc_fake
    
# train the generator and discriminator
def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs, n_batch):
    metrics = {'acc_real': [], 'acc_fake': [], 'd_loss': [], 'g_loss': [], 'fid': [], 'precision': [], 'recall': []}
    bat_per_epo = int(dataset[0].shape[0] / n_batch)
    half_batch = int(n_batch / 2)
    for i in range(n_epochs):
        for j in range(bat_per_epo):
            [X_real, labels_real], y_real = generate_real_samples(dataset, half_batch)
            d_loss1, _ = d_model.train_on_batch([X_real, labels_real], y_real)
            [X_fake, labels_fake], y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
            d_loss2, _ = d_model.train_on_batch([X_fake, labels_fake], y_fake)
            [z_input, labels_input] = generate_latent_points(latent_dim, n_batch)
            y_gan = np.ones((n_batch, 1))
            g_loss = gan_model.train_on_batch([z_input, labels_input], y_gan)
            print(f'Epoch: {i+1}/{n_epochs}, Batch: {j+1}/{bat_per_epo}, D loss: {d_loss1:.3f}, G loss: {g_loss:.3f}')
        # evaluate
        if (i+1) % 5 == 0:
            acc_real, acc_fake = summarize_performance(g_model, d_model, dataset, latent_dim)
            generated_fake_images, _ = generate_fake_samples(g_model, latent_dim, 100)
            fid = calculate_fid(dataset, generated_fake_images)
            precision, recall = calculate_precision_recall(dataset, generated_fake_images)
            metrics['acc_real'].append(acc_real)
            metrics['acc_fake'].append(acc_fake)
            metrics['d_loss'].append((d_loss1 + d_loss2) / 2)
            metrics['g_loss'].append(g_loss)
            metrics['fid'].append(fid)
            metrics['precision'].append(precision)
            metrics['recall'].append(recall)
        if (i+1) % 10 == 0:
            d_model.save_weights('model/discriminator_model_virtual.h5')
            g_model.save_weights('model/generator_model_virtual.h5')  
    return metrics

In [5]:
def plot_metrics(metrics, epochs):
    # Plotting Loss
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, metrics['acc_real'], label='Accuracy Real')
    plt.plot(epochs, metrics['acc_fake'], label='Accuracy Fake')
    plt.title('Discriminator Accuracy During Training')
    plt.xlabel('Eval')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()

    plt.figure(figsize=(10, 6))
    plt.plot(epochs, metrics['d_loss'], label='Discriminator Loss')
    plt.plot(epochs, metrics['g_loss'], label='Generator Loss')
    plt.title('Discriminator and Generator Loss During Training')
    plt.xlabel('Eval')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

    # Plotting FID, sFID, Class-aware-FID, and MiFID
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, metrics['fid'], label='FID')
    plt.title('FID During Training')
    plt.xlabel('Eval')
    plt.ylabel('Score')
    plt.legend()
    plt.show()


In [6]:
# size of the latent space
latent_dim = 128
# create the discriminator
d_model = define_discriminator()
d_model.compile(loss='binary_crossentropy', optimizer=SGD(learning_rate=0.0002, momentum=0.5), metrics=['accuracy'])

# create the generator
g_model = define_generator(latent_dim)
g_model.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))

# create the gan
gan_model = define_gan(g_model, d_model)
gan_model.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))

# load image data
dataset = load_real_samples()
# train model
metrics = train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=200, n_batch=128)

epochs = range(1, len(metrics['d_loss']) + 1)
plot_metrics(metrics, epochs)

TypeError: <tf.Tensor 'vbn_2/new_mean:0' shape=(1, 1, 1, 256) dtype=float32> is out of scope and cannot be used here. Use return values, explicit Python locals or TensorFlow collections to access it.
Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.

<tf.Tensor 'vbn_2/new_mean:0' shape=(1, 1, 1, 256) dtype=float32> was defined here:
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\runpy.py", line 196, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\runpy.py", line 86, in _run_code
      exec(code, run_globals)
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\traitlets\config\application.py", line 1053, in launch_instance
      app.start()
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\ipykernel\kernelapp.py", line 737, in start
      self.io_loop.start()
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\tornado\platform\asyncio.py", line 195, in start
      self.asyncio_loop.run_forever()
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\asyncio\base_events.py", line 603, in run_forever
      self._run_once()
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\asyncio\base_events.py", line 1909, in _run_once
      handle._run()
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\asyncio\events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\ipykernel\kernelbase.py", line 524, in dispatch_queue
      await self.process_one()
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\ipykernel\kernelbase.py", line 513, in process_one
      await dispatch(*args)
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\ipykernel\kernelbase.py", line 418, in dispatch_shell
      await result
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\ipykernel\kernelbase.py", line 758, in execute_request
      reply_content = await reply_content
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\ipykernel\ipkernel.py", line 426, in do_execute
      res = shell.run_cell(
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\ipykernel\zmqshell.py", line 549, in run_cell
      return super().run_cell(*args, **kwargs)
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\IPython\core\interactiveshell.py", line 3046, in run_cell
      result = self._run_cell(
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\IPython\core\interactiveshell.py", line 3101, in _run_cell
      result = runner(coro)
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\IPython\core\interactiveshell.py", line 3306, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\IPython\core\interactiveshell.py", line 3488, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\IPython\core\interactiveshell.py", line 3548, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:\Users\ethan\AppData\Local\Temp\ipykernel_3048\3160050400.py", line 8, in <module>
      g_model = define_generator(latent_dim)
    File "C:\Users\ethan\AppData\Local\Temp\ipykernel_3048\1671772212.py", line 151, in define_generator
      gen = VBN()(gen)
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\keras\engine\base_layer.py", line 1011, in __call__
      return self._functional_construction_call(
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\keras\engine\base_layer.py", line 2498, in _functional_construction_call
      outputs = self._keras_tensor_symbolic_call(
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\keras\engine\base_layer.py", line 2345, in _keras_tensor_symbolic_call
      return self._infer_output_signature(
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\keras\engine\base_layer.py", line 2404, in _infer_output_signature
      outputs = call_fn(inputs, *args, **kwargs)
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\ethan\AppData\Local\Temp\ipykernel_3048\2893464449.py", line 59, in call
      if self.ref_mean is None or self.ref_mean_sq is None:
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\tensorflow\python\autograph\operators\control_flow.py", line 1363, in if_stmt
      _py_if_stmt(cond, body, orelse)
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\tensorflow\python\autograph\operators\control_flow.py", line 1416, in _py_if_stmt
      return body() if cond else orelse()
    File "C:\Users\ethan\AppData\Local\Temp\ipykernel_3048\2893464449.py", line 60, in call
      self.ref_mean = tf.reduce_mean(x, [0,1,2], keepdims=True, name='new_mean')
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\tensorflow\python\util\traceback_utils.py", line 150, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\tensorflow\python\util\dispatch.py", line 1176, in op_dispatch_handler
      return dispatch_target(*args, **kwargs)
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\tensorflow\python\ops\math_ops.py", line 2640, in reduce_mean
      gen_math_ops.mean(
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\tensorflow\python\ops\gen_math_ops.py", line 6285, in mean
      _, _, _op, _outputs = _op_def_library._apply_op_helper(
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\tensorflow\python\framework\op_def_library.py", line 797, in _apply_op_helper
      op = g._create_op_internal(op_type_name, inputs, dtypes=None,
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\tensorflow\python\framework\func_graph.py", line 735, in _create_op_internal
      return super(FuncGraph, self)._create_op_internal(  # pylint: disable=protected-access
    File "c:\Users\ethan\anaconda3\envs\gpu_env2\lib\site-packages\tensorflow\python\framework\ops.py", line 3800, in _create_op_internal
      ret = Operation(

The tensor <tf.Tensor 'vbn_2/new_mean:0' shape=(1, 1, 1, 256) dtype=float32> cannot be accessed from here, because it was defined in FuncGraph(name=vbn_2_scratch_graph, id=2251464527152), which is out of scope.

In [None]:
def generate_images(generator_model, latent_dim, n_samples, labels):
    # Generate points in latent space
    latent_points, sample_labels = generate_latent_points(latent_dim, n_samples)
    # Update latent points with provided labels
    sample_labels[:] = labels
    # Generate images
    images = generator_model.predict([latent_points, sample_labels])
    # Scale from [-1,1] to [0,1]
    images = (images + 1) / 2.0
    # Plot images
    plt.figure(figsize=(10, 10))
    for i in range(n_samples):
        plt.subplot(int(np.sqrt(n_samples)), int(np.sqrt(n_samples)), 1 + i)
        plt.axis('off')
        plt.imshow(images[i])
    plt.show()

# Example usage
generate_images(g_model, latent_dim, n_samples=25, labels = 1)