# Soft Sampling

In [1]:
import math
import torch
import torch.nn.functional as F
from torch import Tensor, tensor
from jaxtyping import Float, Int
from transformers import AutoModelForCausalLM, AutoTokenizer

# torch.set_grad_enabled(False); # disable backprop

In [2]:
local_files_only = True
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print("device:", device)

# model_name = "gpt2"
model_name = "meta-llama/Llama-3.2-1B"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    local_files_only=local_files_only,
    # attn_implementation="flash_attention_2",
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=local_files_only)

print(tokenizer("<|finetune_right_pad_id|>"))
print(tokenizer.convert_ids_to_tokens([128000, 128004]))

tokenizer.pad_token = '<|finetune_right_pad_id|>'

if model_name == "gpt2":
    context_length = model.config.n_ctx
    tokenizer.pad_token = tokenizer.eos_token
else: # llama
    context_length = model.config.max_position_embeddings

device: mps
{'input_ids': [128000, 128004], 'attention_mask': [1, 1]}
['<|begin_of_text|>', '<|finetune_right_pad_id|>']


[128256, 128257]

In [3]:
def tokenize(inputs) -> Int[Tensor, "bs seq"]:
    return tokenizer(
        inputs,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=context_length
    )["input_ids"].to(device)

def decode(tokens) -> str:
    return tokenizer.batch_decode(tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True)

prompts = ["the cat sat in the hat", "hello darkness my old friend", "soft sampling smosh sampling"]
tokens = tokenize(prompts)
print("Tokenized:")
print(tokens)
print("Decode:")
decode(tokens)

Tokenized:
tensor([[128000,   1820,   8415,   7731,    304,    279,   9072],
        [128000,  15339,  27394,    856,   2362,   4333, 128004],
        [128000,   3594,  25936,   1554,   9451,  25936, 128004]],
       device='mps:0')
Decode:


['the cat sat in the hat',
 'hello darkness my old friend',
 'soft sampling smosh sampling']

## Mock Training

Freeze everything but the head

In [4]:
for p in model.parameters():
    p.requires_grad = False
model.lm_head.requires_grad = True

In [5]:
import torch
import torch.nn.functional as F
from fastcore.meta import delegates

def to_cpu(t): return t.detach().cpu()

def mk_proba_dist(
    logits, # (batch_size, d_vocab)
    temperature=1.0,
    top_k=None,
    min_p=None,
):
    "Create probability distribution from logits"
    batch_size, d_vocab = logits.shape
    device = logits.device

    if top_k is not None:
        top_logits, top_idxs = logits.topk(top_k, dim=-1)
        logits = (torch.full_like(logits, float('-inf'))
                       .scatter_(1, top_idxs, top_logits))

    probs = F.softmax(logits / temperature, dim=-1)

    if min_p is not None:
        max_p = probs.max(-1, keepdim=True).values
        mask = probs >= (max_p * min_p)
        probs = probs * mask
        probs = probs / probs.sum(dim=-1, keepdim=True) # renormalize

    assert logits.shape == probs.shape
    return probs

@delegates(mk_proba_dist)
def soft_sampling_forward(
    model,
    input_ids, # tokens of shape (batch_size, seq_len)
    guidance_alpha, # guidance weighting -- 1 equivalent to discrete sampling
    **kwargs, # passed to mk_proba_dist
):
    "Single train step using soft sampling"
    assert 0 <= guidance_alpha <= 1
    batch_size, seq_len = input_ids.shape
    device = input_ids.device

    # cache
    past_key_values = None
    # position_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)

    W_E = model.get_input_embeddings().weight

    loss = torch.tensor(0., device=device)
    embeds = W_E[input_ids[:, :1]]  # BOS shape: (batch_size, 1, d_model)
    tokens = [ to_cpu(input_ids[:, :1]) ]
    for t in range(1, seq_len):
        outputs = model(
            inputs_embeds=embeds,
            past_key_values=past_key_values,
            # position_ids=position_ids,
            use_cache=True
        )

        logits_t = outputs.logits[:, -1]
        past_key_values = outputs.past_key_values

        # loss
        loss_t = F.cross_entropy(logits_t, input_ids[:, t])
        loss += loss_t

        p_t = mk_proba_dist(logits_t, **kwargs)

        # discrete sample -- for logging
        next_tokens = torch.multinomial(p_t, 1) # (batch_size, 1)
        tokens.append(to_cpu(next_tokens))

        # soft sample
        next_emb_soft = p_t @ W_E      # soft sampling
        next_emb_gt = W_E[input_ids[:, t]] # guidance sampling

        next_embed = (
            guidance_alpha * next_emb_gt +
            (1 - guidance_alpha) * next_emb_soft
        )
        embeds = torch.cat([embeds, next_embed[:, None, :]], dim=1)
        # position_ids += 1

    tokens = torch.cat(tokens, dim=1)
    # normalize loss
    loss = loss / (batch_size * (seq_len - 1))
    return loss, tokens

