<a href="https://colab.research.google.com/github/varun29-git/modified-transformer/blob/main/Inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/varun29-git/modified-transformer -q
%cd modified-transformer/
!pip install -r requirements.txt -q

/content/modified-transformer


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import sys
sys.path.append('/content/drive/MyDrive/VectorSLM')

In [4]:
import torch
import tiktoken

from model import build_transformer
from config import *


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cuda


In [6]:
tokenizer = tiktoken.encoding_for_model("gpt-4")
vocab_size = tokenizer.n_vocab


In [7]:
print("Loading model...")

model = build_transformer(
    vocab_size,
    D_MODEL,
    H,
    N,
    D_FF,
    DROPOUT
).to(device)

weights_path = "/content/drive/MyDrive/modified_transformer_weights/_weights/best_model.pt"
model.load_state_dict(torch.load(weights_path, map_location=device))
model.eval()

print("Model loaded successfully.")

Loading model...
Model loaded successfully.


In [8]:
@torch.no_grad()
def generate_text(
    prompt,
    model,
    tokenizer,
    max_new_tokens=70,
    temperature=0.8
):
    model.eval()

    tokens = tokenizer.encode(prompt)
    tokens = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)

    for _ in range(max_new_tokens):
        seq_len = tokens.size(1)

        # causal mask
        mask = torch.tril(
            torch.ones(seq_len, seq_len, device=device)
        ).unsqueeze(0).unsqueeze(0)

        logits = model(tokens, mask)
        logits = logits[:, -1, :] / temperature

        probs = torch.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        tokens = torch.cat([tokens, next_token], dim=1)

        if next_token.item() == tokenizer.eot_token:
            break

    return tokenizer.decode(tokens[0].tolist())


In [9]:
prompts = [
    "He was unhappy",
    "The brave knight went to the",
    "In a magical forest, a small"
]

for i, prompt in enumerate(prompts, 1):
    print(f"\n{'='*20} STORY {i} {'='*20}")
    print(generate_text(prompt, model, tokenizer))



He was unhappy and said, "What do you mean? This is not nice."

The girl said, "I just wanted to do it, but I did not know that. I just wanted some milk too. Now I can make two pies for you. And you can have some tea and a snack. And you can play with your dolls instead."

The girl

The brave knight went to the beach. It was a picture of a castle and even a castle. The knight was so happy, he decided to close the picture of himself. 

The knight learned a valuable lesson that day - when you show me what you used, you can make something truly special.Once upon a time there was a little girl named Gus. lying in the forest

In a magical forest, a small rabbit and a rabbit were playing in the water. The rabbit was very excited and wanted to stay for a long time.

The rabbit asked the rabbit why the rabbit was so flexible. The rabbit told the rabbit he could move around by himself in circles.
The rabbit said he must surrender one more time and make a wish. The rabbit hopped off and
