In [1]:
#VAE on a Small Synthetic Dataset

In [2]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from keras.src.losses.losses import mean_squared_error as mse
from tensorflow.keras.losses import MeanSquaredError

In [3]:
# Create a small synthetic dataset (10 samples, 2 features)
x_train = np.array([
    [0.1, 0.2],
    [0.2, 0.3],
    [0.3, 0.4],
    [0.4, 0.5],
    [0.5, 0.6],
    [0.6, 0.7],
    [0.7, 0.8],
    [0.8, 0.9],
    [0.9, 1.0],
    [1.0, 1.1]
], dtype='float32')
x_train

array([[0.1, 0.2],
       [0.2, 0.3],
       [0.3, 0.4],
       [0.4, 0.5],
       [0.5, 0.6],
       [0.6, 0.7],
       [0.7, 0.8],
       [0.8, 0.9],
       [0.9, 1. ],
       [1. , 1.1]], dtype=float32)

In [4]:
# Expand the dataset to 2D shape for compatibility with the Conv2D layers
x_train = np.expand_dims(x_train, axis=-1)  # Shape (10, 2, 1)

# Set the latent dimension size
latent_dim = 2

In [5]:
# Encoder
inputs = layers.Input(shape=(2, 1))  # Input shape matches the synthetic dataset
x = layers.Flatten()(inputs)
x = layers.Dense(16, activation='relu')(x)

In [6]:
# Mean and log variance for the latent space
z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)

In [7]:
# Sampling function for reparameterization trick
def sampling(args):
    z_mean, z_log_var = args
    batch = tf.shape(z_mean)[0]
    dim = tf.shape(z_mean)[1]
    epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
    return z_mean + tf.exp(0.5 * z_log_var) * epsilon

# Latent space representation
z = layers.Lambda(sampling)([z_mean, z_log_var])




In [8]:
# Decoder
decoder_input = layers.Input(shape=(latent_dim,))
x = layers.Dense(16, activation='relu')(decoder_input)
x = layers.Dense(2, activation='sigmoid')(x)
outputs = layers.Reshape((2, 1))(x)

In [9]:
# Encoder and decoder models
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
decoder = Model(decoder_input, outputs, name='decoder')

In [10]:
# Instantiate MeanSquaredError outside the loss function
mse = MeanSquaredError()

