In [3]:
import tensorflow as tf
from tensorflow.keras import layers

# Config (change according to your dataset)
LATIN_VOCAB_SIZE = 50
DEVANAGARI_VOCAB_SIZE = 60
EMBED_SIZE = 64
HIDDEN_SIZE = 128

# Encoder
class Encoder(tf.keras.Model):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super(Encoder, self).__init__()
        self.embedding = layers.Embedding(vocab_size, embed_size)
        self.gru = layers.GRU(hidden_size, return_state=True, return_sequences=True)

    def call(self, x):
        x = self.embedding(x)
        output, state = self.gru(x)
        return state  # return final hidden state

# Decoder
class Decoder(tf.keras.Model):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super(Decoder, self).__init__()
        self.embedding = layers.Embedding(vocab_size, embed_size)
        self.gru = layers.GRU(hidden_size, return_sequences=True, return_state=True)
        self.fc = layers.Dense(vocab_size)

    def call(self, x, hidden_state):
        x = self.embedding(x)
        output, state = self.gru(x, initial_state=hidden_state)
        logits = self.fc(output)
        return logits, state

# Seq2Seq wrapper
class Seq2Seq(tf.keras.Model):
    def __init__(self, encoder, decoder, target_vocab_size):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.target_vocab_size = target_vocab_size

    def call(self, src, tgt, training=False, teacher_forcing_ratio=0.5):
        batch_size, tgt_len = tf.shape(tgt)[0], tf.shape(tgt)[1]
        outputs = tf.TensorArray(tf.float32, size=tgt_len)

        # Get encoder hidden state
        encoder_hidden = self.encoder(src)

        # First input to the decoder is the <sos> token
        decoder_input = tf.expand_dims(tgt[:, 0], 1)  # shape (batch, 1)
        hidden = encoder_hidden

        for t in range(1, tgt_len):
            preds, hidden = self.decoder(decoder_input, hidden)
            preds = tf.squeeze(preds, axis=1)
            outputs = outputs.write(t, preds)

            # Teacher forcing
            decoder_input = tf.expand_dims(
                tgt[:, t] if tf.random.uniform([]) < teacher_forcing_ratio else tf.argmax(preds, axis=1),
                1
            )

        return tf.transpose(outputs.stack(), [1, 0, 2])  # shape (batch, seq_len, vocab)

# Instantiate
encoder = Encoder(LATIN_VOCAB_SIZE, EMBED_SIZE, HIDDEN_SIZE)
decoder = Decoder(DEVANAGARI_VOCAB_SIZE, EMBED_SIZE, HIDDEN_SIZE)
seq2seq = Seq2Seq(encoder, decoder, DEVANAGARI_VOCAB_SIZE)


In [4]:
# Dummy data
BATCH_SIZE = 32
SRC_SEQ_LEN = 15
TGT_SEQ_LEN = 20

src = tf.random.uniform((BATCH_SIZE, SRC_SEQ_LEN), maxval=LATIN_VOCAB_SIZE, dtype=tf.int32)
tgt = tf.random.uniform((BATCH_SIZE, TGT_SEQ_LEN), maxval=DEVANAGARI_VOCAB_SIZE, dtype=tf.int32)

output = seq2seq(src, tgt, training=True)
print(output.shape)


(32, 20, 60)


In [5]:
predicted_indices = tf.argmax(output, axis=-1)  # shape: (32, 20)


In [6]:
print(predicted_indices)

tf.Tensor(
[[ 0 32 29 30 37 37 37 23 29 43  8 29 29 29 34 14 19  4 34 50]
 [ 0 49 25 58 35 25 42 56 58  2 22  8  8 43  8 25 58 51 22 44]
 [ 0 44 22  8 33  8 33  8 47 32 39 14 54 29 37 29 30 30 30 24]
 [ 0 30 36 37 37 37 37 37 37 56 32 50 25 53 37 32 50 14 40 14]
 [ 0 20 32 50 29 37 37 51 56 22  8 33 29 49 22  7  7  8 14  8]
 [ 0 19 25  2 53  7 53 51 56 22  8 33 30 37 37 37 37 43 31  2]
 [ 0 28 55  6 12 35 56 56  2 14 50 46 55 55  1 29 34 29 38 22]
 [ 0  0 43  8 33  8 33 54 29 37 37 37 19 40  9 40 31  7 13  7]
 [ 0 54  3  3 20 32 50 50  1  1 37 37  1 43  8 12 12 43 23  2]
 [ 0  0 42  0 43  8 12 29 18 35 56  5 45 32  0 12 12  4 22 22]
 [ 0 32 32 29 30 37 37 29 37 14 50 46 56 19 22 30 36  8 40 29]
 [ 0 28 55  6  6 35 56 33 49  8 29 30 38 18 22 58 14 46  7 38]
 [ 0 27 43  8 29 30 37 40 25 22  8 47 32 47 12 18 59 30  3  3]
 [ 0 27 43 35 42  0 43 32 50 29 37 37 56 50  1 29 30 18 30 30]
 [ 0 37 37 37 37 37 37 23 29 19 45 22 10 29 59 18 59 18 35 43]
 [ 0  7  7 46  8 47 12  8 33 49 11 32  1 29 