In [1]:
import math
import torch
import torch.utils.benchmark as bench
import matplotlib.pyplot as plt
from tqdm import tqdm

# Ensure deterministic HW settings
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

# Lower compile‑cache size avoids graph sprawl during benchmarking
torch._dynamo.config.cache_size_limit = 100

torch.set_default_device(torch.device("cuda:1"))

# Import the pure‑bf16 Householder QR implementation
from scripts.opts.qr_bf16 import qr_bf16

In [None]:
@torch.compile(dynamic=True)
def ortho_QR(G: torch.Tensor):
    """Standard QR using torch.linalg (fp32 on CPU, bfloat16 native on recent CUDA)."""
    Q, _ = torch.linalg.qr(G)
    return Q


@torch.compile(dynamic=True)
def ortho_QR_bf16(G: torch.Tensor):
    """Pure‑bf16 Householder QR built on qr_bf16 (strict bf16 path)."""
    Q, _ = qr_bf16(G)
    return Q


@torch.compile(dynamic=True)
def ortho_CQR1(G):
    R, _ = torch.linalg.cholesky_ex(G.T @ G, upper=True)
    Q = torch.linalg.solve_triangular(R, G, upper=True, left=False)
    return Q


@torch.compile(dynamic=True)
def ortho_CQR2(G):
    R, _ = torch.linalg.cholesky_ex(G.T @ G, upper=True)
    Q = torch.linalg.solve_triangular(R, G, upper=True, left=False)
    R, _ = torch.linalg.cholesky_ex(Q.T @ Q, upper=True)
    Q = torch.linalg.solve_triangular(R, Q, upper=True, left=False)
    return Q


@torch.compile(dynamic=True)
def ortho_RCQR1(G, oversample: float = 1.25):
    m, n = G.shape
    k = math.ceil(oversample * n / 128) * 128
    std = math.sqrt(1.0 / k)
    S = torch.empty((k, m), device=G.device, dtype=G.dtype).normal_(std=std)
    _, R = torch.linalg.qr(S @ G, mode="r")
    Q = torch.linalg.solve_triangular(R, G, upper=True, left=False)
    return Q


@torch.compile(dynamic=True)
def ortho_RCQR2(G, oversample: float = 1.25):
    m, n = G.shape
    k = math.ceil(oversample * n / 128) * 128
    std = math.sqrt(1.0 / k)
    S = torch.empty((k, m), device=G.device, dtype=G.dtype).normal_(std=std)
    _, R = torch.linalg.qr(S @ G, mode="r")
    Q = torch.linalg.solve_triangular(R, G, upper=True, left=False)
    R, _ = torch.linalg.cholesky_ex(Q.T @ Q, upper=True)
    Q = torch.linalg.solve_triangular(R, Q, upper=True, left=False)
    return Q


@torch.compile(dynamic=True)
def ortho_RCQR2_16(G, oversample: float = 1.25):
    m, n = G.shape
    k = math.ceil(oversample * n / 128) * 128
    std = math.sqrt(1.0 / k)

    # create sketch entirely in bf16
    S = torch.empty((k, m), device=G.device, dtype=torch.bfloat16).normal_(std=std)
    # both S and G→bf16 so matmul happens in bf16
    sketch = torch.matmul(S, G.to(torch.bfloat16))
    _, R_bf16 = qr_bf16(sketch)

    # back to f32 for the final triangular solve / Cholesky
    Q = torch.linalg.solve_triangular(R_bf16.to(G.dtype), G, upper=True, left=False)
    R2, _ = torch.linalg.cholesky_ex(Q.T @ Q, upper=True)
    Q = torch.linalg.solve_triangular(R2, Q, upper=True, left=False)
    return Q


@torch.compile(dynamic=True)
def ortho_RCQR2_16(G: torch.Tensor, oversample: float = 1.25):
    m, n = G.shape
    k = math.ceil(oversample * n / 128) * 128
    std = math.sqrt(1.0 / k)

    # Sketch in bf16
    S = torch.empty((k, m), device=G.device, dtype=torch.bfloat16).normal_(std=std)
    SG = S @ G.to(torch.bfloat16)  # (k × n) strictly bf16
    _, R_bf16 = qr_bf16(SG)  # ⟹ bf16 upper-triangular R

    # Back to original precision for the triangular solve & refinement
    R = R_bf16.to(G.dtype)
    Q = torch.linalg.solve_triangular(R, G, upper=True, left=False)
    R2, _ = torch.linalg.cholesky_ex(Q.T @ Q, upper=True)
    Q = torch.linalg.solve_triangular(R2, Q, upper=True, left=False)
    return Q


@torch.compile(dynamic=True)
def ortho_NS(
    G: torch.Tensor,
    ns_iters: int = 5,
    tuned_ns_consts: bool = True,
    epsilon: float = 1e-8,
):
    a, b, c = (3.4445, -4.7750, 2.0315) if tuned_ns_consts else (2.0, -1.5, 0.5)
    if G.device.type == "cuda":
        G = G.to(torch.bfloat16)
    X = G / (G.norm() + epsilon)
    transposed = False
    if X.size(0) > X.size(1):
        X = X.T
        transposed = True
    for _ in range(ns_iters):
        A = X @ X.T
        B = b * A + c * (A @ A)
        X = a * X + B @ X
    return X.T if transposed else X


