<a href="https://colab.research.google.com/github/ramanujan2710/abstracting-fsvae/blob/main/Untitled13.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:


import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np

# Define hyperparameters
batch_size = 16
image_size = 256
num_channels = 3
num_classes = 1000
latent_dim = 128
num_layers = 12
num_heads = 8
mlp_dim = 512
dropout_rate = 0.1
learning_rate = 1e-4
num_epochs = 10

# Define the VIT-VAE model
def create_vit_vae():
    # Define the embedding layer
    input_layer = tf.keras.layers.Input(shape=(image_size, image_size, num_channels))
    embedding_layer = tf.keras.layers.Conv2D(filters=latent_dim, kernel_size=1, strides=1)(input_layer)
    embedding_layer = tf.keras.layers.LayerNormalization()(embedding_layer)
    embedding_layer = tf.keras.layers.Activation('relu')(embedding_layer)
    
    # Define the transformer encoder
    encoder_layer = tf.keras.layers.Flatten()(embedding_layer)
    for i in range(num_layers):
        encoder_layer = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=latent_dim//num_heads)(encoder_layer, encoder_layer)
        encoder_layer = tf.keras.layers.Dropout(dropout_rate)(encoder_layer)
        encoder_layer = tf.keras.layers.LayerNormalization()(encoder_layer)
        mlp_layer = tf.keras.Sequential([
            tf.keras.layers.Dense(units=mlp_dim, activation='relu'),
            tf.keras.layers.Dropout(dropout_rate),
            tf.keras.layers.Dense(units=latent_dim),
            tf.keras.layers.Dropout(dropout_rate)
        ])
        encoder_layer = mlp_layer(encoder_layer)
        encoder_layer = tf.keras.layers.LayerNormalization()(encoder_layer)
    
    # Define the decoder
    decoder_layer = tf.keras.layers.Dense(units=image_size*image_size*num_channels)(encoder_layer)
    decoder_layer = tf.keras.layers.Reshape((image_size, image_size, num_channels))(decoder_layer)
    
    # Define the shift and scale invariance
    decoder_layer = tf.keras.layers.Subtract()([decoder_layer, tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=[1,2], keepdims=True))(decoder_layer)])
    decoder_layer = tf.keras.layers.Lambda(lambda x: x / tf.keras.backend.std(x, axis=[1,2], keepdims=True))(decoder_layer)
    
    # Define the relative positional and intensity encodings
    positional_encoding = np.zeros((1, image_size, image_size, latent_dim))
    intensity_encoding = np.zeros((1, image_size, image_size, latent_dim))
    for i in range(image_size):
        for j in range(image_size):
            positional_encoding[0, i, j, :] = np.sin(i / 10000 ** (2 * np.arange(latent_dim) / latent_dim))
            positional_encoding[0, i, j, :] = np.cos(j / 10000 ** (2 * np.arange(latent_dim) / latent_dim))
            intensity_encoding[0, i, j, :] = embedding_layer[0, i, j, :]
    decoder_layer = tf.keras.layers.Concatenate()([decoder_layer, positional_encoding, intensity_encoding])
    
    # Define the final output layer
    output_layer = tf.keras.layers.Conv2D(filters=num_channels, kernel_size=1, strides=1)(decoder_layer)
    
    model = tf.keras.models.Model(inputs=input_layer, outputs=output_layer)
    return model

    div2k_train, div2k_info = tfds.load(name="div2k", split="train", with_info=True)


  def preprocess_data(data):
    image = tf.cast(data['hr'], tf.float32) / 255.0
    image = tf.image.resize(image, [image_size, image_size])
    return image


    train_dataset = div2k_train.map(preprocess_data).shuffle(buffer_size=1024).batch(batch_size)


    model = create_vit_vae()



def vae_loss(x, x_hat):
reconstruction_loss = tf.keras.losses.mean_squared_error(x, x_hat)
kl_divergence = -0.5 * tf.reduce_mean(1 + tf.math.log(tf.keras.backend.square(tf.keras.backend.std(x_hat, axis=[1,2], keepdims=True))) - tf.keras.backend.square(tf.keras.backend.mean(x_hat, axis=[1,2], keepdims=True)) - 1e-8)
return reconstruction_loss + kl_divergence



model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), loss=vae_loss)



model.fit(train_dataset, epochs=num_epochs)


model.save_weights('vit_vae_weights.h5')













IndentationError: ignored