In [24]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import random


In [25]:
OPERATORS = ['+', '-', '*', '/']
IDENTIFIERS = list('abcde')
SPECIAL_TOKENS = ['PAD', 'SOS', 'EOS']
SYMBOLS = ['(', ')', '+', '-', '*', '/']
VOCAB = SPECIAL_TOKENS + SYMBOLS + IDENTIFIERS + ['JUNK']

token_to_id = {tok: i for i, tok in enumerate(VOCAB)}
id_to_token = {i: tok for tok, i in token_to_id.items()}

VOCAB_SIZE = len(VOCAB)
PAD_ID = token_to_id['PAD']
SOS_ID = token_to_id['SOS']
EOS_ID = token_to_id['EOS']

def generate_infix_expression(max_depth):
    if max_depth == 0:
        return random.choice(IDENTIFIERS)
    elif random.random() < 0.5:
        return generate_infix_expression(max_depth - 1)
    else:
        left = generate_infix_expression(max_depth - 1)
        right = generate_infix_expression(max_depth - 1)
        op = random.choice(OPERATORS)
        return f'({left} {op} {right})'

def tokenize(expr):
    return [c for c in expr if c in token_to_id]



def infix_to_postfix(tokens):
    precedence = {'+': 1, '-': 1, '*': 2, '/': 2}
    output, stack = [], []
    for token in tokens:
        if token in IDENTIFIERS:
            output.append(token)
        elif token in OPERATORS:
            while stack and stack[-1] in OPERATORS and precedence[stack[-1]] >= precedence[token]:
                output.append(stack.pop())
            stack.append(token)
        elif token == '(':
            stack.append(token)
        elif token == ')':
            while stack and stack[-1] != '(':
                output.append(stack.pop())
            stack.pop()
    while stack:
        output.append(stack.pop())
    return output

MAX_DEPTH = 3
MAX_LEN = 4 * 2**MAX_DEPTH - 2  # Safe upper bound for postfix len

def encode(tokens, max_len=MAX_LEN):
    ids = [token_to_id[t] for t in tokens] + [EOS_ID]
    return ids + [PAD_ID] * (max_len - len(ids))


def decode_sequence(token_ids, id_to_token, pad_token='PAD', eos_token='EOS'):
    tokens = []
    for token_id in token_ids:
        token = id_to_token.get(token_id, '?')
        if token == eos_token:
            break
        if token != pad_token:
            tokens.append(token)
    return ' '.join(tokens)



def generate_dataset(n, max_depth=MAX_DEPTH):
    X, Y = [], []
    for _ in range(n):
        expr = generate_infix_expression(max_depth)
        infix = tokenize(expr)
        postfix = infix_to_postfix(infix)
        X.append(encode(infix))
        Y.append(encode(postfix))
    return np.array(X), np.array(Y)


def shift_right(seqs):
    shifted = np.zeros_like(seqs)
    shifted[:, 1:] = seqs[:, :-1]
    shifted[:, 0] = SOS_ID
    return shifted


# Create training and validation data
X_train, Y_train = generate_dataset(10000)
decoder_input_train = shift_right(Y_train)

X_val, Y_val = generate_dataset(10000)
decoder_input_val = shift_right(Y_val)

# Sanity check
i = np.random.randint(10000)
print("Example", i)
print("Infix  :", decode_sequence(X_train[i], id_to_token))
print("Postfix:", decode_sequence(Y_train[i], id_to_token))
print("Shifted:", decode_sequence(decoder_input_train[i], id_to_token))

Example 2270
Infix  : ( d + a )
Postfix: d a +
Shifted: SOS d a +


In [26]:
EMBED_DIM = 64
NUM_HEADS = 4
FF_DIM = 64
NUM_LAYERS = 2

class PositionalEmbedding(layers.Layer):
    def __init__(self, vocab_size, embed_dim, max_len):
        super().__init__()
        self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.pos_emb = self.add_weight(name="pos_emb", shape=[1, max_len, embed_dim])

    def call(self, x):
        return self.token_emb(x) + self.pos_emb[:, :tf.shape(x)[1], :]


def transformer_encoder(embed, num_heads, ff_dim, dropout=0.1):
    # Self-attention
    x = layers.LayerNormalization(epsilon=1e-6)(embed)
    attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed.shape[-1])(x, x)
    x = layers.Add()([attn, embed])

    # Feed-forward
    x_norm = layers.LayerNormalization(epsilon=1e-6)(x)
    ff = layers.Dense(ff_dim, activation='relu')(x_norm)
    ff = layers.Dense(embed.shape[-1])(ff)
    return layers.Add()([ff, x])


