<a href="https://colab.research.google.com/github/osjayaprakash/deeplearning/blob/main/transformer_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import kagglehub

# Download latest version
root_dir = kagglehub.dataset_download("shahrukhkhan/im2latex100k")
# path = kagglehub.dataset_download("gregoryeritsyan/im2latex-230k")

print("Path to dataset files:", root_dir)

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import (Input, Conv2D, MaxPooling2D, Flatten,
                                     Dense, GRU, Embedding, Bidirectional,
                                     TimeDistributed, Concatenate, RepeatVector, LSTM, MultiHeadAttention, LayerNormalization, Add, Dropout )
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np
import matplotlib.pyplot as plt
import platform
import sys
import pandas as pd
import sklearn as sk
import scipy as sp
import einops

tf.config.experimental.list_physical_devices('GPU')
print(f"Python Platform: {platform.platform()}")
print(f"Tensor Flow Version: {tf.__version__}")
#print(f"Keras Version: {tf.keras.__version__}")
print()
print(f"Python {sys.version}")
print(f"Pandas {pd.__version__}")
print(f"Scikit-Learn {sk.__version__}")
print(f"SciPy {sp.__version__}")
print(tf.config.list_physical_devices())

# Preprocess

In [None]:
def preprocess_image(image):
    """Preprocess the input image: Resize and normalize."""
    image = tf.image.resize(image, (50, 200))  # Resize to (50, 200)
    image = image / 255.0  # Normalize to [0, 1]
    return image

def load_and_preprocess_images(image_paths):
    """Load and preprocess a batch of images."""
    # Use Gray scale
    images = [preprocess_image(tf.io.decode_image(tf.io.read_file(path), channels=1))
              for path in image_paths]
    return tf.stack(images)

def prepare_sequences(latex_texts, max_seq_length):
    """Convert LaTeX texts to padded sequences of tokens."""
    sequences = [text_to_sequence(text) for text in latex_texts]
    return pad_sequences(sequences, maxlen=max_seq_length, padding='post')


In [None]:
# %%time
# %%prun

df = pd.read_csv(f"{root_dir}/im2latex_train.csv", nrows=5000)

train_image_paths = []
train_latex_texts = []

for index, row in df.iterrows():
    train_image_paths += [f"{root_dir}//formula_images_processed/formula_images_processed/{row.image}"]
    train_latex_texts += ["<START> " + row.formula + " <END>"]

train_images = load_and_preprocess_images(train_image_paths)
# Enable Numpy behaviour of TF
# tf.experimental.numpy.experimental_enable_numpy_behavior()

# vocab_size, max_seq_length = fit_tokenizer(train_latex_texts)

# train_sequences = prepare_sequences(train_latex_texts, max_seq_length)
# #train_sequences = np.expand_dims(train_sequences, -1)
# print("train_images:", train_images.shape)
# print("train_sequences:", train_sequences.shape)

In [None]:
vocabulary_size = 5000
tokenizer = tf.keras.layers.TextVectorization(
    max_tokens=vocabulary_size)

In [None]:
tokenizer.adapt(train_latex_texts)

In [None]:
word_to_index = tf.keras.layers.StringLookup(
    mask_token="",
    vocabulary=tokenizer.get_vocabulary())
index_to_word = tf.keras.layers.StringLookup(
    mask_token="",
    vocabulary=tokenizer.get_vocabulary(),
    invert=True)

In [None]:
latex_labels = tokenizer(train_latex_texts)
train_sequences = np.asarray(latex_labels).astype(np.float32)

In [None]:
print(latex_labels.shape)
print(train_images.shape)

## Transformer Model

In [None]:
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 10  # For real training, use num_epochs=100. 10 is a test value
patch_size = 4  # Size of the patches to be extract from the input images
num_patches = 600
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [
    2048,
    1024,
]  # Size

IMG_SHAPE = (50, 200, 1)
EMBEDDING_DIM = 256
lstm_units = 256
max_seq_len_1 = max(len(seq) for seq in latex_labels) - 1

