### Boilerplate code to run trained models

> ⚡Compute Note: I recommend running this notebook on a node with 1x H200 GPU. 

In [1]:
import torch
import tiktoken

from src.shraygpt import ShrayGPT

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
R0_CHECKPOINT_PATH = "checkpoints/shraygpt-r0.ckpt"

Load model and tokenizer.

In [19]:
tokenizer = tiktoken.get_encoding("r50k_base")

def generate_response(model: ShrayGPT, prompt: str, max_new_tokens: int) -> str:
    prompt_tokens = tokenizer.encode(prompt)
    context = torch.tensor(prompt_tokens, dtype=torch.long, device=DEVICE).unsqueeze(0)
    with torch.no_grad():
        generated = model.generate_nocache(
            context,
            max_new_tokens=max_new_tokens,
            temperature=0.2,
            top_k=20,
        )
        new_tokens = generated[0].tolist()[len(prompt_tokens):]
    response = tokenizer.decode(new_tokens)
    response, _, _ = response.partition("System:")
    response = response.replace("Assistant:", "")
    return response.strip()

ckpt = torch.load(R0_CHECKPOINT_PATH, map_location=DEVICE)
state = ckpt["state_dict"]; hparams = ckpt.get("hyper_parameters")
model = ShrayGPT(**hparams)
model.load_state_dict(state, strict=False)
model.to(DEVICE).eval()

ShrayGPT(
  (tok_emb): Embedding(50257, 1024)
  (layers): ModuleList(
    (0-23): 24 x Block(
      (norm1): RMSNorm()
      (norm2): RMSNorm()
      (mha): CausalSelfAttentionMLA(
        (wq): Linear(in_features=1024, out_features=1024, bias=False)
        (wk_lat): Linear(in_features=1024, out_features=64, bias=False)
        (wv_lat): Linear(in_features=1024, out_features=64, bias=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (moe): MoE(
        (gate): Linear(in_features=1024, out_features=6, bias=False)
        (experts): ModuleList(
          (0-5): 6 x SwiGLU(
            (w1): Linear(in_features=1024, out_features=2744, bias=False)
            (w2): Linear(in_features=1024, out_features=2744, bias=False)
            (w3): Linear(in_features=2744, out_features=1024, bias=False)
            (act): SiLU()
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
      

Generate response. 

In [8]:
response = generate_response(model, "Write a short poem about the sea:", 256)
response, _, _ = response.partition("System:")
response = response.replace('Assistant:', '').strip()

print(response)

In the depths of solitude,
Where and for a journey of spirit,
Are treasures upon the icy serenity
Where the sun sets,
The sorrow of the storm
The wind that dies,
So through the charm of life,
So let us return to the place of civilization,
With the bounties and grandeur,
And listen to the voice of the ocean,
For though the sea is only a drop by the breath of life,
For sure cannot erase it,
There's hope, the power of nature,
So let the sun shine its sweetest wonder,
For though the sea might not give it,
For though the sea may yet remain dreary,
And the mightiest enchanted serene,
In comfort, of heart and small temples,
And always, in harmony,
Without warmth and light,
And the night, though the light may not give,
For though the sea may still seem empty,
And though the sea is not chafing,
It's a never respite, yet true.
