In [None]:
# Step 1: Setting Up the Notebook
import tensorflow as tf
import matplotlib.pyplot as plt
from pix2pix_generator import build_generator
from pix2pix_discriminator import build_discriminator
from train_pix2pix import load_data, train_step  # Assuming `train_step` is implemented in train_pix2pix.py

# Constants
BATCH_SIZE = 16
EPOCHS = 50
LEARNING_RATE = 2e-4
LAMBDA = 100  # Regularization parameter for cycle-consistency loss
TRAIN_DIR = '../data/processed/satellite_to_map/train/'
TEST_DIR = '../data/processed/satellite_to_map/test/'

# Step 2: Data Loading and Preprocessing
train_dataset, test_dataset = load_data(TRAIN_DIR, TEST_DIR)

# Step 3: Build and Compile the Pix2Pix Model
generator = build_generator()
discriminator = build_discriminator()

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE, beta_1=0.5)

# Step 4: Training the Pix2Pix Model
# Training loop
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    
    for batch_num, (input_images, target_images) in enumerate(train_dataset):
        gen_loss, disc_loss = train_step(generator, discriminator, input_images, target_images,
                                         generator_optimizer, discriminator_optimizer)
        
        if (batch_num + 1) % 100 == 0:
            print(f"Batch {batch_num + 1}, Generator Loss: {gen_loss}, Discriminator Loss: {disc_loss}")

# Step 5: Evaluation and Visualization
# Generate sample images
for batch_num, (input_images, target_images) in enumerate(test_dataset):
    generated_images = generator(input_images, training=False)
    break  # Only generate for one batch for demonstration
    
# Visualize results
plt.figure(figsize=(10, 5))

for i in range(4):
    plt.subplot(2, 4, i + 1)
    plt.imshow(input_images[i] * 0.5 + 0.5)
    plt.title('Input Satellite')
    plt.axis('off')
    
    plt.subplot(2, 4, i + 5)
    plt.imshow(generated_images[i] * 0.5 + 0.5)
    plt.title('Generated Map')
    plt.axis('off')

plt.tight_layout()
plt.show()

# Step 6: Conclusion and Future Work
print("Training completed.")
