## Import Necessary Libraries

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import io
from scipy.linalg import sqrtm
import medmnist
from medmnist import INFO

In [2]:
# Ensure reproducibility
tf.random.set_seed(42)
np.random.seed(42)

# Verify GPU availability
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

# Set GPU memory growth to prevent TensorFlow from allocating all GPU memory
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("GPU memory growth enabled.")
    except RuntimeError as e:
        print("Error setting memory growth:", e)

# --- Dataset Loading and Preprocessing ---
def load_retinamnist():
    """Load and preprocess the RetinaMNIST dataset from MedMNIST."""
    info = INFO['retinamnist']
    DataClass = getattr(medmnist, info['python_class'])
    train_dataset = DataClass(split='train', download=True)
    x_train = train_dataset.imgs  # Shape: (N, 28, 28)
    x_train = x_train.reshape(-1, 28, 28, 1).astype('float32')
    x_train = (x_train - 127.5) / 127.5  # Normalize to [-1, 1]
    return x_train

Num GPUs Available:  0


In [3]:
x_train = load_retinamnist()
batch_size = 64
dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(10000).batch(batch_size)

In [4]:
# --- Model Definitions ---
def build_generator(noise_dim=100):
    """Build the generator model."""
    model = keras.Sequential([
        layers.Dense(7*7*128, input_dim=noise_dim),
        layers.Reshape((7, 7, 128)),
        layers.BatchNormalization(),
        layers.LeakyReLU(0.2),
        layers.Conv2DTranspose(64, 4, strides=2, padding='same'),
        layers.BatchNormalization(),
        layers.LeakyReLU(0.2),
        layers.Conv2DTranspose(32, 4, strides=2, padding='same'),
        layers.BatchNormalization(),
        layers.LeakyReLU(0.2),
        layers.Conv2D(1, 3, padding='same', activation='tanh')
    ])
    return model

def build_discriminator_ls():
    """Build the discriminator for LS-GAN."""
    model = keras.Sequential([
        layers.Conv2D(32, 3, strides=2, padding='same', input_shape=(28, 28, 1)),
        layers.LeakyReLU(0.2),
        layers.Dropout(0.3),
        layers.Conv2D(64, 3, strides=2, padding='same'),
        layers.LeakyReLU(0.2),
        layers.Dropout(0.3),
        layers.Flatten(),
        layers.Dense(1, activation='sigmoid')
    ])
    return model

def build_critic_wgan():
    """Build the critic for WGAN."""
    model = keras.Sequential([
        layers.Conv2D(32, 3, strides=2, padding='same', input_shape=(28, 28, 1)),
        layers.LeakyReLU(0.2),
        layers.Conv2D(64, 3, strides=2, padding='same'),
        layers.LeakyReLU(0.2),
        layers.Flatten(),
        layers.Dense(1)
    ])
    return model

# WGAN-GP uses the same critic architecture
build_critic_wgan_gp = build_critic_wgan

In [5]:
# --- Loss Functions ---
def ls_discriminator_loss(real_output, fake_output):
    """Least squares loss for LS-GAN discriminator."""
    real_loss = tf.reduce_mean(tf.square(real_output - 1))
    fake_loss = tf.reduce_mean(tf.square(fake_output))
    return 0.5 * (real_loss + fake_loss)

def ls_generator_loss(fake_output):
    """Least squares loss for LS-GAN generator."""
    return tf.reduce_mean(tf.square(fake_output - 1))

def wgan_critic_loss(real_output, fake_output):
    """Wasserstein loss for WGAN critic."""
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

def wgan_generator_loss(fake_output):
    """Wasserstein loss for WGAN generator."""
    return -tf.reduce_mean(fake_output)

def gradient_penalty(critic, real_images, fake_images, batch_size, lambda_gp=10):
    """Compute gradient penalty for WGAN-GP."""
    epsilon = tf.random.uniform([batch_size, 1, 1, 1], 0.0, 1.0)
    interpolated = epsilon * real_images + (1 - epsilon) * fake_images
    with tf.GradientTape() as gp_tape:
        gp_tape.watch(interpolated)
        critic_interpolated = critic(interpolated, training=True)
    grads = gp_tape.gradient(critic_interpolated, interpolated)
    norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
    gp = tf.reduce_mean((norm - 1.0) ** 2)
    return lambda_gp * gp

def wgan_gp_critic_loss(critic, real_images, fake_images, batch_size, lambda_gp=10):
    """Wasserstein loss with gradient penalty for WGAN-GP critic."""
    real_output = critic(real_images, training=True)
    fake_output = critic(fake_images, training=True)
    c_loss = tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)
    gp = gradient_penalty(critic, real_images, fake_images, batch_size)
    return c_loss + gp

