In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import getpass

MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# HF_TOKEN = getpass.getpass()

In [5]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=None)

In [6]:
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, token=None)
model.eval()
model.to(DEVICE) # type: ignore

Loading checkpoint shards: 100%|██████████| 4/4 [00:11<00:00,  2.79s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
  

In [7]:
PREFIX_PROMPT = "The wind whispered through old ruins"
INSTRUCTION_PROMPT = "Continue the story."
PROMPT = f"{PREFIX_PROMPT}\n\n{INSTRUCTION_PROMPT}.\n"
MAX_NEW_TOKENS = 50
BATCH_SIZE = 3

## Simple Importance Sampling

In [8]:
input_ids = tokenizer.encode(PROMPT, return_tensors="pt").to(DEVICE)
input_ids = input_ids.repeat(BATCH_SIZE, 1)

In [20]:
completed_ids = torch.zeros(BATCH_SIZE).to(model.device)

In [12]:
output = model(input_ids)

In [13]:
output.logits.shape

torch.Size([3, 12, 128256])

In [14]:
logits = output.logits[:, -1, :]

In [15]:
top_k_logits, top_k_indices = torch.topk(logits, k=3, dim=-1)
probs = torch.softmax(top_k_logits, dim=-1)


In [16]:
probs.shape

torch.Size([3, 3])

In [32]:
next_token_ids = []
for i in range(BATCH_SIZE):
    if completed_ids[i]:
        next_token_ids.append(model.config.eos_token_id)
        continue

    sampled_idx = torch.multinomial(probs[i], 1)
    next_token_id = top_k_indices[i, sampled_idx].item()
    next_token_ids.append(next_token_id)

next_token_ids

[2170, 791, 791]

In [34]:
input_ids = torch.cat(
    [
        input_ids,
        torch.tensor(next_token_ids).to(model.device).unsqueeze(-1),
    ],
    dim=-1,
)

In [36]:
input_ids

tensor([[128000,    791,  10160,  58366,   1555,   2362,  46762,    271,  24433,
            279,   3446,  35047,   2170],
        [128000,    791,  10160,  58366,   1555,   2362,  46762,    271,  24433,
            279,   3446,  35047,    791],
        [128000,    791,  10160,  58366,   1555,   2362,  46762,    271,  24433,
            279,   3446,  35047,    791]], device='cuda:0')

In [None]:
for _ in range(MAX_NEW_TOKENS):
    outputs = model(input_ids)
    logits = outputs.logits[~completed_ids, -1, :]

    # Get top-k logits and indices and convert to probabilities
    top_k_logits, top_k_indices = torch.topk(logits, k, dim=-1)
    probs = torch.softmax(top_k_logits, dim=-1)

    # Sample from the top-k distribution
    for i in range(len(logits)):
        sampled_idx = torch.multinomial(probs[i], 1)
        next_token_id = top_k_indices[i, sampled_idx]

        # Stop if EOS token is reached
        if next_token_id == model.config.eos_token_id:
            break

        # Update input_ids for next iteration
        input_ids[i] = torch.cat(
            [input_ids[i], torch.tensor([[next_token_id]]).to(model.device)], dim=1
        )

## Reward Function Testing

In [7]:
import numpy as np
from utils import load_jsonl

data = load_jsonl("data/outputs_task1_IS.jsonl")

samples = [[s["text"] for s in d["continuations"][0]["samples"]] for d in data]
samples = np.array(samples).flatten()

In [9]:
sample = samples[0]
sample

'. It carried a faint whisper. The villagers believed it was a sign of the past. But I knew better. It was a sign of the future. The wind brought with it a chill. It was as if the very presence of it was a'

In [10]:
import pickle

cache_file = "tinystories_ngrams/trigram_probs.pkl"

with open(cache_file, "rb") as f:
    cache = pickle.load(f)
tri_probs = cache["trigram_probs"]

In [19]:
import math
from typing import Dict

class _TokenLM:
    """Minimal token-trigram LM with logp only. Internal use."""

    def __init__(self, tri_probs: Dict[str, float], eps: float):
        self._tri = tri_probs
        self._eps = eps

    @staticmethod
    def _key(t1: str, t2: str, t3: str) -> str:
        return f"Ġ{t1},Ġ{t2},Ġ{t3}"

    def logp(self, t1: str, t2: str, t3: str) -> float:
        """Return log P(t3 | t1, t2) with epsilon floor."""
        p = self._tri.get(self._key(t1, t2, t3), 0.0)
        if p <= 0.0:
            return 0.0
        return -math.log(p)

token_lm = _TokenLM(tri_probs, 1e-9)

tokens = sample.split(" ")
print(tokens)

reward = 0.0
for i in range(len(tokens) - 2):
    reward += token_lm.logp(tokens[i], tokens[i + 1], tokens[i + 2])

reward

['.', 'It', 'carried', 'a', 'faint', 'whisper.', 'The', 'villagers', 'believed', 'it', 'was', 'a', 'sign', 'of', 'the', 'past.', 'But', 'I', 'knew', 'better.', 'It', 'was', 'a', 'sign', 'of', 'the', 'future.', 'The', 'wind', 'brought', 'with', 'it', 'a', 'chill.', 'It', 'was', 'as', 'if', 'the', 'very', 'presence', 'of', 'it', 'was', 'a']


82.18322252988756