In [None]:
class Patches(tf.keras.layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        input_shape = tf.keras.ops.shape(images)
        batch_size = input_shape[0]
        height = input_shape[1]
        width = input_shape[2]
        channels = input_shape[3]
        num_patches_h = height // self.patch_size
        num_patches_w = width // self.patch_size
        patches = tf.keras.ops.image.extract_patches(images, size=self.patch_size)
        patches = tf.keras.ops.reshape(
            patches,
            (
                batch_size,
                num_patches_h * num_patches_w,
                self.patch_size * self.patch_size * channels,
            ),
        )
        return patches

    def get_config(self):
        config = super().get_config()
        config.update({"patch_size": self.patch_size})
        return config

In [None]:
class PatchEncoder(tf.keras.layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = Dense(units=projection_dim)
        self.position_embedding = Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.keras.ops.expand_dims(
            tf.keras.ops.arange(start=0, stop=self.num_patches, step=1), axis=0
        )
        projected_patches = self.projection(patch)
        encoded = projected_patches + self.position_embedding(positions)
        return encoded

    def get_config(self):
        config = super().get_config()
        config.update({"num_patches": self.num_patches})
        return config

In [None]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = Dense(units, activation=tf.keras.activations.gelu)(x)
        x = Dropout(dropout_rate)(x)
    return x

In [None]:
def vision_transformer_encoder(input_shape):
    inputs =  Input(shape=input_shape)
    # Create patches.
    patches = Patches(patch_size)(inputs)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = tf.keras.layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = Flatten()(representation)
    representation = Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classify outputs.
    #logits = tf.keras.layers.Dense(2)(features)
    # Create the Keras model.
    model = tf.keras.Model(inputs=inputs, outputs=features)
    return model

In [None]:
vit = vision_transformer_encoder(IMG_SHAPE)

In [None]:
vit.summary()

In [None]:
class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self,
               *,
               d_model,
               num_heads,
               dff,
               dropout_rate=0.1):
    super(DecoderLayer, self).__init__()

    self.causal_self_attention = CausalSelfAttention(
        num_heads=num_heads,
        key_dim=d_model,
        dropout=dropout_rate)

    self.cross_attention = CrossAttention(
        num_heads=num_heads,
        key_dim=d_model,
        dropout=dropout_rate)

    self.ffn = FeedForward(d_model, dropout_rate)

  def call(self, x, context):
    x = self.causal_self_attention(x=x)
    x = self.cross_attention(x=x, context=context)

    # Cache the last attention scores for plotting later
    self.last_attn_scores = self.cross_attention.last_attn_scores

    x = self.ffn(x)  # Shape `(batch_size, seq_len, d_model)`.
    return x

In [None]:
def positional_encoding(length, depth):
  depth = depth/2

  positions = np.arange(length)[:, np.newaxis]     # (seq, 1)
  depths = np.arange(depth)[np.newaxis, :]/depth   # (1, depth)

  angle_rates = 1 / (10000**depths)         # (1, depth)
  angle_rads = positions * angle_rates      # (pos, depth)

  pos_encoding = np.concatenate(
      [np.sin(angle_rads), np.cos(angle_rads)],
      axis=-1)

  return tf.cast(pos_encoding, dtype=tf.float32)

class PositionalEmbedding(tf.keras.layers.Layer):
  def __init__(self, vocab_size, d_model):
    super().__init__()
    self.d_model = d_model
    self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True)
    self.pos_encoding = positional_encoding(length=2048, depth=d_model)

  def compute_mask(self, *args, **kwargs):
    return self.embedding.compute_mask(*args, **kwargs)

  def call(self, x):
    length = tf.shape(x)[1]
    x = self.embedding(x)
    # This factor sets the relative scale of the embedding and positonal_encoding.
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x = x + self.pos_encoding[tf.newaxis, :length, :]
    return x

In [None]:
class TransformerDecoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff, vocab_size,
               dropout_rate=0.1):
    super(TransformerDecoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size,
                                             d_model=d_model)
    self.dropout = tf.keras.layers.Dropout(dropout_rate)
    self.dec_layers = [
        DecoderLayer(d_model=d_model, num_heads=num_heads,
                     dff=dff, dropout_rate=dropout_rate)
        for _ in range(num_layers)]

    self.last_attn_scores = None

  def call(self, x, context):
    # `x` is token-IDs shape (batch, target_seq_len)
    x = self.pos_embedding(x)  # (batch_size, target_seq_len, d_model)

    x = self.dropout(x)

    for i in range(self.num_layers):
      x  = self.dec_layers[i](x, context)

    self.last_attn_scores = self.dec_layers[-1].last_attn_scores

    # The shape of x is (batch_size, target_seq_len, d_model).
    return x

In [None]:
sample_decoder = TransformerDecoder(num_layers=4,
                         d_model=512,
                         num_heads=8,
                         dff=2048,
                         vocab_size=8000)

In [None]:
def build_cnn_encoder(input_shape):
    inputs = Input(shape=input_shape)
    x = Conv2D(32, (3, 3), activation='relu', padding='same')(inputs)
    x = MaxPooling2D((2, 2))(x)
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    x = MaxPooling2D((2, 2))(x)
    x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
    x = MaxPooling2D((2, 2))(x)
    x = Flatten()(x)
    x = Dense(EMBEDDING_DIM, activation='relu')(x)
    return Model(inputs, x)

