# Soft Sampling

In [1]:
import math
import torch
import torch.nn.functional as F
from torch import Tensor, tensor
from jaxtyping import Float, Int
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.set_grad_enabled(False); # disable backprop

In [11]:
local_files_only = True
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print("device:", device)

# model_name = "gpt2"
model_name = "meta-llama/Llama-3.2-1B"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    local_files_only=local_files_only,
    # attn_implementation="flash_attention_2",
).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=local_files_only)
tokenizer.pad_token = '<|finetune_right_pad_id|>'

if model_name == "gpt2":
    context_length = model.config.n_ctx
    tokenizer.pad_token = tokenizer.eos_token
else: # llama
    context_length = model.config.max_position_embeddings

device: mps


In [12]:
def tokenize(inputs) -> Int[Tensor, "bs seq"]:
    return tokenizer(
        inputs,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=context_length
    )["input_ids"].to(device)
    
def decode(tokens) -> str:
    return tokenizer.batch_decode(tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True)

prompts = ["the cat sat in the hat", "hello darkness my old friend", "soft sampling smosh sampling"]
tokens = tokenize(prompts)
print("Tokenized:")
print(tokens)
print("Decode:")
decode(tokens)

Tokenized:
tensor([[128000,   1820,   8415,   7731,    304,    279,   9072],
        [128000,  15339,  27394,    856,   2362,   4333, 128004],
        [128000,   3594,  25936,   1554,   9451,  25936, 128004]],
       device='mps:0')
Decode:


['the cat sat in the hat',
 'hello darkness my old friend',
 'soft sampling smosh sampling']

## Mock Training

In [100]:
import torch
import torch.nn.functional as F
from fastcore.meta import delegates

def to_cpu(t): return t.detach().cpu()

def mk_proba_dist(
    logits, # (batch_size, d_vocab)
    temperature=1.0,
    top_k=None,
    min_p=None,
):
    "Create probability distribution from logits"
    batch_size, d_vocab = logits.shape
    device = logits.device
    
    if top_k is not None:
        top_logits, top_idxs = logits.topk(top_k, dim=-1)
        logits = (torch.full_like(logits, float('-inf'))
                       .scatter_(1, top_idxs, top_logits))

    probs = F.softmax(logits / temperature, dim=-1)

    if min_p is not None:
        max_p = probs.max(-1, keepdim=True).values
        mask = probs >= (max_p * min_p)
        probs = probs * mask
        probs = probs / probs.sum(dim=-1, keepdim=True) # renormalize

    assert logits.shape == probs.shape
    return probs

@delegates(mk_proba_dist)
def soft_sampling_forward(
    model,
    input_ids, # tokens of shape (batch_size, seq_len)
    guidance_alpha, # guidance weighting -- 1 equivalent to discrete sampling
    **kwargs, # passed to mk_proba_dist
):
    "Single train step using soft sampling"
    assert 0 <= guidance_alpha <= 1
    batch_size, seq_len = input_ids.shape
    device = input_ids.device

    # cache
    past_key_values = None
    # position_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=device)

    W_E = model.get_input_embeddings().weight

    loss = torch.tensor(0., device=device)
    embeds = W_E[input_ids[:, :1]]  # BOS shape: (batch_size, 1, d_model)
    tokens = [ to_cpu(input_ids[:, :1]) ]
    for t in range(1, seq_len):
        outputs = model(
            inputs_embeds=embeds,
            past_key_values=past_key_values,
            # position_ids=position_ids,
            use_cache=True
        )

        logits_t = outputs.logits[:, -1]
        past_key_values = outputs.past_key_values

        p_t = mk_proba_dist(logits_t, **kwargs)

        # loss
        loss_t = F.cross_entropy(p_t, input_ids[:, t])
        loss += loss_t

        # discrete sample -- for logging
        next_tokens = torch.multinomial(p_t, 1) # (batch_size, 1)
        tokens.append(to_cpu(next_tokens))

        # soft sample
        next_emb_soft = p_t @ W_E      # soft sampling
        next_emb_gt = W_E[input_ids[:, t]] # guidance sampling

        next_embed = (
            guidance_alpha * next_emb_gt +
            (1 - guidance_alpha) * next_emb_soft
        )
        embeds = torch.cat([embeds, next_embed[:, None, :]], dim=1)
        # position_ids += 1

    tokens = torch.cat(tokens, dim=1)
    # normalize: mean batch_size, sum seq_len
    loss = loss / batch_size
    return loss, tokens

prompts = [
    "the cat sat in the hat",
    "where did all the hats go?",
    "I am the rizzler",
]

batch = tokenize(prompts)
loss, tokens = soft_sampling_forward(
    model,
    input_ids=batch,
    guidance_alpha=0.,
    temperature=1.,
    min_p=None,
)
print(f"Loss: {loss.item():.4f}")

inp_tokens =  [decode(o) for o in batch]
outp_tokens = [decode(o) for o in tokens]

