In [1]:
%matplotlib inline

from typing import *
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import functools
import math
import random


from collections import namedtuple
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn, optim
import torch.nn.functional as F
import torch
import functools
import os

```
x = U S V^T
xn = U V^T
xn.T @ x = V S V^T
x @ xn.T = U S U^T
```


In [13]:
x = torch.randn(345, 543).to(torch.bfloat16).cuda()

In [136]:
from torch import Tensor
from fractions import Fraction

import logging
logger: logging.Logger = logging.getLogger('test_muon')

def _matrix_root_eigen(
    A: Tensor,
    root: Union[Fraction, int],
    epsilon: float = 0.0,
    exponent_multiplier: float = 1.0,
    make_positive_semidefinite: bool = True,
    retry_double_precision: bool = True,
) -> Tuple[Tensor, Tensor, Tensor]:
    """Compute matrix inverse root using eigendecomposition of symmetric positive (semi-)definite matrix.

            A^{-1/r} = Q L^{-1/r} Q^T

    Assumes matrix A is symmetric.

    Args:
        A (Tensor): Square matrix of interest.
        root (int): Root of interest. Any natural number.
        epsilon (float): Adds epsilon * I to matrix before taking matrix root. (Default: 0.0)
        exponent_multiplier (float): exponent multiplier in the eigen method (Default: 1.0)
        make_positive_semidefinite (bool): Perturbs matrix eigenvalues to ensure it is numerically positive semi-definite. (Default: True)
        retry_double_precision (bool): Flag for re-trying eigendecomposition with higher precision if lower precision fails due
            to CuSOLVER failure. (Default: True)

    Returns:
        X (Tensor): (Inverse) root of matrix. Same dimensions as A.
        L (Tensor): Eigenvalues of A.
        Q (Tensor): Orthogonal matrix consisting of eigenvectors of A.

    """

    # check if root is positive integer
    if root <= 0:
        raise ValueError(f"Root {root} should be positive!")

    # compute matrix power
    alpha = -exponent_multiplier / root

    # compute eigendecomposition and compute minimum eigenvalue
    try:
        L, Q = torch.linalg.eigh(A)

    except Exception as exception:
        if retry_double_precision and A.dtype != torch.float64:
            logger.warning(
                f"Failed to compute eigendecomposition in {A.dtype} precision with exception {exception}! Retrying in double precision..."
            )
            L, Q = torch.linalg.eigh(A.double())
        else:
            raise exception

    lambda_min = torch.min(L)

    # make eigenvalues >= 0 (if necessary)
    if make_positive_semidefinite:
        L += -torch.minimum(lambda_min, torch.as_tensor(0.0))

    # add epsilon
    L += epsilon

    # compute inverse preconditioner
    X = Q * L.pow(alpha).unsqueeze(0) @ Q.T

    return X, L, Q

In [178]:
torch.set_float32_matmul_precision('high')

xn_long = muon_renorm.zeropower_via_newtonschulz5(
    x.float(), steps=18,
    abc=muon_renorm.make_schedule(8, 2, 8),
    dtype=torch.float32,
)

xn = muon_renorm.zeropower_via_newtonschulz5(
    x, steps=5,
)
U, S, V = x.double().svd()
xn_cheat = (U @ V[..., :xn.shape[0]].T).to(torch.bfloat16)


In [81]:
_xn = x.float().svd().U @ x.float().svd().V.T
_xn = _xn.to(torch.bfloat16)

In [85]:
(xn_long - _xn).abs().max()

tensor(0.0010, device='cuda:0', dtype=torch.bfloat16)

In [60]:
x.float().svd().V @ x.float().svd().S.diag() @ x.float().svd().V.T

