In [1]:
from typing import Dict, List, Tuple, Callable
import numpy as np
from numpy import linalg as LA
import functools
from collections import namedtuple, Iterator, Generator

def jacobi_step(A, b, x):
    D = np.diag(A)
    O = A - np.diagflat(D)
    return (b - np.einsum('ij,j->i', O, x)) / D

def quadratic_form(A, b, x):
    # Compute quadratic form 1/2 xAx + xb
    return (0.5 * np.einsum('i,ij,j->', x, A, x) + np.einsum('i,i->', x,b))

def relative_error_to_reference(x, x_ref):
    return LA.norm(x - x_ref) / LA.norm(x_ref)

def compute_residual_norm(A, b, x):
    return LA.norm(b - np.einsum('ij,j->i', A, x))

In [2]:
class Criterion:
    def __init__(self, threshold, comparison, message):
        self.threshold = threshold
        self.comparison = comparison
        self.message = message.format(threshold)
        
    def compare(self, value):
        return self.comparison(value, self.threshold)
    
class Stat:
    def __init__(self,
                 name: str,
                 header: str,
                 fmt: str,
                 compute: Callable,
                 threshold: float,
                 comparison: Callable,
                 message: str) -> None:
        self.name = name
        self.header = header
        self.fmt = fmt
        self.compute = compute
        self.threshold = threshold
        self.comparison = comparison
        self.message = message.format(threshold)
        
    def compare(self, iterate):
        return self.comparison(self.compute(iterate), self.threshold)

def check_termination(composition, criteria, value):
    """
    composition: str
      Denotes the composition of predicates, only disjucntive ('any') or conjuctive ('all') are allowed
    criteria:
      List of criteria conforming to the Criterion interface
    value:
      Current iterate
    """

    idx = 0
    terminate = False
    if composition == 'any':
        idx, terminate = any_with_index(map(lambda c: c.compare(value), criteria))
    elif composition == 'all':
        terminate = all(map(lambda c: c.compare(value), criteria))
    else:
        raise ValueError('Composition of predicates \'{:s}\' not recognized.\n Only \'any\' and \'all\' are allowed.'.format(composition))
    return idx, terminate

def any_with_index(iterable):
    for idx, element in enumerate(iterable):
        if element:
            return idx, True
    return idx, False

In [3]:
class Iterate:
    def __init__(self, iteration, x):
        self._iteration = iteration
        self._x = x
        
    @property
    def x(self):
        return self._x

    @x.setter
    def x(self, new_x):
        self._x = new_x

    @x.deleter
    def x(self):
        del self._x
        
    def update(self, new_x):
        self._iteration += 1
        self._x = new_x

In [4]:
class IterativeSolver(Iterator):
    def __init__(self, stepper, start_guess,
                 failure_stats,
                 success_stats,
                 exception):
        self._niterations = start_guess['iteration']
        self._stepper = stepper
        self._iterate = start_guess
        self._failure_stats = failure_stats
        self._success_stats = success_stats
        self._exception = exception
        self._stats = { s.name: s.compute(start_guess) for s in failure_stats + success_stats }
        print(self._header())
        
    def _header(self):
        nheaders = len(self._failure_stats + self._success_stats)
        header = ''.join('{:^20s}'.format(s.header) for s in self._failure_stats + self._success_stats)
        header += '\n{:s}'.format('='*20*nheaders)
        return header
    
    def _stat_line(self):    
        # Prepare format for stat line by stringing together the fmt strings
        # See here for explanation of format specification: https://pyformat.info/#getitem_and_getattr
        fmts = ''.join('{{stats[{:s}]:^20{:s}}}'.format(s.name, s.fmt.strip('{:}')) for s in self._failure_stats + self._success_stats)
        return fmts.format(stats=self._stats)
        
    def update_stats(self):
        for s in self._failure_stats + self._success_stats:
            self._stats[s.name] = s.compute(self._iterate)

    def __next__(self):
        try:
            idx, failed = check_termination('any', self._failure_stats, self._iterate)
            if failed:
                raise self._exception(self._failure_stats[idx].message)
            else:
                self._iterate = self._stepper(self._iterate)
        except self._exception:
            # We need to print the current status of the iterations
            raise
        finally:
            print(self._stat_line())
            # clean up/checkpoint actions after each iteration
            # Update iterations statistics
            self.update_stats()
            # Update global iterations counter
            self._niterations += 1
            # Check for success
            _, succeeded = check_termination('all', self._success_stats, self._iterate)
            if succeeded:
                #print(success_messages)
                raise StopIteration

In [5]:
dim = 1000
M = np.random.randn(dim, dim)
# Make sure our matrix is SPD
A = 0.5 * (M + M.transpose())
A = A * A.transpose()
A += dim * np.eye(dim)
b = np.random.rand(dim)
x_ref = LA.solve(A, b)

In [6]:
lt_cmp = lambda value, threshold: value < threshold
gt_cmp = lambda value, threshold: value > threshold

it_count = Stat('iteration counter', '# it.', '{:d}', lambda it: it['iteration'], 
                threshold=25, 
                comparison=gt_cmp,
                message='Maximum number of iterations ({0:d}) exceeded')

residual_norm = Stat('norm of residual', '||r||', '{:.3E}', lambda it: compute_residual_norm(A, b, it['x']),
                     threshold=1.0e-5,
                     comparison=lt_cmp,
                     message='Residual norm below threshold {0:.1E}')

def something(A, b, x, E):
    print('current E ', quadratic_form(A, b, x))
    print('previous E ', E)
    return abs(quadratic_form(A, b, x) - E)

denergy = Stat('pseudoenergy', 'Delta E', '{:.3E}', lambda it: abs(quadratic_form(A, b, it['x']) - it['E']),
               threshold=1.0e-5,
               comparison=lt_cmp,
               message='Pseudoenergy variation below threshold {0:.1E}')

stepper = lambda iterate: {
    'iteration': iterate['iteration'] + 1, 
    'x': jacobi_step(A, b, iterate['x']),
    'E': quadratic_form(A, b, iterate['x'])
    }
x_0 = np.zeros_like(b)
guess = {'iteration' : 0, 'x': x_0, 'E': quadratic_form(A, b, x_0)}
jacobi2 = IterativeSolver(stepper, guess, [it_count], [residual_norm, denergy], RuntimeError)

# First converge to a loose threshold
for _ in jacobi2:
    pass 
        
print('jacobi2._niterations ', jacobi2._niterations)
print('Jacobi relative error to reference {:.5E}\n'.format(relative_error_to_reference(jacobi2._iterate['x'], x_ref)))

       # it.               ||r||              Delta E       
         0               1.817E+01           0.000E+00      
         1               7.884E+00           5.566E-01      
         2               3.946E+00           2.631E-01      
         3               1.975E+00           1.203E-01      
         4               9.889E-01           6.311E-02      
         5               4.951E-01           3.086E-02      
         6               2.479E-01           1.563E-02      
         7               1.241E-01           7.781E-03      
         8               6.212E-02           3.907E-03      
         9               3.110E-02           1.953E-03      
         10              1.557E-02           9.785E-04      
         11              7.795E-03           4.897E-04      
         12              3.902E-03           2.452E-04      
         13              1.954E-03           1.227E-04      
         14              9.781E-04           6.145E-05      
         15             