In [1]:
import numpy as np
from numpy import linalg as LA
import functools
from collections import namedtuple, Iterator, Generator
import operator

class Iteration(object):
    def __init__(self, low, high):
        self.low = low
        self.high = high

    def __iter__(self):
        counter = self.low
        while self.high >= counter:
            yield counter
            counter += 1
            
it = Iteration(0, 15)
for num in it:
    print(num, end=' ')
print()        
    
Iterate1 = namedtuple('Iterate', 'iteration, rnorm, x')

StoppingCriterion = namedtuple('StoppingCriterion', 'predicate, is_exceptional, exception')

maxit_exceeded = StoppingCriterion(predicate=(lambda it: it > 25), 
                                   is_exceptional=True, 
                                   exception=RuntimeError)

rnorm_above_rtol = StoppingCriterion(predicate=(lambda rnorm: rnorm > 1.0e-8),
                                     is_exceptional=False,
                                     exception=RuntimeError)

criteria = [maxit_exceeded, rnorm_above_rtol]
exit_iteration = False
for criterion in criteria:
    exit_iteration = exit_iteration or criterion.predicate(5)
print(exit_iteration)
    
class LinearIteration(object):
    def __init__(self, stepper, iterate, max_it=25):
        self.stepper = stepper
        self.iterate = iterate
        self.max_it = max_it
        self.niterations = 0
        
    def __iter__(self):
        # We do not start from zero as this might be a restart
        counter = self.iterate.iteration
        try:
            while self.max_it > counter:
                x_next = self.stepper(self.iterate.x)
                rnorm = LA.norm(x_next - self.iterate.x)
                counter +=1
                self.niterations += 1
                self.iterate = Iterate1(iteration=counter, rnorm=rnorm, x=x_next)
                print('{0:d}  {1:.3E}'.format(self.iterate.iteration, self.iterate.rnorm))
                yield self.iterate           
            raise RuntimeError('Maximum number of iterations ({0:d}) exceeded without meeting stopping criteria'.format(self.max_it))
        except RuntimeError as err:
            # clean up/checkpoint actions before dying
            print('cleaning up the mess')
            raise
        finally:
            # clean up/checkpoint actions
            print('finally!')

0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 
True


In [2]:
def jacobi_step(A, b, x):
    D = np.diag(A)
    O = A - np.diagflat(D)
    return (b - np.einsum('ij,j->i', O, x)) / D

def quadratic_form(A, b, x):
    # Compute quadratic form 1/2 xAx + xb
    return (0.5 * np.einsum('i,ij,j->', x, A, x) + np.einsum('i,i->', x,b))

In [3]:
dim = 1000
M = np.random.randn(dim, dim)
# Make sure our matrix is SPD
A = 0.5 * (M + M.transpose())
A = A * A.transpose()
A += dim * np.eye(dim)
b = np.random.rand(dim)
x_ref = LA.solve(A, b)

def relative_error_to_reference(x, x_ref):
    return LA.norm(x - x_ref) / LA.norm(x_ref)

In [4]:
print('Jacobi algorithm')
#x_jacobi = jacobi(A, b)
stepper = functools.partial(jacobi_step, A, b)
energy = functools.partial(quadratic_form, A, b)
x_0 = np.zeros_like(b)
x_jacobi = Iterate1(iteration=0, x=x_0, rnorm=LA.norm(b - np.einsum('ij,j->i', A, x_0)))
jacobi = LinearIteration(stepper, iterate=x_jacobi)

# First converge to a loose threshold
rtol=1.0e-4
for x_jacobi in jacobi:
    if x_jacobi.rnorm < rtol:
        break

print(jacobi.niterations)
print('Jacobi relative error to reference {:.5E}\n'.format(
        relative_error_to_reference(x_jacobi.x, x_ref)))

# Then take latest iterate and restart
restarted_jacobi = LinearIteration(stepper, iterate=x_jacobi)
rtol=1.0e-7
for x_jacobi in restarted_jacobi:
    if x_jacobi.rnorm < rtol:
        break

print(restarted_jacobi.niterations)
print('Jacobi relative error to reference {:.5E}\n'.format(
        relative_error_to_reference(x_jacobi.x, x_ref)))

Jacobi algorithm
1  1.777E-02
2  7.637E-03
3  3.820E-03
4  1.912E-03
5  9.566E-04
6  4.787E-04
7  2.395E-04
8  1.199E-04
9  5.998E-05
finally!
9
Jacobi relative error to reference 1.46443E-03

