In [1]:
import torch
from pathlib import Path
from typing import Literal
from en_indic_transformer import Predictor, Transformer, Tokenizer

In [2]:
base_dir = Path().absolute().parent.parent
base_dir

PosixPath('/Users/sameergururajmathad/eng-indic-transformer')

In [3]:
tokenizer_dir = base_dir / 'tokenizer'
model_dir = base_dir / 'models'

In [4]:
tokenizer = Tokenizer(str(tokenizer_dir/'tokenizer.model'))

In [5]:
random_seed = 42 # for reproducibility
device: Literal['cpu', 'cuda'] = 'cuda' if torch.cuda.is_available() else 'cpu' # device for training.

# apply random_seed
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)

# transformer details
context_length = 1024 # changed from 3000
vocab_size = tokenizer.n_vocab # since using gpt2 tokenizer
emb_dim = 512
enc_layers = 2
dec_layers = 2
num_heads = 16
dropout = 0.1
bias = False

In [6]:
torch.manual_seed(random_seed) # needed to get same weights for reproducibility
model = Transformer(vocab_size=vocab_size, context_length=context_length, emb_dim=emb_dim, enc_layers=enc_layers, dec_layers=dec_layers, num_heads=num_heads,dropout=dropout, bias=bias)
model.to(device)

Transformer(
  (encoder): Encoder(
    (token_embeddings): Embedding(50000, 512)
    (pos_embeddings): Embedding(1024, 512)
    (encoder_layers): ModuleList(
      (0-1): 2 x EncoderLayer(
        (mlp): MLP(
          (mlp): Sequential(
            (0): Linear(in_features=512, out_features=2048, bias=True)
            (1): GELU(approximate='none')
            (2): Linear(in_features=2048, out_features=512, bias=True)
          )
        )
        (attn): MultiHeadAttention(
          (wq): Linear(in_features=512, out_features=512, bias=False)
          (wk): Linear(in_features=512, out_features=512, bias=False)
          (wv): Linear(in_features=512, out_features=512, bias=False)
          (proj): Linear(in_features=512, out_features=512, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm1): LayerNorm()
        (norm2): LayerNorm()
      )
    )
  )
  (decoder): Decoder(
    (token_embeddings): Embedding(50000, 512)
    (pos_embeddings): Embedding(1

In [7]:
model.load_state_dict(torch.load(model_dir/'model.pt', map_location=device))

<All keys matched successfully>

### Prediction

In [16]:
for token in Predictor.predict(
    model,
    tokenizer,
    "<|english|> How are you today?",
    target='<|hindi|>',
    max_tokens=50,
    stop_token=tokenizer.get_piece_id('<|endoftext|>')
):
    print(tokenizer.decode(token), end="", flush=True)

क्याआपकोक्याकरते हो?अच्छानहीं लगताकि(मूसा)इसमें शक़ नहीं कमैंउसेक्षमाकरूंगी?