In [66]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Assuming you have already defined the label_dict and other parameters for the dataset
DATA_PATH = "/data/home/vvaibhav/AI/VAE/afhq"
width = 128
batch_size = 1024
latent_dim = 32




In [67]:
# Load and preprocess the dataset using TensorFlow functions
print("Train Preprocessing")
train_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    DATA_PATH + '/train',
    image_size=(width, width),
    batch_size=batch_size,
    label_mode='int'
)
print(type(train_dataset))
print(train_dataset)

Train Preprocessing
Found 14630 files belonging to 3 classes.
<class 'tensorflow.python.data.ops.dataset_ops.BatchDataset'>
<BatchDataset element_spec=(TensorSpec(shape=(None, 128, 128, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None,), dtype=tf.int32, name=None))>


In [68]:
print("validation preprocessing")
val_dataset = tf.keras.preprocessing.image_dataset_from_directory(
    DATA_PATH + '/val',
    image_size=(width, width),
    batch_size=batch_size,
    label_mode='int'
)


validation preprocessing
Found 1500 files belonging to 3 classes.


In [69]:
import tensorflow as tf
from tensorflow.keras import models, layers

class VAE(tf.keras.Model):
    def __init__(self, input_dim, latent_dim):
        super(VAE, self).__init__()

        # Encoder
        self.encoder = models.Sequential([
            layers.InputLayer(input_shape=input_dim),
            layers.Conv2D(32, kernel_size=4, strides=2, padding="same", activation="relu"),
            layers.Conv2D(64, kernel_size=4, strides=2, padding="same", activation="relu"),
            layers.Conv2D(128, kernel_size=4, strides=2, padding="same", activation="relu"),
            layers.Flatten(),
            layers.Dense(latent_dim + latent_dim),  # Two times latent_dim for mean and log-variance
        ])

        # Decoder
        self.decoder = models.Sequential([
            layers.InputLayer(input_shape=(latent_dim,)),
            layers.Dense(16 * 16 * 128, activation="relu"),
            layers.Reshape((16, 16, 128)),
            layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding="same", activation="relu"),
            layers.Conv2DTranspose(32, kernel_size=4, strides=2, padding="same", activation="relu"),
            layers.Conv2DTranspose(3, kernel_size=4, strides=2, padding="same", activation="sigmoid"),
        ])

    def reparameterize(self, mu, logvar):
        std = tf.exp(0.5 * logvar)
        eps = tf.random.normal(shape=tf.shape(std))
        return mu + eps * std

    def call(self, x):
        enc_output = self.encoder(x)
        mu, logvar = enc_output[:, :latent_dim], enc_output[:, latent_dim:]
        z = self.reparameterize(mu, logvar)
        dec_output = self.decoder(z)
        return dec_output, mu, logvar



# # Instantiate the VAE model
# print("VAE Class Object")
# vae = VAE(latent_dim)



In [101]:
import tensorflow as tf

# Assuming train_dataset is already defined
nan_count_images = 0
nan_count_labels = 0

for images, labels in train_dataset:
    print(images.shape)
    # Check for NaN values in images
    nan_mask_images = tf.math.is_nan(images)
    nan_count_images += tf.reduce_sum(tf.cast(nan_mask_images, tf.int32)).numpy()

    # Cast labels to float32 before checking for NaN values
    labels_float32 = tf.cast(labels, tf.float32)
    nan_mask_labels = tf.math.is_nan(labels_float32)
    nan_count_labels += tf.reduce_sum(tf.cast(nan_mask_labels, tf.int32)).numpy()

# Print the total count of NaN values in images and labels
print("Total NaN values in images:", nan_count_images)
print("Total NaN values in labels:", nan_count_labels)


(1024, 128, 128, 3)
(1024, 128, 128, 3)
(1024, 128, 128, 3)
(1024, 128, 128, 3)
(1024, 128, 128, 3)
(1024, 128, 128, 3)
(1024, 128, 128, 3)
(1024, 128, 128, 3)
(1024, 128, 128, 3)
(1024, 128, 128, 3)
(1024, 128, 128, 3)
(1024, 128, 128, 3)
(1024, 128, 128, 3)
(1024, 128, 128, 3)
(294, 128, 128, 3)
Total NaN values in images: 0
Total NaN values in labels: 0


In [73]:
import tensorflow as tf
from tensorflow.keras import models, layers

# Assume your VAE model is already defined and compiled
vae = VAE(input_dim=(width, width, 3), latent_dim=32)


