In [1]:
import tensorflow as tf
import keras as k
from keras import layers as l
import glob
import os
import numpy as np
import matplotlib.pyplot as plt

In [2]:
generator_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)

In [3]:
def encoder(layer_in, n_filters, batchnorm=True):
    # Initialize weights
    init = k.initializers.RandomNormal(0., 0.02)    # Add the downsampling
    d = l.Conv2D(n_filters, (4,4,), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
    if(batchnorm):
        d = l.BatchNormalization()(d, training=True)
    d = l.LeakyReLU(alpha=0.2)(d)
    return d

def decoder(layer_in, skip_in, n_filters, dropout=True):
    # Initialize weights
    init = k.initializers.RandomNormal(0., 0.02)    # Add the upsampling
    u = l.Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
    u = l.BatchNormalization()(u, training=True)
    if(dropout):
        u = l.Dropout(0.5)(u, training=True)
    # Merge
    u = l.Concatenate()([u, skip_in])
    u = l.Activation('relu')(u)
    return u

def generator(image_shape=(256,256,3)):
    # Initialize weights
    init = k.initializers.RandomNormal(0., 0.02)
    # Input
    in_image = l.Input(shape=image_shape)

    # Encode
    e1 = encoder(in_image, 64, False)
    e2 = encoder(e1, 128)
    e3 = encoder(e2, 256)
    e4 = encoder(e3, 512)
    e5 = encoder(e4, 512)

    # Bottleneck
    b = l.Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e5)
    b = l.Activation('relu')(b)

    # Decode
    d1 = decoder(b, e5, 512)
    d2 = decoder(d1, e4, 512)
    d3 = decoder(d2, e3, 256, dropout=False)
    d4 = decoder(d3, e2, 128, dropout=False)
    d5 = decoder(d4, e1, 64, dropout=False)

    # Output
    g = l.Conv2DTranspose(3, (4,4), strides=(2,2), padding = 'same', kernel_initializer=init)(d5)
    out_image = l.Activation('tanh')(g)

    # Define model
    model = k.Model(in_image, out_image)
    return model

In [4]:
def discriminator(image_shape):
    init = k.initializers.RandomNormal(0., 0.02)
    input_image = l.Input(shape=image_shape)
    target_image = l.Input(shape=image_shape)
    d = l.Concatenate()([input_image, target_image])
    
    d = l.SpectralNormalization(l.Conv2D(64, (4, 4), strides=2, padding='same', kernel_initializer=init))(d)
    d = l.LeakyReLU(alpha=0.2)(d)
    
    d = l.SpectralNormalization(l.Conv2D(128, (4, 4), strides=2, padding='same', kernel_initializer=init))(d)
    d = l.LayerNormalization()(d, training=True)
    d = l.LeakyReLU(alpha=0.2)(d)
    
    d = l.SpectralNormalization(l.Conv2D(256, (4, 4), strides=2, padding='same', kernel_initializer=init))(d)
    d = l.LayerNormalization()(d, training=True)
    d = l.LeakyReLU(alpha=0.2)(d)
    
    d = l.SpectralNormalization(l.Conv2D(512, (4, 4), strides=1, padding='same', kernel_initializer=init))(d)
    d = l.LayerNormalization()(d, training=True)
    d = l.LeakyReLU(alpha=0.2)(d)
    
    d = l.SpectralNormalization(l.Conv2D(1, (4, 4), strides=1, padding='same', kernel_initializer=init))(d)
    patch_output = l.Activation('sigmoid')(d)
    
    model = k.Model([input_image, target_image], patch_output)
    opt = k.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model

disc_ce_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = disc_ce_loss(tf.ones_like(real_output) * 0.9, real_output)

    generated_loss = disc_ce_loss(tf.zeros_like(fake_output), fake_output)

    total_loss = real_loss + generated_loss
    return total_loss

In [5]:
def gan(g_model, d_model, image_shape):
    for layer in d_model.layers:
        if not isinstance(layer, l.BatchNormalization):
            layer.trainable = False
    in_src = l.Input(shape=image_shape)
    gen_out = g_model(in_src)
    dis_out = d_model([in_src, gen_out])

    model = k.Model(in_src, [dis_out, gen_out])
    
    opt = k.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])
    return model

