In [None]:
import os
import glob
import random
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import datetime


# Enable eager execution for debugging
tf.config.run_functions_eagerly(False)
#tf.data.experimental.enable_debug_mode()

In [None]:
# ===========================
#        CONFIGURATION
# ===========================
AUTOTUNE = tf.data.AUTOTUNE
IMG_WIDTH = 256
IMG_HEIGHT = 256
BATCH_SIZE = 2
EPOCHS = 50
LAMBDA_CYCLE = 15.0
LAMBDA_IDENTITY = 0.5 * LAMBDA_CYCLE
DAY_PATH = '/content/drive/MyDrive/dataset/day'
NIGHT_PATH = '/content/drive/MyDrive/dataset/night'
CHECKPOINT_DIR = "/content/drive/MyDrive/checkpoints"
OUTPUT_DIR = "/content/drive/MyDrive/generated_images"
LOG_DIR = "/content/drive/MyDrive/logs"
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)


In [None]:
# ===========================
#        DATA AUGMENTATION
# ===========================
def random_jitter(image):
    image = tf.image.resize(image, [IMG_WIDTH + 30, IMG_HEIGHT + 30])
    image = tf.image.random_crop(image, [IMG_WIDTH, IMG_HEIGHT, 3])
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, max_delta=0.2)
    image = tf.image.random_contrast(image, lower=0.8, upper=1.2)
    k = tf.random.uniform([], minval=0, maxval=4, dtype=tf.int32)
    image = tf.image.rot90(image, k=k)
    image = tf.image.random_saturation(image, 0.8, 1.2)
    return image

# ===========================
#     DATASET LOADER
# ===========================
def preprocess_image(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, [IMG_WIDTH, IMG_HEIGHT])
    image = random_jitter(image)
    return image

def load_dataset(day_path, night_path):
    day_images = tf.data.Dataset.list_files(day_path + '/*.jpg', shuffle=True)
    night_images = tf.data.Dataset.list_files(night_path + '/*.jpg', shuffle=True)

    day_ds = day_images.map(preprocess_image, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE).prefetch(AUTOTUNE)
    night_ds = night_images.map(preprocess_image, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE).prefetch(AUTOTUNE)

    return tf.data.Dataset.zip((day_ds, night_ds))


In [None]:
# ===========================
#       MODEL DEFINITIONS
# ===========================
def generator_model():
    inputs = layers.Input(shape=[IMG_WIDTH, IMG_HEIGHT, 3])
    x = layers.Conv2D(64, (7, 7), padding='same', activation='relu')(inputs)
    x = layers.Conv2D(128, (3, 3), strides=2, padding='same', activation='relu')(x)
    x = layers.Conv2D(256, (3, 3), strides=2, padding='same', activation='relu')(x)
    for _ in range(9):
        skip = x
        x = layers.Conv2D(256, (3, 3), padding='same', activation='relu')(x)
        x = layers.Add()([x, skip])
    x = layers.Conv2DTranspose(128, (3, 3), strides=2, padding='same', activation='relu')(x)
    x = layers.Conv2DTranspose(64, (3, 3), strides=2, padding='same', activation='relu')(x)
    outputs = layers.Conv2D(3, (7, 7), padding='same', activation='tanh')(x)
    return keras.Model(inputs, outputs)

