In [19]:
import numpy as np
from scipy.sparse import csr_array, csc_matrix, coo_matrix
from scipy.sparse.linalg import cg

In [20]:
def build_laplacian_2d(N):
    # 2D laplacian on N*N grid
    row = []
    col = []
    data = []
    for i in range(N):
        for j in range(N):
            index = i*N + j
            row.append(index)
            col.append(index)
            data.append(-4)
            if i > 0:
                row.append(index)
                col.append(index-N)
                data.append(1)
            if i < N-1:
                row.append(index)
                col.append(index+N)
                data.append(1)
            if j > 0:
                row.append(index)
                col.append(index-1)
                data.append(1)
            if j < N-1:
                row.append(index)
                col.append(index+1)
                data.append(1)
    return csr_array((data, (row, col)), shape=(N*N, N*N))

In [21]:
A = build_laplacian_2d(3)

In [22]:
A.toarray()

array([[-4,  1,  0,  1,  0,  0,  0,  0,  0],
       [ 1, -4,  1,  0,  1,  0,  0,  0,  0],
       [ 0,  1, -4,  0,  0,  1,  0,  0,  0],
       [ 1,  0,  0, -4,  1,  0,  1,  0,  0],
       [ 0,  1,  0,  1, -4,  1,  0,  1,  0],
       [ 0,  0,  1,  0,  1, -4,  0,  0,  1],
       [ 0,  0,  0,  1,  0,  0, -4,  1,  0],
       [ 0,  0,  0,  0,  1,  0,  1, -4,  1],
       [ 0,  0,  0,  0,  0,  1,  0,  1, -4]])

In [23]:
def build_incomplete_poission(N):
    # P^{-1} for incomplete poisson preconditioner
    row = []
    col = []
    data = []
    for i in range(N):
        for j in range(N):
            index = i*N + j
            
            sum_m = 0
            if i > 0:
                row.append(index)
                col.append(index-N)
                data.append(1/4)
                sum_m += (1/4)**2
            if i < N-1:
                row.append(index)
                col.append(index+N)
                data.append(1/4)
                # sum_m += (1/4)**2
            if j > 0:
                row.append(index)
                col.append(index-1)
                data.append(1/4)
                sum_m += (1/4)**2
            if j < N-1:
                row.append(index)
                col.append(index+1)
                data.append(1/4)
                # sum_m += (1/4)**2
            row.append(index)
            col.append(index)
            data.append(1 + sum_m)
    return csr_array((data, (row, col)), shape=(N*N, N*N))

def incomplete_poisson_matrix(A):
    A = A.toarray()
    I = np.identity(A.shape[0])
    L = np.tril(A, -1)
    diag = np.diag(A)
    H = I - L @ np.diag(1.0 / diag)
    return H @ H.T

In [24]:
Pinv = build_incomplete_poission(3)

In [25]:
Pinv.toarray()

array([[1.    , 0.25  , 0.    , 0.25  , 0.    , 0.    , 0.    , 0.    ,
        0.    ],
       [0.25  , 1.0625, 0.25  , 0.    , 0.25  , 0.    , 0.    , 0.    ,
        0.    ],
       [0.    , 0.25  , 1.0625, 0.    , 0.    , 0.25  , 0.    , 0.    ,
        0.    ],
       [0.25  , 0.    , 0.    , 1.0625, 0.25  , 0.    , 0.25  , 0.    ,
        0.    ],
       [0.    , 0.25  , 0.    , 0.25  , 1.125 , 0.25  , 0.    , 0.25  ,
        0.    ],
       [0.    , 0.    , 0.25  , 0.    , 0.25  , 1.125 , 0.    , 0.    ,
        0.25  ],
       [0.    , 0.    , 0.    , 0.25  , 0.    , 0.    , 1.0625, 0.25  ,
        0.    ],
       [0.    , 0.    , 0.    , 0.    , 0.25  , 0.    , 0.25  , 1.125 ,
        0.25  ],
       [0.    , 0.    , 0.    , 0.    , 0.    , 0.25  , 0.    , 0.25  ,
        1.125 ]])

In [26]:
P = incomplete_poisson_matrix(A)

In [27]:
P

array([[1.    , 0.25  , 0.    , 0.25  , 0.    , 0.    , 0.    , 0.    ,
        0.    ],
       [0.25  , 1.0625, 0.25  , 0.0625, 0.25  , 0.    , 0.    , 0.    ,
        0.    ],
       [0.    , 0.25  , 1.0625, 0.    , 0.0625, 0.25  , 0.    , 0.    ,
        0.    ],
       [0.25  , 0.0625, 0.    , 1.0625, 0.25  , 0.    , 0.25  , 0.    ,
        0.    ],
       [0.    , 0.25  , 0.0625, 0.25  , 1.125 , 0.25  , 0.0625, 0.25  ,
        0.    ],
       [0.    , 0.    , 0.25  , 0.    , 0.25  , 1.125 , 0.    , 0.0625,
        0.25  ],
       [0.    , 0.    , 0.    , 0.25  , 0.0625, 0.    , 1.0625, 0.25  ,
        0.    ],
       [0.    , 0.    , 0.    , 0.    , 0.25  , 0.0625, 0.25  , 1.125 ,
        0.25  ],
       [0.    , 0.    , 0.    , 0.    , 0.    , 0.25  , 0.    , 0.25  ,
        1.125 ]])