In [221]:
import numpy as np


eps = 1e-4

def DEF_decompose(A: np.array):
    D = np.diag(np.diag(A))
    E = np.triu(A, 1)
    F = np.tril(A, -1)
    return D, E, F

def process(A: np.array, b: np.array, omega=1.0):
    D, E, F = DEF_decompose(A)
    def L(omega):
        return np.linalg.inv(D / omega + E) * (D * (1 - omega) / omega - F)
    
    def max_eigval_abs(x):
        return np.max(np.abs(np.linalg.eigvals(x)))
    
    # omega = 0.8
    # while (max_eigval_abs(L(omega)) >= 1 and omega > 0):
    #     omega += 0.1


    
    def J():
        return np.linalg.inv(D) * (-E -F)
    
    # omega = 2 / (1 + np.sqrt(1 - max_eigval_abs(J()) ** 2)) 
    # print(omega)   

    diag = np.diag(A)
    b_star = b / diag
    u_curr = np.zeros(A.shape[0]).astype(np.float64)
    u_prev = np.zeros(A.shape[0]).astype(np.float64)
    
    
    while (np.any(np.isfinite(u_curr)) and np.linalg.norm(A @ u_curr - b) > eps):
        u_curr = (1 - omega)  * u_prev + omega * b_star
        for j in range(A.shape[0]):
            u_curr[j] -= omega * (F[j] * u_curr + E[j] * u_prev).sum() / diag[j]
        
        u_prev = u_curr

    return u_curr
    

In [247]:
def is_pos_def(x):
    return np.all(np.linalg.eigvals(x) > 0)

def max_eigval_abs(x):
    return np.max(np.abs(np.linals.eigvals(x)))


for i in range(2, 5):
    A = np.random.randn(i, i).astype(np.float64)
    A += A.T
    while (np.linalg.matrix_rank(A) < i or np.min(np.abs(np.diag(A))) < eps or not is_pos_def(A)):
        A = np.random.randn(i, i).astype(np.float64)
        A += A.T

    assert np.all(A == A.T)
    # print(np.diag(A))
    b = np.random.randn(i).astype(np.float64)
    # print(process(A, b, 1000))
    # print(np.linalg.solve(A, b))

    solution = np.linalg.solve(A, b)
    omega = 0.1
    while omega < 2:
        result = process(A, b, omega)
        if (np.any(np.isnan(result)) or np.mean(np.abs(solution - result)) > eps):
            # print(result)
            # print(solution)
            # print(np.linalg.norm(A @ solution - b))
            pass#print(f"bad system solver {i} {omega}")
        else:
            print(f"solved {i} {omega}")
            print(result)
            print(solution)
            break
        omega += 0.05


solved 2 0.1
[1.69354319 0.01806643]
[1.69365205 0.01799716]
solved 3 0.9000000000000002
[0.35424033 1.27158001 0.95509955]
[0.35430816 1.2716926  0.95521731]
solved 4 1.4500000000000006
[ 8.31715231  2.00248157 -5.10687911 -1.56453296]
[ 8.31734597  2.00250339 -5.10698655 -1.564567  ]