def discriminator_model():
    inputs = layers.Input(shape=[IMG_WIDTH, IMG_HEIGHT, 3])
    x = layers.Conv2D(64, 4, strides=2, padding='same')(inputs)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2D(128, 4, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2D(256, 4, strides=2, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2D(512, 4, strides=1, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2D(1, 4, strides=1, padding='same')(x)
    return keras.Model(inputs, x)

G_day_to_night = generator_model()
G_night_to_day = generator_model()
D_day = discriminator_model()
D_night = discriminator_model()

lr_schedule = keras.optimizers.schedules.ExponentialDecay(0.0002, 10000, 0.98)
optimizer_G_day_to_night = keras.optimizers.Adam(learning_rate=lr_schedule, beta_1=0.5)
optimizer_G_night_to_day = keras.optimizers.Adam(learning_rate=lr_schedule, beta_1=0.5)
optimizer_D_day = keras.optimizers.Adam(learning_rate=lr_schedule, beta_1=0.5)
optimizer_D_night = keras.optimizers.Adam(learning_rate=lr_schedule, beta_1=0.5)




In [None]:
# === Checkpoint Setup ===
checkpoint_prefix = os.path.join(CHECKPOINT_DIR, "ckpt")
checkpoint = tf.train.Checkpoint(
    G_day_to_night=G_day_to_night,
    G_night_to_day=G_night_to_day,
    D_day=D_day,
    D_night=D_night,
    optimizer_G_day_to_night=optimizer_G_day_to_night,
    optimizer_G_night_to_day=optimizer_G_night_to_day,
    optimizer_D_day=optimizer_D_day,
    optimizer_D_night=optimizer_D_night
)
checkpoint_manager = tf.train.CheckpointManager(checkpoint, CHECKPOINT_DIR, max_to_keep=5)


In [None]:
loss_obj = keras.losses.MeanSquaredError()

@tf.function
def train_step(real_day, real_night):
    with tf.GradientTape(persistent=True) as tape:
        fake_night = G_day_to_night(real_day, training=True)
        fake_day = G_night_to_day(real_night, training=True)

        cycled_day = G_night_to_day(fake_night, training=True)
        cycled_night = G_day_to_night(fake_day, training=True)

        same_day = G_night_to_day(real_day, training=True)
        same_night = G_day_to_night(real_night, training=True)

        disc_real_day = D_day(real_day, training=True)
        disc_real_night = D_night(real_night, training=True)
        disc_fake_day = D_day(fake_day, training=True)
        disc_fake_night = D_night(fake_night, training=True)

        # Generator adversarial losses
        gen_day_to_night_loss = loss_obj(tf.ones_like(disc_fake_night), disc_fake_night)
        gen_night_to_day_loss = loss_obj(tf.ones_like(disc_fake_day), disc_fake_day)

        # Cycle consistency loss
        total_cycle_loss = tf.reduce_mean(tf.abs(real_day - cycled_day)) + tf.reduce_mean(tf.abs(real_night - cycled_night))

        # Identity loss
        identity_loss = tf.reduce_mean(tf.abs(real_day - same_day)) + tf.reduce_mean(tf.abs(real_night - same_night))

        # Total generator losses
        total_gen_day_to_night_loss = gen_day_to_night_loss + LAMBDA_CYCLE * total_cycle_loss + LAMBDA_IDENTITY * identity_loss
        total_gen_night_to_day_loss = gen_night_to_day_loss + LAMBDA_CYCLE * total_cycle_loss + LAMBDA_IDENTITY * identity_loss

        # Discriminator losses
        disc_day_loss = loss_obj(tf.ones_like(disc_real_day), disc_real_day) + loss_obj(tf.zeros_like(disc_fake_day), disc_fake_day)
        disc_night_loss = loss_obj(tf.ones_like(disc_real_night), disc_real_night) + loss_obj(tf.zeros_like(disc_fake_night), disc_fake_night)

    # Calculate gradients
    grads_G_day_to_night = tape.gradient(total_gen_day_to_night_loss, G_day_to_night.trainable_variables)
    grads_G_night_to_day = tape.gradient(total_gen_night_to_day_loss, G_night_to_day.trainable_variables)
    grads_D_day = tape.gradient(disc_day_loss, D_day.trainable_variables)
    grads_D_night = tape.gradient(disc_night_loss, D_night.trainable_variables)

    # Apply gradients
    optimizer_G_day_to_night.apply_gradients(zip(grads_G_day_to_night, G_day_to_night.trainable_variables))
    optimizer_G_night_to_day.apply_gradients(zip(grads_G_night_to_day, G_night_to_day.trainable_variables))
    optimizer_D_day.apply_gradients(zip(grads_D_day, D_day.trainable_variables))
    optimizer_D_night.apply_gradients(zip(grads_D_night, D_night.trainable_variables))

    return total_gen_day_to_night_loss, total_gen_night_to_day_loss, disc_night_loss, disc_day_loss


In [None]:
# === Restore if checkpoint exists ===
start_epoch = 0
if checkpoint_manager.latest_checkpoint:
    checkpoint.restore(checkpoint_manager.latest_checkpoint)
    print("Checkpoint restored from:", checkpoint_manager.latest_checkpoint)

    # Extract last saved epoch from checkpoint filename
    restored_ckpt = checkpoint_manager.latest_checkpoint
    ckpt_number = restored_ckpt.split('-')[-1]
    if ckpt_number.isdigit():
        start_epoch = int(ckpt_number)
    print(f"Resuming from epoch {start_epoch+1}")


In [None]:
# ===========================
#        TRAINING LOOP
# ===========================
dataset = load_dataset(DAY_PATH, NIGHT_PATH)

num_day = len(glob.glob(DAY_PATH + '/*.jpg'))
num_night = len(glob.glob(NIGHT_PATH + '/*.jpg'))
steps_per_epoch = min(num_day, num_night) // BATCH_SIZE

# === Training Loop ===
for epoch in range(start_epoch, EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    for step, (day_img, night_img) in enumerate(dataset.repeat().take(steps_per_epoch)):
        G_loss, F_loss, D_night_loss, D_day_loss = train_step(day_img, night_img)

    # Save checkpoint after each epoch
    checkpoint_manager.save(checkpoint_number=epoch + 1)
    print(f"Checkpoint saved at epoch {epoch+1}")

    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}: G_loss={G_loss.numpy()}, F_loss={F_loss.numpy()} | D_night_loss={D_night_loss.numpy()}, D_day_loss={D_day_loss.numpy()}")

        sample_path = sorted(glob.glob(DAY_PATH + '/*.jpg'))[0]
        sample_img = preprocess_image(sample_path)[None, ...]
        prediction = G_day_to_night(sample_img, training=False)[0].numpy()

        fig, axs = plt.subplots(1, 2, figsize=(10, 5))
        axs[0].imshow(sample_img[0])
        axs[0].set_title("Input Day")
        axs[1].imshow((prediction + 1) / 2)
        axs[1].set_title("Predicted Night")
        plt.savefig(os.path.join(OUTPUT_DIR, f"epoch_{epoch+1}.png"))
        plt.close()

In [None]:
# ===========================
#        TESTING CODE
# ===========================
# To test the model with a custom input image, use the following code:

 sample_test_img = "your_image.jpg"  # Replace with your test image path
 test_output_path = "test_output.jpg"
 generate_night_image(sample_test_img, test_output_path)
# # # Visualization of test image
 fig, ax = plt.subplots(1, 2, figsize=(10, 5))
 test_input_img = Image.open(sample_test_img)
 test_output_img = Image.open(test_output_path)

 ax[0].imshow(test_input_img)
 ax[0].set_title("Test Input Image")
 ax[0].axis("off")

 ax[1].imshow(test_output_img)
 ax[1].set_title("Generated Night Image")
 ax[1].axis("off")

 plt.show()