@torch.compile(dynamic=True)
def orthogonalize_SVD(G: torch.Tensor):
    U, _, Vt = torch.linalg.svd(G, full_matrices=False)
    return U @ Vt

In [3]:
def check_orthogonality(Q: torch.Tensor):
    if Q.size(0) < Q.size(1):
        Q = Q.T
    identity = torch.eye(Q.size(1), device=Q.device, dtype=Q.dtype)
    return torch.norm(Q.T @ Q - identity)


# ---------------------------------------------------
#   QUICK SANITY CHECKS
# ---------------------------------------------------
shape = (1024, 1024)
X = torch.randn(shape, dtype=torch.float32)
print(f"Shape: {shape}")
print(f"Condition number: {torch.linalg.cond(X)}")
print(f"Orthogonality QR:        {check_orthogonality(ortho_QR(X))}")
print(f"Orthogonality of CQR1:  {check_orthogonality(ortho_CQR1(X))}")
print(f"Orthogonality of CQR2:  {check_orthogonality(ortho_CQR2(X))}")
print(f"Orthogonality of RCQR1: {check_orthogonality(ortho_RCQR1(X, oversample=1.25))}")
print(f"Orthogonality of RCQR2: {check_orthogonality(ortho_RCQR2(X, oversample=1.25))}")
print(f"Orthogonality of NS:    {check_orthogonality(ortho_NS(X))}")
print(f"Orthogonality of SVD:   {check_orthogonality(orthogonalize_SVD(X))}")
print(f"Orthogonality RCQR2_16:   {check_orthogonality(ortho_RCQR2_16(X))}")

Shape: (1024, 1024)
Condition number: 6977.36767578125
Orthogonality QR:        2.8229291274328716e-05
Orthogonality of CQR1:  0.08875679969787598
Orthogonality of CQR2:  1.7651178495725617e-05
Orthogonality of RCQR1: 352.42108154296875
Orthogonality of RCQR2: 9.364620927954093e-05
Orthogonality of NS:    11.0625


W0516 13:21:01.801000 2729465 torch/_dynamo/variables/tensor.py:869] [7/0] Graph break from `Tensor.item()`, consider setting:
W0516 13:21:01.801000 2729465 torch/_dynamo/variables/tensor.py:869] [7/0]     torch._dynamo.config.capture_scalar_outputs = True
W0516 13:21:01.801000 2729465 torch/_dynamo/variables/tensor.py:869] [7/0] or:
W0516 13:21:01.801000 2729465 torch/_dynamo/variables/tensor.py:869] [7/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0516 13:21:01.801000 2729465 torch/_dynamo/variables/tensor.py:869] [7/0] to include these operations in the captured graph.
W0516 13:21:01.801000 2729465 torch/_dynamo/variables/tensor.py:869] [7/0] 
W0516 13:21:01.801000 2729465 torch/_dynamo/variables/tensor.py:869] [7/0] Graph break: from user code at:
W0516 13:21:01.801000 2729465 torch/_dynamo/variables/tensor.py:869] [7/0]   File "/tmp/ipykernel_2729465/350582922.py", line 61, in ortho_RCQR2_16
W0516 13:21:01.801000 2729465 torch/_dynamo/variables/tensor.py:869] [7/0]     _, R_bf1

Orthogonality of SVD:   0.011044292710721493


AssertionError: 
SpeculationLog diverged at index 1 (log had 19 entries):
- Expected: /home/kwangjunahn/fast-opt/scripts/opts/qr_bf16.py:21 (LOAD_FAST at ip=39)
- Actual: /home/kwangjunahn/fast-opt/scripts/opts/qr_bf16.py:21 (BINARY_SUBSCR at ip=41)
Previous instruction: /home/kwangjunahn/fast-opt/scripts/opts/qr_bf16.py:20(CALL_FUNCTION @ 37)

There are two usual reasons why this may have occured:
- When Dynamo analysis restarted, the second run took a different path than
  the first.  If this occurred, the previous instruction is the critical instruction that
  behaved differently.
- Speculation entries are only added under certain conditions (as seen in
  step()), e.g., there must exist operators in the graph; those conditions may
  have changed on restart.

If this divergence was intentional, clear the speculation log before restarting (do NOT
do this for graph breaks, you will infinite loop).

Otherwise, please submit a bug report, ideally including the contents of TORCH_LOGS=+dynamo


from user code:
   File "/home/kwangjunahn/fast-opt/scripts/opts/qr_bf16.py", line 21, in torch_dynamo_resume_in__householder_reflector_at_20
    x0 = x[0]

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True


In [None]:
# ---------------------------------------------------
#   BENCHMARKING UTILITIES
# ---------------------------------------------------


