# Variational Auto-Encoder for Molecular Generation Task

In [1]:
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split

## Data Processing

In [2]:
LENGTH = 128
DFF = 32

# shape; (batch_size, length, dff)
cb2_embeddings = np.load("./data/processed_cb2.npy")
cb2_embeddings.shape

(2723, 128, 32)

In [3]:
X_train, x_test = train_test_split(cb2_embeddings, test_size=0.3, random_state=17)
X_train = tf.expand_dims(X_train, axis=-1)
x_test = tf.expand_dims(x_test, axis=-1)

X_train.shape, x_test.shape

(TensorShape([1906, 128, 32, 1]), TensorShape([817, 128, 32, 1]))

In [4]:
def preprocess(tensor):
    tensor = tf.cast(tensor, dtype=tf.float32)
    tensor = tf.divide( tf.subtract(tensor, tf.reduce_min(tensor)), tf.subtract(tf.reduce_max(tensor), tf.reduce_min(tensor)) )
    return tensor

In [5]:
X_train = preprocess(X_train)
x_test = preprocess(x_test)

tf.math.reduce_min(X_train), tf.math.reduce_max(X_train), tf.math.reduce_mean(X_train)

(<tf.Tensor: shape=(), dtype=float32, numpy=0.0>,
 <tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.50185823>)

In [6]:
tf.math.reduce_min(x_test), tf.math.reduce_max(x_test), tf.math.reduce_mean(x_test)

(<tf.Tensor: shape=(), dtype=float32, numpy=0.0>,
 <tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5030998>)

## VAE Model

In [7]:
from tensorflow import keras
from tensorflow.keras import layers

class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [8]:
MAX_ROWS = 128
MAX_COLUMNS = 32
latent_dim = 512

encoder_inputs = keras.Input(shape=(MAX_ROWS, MAX_COLUMNS, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2D(128, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2D(256, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(128, activation="relu")(x)
x = layers.Reshape((128, 1))(x)  # Add this line to reshape the output for LSTM
x = layers.LSTM(128, return_sequences=True)(x)  # Apply LSTM after Conv2D and Reshape
x = layers.Flatten()(x)  # Flatten the LSTM output to be connected to z_mean and z_log_var
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 128, 32, 1)]         0         []                            
                                                                                                  
 conv2d (Conv2D)             (None, 64, 16, 32)           320       ['input_1[0][0]']             
                                                                                                  
 conv2d_1 (Conv2D)           (None, 32, 8, 64)            18496     ['conv2d[0][0]']              
                                                                                                  
 conv2d_2 (Conv2D)           (None, 16, 4, 128)           73856     ['conv2d_1[0][0]']            
                                                                                            

In [9]:
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(128, activation="relu")(latent_inputs)
x = layers.Reshape((128, 1))(x)
x = layers.LSTM(128, return_sequences=True)(x)
x = layers.Reshape((16, 4, 256))(x)
x = layers.Conv2DTranspose(256, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(128, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)  # Output the reconstructed input
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()


# latent_inputs = keras.Input(shape=(latent_dim,))
# x = layers.Dense(MAX_ROWS // 16 * MAX_COLUMNS // 16 * 64, activation="relu")(latent_inputs)
# x = layers.Reshape((MAX_ROWS // 16, MAX_COLUMNS // 16, 64))(x)
# x = layers.Conv2DTranspose(256, 3, activation="relu", strides=2, padding="same")(x)
# x = layers.Conv2DTranspose(128, 3, activation="relu", strides=2, padding="same")(x)
# x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
# x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
# decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
# decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
# decoder.summary()

Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 512)]             0         
                                                                 
 dense_1 (Dense)             (None, 128)               65664     
                                                                 
 reshape_1 (Reshape)         (None, 128, 1)            0         
                                                                 
 lstm_1 (LSTM)               (None, 128, 128)          66560     
                                                                 
 reshape_2 (Reshape)         (None, 16, 4, 256)        0         
                                                                 
 conv2d_transpose (Conv2DTr  (None, 32, 8, 256)        590080    
 anspose)                                                        
                                                           

In [10]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            mse = tf.keras.losses.MeanSquaredError()
            kl = tf.keras.losses.KLDivergence()
            reconstruction_loss = mse(data, reconstruction)
            kl_loss = kl(data, reconstruction)
            
            # reconstruction_loss = tf.reduce_mean(
            #     tf.reduce_sum(
            #         mse(data, reconstruction), axis=(1, 2)
            #     )
            # )
            # kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            # kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            
            total_loss = reconstruction_loss + 0.15 * kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

## Loading VAE from Checkpoint
Only run if checkpoint has been created

In [11]:
vae = VAE(encoder, decoder)
vae.load_weights('./vae_checkpoints/trained_vae')

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x2913256d0>

## VAE Training

In [11]:
vae = VAE(encoder, decoder)

early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='reconstruction_loss', patience=3, verbose=1)
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
callbacks = [early_stopping_callback, tensorboard_callback]

