In [196]:
import math
import torch
import torch.utils.benchmark
from tqdm import tqdm

In [54]:
@torch.compile(dynamic=True)
def orthogonalize_QR(G: torch.Tensor):
    """Orthogonalize matrix using QR decomposition"""
    Q, R = torch.linalg.qr(G)
    return Q


@torch.compile(dynamic=True)
def orthogonalize_CQR(G):
    """Orthogonalize matrix using Cholesky QR decomposition"""
    R, _ = torch.linalg.cholesky_ex(G.T @ G, upper=True)
    Q = torch.linalg.solve_triangular(R, G, upper=True, left=False)
    return Q


@torch.compile(dynamic=True)
def orthogonalize_NS(
    G: torch.Tensor,
    ns_iters: int = 5,
    tuned_ns_consts: bool = True,
    epsilon: float = 1e-8,
):
    """Orthogonalize matrix using Newton-Schulz iteration"""
    if tuned_ns_consts:
        # Tuned constants
        a, b, c = 3.4445, -4.7750, 2.0315
    else:
        # "Vanilla" constants
        a, b, c = 2.0, -1.5, 0.5

    # Convert to 16 bit
    if G.device.type == "cuda":
        G = G.to(torch.bfloat16)

    # Normalize G so it doesn't blow up in iteration
    X = G / (G.norm() + epsilon)
    transposed = False

    # If G has more rows than columns, transpose it
    if X.size(0) > X.size(1):
        X = X.T
        transposed = True

    for _ in range(ns_iters):
        A = X @ X.T
        B = b * A + c * (A @ A)
        X = a * X + B @ X

    if transposed:
        X = X.T

    return X


@torch.compile(dynamic=True)
def orthogonalize_SVD(G: torch.Tensor):
    """Orthogonalize matrix using SVD"""
    U, _, Vt = torch.linalg.svd(G)
    return U @ Vt

In [55]:
# testing torch.utils.benchmark.Timer

x = torch.randn(1024, 1024)
x = x.to(device="cuda")

timer = torch.utils.benchmark.Timer(
    stmt="orthogonalize_QR(x)",
    setup="from __main__ import orthogonalize_QR",
    globals={"x": x},
)
print(timer.blocked_autorange(min_run_time=1))

timer = torch.utils.benchmark.Timer(
    stmt="orthogonalize_CQR(x)",
    setup="from __main__ import orthogonalize_CQR",
    globals={"x": x},
)
print(timer.blocked_autorange(min_run_time=1))

timer = torch.utils.benchmark.Timer(
    stmt="orthogonalize_NS(x)",
    setup="from __main__ import orthogonalize_NS",
    globals={"x": x},
)
print(timer.blocked_autorange(min_run_time=1))

timer = torch.utils.benchmark.Timer(
    stmt="orthogonalize_SVD(x)",
    setup="from __main__ import orthogonalize_SVD",
    globals={"x": x},
)
print(timer.blocked_autorange(min_run_time=1))

<torch.utils.benchmark.utils.common.Measurement object at 0x7fcc3128e0b0>
orthogonalize_QR(x)
setup: from __main__ import orthogonalize_QR
  Median: 5.13 ms
  2 measurements, 100 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fcc312cec20>
orthogonalize_CQR(x)
setup: from __main__ import orthogonalize_CQR
  Median: 1.38 ms
  IQR:    0.00 ms (1.38 to 1.38)
  8 measurements, 100 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fcc3c8b9ab0>
orthogonalize_NS(x)
setup: from __main__ import orthogonalize_NS
  Median: 367.75 us
  IQR:    4.98 us (365.73 to 370.71)
  2677 measurements, 1 runs per measurement, 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7fcc3c8b83d0>
orthogonalize_SVD(x)
setup: from __main__ import orthogonalize_SVD
  Median: 61.12 ms
  IQR:    0.08 ms (61.09 to 61.16)
  17 measurements, 1 runs per measurement, 1 thread