def transformer_decoder(embed, enc_output, num_heads, ff_dim, dropout=0.1):
    # Masked self-attention
    x = layers.LayerNormalization(epsilon=1e-6)(embed)
    attn1 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed.shape[-1])(x, x, attention_mask=tf.linalg.band_part(tf.ones((MAX_LEN, MAX_LEN)), -1, 0))
    x = layers.Add()([attn1, embed])

    # Encoder-Decoder attention
    x_norm = layers.LayerNormalization(epsilon=1e-6)(x)
    attn2 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed.shape[-1])(x_norm, enc_output)
    x = layers.Add()([attn2, x])

    # Feed-forward
    x_norm = layers.LayerNormalization(epsilon=1e-6)(x)
    ff = layers.Dense(ff_dim, activation='relu')(x_norm)
    ff = layers.Dense(embed.shape[-1])(ff)
    return layers.Add()([ff, x])


from tensorflow.keras import Input, Model

def build_transformer_model():
    encoder_inputs = Input(shape=(MAX_LEN,), name="encoder_input")
    decoder_inputs = Input(shape=(MAX_LEN,), name="decoder_input")

    encoder_embed = PositionalEmbedding(VOCAB_SIZE, EMBED_DIM, MAX_LEN)(encoder_inputs)
    decoder_embed = PositionalEmbedding(VOCAB_SIZE, EMBED_DIM, MAX_LEN)(decoder_inputs)

    x = encoder_embed
    for _ in range(NUM_LAYERS):
        x = transformer_encoder(x, NUM_HEADS, FF_DIM)
    encoder_output = x

    y = decoder_embed
    for _ in range(NUM_LAYERS):
        y = transformer_decoder(y, encoder_output, NUM_HEADS, FF_DIM)

    outputs = layers.Dense(VOCAB_SIZE, activation="softmax")(y)

    model = Model([encoder_inputs, decoder_inputs], outputs)
    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
    return model

model = build_transformer_model()
model.summary()

# Sparse categorical crossentropy needs 3D input for Y
Y_train_expanded = np.expand_dims(Y_train, axis=-1)
Y_val_expanded = np.expand_dims(Y_val, axis=-1)

In [27]:
history = model.fit(
    [X_train, decoder_input_train],
    Y_train_expanded,
    validation_data=([X_val, decoder_input_val], Y_val_expanded),
    epochs=10,
    batch_size=64
)

Epoch 1/10


W0000 00:00:1749402207.762562      78 assert_op.cc:38] Ignoring Assert operator compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert


[1m156/157[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 9ms/step - accuracy: 0.7936 - loss: 0.6833

W0000 00:00:1749402215.657316      77 assert_op.cc:38] Ignoring Assert operator compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert


[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 45ms/step - accuracy: 0.7940 - loss: 0.6816

W0000 00:00:1749402221.918983      77 assert_op.cc:38] Ignoring Assert operator compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert
W0000 00:00:1749402223.653738      79 assert_op.cc:38] Ignoring Assert operator compile_loss/sparse_categorical_crossentropy/SparseSoftmaxCrossEntropyWithLogits/assert_equal_1/Assert/Assert