vae.compile(optimizer=keras.optimizers.legacy.Adam())
vae.fit(X_train, epochs=50, batch_size=128, callbacks=callbacks)

Epoch 1/50

KeyboardInterrupt: 

In [12]:
vae.save_weights('./vae_checkpoints/trained_vae')

## Sampling Latent Space

#### Build Transformer

In [12]:
def get_angles(pos, i, d_model):
  angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
  return pos * angle_rates

def positional_encoding(position, d_model):
  angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)
  
  # apply sin to even indices in the array; 2i
  angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
  
  # apply cos to odd indices in the array; 2i+1
  angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    
  pos_encoding = angle_rads[np.newaxis, ...]
    
  return tf.cast(pos_encoding, dtype=tf.float32)

def create_masks(inp, tar):
  # Encoder padding mask
  enc_padding_mask = create_padding_mask(inp)
  
  # Used in the 2nd attention block in the decoder.
  # This padding mask is used to mask the encoder outputs.
  dec_padding_mask = create_padding_mask(inp)
  
  # Used in the 1st attention block in the decoder.
  # It is used to pad and mask future tokens in the input received by 
  # the decoder.
  look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
  dec_target_padding_mask = create_padding_mask(tar)
  combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
  
  return enc_padding_mask, combined_mask, dec_padding_mask

def create_padding_mask(seq):
  seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
  
  # add extra dimensions to add the padding
  # to the attention logits.
  return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)

def create_look_ahead_mask(size):
  mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
  return mask  # (seq_len, seq_len)

def scaled_dot_product_attention(q, k, v, mask):
  """Calculate the attention weights.
  q, k, v must have matching leading dimensions.
  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  The mask has different shapes depending on its type(padding or look ahead) 
  but it must be broadcastable for addition.
  
  Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.
    
  Returns:
    output, attention_weights
  """

  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)
  
  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)  

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

  output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

  return output, attention_weights

class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model
    
    assert d_model % self.num_heads == 0
    
    self.depth = d_model // self.num_heads
    
    self.wq = tf.keras.layers.Dense(d_model)
    self.wk = tf.keras.layers.Dense(d_model)
    self.wv = tf.keras.layers.Dense(d_model)
    
    self.dense = tf.keras.layers.Dense(d_model)
        
  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])
    
  def call(self, v, k, q, mask):
    batch_size = tf.shape(q)[0]
    
    q = self.wq(q)  # (batch_size, seq_len, d_model)
    k = self.wk(k)  # (batch_size, seq_len, d_model)
    v = self.wv(v)  # (batch_size, seq_len, d_model)
    
    q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
    
    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    scaled_attention, attention_weights = scaled_dot_product_attention(
        q, k, v, mask)
    
    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

    concat_attention = tf.reshape(scaled_attention, 
                                  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

    output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)
        
    return output, attention_weights

def point_wise_feed_forward_network(d_model, dff):
  return tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)
      tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
  ])

class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, rate=0.1):
    super(EncoderLayer, self).__init__()

    self.mha = MultiHeadAttention(d_model, num_heads)
    self.ffn = point_wise_feed_forward_network(d_model, dff)

    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    
    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)
    
  def call(self, x, training, mask):

    attn_output, _ = self.mha(x, x, x, mask)  # (batch_size, input_seq_len, d_model)
    attn_output = self.dropout1(attn_output, training=training)
    out1 = self.layernorm1(x + attn_output)  # (batch_size, input_seq_len, d_model)
    
    ffn_output = self.ffn(out1)  # (batch_size, input_seq_len, d_model)
    ffn_output = self.dropout2(ffn_output, training=training)
    out2 = self.layernorm2(out1 + ffn_output)  # (batch_size, input_seq_len, d_model)
    
    return out2

class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, rate=0.1):
    super(DecoderLayer, self).__init__()

    self.mha1 = MultiHeadAttention(d_model, num_heads)
    self.mha2 = MultiHeadAttention(d_model, num_heads)

    self.ffn = point_wise_feed_forward_network(d_model, dff)
 
    self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    
    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)
    self.dropout3 = tf.keras.layers.Dropout(rate)
    
    
  def call(self, x, enc_output, training, 
           look_ahead_mask, padding_mask):
    # enc_output.shape == (batch_size, input_seq_len, d_model)

    attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)  # (batch_size, target_seq_len, d_model)
    attn1 = self.dropout1(attn1, training=training)
    out1 = self.layernorm1(attn1 + x)
    
    attn2, attn_weights_block2 = self.mha2(
        enc_output, enc_output, out1, padding_mask)  # (batch_size, target_seq_len, d_model)
    attn2 = self.dropout2(attn2, training=training)
    out2 = self.layernorm2(attn2 + out1)  # (batch_size, target_seq_len, d_model)
    
    ffn_output = self.ffn(out2)  # (batch_size, target_seq_len, d_model)
    ffn_output = self.dropout3(ffn_output, training=training)
    out3 = self.layernorm3(ffn_output + out2)  # (batch_size, target_seq_len, d_model)
    
    return out3, attn_weights_block1, attn_weights_block2

class Encoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
               maximum_position_encoding, rate=0.1):
    super(Encoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers
    
    self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
    self.pos_encoding = positional_encoding(maximum_position_encoding, 
                                            self.d_model)
    
    
    self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) 
                       for _ in range(num_layers)]
  
    self.dropout = tf.keras.layers.Dropout(rate)
        
  def call(self, x, training, mask):

    seq_len = tf.shape(x)[1]
    
    # adding embedding and position encoding.
    x = self.embedding(x)  # (batch_size, input_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)
    
    for i in range(self.num_layers):
      x = self.enc_layers[i](x, training, mask)
    
    return x  # (batch_size, input_seq_len, d_model)

class Decoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
               maximum_position_encoding, rate=0.1):
    super(Decoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers
    
    self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
    self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)
    
    self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate) 
                       for _ in range(num_layers)]
    self.dropout = tf.keras.layers.Dropout(rate)
    
  def call(self, x, enc_output, training, 
           look_ahead_mask, padding_mask):

    seq_len = tf.shape(x)[1]
    attention_weights = {}
    
    x = self.embedding(x)  # (batch_size, target_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]
    
    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x, block1, block2 = self.dec_layers[i](x, enc_output, training,
                                             look_ahead_mask, padding_mask)
      
      attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
      attention_weights['decoder_layer{}_block2'.format(i+1)] = block2
    
    # x.shape == (batch_size, target_seq_len, d_model)
    return x, attention_weights

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
  def __init__(self, d_model, warmup_steps=4000):
    super(CustomSchedule, self).__init__()
    
    self.d_model = d_model
    self.d_model = tf.cast(self.d_model, tf.float32)

    self.warmup_steps = warmup_steps
    
  def __call__(self, step):
    arg1 = tf.math.rsqrt(step)
    arg2 = step * (self.warmup_steps ** -1.5)
    
    return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

In [13]:
class Transformer(tf.keras.Model):
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, 
               target_vocab_size, pe_input, pe_target, rate=0.1):
    super(Transformer, self).__init__()

    self.encoder = Encoder(num_layers, d_model, num_heads, dff, 
                           input_vocab_size, pe_input, rate)

    self.decoder = Decoder(num_layers, d_model, num_heads, dff, 
                           target_vocab_size, pe_target, rate)

    self.final_layer = tf.keras.layers.Dense(target_vocab_size)
    
  def call(self, inp, tar, training, enc_padding_mask, 
           look_ahead_mask, dec_padding_mask):

    enc_output = self.encoder(inp, training, enc_padding_mask)  # (batch_size, inp_seq_len, d_model)
    
    # dec_output.shape == (batch_size, tar_seq_len, d_model)
    dec_output, attention_weights = self.decoder(
        tar, enc_output, training, look_ahead_mask, dec_padding_mask)
    
    final_output = self.final_layer(dec_output)  # (batch_size, tar_seq_len, target_vocab_size)
    
    return final_output, attention_weights

In [14]:
num_layers = 4
d_model = 32
dff = 512
num_heads = 8
input_vocab_size = 510
target_vocab_size = 510
dropout_rate = 0.2

learning_rate = CustomSchedule(d_model)

optimizer = tf.keras.optimizers.legacy.Adam(learning_rate, beta_1=0.9, beta_2=0.98, 
                                     epsilon=1e-9)

transformer = Transformer(num_layers, d_model, num_heads, dff,
                          input_vocab_size, target_vocab_size, 
                          pe_input=input_vocab_size, 
                          pe_target=target_vocab_size,
                          rate=dropout_rate)

