In [1]:
import torch

def pcg(A, b, M_inv=None, x0=None, tol=1e-8, max_iter=1000):
    """
    Preconditioned Conjugate Gradient (PCG) solver.

    Args:
        A (callable): A function that computes the matrix-vector product A(x).
        b (torch.Tensor): The right-hand side vector.
        M_inv (callable, optional): Preconditioner. A function that computes M^{-1}(x). 
                                    If None, no preconditioner is used.
        x0 (torch.Tensor, optional): Initial guess for the solution. If None, a zero vector is used.
        tol (float): Tolerance for convergence based on residual.
        max_iter (int): Maximum number of iterations.
    
    Returns:
        x (torch.Tensor): Approximate solution to A x = b.
    """
    
    if x0 is None:
        x0 = torch.zeros_like(b)
    
    if M_inv is None:
        # If no preconditioner is provided, use identity (i.e., M_inv = I)
        M_inv = lambda x: x
    
    # Initializations
    x = x0
    r = b - A(x)  # Initial residual
    z = M_inv(r)  # Apply preconditioner
    p = z.clone() # Search direction
    rsold = torch.dot(r, z)  # Inner product of r and preconditioned r (z)

    for i in range(max_iter):
        Ap = A(p)  # Compute A*p
        alpha = rsold / torch.dot(p, Ap)  # Step size
        
        # Update the solution
        x = x + alpha * p
        
        # Update residual
        r = r - alpha * Ap
        
        # Check for convergence
        if torch.norm(r) < tol:
            print(f"Converged in {i+1} iterations.")
            break
        
        # Apply preconditioner to the new residual
        z = M_inv(r)
        
        # Compute new direction coefficient
        rsnew = torch.dot(r, z)
        beta = rsnew / rsold
        
        # Update search direction
        p = z + beta * p
        
        rsold = rsnew

    return x

In [3]:

# Example Usage
if __name__ == "__main__":
    # Define the matrix A as a function that computes A @ x
    def A(x):
        A_matrix = torch.tensor([[4.0, 1.0], [1.0, 3.0]])
        return A_matrix @ x

    # Right-hand side vector b
    b = torch.tensor([1.0, 2.0])
    
    # Preconditioner: For simplicity, using Jacobi (diagonal elements of A)
    def M_inv(x):
        M_diag = torch.tensor([4.0, 3.0])  # Diagonal of A
        return x / M_diag
    
    # Initial guess (optional)
    x0 = torch.tensor([2.0, 1.0])
    
    # Solve A @ x = b using Preconditioned CG
    x_solution = pcg(A, b, M_inv=None, x0=None, tol=1e-8, max_iter=1000)
    
    print("Solution x:", x_solution)

Converged in 2 iterations.
Solution x: tensor([0.0909, 0.6364])