# --- Utility Functions ---
def plot_to_image(figure):
    """Convert a Matplotlib figure to a TensorFlow image tensor for TensorBoard."""
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    plt.close(figure)
    buf.seek(0)
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    image = tf.expand_dims(image, 0)
    return image

# Load InceptionV3 for evaluation
inception_model = keras.applications.InceptionV3(include_top=True, input_shape=(299, 299, 3))
feature_extractor = keras.Model(inputs=inception_model.input, outputs=inception_model.get_layer('avg_pool').output)

def compute_is(generator, noise_dim, num_samples=1000):
    """Compute Inception Score."""
    noise = tf.random.normal([num_samples, noise_dim])
    generated_images = generator(noise, training=False)
    generated_images = (generated_images + 1) / 2  # Rescale to [0, 1]
    generated_images = tf.image.resize(generated_images, [299, 299])
    generated_images = tf.repeat(generated_images, 3, axis=-1)  # Grayscale to RGB
    preds = inception_model.predict(generated_images, batch_size=32)
    preds = tf.nn.softmax(preds, axis=1)
    kl_div = tf.reduce_sum(preds * (tf.math.log(preds + 1e-10) - tf.math.log(tf.reduce_mean(preds, axis=0) + 1e-10)), axis=1)
    is_score = tf.exp(tf.reduce_mean(kl_div))
    return is_score.numpy()

def compute_fid(generator, real_images, noise_dim, num_samples=1000):
    """Compute Fréchet Inception Distance."""
    noise = tf.random.normal([num_samples, noise_dim])
    generated_images = generator(noise, training=False)
    generated_images = (generated_images + 1) / 2
    generated_images = tf.image.resize(generated_images, [299, 299])
    generated_images = tf.repeat(generated_images, 3, axis=-1)
    
    real_images = real_images[:num_samples]
    real_images = (real_images + 1) / 2
    real_images = tf.image.resize(real_images, [299, 299])
    real_images = tf.repeat(real_images, 3, axis=-1)
    
    real_features = feature_extractor.predict(real_images, batch_size=32)
    generated_features = feature_extractor.predict(generated_images, batch_size=32)
    
    mu_real, sigma_real = np.mean(real_features, axis=0), np.cov(real_features, rowvar=False)
    mu_gen, sigma_gen = np.mean(generated_features, axis=0), np.cov(generated_features, rowvar=False)
    
    diff = mu_real - mu_gen
    covmean = sqrtm(sigma_real.dot(sigma_gen))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    fid = np.sum(diff**2) + np.trace(sigma_real + sigma_gen - 2*covmean)
    return fid

In [6]:
# --- Training Functions ---
def train_ls_gan(generator, discriminator, dataset, epochs=50, batch_size=64, noise_dim=100):
    """Train LS-GAN."""
    d_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
    g_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
    log_dir = 'logs/ls_gan'
    summary_writer = tf.summary.create_file_writer(log_dir)
    
    for epoch in range(epochs):
        for real_images in dataset:
            batch_size_current = real_images.shape[0]
            noise = tf.random.normal([batch_size_current, noise_dim])
            with tf.GradientTape() as d_tape, tf.GradientTape() as g_tape:
                generated_images = generator(noise, training=True)
                real_output = discriminator(real_images, training=True)
                fake_output = discriminator(generated_images, training=True)
                d_loss = ls_discriminator_loss(real_output, fake_output)
                g_loss = ls_generator_loss(fake_output)
            d_grads = d_tape.gradient(d_loss, discriminator.trainable_variables)
            g_grads = g_tape.gradient(g_loss, generator.trainable_variables)
            d_optimizer.apply_gradients(zip(d_grads, discriminator.trainable_variables))
            g_optimizer.apply_gradients(zip(g_grads, generator.trainable_variables))
        
        with summary_writer.as_default():
            tf.summary.scalar('d_loss', d_loss, step=epoch)
            tf.summary.scalar('g_loss', g_loss, step=epoch)
            summary_writer.flush()
            if epoch % 10 == 0:
                generated_images = generator(tf.random.normal([16, noise_dim]), training=False)
                fig, axes = plt.subplots(4, 4, figsize=(4, 4))
                for i, ax in enumerate(axes.flat):
                    ax.imshow(generated_images[i, :, :, 0] * 0.5 + 0.5, cmap='gray')
                    ax.axis('off')
                tf.summary.image('generated_images', plot_to_image(fig), step=epoch)
        print(f'LS-GAN Epoch {epoch+1}/{epochs}, D_loss: {d_loss:.4f}, G_loss: {g_loss:.4f}')