tensor([[13.3335, -0.3297,  0.4140,  ..., -0.0603,  0.1835, -0.0317],
        [-0.3297, 14.3710,  0.0423,  ...,  0.4852,  0.3412, -0.9058],
        [ 0.4140,  0.0423, 12.7372,  ...,  0.8726,  0.2346,  0.9779],
        ...,
        [-0.0603,  0.4852,  0.8726,  ..., 12.9309,  1.3419, -0.3291],
        [ 0.1835,  0.3412,  0.2346,  ...,  1.3419, 14.7677, -0.6373],
        [-0.0317, -0.9058,  0.9779,  ..., -0.3291, -0.6373, 12.6676]],
       device='cuda:0')

In [61]:
_temp

tensor([[12.6875, -0.3086,  0.4766,  ..., -0.1875,  0.2734, -0.0513],
        [-0.3184, 13.5625, -0.0713,  ...,  0.4961,  0.2021, -0.7969],
        [ 0.4590, -0.0762, 12.1250,  ...,  0.5625,  0.0645,  0.8711],
        ...,
        [-0.1924,  0.5117,  0.5664,  ..., 12.1875,  1.0156, -0.3223],
        [ 0.2617,  0.2334,  0.0703,  ...,  1.0234, 14.0000, -0.7930],
        [-0.0645, -0.8008,  0.8594,  ..., -0.3008, -0.7773, 11.9375]],
       device='cuda:0', dtype=torch.bfloat16)

In [141]:
xn.T @ x

tensor([[12.8125, -0.3535,  0.4297,  ..., -0.2324,  0.2852, -0.0718],
        [-0.3633, 13.8125, -0.0957,  ...,  0.4473,  0.2129, -0.8047],
        [ 0.4375, -0.0422, 12.3125,  ...,  0.5859,  0.0566,  0.8398],
        ...,
        [-0.2139,  0.4844,  0.6328,  ..., 12.3750,  1.0391, -0.3047],
        [ 0.2949,  0.2168,  0.0581,  ...,  1.0469, 14.1875, -0.7969],
        [-0.0267, -0.8086,  0.8516,  ..., -0.2949, -0.8242, 12.1250]],
       device='cuda:0', dtype=torch.bfloat16)

In [143]:
_matrix_root_eigen(x @ xn.T, 2, epsilon=1e-12)


Failed to compute eigendecomposition in torch.bfloat16 precision with exception "linalg_eigh_cuda" not implemented for 'BFloat16'! Retrying in double precision...


