In [28]:
import numpy as np
from scipy import sparse as sp
from scipy.sparse.linalg import spsolve

In [31]:
def solver(T, a, b, tau, h, f, u0):
    t = np.arange(0, T+tau, tau)
    x = np.arange(a, b+h, h)
    M = x.size - 2
    r = tau/h**2
    Delta_h = sp.diags(-2*np.ones(M)) + sp.diags(np.ones(M-1), 1) + sp.diags(np.ones(M-1), -1)
    A = sp.eye(M) - r * Delta_h
    N = t.size
    u = np.zeros((N, M))
    b = np.zeros(M)
    u[0, :] = u0(x[1:-1])
    for n in range(1, N):
        b[:] = u[n-1, :] + tau * f(x[1:-1], t[n])
        u[n, :] = spsolve(A, b)
    return t, x, u

In [30]:
def u_e(x, t):
    return np.exp(-t) * np.sin(np.pi * x)
def u0(x):
    return np.sin(np.pi * x)
def f(x, t):
    return (np.pi ** 2 - 1) * np.exp(-t) * np.sin(np.pi * x)

In [32]:
t, x, u = solver(1, 0, 1, 1/64, 1/64, f, u0)

In [35]:
np.linalg.norm(u_e(x[1:-1], t[-1])-u[-1, :])

0.0023094557947220653

In [37]:
%timeit solver(1, 0, 1, 1/512, 1/512, f, u0)

96.2 ms ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [1]:
using LinearAlgebra, SparseArrays

In [28]:
function solver(T, a, b, τ, h, f, u₀)
    t = collect(0:τ:T)
    x = collect(a:h:b)
    x_ = @view x[begin+1:end-1]
    M = length(x) - 2
    r = τ/h^2
    Δₕ = spdiagm(-1=>ones(M-1), 0=>-2*ones(M), 1=>ones(M-1))
    A = I - r * Δₕ
    N = length(t)
    u = zeros(M, N)
    b = zeros(M)
    @. u[:, 1] = u₀(x_)
    @inbounds for n in firstindex(t)+1:lastindex(t)
        @. b[:] = u[:, n-1] + τ * f(x_, t[n])
        u[:, n] .= A\b
    end
    return t, x, u
end

solver (generic function with 1 method)

In [8]:
uₑ(x, t) = exp(-t) * sinpi(x)
u₀(x) = sinpi(x)
f(x, t) = (π^2 - 1) * exp(-t) * sinpi(x)

f (generic function with 1 method)

In [29]:
t, x, u = solver(1, 0, 1, 1/64, 1/64, f, u₀);

In [30]:
norm(uₑ.(x[2:end-1], t[end]) - u[:, end])

0.002309455794769345

In [31]:
using BenchmarkTools
@benchmark solver(1, 0, 1, 1/512, 1/512, f, u₀)

BenchmarkTools.Trial: 64 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m66.229 ms[22m[39m … [35m146.579 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m1.41% … 47.62%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m75.086 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m1.84%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m78.380 ms[22m[39m ± [32m 14.074 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m6.74% ± 10.39%

  [39m [39m [39m [39m [39m▃[39m▂[39m [39m [34m█[39m[39m▃[39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▃[39m▁[39m▇[39m▇[39m█