In [1]:
import jax
import jax.numpy as jnp
key = jax.random.key(69)

2025-07-28 14:57:31.754790: W external/xla/xla/service/gpu/nvptx_compiler.cc:763] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.9.86). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
B = 32
d = 11
FROBENIUS_NORM_GAIN = 6.9
FROBENIUS_NORM_BIAS = 4.2

start = jax.random.normal(key, (B, d))
end = jax.random.normal(key, (B, d))

# create masks
coeff_vector_mask = jnp.arange(d, 0, -1)
coeff_vector_mask = (
    coeff_vector_mask**2
)  # this is just a test, comment this out later
coeff_vector_mask_col = jnp.expand_dims(coeff_vector_mask, 1)  # Shape: (d, 1)
coeff_vector_mask_row = jnp.expand_dims(coeff_vector_mask, 0)  # Shape: (1, d)
coeff_matrix_mask = jnp.minimum(
    coeff_vector_mask_col, coeff_vector_mask_row
)  # Shape: (d, d)

In [3]:
# create masks
coeff_vector_mask = jnp.arange(d, 0, -1)
coeff_vector_mask = (
    coeff_vector_mask**2
)  # this is just a test, comment this out later
coeff_vector_mask_col = jnp.expand_dims(coeff_vector_mask, 1)  # Shape: (d, 1)
coeff_vector_mask_row = jnp.expand_dims(coeff_vector_mask, 0)  # Shape: (1, d)
coeff_matrix_mask = jnp.minimum(
    coeff_vector_mask_col, coeff_vector_mask_row
)  # Shape: (d, d)

In [4]:
### old implementation of loss
averaged_inner_prod = (
    jnp.einsum("bi, bi -> i", start, end)
    / start.shape[0]
)
approximation_error_loss = -2 * jnp.sum(averaged_inner_prod * coeff_vector_mask)

product_1 = jnp.einsum(
    "bj, bk -> bjk", start, end
)
averaged_product_1 = jnp.mean(product_1, axis=0)
product_2 = jnp.einsum(
    "bj, bk -> bjk", start, start
)
averaged_product_2 = jnp.mean(product_2, axis=0)
orthogonality_loss = jnp.sum(
    coeff_matrix_mask * averaged_product_1 * averaged_product_2
)

loss = approximation_error_loss + orthogonality_loss
print(f"old loss computation: {loss}")

old loss computation: -364.30926513671875


In [5]:
### new implementation of loss
cov_matrix = start.T @ start / start.shape[0]
corr_matrix = start.T @ end / start.shape[0]
scaled_corr_matrix = coeff_matrix_mask * corr_matrix

raw_loss = -2 * jnp.trace(scaled_corr_matrix) + jnp.sum(cov_matrix * scaled_corr_matrix)

print(f"new loss computation: {raw_loss}")

new loss computation: -364.3092041015625