In [15]:
checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(transformer=transformer,
                           optimizer=optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

Latest checkpoint restored!!


#### Decoder Function

In [None]:
'O=C(NCc1ccccc1)c1cc2c(n(CC3N3CCOCCC)c1=O)CCCCCCC'
'O=C(NCc1ccccc1)c1cc2c(n(CC3N3CCOCCC)c1=O)CCCCCCC'
'O=C(NCc1ccccc1)c1cc2c(n(CC3N3CCOCCC)c1=O)CCCCCCC'

In [16]:
import random
from tokenizer import SmilesTokenizer

def encode_sample(input_sample):
    z_mean, z_log_var, sampling = vae.encoder(input_sample)
    return z_mean, z_log_var, sampling

def sample_from_latent(z_mean, z_log_var):
    epsilon = tf.random.normal(tf.shape(z_mean))
    z = z_mean + tf.exp(0.5 * z_log_var) * epsilon
    return z

def decode_sample(z_sample):
    reconstructed_sample = vae.decoder(z_sample)
    return reconstructed_sample

def sampler(i):
    input_sample = X_train[i]
    input_sample = tf.expand_dims(input_sample, axis=0)
    
    z_mean, z_log_var, sampling = encode_sample(input_sample)

    z_sample = sample_from_latent(z_mean, z_log_var)
    
    reconstructed_sample = decode_sample(z_sample)
    reconstructed_sample = tf.squeeze(reconstructed_sample, axis=-1)
    return reconstructed_sample

def decode(embedding):
    tokenizer = SmilesTokenizer("./data/vocab.txt")
    sample_zeros = tf.zeros((1, 128))
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(sample_zeros, sample_zeros)

    dec_output, _ = transformer.decoder(sample_zeros, embedding, False, combined_mask, dec_padding_mask)
    final_output = transformer.final_layer(dec_output)

    smiles_string = []
    for token in range(128):
        smiles_string.append( int( tf.argmax( final_output[:, token, :], axis=-1 )[0] ) )
    smiles_string = tokenizer.decode(smiles_string).replace(' ', '')
    smiles_string = smiles_string[:smiles_string.find('[SEP]')]

    return smiles_string

for i in range(500):
    print(decode(sampler(i)))



KeyboardInterrupt: 

In [107]:
import random
from tokenizer import SmilesTokenizer
tokenizer = SmilesTokenizer("./data/vocab.txt")

cb2_smiles = open('./data/X_SMILES.txt', 'r')
cb2_smiles = cb2_smiles.read().splitlines()
sampling_smiles = cb2_smiles[ random.randint(0, len(cb2_smiles) - 1) ]

sample = tokenizer.encode(sampling_smiles)
sample = tf.keras.utils.pad_sequences([sample], maxlen=128, padding='post')
sample = tf.constant(sample)
sample_zeros = tf.zeros((1, 128))

enc_padding_mask, combined_mask, dec_padding_mask = create_masks(sample_zeros, sample_zeros)
enc_output = transformer.encoder(sample_zeros, False, enc_padding_mask)

dec_output, _ = transformer.decoder(sample_zeros, reconstructed_sample, False, combined_mask, dec_padding_mask)
final_output = transformer.final_layer(dec_output)

# enc_padding_mask, combined_mask, dec_padding_mask = create_masks(sample, sample)

# predictions, _ = transformer(sample, sample, 
#                                 False, 
#                                 enc_padding_mask, 
#                                 combined_mask, 
#                                 dec_padding_mask)

print(predictions)

tf.Tensor(
[[[-12.010056  -12.201371  -11.890745  ... -11.987468  -11.95752
   -12.05295  ]
  [-16.562353  -16.805267  -16.272493  ... -16.537561  -16.496244
   -16.327068 ]
  [-21.470423  -21.515547  -21.554365  ... -21.364697  -21.69297
   -21.427305 ]
  ...
  [-15.237321  -15.373925  -15.216337  ... -15.198552  -14.965597
   -15.211064 ]
  [-14.855588  -14.9820385 -14.84074   ... -14.81174   -14.574617
   -14.827436 ]
  [-14.50811   -14.626286  -14.498796  ... -14.458283  -14.214958
   -14.480728 ]]], shape=(1, 128, 510), dtype=float32)


In [105]:
from tokenizer import SmilesTokenizer
tokenizer = SmilesTokenizer("./data/vocab.txt")

preds = []

for token in range(128):
    preds.append( int( tf.argmax( predictions[:, token, :], axis=-1 )[0] ) )
pred_smiles = tokenizer.decode(preds).replace(' ', '')

pred_smiles[:pred_smiles.find('[SEP]')], sampling_smiles

('O=C(NCc1ccccc1)c1cc2c(n(CC3N3CCOCCC)c1=O)CCCCCCC',
 'CN(CC1CC1)S(=O)(=O)c1ccc(Nc2ccc(Cl)cc2Cl)cc1C(F)(F)F')

In [70]:
def decode(embedding):
    start_token = 12
    end_token = 13
    max_length = 128
    batch_size = 32
    
    decoder_input = tf.expand_dims([start_token] * 32, 0)
    generated_tokens = decoder_input

    embedding = tf.transpose(embedding, perm=[2, 0, 1])
    
    for i in range(max_length):
        # enc_padding_mask, combined_mask, dec_padding_mask = create_masks(decoder_input, decoder_input)
        dec_padding_mask = create_padding_mask(decoder_input)
        look_ahead_mask = create_look_ahead_mask(tf.shape(decoder_input)[1])
        combined_mask = tf.maximum(look_ahead_mask, dec_padding_mask)
        decoder_output = transformer.decoder(decoder_input, embedding, False, combined_mask, dec_padding_mask)
        
        last_token_logits = decoder_output[:, -1, :]
        next_token = tf.argmax(last_token_logits, axis=-1)
        generated_tokens = tf.concat([generated_tokens, tf.expand_dims(next_token, 1)], axis=-1)
        
        # Stop decoding if the end-of-sequence token is generated for all samples
        if tf.reduce_all(next_token == end_token):
            break
    
        # Update the decoder input for the next step
        decoder_input = generated_tokens
    
    # Remove the batch dimension from the generated tokens
    generated_tokens = tf.squeeze(generated_tokens, axis=0)
    generated_strings = tokenizer.decode(token_indices).replace(' ', '')

decode(X_train[0])

ValueError: Exception encountered when calling layer 'multi_head_attention_17' (type MultiHeadAttention).

Input 0 of layer "dense_96" is incompatible with the layer: expected axis -1 of input shape to have value 1, but received input with shape (1, 128, 32)

Call arguments received by layer 'multi_head_attention_17' (type MultiHeadAttention):
  • v=tf.Tensor(shape=(1, 128, 32), dtype=float32)
  • k=tf.Tensor(shape=(1, 128, 32), dtype=float32)
  • q=tf.Tensor(shape=(1, 32, 32), dtype=float32)
  • mask=tf.Tensor(shape=(1, 1, 1, 32), dtype=float32)

In [93]:
# Step 1: Encode an input sample to get mean and log variance of the latent distribution
def encode_sample(input_sample):
    z_mean, z_log_var, sampling = vae.encoder(input_sample)
    return z_mean, z_log_var, sampling

# Step 2: Sample a point from the latent distribution using the reparameterization trick
def sample_from_latent(z_mean, z_log_var):
    epsilon = tf.random.normal(tf.shape(z_mean))
    z = z_mean + tf.exp(0.5 * z_log_var) * epsilon
    return z

# Step 3: Decode the sampled point to generate a reconstructed output
def decode_sample(z_sample):
    reconstructed_sample = vae.decoder(z_sample)
    return reconstructed_sample

# Example usage:
# Assuming you have an input sample 'input_sample' (numpy array or tensor) for encoding
input_sample = X_train[0]  # Replace with your actual input sample
input_sample = tf.expand_dims(input_sample, axis=0)

# Encode the input sample to get mean and log variance of the latent distribution
z_mean, z_log_var, sampling = encode_sample(input_sample)

# Sample a point from the latent distribution using the reparameterization trick
# z_sample = sample_from_latent(z_mean, z_log_var)

# Decode the sampled point to generate a reconstructed output
reconstructed_sample = decode_sample(sampling)
reconstructed_sample = tf.squeeze(reconstructed_sample, axis=-1)

In [94]:
reconstructed_sample

<tf.Tensor: shape=(1, 128, 32), dtype=float32, numpy=
array([[[2.3307985e-04, 9.9998474e-01, 6.9007462e-01, ...,
         9.9988669e-01, 5.3884083e-04, 4.4157964e-01],
        [6.4971987e-03, 1.5548096e-03, 2.0074555e-01, ...,
         7.7221984e-01, 9.9871063e-01, 4.2480056e-04],
        [2.6680591e-02, 1.2697066e-03, 5.4089183e-01, ...,
         8.8679492e-01, 9.9519497e-01, 3.3123963e-02],
        ...,
        [5.5085020e-05, 2.9048098e-05, 2.9102489e-01, ...,
         6.6387671e-01, 6.4806358e-05, 6.0993427e-01],
        [2.1871034e-04, 1.0094435e-05, 2.0955595e-01, ...,
         6.5620542e-01, 7.8682715e-05, 6.5084594e-01],
        [1.3981329e-03, 1.0841378e-04, 1.4787267e-01, ...,
         6.7933577e-01, 4.2850224e-04, 6.8192679e-01]]], dtype=float32)>

In [95]:
def decode(embedding):
    start_token = 12
    end_token = 13
    max_length = 128
    batch_size = 32
    
    decoder_input = tf.expand_dims([start_token] * 128, 0)
    generated_tokens = decoder_input

    embedding = tf.transpose(embedding, perm=[2, 0, 1])
    
    for i in range(max_length):
        # enc_padding_mask, combined_mask, dec_padding_mask = create_masks(decoder_input, decoder_input)
        dec_padding_mask = create_padding_mask(decoder_input)
        look_ahead_mask = create_look_ahead_mask(tf.shape(decoder_input)[1])
        combined_mask = tf.maximum(look_ahead_mask, dec_padding_mask)
        decoder_output = transformer.decoder(decoder_input, embedding, False, combined_mask, dec_padding_mask)
        
        last_token_logits = decoder_output[:, -1, :]
        next_token = tf.argmax(last_token_logits, axis=-1)
        generated_tokens = tf.concat([generated_tokens, tf.expand_dims(next_token, 1)], axis=-1)
        
        # Stop decoding if the end-of-sequence token is generated for all samples
        if tf.reduce_all(next_token == end_token):
            break
    
        # Update the decoder input for the next step
        decoder_input = generated_tokens
    
    # Remove the batch dimension from the generated tokens
    generated_tokens = tf.squeeze(generated_tokens, axis=0)
    generated_strings = tokenizer.decode(token_indices).replace(' ', '')

decode(enc_output)

ValueError: Exception encountered when calling layer 'multi_head_attention_29' (type MultiHeadAttention).

Input 0 of layer "dense_161" is incompatible with the layer: expected axis -1 of input shape to have value 32, but received input with shape (32, 1, 128)

Call arguments received by layer 'multi_head_attention_29' (type MultiHeadAttention):
  • v=tf.Tensor(shape=(32, 1, 128), dtype=float32)
  • k=tf.Tensor(shape=(32, 1, 128), dtype=float32)
  • q=tf.Tensor(shape=(1, 128, 32), dtype=float32)
  • mask=tf.Tensor(shape=(1, 1, 1, 128), dtype=float32)

## Previous Attempts

In [None]:
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split

cb2_embeddings = np.load("./data/CB2_EMBED.npz")

In [None]:
MAX_ROWS = 104
MAX_COLUMNS = 512

cb2_embeddings = list( cb2_embeddings.values() )
X_train, x_test = train_test_split(cb2_embeddings, test_size=0.3, random_state=17)

X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train, padding='post', maxlen=MAX_ROWS)
x_test = tf.keras.preprocessing.sequence.pad_sequences(x_test, padding='post', maxlen=MAX_ROWS)
X_train = tf.expand_dims(X_train, axis=-1)
x_test = tf.expand_dims(x_test, axis=-1)

