In [5]:
import torch
from pathlib import Path
from tokenizers import Tokenizer
from dataset import greedy_decode
from transformer import build_transformer

def translate(sentence: str, model, tokenizer_src, tokenizer_tgt, max_len, device):
    model.eval()
    
    # Tokenize the input sentence
    encoder_input = tokenizer_src.encode(sentence).ids
    encoder_input = torch.tensor(encoder_input).unsqueeze(0).to(device)
    
    # Create the encoder mask
    encoder_mask = (encoder_input != tokenizer_src.token_to_id('[PAD]')).unsqueeze(0).unsqueeze(0).int().to(device)
    
    # Perform the translation
    translation = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
    
    # Decode the translation
    translated_sentence = tokenizer_tgt.decode(translation.detach().cpu().numpy())
    
    return translated_sentence

def main():
    # Set up device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load the tokenizers
    tokenizer_src = Tokenizer.from_file("tokenizer_en.json")
    tokenizer_tgt = Tokenizer.from_file("tokenizer_mr.json")

    # Set up model configuration
    config = {
        "seq_len": 350,
        "d_model": 512,
        "lang_src": "en",
        "lang_tgt": "mr",
    }

    # Build the model
    model = build_transformer(tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size(), 
                              config["seq_len"], config["seq_len"], d_model=config["d_model"]).to(device)

    # Load the model weights
    model_weights = "tmodel_49.pt"
    state = torch.load(model_weights, map_location=device)
    model.load_state_dict(state['model_state_dict'])

    # Example usage
    english_sentence = "Hello, how are you?"
    marathi_translation = translate(english_sentence, model, tokenizer_src, tokenizer_tgt, config["seq_len"], device)
    
    print(f"English: {english_sentence}")
    print(f"Marathi: {marathi_translation}")

if __name__ == "__main__":
    main()

English: Hello, how are you?
Marathi: , तू काय आहेस ?