In [6]:
def compute_approximation_error_loss(
        start_representation, end_representation
    ):
    """
    Compute the approximation error loss between the start and end representations.
    This is the -2 Sigma <g_l, Tf_l> term, but as we are working with EVD, it is just <f_l, Tf_l>.
    Recall from the derivation that this is equivalent to E[(f_l - Tf_l)^2], where we sample over
    transitions and take an expectation over p(x, x').

    Args:
        start_representation: The representation of the start state.
        end_representation: The representation of the end state.

    Returns:
        The approximation error loss.
    """
    print(
        f"for approximation error loss: start_representation: {start_representation.shape}"
    )
    print(
        f"for approximation error loss: end_representation: {end_representation.shape}"
    )
    # Compute the loss term
    # loss = -2 * ((start_representation - end_representation)**2).mean() # need to recheck this...
    loss = 0
    coeff_mask = jnp.arange(d, 0, -1)  # for joint LoRA
    # coeff_mask = coeff_mask**2  # this is just a test, comment this out later
    if len(start_representation.shape) == 1 and len(end_representation.shape) == 1:
        print("one dimensional approximation error loss")

        # computing loss via squares
        squared_diff = (start_representation - end_representation) ** 2
        joint_squared_diff = squared_diff * coeff_mask
        loss = -(jnp.sum(joint_squared_diff))

        # computing loss via straight inner product
        inner_prod = jnp.dot(start_representation, end_representation)
        loss = -2 * inner_prod
    else:
        print("batched approximation error loss")
        # shapes (1024, 11) and (1024, 11)

        # computing loss via squares
        # squared_diff = (start_representation - end_representation) ** 2
        # joint_squared_diff = squared_diff * coeff_mask
        # loss = - jnp.mean(jnp.sum(joint_squared_diff, axis=1))

        # computing loss via straight inner product
        # inner_prod = jnp.einsum('bi, bi, i -> b', start_representation, end_representation, coeff_mask)
        # loss = -2 * jnp.mean(inner_prod)

        averaged_inner_prod = (
            jnp.einsum("bi, bi -> i", start_representation, end_representation)
            / start_representation.shape[0]
        )
        loss = -2 * jnp.sum(averaged_inner_prod * coeff_mask)


    return loss

def compute_orthogonality_loss(
    representation_1, representation_2, representation_2_end
):
    """
    Compute the orthogonality loss between the two representations.

    If we are working with LoRA:
        This is the Sigma_l Sigma_l' <f_l | f_l'> <g_l | g_l'> term.
        Since we are working with EVD, it is just Sigma_l <f_l | f_l'>^2.
    If we are working with OMM:
        This is Sigma_l Sigma_l' <f_l | Tf_l'> <f_l | f_l'> term.

    Args:
        representation_1: The first representation.
        representation_2: The second representation.

    Returns:
        The orthogonality loss.
    """
    loss = 0

    print(f"for orthogonality loss: representation_1: {representation_1.shape}")
    print(f"for orthogonality loss: representation_2: {representation_2.shape}")

    # Create the mask in a vectorized way
    coeff_vector_mask = jnp.arange(d, 0, -1)
    # coeff_vector_mask = (
    #     coeff_vector_mask**2
    # )  # this is just a test, comment this out later
    coeff_vector_mask_col = jnp.expand_dims(coeff_vector_mask, 1)  # Shape: (d, 1)
    coeff_vector_mask_row = jnp.expand_dims(coeff_vector_mask, 0)  # Shape: (1, d)
    coeff_matrix_mask = jnp.minimum(
        coeff_vector_mask_col, coeff_vector_mask_row
    )  # Shape: (d, d)

    # OMM loss case: Sigma_l Sigma_l' <f_l | Tf_l'> <f_l | f_l'>
    try:
        if len(representation_1.shape) == 1:
            print("one dimensional orthogonality loss")
            product_1 = jnp.einsum(
                "j, k -> jk", representation_2, representation_2_end
            )
            product_2 = jnp.einsum(
                "j, k -> jk", representation_1, representation_1
            )
            loss += jnp.sum(coeff_matrix_mask * product_1 * product_2)
        else:
            print("batched orthogonality loss")
            product_1 = jnp.einsum(
                "bj, bk -> bjk", representation_2, representation_2_end
            )
            averaged_product_1 = jnp.mean(product_1, axis=0)
            product_2 = jnp.einsum(
                "bj, bk -> bjk", representation_1, representation_1
            )
            averaged_product_2 = jnp.mean(product_2, axis=0)
            loss += jnp.sum(
                coeff_matrix_mask * averaged_product_1 * averaged_product_2
            )
    except:
        print(f"Shape of representation_1: {representation_1.shape}")
        print(f"Shape of representation_2: {representation_2.shape}")
        raise
    

    return loss

def compute_frobenius_norm_loss(representation, matrix_mask):
    """
    Computes the Frobenius norm loss between the representation and the identity matrix.
    
    This is ||rep.T @ rep - I||_2. Bias factor removed, bias is to be multiplied in during the
    actual loss computation.
    """
    return jnp.sum(
        matrix_mask
        * (
            representation.T @ representation / representation.shape[0]
            - jnp.eye(representation.shape[1])
        )
        ** 2
    )