[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 71ms/step - accuracy: 0.7943 - loss: 0.6800 - val_accuracy: 0.9030 - val_loss: 0.2450
Epoch 2/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 13ms/step - accuracy: 0.9204 - loss: 0.2036 - val_accuracy: 0.9735 - val_loss: 0.0744
Epoch 3/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 13ms/step - accuracy: 0.9804 - loss: 0.0561 - val_accuracy: 0.9921 - val_loss: 0.0247
Epoch 4/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 13ms/step - accuracy: 0.9949 - loss: 0.0157 - val_accuracy: 0.9940 - val_loss: 0.0184
Epoch 5/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 14ms/step - accuracy: 0.9976 - loss: 0.0077 - val_accuracy: 0.9992 - val_loss: 0.0031
Epoch 6/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 13ms/step - accuracy: 0.9979 - loss: 0.0069 - val_accuracy: 0.9996 - val_loss: 0.0013
Epoch 7/10
[1m157/157[0m [32m

In [28]:
def autoregressive_decode_transformer(model, input_seq, 
                                      sos_id=1, eos_id=2, max_len=30):
    input_seq = input_seq.reshape(1, -1)  # (1, encoder_len)
    output_seq = [sos_id]

    for _ in range(max_len):
        # Pad decoder input to max_len
        decoder_input = np.zeros((1, max_len), dtype=int)
        decoder_input[0, :len(output_seq)] = output_seq

        # Predict next token
        predictions = model.predict([input_seq, decoder_input], verbose=0)
        next_token = np.argmax(predictions[0, len(output_seq)-1])

        if next_token == eos_id:
            break
        output_seq.append(next_token)

    return output_seq[1:]  # drop SOS


In [29]:
def prefix_accuracy_single(y_true, y_pred, id_to_token, eos_id=EOS_ID, verbose=False):
    t_str = decode_sequence(y_true, id_to_token).split(' EOS')[0]
    p_str = decode_sequence(y_pred, id_to_token).split(' EOS')[0]
    t_tokens = t_str.strip().split()
    p_tokens = p_str.strip().split()

    max_len = max(len(t_tokens), len(p_tokens))
    match_len = sum(x == y for x, y in zip(t_tokens, p_tokens))

    score = match_len / max_len if max_len > 0 else 0

    if verbose:
        print("TARGET  :", ' '.join(t_tokens))
        print("PREDICT :", ' '.join(p_tokens))
        print(f"MATCH   : {match_len}/{max_len} → {score:.2f}")
    return score


def test(n=20, rounds=5):
    results = []
    for r in range(rounds):
        print(f"Round {r+1}")
        X_test, Y_test = generate_dataset(n)
        scores = []
        for i in range(n):
            y_pred = autoregressive_decode_transformer(model, X_test[i])
            score = prefix_accuracy_single(Y_test[i], y_pred, id_to_token)
            scores.append(score)
        avg = np.mean(scores)
        print(f"  Average prefix accuracy: {avg:.3f}")
        results.append(avg)
    return np.mean(results), np.std(results)


In [30]:
mean_score, std_dev = test(n=20, rounds=10)
print(f"\nFinal Prefix Accuracy: {mean_score:.3f} ± {std_dev:.3f}")

Round 1
  Average prefix accuracy: 1.000
Round 2
  Average prefix accuracy: 1.000
Round 3
  Average prefix accuracy: 1.000
Round 4
  Average prefix accuracy: 1.000
Round 5
  Average prefix accuracy: 1.000
Round 6
  Average prefix accuracy: 1.000
Round 7
  Average prefix accuracy: 1.000
Round 8
  Average prefix accuracy: 1.000
Round 9
  Average prefix accuracy: 1.000
Round 10
  Average prefix accuracy: 1.000

Final Prefix Accuracy: 1.000 ± 0.000


In [31]:
X_test_new, Y_test_new = generate_dataset(n=1000, max_depth=MAX_DEPTH)

def evaluate_on_dataset(X_data, Y_data, sample_count=20, verbose=False):
    scores = []
    for i in range(sample_count):
        y_pred = autoregressive_decode_transformer(model, X_data[i])
        score = prefix_accuracy_single(Y_data[i], y_pred, id_to_token, verbose=verbose)
        scores.append(score)
    return np.mean(scores), np.std(scores)

mean_new, std_new = evaluate_on_dataset(X_test_new, Y_test_new, sample_count=100)
print(f"\nNew Test Set Prefix Accuracy: {mean_new:.3f} ± {std_new:.3f}")

for _ in range(5):
    i = np.random.randint(len(X_test_new))
    print(f"\nExample {i}")
    print("Infix       :", decode_sequence(X_test_new[i], id_to_token))
    print("True Postfix:", decode_sequence(Y_test_new[i], id_to_token))
    print("Predicted   :", decode_sequence(autoregressive_decode_transformer(model, X_test_new[i]), id_to_token))
    print('-' * 60)



New Test Set Prefix Accuracy: 1.000 ± 0.000

Example 687
Infix       : ( d * ( ( a - a ) * ( d + c ) ) )
True Postfix: d a a - d c + * *
Predicted   : d a a - d c + * *
------------------------------------------------------------

Example 631
Infix       : ( ( e - ( e - b ) ) / ( a * ( e + e ) ) )
True Postfix: e e b - - a e e + * /
Predicted   : e e b - - a e e + * /
------------------------------------------------------------

Example 706
Infix       : ( b + ( ( b * d ) + ( b * a ) ) )
True Postfix: b b d * b a * + +
Predicted   : b b d * b a * + +
------------------------------------------------------------

Example 398
Infix       : e
True Postfix: e
Predicted   : e
------------------------------------------------------------

Example 498
Infix       : ( c - ( d + a ) )
True Postfix: c d a + -
Predicted   : c d a + -
------------------------------------------------------------
