In [1]:
import tensorflow as tf
import keras as k
from keras import layers as l
import glob
import os
import numpy as np
import matplotlib.pyplot as plt

generator_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)

In [2]:
def encoder(layer_in, n_filters, batchnorm=True):
    # Initialize weights
    init = k.initializers.RandomNormal(0., 0.02)    # Add the downsampling
    d = l.Conv2D(n_filters, (4,4,), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
    if(batchnorm):
        d = l.BatchNormalization()(d, training=True)
    d = l.LeakyReLU(alpha=0.2)(d)
    return d

def decoder(layer_in, skip_in, n_filters, dropout=True):
    # Initialize weights
    init = k.initializers.RandomNormal(0., 0.02)    # Add the upsampling
    u = l.Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
    u = l.BatchNormalization()(u, training=True)
    if(dropout):
        u = l.Dropout(0.5)(u, training=True)
    # Merge
    u = l.Concatenate()([u, skip_in])
    u = l.Activation('relu')(u)
    return u

def Generator(image_shape=(256,256,3)):
    # Initialize weights
    init = k.initializers.RandomNormal(0., 0.02)
    # Input
    in_image = l.Input(shape=image_shape)

    # Encode
    e1 = encoder(in_image, 64, False)
    e2 = encoder(e1, 128)
    e3 = encoder(e2, 256)
    e4 = encoder(e3, 512)
    e5 = encoder(e4, 512)

    # Bottleneck
    b = l.Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e5)
    b = l.Activation('relu')(b)

    # Decode
    d1 = decoder(b, e5, 512)
    d2 = decoder(d1, e4, 512)
    d3 = decoder(d2, e3, 256, dropout=False)
    d4 = decoder(d3, e2, 128, dropout=False)
    d5 = decoder(d4, e1, 64, dropout=False)

    # Output
    g = l.Conv2DTranspose(3, (4,4), strides=(2,2), padding = 'same', kernel_initializer=init)(d5)
    out_image = l.Activation('tanh')(g)

    # Define model
    model = k.Model(inputs=in_image, outputs = out_image)
    return model

loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

LAMBDA = 100

def generator_loss(disc_fake_output, generator_output, target_output):
    gan_loss = loss_object(tf.ones_like(disc_fake_output), disc_fake_output)

    l1_loss = tf.reduce_mean(tf.abs(target_output - generator_output))

    total_gen_loss = gan_loss + (LAMBDA * l1_loss)
    return total_gen_loss, gan_loss, l1_loss

In [3]:
def Discriminator(image_shape):
    init = k.initializers.RandomNormal(0., 0.02)
    input_image = l.Input(shape=image_shape)
    target_image = l.Input(shape=image_shape)
    d = l.Concatenate()([input_image, target_image])
    
    d = l.SpectralNormalization(l.Conv2D(64, (4, 4), strides=2, padding='same', kernel_initializer=init))(d)
    d = l.LeakyReLU(alpha=0.2)(d)
    
    d = l.SpectralNormalization(l.Conv2D(128, (4, 4), strides=2, padding='same', kernel_initializer=init))(d)
    d = l.LayerNormalization()(d, training=True)
    d = l.LeakyReLU(alpha=0.2)(d)
    
    d = l.SpectralNormalization(l.Conv2D(256, (4, 4), strides=2, padding='same', kernel_initializer=init))(d)
    d = l.LayerNormalization()(d, training=True)
    d = l.LeakyReLU(alpha=0.2)(d)
    
    zero_pad1 = l.ZeroPadding2D()(d)
    d = l.Conv2D(512, 4, strides=1, kernel_initializer = init,
                 use_bias= False)(zero_pad1)
    
    batchnorm1 = l.BatchNormalization()(d)
    d = l.LeakyReLU(alpha=0.2)(batchnorm1)

    zero_pad2 = l.ZeroPadding2D()(d)
    
    d = l.SpectralNormalization(l.Conv2D(1, 4, strides=1, kernel_initializer=init))(zero_pad2)
    
    model = k.Model([input_image, target_image], d)
    return model

disc_ce_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = disc_ce_loss(tf.ones_like(real_output), real_output)

    generated_loss = disc_ce_loss(tf.zeros_like(fake_output), fake_output)

    total_loss = real_loss + generated_loss
    return total_loss