for images, labels in train_dataset:
    # Forward pass through the VAE model
    recon_data, mu, logvar = vae(images)
    #print(recon_data, mu, logvar)
    # Check for NaN values in recon_data, mu, and logvar
    nan_mask_recon = tf.math.is_nan(recon_data)
    nan_mask_mu = tf.math.is_nan(mu)
    nan_mask_logvar = tf.math.is_nan(logvar)

    # Count NaN values in each batch
    count_nan_recon = tf.reduce_sum(tf.cast(nan_mask_recon, tf.int32))
    count_nan_mu = tf.reduce_sum(tf.cast(nan_mask_mu, tf.int32))
    count_nan_logvar = tf.reduce_sum(tf.cast(nan_mask_logvar, tf.int32))

    print(f"NaN count in recon_data: {count_nan_recon.numpy()}")
    print(f"NaN count in mu: {count_nan_mu.numpy()}")
    print(f"NaN count in logvar: {count_nan_logvar.numpy()}")
    
    # Optionally, break the loop after checking the first batch
    break


NaN count in recon_data: 0
NaN count in mu: 0
NaN count in logvar: 0


In [99]:

def loss_function(recon_x, x, mu, logvar):
    # Reshape x to match recon_x's shape
    x_reshaped = tf.reshape(x, shape=(-1,128, 128, 3))

    # Calculate binary cross-entropy loss along spatial dimensions
    BCE = tf.reduce_sum(tf.keras.losses.binary_crossentropy(recon_x, x_reshaped), axis=(1, 2))

    # Calculate Kullback-Leibler Divergence
    KLD = -0.5 * tf.reduce_sum(1 + logvar - tf.square(mu) - tf.exp(logvar), axis=1)

    return tf.reduce_mean(BCE + KLD)



In [100]:

# Define the optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

# Training loop
print("Training Started")
num_epochs = 1

for epoch in range(num_epochs):
    total_train_loss = 0

    for batch in train_dataset:
        data, _ = batch

        with tf.GradientTape() as tape:
            recon_data, mu, logvar = vae(data)
            loss = loss_function(recon_data, data, mu, logvar)
            print(loss)

        gradients = tape.gradient(loss, vae.trainable_variables)
        optimizer.apply_gradients(zip(gradients, vae.trainable_variables))

        total_train_loss += loss.numpy()

    avg_train_loss = total_train_loss / len(train_dataset)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Train Loss: {avg_train_loss}")

    # Validation loop
    total_val_loss = 0
    for val_batch in val_dataset:
        val_data, _ = val_batch
        recon_val_data, mu_val, logvar_val = vae(val_data)
        val_loss = loss_function(recon_val_data, val_data, mu_val, logvar_val)
        total_val_loss += val_loss.numpy()

    avg_val_loss = total_val_loss / len(val_dataset)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Validation Loss: {avg_val_loss}")

# Save the model after training
print("Saving Model")
save_path = "/data/home/vvaibhav/AI/VAE/vae_model"
vae.save(save_path)

# Generate new images
vae.eval()
random_latent = tf.random.normal(shape=(16, latent_dim))
generated_images = vae.decoder(random_latent).numpy()

# Display or save the generated images as needed
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
axes = axes.flatten()

for i in range(16):
    image = generated_images[i]
    axes[i].imshow(image)
    axes[i].axis('off')

plt.tight_layout()
plt.show()


Training Started
tf.Tensor(nan, shape=(), dtype=float32)
tf.Tensor(nan, shape=(), dtype=float32)
tf.Tensor(nan, shape=(), dtype=float32)
tf.Tensor(nan, shape=(), dtype=float32)
tf.Tensor(nan, shape=(), dtype=float32)
tf.Tensor(nan, shape=(), dtype=float32)
tf.Tensor(nan, shape=(), dtype=float32)
tf.Tensor(nan, shape=(), dtype=float32)
tf.Tensor(nan, shape=(), dtype=float32)
tf.Tensor(nan, shape=(), dtype=float32)
tf.Tensor(nan, shape=(), dtype=float32)
tf.Tensor(nan, shape=(), dtype=float32)
tf.Tensor(nan, shape=(), dtype=float32)
tf.Tensor(nan, shape=(), dtype=float32)
tf.Tensor(nan, shape=(), dtype=float32)
Epoch [1/1], Train Loss: nan
Epoch [1/1], Validation Loss: nan
Saving Model
INFO:tensorflow:Assets written to: /data/home/vvaibhav/AI/VAE/vae_model/assets


AttributeError: 'VAE' object has no attribute 'eval'