X_train.shape, x_test.shape, X_train[1].shape, X_train[1][0].shape

In [None]:
X_train = tf.cast(X_train, dtype=tf.float32)
X_train = X_train / tf.norm(X_train)

In [None]:
from tensorflow import keras
from tensorflow.keras import layers

In [None]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.random.normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [None]:
latent_dim = 2

encoder_inputs = keras.Input(shape=(MAX_ROWS, MAX_COLUMNS, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling()([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

In [None]:
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(26 * 128 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((26, 128, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

In [None]:
class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder
        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.kl_loss_tracker,
        ]

    def train_step(self, data):
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(data)
            reconstruction = self.decoder(z)
            reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum(
                    keras.losses.binary_crossentropy(data, reconstruction), axis=(1, 2)
                )
            )
            kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss = reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result(),
        }

In [None]:
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.legacy.Adam())
vae.fit(X_train, epochs=30, batch_size=128)

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Lambda, Reshape, Layer
from tensorflow.keras.models import Model
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras import backend as K

# Define the input shape
input_shape = (104, 512)
latent_dim = 64  # Adjust this based on your requirements

# Encoder network
inputs = Input(shape=input_shape)
x = Dense(256, activation='relu')(inputs)
x = Dense(128, activation='relu')(x)
z_mean = Dense(latent_dim)(x)
z_log_var = Dense(latent_dim)(x)

# Reparameterization trick for the sampling layer
# Update the sampling function
def sampling(args):
    z_mean, z_log_var = args
    batch_size = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    epsilon = K.random_normal(shape=(batch_size, dim, latent_dim))  # <-- Generate noise with the right shape
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

z = Lambda(sampling)([z_mean, z_log_var])

# Decoder network
decoder_input = Input(shape=(MAX_ROWS, latent_dim,))
x = Dense(128, activation='relu')(decoder_input)
x = Dense(256, activation='relu')(x)
outputs = Dense(input_shape[1], activation='sigmoid')(x)

# Define the encoder and decoder models
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
decoder = Model(decoder_input, outputs, name='decoder')

# Custom VAE layer for loss calculation
class VaeLossLayer(Layer):
    def __init__(self, **kwargs):
        super(VaeLossLayer, self).__init__(**kwargs)

    def vae_loss(self, y_true, y_pred):
        reconstruction_loss = K.mean(binary_crossentropy(y_true, y_pred)) * input_shape[0] * input_shape[1]
        kl_loss = -0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
        return reconstruction_loss + kl_loss

    def call(self, inputs):
        y_true, y_pred = inputs
        loss = self.vae_loss(y_true, y_pred)
        self.add_loss(loss, inputs=inputs)
        return y_true  # We return y_true to maintain the original output

# Combine the encoder and decoder to form the VAE model with the custom loss layer
vae_outputs = decoder(encoder(inputs)[2])
vae_loss_layer = VaeLossLayer(name='vae_loss_layer')([inputs, vae_outputs])
vae = Model(inputs, vae_outputs)

# Manually compute the loss and gradients for custom training loop
optimizer = tf.keras.optimizers.legacy.Adam()
def compute_loss(y_true, y_pred):
    reconstruction_loss = K.mean(binary_crossentropy(y_true, y_pred)) * input_shape[0] * input_shape[1]
    kl_loss = -0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    return reconstruction_loss + kl_loss

@tf.function
def train_step(x):
    with tf.GradientTape(persistent=True) as tape:
        z_mean, z_log_var, z = encoder(x, training=True)  # Use the encoder model directly
        x_pred = decoder(z, training=True)  # Use the decoder model directly

        # Compute the reconstruction loss
        reconstruction_loss = tf.reduce_mean(
            binary_crossentropy(x, x_pred)
        ) * input_shape[0] * input_shape[1]

        # Compute the KL divergence loss (for each sample in the batch)
        kl_loss = -0.5 * tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=-1)

        # Total loss
        total_loss = reconstruction_loss + tf.reduce_mean(kl_loss) + tf.reduce_sum(vae.losses)

    # Compute gradients and update the model
    gradients = tape.gradient(total_loss, vae.trainable_variables)
    optimizer.apply_gradients(zip(gradients, vae.trainable_variables))

    return total_loss
    
vae.summary()

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Lambda, Reshape, Layer
from tensorflow.keras.models import Model
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras import backend as K

# Define the input shape
input_shape = (104, 512)
latent_dim = 64  # Adjust this based on your requirements

# Encoder network
inputs = Input(shape=input_shape)
x = Dense(256, activation='relu')(inputs)
x = Dense(128, activation='relu')(x)
z_mean = Dense(latent_dim)(x)
z_log_var = Dense(latent_dim)(x)

# Reparameterization trick for the sampling layer
# Update the sampling function
def sampling(args):
    z_mean, z_log_var = args
    batch_size = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    epsilon = K.random_normal(shape=(batch_size, dim, latent_dim))  # <-- Generate noise with the right shape
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

z = Lambda(sampling)([z_mean, z_log_var])

# Decoder network
decoder_input = Input(shape=(MAX_ROWS, latent_dim,))
x = Dense(128, activation='relu')(decoder_input)
x = Dense(256, activation='relu')(x)
outputs = Dense(input_shape[1], activation='sigmoid')(x)

# Define the encoder and decoder models
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
decoder = Model(decoder_input, outputs, name='decoder')

# Custom VAE layer for loss calculation
class VaeLossLayer(Layer):
    def __init__(self, **kwargs):
        super(VaeLossLayer, self).__init__(**kwargs)

    def vae_loss(self, y_true, y_pred):
        reconstruction_loss = K.mean(binary_crossentropy(y_true, y_pred)) * input_shape[0] * input_shape[1]
        kl_loss = -0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
        return reconstruction_loss + kl_loss

    def call(self, inputs):
        y_true, y_pred = inputs
        loss = self.vae_loss(y_true, y_pred)
        self.add_loss(loss, inputs=inputs)
        return y_true  # We return y_true to maintain the original output

# Combine the encoder and decoder to form the VAE model with the custom loss layer
vae_outputs = decoder(encoder(inputs)[2])
vae_loss_layer = VaeLossLayer(name='vae_loss_layer')([inputs, vae_outputs])
vae = Model(inputs, vae_outputs)

# Manually compute the loss and gradients for custom training loop
optimizer = tf.keras.optimizers.legacy.Adam()
def compute_loss(y_true, y_pred):
    reconstruction_loss = K.mean(binary_crossentropy(y_true, y_pred)) * input_shape[0] * input_shape[1]
    kl_loss = -0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    return reconstruction_loss + kl_loss

@tf.function
def train_step(x):
    with tf.GradientTape(persistent=True) as tape:
        z_mean, z_log_var, z = encoder(x, training=True)  # Use the encoder model directly
        x_pred = decoder(z, training=True)  # Use the decoder model directly

        # Compute the reconstruction loss
        reconstruction_loss = tf.reduce_mean(
            binary_crossentropy(x, x_pred)
        ) * input_shape[0] * input_shape[1]

        # Compute the KL divergence loss (for each sample in the batch)
        kl_loss = -0.5 * tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=-1)

        # Total loss
        total_loss = reconstruction_loss + tf.reduce_mean(kl_loss) + tf.reduce_sum(vae.losses)

    # Compute gradients and update the model
    gradients = tape.gradient(total_loss, vae.trainable_variables)
    optimizer.apply_gradients(zip(gradients, vae.trainable_variables))

    return total_loss
    
vae.summary()

In [None]:
# Training loop
epochs = 30
batch_size = 32
num_batches = len(X_train) // batch_size

for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = start_idx + batch_size
        x_batch = X_train[start_idx:end_idx]
        loss = train_step(x_batch)
        print(f"Batch {batch_idx+1}/{num_batches}, Loss: {loss.numpy():.4f}")

In [None]:
# Generate new samples using the VAE
random_samples = tf.random.normal(shape=(1, 104, latent_dim))
generated_images = decoder(random_samples)

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Lambda, Reshape
from tensorflow.keras.models import Model
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras import backend as K

# Define the input shape
input_shape = (104, 512)
latent_dim = 64  # Adjust this based on your requirements

# Encoder network
inputs = Input(shape=input_shape)
x = Dense(256, activation='relu')(inputs)
x = Dense(128, activation='relu')(x)
z_mean = Dense(latent_dim)(x)
z_log_var = Dense(latent_dim)(x)

# Reparameterization trick for the sampling layer
def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0.0, stddev=1.0)
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

