-
Notifications
You must be signed in to change notification settings - Fork 42
/
sphere_helmholtz.py
84 lines (68 loc) · 2.38 KB
/
sphere_helmholtz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Solve Helmholtz equation on a sphere
Using spherical coordinates
"""
import os
from shenfun import *
from shenfun.la import SolverGeneric1ND
import sympy as sp
by_parts = False
# Define spherical coordinates
r = 1
theta, phi = psi = sp.symbols('x,y', real=True, positive=True)
rv = (r*sp.sin(theta)*sp.cos(phi), r*sp.sin(theta)*sp.sin(phi), r*sp.cos(theta))
alpha = 2
# Manufactured solution
sph = sp.functions.special.spherical_harmonics.Ynm
ue = sph(6, 3, theta, phi)
#ue = sp.cos(8*(sp.sin(theta)*sp.cos(phi) + sp.sin(theta)*sp.sin(phi) + sp.cos(theta)))
#g = - ue.diff(theta, 2) - (1/sp.tan(theta))*ue.diff(theta, 1) - (1/sp.sin(theta)**2)*ue.diff(phi, 2) + alpha*ue
N, M = 60, 40
L0 = FunctionSpace(N, 'C', domain=(0, np.pi))
F1 = FunctionSpace(M, 'F', dtype='D')
T = TensorProductSpace(comm, (L0, F1), coordinates=(psi, rv, sp.Q.positive(sp.sin(theta))))
v = TestFunction(T)
u = TrialFunction(T)
# Compute the right hand side on the quadrature mesh
g = (-div(grad(u))+alpha*u).tosympy(basis=ue, psi=psi)
gj = Array(T, buffer=g)
# Take scalar product
g_hat = Function(T)
g_hat = inner(v, gj, output_array=g_hat)
# Assemble matrices.
if by_parts:
mats = inner(grad(v), grad(u))
mats += [inner(v, alpha*u)]
else:
mats = inner(v, -div(grad(u))+alpha*u)
# Solve
u_hat = Function(T)
Sol1 = SolverGeneric1ND(mats)
u_hat = Sol1(g_hat, u_hat)
# Transform back to real space.
uj = u_hat.backward()
uq = Array(T, buffer=ue)
error = np.sqrt(inner(1, abs((uj-uq)**2)))
print(f'sphere_helmholtz L2 error {error:2.6e}')
if 'pytest' not in os.environ:
# Postprocess
# Refine for a nicer plot. Refine simply pads Functions with zeros, which
# gives more quadrature points. u_hat has NxM quadrature points, refine
# using any higher number.
u_hat2 = u_hat.refine([N*3, M*3])
ur = u_hat2.backward(kind='uniform')
from mayavi import mlab
xx, yy, zz = u_hat2.function_space().local_cartesian_mesh(uniform=True)
# Wrap periodic direction around
if T.bases[1].domain == (0, 2*np.pi):
xx = np.hstack([xx, xx[:, 0][:, None]])
yy = np.hstack([yy, yy[:, 0][:, None]])
zz = np.hstack([zz, zz[:, 0][:, None]])
ur = np.hstack([ur, ur[:, 0][:, None]])
mlab.figure(bgcolor=(1, 1, 1), size=(400, 400))
mlab.mesh(xx, yy, zz, scalars=ur.real, colormap='jet')
mlab.savefig('spherewhite.png')
mlab.show()
else:
assert error < 1e-6
T.destroy()