def train_wgan(generator, critic, dataset, epochs=50, batch_size=64, noise_dim=100, n_critic=5, clip_value=0.01):
    """Train a Wasserstein GAN."""
    g_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9)
    c_optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9)
    log_dir = 'logs/wgan'
    summary_writer = tf.summary.create_file_writer(log_dir)

    for epoch in range(epochs):
        # Train the critic
        for _ in range(n_critic):
            real_images = next(iter(dataset))
            batch_size_current = real_images.shape[0]
            noise = tf.random.normal([batch_size_current, noise_dim])
            with tf.GradientTape() as c_tape:
                generated_images = generator(noise, training=True)
                real_output = critic(real_images, training=True)
                fake_output = critic(generated_images, training=True)
                c_loss = tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)  # WGAN critic loss
            c_grads = c_tape.gradient(c_loss, critic.trainable_variables)
            c_optimizer.apply_gradients(zip(c_grads, critic.trainable_variables))
            # Clip critic weights
            for var in critic.trainable_variables:
                var.assign(tf.clip_by_value(var, -clip_value, clip_value))

        # Train the generator
        noise = tf.random.normal([batch_size, noise_dim])
        with tf.GradientTape() as g_tape:
            generated_images = generator(noise, training=True)
            fake_output = critic(generated_images, training=True)
            g_loss = -tf.reduce_mean(fake_output)  # WGAN generator loss
        g_grads = g_tape.gradient(g_loss, generator.trainable_variables)
        g_optimizer.apply_gradients(zip(g_grads, generator.trainable_variables))

        # Log losses and print progress
        with summary_writer.as_default():
            tf.summary.scalar('c_loss', c_loss, step=epoch)
            tf.summary.scalar('g_loss', g_loss, step=epoch)
            summary_writer.flush()
        print(f'WGAN Epoch {epoch+1}/{epochs}, C_loss: {c_loss:.4f}, G_loss: {g_loss:.4f}')
        
def train_wgan_gp(generator, critic, dataset, epochs=50, batch_size=64, noise_dim=100, n_critic=5, lambda_gp=10):
    """Train WGAN-GP."""
    g_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9)
    c_optimizer = keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9)
    log_dir = 'logs/wgan_gp'
    summary_writer = tf.summary.create_file_writer(log_dir)
    
    for epoch in range(epochs):
        for _ in range(n_critic):
            real_images = next(iter(dataset))
            batch_size_current = real_images.shape[0]
            noise = tf.random.normal([batch_size_current, noise_dim])
            with tf.GradientTape() as c_tape:
                generated_images = generator(noise, training=True)
                c_loss = wgan_gp_critic_loss(critic, real_images, generated_images, batch_size_current, lambda_gp)
            c_grads = c_tape.gradient(c_loss, critic.trainable_variables)
            c_optimizer.apply_gradients(zip(c_grads, critic.trainable_variables))
        
        noise = tf.random.normal([batch_size, noise_dim])
        with tf.GradientTape() as g_tape:
            generated_images = generator(noise, training=True)
            fake_output = critic(generated_images, training=True)
            g_loss = wgan_generator_loss(fake_output)
        g_grads = g_tape.gradient(g_loss, generator.trainable_variables)
        g_optimizer.apply_gradients(zip(g_grads, generator.trainable_variables))
        
        with summary_writer.as_default():
            tf.summary.scalar('c_loss', c_loss, step=epoch)
            tf.summary.scalar('g_loss', g_loss, step=epoch)
            summary_writer.flush()
            if epoch % 10 == 0:
                generated_images = generator(tf.random.normal([16, noise_dim]), training=False)
                fig, axes = plt.subplots(4, 4, figsize=(4, 4))
                for i, ax in enumerate(axes.flat):
                    ax.imshow(generated_images[i, :, :, 0] * 0.5 + 0.5, cmap='gray')
                    ax.axis('off')
                tf.summary.image('generated_images', plot_to_image(fig), step=epoch)
        print(f'WGAN-GP Epoch {epoch+1}/{epochs}, C_loss: {c_loss:.4f}, G_loss: {g_loss:.4f}')