(tensor([[ 2.5655e-01, -5.5646e-03,  4.4095e-03,  ...,  5.1869e-03,
          -4.6841e-03,  5.9985e-03],
         [-5.5646e-03,  2.5677e-01,  1.4152e-03,  ..., -2.3429e-03,
           7.6731e-04, -2.8111e-03],
         [ 4.4095e-03,  1.4152e-03,  2.5578e-01,  ...,  3.9613e-03,
           4.4327e-03, -6.1242e-04],
         ...,
         [ 5.1869e-03, -2.3429e-03,  3.9613e-03,  ...,  2.6039e-01,
           1.1045e-04, -3.8337e-03],
         [-4.6841e-03,  7.6731e-04,  4.4327e-03,  ...,  1.1045e-04,
           2.4128e-01,  2.3840e-03],
         [ 5.9985e-03, -2.8111e-03, -6.1242e-04,  ..., -3.8337e-03,
           2.3840e-03,  2.4497e-01]], device='cuda:0', dtype=torch.float64),
 tensor([ 3.7722,  4.0749,  4.4163,  4.8264,  4.9394,  5.1295,  5.2883,  5.3255,
          5.5114,  5.5390,  5.5397,  5.5595,  5.5858,  5.5981,  5.6070,  5.6165,
          5.6262,  5.6388,  5.6643,  5.6742,  5.6976,  5.7026,  5.7205,  5.7441,
          5.7581,  5.7875,  5.8012,  5.8225,  5.8370,  5.8377,  5.8633,  

In [185]:
_matrix_root_eigen(xn.T @ x, 2, epsilon=1e-6)[0].svd().S


Failed to compute eigendecomposition in torch.bfloat16 precision with exception "linalg_eigh_cuda" not implemented for 'BFloat16'! Retrying in double precision...


tensor([1.0000e+03, 1.4783e+01, 9.5640e+00, 7.6839e+00, 6.0208e+00, 5.6663e+00,
        5.2810e+00, 4.9966e+00, 4.6449e+00, 4.5857e+00, 4.2507e+00, 4.1610e+00,
        4.0042e+00, 3.9881e+00, 3.9507e+00, 3.7453e+00, 3.6993e+00, 3.6328e+00,
        3.5699e+00, 3.5113e+00, 3.4616e+00, 3.4022e+00, 3.3537e+00, 3.3207e+00,
        3.2948e+00, 3.1739e+00, 3.1164e+00, 3.1013e+00, 3.0609e+00, 2.9961e+00,
        2.9704e+00, 2.9365e+00, 2.9206e+00, 2.8937e+00, 2.8411e+00, 2.8160e+00,
        2.8048e+00, 2.7577e+00, 2.7483e+00, 2.7209e+00, 2.7112e+00, 2.6802e+00,
        2.6138e+00, 2.6037e+00, 2.5827e+00, 2.5814e+00, 2.5549e+00, 2.5319e+00,
        2.5116e+00, 2.4976e+00, 2.4842e+00, 2.4593e+00, 2.4305e+00, 2.4083e+00,
        2.3956e+00, 2.3949e+00, 2.3691e+00, 2.3569e+00, 2.3419e+00, 2.3243e+00,
        2.3080e+00, 2.2944e+00, 2.2883e+00, 2.2756e+00, 2.2584e+00, 2.2527e+00,
        2.2404e+00, 2.2180e+00, 2.2151e+00, 2.2106e+00, 2.1897e+00, 2.1705e+00,
        2.1556e+00, 2.1511e+00, 2.1481e+

In [100]:
x.double().svd().V @ x.double().svd().S.pow(-1).diag() @ x.double().svd().V.T


tensor([[ 0.0359, -0.0028,  0.0004,  ..., -0.0027,  0.0005,  0.0005],
        [-0.0028,  0.0418,  0.0038,  ...,  0.0005,  0.0003, -0.0004],
        [ 0.0004,  0.0038,  0.0367,  ..., -0.0014, -0.0019, -0.0018],
        ...,
        [-0.0027,  0.0005, -0.0014,  ...,  0.0428,  0.0019,  0.0011],
        [ 0.0005,  0.0003, -0.0019,  ...,  0.0019,  0.0393,  0.0018],
        [ 0.0005, -0.0004, -0.0018,  ...,  0.0011,  0.0018,  0.0394]],
       device='cuda:0', dtype=torch.float64)

In [152]:
S

tensor([41.4167, 40.8357, 40.4905, 40.3890, 39.9276, 39.6737, 39.5611, 39.2824,
        39.1043, 38.9433, 38.7521, 38.6324, 38.5083, 38.2517, 37.9836, 37.8075,
        37.6660, 37.5274, 37.4257, 37.2361, 37.0193, 36.9609, 36.6816, 36.5950,
        36.3440, 36.2487, 36.1401, 35.9429, 35.6863, 35.5173, 35.3303, 35.2366,
        35.1712, 34.9059, 34.8248, 34.7916, 34.6254, 34.5272, 34.3327, 34.1458,
        34.0266, 33.8670, 33.6637, 33.5804, 33.5521, 33.4323, 33.3144, 33.2121,
        33.0976, 32.9569, 32.8536, 32.7094, 32.6467, 32.4949, 32.3802, 32.1979,
        32.1814, 32.0908, 31.9404, 31.8065, 31.7277, 31.5810, 31.3712, 31.2353,
        31.0953, 30.9324, 30.7981, 30.7030, 30.6192, 30.5217, 30.4655, 30.3025,
        30.1895, 30.1083, 30.0822, 29.8451, 29.7453, 29.6892, 29.5754, 29.4528,
        29.3788, 29.2818, 29.1536, 29.1182, 28.9359, 28.7137, 28.6792, 28.5064,
        28.4297, 28.3591, 28.2697, 28.1397, 28.0334, 27.9820, 27.8972, 27.7871,
        27.6875, 27.6282, 27.5785, 27.52

In [150]:
xn_cheat

tensor([[ 0.0315, -0.0615, -0.0310,  ..., -0.0308, -0.0479,  0.0107],
        [-0.0889,  0.0217, -0.0330,  ..., -0.0718, -0.0649, -0.0146],
        [ 0.0198, -0.0374,  0.0043,  ..., -0.0030,  0.0070,  0.0767],
        ...,
        [ 0.0205,  0.0352,  0.0014,  ...,  0.0171,  0.0087, -0.0125],
        [ 0.0027, -0.0142, -0.0194,  ..., -0.0037,  0.0219, -0.0708],
        [-0.0005,  0.0299,  0.0952,  ..., -0.0581,  0.0430, -0.0143]],
       device='cuda:0', dtype=torch.bfloat16)

In [182]:
(V @ S.diag() @ V.T)

tensor([[13.3335, -0.3297,  0.4139,  ..., -0.0602,  0.1834, -0.0317],
        [-0.3297, 14.3710,  0.0423,  ...,  0.4852,  0.3412, -0.9059],
        [ 0.4139,  0.0423, 12.7372,  ...,  0.8727,  0.2347,  0.9779],
        ...,
        [-0.0602,  0.4852,  0.8727,  ..., 12.9310,  1.3420, -0.3291],
        [ 0.1834,  0.3412,  0.2347,  ...,  1.3420, 14.7677, -0.6373],
        [-0.0317, -0.9059,  0.9779,  ..., -0.3291, -0.6373, 12.6676]],
       device='cuda:0', dtype=torch.float64)

In [207]:
L, info = torch.linalg.cholesky_ex((x @ xn.T).float())
(torch.cholesky_inverse(L).to(torch.bfloat16) @ x).float().svd().S, info

(tensor([1.5198, 1.5137, 1.5106, 1.5045, 1.5026, 1.5012, 1.4986, 1.4957, 1.4925,
         1.4909, 1.4895, 1.4835, 1.4821, 1.4805, 1.4796, 1.4752, 1.4713, 1.4667,
         1.4616, 1.4594, 1.4552, 1.4516, 1.4481, 1.4434, 1.4394, 1.4312, 1.4210,
         1.4194, 1.4151, 1.4128, 1.4030, 1.3972, 1.3905, 1.3811, 1.3728, 1.3685,
         1.3670, 1.3575, 1.3534, 1.3472, 1.3428, 1.3352, 1.3242, 1.3181, 1.3101,
         1.3072, 1.3038, 1.2989, 1.2931, 1.2866, 1.2718, 1.2708, 1.2688, 1.2672,
         1.2654, 1.2637, 1.2622, 1.2617, 1.2597, 1.2590, 1.2579, 1.2575, 1.2545,
         1.2533, 1.2515, 1.2509, 1.2500, 1.2489, 1.2475, 1.2463, 1.2456, 1.2443,
         1.2420, 1.2405, 1.2393, 1.2378, 1.2360, 1.2351, 1.2329, 1.2305, 1.2298,
         1.2272, 1.2257, 1.2248, 1.2220, 1.2209, 1.2189, 1.2148, 1.2142, 1.2123,
         1.2114, 1.2099, 1.2085, 1.2061, 1.2048, 1.2033, 1.2012, 1.1999, 1.1980,
         1.1979, 1.1942, 1.1937, 1.1885, 1.1860, 1.1850, 1.1844, 1.1805, 1.1783,
         1.1767, 1.1743, 1.1