In [None]:
def build_rnn_encoder(decoder_input, encoder_output, target_vocab_size, max_seq_len_1):

    embedding_layer = Embedding(input_dim=target_vocab_size, output_dim=EMBEDDING_DIM, input_length=max_seq_len_1)
    embedded_seq = embedding_layer(decoder_input)

    decoder_lstm_input = tf.keras.layers.Concatenate(axis=-1)([encoder_output, embedded_seq])
    decoder_lstm = LSTM(lstm_units, return_sequences=True)(decoder_lstm_input)
    decoder_output = TimeDistributed(Dense(target_vocab_size, activation="softmax"))(decoder_lstm)

    return Model(inputs=[decoder_input, encoder_output], outputs= decoder_output)

In [None]:
def build_model(input_shape, num_layers, d_model, num_heads, dff, target_vocab_size, max_seq_len_1):
    #encoder = build_cnn_encoder(input_shape)
    encoder = vision_transformer_encoder(input_shape)
    image_input = Input(shape=input_shape, name="image_input")

    encoder_output = encoder(image_input)
    encoder_output = RepeatVector(max_seq_len_1)(encoder_output)  # Repeat encoder output for each time step

    decoder_input = Input(shape=(max_seq_len_1,), name="decoder_input")
    decoder = build_rnn_encoder(decoder_input, encoder_output, target_vocab_size, max_seq_len_1)

    decoder_output = decoder([decoder_input, encoder_output])
    return Model(inputs=[image_input, decoder_input], outputs=decoder_output)

In [None]:
model = build_model(IMG_SHAPE, 2, 256, 2, 256, tokenizer.vocabulary_size(), max_seq_len_1)
#transformer_model = Transformer(tokenizer, output_layer=output_layer, units=128, dropout_rate=0.5, num_layers=2, num_heads=2)

In [None]:
model.summary()

In [None]:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])


In [None]:
print(train_sequences[..., :-1].shape)
print(train_sequences[..., 1:].shape)
print(train_sequences.shape)

In [None]:
with tf.device('/GPU:0'):
    model.fit([train_images, train_sequences[..., :-1]],
              train_sequences[..., 1:],
              epochs=20,
              batch_size=128,
              validation_split=0.2)

from tensorflow.keras.models import load_model


In [None]:
transformer_model.save('/home/ubuntu/model_av.keras')

# Model Training

In [None]:
with tf.device('/GPU:0'):
    model.fit([train_images, train_sequences[:, :-1]],
              train_sequences[:, 1:],
              epochs=20,
              batch_size=128,
              validation_split=0.2)

from tensorflow.keras.models import load_model
model.save('/home/ubuntu/latex_model.keras')

#model = load_model('latex_model.h5')


In [None]:
#dot_img_file =
import keras
keras.utils.plot_model(model,
                       show_shapes=True,
                       show_dtype=True,
                       show_layer_names=True,
                       expand_nested=True,
                       show_layer_activations=True,
                       )

In [None]:
import keras
keras.utils.plot_model(model,
                       show_shapes=True,
                       show_dtype=True,
                       show_layer_names=True,
                       expand_nested=True,
                       show_layer_activations=True,
                       to_file='/Users/jayaprakash/latex_model.png'
                       )

# Predict

In [None]:
import numpy as np

def predict_latex_sequence(model, image, tokenizer):
    """
    Predict LaTeX sequence from a single image.

    Parameters:
    - model: Trained Keras model for predicting LaTeX sequence.
    - image: Input image (preprocessed to match training dimensions).
    - tokenizer: Tokenizer fitted on LaTeX sequences for decoding predictions.
    - max_seq_len: Maximum sequence length for the predicted sequence.

    Returns:
    - latex_sequence: Predicted LaTeX sequence as a string.
    """
    # Prepare input image and initialize the sequence
    image = np.expand_dims(image, axis=0)  # Add batch dimension
    start_token = tokenizer.word_index["<START>"]
    end_token = tokenizer.word_index["<END>"]

    # Initial sequence with the start token
    sequence = [start_token]

    for _ in range(max_seq_len_1):
        # Pad the current sequence to match input length
        padded_sequence = np.pad(sequence, (0, max_seq_len_1 - len(sequence)), mode='constant')
        padded_sequence = np.expand_dims(padded_sequence, axis=0)  # Add batch dimension

        # Predict next token
        preds = model.predict([image, padded_sequence])
        next_token = np.argmax(preds[0, len(sequence) - 1, :])

        # Break if end token is reached
        if next_token == end_token:
            break

        # Add the predicted token to the sequence
        sequence.append(next_token)

    # Decode the token sequence to a string
    latex_sequence = tokenizer.sequences_to_texts([sequence[1:]])[0]  # Skip the start token
    return latex_sequence

predicted_latex = predict_latex_sequence(model, train_images[12], tokenizer)
print("Predicted LaTeX:", predicted_latex)
#print("Original Seq:", train_sequences[0])
print("Original Seq:", train_latex_texts[12])