z = Lambda(sampling)([z_mean, z_log_var])

# Decoder network
decoder_input = Input(shape=(MAX_ROWS, latent_dim,))
x = Dense(128, activation='relu')(decoder_input)
x = Dense(256, activation='relu')(x)
outputs = Dense(input_shape[1], activation='sigmoid')(x)
# outputs = Reshape(input_shape)(outputs)

# Define the encoder and decoder models
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
decoder = Model(decoder_input, outputs, name='decoder')

# Combine the encoder and decoder to form the VAE model
vae_outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, vae_outputs, name='vae')

# Define the VAE loss function
def vae_loss(y_true, y_pred):
    reconstruction_loss = K.mean(binary_crossentropy(y_true, y_pred)) * input_shape[0] * input_shape[1]
    kl_loss = -0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    return reconstruction_loss + kl_loss

# Compile the VAE model with the loss function directly
vae.compile(optimizer='adam', loss=vae_loss)

# Print the model summary
vae.summary()


In [None]:
# Train the VAE
vae.fit(X_train, X_train, epochs=30, batch_size=32)

# Generate new samples using the VAE
random_samples = tf.random.normal(shape=(10, latent_dim))
generated_images = vae.decoder(random_samples)

# Optionally, you can save and load the trained model
# vae.save('vae_model')
# vae = tf.keras.models.load_model('vae_model')