In [7]:
key = jax.random.PRNGKey(696969)
start_representation = jax.random.normal(jax.random.PRNGKey(0), (32, 11))
end_representation = jax.random.normal(jax.random.PRNGKey(1), (32, 11))
start_representation_2 = jax.random.normal(jax.random.PRNGKey(2), (32, 11))
end_representation_2 = jax.random.normal(jax.random.PRNGKey(3), (32, 11))
constraint_start_representation = jax.random.normal(jax.random.PRNGKey(4), (32, 11))
constraint_end_representation = jax.random.normal(jax.random.PRNGKey(5), (32, 11))
print(start_representation.sum())
print(end_representation.sum())
print(start_representation_2.sum())
print(end_representation_2.sum())
print(constraint_start_representation.sum())
print(constraint_end_representation.sum())


40.224
-12.485597
-32.225197
17.889158
14.044348
6.9094186


In [8]:
### old loss computation
# Create the mask in a vectorized way
coeff_vector_mask = jnp.arange(d, 0, -1)
# coeff_vector_mask = (
#     coeff_vector_mask**2
# )  # this is just a test, comment this out later
coeff_vector_mask_col = jnp.expand_dims(coeff_vector_mask, 1)  # Shape: (d, 1)
coeff_vector_mask_row = jnp.expand_dims(coeff_vector_mask, 0)  # Shape: (1, d)
coeff_matrix_mask = jnp.minimum(
    coeff_vector_mask_col, coeff_vector_mask_row
)  # Shape: (d, d)

# Compute graph loss and regularization
approximation_error_loss = compute_approximation_error_loss(
    start_representation, end_representation
)
orthogonality_loss = compute_orthogonality_loss(
    constraint_start_representation,
    start_representation_2,
    end_representation_2,
)
frobenius_norm_loss = compute_frobenius_norm_loss(
    constraint_end_representation, # note: beforehand we were using start_representation. is there an important difference in sampling?
    coeff_matrix_mask,
)

old_loss = (
    FROBENIUS_NORM_GAIN * (approximation_error_loss + orthogonality_loss) +
    FROBENIUS_NORM_BIAS * frobenius_norm_loss
)
print(old_loss)

for approximation error loss: start_representation: (32, 11)
for approximation error loss: end_representation: (32, 11)
batched approximation error loss
for orthogonality loss: representation_1: (32, 11)
for orthogonality loss: representation_2: (32, 11)
batched orthogonality loss
132.10524


In [9]:
### new loss computation
# Create the mask in a vectorized way
coeff_vector_mask = jnp.arange(d, 0, -1)
coeff_matrix_mask = jnp.minimum(
    jnp.expand_dims(coeff_vector_mask, 1), jnp.expand_dims(coeff_vector_mask, 0)
)  # Shape: (d, d)

# V^T V
cov_matrix = constraint_start_representation.T @ constraint_start_representation / constraint_start_representation.shape[0] # [d, d]

# V^T AV
corr_matrix = start_representation.T @ end_representation / start_representation.shape[0] # [d, d]
corr_matrix_2 = start_representation_2.T @ end_representation_2 / start_representation_2.shape[0] # [d, d]
scaled_corr_matrix = coeff_matrix_mask * corr_matrix
scaled_corr_matrix_2 = coeff_matrix_mask * corr_matrix_2

# loss: trace(coeff_matrix_mask * (-2 corr_matrix + cov_matrix @ corr_matrix))
raw_loss = -2 * jnp.trace(scaled_corr_matrix) + jnp.sum(cov_matrix * scaled_corr_matrix_2)

# shift term (decorrelate this too)
frobenius_norm_loss = compute_frobenius_norm_loss(
    constraint_end_representation,
    coeff_matrix_mask,
)

