In [1]:
import torch
from usta_model import UstaModel
from usta_tokenizer import UstaTokenizer

u_tokenizer = UstaTokenizer("tokenizer.json")

prompt = "the capital of united"

tokens = u_tokenizer.encode(prompt)
tokens

tensor([ 0, 61,  1, 61,  2, 61,  3])

In [5]:
torch.manual_seed(1)
u_model = UstaModel(vocab_size=len(u_tokenizer.vocab), embedding_dim=4, num_heads=2, context_length=32, num_layers=3)

out = u_model(tokens)
out.shape

torch.Size([7, 64])

In [4]:
u_model

UstaModel(
  (embedding): Embedding(64, 4)
  (pos_embedding): Embedding(32, 4)
  (layers): Sequential(
    (0): UstaDecoderBlock(
      (self_attention): UstaMultiHeadAttention(
        (multi_head_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
        )
        (projection): Linear(in_features=4, out_features=4, bias=True)
      )
      (norm1): UstaLayerNorm()
      (mlp): UstaMLP(
        (gate_proj): Linear(in_features=4, out_features=4, bias=True)
        (up_proj): Linear(in_features=4, out_features=4, bias=True)
        (down_proj): Linear(in_features=4, out_features=4, bias=True)
        (gelu): GELU()
      )
      (norm2): UstaLayerNorm()
    )
    (1): UstaDecoderBlock(
      (self_attention): UstaMultiHeadAttention(
        (multi_head_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
        )
        (projection): L

In [11]:
import torch

probs = torch.softmax(out[-1], dim=-1)
max_prob, max_index = torch.max(probs, dim=-1)
max_prob, max_index, probs

(tensor(0.1787, grad_fn=<MaxBackward0>),
 tensor(60),
 tensor([0.1006, 0.0152, 0.0036, 0.0015, 0.0022, 0.0039, 0.0260, 0.0109, 0.0025,
         0.0020, 0.0027, 0.0039, 0.0032, 0.0396, 0.0098, 0.0215, 0.0015, 0.0100,
         0.0085, 0.0031, 0.0175, 0.0081, 0.0016, 0.0053, 0.0360, 0.0019, 0.0030,
         0.0100, 0.0338, 0.0011, 0.0139, 0.0030, 0.0009, 0.0014, 0.0014, 0.0054,
         0.0088, 0.0241, 0.0248, 0.0363, 0.0747, 0.0049, 0.0077, 0.0056, 0.0050,
         0.0842, 0.0018, 0.0055, 0.0081, 0.0064, 0.0109, 0.0030, 0.0196, 0.0234,
         0.0038, 0.0033, 0.0021, 0.0093, 0.0023, 0.0149, 0.1787, 0.0038, 0.0043,
         0.0060], grad_fn=<SoftmaxBackward0>))

In [12]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

q_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
q_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B")

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
q_tokens = q_tokenizer.encode(prompt)
q_tokens

[1782, 6722, 315, 28192]

In [23]:
q_model.generate(torch.tensor([q_tokens]))

tensor([[ 1782,  6722,   315, 28192,  5302,   374,   279,  3283,   315, 93671,
            11,   714,   279,  6722,   315,   279,  3146,   374,   537,   279,
          6722,   315,   279,  1584]])

In [None]:
# input = [1782, 6722, 315, 28192]
# output = [38297, 315, 279, 5302]
# expected = [6722, 315, 28192, 5302]

In [None]:
q_out = q_model(torch.tensor([q_tokens]))
q_out

In [30]:
q_out.logits.shape

torch.Size([1, 4, 151936])

In [31]:
q_out.logits[0, 0, :].shape

torch.Size([151936])

In [36]:
probs = torch.softmax(q_out.logits[0, 2, :], dim=-1)
max_prob, max_index = torch.max(probs, dim=-1)
max_prob, max_index, probs

(tensor(0.3644, grad_fn=<MaxBackward0>),
 tensor(279),
 tensor([7.1117e-06, 5.0768e-05, 3.3015e-07,  ..., 3.4478e-09, 3.4478e-09,
         3.4478e-09], grad_fn=<SoftmaxBackward0>))

In [37]:
q_tokenizer.decode([max_index])

' the'

In [84]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

g_tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
g_model = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it")

In [85]:
g_tokens = g_tokenizer.encode("the capital of united")
# input = [2, 1437, 5279, 529, 26974]
# output = [107, 2148, 138, 236743, 1786]
# expected = [1437, 5279, 529, 26974, 5022]
g_tokens

[2, 1437, 5279, 529, 26974]

In [52]:
g_tokenizer.encode(" states")

[2, 5022]

In [98]:
g_tokenizer.decode([156702])

'క్ష్'

In [95]:
g_model.generate(torch.tensor([g_tokens]), max_new_tokens=1)

tensor([[    2,  1437,  5279,   529, 26974, 33138]])

In [96]:
g_out = g_model(torch.tensor([g_tokens]))
g_out.logits.shape

torch.Size([1, 5, 262144])

In [97]:
probs = torch.softmax(g_out.logits[0, 4, :], dim=-1)
max_prob, max_index = torch.max(probs, dim=-1)
max_prob, max_index, probs

(tensor(0.2421, grad_fn=<MaxBackward0>),
 tensor(156702),
 tensor([4.7484e-12, 3.9397e-06, 7.5031e-04,  ..., 1.4784e-12, 1.5568e-12,
         1.3565e-12], grad_fn=<SoftmaxBackward0>))