for inp, outp in zip(inp_tokens, outp_tokens):
    prev_toks = ""
    for i_t, o_t in zip(inp, outp):
        print(f"{repr(prev_toks)} : {repr(i_t)} / {repr(o_t)}")
        prev_toks += i_t
    print("---")

Loss: 27.4418
'' : '' / ''
'' : 'the' / 'Question'
'the' : ' cat' / ']:\n\n'
'the cat' : ' sat' / 'Question'
'the cat sat' : ' in' / '.mark'
'the cat sat in' : ' the' / 'ing'
'the cat sat in the' : ' hat' / 'ose'
'the cat sat in the hat' : '' / 'Random'
---
'' : '' / ''
'' : 'where' / 'When'
'where' : ' did' / '\n'
'where did' : ' all' / 'Question'
'where did all' : ' the' / '\n\n'
'where did all the' : ' hats' / 'on'
'where did all the hats' : ' go' / ' ='
'where did all the hats go' : '?' / '\n\n'
---
'' : '' / ''
'' : 'I' / 'Question'
'I' : ' am' / '_'
'I am' : ' the' / 'Question'
'I am the' : ' r' / '_S'
'I am the r' : 'izz' / ' C'
'I am the rizz' : 'ler' / 'uder'
'I am the rizzler' : '' / 'E'
---


In [14]:
from typing import Callable
from transformers import Trainer, TrainingArguments
from dataclasses import dataclass, asdict, field

@dataclass
class ProbabilityDistributionArguments:
    temperature: float = 1.
    top_k: int | None = None
    min_p: float | None = None
    
    def __post_init__(self):
        assert self.temperature >= 0
        assert self.top_k is None or self.min_p > 0
        assert self.min_p is None or 0 <= self.min_p <= 1

@dataclass
class GuidanceAlphaScheduler:
    guidance_alpha_warmup_delay_ratio: float = 0.1
    guidance_alpha_warmup_ratio: float = 0.7
    guidance_alpha_max: float = 0.7
    
    def get_scheduler(self, num_training_steps: int) -> Callable[[int], float]:
        warmup_delay_steps = int(num_training_steps * self.guidance_alpha_warmup_delay_ratio)
        warmup_end_steps = int(num_training_steps * (self.guidance_alpha_warmup_delay_ratio + self.guidance_alpha_warmup_ratio))
        
        def scheduler(step: int) -> float:
            if step < warmup_delay_steps:
                return 0.0
            elif step < warmup_end_steps:
                # Linear warmup from 0 to guidance_alpha_max
                warmup_progress = (step - warmup_delay_steps) / (warmup_end_steps - warmup_delay_steps)
                return self.guidance_alpha_max * warmup_progress
            else:
                return self.guidance_alpha_max
            
        return scheduler

@dataclass
class SoftDecodingTrainingArguments(TrainingArguments):
    guidance_alpha_scheduler: GuidanceAlphaScheduler = field(default_factory=lambda: GuidanceAlphaScheduler())
    probability_distribution_argments: ProbabilityDistributionArguments = field(default_factory=lambda: ProbabilityDistributionArguments())

    def __post_init__(self):
        super().__post_init__()
        if isinstance(self.probability_distribution_argments, dict):
            self.probability_distribution_argments = ProbabilityDistributionArguments(**self.probability_distribution_argments)
        if isinstance(self.guidance_alpha_scheduler, dict):
            self.guidance_alpha_scheduler = GuidanceAlphaScheduler(**self.guidance_alpha_scheduler)
    

class SoftDecodingTrainer(Trainer):
    def __init__(
        self,
        model,
        args: SoftDecodingTrainingArguments,
        **kwargs,
    ):
        super().__init__(model, args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs["input_ids"]
        loss, predicted_tokens = soft_sampling_forward(
            model=model,
            input_ids=input_ids,
            guidance_alpha=self.guidance_alpha,
            **asdict(self.args.probability_distribution_argments),
        )
        return (
            (loss, {"predicted_tokens": predicted_tokens})
            if return_outputs
            else loss
        )
            
    def create_scheduler(self, num_training_steps: int, optimizer=None):
        self.guidance_alpha_scheduler = self.args.guidance_alpha_scheduler.get_scheduler(num_training_steps)
        return super().create_scheduler(num_training_steps, optimizer)

    def training_step(self, model, inputs):
        self.guidance_alpha = self.guidance_alpha_scheduler(self.state.global_step)
        if self.control.should_log: self.log("guidance_alpha", self.guidance_alpha, prog_bar=True)
        return super().training_step(model, inputs)


training_args = SoftDecodingTrainingArguments(
    output_dir="/tmp/soft-decoding-output",
    guidance_alpha_scheduler=dict(
        guidance_alpha_warmup_delay_ratio=0.1,
        guidance_alpha_warmup_ratio=0.7,
        guidance_alpha_max=0.7
    ),
    probability_distribution_argments=dict(
        temperature=1.0,
        top_k=None,
        min_p=0.2
    )
)
# trainer = SoftDecodingTrainer(
#     model=model,
#     args=training_args,
#     train_dataset=train_dataset,
#     eval_dataset=eval_dataset,
# )
# trainer.train()