In [8]:
import numpy as np
import validation as validation

In [9]:
#########################################################################
#  Operators
#  Each operator must be a class with forward and transpose methods
###########################################################################

class zero_op:
    """
        Returns 0
    """
    def forward(self,x):
        return 0.0
    def transpose(self,x):
        return 0.0

class matrix_op:
    """
        matrix_op(A) converts the matrix A to a matrix 
        operator and its transpose
    """
    def __init__(self,A):
        self.A=A
        
    def forward(self,x):
        return np.matmul(self.A,x)
    
    def transpose(self,x):
        return np.matmul(np.transpose(self.A),x)
    
    
class scalar_prod_op:
    """
       scalar_prod(a) is the scalar product with a
    """
    def __init__(self,a):
        self.a=a
        
    def forward(self,x):
        return self.a*x
    
    def transpose(self,x):
        return self.forward(x)
    
class hadamard_op:
    """
       Hadamard product. Multiply every entry of 
       x with the corresponding entry of A
    """
    def __init__(self,A):
        self.A=A
    def forward(self,x):
        return self.A*x
    def transpose(self,x):
        return self.forward(x)
    
class real_fftn_op:
    """
        FFT of real valued vectors and matrices
    """
    def forward(self,x):
        return np.fft.fftshift(np.fft.fftn(np.real(x)))
    def transpose(self,x):
        return np.real(np.fft.ifftn(np.fft.ifftshift(x)))
    
    
class composite_op:
    """
       Creates a composite operator 
       Uses mathematical notation
       That is, composite_op(A,B,C) is the operator ABC, i.e. C operates first
       followed by B followed by A
       composite_op accepts any number of operators as arguments
    """
    def __init__(self,*ops):
        self.ops=ops
        
    def forward(self,x):
        for op in reversed(self.ops):
            x=op.forward(x)
        return x
    
    def transpose(self,x):
        for op in self.ops:
            x=op.transpose(x)
        return x

    
class add_op:
    """
        Creates a sum operator
        e.g. add_op(A,B,C) gives the operator A+B+C
    """
    def __init__(self,*ops):
        self.ops=ops

    def forward(self,x):
        y=0
        for op in self.ops:
            y += op.forward(x)
        return y
    
    def transpose(self,x):
        y=0
        for op in self.ops:
            y += op.transpose(x)
        return y   

In [10]:
##########################################################
#  Utility functions for conjugate gradient
##########################################################

def inner_prod(x,y):
    """
        real and complex inner product
        Works for vectors and matrices
    """
    return np.real(np.sum(np.conj(x)*y))

def norm_sq(x):
    """
        Euclidean norm sq for vectors
        Frobenius norm sq for matrices
        Works for real and complex
    """
    return inner_prod(x,x)

def norm(x):
    """
        Euclidean norm for vectors
        Frobenius norm for matrices
        Works for real and complex
    """
    return np.sqrt(norm_sq(x))
    
            
def lhs_op(A,B):
    """
        Creates the operator on the LHS of the CG linear equation
        i.e. creates (A^TA+B^TB)
    """
    def lhs(x):
        return A.transpose(A.forward(x))+B.transpose(B.forward(x))
    return lhs
    
##################################################################
#    Conjugate gradient minimization with L2 regularization
##################################################################
    
def c_grad(y,A,x,B=zero_op(),max_iter=100,f_tol=1e-5):
    """
     Conjugate gradient algorithm for minimizing ||y-Ax||^2 + ||Bx||^2
     The second term is a regulizer, and is optional
     The termination criteria are
         either k=num_steps >= max_iter 
         or ||y-Ax||<= f_tol*||y|| (Scipy criterion)
         
    The minimizer is found by solving the linear equation
            (A^TA + B^TB)x=A^Ty
    
    Inputs: y,A,B as defined by the objective function above. A and B
                 have to be operators with forward and transpose defined
            max_iter: maximum number of iterations
            f_tol: as defined above
    
    Output: tuple (x,flag)
        where x is the solution
             flag=1 if max_iter are reached, else 0
    """
    # Mathematical comments below correspond to Wikipedia CG formulae
    # written in Latex
    # See https://en.wikipedia.org/wiki/Conjugate_gradient_method
    
    A_star=lhs_op(A,B) #A_star is the operator A^*=(A^TA + B^TB)
    b=A.transpose(y)  #A^Ty
    r_k=b-A_star(x)
    p_k=r_k
    x_k=x
    #initialize iteration
    k=0
    t=f_tol*norm(b)
    res_norm_sq=norm_sq(r_k)
    while ((k<max_iter) & (np.sqrt(res_norm_sq)>=t)):
        AtAp=A_star(p_k) #Precalculate to save flops
        alpha_k= res_norm_sq/inner_prod(p_k,AtAp) # alpha_k = r^T_kr_k/p^T_k A^*p_k
        x_k1=x_k+alpha_k*p_k                      # x_{k+1}=x_k+\alpha_k p_k
        r_k1=r_k-alpha_k*AtAp                     # r_{k+1}= r_{k}-\alpha_k A^*p_k
        beta_k=norm_sq(r_k1)/res_norm_sq          # \beta_k = r^T_{k+1}r_{k+1}/r^T_kr_k
        p_k1=r_k1+beta_k*p_k                      # p_{k+1}= r_{k+1}+\beta_k p_k
        #update
        k=k+1
        x_k=x_k1
        p_k=p_k1
        r_k=r_k1
        res_norm_sq=norm_sq(r_k)
        print(f"CG: step={k} res_norm={np.sqrt(res_norm_sq)} ")
    return x_k,k>=max_iter
        