# Compute total loss. NOTE: WITH NON-DEFAULT BIAS AND GAIN, THIS HAS TO BE WITH ORBITALS
loss = FROBENIUS_NORM_GAIN * raw_loss + FROBENIUS_NORM_BIAS * frobenius_norm_loss
print(loss)

132.10522


In [10]:
### check if matrix mask can be applied naively to frobenius norm loss
# || V^TV - I ||_2
# V should be (B, d)
small_mask = jnp.array(
    [[3, 2, 1],
     [2, 2, 1],
     [1, 1, 1]]
)
V = jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [3.15, 2.9, 8.1], [12.2, -6.9, 1.8]]) # (6, 3)
I = jnp.eye(3)

# brute force
brute_force_loss = 0
for i in range(1, 4):
    brute_force_loss += jnp.sum((V[:, :i].T @ V[:, :i] - I[:i, :i]) ** 2)
print(brute_force_loss)

# vectorized
print(jnp.sum(small_mask * (V.T @ V - I) ** 2))

# holy shit these aren't equal im a retard
# wait no these are equal i am NOT a retard

883047.0
883047.0


In [11]:
def _compute_frobenius_norm_loss(representation):
    """
    Compute the Frobenius norm loss between the representation and the identity matrix.

    Args:
        representation: The representation of the state. Shape (B, d)
        alpha: The alpha parameter.

    Returns:
        The Frobenius norm loss.
    """
    return jnp.sum(
        (
            representation.T @ representation / representation.shape[0]
            - jnp.eye(representation.shape[1])
        )
        ** 2
    )

def compute_approximation_error_loss(
    start_representation, end_representation
):
    """
    Compute the approximation error loss between the start and end representations.
    This is the -2 Sigma <g_l, Tf_l> term, but as we are working with EVD, it is just <f_l, Tf_l>.
    Recall from the derivation that this is equivalent to E[(f_l - Tf_l)^2], where we sample over
    transitions and take an expectation over p(x, x').

    Args:
        start_representation: The representation of the start state.
        end_representation: The representation of the end state.

    Returns:
        The approximation error loss.
    """
    print(
        f"for approximation error loss: start_representation: {start_representation.shape}"
    )
    print(
        f"for approximation error loss: end_representation: {end_representation.shape}"
    )
    # Compute the loss term
    # loss = -2 * ((start_representation - end_representation)**2).mean() # need to recheck this...
    loss = 0
    if len(start_representation.shape) == 1 and len(end_representation.shape) == 1:
        print("one dimensional approximation error loss")

        # computing loss via squares
        # squared_diff = (start_representation - end_representation) ** 2
        # joint_squared_diff = squared_diff
        # loss = -(jnp.sum(joint_squared_diff))

        # computing loss via straight inner product
        inner_prod = jnp.dot(start_representation, end_representation)
        loss = -2 * inner_prod
    else:
        print("batched approximation error loss")
        # shapes (1024, 11) and (1024, 11)

        averaged_inner_prod = (
            jnp.einsum("bi, bi -> i", start_representation, end_representation)
            / start_representation.shape[0]
        )
        loss = -2 * jnp.sum(averaged_inner_prod)

    return loss

