In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import seaborn as sns
import pandas as pd
import os

torch.set_default_device("cuda")

# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Gumbel-Max Trick
Sampling from the mulitnomial is equivalent to taking the argmax over logits plus standard Gumbel noise.

In [None]:
import matplotlib.pyplot as plt

logits = torch.tensor([2.5, 2, 3])
probs = torch.nn.functional.softmax(logits, dim=-1)


def gumbel_sample(logits: torch.Tensor, n: int):
    gumbel_noise = -(-torch.rand((n, len(logits))).log()).log()
    return torch.argmax(logits + gumbel_noise, dim=-1)


fig, axs = plt.subplots(1, 2, figsize=(10, 4))

ax = sns.barplot(probs.cpu(), ax=axs[0])
ax.grid(axis="y")
ax.set_title("Barplot of probs")

gumbel_samples = gumbel_sample(logits, n=1000)
ax = sns.histplot(gumbel_samples.cpu(), stat="probability", ax=axs[1], binwidth=0.4)
ax.grid(axis="y")
ax.set_title("Histogram of gumbel samples")

# Fused MM-Sample
We now attempt to sample from the logits without materializing them.
We compute the logits incrementally, and as we do, we keep track of the gumbel max index.

In [None]:
vocab_size = 100  # V
hidden_size = 10  # D
logits1 = torch.arange(-vocab_size / 2, vocab_size / 2)[None, :]  # [1, V]
logits2 = torch.arange(vocab_size / 2, -vocab_size / 2, step=-1)[None, :]  # [1, V]
logits = torch.cat([logits1, logits2], dim=0)  # [seq_len, V]
hl_seq_len = logits.shape[0]
# use SVD to construct the hidden states that yield the logits
# use pseudoinverse to construct the weights.
# (there are many ways to do this, this is just one)
# W @ H = L.T
#  -> W = L.T @ H⁻¹
U, S, Vt = torch.linalg.svd(logits, full_matrices=False)
hidden_states = torch.cat(  # [D, seq_len]
    [
        U.T,
        torch.rand((hidden_size - hl_seq_len, hl_seq_len)),  # padding
    ],
)
weights = logits.T @ torch.linalg.pinv(hidden_states)  # [V, D]
assert torch.allclose(weights @ hidden_states, logits.T)

In [None]:
# To bfloat 16
weights = weights.bfloat16()
hidden_states = hidden_states.bfloat16()

## Baseline: PyTorch Sampling

In [None]:
from fused_mm_sampling import sample


def plot_samples(samples: torch.Tensor, seq_len: int, num_samples: int):
    data = {
        "sample": samples.flatten().cpu(),
        "seq": [seq for seq in range(seq_len) for _ in range(num_samples)],
    }
    df = pd.DataFrame(data)
    ax = sns.histplot(df, x="sample", hue="seq", bins=100)
    ax.grid(axis="y", alpha=0.5)


num_samples = 1000
temperature = 5
samples, probs = sample(
    weights,
    hidden_states,
    num_samples=num_samples,
    temperature=temperature,
    return_probs=True,
)  # [seq_len, num_samples]
plot_samples(samples, hl_seq_len, num_samples)

## Fused PyTorch Incremental Sampling

In [None]:
from fused_mm_sampling import incremental_sample_pt

samples2 = incremental_sample_pt(weights, hidden_states, num_samples, temperature)
plot_samples(samples2, hl_seq_len, num_samples)

# Triton

In [None]:
from fused_mm_sampling import fused_sample_triton

samples3 = fused_sample_triton(
    weights,
    hidden_states,
    num_samples,
    temperature,  # temperature,
    seed=111,
)
plot_samples(samples3, hl_seq_len, num_samples)

# Compare Speed - Realistic Example
Now we test a realistic example with a large `vocab_size=256k`, and a large `hidden_size=5120`

In [None]:
vocab_size = 256_000
hidden_size = 5120
seq_len = 256
num_samples = 1
speedtest_kwargs = dict(
    hidden_states=torch.randn((hidden_size, seq_len)).bfloat16(),
    weights=torch.randn((vocab_size, hidden_size)).bfloat16(),
    num_samples=num_samples,
    temperature=1.0,
)

In [None]:
sample_compiled = torch.compile(sample)
_ = sample_compiled(**speedtest_kwargs)
# sample_incremental_pt_compiled = torch.compile(incremental_sample_pt)

%timeit fused_sample_triton(**speedtest_kwargs, seed=0)

%timeit sample_compiled(**speedtest_kwargs)

# Memory Profiling
Should be observable with the PyTorch Memory Timeline: https://pytorch.org/blog/understanding-gpu-memory-1/

In [None]:
logits_numel = vocab_size * 256  # new_seqlen
logits_bytes = logits_numel * 2  # bfloat16
logits_gb = logits_bytes / 10**9
print(f"logits_numel: {logits_numel:,}")
print(f"logits_gb: {logits_gb:.2f} GB")

weights_numel = vocab_size * hidden_size
weights_bytes = weights_numel * 2  # bfloat16
weights_gb = weights_bytes / 10**9
print(f"weights_numel: {weights_numel:,}")
print(f"weights_gb: {weights_gb:.2f} GB")