In [56]:
def benchmark_qr_and_ns(matrix_shapes, device="cuda", ns_iters=5):
    """
    Benchmark QR vs Newton-Schulz orthogonalization.
    matrix_shapes: list of tuples, each tuple is a shape (m, n)
    """
    device = torch.device(device)
    results = []

    for shape in tqdm(matrix_shapes):
        assert len(shape) == 2, "Shape must be tuple of length 2"

        x = torch.randn(*shape, device=device)
        label = f"QR vs NS-{ns_iters} for device={device.type}"
        sub_label = f"{shape[0]}x{shape[1]}"

        timer = torch.utils.benchmark.Timer(
            stmt="orthogonalize_QR(x)",
            setup="from __main__ import orthogonalize_QR",
            globals={"x": x},
            description="QR",
            label=label,
            sub_label=sub_label,
        )
        qr_time = timer.blocked_autorange(min_run_time=1)

        timer = torch.utils.benchmark.Timer(
            stmt="orthogonalize_NS(x, ns_iters=ns_iters)",
            setup="from __main__ import orthogonalize_NS",
            globals={"x": x, "ns_iters": ns_iters},
            description="NS",
            label=label,
            sub_label=sub_label,
        )
        ns_time = timer.blocked_autorange(min_run_time=1)

        results.extend((qr_time, ns_time))

    return results

In [5]:
# Benchmark for CPU
# Use smaller matrices for CPU to avoid long run times
base_dim = [256, 512, 1024]
matrix_shapes = []
for n in base_dim:
    matrix_shapes.extend([(n, n), (n, 4 * n), (4 * n, n)])
print(matrix_shapes)

results = benchmark_qr_and_ns(matrix_shapes, device="cpu", ns_iters=5)

compare = torch.utils.benchmark.Compare(results)
compare.trim_significant_figures()
compare.print()

[(256, 256), (256, 1024), (1024, 256), (512, 512), (512, 2048), (2048, 512), (1024, 1024), (1024, 4096), (4096, 1024)]


100%|██████████| 9/9 [00:36<00:00,  4.05s/it]

[- QR vs NS-5 for device=cpu -]
                 |   QR  |   NS
1 threads: --------------------
      256x256    |    1  |    5
      256x1024   |    5  |   15
      1024x256   |    6  |   15
      512x512    |   10  |   38
      512x2048   |   34  |  120
      2048x512   |   37  |  121
      1024x1024  |   67  |  310
      1024x4096  |  240  |  930
      4096x1024  |  261  |  940

Times are in milliseconds (ms).






In [44]:
# Benchmark for GPU
base_dim = [1024, 2048, 4096, 8192, 16384]
matrix_shapes = []
for n in base_dim:
    matrix_shapes.extend([(n, n), (n, 4 * n), (4 * n, n)])
print(matrix_shapes)

results = benchmark_qr_and_ns(matrix_shapes, device="cuda", ns_iters=5)

compare = torch.utils.benchmark.Compare(results)
compare.trim_significant_figures()
compare.print()

[(1024, 1024), (1024, 4096), (4096, 1024), (2048, 2048), (2048, 8192), (8192, 2048), (4096, 4096), (4096, 16384), (16384, 4096), (8192, 8192), (8192, 32768), (32768, 8192), (16384, 16384), (16384, 65536), (65536, 16384)]


100%|██████████| 15/15 [00:48<00:00,  3.24s/it]

[------ QR vs NS-5 for device=cuda -----]
                   |     QR    |     NS  
1 threads: ------------------------------
      1024x1024    |     5130  |      339
      1024x4096    |     6810  |      535
      4096x1024    |    11400  |      565
      2048x2048    |    12200  |      720
      2048x8192    |    17000  |     2100
      8192x2048    |    35000  |     2200
      4096x4096    |    38100  |     5100
      4096x16384   |    56700  |    10000
      16384x4096   |   120000  |    10000
      8192x8192    |   134000  |    40000
      8192x32768   |   245000  |   100000
      32768x8192   |   450000  |   100000
      16384x16384  |   630000  |   300000
      16384x65536  |  1457000  |  1019000
      65536x16384  |  2620000  |  1008000