def compute_orthogonality_loss(
    representation_1, representation_2, representation_2_end=None
):
    """
    Compute the orthogonality loss between the two representations.

    If we are working with LoRA:
        This is the Sigma_l Sigma_l' <f_l | f_l'> <g_l | g_l'> term.
        Since we are working with EVD, it is just Sigma_l <f_l | f_l'>^2.
    If we are working with OMM:
        This is Sigma_l Sigma_l' <f_l | Tf_l'> <f_l | f_l'> term.

    Args:
        representation_1: The first representation.
        representation_2: The second representation.

    Returns:
        The orthogonality loss.
    """
    loss = 0

    # OMM loss case: Sigma_l Sigma_l' <f_l | Tf_l'> <f_l | f_l'>
    try:
        if len(representation_1.shape) == 1:
            print("one dimensional orthogonality loss")
            product_1 = jnp.einsum(
                "j, k -> jk", representation_2, representation_2_end
            )
            product_2 = jnp.einsum(
                "j, k -> jk", representation_1, representation_1
            )
            loss += jnp.sum(product_1 * product_2)
        else:
            print("batched orthogonality loss")
            product_1 = (
                jnp.einsum(
                    "bj, bk -> jk", representation_2, representation_2_end
                )
                / representation_2.shape[0]
            )
            product_2 = (
                jnp.einsum("bj, bk -> jk", representation_1, representation_1)
                / representation_1.shape[0]
            )
            loss += jnp.sum(product_1 * product_2)
    except:
        print(f"Shape of representation_1: {representation_1.shape}")
        print(f"Shape of representation_2: {representation_2.shape}")
        raise

    return loss

def _build_stop_grad_encoding(
    encoding: jnp.ndarray, top_i: int, **kwargs
) -> jnp.ndarray:
    """
    Builds a top_i state encoding where everything but the top_ith eigenfunction is frozen.

    Args:
        encodings: The state encoding.
        top_i: How many of the top eigenfunctions to use for the loss function.

    Returns:
        The top_i state encoding.
    """
    # encodings are either of shape (d,) or (b, d)
    if len(encoding.shape) == 1:
        mask_function = lambda encoding: jnp.concatenate(
            [
                jax.lax.stop_gradient(encoding[: top_i - 1]),
                encoding[top_i - 1 : top_i],
            ],
            axis=0,
        )
    else:
        mask_function = lambda encoding: jnp.concatenate(
            [
                jax.lax.stop_gradient(encoding[:, : top_i - 1]),
                encoding[:, top_i - 1 : top_i],
            ],
            axis=1,
        )

    return mask_function(encoding)

def _compute_loss_function_component(top_i):
    """
    Computes a single component of the loss function (the top_i loss).

    Args:
        params: The parameters of the model.
        state_encoding: The state encoding.
        top_i: How many of the top eigenfunctions to use for the loss function.

    Returns:
        Tuple of (loss, approximation_error_loss, orthogonality_loss)
    """

    approximation_error_loss = compute_approximation_error_loss(
        _build_stop_grad_encoding(start_representation, top_i),
        _build_stop_grad_encoding(end_representation, top_i),
    )
    orthogonality_loss = compute_orthogonality_loss(
        _build_stop_grad_encoding(constraint_start_representation, top_i),
        (
            _build_stop_grad_encoding(start_representation_2, top_i)
        ),
        representation_2_end=(
            _build_stop_grad_encoding(end_representation_2, top_i)
        ),
    )
    frobenius_norm_loss = _compute_frobenius_norm_loss(
        _build_stop_grad_encoding(start_representation, top_i),
    )
    loss = approximation_error_loss + orthogonality_loss + FROBENIUS_NORM_BIAS * frobenius_norm_loss

    return loss

In [12]:
B = 32
d = 11
FROBENIUS_NORM_GAIN = 1.0
FROBENIUS_NORM_BIAS = 1.0
# create masks
coeff_vector_mask = jnp.arange(d, 0, -1)
coeff_vector_mask = (
    coeff_vector_mask**2
)  # this is just a test, comment this out later
coeff_vector_mask_col = jnp.expand_dims(coeff_vector_mask, 1)  # Shape: (d, 1)
coeff_vector_mask_row = jnp.expand_dims(coeff_vector_mask, 0)  # Shape: (1, d)
coeff_matrix_mask = jnp.minimum(
    coeff_vector_mask_col, coeff_vector_mask_row
)  # Shape: (d, d)