In [4]:
# def gan(g_model, d_model, image_shape):
#     for layer in d_model.layers:
#         if not isinstance(layer, l.BatchNormalization):
#             layer.trainable = False
#     in_src = l.Input(shape=image_shape)
#     gen_out = g_model(in_src)
#     dis_out = d_model([in_src, gen_out])

#     model = k.Model(in_src, [dis_out, gen_out])
    
#     opt = k.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
#     model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])
#     return model

In [5]:
def load_images(rgb_paths, ndvi_paths, image_shape=(256,256,3)):
    # Load and resize images
    rgb_images = [k.preprocessing.image.load_img(img, target_size=image_shape) for img in rgb_paths]
    ndvi_images = [k.preprocessing.image.load_img(img, target_size=image_shape) for img in ndvi_paths]

    # Convert to numpy arrays
    rgb_images = np.array([k.preprocessing.image.img_to_array(img) / 127.5 - 1 for img in rgb_images])
    ndvi_images = np.array([k.preprocessing.image.img_to_array(img) / 127.5 - 1 for img in ndvi_images])
    
    return rgb_images, ndvi_images

In [6]:
def load_images(rgb_paths, ndvi_paths, image_shape=(256,256,3)):
    # Load and resize images
    rgb_images = [k.preprocessing.image.load_img(img, target_size=image_shape) for img in rgb_paths]
    ndvi_images = [k.preprocessing.image.load_img(img, target_size=image_shape) for img in ndvi_paths]

    # Convert to numpy arrays
    rgb_images = np.array([k.preprocessing.image.img_to_array(img) / 127.5 - 1 for img in rgb_images])
    ndvi_images = np.array([k.preprocessing.image.img_to_array(img) / 127.5 - 1 for img in ndvi_images])
    
    return rgb_images, ndvi_images

def get_dataset(rgb_dir, ndvi_dir, image_shape=(256,256,3)):
    # Load images
    rgb_images = sorted(glob.glob(os.path.join(rgb_dir, '*.jpg')))
    ndvi_images = sorted(glob.glob(os.path.join(ndvi_dir, '*.jpg')))
    # Check if the number of images is the same
    if len(rgb_images) != len(ndvi_images):
        raise ValueError("Number of RGB and NDVI images do not match.")
    
    rgb_images = np.array(rgb_images)
    ndvi_images = np.array(ndvi_images)

    return rgb_images, ndvi_images
    

def real_pairs(dataset, n_samples, patch_shape):
    # Unpack
    trainA, trainB = dataset

    # Pick random
    ix = np.random.randint(0,trainA.shape[0], n_samples)
    X1, X2 = trainA[ix], trainB[ix]

    X1, X2 = load_images(X1, X2)
    # Label 1 (Real)
    y = tf.ones((n_samples, patch_shape, patch_shape, 1))
    return [X1, X2], y

def fake_pairs(g_model, samples, patch_shape):
    # Generate fake
    X = g_model.predict(samples, batch_size=32)
    # Label 0 (Fake)
    y = tf.zeros((len(X), patch_shape, patch_shape, 1))
    return X, y


In [29]:
def performance(step, g_model, d_model, dataset, n_samples = 3):
    # Select an input sample
    [X_realA, X_realB], _ = real_pairs(dataset, n_samples, 1)

    # Generate a fake input sample
    X_fakeB, _ = fake_pairs(g_model, X_realA, 1)

    # Scale pixel values
    X_realA = (X_realA + 1) / 2.0
    X_fakeB = (X_fakeB + 1) / 2.0
    X_realB = (X_realB + 1) / 2.0

    # Plot real images
    for i in range(n_samples):
        plt.subplot(3, n_samples, 1 + i)
        plt.axis('off')
        plt.imshow(X_realA[i])
    
    # Plot fake image
    for i in range(n_samples):
        plt.subplot(3, n_samples, 1 + n_samples + i)
        plt.axis('off')
        plt.imshow(X_fakeB[i])

    # Plot real image
    for i in range(n_samples):
        plt.subplot(3, n_samples, 1 + n_samples * 2 + i)
        plt.axis('off')
        plt.imshow(X_realB[i])
    
    filename1 = 'plot_%06d.png' % (step + 1)
    plt.savefig(filename1)
    plt.close()
    filename2 = 'generator_model_%06d.keras' % (step + 1)
    g_model.save(filename2)
    filename3 = 'discriminator_model_%06d.keras' % (step + 1)
    d_model.save(filename3)
    print('>Saved: %s, %s, %s' % (filename1, filename2, filename3))
    