10  3.001E-05
11  1.502E-05
12  7.515E-06
13  3.760E-06
14  1.882E-06
15  9.416E-07
16  4.712E-07
17  2.358E-07
18  1.180E-07
19  5.904E-08
finally!
10
Jacobi relative error to reference 1.44147E-06



In [5]:
# Example from https://www.usenix.org/system/files/login/articles/12_beazley-online.pdf
from contextlib import contextmanager

@contextmanager
def manager():
    # Everything before yield is part of _ _enter_ _
    print("Entering")
    try:
        yield "SomeValue"
    # Everything beyond the yield is part of _ _exit_ _
    except Exception as e:
        print("An error occurred: %s" % e)
        raise
    else:
        print("No errors occurred")

In [6]:
with manager() as val:
    print("Hello, world")
    print(val)
    x = int('whatever')

Entering
Hello, world
SomeValue
An error occurred: invalid literal for int() with base 10: 'whatever'


ValueError: invalid literal for int() with base 10: 'whatever'

In [None]:
class BoundedRepeater(Iterator):
    def __init__(self, value, max_repeats):
        self.value = value
        self.max_repeats = max_repeats
        self.count = 0     
        
    def __next__(self):
        try:
            if self.count >= self.max_repeats:
                raise StopIteration('Exceeded maximum number of repeats!!!')
            self.count += 1
            return self.value
        except StopIteration as e:
            print(e)
            raise
        finally:
            print('cleaning up, finally!')

repeater = BoundedRepeater('Hello', 3)

# This causes the exception to be raised
#next(repeater)

from typing import Dict, List, Tuple, Callable

class Iterate:
    def __init__(self, iteration: int, x: List[float], stats: Dict[str, float]) -> None:
        """
        stats is a dictionary containing the statistics (e.g. norm of residual) for the current iterate.
        The key, value pair is the name and value of the statistics
        """
        self._iteration = iteration
        self._x = x
        self._stats = stats
        
    def __str__(self) -> str:
        """
        This is used for print(iterate), so mostly in debugging and we want full info: iteration number, vector and stats
        FIXME we possibly want also print_out to print a line in the final report
        """
        message = 'Iteration number {0:2d}'.format(self._iteration)
        return '\n'.join()
        
    def compute_stats(self, funcs: Dict[str, Callable[..., float]]) -> None:
        """
        Apply `funcs` on `_x` to update `_stats`
        FIXME: I think funcs and stats should be both members, so we avoid having them out of sync...
        """
        for key, f in funcs:
            self._stats[key] = f(self._x)
        

class Criterion:
    def __init__(self, threshold, comparison, exception, message):
        self.threshold = threshold
        self.comparison = comparison
        self.exception = exception
        self.message = message.format(self.threshold)
        
    def compare(self, value):
        return self.comparison(value, self.threshold)
    
    def throw(self):
        raise self.exception(self.message)

maxit_exceeded = Criterion(threshold=25,
                           comparison=(lambda value, threshold: value > threshold),
                           exception=RuntimeError,
                           message='Maximum number of iterations ({0:d}) exceeded')

rnorm_above_rtol = Criterion(threshold=1.0e-10,
                             comparison=(lambda value, threshold: value < threshold),
                             exception=RuntimeError,
                             message='Residual norm below threshold {0:.1E}')

denergy_below_etol = Criterion(threshold=1.0e-3, 
                              comparison=(lambda value, threhold: abs(value) < threshold),
                              exception=RuntimeError,
                              message='Energy difference below threshold {0:.1E}')

# This is any on a list of custom predicates
def check_failure(predicates, values):
    for i, p in enumerate(predicates):
        if p.compare(values[i]):
            raise p.throw()
    return False
    
# This is all() on a list of custom predicates
def check_success(predicates, values):
    for i, p in enumerate(predicates):
        if not p.compare(values[i]):
            return False
    return True

