In [2]:
import tensorflow as tf
from tensorflow.keras.layers import Embedding, LSTM, Dense
from tensorflow.keras.models import Model
import numpy as np

# Sample vocab sizes
vocab_inp_size = 5000
vocab_tar_size = 5000
embedding_dim = 256
units = 512

# Encoder
class Encoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, enc_units):
        super(Encoder, self).__init__()
        self.enc_units = enc_units
        self.embedding = Embedding(vocab_size, embedding_dim)
        self.lstm = LSTM(self.enc_units, return_sequences=True, return_state=True)

    def call(self, x):
        x = self.embedding(x)
        output, state_h, state_c = self.lstm(x)
        return output, state_h, state_c


# Bahdanau Attention
class BahdanauAttention(tf.keras.layers.Layer):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = Dense(units)
        self.W2 = Dense(units)
        self.V = Dense(1)

    def call(self, query, values):
        query_with_time_axis = tf.expand_dims(query, 1)
        score = self.V(tf.nn.tanh(self.W1(query_with_time_axis) + self.W2(values)))
        attention_weights = tf.nn.softmax(score, axis=1)
        context_vector = attention_weights * values
        context_vector = tf.reduce_sum(context_vector, axis=1)
        return context_vector, attention_weights


# Decoder with attention
class Decoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, dec_units):
        super(Decoder, self).__init__()
        self.dec_units = dec_units
        self.embedding = Embedding(vocab_size, embedding_dim)
        self.lstm = LSTM(self.dec_units, return_sequences=True, return_state=True)
        self.fc = Dense(vocab_size)
        self.attention = BahdanauAttention(self.dec_units)

    def call(self, x, hidden, enc_output):
        context_vector, attention_weights = self.attention(hidden, enc_output)
        x = self.embedding(x)
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
        output, state_h, state_c = self.lstm(x)
        output = tf.reshape(output, (-1, output.shape[2]))
        x = self.fc(output)
        return x, state_h, state_c, attention_weights


# Instantiate
encoder = Encoder(vocab_inp_size, embedding_dim, units)
decoder = Decoder(vocab_tar_size, embedding_dim, units)

# ---------- TEST OUTPUT SHAPES ----------
batch_size = 16
seq_len = 10

# Fake input (integer token IDs)
sample_input = tf.random.uniform((batch_size, seq_len), minval=0, maxval=vocab_inp_size, dtype=tf.int32)
encoder_output, enc_h, enc_c = encoder(sample_input)

print("Encoder output shape:", encoder_output.shape)
print("Encoder hidden state shape:", enc_h.shape)
print("Encoder cell state shape:", enc_c.shape)

# Pass through decoder
sample_decoder_input = tf.random.uniform((batch_size, 1), minval=0, maxval=vocab_tar_size, dtype=tf.int32)
decoder_output, dec_h, dec_c, attn_weights = decoder(sample_decoder_input, enc_h, encoder_output)

print("Decoder output shape:", decoder_output.shape)
print("Decoder hidden state shape:", dec_h.shape)
print("Decoder cell state shape:", dec_c.shape)
print("Attention weights shape:", attn_weights.shape)


Encoder output shape: (16, 10, 512)
Encoder hidden state shape: (16, 512)
Encoder cell state shape: (16, 512)
Decoder output shape: (16, 5000)
Decoder hidden state shape: (16, 512)
Decoder cell state shape: (16, 512)
Attention weights shape: (16, 10, 1)