In [33]:
def add_noise(images, noise_factor=0.05):
    noise = noise_factor * tf.random.normal(shape=images.shape)
    return images + noise


import time
from IPython import display
def train_step(input_image, target, gen_output, step, g_model, d_model):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = g_model(input_image, training = True)

        im = tf.convert_to_tensor(input_image)
        disc_real_output = d_model([input_image, target], training=True)
        disc_fake_output = d_model([im, gen_output], training=True)

        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_fake_output, gen_output, target)
        disc_loss = discriminator_loss(disc_real_output, disc_fake_output)
        real_loss = disc_ce_loss(tf.ones_like(disc_real_output), disc_real_output)
        generated_loss = disc_ce_loss(tf.zeros_like(disc_fake_output), disc_fake_output)
    print(f">{step}, d[Total:{disc_loss:.2f} Real: {real_loss:.2f} Fake:{generated_loss:.2f} ] \tg[Total: {gen_total_loss:.2f} Gan: {gen_gan_loss:.2f} L1:{gen_l1_loss:.2f}]")
    generator_gradients = gen_tape.gradient(gen_total_loss,
                                          g_model.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss,
                                               d_model.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(generator_gradients,
                                          g_model.trainable_variables))
    
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              d_model.trainable_variables))
    

def fit(gen, disc, dataset, steps, checkpoint):
    n_patch = disc.output_shape[1]
    start = time.time()

    # Unpack data
    trainA, trainB = dataset
    for step in range(steps):
        [X_realA, X_realB], y_real = real_pairs([trainA, trainB], 1, n_patch)
        X_fakeB, y_fake = fake_pairs(gen, X_realA, n_patch)

        X_realA = add_noise(X_realA, 0.02)
        X_realB = add_noise(X_realB, 0.02)
        X_fakeB = add_noise(X_fakeB, 0.02)
        
        if (step) % 1000 == 0:
            display.clear_output(wait=True)

            if step != 0:
                print(f'Time taken for 1000 steps: {time.time()-start:.2f} sec\n')
                print(f"Step: {step//1000}k")
                start = time.time()
                
        if (step) % 1000 == 0:
            performance(step, gen, disc, [trainA, trainB])
        
        if (step) % 10_000 == 0:
            gen.save("generator_model.keras")
            disc.save("discriminator_model.keras")
            
        train_step(X_realA, X_realB,X_fakeB, step, gen, disc)
    


    

In [None]:
# generator_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
# discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
dataset = get_dataset('dataset/train/inputs/RGB', 'dataset/train/inputs/NDVI', image_shape=(256,256,3))
generator = k.models.load_model('generator_model.keras')
discriminator = k.models.load_model('discriminator_model.keras')
# generator.compile(optimizer=generator_optimizer, loss=lambda y_true, y_pred: tf.reduce_mean(y_pred) )
# discriminator = Discriminator((256, 256, 3))
# discriminator.compile(optimizer=discriminator_optimizer, loss=lambda y_true, y_pred: tf.reduce_mean(y_pred))
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)


In [35]:
fit(generator, discriminator, dataset, 10000, checkpoint)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 192ms/step
>Saved: plot_000001.png, generator_model_000001.keras, discriminator_model_000001.keras
>0, d[Total:1.59 Real: 0.92 Fake:0.67 ] 	g[Total: 6.85 Gan: 0.74 L1:0.06]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 109ms/step
>1, d[Total:0.87 Real: 0.41 Fake:0.46 ] 	g[Total: 47.69 Gan: 1.03 L1:0.47]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 105ms/step
>2, d[Total:0.58 Real: 0.29 Fake:0.29 ] 	g[Total: 11.94 Gan: 1.40 L1:0.11]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 107ms/step
>3, d[Total:0.50 Real: 0.20 Fake:0.31 ] 	g[Total: 11.40 Gan: 1.40 L1:0.10]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 108ms/step
>4, d[Total:1.42 Real: 0.83 Fake:0.59 ] 	g[Total: 8.81 Gan: 0.84 L1:0.08]
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 104ms/step
>5, d[Total:0.96 Real: 0.74 Fake:0.22 ] 	g[Total: 16.93 Gan: 1.69 L1:0.15]
[1m1/1[0m [32m━━━━━━━━━

KeyboardInterrupt: 