start_representation = jax.random.normal(jax.random.PRNGKey(0), (32, 11))
end_representation = jax.random.normal(jax.random.PRNGKey(1), (32, 11))
start_representation_2 = jax.random.normal(jax.random.PRNGKey(2), (32, 11))
end_representation_2 = jax.random.normal(jax.random.PRNGKey(3), (32, 11))
constraint_start_representation = jax.random.normal(jax.random.PRNGKey(4), (32, 11))
constraint_end_representation = jax.random.normal(jax.random.PRNGKey(5), (32, 11))
print(start_representation.sum())
print(end_representation.sum())
print(start_representation_2.sum())
print(end_representation_2.sum())
print(constraint_start_representation.sum())
print(constraint_end_representation.sum())

40.224
-12.485597
-32.225197
17.889158
14.044348
6.9094186


In [13]:
### MUST RUN TOP CELL BEFORE THIS
# old computation of loss
total_loss = 0
total_approximation_error_loss = 0
total_orthogonality_loss = 0
total_frobenius_norm_loss = 0
for top_i in range(1, d + 1):
    curr_coef = 1 # d - top_i + 1
    curr_loss = _compute_loss_function_component(top_i)
    total_loss += curr_coef * curr_loss
    
print(total_loss)

for approximation error loss: start_representation: (32, 1)
for approximation error loss: end_representation: (32, 1)
batched approximation error loss
batched orthogonality loss
for approximation error loss: start_representation: (32, 2)
for approximation error loss: end_representation: (32, 2)
batched approximation error loss
batched orthogonality loss
for approximation error loss: start_representation: (32, 3)
for approximation error loss: end_representation: (32, 3)
batched approximation error loss
batched orthogonality loss
for approximation error loss: start_representation: (32, 4)
for approximation error loss: end_representation: (32, 4)
batched approximation error loss
batched orthogonality loss
for approximation error loss: start_representation: (32, 5)
for approximation error loss: end_representation: (32, 5)
batched approximation error loss
batched orthogonality loss
for approximation error loss: start_representation: (32, 6)
for approximation error loss: end_representation: 

In [14]:
### NEW COMPUTATION
def _compute_indexwise_products(f, g):
    assert f.shape[0] == g.shape[0]
    return jnp.einsum("bi, bj -> ij", f, g) / f.shape[0]

def _generate_masks(curr_d):
    """
    Generates a monotonically decreasing mask in dimension index - both vector and matrix
    masks. Namely, they look like:
    [d, d - 1, ..., 2, 1]
    [[3, 2, 1]
        [2, 2, 1]
        [1, 1, 1]] (for the d = 3 case)

    Returns:
        Tuple of the form (vector_mask, matrix_mask)
    """
    vector_mask = jnp.arange(curr_d, 0, -1)
    matrix_mask = jnp.minimum(
        jnp.expand_dims(vector_mask, 1), jnp.expand_dims(vector_mask, 0)
    )

    return vector_mask, matrix_mask


In [16]:
envs = ["GridMaze-7", "GridMaze-9", "GridMaze-17", "GridMaze-19", "GridMaze-26", "GridMaze-32", "GridRoom-1", "GridRoom-4", "GridRoom-16", "GridRoom-32", "GridRoom-64", "GridRoomSym-4"]

import numpy as np

for env in envs:
    eigval_path = f"src/env/grid/eigval/{env}.npz"
    try:
        data = np.load(eigval_path)
        eigvals = data["eigval"] if "eigval" in data else data[data.files[0]]
        print(f"{env}: {len(eigvals)} eigenvalues")
    except Exception as e:
        print(f"Could not load eigenvalues for {env}: {e}")


GridMaze-7: 11 eigenvalues
GridMaze-9: 13 eigenvalues
GridMaze-17: 87 eigenvalues
GridMaze-19: 161 eigenvalues
GridMaze-26: 388 eigenvalues
GridMaze-32: 475 eigenvalues
GridRoom-1: 225 eigenvalues
GridRoom-4: 104 eigenvalues
GridRoom-16: 271 eigenvalues
GridRoom-32: 544 eigenvalues
GridRoom-64: 1088 eigenvalues
GridRoomSym-4: 104 eigenvalues