In [11]:
# Custom VAE Model Class
class VAE(Model):
    def __init__(self, encoder, decoder, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        return reconstructed

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = mse(tf.keras.backend.flatten(data), tf.keras.backend.flatten(reconstruction))
            kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
            kl_loss = tf.keras.backend.sum(kl_loss, axis=-1)
            kl_loss *= -0.5
            total_loss = tf.keras.backend.mean(reconstruction_loss + kl_loss)
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        return {"loss": total_loss}

In [12]:
# Instantiate the VAE model
vae = VAE(encoder, decoder)

# Compile the model with an optimizer
vae.compile(optimizer='adam')

# Summary of the VAE model
vae.summary()

In [13]:
# Train the VAE
vae.fit(x_train, epochs=500, batch_size=2)

Epoch 1/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - loss: 0.4612
Epoch 2/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.3953 
Epoch 3/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.3698 
Epoch 4/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.3189 
Epoch 5/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.2945 
Epoch 6/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.2681 
Epoch 7/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.2351 
Epoch 8/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.2104 
Epoch 9/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.2030 
Epoch 10/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.1841 
Epoch 11/5

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0854 
Epoch 84/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0660 
Epoch 85/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0724 
Epoch 86/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0591 
Epoch 87/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.0844 
Epoch 88/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0779 
Epoch 89/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0779 
Epoch 90/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0713 
Epoch 91/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0721 
Epoch 92/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0743 
Epoch 93/500


[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0596 
Epoch 165/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0781 
Epoch 166/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0716 
Epoch 167/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0733 
Epoch 168/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0730 
Epoch 169/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0784 
Epoch 170/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0800 
Epoch 171/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0639 
Epoch 172/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0707 
Epoch 173/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0751 
Epoc

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0652 
Epoch 246/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0640 
Epoch 247/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0639 
Epoch 248/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0745 
Epoch 249/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0776 
Epoch 250/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0683 
Epoch 251/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0583 
Epoch 252/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0728 
Epoch 253/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0695 
Epoch 254/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0678 
Epoc

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.0746 
Epoch 327/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0639 
Epoch 328/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0659 
Epoch 329/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0658 
Epoch 330/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0691 
Epoch 331/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0664 
Epoch 332/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0723 
Epoch 333/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0673 
Epoch 334/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0702 
Epoch 335/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0752 
Epoc

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0692 
Epoch 408/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0611 
Epoch 409/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0682 
Epoch 410/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0668 
Epoch 411/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0672 
Epoch 412/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0625 
Epoch 413/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0733 
Epoch 414/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0654 
Epoch 415/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0714 
Epoch 416/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - loss: 0.0765 
Epoc

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0672 
Epoch 489/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0709 
Epoch 490/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0685 
Epoch 491/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0721 
Epoch 492/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0740 
Epoch 493/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0668 
Epoch 494/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0719 
Epoch 495/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0692 
Epoch 496/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0726 
Epoch 497/500
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step - loss: 0.0678 
Epoc

<keras.src.callbacks.history.History at 0x1dcd3f3a6e0>

In [16]:
# Encode the data to the latent space
encoded_data = encoder.predict(x_train)[2]

# Decode the latent space back to the original space
decoded_data = decoder.predict(encoded_data)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 38ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 37ms/step


In [17]:
# Print the original data, encoded data, and decoded data
print("Original Data:\n", x_train.squeeze())
print("Encoded Data (Latent Space):\n", encoded_data)
print("Decoded Data:\n", decoded_data.squeeze())

Original Data:
 [[0.1 0.2]
 [0.2 0.3]
 [0.3 0.4]
 [0.4 0.5]
 [0.5 0.6]
 [0.6 0.7]
 [0.7 0.8]
 [0.8 0.9]
 [0.9 1. ]
 [1.  1.1]]
Encoded Data (Latent Space):
 [[ 0.69957566  0.77276045]
 [ 1.5225275   1.3641342 ]
 [-1.3093973  -0.19854012]
 [ 0.66807365 -0.22301093]
 [ 0.10260925 -1.2711886 ]
 [-0.34924918 -1.4115722 ]
 [ 0.48423332 -1.1279027 ]
 [-0.00386277 -1.3710927 ]
 [ 1.4376853  -0.17703515]
 [-0.44145647  0.9110388 ]]
Decoded Data:
 [[0.55574375 0.6443778 ]
 [0.58032197 0.70872265]
 [0.5311357  0.6482096 ]
 [0.55482066 0.634321  ]
 [0.5340287  0.6392071 ]
 [0.53101    0.63577795]
 [0.5316541  0.63688934]
 [0.5338432  0.6411176 ]
 [0.5713732  0.6735885 ]
 [0.5408968  0.6384337 ]]


In [18]:
#As the dataset we have taken is very small that why the result are very poor

In [19]:
# Generate new rows
def generate_new_samples(decoder, latent_dim, num_samples=5):
    # Sample from a Gaussian distribution
    new_latent_samples = np.random.normal(size=(num_samples, latent_dim))
    
    # Decode the latent samples to generate new data
    new_data = decoder.predict(new_latent_samples)
    
    return new_data

In [20]:
# Generate 5 new rows
new_samples = generate_new_samples(decoder, latent_dim, num_samples=5)

# Print the new samples
print("Generated New Samples:\n", new_samples.squeeze())

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 137ms/step
Generated New Samples:
 [[0.55077887 0.6261333 ]
 [0.5298668  0.6405395 ]
 [0.56380415 0.6564733 ]
 [0.5468174  0.6159637 ]
 [0.52998316 0.60870385]]
