In [None]:
import torch
import seaborn as sns
import pandas as pd

torch.set_default_device("cuda")

In [None]:
logits = torch.tensor([2.5, 2, 3])
probs = torch.nn.functional.softmax(logits, dim=-1)
print(probs)
ax = sns.barplot(probs.cpu())
ax.grid(axis="y")

In [None]:
samples = torch.multinomial(probs, num_samples=1000, replacement=True)
print(torch.bincount(samples) / len(samples))
ax = sns.histplot(samples.cpu(), stat="probability")
ax.grid(axis="y")

In [None]:
spec_logits = torch.tensor([3.0, 2.0, 1.0])
spec_probs = torch.nn.functional.softmax(spec_logits, dim=-1)
spec_samples = torch.multinomial(spec_probs, num_samples=1000, replacement=True)
print(spec_probs)

In [None]:
ax = sns.histplot(spec_samples.cpu(), stat="probability")
ax.grid(axis="y")

In [None]:
# def rejection_sample(probs: torch.Tensor, spec_probs: torch.Tensor, idx: torch.Tensor):
#     """
#     Target model distribution: q(x)
#     Draft model distribution: p(x)
#     """
#     q = probs[idx]
#     p = spec_probs[idx]
#     r = torch.rand(1)
#     if r < torch.clamp(q / p, max=1.0):
#         return idx
#     new_p = torch.clamp(probs - spec_probs, min=0.0)
#     return torch.multinomial(new_p, num_samples=1, replacement=True)[0]


def rejection_sample(probs: torch.Tensor, spec_probs: torch.Tensor, idxs: torch.Tensor):
    """
    Target model distribution: q(x)
    Draft model distribution: p(x)
    Vectorized implementation
    """
    qs = probs[idxs]
    ps = spec_probs[idxs]
    rs = torch.rand(len(idxs))
    keep_mask = rs < torch.clamp(qs / ps, max=1.0)
    new_p = torch.clamp(probs - spec_probs, min=0.0)  # pseudo-probability
    new_samples = torch.multinomial(new_p, num_samples=len(idxs), replacement=True)
    return torch.where(keep_mask, idxs, new_samples)

In [None]:
# rej_samples = torch.tensor([rejection_sample(probs, spec_probs, x) for x in spec_samples])
rej_samples = rejection_sample(probs, spec_probs, spec_samples)
ax = sns.histplot(rej_samples.cpu(), stat="probability")
ax.grid(axis="y")

# Gumbel-Max Trick
Sampling from the mulitnomial is equivalent to taking the argmax over logits plus standard Gumbel noise.

In [None]:
def gumbel_sample(logits: torch.Tensor, n: int):
    gumbel_noise = -(-torch.rand((n, len(logits))).log()).log()
    return torch.argmax(logits + gumbel_noise, dim=-1)


gumbel_samples = gumbel_sample(logits, n=1000)
ax = sns.histplot(gumbel_samples.cpu(), stat="probability")
ax.grid(axis="y")

# Fused MM-Sample
We now attempt to sample from the logits without materializing them.
We compute the logits incrementally, and as we do, we keep track of the gumbel max index.

In [None]:
vocab_size = 100  # V
hidden_size = 10  # D
logits1 = torch.arange(-vocab_size / 2, vocab_size / 2)[None, :]  # [1, V]
logits2 = torch.arange(vocab_size / 2, -vocab_size / 2, step=-1)[None, :]  # [1, V]
logits = torch.cat([logits1, logits2], dim=0)  # [seq_len, V]
seq_len = logits.shape[0]
# use SVD to construct the hidden states that yield the logits
# use pseudoinverse to construct the weights.
# (there are many ways to do this, this is just one)
# W @ H = L.T
#  -> W = L.T @ H⁻¹
U, S, Vt = torch.linalg.svd(logits, full_matrices=False)
hidden_states = torch.cat(  # [D, seq_len]
    [
        U.T,
        torch.rand((hidden_size - seq_len, seq_len)),  # padding
    ],
)
weights = logits.T @ torch.linalg.pinv(hidden_states)  # [V, D]
assert torch.allclose(weights @ hidden_states, logits.T)

## Baseline: PyTorch Sampling

In [None]:
def sample(
    weights: torch.Tensor,
    hidden_states: torch.Tensor,
    num_samples: int,
    temperature: float,
):
    logits = weights @ hidden_states  # [seq_len, V]
    logits -= torch.max(logits, dim=0, keepdim=True).values
    probs = torch.nn.functional.softmax(logits / temperature, dim=0)  # [seq_len, V]
    samples = torch.multinomial(probs.T, num_samples=num_samples, replacement=True)
    return samples, probs


def plot_samples(samples: torch.Tensor, seq_len: int, num_samples: int):
    data = {
        "sample": samples.flatten().cpu(),
        "seq": [seq for seq in range(seq_len) for _ in range(num_samples)],
    }
    df = pd.DataFrame(data)
    sns.histplot(df, x="sample", hue="seq", bins=100)


num_samples = 1000
samples, probs = sample(
    weights, hidden_states, num_samples=num_samples, temperature=5
)  # [seq_len, num_samples]
plot_samples(samples, seq_len, num_samples)

## Fused PyTorch Incremental Sampling

In [None]:
def incremental_sample_pt(
    weights: torch.Tensor,
    hidden_states: torch.Tensor,
    num_samples: int,
    temperature: float,
):
    V, D = weights.shape
    D, seq_len = hidden_states.shape
    block_size = 8
    # compute logits blocks
    gumbel_max = float("-inf") * torch.ones(size=(num_samples, seq_len))
    gumbel_max_idx = torch.empty(size=(num_samples, seq_len), dtype=torch.long)
    n_blocks = torch.ceil(torch.tensor(V) / block_size).int()
    for blk_idx in range(n_blocks):
        idx_from = blk_idx * block_size
        idx_to = (blk_idx + 1) * block_size
        w_blk = weights[idx_from:idx_to]  # [block_size, D]
        logits_blk = w_blk @ hidden_states / temperature  # [seq_len, block_size]
        unif_noise = torch.rand((num_samples, *logits_blk.shape))
        gumbel_noise = -(-unif_noise.log()).log()
        new_max, new_max_idx_local = torch.max(logits_blk + gumbel_noise, dim=1)
        new_max_idx_global = idx_from + new_max_idx_local

        replace_mask = new_max > gumbel_max
        gumbel_max[replace_mask] = new_max[replace_mask]
        gumbel_max_idx[replace_mask] = new_max_idx_global[replace_mask]
    return gumbel_max_idx.T


samples2 = incremental_sample_pt(weights, hidden_states, num_samples, temperature=5)
plot_samples(samples2, seq_len, num_samples)