### Test 1: Simple test without noise
#### The operator A is (MO+R). Task is to recover a vector. No regularization on CG

In [32]:
M=matrix_op(np.array(((1, 2, 3),(0, 1, 0),(2, 3,1))))
O=scalar_prod_op(2.0)
R=scalar_prod_op(0.1)
A=add_op(composite_op(M,O),R) 

x=np.array((1,0.5,1)).reshape(-1,1) #Input x
y=A.forward(x) #Create noiseless output by applying A
print(f"x={x}\n y={y}")

x0=np.zeros_like(x) #Initial value for CG
x_cg,flag=c_grad(y,A,x0)
print(f"x={x_cg} \n flag={flag}")
validation.validate_equal(x, x_cg, "Original", "CG", rtol=0.1)

x=[[1. ]
 [0.5]
 [1. ]]
 y=[[10.1 ]
 [ 1.05]
 [ 9.1 ]]
CG: step=1 res_norm=3.234651420766811 
CG: step=2 res_norm=0.7480204131885724 
CG: step=3 res_norm=6.8985257123446856e-12 
x=[[1. ]
 [0.5]
 [1. ]] 
 flag=False
Original and CG are equal


### Test 2: x is matrix, y contains noise, CG uses regularizer

In [35]:
M=matrix_op(np.array(((1, 2, 3),(0, 1, 0),(2, -3,1))))
O=scalar_prod_op(2.0)
R=scalar_prod_op(0.1)

A=add_op(composite_op(M,O),R)

x=np.transpose(np.array(((1,0.5,1),(0,1,0))))
y=A.forward(x)
y=y+0.1*np.random.normal(size=y.shape)
print(f"x={x}\n y={y}")

x0=np.zeros_like(x)
x_cg,flag=c_grad(y,A,x0,B=R,f_tol=1e-8)
print(f"x={x_cg} \n flag={flag}")
validation.validate_equal(x, x_cg, "Original", "CG", rtol=0.1)

x=[[1.  0. ]
 [0.5 1. ]
 [1.  0. ]]
 y=[[10.03063349  3.91314989]
 [ 1.00018442  2.12118154]
 [ 3.17879495 -5.98570569]]
CG: step=1 res_norm=8.195664905134663 
CG: step=2 res_norm=0.27537692472762665 
CG: step=3 res_norm=1.432877941932448e-13 
x=[[ 0.97121412  0.03073669]
 [ 0.47357632  1.00754367]
 [ 1.01595754 -0.03024619]] 
 flag=False
Original and CG are not equal


### CG with fft

In [45]:
x=np.random.normal(size=(4,4))
print(x)
A=real_fftn_op()
y=A.forward(x)
y=y+0.2*np.random.normal(size=y.shape)
#print(f"x={x}\n y={y}")
print(f"x={x}")

x0=np.zeros_like(x)
R=scalar_prod_op(0.1)
x_cg,_=c_grad(y,A,x0,B=R)

print(x_cg)
validation.validate_equal(x, x_cg, "Original", "CG", rtol=0.1)

[[-1.0051926   1.19587362 -0.14977859  0.72965053]
 [ 1.33682123  0.47527541  0.09691704  0.70387281]
 [-0.80815115  0.90791419  0.66868599  0.30079857]
 [ 0.94451918  0.16271802  0.52889598  0.62236485]]
x=[[-1.0051926   1.19587362 -0.14977859  0.72965053]
 [ 1.33682123  0.47527541  0.09691704  0.70387281]
 [-0.80815115  0.90791419  0.66868599  0.30079857]
 [ 0.94451918  0.16271802  0.52889598  0.62236485]]
CG: step=1 res_norm=3.976416080268664e-16 
[[-1.02778266  1.21297435 -0.16908533  0.75136733]
 [ 1.31031303  0.50369473  0.05136426  0.72634982]
 [-0.80635189  0.92797281  0.56864629  0.32686823]
 [ 0.92189516  0.190553    0.47906619  0.64932783]]
Original and CG are not equal
