In [13]:
import math

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
import getpass

MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# HF_TOKEN = getpass.getpass()

In [5]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=None)

In [6]:
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, token=None)
model.eval()
model.to(DEVICE) # type: ignore

Loading checkpoint shards: 100%|██████████| 4/4 [00:12<00:00,  3.06s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
  

In [17]:
PREFIX_PROMPT = "The wind whispered through old ruins"
INSTRUCTION_PROMPT = "Continue the story."
PROMPT = f"{PREFIX_PROMPT}\n\n{INSTRUCTION_PROMPT}.\n"
MAX_NEW_TOKENS = 20
K = 8
k = 10
beta = 5.0

## Sequential Monte Carlo Sampling

In [11]:
from utils import load_counts_and_reward

reward_calc = load_counts_and_reward("./tinystories_ngrams")

def reward_sum_pos_ids(ids) -> float:
    """
    Args:
        reward_calc: FastRewardCalculator (token_lm.logp available).
        tokenizer: for ids→tokens conversion.
        ids: current full context ids (prompt + generated so far).

    Returns:
        float ΔR_t ≥ 0.
    """
    if len(ids) < 3:
        return 0.0

    tokens = tokenizer.decode(ids, skip_special_tokens=True)
    reward = reward_calc.calculate_reward_tokens(
        tokens.strip().split(" "), normalize=True
    )
    return reward

In [19]:
input_ids = tokenizer.encode(PROMPT, return_tensors="pt").to(model.device)
input_tokens_size = input_ids.shape[1]
input_ids = input_ids.repeat(K, 1)

completed_ids = torch.zeros(K).to(model.device)

# Define here to have access to the weights variable outside the loop
weights = []
normalized_weights = []

for _ in range(MAX_NEW_TOKENS):
    outputs = model(input_ids)
    logits = outputs.logits[:, -1, :]

    # Get top-k logits and indices and convert to probabilities
    top_k_logits, top_k_indices = torch.topk(logits, k, dim=-1)
    probs = torch.softmax(top_k_logits, dim=-1)

    # Sample from the top-k distribution
    next_token_ids = []
    weights = []
    for i in range(K):
        if completed_ids[i]:
            next_token_ids.append(model.config.eos_token_id)
            weights.append(1.0)
            continue

        sampled_idx = torch.multinomial(probs[i], 1)

        next_token_id = top_k_indices[i, sampled_idx].item()
        next_token_prob = top_k_logits[i, sampled_idx].item()
        next_token_ids.append(next_token_id)

        pi_t = math.exp(
            beta * reward_sum_pos_ids(input_ids[i])
        )
        pi_t_1 = math.exp(
            beta
            * reward_sum_pos_ids(
                torch.cat(
                    [input_ids[i], torch.tensor([next_token_id]).to(model.device)]
                ),
            )
        )
        weight = pi_t / (pi_t_1 * next_token_prob)
        weights.append(weight)

    next_token_ids = torch.tensor(next_token_ids).to(model.device)
    input_ids = torch.cat(
        [
            input_ids,
            next_token_ids.unsqueeze(-1),
        ],
        dim=-1,
    )

    total_weights = sum(weights)
    normalized_weights = torch.tensor(weights) / total_weights

    # Resample
    resampled_indices = torch.multinomial(normalized_weights, K, replacement=True).to(input_ids.device)
    print(normalized_weights)
    print(resampled_indices)
    input_ids = input_ids[resampled_indices]

    # Stop if all sequences are completed
    completed_ids = completed_ids.masked_fill(
        next_token_ids == model.config.eos_token_id, 1
    )
    if completed_ids.all():
        break

gen_ids = input_ids[:, input_tokens_size:].tolist()

samples = []
for i in range(K):
    samples.append(
        {
            "text": tokenizer.decode(gen_ids[i], skip_special_tokens=True),
            "weight": weights[i],
        }
    )

tensor([0.1190, 0.1291, 0.1190, 0.1291, 0.1190, 0.1190, 0.1366, 0.1291])
tensor([5, 1, 5, 4, 0, 7, 3, 0], device='cuda:0')
tensor([0.1236, 0.1167, 0.1236, 0.1236, 0.1236, 0.1255, 0.1399, 0.1236])
tensor([4, 4, 3, 5, 1, 4, 0, 0], device='cuda:0')
tensor([0.1215, 0.1215, 0.1215, 0.1417, 0.1294, 0.1215, 0.1215, 0.1215])
tensor([5, 4, 5, 0, 2, 0, 5, 2], device='cuda:0')
tensor([0.0957, 0.3298, 0.0957, 0.0957, 0.0957, 0.0957, 0.0957, 0.0957])
tensor([1, 1, 2, 2, 6, 1, 2, 0], device='cuda:0')
tensor([0.1115, 0.1115, 0.1303, 0.1303, 0.1303, 0.1115, 0.1442, 0.1303])
tensor([5, 0, 7, 0, 0, 2, 2, 7], device='cuda:0')
tensor([0.1479, 0.1479, 0.1374, 0.0069, 0.1479, 0.1374, 0.1374, 0.1374])
tensor([4, 0, 4, 0, 1, 4, 4, 2], device='cuda:0')
tensor([0.1240, 0.1278, 0.1240, 0.1327, 0.1278, 0.1240, 0.1327, 0.1071])
tensor([4, 5, 6, 6, 0, 1, 4, 6], device='cuda:0')
tensor([0.1238, 0.1249, 0.1262, 0.1262, 0.1249, 0.1238, 0.1238, 0.1262])
tensor([3, 2, 7, 0, 7, 4, 3, 1], device='cuda:0')
tensor([0.1308, 

In [20]:
samples

[{'text': 'As I walked through the crumbling arches and weathered stones, I felt a strange energy coursing',
  'weight': 0.05830518478831011},
 {'text': 'As I walked through the crumbling arches and weathered stones, I felt a strange energy emanating',
  'weight': 0.1372190051489348},
 {'text': 'As I walked through the crumbling arches and crumbling stone, the wind whispered secrets in my ear.',
  'weight': 0.1372190051489348},
 {'text': 'As I walked through the crumbling arches and crumbling stone, the wind whispered secrets in my ear.',
  'weight': 0.01904768555878557},
 {'text': 'As I walked through the crumbling arches and weathered stones, I couldn’t shake the feeling that',
  'weight': 0.01904768555878557},
 {'text': 'As I walked through the crumbling arches and crumbling stone, the wind whispered secrets in my ear.',
  'weight': 0.1372190051489348},
 {'text': 'As I walked through the crumbling arches and crumbling stone, the wind whispered secrets in my ear.',
  'weight': 0.0351