In [None]:
# Import necessary libraries
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import Mean
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import load_model
from tensorflow.keras.layers import Input, Concatenate, Conv2D, Conv2DTranspose, LeakyReLU
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
import numpy as np
import os

# Constants
BATCH_SIZE = 16
EPOCHS = 50
LEARNING_RATE = 2e-4
LAMBDA = 100  # Regularization parameter for cycle-consistency loss

# Directories
TRAIN_DIR = '../data/processed/satellite_to_map/train/'
TEST_DIR = '../data/processed/satellite_to_map/test/'
CHECKPOINT_DIR = './checkpoints/'

# Function to load and preprocess images
def load_image_pair(satellite_path, map_path):
    satellite = tf.io.read_file(satellite_path)
    satellite = tf.image.decode_jpeg(satellite, channels=3)
    satellite = tf.image.resize(satellite, (256, 256))
    satellite = tf.cast(satellite, tf.float32) / 127.5 - 1.0
    
    map_image = tf.io.read_file(map_path)
    map_image = tf.image.decode_jpeg(map_image, channels=3)
    map_image = tf.image.resize(map_image, (256, 256))
    map_image = tf.cast(map_image, tf.float32) / 127.5 - 1.0
    
    return satellite, map_image

# Function to build the generator model
def build_generator(input_shape=(256, 256, 3), output_channels=3):
    inputs = Input(shape=input_shape, name='input_image')
    
    # Encoder
    down1 = Conv2D(64, (4, 4), strides=(2, 2), padding='same', activation='relu')(inputs)
    down2 = Conv2D(128, (4, 4), strides=(2, 2), padding='same', activation='relu')(down1)
    down3 = Conv2D(256, (4, 4), strides=(2, 2), padding='same', activation='relu')(down2)
    down4 = Conv2D(512, (4, 4), strides=(2, 2), padding='same', activation='relu')(down3)
    
    # Bottleneck
    bottleneck = Conv2D(512, (4, 4), strides=(2, 2), padding='same', activation='relu')(down4)
    
    # Decoder
    up1 = Conv2DTranspose(256, (4, 4), strides=(2, 2), padding='same', activation='relu')(bottleneck)
    merge1 = Concatenate()([up1, down3])
    
    up2 = Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', activation='relu')(merge1)
    merge2 = Concatenate()([up2, down2])
    
    up3 = Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same', activation='relu')(merge2)
    merge3 = Concatenate()([up3, down1])
    
    outputs = Conv2DTranspose(output_channels, (4, 4), strides=(2, 2), padding='same', activation='tanh')(merge3)
    
    generator = Model(inputs, outputs, name='generator')
    return generator

# Function to build the discriminator model
def build_discriminator(input_shape=(256, 256, 3)):
    input_image = Input(shape=input_shape, name='input_image')
    target_image = Input(shape=input_shape, name='target_image')
    
    combined_input = Concatenate()([input_image, target_image])
    
    conv1 = Conv2D(64, (4, 4), strides=(2, 2), padding='same')(combined_input)
    conv1 = LeakyReLU(alpha=0.2)(conv1)
    
    conv2 = Conv2D(128, (4, 4), strides=(2, 2), padding='same')(conv1)
    conv2 = LeakyReLU(alpha=0.2)(conv2)
    
    conv3 = Conv2D(256, (4, 4), strides=(2, 2), padding='same')(conv2)
    conv3 = LeakyReLU(alpha=0.2)(conv3)
    
    conv4 = Conv2D(512, (4, 4), padding='same')(conv3)
    conv4 = LeakyReLU(alpha=0.2)(conv4)
    
    output = Conv2D(1, (4, 4), padding='same')(conv4)
    
    discriminator = Model(inputs=[input_image, target_image], outputs=output, name='discriminator')
    return discriminator

# Function to calculate generator loss
def generator_loss(discriminator_real_outputs, generated_outputs, target_images):
    adversarial_loss = BinaryCrossentropy(from_logits=True)(tf.ones_like(discriminator_real_outputs), generated_outputs)
    l1_loss = tf.reduce_mean(tf.abs(target_images - generated_outputs))
    total_gen_loss = adversarial_loss + (LAMBDA * l1_loss)
    return total_gen_loss