In [7]:
# --- Main Execution ---
def main():
    noise_dim = 100
    epochs = 50

    # LS-GAN
    print("Training LS-GAN...")
    generator_ls = build_generator(noise_dim)
    discriminator_ls = build_discriminator_ls()
    train_ls_gan(generator_ls, discriminator_ls, dataset, epochs, batch_size, noise_dim)
    is_score_ls = compute_is(generator_ls, noise_dim)
    fid_score_ls = compute_fid(generator_ls, x_train, noise_dim)
    print(f'LS-GAN Evaluation - IS: {is_score_ls:.4f}, FID: {fid_score_ls:.4f}\n')

    # WGAN
    print("Training WGAN...")
    generator_wgan = build_generator(noise_dim)
    critic_wgan = build_critic_wgan()
    train_wgan(generator_wgan, critic_wgan, dataset, epochs, batch_size, noise_dim)
    is_score_wgan = compute_is(generator_wgan, noise_dim)
    fid_score_wgan = compute_fid(generator_wgan, x_train, noise_dim)
    print(f'WGAN Evaluation - IS: {is_score_wgan:.4f}, FID: {fid_score_wgan:.4f}\n')

    # WGAN-GP
    print("Training WGAN-GP...")
    generator_wgan_gp = build_generator(noise_dim)
    critic_wgan_gp = build_critic_wgan_gp()
    train_wgan_gp(generator_wgan_gp, critic_wgan_gp, dataset, epochs, batch_size, noise_dim)
    is_score_wgan_gp = compute_is(generator_wgan_gp, noise_dim)
    fid_score_wgan_gp = compute_fid(generator_wgan_gp, x_train, noise_dim)
    print(f'WGAN-GP Evaluation - IS: {is_score_wgan_gp:.4f}, FID: {fid_score_wgan_gp:.4f}')

In [8]:
main()

Training LS-GAN...


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


LS-GAN Epoch 1/50, D_loss: 0.2582, G_loss: 0.2350
LS-GAN Epoch 2/50, D_loss: 0.2373, G_loss: 0.2742
LS-GAN Epoch 3/50, D_loss: 0.2243, G_loss: 0.2791
LS-GAN Epoch 4/50, D_loss: 0.2132, G_loss: 0.3225
LS-GAN Epoch 5/50, D_loss: 0.2194, G_loss: 0.2662
LS-GAN Epoch 6/50, D_loss: 0.2225, G_loss: 0.3224
LS-GAN Epoch 7/50, D_loss: 0.2158, G_loss: 0.2952
LS-GAN Epoch 8/50, D_loss: 0.2237, G_loss: 0.2924
LS-GAN Epoch 9/50, D_loss: 0.2285, G_loss: 0.2968
LS-GAN Epoch 10/50, D_loss: 0.2193, G_loss: 0.2974
LS-GAN Epoch 11/50, D_loss: 0.2267, G_loss: 0.2769
LS-GAN Epoch 12/50, D_loss: 0.2171, G_loss: 0.2739
LS-GAN Epoch 13/50, D_loss: 0.2264, G_loss: 0.2845
LS-GAN Epoch 14/50, D_loss: 0.2222, G_loss: 0.2985
LS-GAN Epoch 15/50, D_loss: 0.2351, G_loss: 0.2830
LS-GAN Epoch 16/50, D_loss: 0.2193, G_loss: 0.2753
LS-GAN Epoch 17/50, D_loss: 0.2394, G_loss: 0.2894
LS-GAN Epoch 18/50, D_loss: 0.2256, G_loss: 0.2860
LS-GAN Epoch 19/50, D_loss: 0.2316, G_loss: 0.2704
LS-GAN Epoch 20/50, D_loss: 0.2331, G_lo

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


WGAN Epoch 1/50, C_loss: -0.0031, G_loss: 0.0017
WGAN Epoch 2/50, C_loss: -0.0045, G_loss: 0.0015
WGAN Epoch 3/50, C_loss: -0.0052, G_loss: 0.0020
WGAN Epoch 4/50, C_loss: -0.0080, G_loss: 0.0033
WGAN Epoch 5/50, C_loss: -0.0116, G_loss: 0.0057
WGAN Epoch 6/50, C_loss: -0.0155, G_loss: 0.0103
WGAN Epoch 7/50, C_loss: -0.0214, G_loss: 0.0166
WGAN Epoch 8/50, C_loss: -0.0326, G_loss: 0.0248
WGAN Epoch 9/50, C_loss: -0.0421, G_loss: 0.0304
WGAN Epoch 10/50, C_loss: -0.0538, G_loss: 0.0403
WGAN Epoch 11/50, C_loss: -0.0647, G_loss: 0.0493
WGAN Epoch 12/50, C_loss: -0.0775, G_loss: 0.0702
WGAN Epoch 13/50, C_loss: -0.0924, G_loss: 0.1105
WGAN Epoch 14/50, C_loss: -0.1222, G_loss: 0.1186
WGAN Epoch 15/50, C_loss: -0.1368, G_loss: 0.1543
WGAN Epoch 16/50, C_loss: -0.1529, G_loss: 0.1845
WGAN Epoch 17/50, C_loss: -0.1849, G_loss: 0.1829
WGAN Epoch 18/50, C_loss: -0.2103, G_loss: 0.1850
WGAN Epoch 19/50, C_loss: -0.2495, G_loss: 0.2112
WGAN Epoch 20/50, C_loss: -0.2948, G_loss: 0.2568
WGAN Epoc