In [None]:
from math import pi
import numpy as np
import scipy as sp
import scipy.linalg as spl
import numpy.linalg as npl
import matplotlib.pyplot as plt
import scipy.sparse as sps
import scipy.sparse.linalg as sspl

In [None]:
def A(dz,N):
    N = N-2
    tmp = -1/(dz**2)*np.ones((N-1))
    diag = 2/(dz**2)*np.ones(N)
    diag[0] /= 2
    diag[-1] /= 2
    return sps.diags((tmp,diag,tmp),(-1,0,1))

def source(Z):
    return (1 / (2 * pi * sigma**2)**.5) * np.exp(-0.5*((Z-50)**2)/sigma**2)

def get_r(Z,rmax,s,theta):
    return rmax - s*(Z-theta)**2

def get_rho(dz,U):
    return dz*(U[0]/2 + U[-1]/2 + np.sum(U[1:-1]))



def solve_explicit(A,U,Z,dz,dt):
    r = get_r(Z, r_max, s, theta)
    rho = get_rho(dz,U)
    tmp1 = sps.eye(A.shape[0])
    tmp2 = dt*A
    tmp3 = (dt*(r-kappa*rho))[1:-1]
    U_next = (tmp1 + tmp2).dot(U * tmp3)
    return U_next

def solve_implicite(A, U, Z, dz, dt):
    r = get_r(Z, r_max, s, theta)
    rho = get_rho(dz,U)
    
    
    tmpl_1 = np.eye(A.shape[0])
    tmpl_2 = dt * A
    
    tmpr_1 = 1
    tmpr_2 = dt * (r - kappa * rho)[1:-1]
    
    TMPR = (tmpr_1 + tmpr_2) * U
#     print((tmpl_1 - tmpl_2).shape, TMPR.shape)
    
    return sspl.cg(tmpl_1 - tmpl_2, TMPR)[0]

def solve_implicite_stat(A, U, Z, dz, dt):
    r = get_r(Z, r_max, s, theta)
    rho = get_rho(dz,U)
    
    tmpl_2 = dt * A
    
    tmpr_2 = dt * (r - kappa * rho)[1:-1]
    
    TMPR = tmpr_2 * U
#     print((tmpl_1 - tmpl_2).shape, TMPR.shape)
    
    return sspl.cg(- tmpl_2, TMPR)[0]

def F(U, r, rho):
    return U * (r - kappa * rho)

def solve_splitting(A, U, Z, dz, dt):
    r = get_r(Z, r_max, s, theta)
    rho = get_rho(dz,U)
    
    tempU1 = sspl.cg(np.eye(A.shape[0]) - dt/2 * A, U)[0]
    tempU2 = tempU1 + dt * F(tempU1, r, rho)
    return sspl.cg(np.eye(A.shape[0]) - dt/2 * A, tempU2)[0]

In [None]:
kappa = 1
s = 1
r_max = 1
sigma = 0.4
theta = 45

In [None]:
N = 1000
a = 0
b = 100
dz = (b-a)/N
dt = 0.002
T = 5
nbT = int(T/dt)
fig = plt.figure(figsize = (15,8))
Z = np.linspace(a,b,N)
U = np.zeros(N)
U[1:-1] = np.copy(source(Z[1:-1]))
U[0] = U[1]
U[-1] = U[-2]
for i in range(0, nbT+1):
    if i % (nbT//2) == 0:
        plt.plot(Z,U, ".--", label = "approximated solution at t = {0}".format(i*dt))
    U[1:-1] = np.copy(solve_explicit(A(dz,N),U[1:-1],Z,dz,dt))
    U[0] = U[1]
    U[-1] = U[-2]
    
plt.legend()
plt.grid()
plt.show()

In [None]:
N = 1000
a = 0
b = 100
dz = (b-a)/N
dt = 0.02
T = 5
nbT = int(T/dt)
fig = plt.figure(figsize = (15,8))
Z = np.linspace(a,b,N)
U = np.zeros(N)
U[1:-1] = np.copy(source(Z[1:-1]))
U[0] = U[1]
U[-1] = U[-2]
for i in range(0, nbT+1):
    if i % (nbT//2) == 0:
        plt.plot(Z,U, ".--", label = "approximated solution at t = {0}".format(i*dt))
    U[1:-1] = np.copy(solve_implicite_stat(A(dz,N),U[1:-1],Z,dz,dt))
    U[0] = U[1]
    U[-1] = U[-2]
plt.legend()
plt.grid()
plt.show()

In [None]:
N = 1000
a = 0
b = 100
dz = (b-a)/N
dt = 0.02
T = 5
nbT = int(T/dt)
fig = plt.figure(figsize = (15,8))
Z = np.linspace(a,b,N)
U = np.zeros(N)
U[1:-1] = np.copy(source(Z[1:-1]))
U[0] = U[1]
U[-1] = U[-2]
for i in range(0, nbT+1):
    if i % (nbT//2) == 0:
        plt.plot(Z,U, ".--", label = "approximated solution at t = {0}".format(i*dt))
    U[1:-1] = np.copy(solve_splitting(A(dz,N),U[1:-1],Z[1:-1],dz,dt))
    U[0] = U[1]
    U[-1] = U[-2]
plt.legend()
plt.grid()
plt.show()

In [None]:
N = 1000
a = 0
b = 100
dz = (b-a)/N
dt = 0.002
T = 5
nbT = int(T/dt)
fig = plt.figure(figsize = (15,8))
Z = np.linspace(a,b,N)
U = np.zeros(N)
U[1:-1] = np.copy(source(Z[1:-1]))
U[0] = U[1]
U[-1] = U[-2]
for i in range(0, 4):
    plt.plot(Z,U, ".--", label = "approximated solution at t = {0}".format(i*dt))
    U[1:-1] = np.copy(solve_splitting(A(dz,N),U[1:-1],Z[1:-1],dz,dt))
    U[0] = U[1]
    U[-1] = U[-2]
plt.legend()
plt.grid()
plt.show()