In [1]:
import itertools as it
import numpy as np
from fenics import *

In [2]:
# Imports for plotting
import ipympl
%matplotlib widget
import matplotlib.pyplot as plt
# Has side effects allowing for 3D plots
from mpl_toolkits.mplot3d import Axes3D

In [3]:
### Define Mesh sizes
t, T = 0.0, 10.0   # start and end times
nt = 100            # number of time steps
a, b = 0, 40     # start and end spacial positions
nx = 100           # number of space steps
dt = (T - t)/ nt

### Define periodic boundary
class PeriodicBoundary(SubDomain):

    def inside(self, x, on_boundary):
        return bool(x[0] < DOLFIN_EPS and x[0] > -DOLFIN_EPS and on_boundary)

    # Map right boundary to left boundary
    def map(self, x, y):
        y[0] = x[0] - (b - a)

### Create mesh and define function spaces
mesh = IntervalMesh(nx, a, b)
F_ele = FiniteElement("CG", mesh.ufl_cell(), 1)
V = FunctionSpace(mesh, F_ele, constrained_domain=PeriodicBoundary())
W = FunctionSpace(mesh, MixedElement([F_ele, F_ele]), constrained_domain=PeriodicBoundary())

In [4]:
### Define and deduce initial values
w0 = Function(W)
m0, u0 = w0.split()
u0_uninterpolated = Expression('2/cosh(x[0] - 403.0/15.0) + 5 / cosh(x[0] - 203.0/15.0)', degree=2)
u0.assign(interpolate(u0_uninterpolated, V))

# Find initial value m0 for m
q = TestFunction(V)
m = TrialFunction(V)

am = q * m * dx
Lm = (q * u0 + q.dx(0) * u0.dx(0)) * dx
      
m0 = Function(V)
solve(am == Lm, m0)

# Put m0, u0 back into w0
wvector = w0.vector()
for index, val in enumerate(it.chain(np.array(m0.vector()), np.array(u0.vector()))):
    wvector[index] = val

In [5]:
### Storing results
# We begin by storing the intial condition.
uvals = [u0.compute_vertex_values(mesh)]

In [6]:
### State variational problem
# Relabel w0 as w_prev. Mathematically, let w_prev equal w0.
w_prev = w0
p, q = TestFunctions(W)
# Function to solve for
w = Function(W)
m, u = split(w)
m_prev, u_prev = split(w_prev)
m_mid = 0.5 * (m + m_prev)
u_mid = 0.5 * (u + u_prev)
F = (
    (q * u + q.dx(0) * u.dx(0) - q * m) * dx +                                          # q part
    (p * (m - m_prev) + dt * (p * m_mid * u_mid.dx(0) - p.dx(0) * m_mid * u_mid)) * dx  # p part
    )
J = derivative(F, w)
problem = NonlinearVariationalProblem(F, w, J=J)
solver = NonlinearVariationalSolver(problem)

In [6]:
### Time step through the problem
m, u = w.split()
m_prev, u_prev = w_prev.split()
for n in range(nt):
    t += dt
    if n % 10 == 0:
        E = assemble((u_prev * u_prev + u_prev.dx(0) * u_prev.dx(0)) * dx)
        print("time {:0>4} energy {}".format(t, E))  # Energy should remain constant
    solver.solve()
    uvals.append(u.compute_vertex_values(mesh))  # Save result
    w_prev.assign(w)  # Update for next loop

time 00.1 energy 684.4514730565384
time 00.2 energy 684.4514730565384
time 0.30000000000000004 energy 684.4514730565384
time 00.4 energy 684.4514730565384
time 00.5 energy 684.4514730565384
time 00.6 energy 684.4514730565384
time 00.7 energy 684.4514730565384
time 0.7999999999999999 energy 684.4514730565384
time 0.8999999999999999 energy 684.4514730565384
time 0.9999999999999999 energy 684.4514730565384
time 1.0999999999999999 energy 684.4514730565384
time 01.2 energy 684.4514730565384
time 01.3 energy 684.4514730565384
time 1.4000000000000001 energy 684.4514730565384
time 1.5000000000000002 energy 684.4514730565384
time 1.6000000000000003 energy 684.4514730565384
time 1.7000000000000004 energy 684.4514730565384
time 1.8000000000000005 energy 684.4514730565384
time 1.9000000000000006 energy 684.4514730565384
time 2.0000000000000004 energy 684.4514730565384
time 2.1000000000000005 energy 684.4514730565384
time 2.2000000000000006 energy 684.4514730565384
time 2.3000000000000007 energy 68

In [7]:
uvals = np.array(uvals)
xvals = mesh.coordinates()[:, 0]
tvals = np.linspace(0, T, nt + 1)

In [8]:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1, projection='3d')
ax.scatter(*np.meshgrid(xvals, tvals), uvals)

FigureCanvasNbAgg()

<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x7eff3ff2ed30>