In [1]:
import math
import torch
from typing import Tuple

In [2]:
def flash_qr(
    A: torch.Tensor, block_size: int = 16
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Block wise Gram Schmidt QR
    Vector projections across blocks can be done with matrix multiplication
    If block size is small enough, the current block can remain in GPU cache
    """
    m, n = A.shape
    assert n % block_size == 0, "n must be divisible by block_size"
    num_blocks = n // block_size

    Q = A.clone()
    R = torch.zeros((n, n), device=A.device, dtype=A.dtype)

    # Iterate over all output blocks
    for block_idx in range(num_blocks):
        start_col = block_idx * block_size
        end_col = start_col + block_size
        block = Q[:, start_col:end_col]

        # Previous blocks are already orthogonalized
        # Batch orthogonalize current block with respect to previous blocks
        for prev_block_idx in range(block_idx):
            prev_start = prev_block_idx * block_size
            prev_end = prev_start + block_size
            prev_block = Q[:, prev_start:prev_end]

            S = prev_block.T @ block
            block -= prev_block @ S
            R[prev_start:prev_end, start_col:end_col] = S

        # Sequentially orthogonalize each vector in the current block
        R_block = R[start_col:end_col, start_col:end_col]
        for j in range(block_size):
            # Subtract projection of each previous vector
            for i in range(j):
                dot = torch.dot(block[:, i], block[:, j])
                block[:, j] -= dot * block[:, i]
                R_block[i, j] = dot

            # Normalize the current vector
            norm = torch.linalg.norm(block[:, j])
            block[:, j] = block[:, j] / norm
            R_block[j, j] = norm

    return Q, R


def check_qr(A, Q, R):
    A = A.to(dtype=torch.float64)
    Q = Q.to(dtype=torch.float64)
    R = R.to(dtype=torch.float64)

    # Check that Q@R = A
    err = A - Q @ R
    err_rms_norm = torch.sqrt(torch.mean(err**2))
    print(f"RMS norm of A - Q@R: {err_rms_norm:.2e}")
    print(f"Max absolute error: {torch.max(torch.abs(err)):.2e}")

    # Check that Q is orthogonal
    I = torch.eye(Q.shape[1], device=Q.device, dtype=Q.dtype)
    err = Q.T @ Q - I
    rms_norm = torch.sqrt(torch.mean(err**2))
    print(f"RMS norm of Q^T@Q - I: {rms_norm:.2e}")

    # Check that R is upper triangular
    R_triu = torch.triu(R)
    print(f"R is triangular: {torch.allclose(R, R_triu)}")

    print()


def check_ortho(Q: torch.Tensor):
    # Check that Q is orthogonal
    Q = Q.to(dtype=torch.float64)
    I = torch.eye(Q.shape[1], device=Q.device, dtype=Q.dtype)
    err = Q.T @ Q - I
    rms_norm = torch.sqrt(torch.mean(err**2))
    print(f"RMS norm of Q^T@Q - I: {rms_norm:.2e}")

In [5]:
d = 1024
A = torch.randn(d, d, dtype=torch.float32, device="cuda")
# A = A @ A.T @ A
# A = A / torch.sqrt(torch.mean(A**2))
A_f16 = A.to(dtype=torch.bfloat16)
A = A_f16.to(dtype=torch.float32)
print(f"Condition number of A: {torch.linalg.cond(A):.2e}\n")


print("===Checking QR decomposition with torch.linalg.qr")
Q, R = torch.linalg.qr(A)
check_qr(A, Q, R)

print("===Checking QR decomposition with flash_qr")
Q, R = flash_qr(A, block_size=16)
check_qr(A, Q, R)

print("===Checking QR decomposition with flash_qr with bfloat16")
Q, R = flash_qr(A_f16, block_size=16)
check_qr(A, Q, R)

Condition number of A: 6.26e+03

===Checking QR decomposition with torch.linalg.qr
RMS norm of A - Q@R: 5.15e-07
Max absolute error: 4.56e-06
RMS norm of Q^T@Q - I: 2.21e-08
R is triangular: True

===Checking QR decomposition with flash_qr
RMS norm of A - Q@R: 1.35e-07
Max absolute error: 1.64e-06
RMS norm of Q^T@Q - I: 5.51e-07
R is triangular: True

===Checking QR decomposition with flash_qr with bfloat16
RMS norm of A - Q@R: 8.47e-03
Max absolute error: 1.05e-01
RMS norm of Q^T@Q - I: 3.22e-03
R is triangular: True



In [6]:
def solve_triangular(R: torch.Tensor, A: torch.Tensor) -> torch.Tensor:
    # Solve the triangular system Q@R = A for Q
    m, n = A.shape
    assert R.shape[0] == R.shape[1], "R must be square"
    assert R.shape[0] == n, "Incompatible shapes for triangular solve"

    Q = A.clone()

    for col in range(n):
        for i in range(col):
            Q[:, col] -= R[i, col] * Q[:, i]
        Q[:, col] /= R[col, col]

    return Q


def block_solve_triangular(
    R: torch.Tensor, A: torch.Tensor, block_size: int = 16
) -> torch.Tensor:
    # Solve the triangular system Q@R = A for Q
    m, n = A.shape
    assert R.shape[0] == R.shape[1], "R must be square"
    assert R.shape[0] == n, "Incompatible shapes for triangular solve"
    assert n % block_size == 0, "n must be divisible by block_size"
    num_blocks = n // block_size

    Q = A.clone()

    for block_idx in range(num_blocks):
        start_col = block_idx * block_size
        end_col = start_col + block_size
        block = Q[:, start_col:end_col].clone()

        # Blockwise subtract previously solved values
        for prev_block_idx in range(block_idx):
            prev_start = prev_block_idx * block_size
            prev_end = prev_start + block_size
            prev_block = Q[:, prev_start:prev_end]
            R_block = R[prev_start:prev_end, start_col:end_col]
            block -= prev_block @ R_block

        # Sequentially process the current block of R along the diagonal
        R_block = R[start_col:end_col, start_col:end_col]
        for j in range(block_size):
            for i in range(j):
                block[:, j] -= R_block[i, j] * block[:, i]
            block[:, j] /= R_block[j, j]

        Q[:, start_col:end_col] = block

    return Q

In [None]:
d = 1024
A = torch.randn(d, d, dtype=torch.float32, device="cuda")
# A = A @ A.T @ A
# A = A / torch.sqrt(torch.mean(A**2))
Q, R = torch.linalg.qr(A)

# Solve using torch.linalg.solve_triangular
print("===Checking solve_triangular")
Q_solve = torch.linalg.solve_triangular(R, A, upper=True, left=False)
check_qr(A, Q_solve, R)

# Solve in float32
print("===Checking block_solve_triangular")
Q_solve = block_solve_triangular(R, A, block_size=16)
check_qr(A, Q_solve, R)

# Solve in bfloat16
print("===Checking block_solve_triangular with bfloat16")
A = A.to(dtype=torch.bfloat16)
R = R.to(dtype=torch.bfloat16)
Q_solve = block_solve_triangular(R, A, block_size=16)
check_qr(A, Q_solve, R)

RMS norm of A - Q@R: 1.78e-07
Max absolute error: 1.82e-06
RMS norm of Q^T@Q - I: 8.31e-07
R is triangular: True

RMS norm of A - Q@R: 1.39e-07
Max absolute error: 1.64e-06
RMS norm of Q^T@Q - I: 8.22e-07
R is triangular: True

RMS norm of A - Q@R: 8.46e-03
Max absolute error: 1.48e-01
RMS norm of Q^T@Q - I: 1.73e-01
R is triangular: True



In [18]:
def flash_orthogonalize(
    A: torch.Tensor, SA: torch.Tensor, block_size: int = 16
) -> torch.Tensor:
    """
    Fused QR and solve triangular. The full R matrix is never materialized.
        A = input matrix or shard of the input matrix
        SA = random sketch of the whole A matrix
        block_size = number of columns to process at once
    Return orthogonalized copy of A
    """
    m, n = A.shape
    k, n2 = SA.shape
    assert n == n2, f"Incompatible shapes A={A.shape} and SA={SA.shape}"
    assert n % block_size == 0, "n must be divisible by block_size"
    num_blocks = n // block_size

    Q = SA.clone()
    A = A.clone()

    # Iterate over all output blocks
    for block_idx in range(num_blocks):
        start_col = block_idx * block_size
        end_col = start_col + block_size
        Q_block = Q[:, start_col:end_col]
        A_block = A[:, start_col:end_col]

        # Previous blocks of Q are already orthogonalized
        for prev_block_idx in range(block_idx):
            prev_start = prev_block_idx * block_size
            prev_end = prev_start + block_size
            Q_prev = Q[:, prev_start:prev_end]
            A_prev = A[:, prev_start:prev_end]

            # Compute one block of the R matrix
            R_block = Q_prev.T @ Q_block
            # Subtract projection of previous vectors from current Q block
            Q_block -= Q_prev @ R_block
            # Update the output matrix
            A_block -= A_prev @ R_block

        # Sequentially orthogonalize each vector in the current block
        for j in range(block_size):
            # Subtract projection of each previous vector
            for i in range(j):
                dot = torch.dot(Q_block[:, i], Q_block[:, j])
                Q_block[:, j] -= dot * Q_block[:, i]
                A_block[:, j] -= dot * A_block[:, i]

            # Normalize the current vector
            norm = torch.linalg.norm(Q_block[:, j])
            Q_block[:, j] = Q_block[:, j] / norm
            A_block[:, j] = A_block[:, j] / norm

    return A

In [30]:
m, n = 1024, 1024
k = math.ceil(1.5 * n)
A = torch.randn(m, n, dtype=torch.float32, device="cuda")
# A = A @ A.T @ A
# A = A / torch.sqrt(torch.mean(A**2))

S = torch.empty(k, m, dtype=A.dtype, device="cuda").normal_(std=math.sqrt(n))
SA = S @ A

print("\nOrthogonalize directly")
Q, _ = flash_qr(A, block_size=16)
check_ortho(Q)

print("\nOrthogonalize with random sketch")
Q = flash_orthogonalize(A, SA, block_size=16)
check_ortho(Q)

print("\nOrthogonalize with Householder QR")
Q, _ = torch.linalg.qr(A)
check_ortho(Q)

A = A.to(dtype=torch.bfloat16)
S = S.to(dtype=torch.bfloat16)
SA = S @ A

print("\nOrthogonalize directly in bfloat16")
Q, _ = flash_qr(A, block_size=16)
check_ortho(Q)

print("\nOrthogonalize with random sketch in bfloat16")
Q = flash_orthogonalize(A, SA, block_size=16)
check_ortho(Q)


Orthogonalize directly
RMS norm of Q^T@Q - I: 4.61e-07

Orthogonalize with random sketch
RMS norm of Q^T@Q - I: 3.12e-02

Orthogonalize with Householder QR
RMS norm of Q^T@Q - I: 1.89e-08

Orthogonalize directly in bfloat16
RMS norm of Q^T@Q - I: 3.27e-03

Orthogonalize with random sketch in bfloat16
RMS norm of Q^T@Q - I: 3.12e-02
