In [29]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt
import os
import pathlib
gpus = tf.config.experimental.list_physical_devices('GPU')
gpus

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [30]:
BUFFER_SIZE = 512
BATCH_SIZE = 8
noise_dim = 100
img_sz = (256,256)

In [31]:
# Load and preprocess the dataset
def load_and_preprocess_dataset(data_dir, image_size=(128, 128)):
    data_dir = pathlib.Path(data_dir)
    
    # Get list of image files
    image_files = list(data_dir.glob('*.jpg')) + list(data_dir.glob('*.png'))
    image_count = len(image_files)
    print(f"Found {image_count} images.")

    # Convert paths to strings
    image_files = [str(path) for path in image_files]

    # Create a dataset from the image files
    list_ds = tf.data.Dataset.from_tensor_slices(image_files)

    # Define preprocessing function
    def preprocess_image(file_path):
        try:
            img = tf.io.read_file(file_path)
            img = tf.image.decode_image(img, channels=3, expand_animations=False)
            print(f"Original image shape: {img.shape}")
            img = tf.image.resize(img, image_size, method=tf.image.ResizeMethod.BICUBIC)
            print(f"Resized image shape: {img.shape}")
            img = tf.cast(img, tf.float32)
            img = (img - 127.5) / 127.5  # Normalize to [-1, 1]
            return img
        except tf.errors.InvalidArgumentError:
            print(f"Error processing image: {file_path}")
            return None

    # Map preprocessing function to dataset
    dataset = list_ds.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
    
    # Filter out None values (failed preprocessing)
    dataset = dataset.filter(lambda x: x is not None)

    # Prepare dataset for training
    dataset = dataset.shuffle(buffer_size=BUFFER_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

    return dataset

def make_generator_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(16*16*256, use_bias=False, input_shape=(noise_dim,)),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.LeakyReLU(),
        
        tf.keras.layers.Reshape((16, 16, 256)),
        
        tf.keras.layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.LeakyReLU(),

        tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.LeakyReLU(),

        tf.keras.layers.Conv2DTranspose(32, (5, 5), strides=(4, 4), padding='same', use_bias=False),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.LeakyReLU(),
        
        # tf.keras.layers.Conv2DTranspose(16, (5, 5), strides=(2, 2), padding='same', use_bias=False),
        # tf.keras.layers.BatchNormalization(),
        # tf.keras.layers.LeakyReLU(),
        
        # tf.keras.layers.Conv2DTranspose(16, (5, 5), strides=(2, 2), padding='same', use_bias=False),
        # tf.keras.layers.BatchNormalization(),
        # tf.keras.layers.LeakyReLU(),

        tf.keras.layers.Conv2DTranspose(3, (5, 5), strides=(1, 1), padding='same', use_bias=False, activation='tanh')
    ])
    return model

def make_discriminator_model():
    model = models.Sequential([
        layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[256,256,3]),
        layers.LeakyReLU(),
        layers.Dropout(0.2),
        
        # layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same'),
        # layers.LeakyReLU(),
        # layers.Dropout(0.3),
        
        layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(),
        layers.Dropout(0.2),

        layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(),
        layers.Dropout(0.2),

        # layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'),
        # layers.LeakyReLU(),
        # layers.Dropout(0.2),

        # layers.Conv2D(32, (5, 5), strides=(2, 2), padding='same'),
        # layers.LeakyReLU(),
        # layers.Dropout(0.2),
        
        layers.Conv2D(16, (5, 5), strides=(2, 2), padding='same'),
        layers.LeakyReLU(),
        layers.Dropout(0.2),
        
        layers.Flatten(),
        layers.Dense(1)
    ])
    return model

# Define loss functions
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

# Define optimizers
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

# Define the training step
@tf.function
def train_step(images):
    batch_size = tf.shape(images)[0]
    noise = tf.random.normal([batch_size, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

    return gen_loss, disc_loss

# Function to generate and save images
def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)
    
    fig = plt.figure(figsize=(32, 32))
    
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i] * 0.5 + 0.5)  # Denormalize
        plt.axis('off')
    
    plt.savefig(f'generated/image_at_epoch_{epoch:04d}.png')
    plt.close()

# Set up the training loop
def train(dataset, epochs):
    for epoch in range(epochs):
        gen_loss_list = []
        disc_loss_list = []
        for image_batch in dataset:
            gen_loss, disc_loss = train_step(image_batch)
            gen_loss_list.append(gen_loss)
            disc_loss_list.append(disc_loss)
        
        # Print losses
        print(f"Epoch {epoch+1}, Gen Loss: {np.mean(gen_loss_list):.4f}, Disc Loss: {np.mean(disc_loss_list):.4f}")
        
        # Generate and save images
        if (epoch + 1) % 10  == 0:
            generate_and_save_images(generator, epoch + 1, seed)

        # Save the model every 15 epochs
        if (epoch + 1) % 1000 == 0:
            generator.save(f'generator_model_epoch_{epoch+1}.h5')
            discriminator.save(f'discriminator_model_epoch_{epoch+1}.h5')


In [32]:
# Initialize models
generator = make_generator_model()
discriminator = make_discriminator_model()
# generator = tf.keras.models.load_model('generator_model_epoch_1000.h5')
# discriminator = tf.keras.models.load_model('discriminator_model_epoch_1000.h5')

generator.summary()
discriminator.summary()

# Load and prepare the dataset
data_dir = "dataset_cbp"
train_dataset = load_and_preprocess_dataset(data_dir, image_size=img_sz)

# Create a seed for image generation
seed = tf.random.normal([16, noise_dim])

# Train the model
train(train_dataset, epochs=10000)

Model: "sequential_14"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_14 (Dense)            (None, 65536)             6553600   
                                                                 
 batch_normalization_37 (Bat  (None, 65536)            262144    
 chNormalization)                                                
                                                                 
 leaky_re_lu_75 (LeakyReLU)  (None, 65536)             0         
                                                                 
 reshape_7 (Reshape)         (None, 16, 16, 256)       0         
                                                                 
 conv2d_transpose_37 (Conv2D  (None, 32, 32, 128)      819200    
 Transpose)                                                      
                                                                 
 batch_normalization_38 (Bat  (None, 32, 32, 128)    