prompts = [
    "the cat sat in the hat",
    "where did all the hats go?",
    "I am the rizzler",
]

batch = tokenize(prompts)
loss, tokens = soft_sampling_forward(
    model,
    input_ids=batch,
    guidance_alpha=0.,
    temperature=1.,
    min_p=None,
)
print(f"Loss: {loss.item():.4f}")

inp_tokens =  [decode(o) for o in batch]
outp_tokens = [decode(o) for o in tokens]

for inp, outp in zip(inp_tokens, outp_tokens):
    prev_toks = ""
    for i_t, o_t in zip(inp, outp):
        print(f"{repr(prev_toks)} : {repr(i_t)} / {repr(o_t)}")
        prev_toks += i_t
    print("---")

Loss: 27.4427
'' : '' / ''
'' : 'the' / 'Question'
'the' : ' cat' / 'Site'
'the cat' : ' sat' / 'British'
'the cat sat' : ' in' / 'Exception'
'the cat sat in' : ' the' / 'lene'
'the cat sat in the' : ' hat' / '">'
'the cat sat in the hat' : '' / '\n'
---
'' : '' / ''
'' : 'where' / 'Question'
'where' : ' did' / ':\n'
'where did' : ' all' / '男'
'where did all' : ' the' / 'Comment'
'where did all the' : ' hats' / 'Request'
'where did all the hats' : ' go' / 'De'
'where did all the hats go' : '?' / '",\n'
---
'' : '' / ''
'' : 'I' / '['
'I' : ' am' / 'Nr'
'I am' : ' the' / ' Tag'
'I am the' : ' r' / '_question'
'I am the r' : 'izz' / 'ster'
'I am the rizz' : 'ler' / 'isms'
'I am the rizzler' : '' / ' "/'
---


In [5]:
from typing import Callable
from transformers import Trainer, TrainingArguments
from dataclasses import dataclass, asdict, field

@dataclass
class ProbabilityDistributionArguments:
    temperature: float = 1.
    top_k: int | None = None
    min_p: float | None = None

    def __post_init__(self):
        assert self.temperature >= 0
        assert self.top_k is None or self.min_p > 0
        assert self.min_p is None or 0 <= self.min_p <= 1

@dataclass
class GuidanceAlphaScheduler:
   warmup_delay_ratio: float = 0.1  # Start warmup after this % of steps
   warmup_ratio: float = 0.7        # Warmup duration as % of steps
   min_alpha: float = 0.25          # Minimum guidance alpha value

   def get_scheduler(self, num_steps: int) -> Callable[[int], float]:
       "Scheduler decays from 1.0 to min_alpha"
       delay_end = int(num_steps * self.warmup_delay_ratio)
       warmup_end = int(num_steps * (self.warmup_delay_ratio + self.warmup_ratio))

       def scheduler(step: int) -> float:
           if step < delay_end:
               return 1.0
           if step < warmup_end:
               progress = (step - delay_end) / (warmup_end - delay_end)
               return 1.0 - ((1.0 - self.min_alpha) * progress)
           return self.min_alpha

       return scheduler

