In [99]:
import numpy as np

from basic_ops import gen_matrix

In [105]:
def svd(C: np.ndarray, iters=100):
    CTC = C.T @ C
    eigen_values, eigen_vectors = np.linalg.eig(CTC)
    indxs = np.argsort(eigen_values)[::-1]
    eigen_values = eigen_values[indxs]
    eigen_vectors = eigen_vectors[:, indxs]

    SIGMA = np.diag(np.sqrt(eigen_values))
    SIGMA_INV = np.diag(1/np.sqrt(eigen_values))

    V = eigen_vectors

    for x in V:
        for y in V:
            if np.linalg.norm(x-y) < 0.001: 
                pass
            else: 
                assert x@y , 0.0001

    O = np.empty((len(V), len(V)))
    for i, x in enumerate(V):
        for j, y in enumerate(V):
            O[i][j] = x@y

    print(O)


    U = C @ V @ SIGMA_INV

    return U, SIGMA, V.T


def demo(n: int, m: int):
    A = gen_matrix(n, m)
    
    if n < m:
        A = A.T

    U, SIGMA, V = svd(A)
    u, sigma, v = np.linalg.svd(A, full_matrices=True)

    print(f'U: \n{U}')
    print(f'u: \n{u}')
    print(f'SIGMA: \n{SIGMA}')
    print(f'sigma: \n{sigma}')
    print(f'V: \n{V}')
    print(f'v: \n{v}')

    resA = U @ SIGMA @ V
    print(f'A:\n{A}')  
    print(f'resA:\n{resA}')

    assert np.linalg.norm(A - resA) < 0.001

    for x, y in zip(U.T, u.T):
        assert np.linalg.norm(x - y) < 0.001 or np.linalg.norm(x + y) < 0.001
    
    for x, y in zip(V, v):
        assert np.linalg.norm(x - y) < 0.001 or np.linalg.norm(x + y) < 0.001

    if n < m:
        V.T, SIGMA, U.T

    return U, SIGMA, V

demo(3, 3)

[[ 1.00000000e+00  3.46944695e-16  2.77555756e-16]
 [ 3.46944695e-16  1.00000000e+00 -8.04911693e-16]
 [ 2.77555756e-16 -8.04911693e-16  1.00000000e+00]]
U: 
[[-0.56826071 -0.77169259 -0.28560516]
 [-0.59670192  0.6254638  -0.50273438]
 [-0.56659208  0.11526305  0.81589695]]
u: 
[[-0.56826071  0.77169259 -0.28560516]
 [-0.59670192 -0.6254638  -0.50273438]
 [-0.56659208 -0.11526305  0.81589695]]
SIGMA: 
[[1.76110928 0.         0.        ]
 [0.         0.31753153 0.        ]
 [0.         0.         0.10093385]]
sigma: 
[1.76110928 0.31753153 0.10093385]
V: 
[[-0.7462262  -0.6295889  -0.21625049]
 [-0.53230693  0.75941292 -0.37408734]
 [-0.39974466  0.16404214  0.90182836]]
v: 
[[-0.7462262  -0.6295889  -0.21625049]
 [ 0.53230693 -0.75941292  0.37408734]
 [-0.39974466  0.16404214  0.90182836]]
A:
[[0.88875848 0.43926025 0.28208476]
 [0.6987429  0.80410691 0.10719159]
 [0.69220545 0.66952644 0.27635691]]
resA:
[[0.88875848 0.43926025 0.28208476]
 [0.6987429  0.80410691 0.10719159]
 [0.6922

(array([[-0.56826071, -0.77169259, -0.28560516],
        [-0.59670192,  0.6254638 , -0.50273438],
        [-0.56659208,  0.11526305,  0.81589695]]),
 array([[1.76110928, 0.        , 0.        ],
        [0.        , 0.31753153, 0.        ],
        [0.        , 0.        , 0.10093385]]),
 array([[-0.7462262 , -0.6295889 , -0.21625049],
        [-0.53230693,  0.75941292, -0.37408734],
        [-0.39974466,  0.16404214,  0.90182836]]))