In [1]:
from model import Translator
from dataset import TextDataset
import torch
import tqdm
from tokenizers import Tokenizer
import os

In [2]:

print(f"Using PyTorch version {torch.__version__}")

# use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

# use tensor cores
torch.set_float32_matmul_precision('high')

# use flash attention
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)

Using PyTorch version 2.3.0+cu121
Using device cuda


In [3]:
model = torch.load("../models/model.pt")
model.eval()

Translator(
  (engEmbedding): Embedding(804, 256)
  (hilliEmbedding): Embedding(292, 256)
  (decoder_block): ModuleList(
    (0-4): 5 x Decoder(
      (feed_forward): Sequential(
        (0): Dropout(p=0.1, inplace=False)
        (1): Linear(in_features=256, out_features=512, bias=False)
        (2): ReLU()
        (3): Linear(in_features=512, out_features=256, bias=False)
        (4): ReLU()
      )
      (layernorm): RMSNorm()
      (layernorm2): RMSNorm()
      (layernorm3): RMSNorm()
      (MHA): SelfAttention(
        (c_attn): Linear(in_features=256, out_features=768, bias=False)
        (c_proj): Linear(in_features=256, out_features=256, bias=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (CA): CrossAttention(
        (query_attn): Linear(in_features=256, out_features=256, bias=False)
        (key_attn): Linear(in_features=256, out_features=256, bias=False)
        (value_attn): Linear(in_features=256, out_features=256, bias=False)
        (c_proj): 

In [4]:
def generate(sentence):
    hilliTokenizer = Tokenizer.from_file("../models/hilliTokenizer.json")
    engTokenizer = Tokenizer.from_file("../models/englighTokenizer.json")
    sentence = hilliTokenizer.encode(sentence).ids
    sentence = torch.tensor(
        sentence, dtype=torch.int64).unsqueeze(0).to(device)
    currentOutput = [0]
    model.eval()
    for i in range(100):
        x = torch.tensor(
            currentOutput, dtype=torch.int64).unsqueeze(0).to(device)
        # print(x, sentence)
        output = model(x=x, originalText=sentence, return_loss=False)
        output = torch.argmax(output[0][-1]).item()
        currentOutput.append(output)
        if (output == 1):
            break
    currentOutput = engTokenizer.decode(currentOutput)
    model.train()
    return currentOutput


In [5]:
generate("Mi Muhe Nye Mi Muhe Beru.") #I hate what I do

"I don ' t like what I like you doing?"

In [6]:
generate("Unu Du Tomo Beru Si?") #What are your two friends doing?

'What are your two friends doing?'

In [7]:
generate("Mi muhe Gusha Boya.") # I like green.

'I like green food.'

In [8]:
generate("Mi muhe Mi Muhe Upa Celi Nini, mi muhe Lata.") #I wish for the sun to be gone, I appreciate the cold.

'I like to want the sun to go away, I like cold.'

In [9]:
model(torch.tensor([[0, 46, 289, 12, 89, 207, 303, 46, 207, 141]]).cuda(), torch.tensor([[0, 126, 228, 229, 126, 228, 237,  19, 1]]).cuda())

tensor([[[0.0006, 0.0001, 0.0006,  ..., 0.0005, 0.0009, 0.0006],
         [0.0010, 0.0004, 0.0009,  ..., 0.0007, 0.0007, 0.0009],
         [0.0010, 0.0029, 0.0010,  ..., 0.0009, 0.0015, 0.0009],
         ...,
         [0.0009, 0.0004, 0.0010,  ..., 0.0007, 0.0008, 0.0009],
         [0.0012, 0.0004, 0.0012,  ..., 0.0017, 0.0008, 0.0012],
         [0.0011, 0.0008, 0.0011,  ..., 0.0010, 0.0010, 0.0011]]],
       device='cuda:0', grad_fn=<SoftmaxBackward0>)

In [10]:
model = torch.jit.trace(model, (torch.tensor([[0, 46, 289, 12, 89, 207, 303, 46, 207, 141]]).cuda(), torch.tensor([[0, 126, 228, 229, 126, 228, 237,  19, 1]]).cuda()))

  assert query_batch == key_batch == value_batch
  assert query_channels == key_channels == value_channels
Tensor-likes are not close!

Mismatched elements: 7530 / 8040 (93.7%)
Greatest absolute difference: 0.0671837329864502 at index (0, 3, 89) (up to 1e-05 allowed)
Greatest relative difference: 4.217129647113587 at index (0, 0, 377) (up to 1e-05 allowed)
  _check_trace(


In [11]:
model = torch.jit.script(model)