In [None]:
class VariationalAutoencoder(tf.keras.Model):
    def __init__(self, latent_dim):
        super(VariationalAutoencoder, self).__init__()
        self.latent_dim = latent_dim

        # Encoder network
        self.encoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(MAX_ROWS, MAX_COLUMNS, 1)),
            tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=(2, 2), activation='relu'),
            tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=(2, 2), activation='relu'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(latent_dim + latent_dim),
        ])

        # Decoder network
        self.decoder = tf.keras.Sequential([
            tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
            tf.keras.layers.Dense(units=(MAX_ROWS // 8) * (MAX_COLUMNS // 8) * 32, activation='relu'),  # Adjust units based on desired intermediate shape
            tf.keras.layers.Reshape(target_shape=(MAX_ROWS // 8, MAX_COLUMNS // 8, 32)),  # Adjust target_shape based on desired intermediate shape
            tf.keras.layers.Conv2DTranspose(filters=64, kernel_size=3, strides=(2, 2), padding='SAME', activation='relu'),
            tf.keras.layers.Conv2DTranspose(filters=32, kernel_size=3, strides=(2, 2), padding='SAME', activation='relu'),
            tf.keras.layers.Conv2DTranspose(filters=1, kernel_size=3, strides=(2, 2), padding='SAME', activation='sigmoid'),
        ])

    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=tf.shape(mean))
        return eps * tf.exp(logvar * 0.5) + mean

    def decode(self, z):
        return self.decoder(z)

    def call(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        reconstructed = self.decode(z)
        
        # Crop the reconstructed output to match the input size (106, 512)
        target_height, target_width = MAX_ROWS, MAX_COLUMNS
        reconstructed = tf.image.crop_to_bounding_box(reconstructed, 
                                                      0, 0, target_height, target_width)
        
        return reconstructed, mean, logvar

# Define the loss function for VAE
def vae_loss(x, reconstructed):
    # Reconstruction loss (binary cross-entropy)
    reconstruction_loss = tf.reduce_mean(tf.keras.losses.binary_crossentropy(tf.keras.backend.flatten(x),
                                                                              tf.keras.backend.flatten(reconstructed)))

    # KL divergence loss
    kl_loss = -0.5 * tf.reduce_sum(1 + reconstructed[2] - tf.square(reconstructed[1]) - tf.exp(reconstructed[2]), axis=-1)
    kl_loss = tf.reduce_mean(kl_loss)

    return reconstruction_loss + kl_loss

In [None]:
# Set hyperparameters and create the VAE model
latent_dim = 32
vae = VariationalAutoencoder(latent_dim)
vae.compile(optimizer=tf.keras.optimizers.legacy.Adam(learning_rate=0.001), loss=vae_loss)
vae.build( (None, MAX_ROWS, MAX_COLUMNS) )

# vae.summary()

# vae.encoder.summary()
vae.decoder.summary()

In [None]:
# Compile the model

# Train the VAE
vae.fit(X_train, X_train, epochs=30, batch_size=32)

# Generate new samples using the VAE
random_samples = tf.random.normal(shape=(10, latent_dim))
generated_images = vae.decoder(random_samples)

# Optionally, you can save and load the trained model
# vae.save('vae_model')
# vae = tf.keras.models.load_model('vae_model')