@dataclass
class SoftDecodingTrainingArguments(TrainingArguments):
    guidance_alpha_scheduler: GuidanceAlphaScheduler = field(default_factory=lambda: GuidanceAlphaScheduler())
    probability_distribution_argments: ProbabilityDistributionArguments = field(default_factory=lambda: ProbabilityDistributionArguments())

    def __post_init__(self):
        super().__post_init__()
        if isinstance(self.probability_distribution_argments, dict):
            self.probability_distribution_argments = ProbabilityDistributionArguments(**self.probability_distribution_argments)
        if isinstance(self.guidance_alpha_scheduler, dict):
            self.guidance_alpha_scheduler = GuidanceAlphaScheduler(**self.guidance_alpha_scheduler)


class SoftDecodingTrainer(Trainer):
    def __init__(
        self,
        model,
        args: SoftDecodingTrainingArguments,
        **kwargs,
    ):
        super().__init__(model, args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs["input_ids"]
        loss, predicted_tokens = soft_sampling_forward(
            model=model,
            input_ids=input_ids,
            guidance_alpha=self.guidance_alpha,
            **asdict(self.args.probability_distribution_argments),
        )
        return (
            (loss, {"predicted_tokens": predicted_tokens})
            if return_outputs
            else loss
        )

    def create_scheduler(self, num_training_steps: int, optimizer=None):
        self.guidance_alpha_scheduler = self.args.guidance_alpha_scheduler.get_scheduler(num_training_steps)
        return super().create_scheduler(num_training_steps, optimizer)

    def training_step(self, model, inputs):
        self.guidance_alpha = self.guidance_alpha_scheduler(self.state.global_step)
        if self.control.should_log: self.log("guidance_alpha", self.guidance_alpha, prog_bar=True)
        return super().training_step(model, inputs)


training_args = SoftDecodingTrainingArguments(
    output_dir="/tmp/soft-decoding-output",
    guidance_alpha_scheduler=dict(
        warmup_delay_ratio=0.1,
        warmup_ratio=0.7,
        min_alpha=0.3
    ),
    probability_distribution_argments=dict(
        temperature=1.0,
        top_k=None,
        min_p=0.2
    )
)
# trainer = SoftDecodingTrainer(
#     model=model,
#     args=training_args,
#     train_dataset=train_dataset,
#     eval_dataset=eval_dataset,
# )
# trainer.train()

In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

from trl import GRPOTrainer, GRPOConfig

class SoftGRPOTrainer(GRPOTrainer):
    """
    Subclass the GRPOTrainer to override the generation step with a custom
    'soft sampling' method.
    """

    def generate_completions(self, prompts, num_completions=1, **kwargs):
        """
        Called by GRPOTrainer internally to produce completions for each prompt.
        Here, we override to do 'soft decoding' with manual forward passes.

        `prompts`: list of prompt strings or list of message dicts (depending on your format).
        `num_completions`: how many completions to produce per prompt (the 'G' in GRPO).

        Return format: A list of length (batch_size * num_completions), where each item
                       is either a string or a chat-style list of messages.
        """
        # Basic code: we’ll handle each prompt individually, produce num_completions for each.
        # This is naive. You might want to do batched decoding. For clarity, we do it prompt-by-prompt.
        completions = []
        for prompt in prompts:
            for _ in range(num_completions):
                completion_text = self._soft_decode_one(prompt, **kwargs)
                completions.append(completion_text)

        return completions

    def _soft_decode_one(self, prompt, max_length=128, temperature=1.0, guidance_alpha=0.0):
        """
        Single forward pass that does 'soft sampling' for one prompt.
        Here, guidance_alpha=0 => purely soft. guidance_alpha=1 => purely discrete teacher-forcing style.
        """

        # 1. Tokenize prompt
        if isinstance(prompt, list):
            # If using Chat format with messages, flatten them. Or handle them however you want.
            # For simplicity, assume prompt[-1]["content"] is your user content:
            prompt_str = prompt[-1]["content"]
        else:
            prompt_str = prompt

        # Get tokenizer from the trainer's model
        tokenizer = self.tokenizer
        model = self.model  # the policy model we are training (should be half or fp16)
        device = model.device

        input_ids = tokenizer.encode(prompt_str, return_tensors="pt").to(device)
        # We'll keep track of discrete tokens for logging/decoding:
        all_tokens = input_ids.clone()

        # 2. We'll gather the last hidden state or next-token logits in a loop:
        past_key_values = None
        # Convert the initial prompt tokens into embeddings:
        W_E = model.get_input_embeddings().weight  # shape [vocab_size, hidden_dim]
        embeds = W_E[input_ids[:, -1]]  # last token embedding. shape: [1, hidden_dim]

        # We have to manually feed the entire prefix or do some caching trick.
        # For simplicity, re-run from scratch each step. For efficiency, you'd do caching more carefully.

        # But let's do a simple approach: run the entire input_ids once for the prompt, then keep a loop.
        # We can do that by calling `model(..., use_cache=True)`.
        # Then on subsequent tokens, we only feed the newly appended embedding.

        # However, let's illustrate a simple step-by-step (inefficient) approach for clarity.

        # Step 1: encode the entire prompt in a normal forward pass:
        out = model(input_ids=input_ids, use_cache=True)
        past_key_values = out.past_key_values

        # Now we loop for next tokens
        for step in range(max_length):
            # Obtain next-token logits from the last forward pass
            # Usually, out.logits is shape [batch_size, seq_len, vocab_size].
            # We want the final step's logits.
            logits = out.logits[:, -1, :]  # shape [1, vocab_size]

            # 1) Build a distribution over the next token
            #    shape [1, vocab_size]
            probs = F.softmax(logits / temperature, dim=-1)

            # 2) Optionally sample a discrete token for logging
            #    shape [1, 1]
            next_token_id = torch.multinomial(probs, num_samples=1)

            # 3) "Soft" next embedding
            # shape [1, hidden_dim]
            emb_soft = probs @ W_E  # [1, vocab_size] * [vocab_size, hidden_dim]

            # 4) Possibly do a “guidance mix” with the ground-truth token (if training with supervised data),
            #    but in typical RL setting we don’t have that. So just use `emb_soft`.
            # next_emb = guidance_alpha * W_E[true_token_id] + (1 - guidance_alpha) * emb_soft
            next_emb = emb_soft

            # 5) If next_token_id is EOS, break
            if next_token_id.item() == tokenizer.eos_token_id:
                # append for final decode
                all_tokens = torch.cat([all_tokens, next_token_id], dim=-1)
                break

            # Otherwise, append the discrete token to all_tokens for logging/decoding
            all_tokens = torch.cat([all_tokens, next_token_id], dim=-1)

            # 6) Feed the new embedding into the model:
            out = model(inputs_embeds=next_emb.unsqueeze(1),  # shape [1, 1, hidden_dim]
                        past_key_values=past_key_values,
                        use_cache=True)
            past_key_values = out.past_key_values

        # 7) Convert final tokens to string
        completion_str = tokenizer.decode(all_tokens[0, :], skip_special_tokens=True)
        return completion_str


#
# Then you can use this `SoftGRPOTrainer` similarly to how you would use GRPOTrainer:
#

def train_soft_grpo():
    # Suppose you have your dataset & reward functions
    from datasets import load_dataset
    dataset = load_dataset("trl-lib/tldr", split="train")  # trivial example

    def reward_len(completions, **kwargs):
        return [-abs(20 - len(c)) for c in completions]

    # Basic config
    training_args = GRPOConfig(
        output_dir="my-soft-grpo-checkpoints",
        logging_steps=10,
        use_vllm=False,  # must disable vLLM for custom "soft" generation
    )

    trainer = SoftGRPOTrainer(
        model="Qwen/Qwen2-0.5B-Instruct",
        reward_funcs=reward_len,
        args=training_args,
        train_dataset=dataset,
    )

    trainer.train()

#
# That’s it: now your completions come from `_soft_decode_one()` instead of the standard discrete generate().
#