def benchmark(
    functions: list[str], shapes: list[tuple[int, int]], device: str | None = None
):
    device = device or torch.get_default_device()
    results = {func: [] for func in functions}
    for m, n in tqdm(shapes):
        X = torch.randn((m, n), device=device, dtype=torch.float32)
        for func in functions:
            tmr = bench.Timer(
                stmt=f"{func}(X)",
                setup=f"from __main__ import {func}",
                globals={"X": X},
                num_threads=1,
            )
            res = tmr.blocked_autorange(min_run_time=1)
            results[func].append(res.median)
    return results


def benchmark_and_plot(
    base_dim: list[int], sparsity_ratio: int, device: str | None = None
):
    shapes = [(n, n // sparsity_ratio) for n in base_dim]
    functions = [
        "ortho_QR",
        "ortho_QR_bf16",
        "ortho_CQR1",
        "ortho_RCQR2",
        "ortho_RCQR2_16",
        "ortho_NS",
    ]
    results = benchmark(functions, shapes, device)
    plt.figure(figsize=(10, 6))
    for func, times in results.items():
        plt.plot([n for n, _ in shapes], times, label=func, marker="o")
    plt.xscale("log")
    plt.yscale("log")
    plt.xticks([n for n, _ in shapes], [f"{n}x{m}" for n, m in shapes], rotation=45)
    plt.xlabel("Matrix Size")
    plt.ylabel("Time (s)")
    plt.title(f"n x n/{sparsity_ratio}")
    plt.legend()
    plt.grid()
    plt.show()


def benchmark_actual_shapes(
    model_dims: list[int], mlp_expansion: int, sparsity_ratio: int, tp_dim: int
):
    ns_matrix_dims = [(mlp_expansion * d, d) for d in model_dims]
    qr_matrix_dims = [(mlp_expansion * d, d // sparsity_ratio) for d in model_dims]
    cqr_matrix_dims = [
        (mlp_expansion * d // tp_dim, d // sparsity_ratio) for d in model_dims
    ]

    results_ns = benchmark(["ortho_NS"], ns_matrix_dims)
    results_qr = benchmark(["ortho_QR"], qr_matrix_dims)
    results_qr16 = benchmark(["ortho_QR_bf16"], qr_matrix_dims)
    results_cqr = benchmark(["ortho_CQR1"], cqr_matrix_dims)
    results_r2 = benchmark(["ortho_RCQR2"], cqr_matrix_dims)
    results_r216 = benchmark(["ortho_RCQR2_16"], cqr_matrix_dims)
    results = {
        **results_ns,
        **results_qr,
        **results_qr16,
        **results_cqr,
        **results_r2,
        **results_r216,
    }

    # Print calculated matrix shapes for reference
    print(
        "                            NS shape             QR shape             CQR shape"
    )
    for i, dim in enumerate(model_dims):
        ns_shape = str(ns_matrix_dims[i])
        qr_shape = str(qr_matrix_dims[i])
        cqr_shape = str(cqr_matrix_dims[i])
        print(f"Model dim {dim:5}: {ns_shape:>20} {qr_shape:>20} {cqr_shape:>20}")

    plt.figure(figsize=(10, 6))
    for func, times in results.items():
        plt.plot(model_dims, times, label=func, marker="o")
    plt.xscale("log")
    plt.yscale("log")
    plt.xticks(model_dims, [str(d) for d in model_dims])
    plt.xlabel("Model Dimension")
    plt.ylabel("Time (s)")
    plt.title(f"{mlp_expansion}n x n matrix, sparsity 1/{sparsity_ratio}, TP {tp_dim}")
    plt.legend()
    plt.grid()
    plt.show()

In [None]:
# ---------------------------------------------------
#   RUN BENCHMARKS
# ---------------------------------------------------
print(f"Using device: {torch.get_default_device()}")
base_dim = [1024, 2048, 4096, 8192, 16384]
benchmark_and_plot(base_dim, sparsity_ratio=1)
benchmark_and_plot(base_dim, sparsity_ratio=4)
benchmark_and_plot(base_dim, sparsity_ratio=16)

# Example large‑model settings
model_dims = [1024, 2048, 4096, 8192, 16384]
benchmark_actual_shapes(model_dims, mlp_expansion=1, sparsity_ratio=16, tp_dim=8)
benchmark_actual_shapes(model_dims, mlp_expansion=4, sparsity_ratio=4, tp_dim=1)
benchmark_actual_shapes(model_dims, mlp_expansion=4, sparsity_ratio=16, tp_dim=1)
benchmark_actual_shapes(model_dims, mlp_expansion=4, sparsity_ratio=4, tp_dim=8)
benchmark_actual_shapes(model_dims, mlp_expansion=4, sparsity_ratio=16, tp_dim=8)
benchmark_actual_shapes(model_dims, mlp_expansion=4, sparsity_ratio=1, tp_dim=1)
benchmark_actual_shapes(model_dims, mlp_expansion=4, sparsity_ratio=1, tp_dim=8)
benchmark_actual_shapes(model_dims, mlp_expansion=4, sparsity_ratio=64, tp_dim=8)