In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import scipy.sparse as sps
import scipy.sparse.linalg
from scipy.special.orthogonal import p_roots
from pyamg.classical import ruge_stuben_solver

In [None]:
from topo import SQuad
from poly import lagrange_list, lobatto_list
from basis import LagrangeBasisQuad
from basis import LobattoBasisQuad
from mesh import Mesh, uniform_nodes_2d
from assemble import simple_assembly, simple_build_rhs
from poisson import poisson_Kloc

## Mesh nodes

In [None]:
order   = 4
L       = 1.0
n_elems = 8

x_max = L
y_max = L

periodic = True

vertices, elem_to_vertex, boundary_vertices,\
         get_elem_ref, maps = \
         uniform_nodes_2d(n_elems, x_max, y_max, True, True)
        
vertex_map = maps[0]
edge_map   = maps[1]

## Assembly

In [None]:
topo  = SQuad()
basis = LagrangeBasisQuad(topo, order)
#basis = LobattoBasisQuad(topo, order)
mesh  = Mesh(topo, basis)
mesh.build_mesh(vertices, elem_to_vertex, boundary_vertices)

nodes = vertices[elem_to_vertex]
jacb  = topo.calc_jacb(nodes)
jacb_det = topo.calc_jacb_det(jacb)
jacb_inv = topo.calc_jacb_inv(jacb)
jacb     = jacb[0]
jacb_det = jacb_det[0]
jacb_inv = jacb_inv[0]

In [None]:
# ref = np.array([[0,0],[.5,.5]])
# basis.eval_ref(ref, d=1)

In [None]:
if periodic:
    mesh.apply_dof_maps(vertex_map, edge_map)
    mesh.reorder_dofs()
    mesh.boundary_dofs = [0]
else:
    mesh.reorder_dofs()

In [None]:
Kloc = poisson_Kloc(basis, jacb_det, jacb_inv)
K = simple_assembly(mesh, Kloc)

def f(X):
    shape = X.shape[:-1]
    X = X.reshape((-1,2))
    x = X[:,0]
    y = X[:,1]
    return (x*(x-x_max)*y*(y-y_max)).reshape(shape)

def f2(X):
    shape = X.shape[:-1]
    X = X.reshape((-1,2))
    x = X[:,0]
    y = X[:,1]
    return -(2*y*(y-y_max)+x*(x-x_max)*2).reshape(shape)

k1 = 1.0
k2 = 1.0
def f(X):
    shape = X.shape[:-1]
    X = X.reshape((-1,2))
    x = X[:,0]
    y = X[:,1]
    sol  = np.sin(k1*2*np.pi*x/x_max)
    sol *= np.sin(k2*2*np.pi*y/y_max)
    return sol.reshape(shape)

def f2(X):
    sol  = -f(X)
    sol *=  (k1*2*np.pi/x_max)**2\
           +(k2*2*np.pi/y_max)**2
    return -sol

rhs = simple_build_rhs(topo, basis, mesh, f2)

In [None]:
plt.spy(K)
(K-K.T).nnz

In [None]:
ml = ruge_stuben_solver(K)
residuals = []
sol = ml.solve(rhs, tol=1e-12, residuals=residuals, maxiter=5000,
               accel='cg')
sol[mesh.boundary_dofs] = 0.0
len(residuals), residuals[-1]

In [None]:
f(mesh.get_dof_phys())-sol

In [None]:
n = 100
x_vals = np.linspace(0,x_max,n)
y_vals = np.linspace(0,y_max,n)
X, Y = np.meshgrid(x_vals, y_vals)
X = X.ravel()
Y = Y.ravel()
phys = np.zeros((len(X),2), dtype=np.double)
phys[:,0] = X
phys[:,1] = Y

elem, ref = get_elem_ref(phys)
Z1 = mesh.eval_elem_ref(sol, elem, ref)
Z2 = f(phys)

X = X.reshape((n,n))
Y = Y.reshape((n,n))
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot_wireframe(X, Y, Z1.reshape((n,n)))
ax.plot_wireframe(X, Y, Z2.reshape((n,n)), 
                  color='g')
plt.show()
np.max(np.abs(Z1-Z2))

In [None]:
k = 20
plt.plot(X[k,:], Z1.reshape((n,n))[k,:])
plt.plot(X[k,:], Z2.reshape((n,n))[k,:])

In [None]:
if basis.is_nodal and False:
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    dof_phys = mesh.get_dof_phys()
    X = dof_phys[:,0]
    Y = dof_phys[:,1]
    Z1 = sol
    Z2 = f(dof_phys)
    ax.plot_wireframe(X, Y, Z1)
    ax.plot_wireframe(X, Y, Z2, 
                      color='g')
    print np.max(np.abs(Z1-Z2))