class IterativeSolver(Iterator):
    def __init__(self, stepper, iterate, failures, successes):
        self.stepper = stepper
        self.iterate = iterate
        self.failures = failures
        self.successes = successes
        self.niterations = 0
        # Print iteration header
        self._header()

    def _header(self):
        print('   # It.     |r|           |dE|')
        print('-----------------------------------')
        
    def __next__(self):
        # We do not start from zero as this might be a restart
        counter = self.iterate.iteration
        rnorm = 0.0
        denergy = 0.0
        try:
            failed = check_failure(self.failures, [self.iterate.iteration])
            if not failed:
                x_next = self.stepper(self.iterate.x)
                rnorm = LA.norm(x_next - self.iterate.x)
                denergy = energy(x_next) - energy(self.iterate.x)
                self.niterations += 1
                counter += 1
                self.iterate = Iterate(iteration=counter, rnorm=rnorm, x=x_next)
        except:
            success_messages = '\n'.join(map(lambda x: x.message, self.successes))
            print('Success criteria not met within maximum number of iterations')
            print(success_messages)
            raise
        finally:
            # clean up/checkpoint actions after each iteration
            # Print information on iteration
            print('   {0:2d}      {1:.3E}     {2:.3E}'.format(self.iterate.iteration, self.iterate.rnorm, abs(denergy)))
            # Check for success
            if check_success(self.successes, [self.iterate.rnorm, denergy]):
                success_messages = '\n'.join(map(lambda x: x.message, self.successes))
                print(success_messages)
                raise StopIteration
            
x_0 = np.zeros_like(b)
x_jacobi = Iterate(iteration=0, x=x_0, rnorm=LA.norm(b - np.einsum('ij,j->i', A, x_0)))
jacobi2 = IterativeSolver(stepper, x_jacobi, [maxit_exceeded], [rnorm_above_rtol, denergy_below_etol])

# First converge to a loose threshold
for _ in jacobi2:
    pass 
        
print('jacobi2.niterations ', jacobi2.niterations)
print('Jacobi relative error to reference {:.5E}\n'.format(relative_error_to_reference(jacobi2.iterate.x, x_ref)))

In [None]:
class Fibonacci(Generator):
    def __init__(self):
        self.a, self.b = 0, 1        
    def send(self, ignored_arg):
        return_value = self.a
        self.a, self.b = self.b, self.a+self.b
        return return_value
    def throw(self, type=None, value=None, traceback=None):
        raise StopIteration
        
class IterGen(Generator):
    def __init__(self, start, stop):
        self.start = start
        self.stop = stop
        self.counter = 0
        
    def send(self, value):
        try:
            if self.counter > self.stop:
                self.throw(RuntimeError, val='Exceeded maximum number of repeats!!!')
            return self.counter
        except RuntimeError as e:
            print(e)
            raise
        finally:
            # Checkpoint, if necessary
            # Check early exit convergence criteria and stop
            # Update counter
            self.counter += 1
            print('cleaning up, finally!')

    def throw(self, type, val=None, tb=None):
        # Throw if no convergence achieved and maximum number of iterations reached
        raise type(val)
        
class IterativeSolver2(Generator):
    def __init__(self, stepper, iterate, iteration_stop, early_exit):
        self.stepper = stepper
        self.iterate = iterate
        self.iteration_stop = iteration_stop
        self.early_exit = early_exit
        self.niterations = 0
        # Print iteration header
        self._header()

    def _header(self):
        print('   # It.     |r|   ')
        print('-------------------')
        
    def send(self, value):
        # We do not start from zero as this might be a restart
        counter = self.iterate.iteration
        try:
            if self.niterations > self.iteration_stop.threshold:
                self.throw(RuntimeError, val=self.iteration_stop.message)
            x_next = self.stepper(self.iterate.x)
            rnorm = LA.norm(x_next - self.iterate.x)
            self.niterations += 1
            counter += 1
            self.iterate = Iterate(iteration=counter, rnorm=rnorm, x=x_next)
            return self.iterate
        except RuntimeError as err:
            # clean up/checkpoint actions before dying
            print(err)
            raise
        finally:
            # clean up/checkpoint actions after each iteration
            # Print information on iteration
            print('   {0:d}      {1:.3E}'.format(self.iterate.iteration, self.iterate.rnorm))
            # Check convergence
            if self.iterate.rnorm < self.early_exit.threshold:
                print(self.early_exit.message)
                return self.iterate
            
    def throw(self, type, val=None, tb=None):
        # Throw if no convergence achieved and maximum number of iterations reached
        raise type(val)
            
x_0 = np.zeros_like(b)
x_jacobi = Iterate(iteration=0, x=x_0, rnorm=LA.norm(b - np.einsum('ij,j->i', A, x_0)))
jacobi_gen = IterativeSolver2(stepper, x_jacobi, maxit_exceeded, rnorm_above_rtol)

In [None]:
import contextlib
@contextlib.contextmanager
def mylist():
    try:
        l = [1, 2, 3, 4, 5]
        yield l
    finally:
        print("exit scope")

with mylist() as l:
    print(l)