In [None]:
from google.colab import drive
drive.mount('/content/drive')



In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import os
from IPython.display import display
from tqdm import tqdm
from IPython.display import display

In [None]:
EPOCHS = 100
BATCH_SIZE = 64
LATENT_DIM = 128
IMG_SIZE = 64
NUM_CLASSES = 10

N_CRITIC = 3
LAMBDA_GP = 10
LR = 0.0002

BASE_DIR = "/content/drive/MyDrive/Lab_2_GAN/WGAN_GP"
IMG_DIR = os.path.join(BASE_DIR, "images")
MODEL_DIR = os.path.join(BASE_DIR, "models")
CKPT_DIR = os.path.join(BASE_DIR, "checkpoints")

os.makedirs(IMG_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(CKPT_DIR, exist_ok=True)


In [None]:

 (x_train, y_train), _ = tf.keras.datasets.fashion_mnist.load_data()

x_train = x_train.astype("float32") / 127.5 - 1.0
x_train = np.expand_dims(x_train, axis=-1)
x_train = tf.image.resize(x_train, (IMG_SIZE, IMG_SIZE))

y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)

dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset = dataset.shuffle(60000).batch(BATCH_SIZE, drop_remainder=True)


In [None]:
def build_generator():
    noise = layers.Input(shape=(LATENT_DIM,))
    label = layers.Input(shape=(NUM_CLASSES,))

    x = layers.Concatenate()([noise, label])

    x = layers.Dense(8 * 8 * 512, use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Reshape((8, 8, 512))(x)

    # 8x8 -> 16x16
    x = layers.Conv2DTranspose(256, 4, strides=2, padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    # 16x16 -> 32x32
    x = layers.Conv2DTranspose(128, 4, strides=2, padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    # 32x32 -> 64x64
    x = layers.Conv2DTranspose(64, 4, strides=2, padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    # Final output (NO upsample here)
    img = layers.Conv2D(
        1, kernel_size=3, padding="same", activation="tanh"
    )(x)

    return tf.keras.Model([noise, label], img)


In [None]:
def build_discriminator():
    img = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 1))
    label = layers.Input(shape=(NUM_CLASSES,))

    label_map = layers.Dense(IMG_SIZE * IMG_SIZE)(label)
    label_map = layers.Reshape((IMG_SIZE, IMG_SIZE, 1))(label_map)

    x = layers.Concatenate()([img, label_map])

    x = layers.Conv2D(64, 4, strides=2, padding="same")(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(128, 4, strides=2, padding="same")(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(256, 4, strides=2, padding="same")(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(512, 4, strides=2, padding="same")(x)
    x = layers.LeakyReLU(0.2)(x)

    x = layers.Flatten()(x)
    out = layers.Dense(1)(x)

    return tf.keras.Model([img, label], out)


In [None]:
generator = build_generator()
discriminator = build_discriminator()

g_opt = tf.keras.optimizers.Adam(LR, beta_1=0.0, beta_2=0.9)
d_opt = tf.keras.optimizers.Adam(LR, beta_1=0.0, beta_2=0.9)

# Force build
_ = generator([tf.random.normal((1, LATENT_DIM)), tf.one_hot([0], NUM_CLASSES)])
_ = discriminator([tf.random.normal((1, IMG_SIZE, IMG_SIZE, 1)), tf.one_hot([0], NUM_CLASSES)])


In [None]:
def gradient_penalty(real, fake, labels):
    alpha = tf.random.uniform([BATCH_SIZE, 1, 1, 1], 0., 1.)
    interpolated = real * alpha + fake * (1 - alpha)

    with tf.GradientTape() as tape:
        tape.watch(interpolated)
        pred = discriminator([interpolated, labels], training=True)

    grads = tape.gradient(pred, interpolated)
    norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1,2,3]))
    gp = tf.reduce_mean((norm - 1.0) ** 2)
    return gp


In [None]:
@tf.function
def train_step(real_imgs, labels):
    for _ in range(N_CRITIC):
        noise = tf.random.normal([BATCH_SIZE, LATENT_DIM])
        with tf.GradientTape() as tape:
            fake_imgs = generator([noise, labels], training=True)
            real_out = discriminator([real_imgs, labels], training=True)
            fake_out = discriminator([fake_imgs, labels], training=True)
            gp = gradient_penalty(real_imgs, fake_imgs, labels)
            d_loss = tf.reduce_mean(fake_out) - tf.reduce_mean(real_out) + LAMBDA_GP * gp

        grads = tape.gradient(d_loss, discriminator.trainable_variables)
        d_opt.apply_gradients(zip(grads, discriminator.trainable_variables))

    noise = tf.random.normal([BATCH_SIZE, LATENT_DIM])
    with tf.GradientTape() as tape:
        fake_imgs = generator([noise, labels], training=True)
        fake_out = discriminator([fake_imgs, labels], training=True)
        g_loss = -tf.reduce_mean(fake_out)

    grads = tape.gradient(g_loss, generator.trainable_variables)
    g_opt.apply_gradients(zip(grads, generator.trainable_variables))

    return d_loss, g_loss


In [None]:

def save_images(epoch):
    noise = tf.random.normal([25, LATENT_DIM])
    labels = tf.one_hot(np.random.randint(0,10,25), NUM_CLASSES)
    imgs = generator([noise, labels], training=False)
    imgs = (imgs + 1) / 2

    fig, axs = plt.subplots(5,5, figsize=(6,6))
    idx = 0
    for i in range(5):
        for j in range(5):
            axs[i,j].imshow(imgs[idx,:,:,0], cmap="gray")
            axs[i,j].axis("off")
            idx+=1
    plt.savefig(os.path.join(IMG_DIR, f"epoch_{epoch}.png"))
    display(fig)
    plt.close()


In [None]:
epoch_var = tf.Variable(0, dtype=tf.int64)

ckpt = tf.train.Checkpoint(
    generator=generator,
    discriminator=discriminator,
    g_opt=g_opt,
    d_opt=d_opt,
    epoch=epoch_var
)

ckpt_manager = tf.train.CheckpointManager(ckpt, CKPT_DIR, max_to_keep=5)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print(f"âœ… Resumed from checkpoint at epoch {int(epoch_var.numpy())}")
else:
    print("ðŸš€ No checkpoint found. Starting from scratch.")


In [None]:

start_epoch = int(epoch_var.numpy()) + 1

for epoch in range(start_epoch, EPOCHS + 1):
    print(f"\nEpoch {epoch}/{EPOCHS}")

    bar = tqdm(dataset, leave=False)
    for real_imgs, labels in bar:
        d_loss, g_loss = train_step(real_imgs, labels)
        bar.set_postfix(D_loss=f"{d_loss:.3f}", G_loss=f"{g_loss:.3f}")

    print(f"Epoch {epoch} | D_loss: {d_loss:.3f} | G_loss: {g_loss:.3f}")

    if epoch % 5 == 0:
        save_images(epoch)

    epoch_var.assign(epoch)      #  save epoch number
    ckpt_manager.save()
    print("ðŸ’¾ Checkpoint saved")


In [None]:
generator.save(os.path.join(MODEL_DIR, "generator_final.h5"))
discriminator.save(os.path.join(MODEL_DIR, "discriminator_final.h5"))
print("âœ… Final models saved")


In [None]:
# Restore latest checkpoint for evaluation
ckpt.restore(ckpt_manager.latest_checkpoint)

print("Loaded checkpoint at epoch:", int(epoch_var.numpy()))


In [None]:
epoch_var.assign(55)   # put your real trained epoch here
print("Epoch corrected to:", int(epoch_var.numpy()))


In [None]:
# Load real images
(x_train, _), _ = tf.keras.datasets.fashion_mnist.load_data()
x_train = x_train.astype("float32") / 127.5 - 1.0
x_train = np.expand_dims(x_train, axis=-1)
x_train = tf.image.resize(x_train, (IMG_SIZE, IMG_SIZE))

real_imgs = (x_train[:16] + 1) / 2

# Generate fake images
noise = tf.random.normal((16, LATENT_DIM))
labels = tf.one_hot(np.random.randint(0, 10, 16), NUM_CLASSES)
fake_imgs = (generator([noise, labels], training=False) + 1) / 2

plt.figure(figsize=(12,6))

# REAL
for i in range(16):
    plt.subplot(4,8,i+1)
    plt.imshow(real_imgs[i,:,:,0], cmap="gray")
    plt.title("Real", fontsize=8)
    plt.axis("off")

# FAKE
for i in range(16):
    plt.subplot(4,8,16+i+1)
    plt.imshow(fake_imgs[i,:,:,0], cmap="gray")
    plt.title("Fake", fontsize=8)
    plt.axis("off")

plt.suptitle(
    f"Real vs Fake Images (WGAN-GP, Epoch {int(epoch_var.numpy())})",
    fontsize=14
)
plt.tight_layout()
plt.show()