# Function to calculate discriminator loss
def discriminator_loss(discriminator_real_outputs, discriminator_generated_outputs):
    real_loss = BinaryCrossentropy(from_logits=True)(tf.ones_like(discriminator_real_outputs), discriminator_real_outputs)
    generated_loss = BinaryCrossentropy(from_logits=True)(tf.zeros_like(discriminator_generated_outputs), discriminator_generated_outputs)
    total_disc_loss = real_loss + generated_loss
    return total_disc_loss

# Function for training step
def train_step(generator, discriminator, input_images, target_images, generator_optimizer, discriminator_optimizer, training=True):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(input_images, training=training)
        
        discriminator_real_outputs = discriminator([input_images, target_images], training=training)
        discriminator_generated_outputs = discriminator([input_images, generated_images], training=training)
        
        gen_loss = generator_loss(discriminator_real_outputs, discriminator_generated_outputs, target_images)
        disc_loss = discriminator_loss(discriminator_real_outputs, discriminator_generated_outputs)
    
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    
    return gen_loss, disc_loss

# Function to load and prepare dataset
def load_data(train_dir, test_dir):
    train_satellite_paths = [os.path.join(train_dir, 'satellite_images', filename) for filename in os.listdir(os.path.join(train_dir, 'satellite_images'))]
    train_map_paths = [os.path.join(train_dir, 'map_images', filename) for filename in os.listdir(os.path.join(train_dir, 'map_images'))]
    test_satellite_paths = [os.path.join(test_dir, 'satellite_images', filename) for filename in os.listdir(os.path.join(test_dir, 'satellite_images'))]
    test_map_paths = [os.path.join(test_dir, 'map_images', filename) for filename in os.listdir(os.path.join(test_dir, 'map_images'))]
    
    train_dataset = tf.data.Dataset.from_tensor_slices((train_satellite_paths, train_map_paths))
    train_dataset = train_dataset.shuffle(len(train_satellite_paths)).map(load_image_pair, num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE)
    
    test_dataset = tf.data.Dataset.from_tensor_slices((test_satellite_paths, test_map_paths))
    test_dataset = test_dataset.map(load_image_pair, num_parallel_calls=tf.data.AUTOTUNE).batch(BATCH_SIZE)
    
    return train_dataset, test_dataset

# Main training function
def train():
    # Build generator and discriminator
    generator = build_generator()
    discriminator = build_discriminator()
    
    # Optimizers
    generator_optimizer = Adam(learning_rate=LEARNING_RATE, beta_1=0.5)
    discriminator_optimizer = Adam(learning_rate=LEARNING_RATE, beta_1=0.5)
    
    # Load data
    train_dataset, test_dataset = load_data(TRAIN_DIR, TEST_DIR)
    
    # Checkpoint
    checkpoint_prefix = os.path.join(CHECKPOINT_DIR, "ckpt")
    checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                     discriminator_optimizer=discriminator_optimizer,
                                     generator=generator,
                                     discriminator=discriminator)
    
    # Metrics
    gen_loss_metric = Mean(name='gen_loss')
    disc_loss_metric = Mean(name='disc_loss')
    
    # Training loop
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch+1}/{EPOCHS}")
        
        # Reset metrics
        gen_loss_metric.reset_states()
        disc_loss_metric.reset_states()
        
        # Training
        for batch_num, (input_images, target_images) in enumerate(train_dataset):
            gen_loss, disc_loss
            # Perform one training step
            gen_loss, disc_loss = train_step(generator, discriminator, input_images, target_images,
                                             generator_optimizer, discriminator_optimizer)
            
            # Update metrics
            gen_loss_metric.update_state(gen_loss)
            disc_loss_metric.update_state(disc_loss)
            
            # Print training metrics
            if (batch_num + 1) % 100 == 0:
                print(f"Batch {batch_num + 1}, Generator Loss: {gen_loss_metric.result()}, Discriminator Loss: {disc_loss_metric.result()}")
        
        # Save checkpoint (every epoch)
        checkpoint.save(file_prefix=checkpoint_prefix)
        
        # Validation
        for batch_num, (input_images, target_images) in enumerate(test_dataset):
            val_gen_loss, val_disc_loss = train_step(generator, discriminator, input_images, target_images,
                                                    generator_optimizer, discriminator_optimizer, training=False)
        
        # Print validation metrics
        print(f"Validation - Generator Loss: {gen_loss_metric.result()}, Discriminator Loss: {disc_loss_metric.result()}")

if __name__ == "__main__":
    train()

