In [1]:
import numpy as np
import urllib
import matplotlib.pyplot as plt
from skimage.restoration import unwrap_phase
from tqdm import trange
from scipy.signal import gaussian
from scipy.sparse import diags
import scipy.sparse.linalg as splinalg
np.set_printoptions(2)
%load_ext Cython

In [2]:
%%cython --compile-args=-O3
##Taken from https://github.com/cpcloud/PyTDMA

import numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
cdef void solve(int n, double[:] lower, double[:] diag, double[:] upper,
                    double[:] rhs, double[:] x):
    
    cdef double m
    cdef int i, im1, nm1 = n - 1
    
    for i in xrange(n):
        im1 = i - 1
        m = lower[i] / diag[im1]
        diag[i] -= m * upper[im1]
        rhs[i] -= m * rhs[im1]

        x[nm1] = rhs[nm1] / diag[nm1]

        for i in xrange(n - 2, -1, -1):
            x[i] = (rhs[i] - upper[i] * x[i + 1]) / diag[i]


cpdef double[:] tdma(double[:] a, double[:] b, double[:] c,double[:] d):
    cdef int n = b.shape[0]
    cdef double[:] x = np.zeros(n, dtype=np.float64)
    solve(n, a, b, c, d, x)
    return x

In [3]:
N = 500
a = np.random.random(N)
b = np.random.random(N)
c = np.random.random(N)
d = np.random.random(N)

In [4]:
%timeit np.array(tdma(a,b,c,d))

2.32 ms ± 46.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
x = np.random.random(N)
b = diags(x)+diags(x[1:]**2,offsets=1)+diags(x[:-1]**3,offsets=1)

In [6]:
%timeit splinalg.spsolve(b,x)

313 µs ± 45 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [7]:
0.5e-6*1e3*2

0.001