In [6]:
def load_images(rgb_dir, ndvi_dir, image_shape=(256,256,3)):
    # Load images
    rgb_images = sorted(glob.glob(os.path.join(rgb_dir, '*.jpg')))
    ndvi_images = sorted(glob.glob(os.path.join(ndvi_dir, '*.jpg')))
    # Check if the number of images is the same
    if len(rgb_images) != len(ndvi_images):
        raise ValueError("Number of RGB and NDVI images do not match.")
    
    # Load and resize images
    rgb_images = [k.preprocessing.image.load_img(img, target_size=image_shape) for img in rgb_images]
    ndvi_images = [k.preprocessing.image.load_img(img, target_size=image_shape) for img in ndvi_images]

    # Convert to numpy arrays
    rgb_images = np.array([k.preprocessing.image.img_to_array(img) / 127.5 - 1 for img in rgb_images])
    ndvi_images = np.array([k.preprocessing.image.img_to_array(img) / 127.5 - 1 for img in ndvi_images])
    
    return rgb_images, ndvi_images

def real_pairs(dataset, n_samples, patch_shape):
    # Unpack
    trainA, trainB = dataset

    # Pick random
    ix = np.random.randint(0,trainA.shape[0], n_samples)
    X1, X2 = trainA[ix], trainB[ix]

    # Label 1 (Real)
    y = tf.ones((n_samples, patch_shape, patch_shape, 1))
    return [X1, X2], y

def fake_pairs(g_model, samples, patch_shape):
    # Generate fake
    X = g_model.predict(samples, batch_size=32)
    # Label 0 (Fake)
    y = tf.zeros((len(X), patch_shape, patch_shape, 1))
    return X, y


In [7]:
def performance(step, g_model, dataset, n_samples = 3):
    # Select an input sample
    [X_realA, X_realB], _ = real_pairs(dataset, n_samples, 1)

    # Generate a fake input sample
    X_fakeB, _ = fake_pairs(g_model, X_realA, 1)

    # Scale pixel values
    X_realA = (X_realA + 1) / 2.0
    X_fakeB = (X_fakeB + 1) / 2.0
    X_realB = (X_realB + 1) / 2.0

    # Plot real images
    for i in range(n_samples):
        plt.subplot(3, n_samples, 1 + i)
        plt.axis('off')
        plt.imshow(X_realA[i])
    
    # Plot fake image
    for i in range(n_samples):
        plt.subplot(3, n_samples, 1 + n_samples + i)
        plt.axis('off')
        plt.imshow(X_fakeB[i])

    # Plot real image
    for i in range(n_samples):
        plt.subplot(3, n_samples, 1 + n_samples * 2 + i)
        plt.axis('off')
        plt.imshow(X_realB[i])
    
    filename1 = 'plot_%06d.png' % (step + 1)
    plt.savefig(filename1)
    plt.close()
    filename2 = 'model_%06d.h5' % (step + 1)
    g_model.save(filename2)
    print('>Saved: %s and %s' % (filename1, filename2))
    

In [8]:
def add_noise(images, noise_factor=0.05):
    noise = noise_factor * tf.random.normal(shape=images.shape)
    return images + noise

def train(g_model, d_model, gan_model, dataset, epochs=100, batch = 1):
    # Determine the output shape
    n_patch = d_model.output_shape[1]

    # Unpack data
    trainA, trainB = dataset

    # Calcuate batches per epoch
    bat_per_epoch = int(len(trainA) / batch)
    
    # Calculate iterations
    n_steps = bat_per_epoch * epochs

    # Enumerate epochs
    for i in range(n_steps):
        # Select real samples
        [X_realA, X_realB], y_real = real_pairs([trainA, trainB], batch, n_patch)

        # Select fake samples
        X_fakeB, y_fake = fake_pairs(g_model, X_realA, n_patch)

        # Add noise to the real images
        X_realA = add_noise(X_realA, noise_factor=0.05)
        X_realB = add_noise(X_realB, noise_factor=0.05)

        # Add noise to the fake images
        X_fakeB = add_noise(X_fakeB, noise_factor=0.05)
        
        # Update discriminator
        d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real, return_dict=True)
        d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake, return_dict=True)
        
        d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real, return_dict=True)
        d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake, return_dict=True)

        # Update generator
        g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])

        # Summarize performance
        print(f">{i+1}, d1[ a:{d_loss1['accuracy']:.2f}, l:{d_loss1['loss']:.2f}] \td2[ a:{d_loss2['accuracy']:.2f}, l:{d_loss2['loss']:.2f}] \tg[{g_loss:.2f}]")
        if (i+1) % (bat_per_epoch/2) == 0:
            performance(i, g_model, [trainA, trainB])

In [None]:
dataset = load_images('dataset/train/inputs/RGB', 'dataset/train/inputs/NDVI', image_shape=(256,256,3))
generator = generator()
discriminator = discriminator((256, 256, 3))
gan_model = gan(generator, discriminator, (256, 256, 3))


In [None]:
train(g_model=generator, d_model=discriminator,gan_model=gan_model, dataset=dataset, epochs=100, batch=10)