Times are in microseconds (us).






In [45]:
# Print ratio of times
# This assumes that the results are in pairs, first QR then NS
for i in range(len(results) // 2):
    qr_result = results[i * 2]
    ns_result = results[i * 2 + 1]
    assert qr_result.task_spec.sub_label == ns_result.task_spec.sub_label

    shape = qr_result.task_spec.sub_label
    qr_time = 1e6 * qr_result.mean
    ns_time = 1e6 * ns_result.mean
    ratio = qr_time / ns_time
    print(f"{shape:>12}\tQR={qr_time:7.0f}\tNS={ns_time:7.0f}\tratio={ratio:2.2f}")

   1024x1024	QR=   5130	NS=    339	ratio=15.14
   1024x4096	QR=   6807	NS=    536	ratio=12.71
   4096x1024	QR=  11411	NS=    565	ratio=20.18
   2048x2048	QR=  12238	NS=    724	ratio=16.90
   2048x8192	QR=  17055	NS=   2097	ratio=8.13
   8192x2048	QR=  34966	NS=   2214	ratio=15.79
   4096x4096	QR=  38107	NS=   5124	ratio=7.44
  4096x16384	QR=  56931	NS=  14774	ratio=3.85
  16384x4096	QR= 116737	NS=  15204	ratio=7.68
   8192x8192	QR= 134161	NS=  38114	ratio=3.52
  8192x32768	QR= 245070	NS= 114569	ratio=2.14
  32768x8192	QR= 448316	NS= 114732	ratio=3.91
 16384x16384	QR= 633337	NS= 297215	ratio=2.13
 16384x65536	QR=1457101	NS=1019387	ratio=1.43
 65536x16384	QR=2619854	NS=1007954	ratio=2.60


In [57]:
def benchmark_qr(matrix_shapes, device="cuda"):
    """
    Benchmark QR vs Cholesky QR orthogonalization.
    matrix_shapes: list of tuples, each tuple is a shape (m, n)
    """
    device = torch.device(device)
    results = []

    for shape in tqdm(matrix_shapes):
        assert len(shape) == 2, "Shape must be tuple of length 2"

        x = torch.randn(*shape, device=device)
        label = f"QR vs CQR for device={device.type}"
        sub_label = f"{shape[0]}x{shape[1]}"

        timer = torch.utils.benchmark.Timer(
            stmt="orthogonalize_QR(x)",
            setup="from __main__ import orthogonalize_QR",
            globals={"x": x},
            description="QR",
            label=label,
            sub_label=sub_label,
        )
        qr_time = timer.blocked_autorange(min_run_time=1)

        timer = torch.utils.benchmark.Timer(
            stmt="orthogonalize_CQR(x)",
            setup="from __main__ import orthogonalize_CQR",
            globals={"x": x},
            description="CQR",
            label=label,
            sub_label=sub_label,
        )
        cqr_time = timer.blocked_autorange(min_run_time=1)

        results.extend((qr_time, cqr_time))

    return results

In [59]:
# Benchmark for CPU
# Use smaller matrices for CPU to avoid long run times
base_dim = [256, 512, 1024]
matrix_shapes = []
for n in base_dim:
    matrix_shapes.extend([(n, n), (n, 4 * n), (4 * n, n)])
print(matrix_shapes)

results = benchmark_qr(matrix_shapes, device="cpu")

compare = torch.utils.benchmark.Compare(results)
compare.trim_significant_figures()
compare.print()

[(256, 256), (256, 1024), (1024, 256), (512, 512), (512, 2048), (2048, 512), (1024, 1024), (1024, 4096), (4096, 1024)]


100%|██████████| 9/9 [00:24<00:00,  2.72s/it]

[-- QR vs CQR for device=cpu -]
                 |   QR  |  CQR
1 threads: --------------------
      256x256    |    1  |    1
      256x1024   |    5  |   17
      1024x256   |    6  |    3
      512x512    |   10  |    7
      512x2048   |   35  |  100
      2048x512   |   37  |   23
      1024x1024  |   68  |   50
      1024x4096  |  231  |  762
      4096x1024  |  262  |  174

Times are in milliseconds (ms).






In [50]:
# Benchmark for GPU
base_dim = [1024, 2048, 4096, 8192, 16384]
matrix_shapes = []
for n in base_dim:
    matrix_shapes.extend([(n, n), (n, 4 * n), (4 * n, n)])
print(matrix_shapes)

results = benchmark_qr(matrix_shapes, device="cuda")

compare = torch.utils.benchmark.Compare(results)
compare.trim_significant_figures()
compare.print()

[(1024, 1024), (1024, 4096), (4096, 1024), (2048, 2048), (2048, 8192), (8192, 2048), (4096, 4096), (4096, 16384), (16384, 4096), (8192, 8192), (8192, 32768), (32768, 8192), (16384, 16384), (16384, 65536), (65536, 16384)]


100%|██████████| 15/15 [01:19<00:00,  5.30s/it]

[--- QR vs CQR for device=cuda ----]
                   |   QR   |   CQR 
1 threads: -------------------------
      1024x1024    |     5  |      1
      1024x4096    |     7  |      9
      4096x1024    |    11  |      2
      2048x2048    |    12  |      4
      2048x8192    |    17  |     30
      8192x2048    |    35  |      5
      4096x4096    |    38  |     11
      4096x16384   |    56  |    156
      16384x4096   |   117  |     27
      8192x8192    |   130  |     57
      8192x32768   |   246  |   1104
      32768x8192   |   451  |    200
      16384x16384  |   637  |    400
      16384x65536  |  1457  |  14940
      65536x16384  |  2644  |   1408

Times are in milliseconds (ms).






In [52]:
# Print ratio of times
# This assumes that the results are in pairs
for i in range(len(results) // 2):
    qr_result = results[i * 2]
    cqr_result = results[i * 2 + 1]
    assert qr_result.task_spec.sub_label == cqr_result.task_spec.sub_label

    shape = qr_result.task_spec.sub_label
    qr_time = 1e6 * qr_result.mean
    cqr_time = 1e6 * cqr_result.mean
    ratio = qr_time / cqr_time
    print(f"{shape:>12}\tQR={qr_time:7.0f}\tCQR={cqr_time:7.0f}\tratio={ratio:2.2f}")

   1024x1024	QR=   5129	CQR=   1382	ratio=3.71
   1024x4096	QR=   6849	CQR=   9173	ratio=0.75
   4096x1024	QR=  11406	CQR=   1655	ratio=6.89
   2048x2048	QR=  12195	CQR=   3532	ratio=3.45
   2048x8192	QR=  17024	CQR=  30303	ratio=0.56
   8192x2048	QR=  34889	CQR=   5440	ratio=6.41
   4096x4096	QR=  38003	CQR=  11342	ratio=3.35
  4096x16384	QR=  56436	CQR= 156031	ratio=0.36
  16384x4096	QR= 116669	CQR=  26653	ratio=4.38
   8192x8192	QR= 134117	CQR=  56549	ratio=2.37
  8192x32768	QR= 246151	CQR=1104199	ratio=0.22
  32768x8192	QR= 450815	CQR= 179505	ratio=2.51
 16384x16384	QR= 637387	CQR= 394717	ratio=1.61
 16384x65536	QR=1456792	CQR=14937753	ratio=0.10
 65536x16384	QR=2643666	CQR=1408217	ratio=1.88


In [172]:
@torch.compile(dynamic=True)
def simulate_muon(G):
    a, b, c = 3.4445, -4.7750, 2.0315
    X = G.bfloat16()
    X /= X.norm() + 1e-8
    if G.size(0) > G.size(1):
        X = X.T
    for _ in range(5):
        A = X @ X.T
        B = b * A + c * (A @ A)
        X = a * X + B @ X
    if G.size(0) > G.size(1):
        X = X.T
    return X


@torch.compile(dynamic=True)
def simulate_dion(G, Q):
    P = G @ Q
    P = torch.linalg.qr(P)[0]
    R = G.T @ P
    M = G - P @ R.T
    Q = R / (R.sum(dim=0, keepdim=True) + 1e-8)
    return P @ Q.T


@torch.compile(dynamic=True)
def simulate_cholesky_dion(G, Q):
    P = G @ Q
    A = P.T @ P
    # A = A + torch.eye(A.size(0), device=A.device)
    L, info = torch.linalg.cholesky_ex(A, upper=True)
    print("info", info)
    P = torch.linalg.solve_triangular(L, P, upper=True, left=False)
    # A = P.T @ P
    # L, _ = torch.linalg.cholesky_ex(A, upper=True)
    # P = torch.linalg.solve_triangular(L, P, upper=True, left=False)
    R = G.T @ P
    M = G - P @ R.T
    Q = R / (R.sum(dim=0, keepdim=True) + 1e-8)
    return P @ Q.T

In [219]:
m, n = 1024, 1024
r = min(m, n) // 1
G = torch.randn(m, n).to(device="cuda")
Q = torch.randn(n, r).to(device="cuda")

P = G @ Q
I = torch.eye(m, device=G.device)
print("cond G", torch.linalg.cond(G))
print("cond P", torch.linalg.cond(P))
print("cond G@G^T", torch.linalg.cond(G @ G.T))
print("cond P@P^T", torch.linalg.cond(P @ P.T))
print("cond I + P@P^T", torch.linalg.cond(I + P @ P.T))

muon_update = simulate_muon(G)
dion_update = simulate_dion(G, Q)
cqr_dion_update = simulate_cholesky_dion(G, Q)


def test_orthogonality(X):
    return (X.T @ X - torch.eye(X.size(1), device=X.device)).norm()


print((muon_update - dion_update).norm())
print((dion_update - cqr_dion_update).norm())
print(test_orthogonality(muon_update))
print(test_orthogonality(dion_update))
print(test_orthogonality(cqr_dion_update))
print(G.dtype)

cond G tensor(2524.2595, device='cuda:0')
cond P tensor(103697.7578, device='cuda:0')
cond G@G^T tensor(8432830., device='cuda:0')
cond P@P^T tensor(1.3741e+08, device='cuda:0')
cond I + P@P^T tensor(8750351., device='cuda:0')
info tensor(1022, device='cuda:0', dtype=torch.int32)
tensor(1307.8210, device='cuda:0')
tensor(nan, device='cuda:0')
tensor(11.1049, device='cuda:0')
tensor(1264734.7500, device='cuda:0')
tensor(nan, device='cuda:0')
torch.float32


In [236]:
def rcqr(X):
    m, n = X.shape
    k = math.ceil(1.25 * n)
    theta = torch.randn(k, m, device=X.device)
    theta.mul_(1.0 / math.sqrt(k))
    P = theta @ X
    _, R = torch.linalg.qr(P, mode="r")
    # R, _ = torch.linalg.cholesky_ex(P.T @ P, upper=True)
    Q = torch.linalg.solve_triangular(R, X, upper=True, left=False)

    # A = Q.T @ Q
    # L, info = torch.linalg.cholesky_ex(A, upper=True)
    # Q = torch.linalg.solve_triangular(L, Q, upper=True, left=False)
    return Q


def scqr(X):
    A = X.T @ X
    A = A + torch.eye(A.size(0), device=A.device)
    L, info = torch.linalg.cholesky_ex(A, upper=True)
    Q = torch.linalg.solve_triangular(L, X, upper=True, left=False)
    return Q


print(torch.linalg.cond(rcqr(G)))
print(torch.linalg.cond(rcqr(P)))
print(torch.linalg.cond(rcqr(P @ P.T)))
print()

print(torch.linalg.cond(scqr(G)))
print(torch.linalg.cond(scqr(P)))
print(torch.linalg.cond(scqr(P @ P.T)))
print()

print(torch.linalg.cond(orthogonalize_QR(P)))
print(torch.linalg.cond(orthogonalize_CQR(P)))
print(torch.linalg.cond(orthogonalize_NS(P).float()))

tensor(17.4587, device='cuda:0')
tensor(18.2574, device='cuda:0')
tensor(2186.3999, device='cuda:0')

tensor(39.6709, device='cuda:0')
tensor(39.2988, device='cuda:0')
tensor(nan, device='cuda:0')

tensor(1.0003, device='cuda:0')
tensor(nan, device='cuda:0')
tensor(2719.0166, device='cuda:0')


In [None]:
matrix shape (n, 4*n)
param shape (4*n, n)
colwise parallel (4*n/TP, n)


linear (col, row)
colwise parallel shard(0)
rowwise parallel shard(1)
fsdp shard(0) default

linear (col, row)
colwise (TP, DP)
rowwise (DP, TP)

paper (row, col)
colwise (DP, TP) (n, 4n)
rowwise (TP, DP) (4n, n)




In [8]:
n = 2
a = torch.randn(4 * n, n)
q, r = torch.linalg.qr(a)
print(q.shape, r.shape)
print(a)
print(q)
print(r)

torch.Size([8, 2]) torch.Size([2, 2])
tensor([[ 0.3131,  1.7820],
        [ 1.7147,  0.8314],
        [ 0.4297,  1.3404],
        [-1.0639, -0.6761],
        [-0.5937, -0.7386],
        [ 0.3421, -1.6879],
        [ 0.0020, -1.2224],
        [ 1.4354, -0.8162]])
tensor([[-0.1193,  0.5077],
        [-0.6535,  0.1025],
        [-0.1638,  0.3653],
        [ 0.4055, -0.1117],
        [ 0.2263, -0.1707],
        [-0.1304, -0.5356],
        [-0.0008, -0.3669],
        [-0.5471, -0.3678]])
tensor([[-2.6238, -0.7493],
        [ 0.0000,  3.3338]])


In [28]:
m = torch.randn(16, 2)
a = m.chunk(4, dim=0)

qr0 = [torch.linalg.qr(chunk) for chunk in a]
q0 = [qr.Q for qr in qr0]
r0 = [qr.R for qr in qr0]
print("q0 shape", q0[0].shape)
print("r0 shape", r0[0].shape)

r1 = torch.cat(r0, dim=0).chunk(2, dim=0)
qr1 = [torch.linalg.qr(chunk) for chunk in r1]
q1 = [qr.Q for qr in qr1]
r1 = [qr.R for qr in qr1]
print("q1 shape", q1[0].shape)
print("r1 shape", r1[0].shape)

r2 = torch.cat(r1, dim=0).chunk(1, dim=0)
qr1 = [torch.linalg.qr(chunk) for chunk in r2]
q2 = [qr.Q for qr in qr1]
r2 = [qr.R for qr in qr1]
print("q2 shape", q2[0].shape)
print("r2 shape", r2[0].shape)

print("r2", r2[0])

q0 shape torch.Size([4, 2])
r0 shape torch.Size([2, 2])
q1 shape torch.Size([4, 2])
r1 shape torch.Size([2, 2])
q2 shape torch.Size([4, 2])
r2 shape torch.Size([2, 2])
r2 tensor([[4.2202, 2.3031],
        [0.0000, 2.4578]])


In [36]:
def cholesky_qr(X):
    R, _ = torch.linalg.cholesky_ex(X.T @ X, upper=True)
    Q = torch.linalg.solve_triangular(R, X, upper=True, left=False)
    return Q, R


m = torch.randn(100, 8)
Q, R = cholesky_qr(m)
Q_, R_ = torch.linalg.qr(m)

print((m - Q @ R).abs().max())
print((m - Q_ @ R_).abs().max())

tensor(4.7684e-07)
tensor(1.0133e-06)
