In [None]:
import torch
import seaborn as sns

In [None]:
logits = torch.tensor([2.5, 2, 3])
probs = torch.nn.functional.softmax(logits, dim=-1)
print(probs)
ax = sns.barplot(probs)
ax.grid(axis="y")

In [None]:
samples = torch.multinomial(probs, num_samples=1000, replacement=True)
print(torch.bincount(samples) / len(samples))
ax = sns.histplot(samples, stat="probability")
ax.grid(axis="y")

In [None]:
spec_logits = torch.tensor([3.0, 2.0, 1.0])
spec_probs = torch.nn.functional.softmax(spec_logits, dim=-1)
spec_samples = torch.multinomial(spec_probs, num_samples=1000, replacement=True)
print(spec_probs)

In [None]:
ax = sns.histplot(spec_samples, stat="probability")
ax.grid(axis="y")

In [None]:
# def rejection_sample(probs: torch.Tensor, spec_probs: torch.Tensor, idx: torch.Tensor):
#     """
#     Target model distribution: q(x)
#     Draft model distribution: p(x)
#     """
#     q = probs[idx]
#     p = spec_probs[idx]
#     r = torch.rand(1)
#     if r < torch.clamp(q / p, max=1.0):
#         return idx
#     new_p = torch.clamp(probs - spec_probs, min=0.0)
#     return torch.multinomial(new_p, num_samples=1, replacement=True)[0]


def rejection_sample(probs: torch.Tensor, spec_probs: torch.Tensor, idxs: torch.Tensor):
    """
    Target model distribution: q(x)
    Draft model distribution: p(x)
    Vectorized implementation
    """
    qs = probs[idxs]
    ps = spec_probs[idxs]
    rs = torch.rand(len(idxs))
    keep_mask = rs < torch.clamp(qs / ps, max=1.0)
    new_p = torch.clamp(probs - spec_probs, min=0.0)  # pseudo-probability
    new_samples = torch.multinomial(new_p, num_samples=len(idxs), replacement=True)
    return torch.where(keep_mask, idxs, new_samples)

In [None]:
# rej_samples = torch.tensor([rejection_sample(probs, spec_probs, x) for x in spec_samples])
rej_samples = rejection_sample(probs, spec_probs, spec_samples)
ax = sns.histplot(rej_samples, stat="probability")
ax.grid(axis="y")

# Gumbel-Max Trick
Sampling from the mulitnomial is equivalent to taking the argmax over logits plus standard Gumbel noise.

In [None]:
def gumbel_sample(logits: torch.Tensor, n: int):
    gumbel_noise = - (- torch.rand((n, len(logits))).log()).log()
    return torch.argmax(logits + gumbel_noise, dim=-1)

gumbel_samples = gumbel_sample(logits, n=1000)
ax = sns.histplot(gumbel_samples, stat="probability")
ax.grid(axis="y")

# Fused MM-Sample
We now attempt to sample from the logits without materializing them.
We compute the logits incrementally, and as we do, we keep track of the gumbel max index.

In [None]:
vocab_size = 100
hidden_size = 10
weights = torch.rand((vocab_size, hidden_size))
weights