In [5]:
import torch

def apply_preconditioner(M, r, tol=1e-6, max_iter=10):
    z = torch.zeros_like(r)  # Initial guess for z
    for _ in range(max_iter):
        # Jacobi preconditioner: Solve Mz = r approximately
        z_new = r / torch.diag(M)  # Jacobi step
        if torch.norm(z_new - z) < tol:
            break
        z = z_new
    return z

def conjugate_gradient(A, b, M, tol=1e-6, max_iter=100):
    
    x = torch.zeros_like(b)  # Initial guess (zero vector)
    r = b - A @ x            # Initial residual
    z = apply_preconditioner(M, r)  # Preconditioned residual
    p = z                    # Initial search direction

    for i in range(max_iter):
        Ap = A @ p           # Matrix-vector product A*p
        alpha = torch.dot(r, z) / torch.dot(p, Ap)  # Step size
        
        # Update the solution
        x = x + alpha * p
        
        # Update residual
        r_new = r - alpha * Ap
        
        # Check for convergence
        if torch.norm(r_new) < tol:
            print(f"Converged in {i} iterations")
            break
        
        # Apply the preconditioner iteratively
        z_new = apply_preconditioner(M, r_new)
        
        # Compute beta for next search direction
        beta = torch.dot(r_new, z_new) / torch.dot(r, z)
        
        # Update the search direction
        p = z_new + beta * p
        
        # Update residuals and preconditioned residuals for next iteration
        r = r_new
        z = z_new
        
    return x

# Example usage

# Define problem: A*x = b
N = 100
A = torch.randn(N, N)
A = A @ A.T  # Ensure A is symmetric positive definite
b = torch.randn(N)

# Define a preconditioner (Jacobi in this case, i.e., diagonal of A)
M = torch.diag(torch.diag(A))

# Solve using CG with an iterative preconditioner
x_approx = conjugate_gradient(A, b, M)
print(f"Solution: {x_approx}")


Solution: tensor([ 7.7344e-01,  4.3687e+00,  6.9564e+00, -2.0197e+01, -1.5900e+01,
        -1.2269e+00, -6.9439e-01, -1.7931e-02, -1.8110e+01, -7.2649e+00,
         6.1214e-01, -8.4489e-01,  5.9477e+00, -8.8599e+00,  8.4666e-01,
        -2.6092e+01,  2.0729e+01, -7.9087e+00, -3.8577e+00, -1.4499e+01,
        -1.4657e+01, -1.3788e+01, -1.3956e+01, -2.6199e+01, -7.2052e+00,
        -3.5897e+00, -1.3182e+01,  1.2401e+00, -2.0529e+00,  2.6556e+01,
         1.2634e+01, -1.8660e+00,  1.5714e+01,  1.2464e+00,  6.1699e-01,
         9.0474e+00,  1.8609e+01, -7.6115e+00,  9.2485e+00, -1.1184e+01,
         2.4532e+00,  1.4679e+01,  1.6107e+01, -8.3286e+00,  1.8670e+00,
        -1.5432e+01,  4.3211e+00, -4.2589e+00,  1.2108e+01, -2.6069e+01,
         1.0527e+01,  3.1457e+01, -9.6584e-01,  5.1985e+00, -6.6098e+00,
        -2.2560e+01,  7.8750e+00,  9.7475e+00, -1.7390e+01, -1.7717e+00,
        -1.2219e+01, -1.0057e+01,  7.5301e+00,  4.8581e+00,  3.5889e-01,
         1.0312e+01,  8.2844e+00, -4.6939