In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import scipy.sparse as sps
import scipy.sparse.linalg
from scipy.special.orthogonal import p_roots
from pyamg.classical import ruge_stuben_solver

In [None]:
from topo import SQuad
from poly import lagrange_list
from basis import LagrangeBasisQuad
from mesh import Mesh2D
from assemble import simple_assembly, simple_build_rhs

In [None]:
order   = 1
L       = 1.0
n_elems = 2

x_max = L
y_max = L
x_vals = np.linspace(0, x_max, n_elems+1)
y_vals = np.linspace(0, y_max, n_elems+1)

vertices = np.zeros(((n_elems+1)**2, 2), dtype=np.double)
elem_to_vertex = np.zeros((n_elems**2, 4), dtype=np.int)

for i in range(n_elems):
    for j in range(n_elems):
        elem = i*n_elems+j
        elem_to_vertex[elem,0] = i*(n_elems+1)+j
        elem_to_vertex[elem,1] = i*(n_elems+1)+j+1
        elem_to_vertex[elem,2] = (i+1)*(n_elems+1)+j+1
        elem_to_vertex[elem,3] = (i+1)*(n_elems+1)+j

boundary_vertices = []
for i in range(n_elems+1):
    for j in range(n_elems+1):
        v = i*(n_elems+1)+j
        vertices[v,0] = x_vals[j]
        vertices[v,1] = y_vals[i]
        if (i==0) or (j==0) or\
           (i==n_elems) or (j==n_elems):
            boundary_vertices.append(v)


In [None]:
topo  = SQuad()
nodes = vertices[elem_to_vertex]
jacb  = topo.calc_jacb(nodes)
jacb_inv = topo.calc_jacb_inv(jacb)

### Basis Test Cases

* One at node, zero at all other nodes

In [None]:
basis = LagrangeBasisQuad(SQuad, order)

In [None]:
x_vals = np.linspace(-1,1,order+1)
y_vals = x_vals
X, Y = np.meshgrid(x_vals, y_vals)
bp = basis.basis_polys[0]
k = 0
p = bp[k]
#plt.contour(X, Y, p(X, Y))

In [None]:
ref = np.array([X.ravel(), Y.ravel()]).T

coeffs = np.zeros(basis.n_dofs)
coeffs[k] = 1.0

np.all(np.max(np.abs(basis.eval_ref(coeffs, ref)-p(X,Y).ravel())<1e-12))

## Build Mesh

In [None]:
mesh = Mesh2D(topo, basis)
mesh.build_mesh(vertices, elem_to_vertex, boundary_vertices)

## Assembly

In [None]:
cub_points, cub_weights = topo.get_quadrature(order+1)
Kloc = np.zeros((basis.n_dofs, basis.n_dofs),
                dtype=np.double)
cub_vals = basis.eval_ref(np.eye(basis.n_dofs),
                          cub_points, d=1)

for i in range(basis.n_dofs):
    for j in range(basis.n_dofs):
        Kloc[i,j] = np.sum(cub_vals[i]*cub_vals[j], axis=0).dot(cub_weights)

K = simple_assembly(mesh, Kloc)

def f(X):
    shape = X.shape[:-1]
    X = X.reshape((-1,2))
    x = X[:,0]
    y = X[:,1]
    return (x*(x-L)*y*(y-L)).reshape(shape)

def f2(X):
    shape = X.shape[:-1]
    X = X.reshape((-1,2))
    x = X[:,0]
    y = X[:,1]
    return (2*y*(y-L)+x*(x-L)*2).reshape(shape)

rhs = simple_build_rhs(topo, basis, mesh, f2)

In [None]:
plt.spy(K)

In [None]:
ml = ruge_stuben_solver(K)
sol = ml.solve(rhs)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

X = vertices[:,0].reshape((n_elems+1, n_elems+1))
Y = vertices[:,1].reshape((n_elems+1, n_elems+1))
Z1 = sol[mesh.vertex_to_dof].ravel()
Z2 = f(mesh.vertices).ravel()
ax.plot_wireframe(X, Y, Z1.reshape((n_elems+1, n_elems+1)))
ax.plot_wireframe(X, Y, Z2.reshape((n_elems+1, n_elems+1)), 
                  color='g')
np.max(np.abs(Z1-Z2))