In [1]:
import itertools as it
import numpy as np
from fenics import *

In [2]:
# Imports for plotting
import ipympl
%matplotlib widget
import matplotlib.pyplot as plt
# Has side effects allowing for 3D plots
from mpl_toolkits.mplot3d import Axes3D

In [3]:
### Define Mesh sizes
t, T = 0.0, 10.0   # start and end times
nt = 100            # number of time steps
a, b = 0, 40     # start and end spacial positions
nx = 100           # number of space steps
dt = (T - t)/ nt

### Define periodic boundary
class PeriodicBoundary(SubDomain):

    def inside(self, x, on_boundary):
        return bool(x[0] < DOLFIN_EPS and x[0] > -DOLFIN_EPS and on_boundary)

    # Map right boundary to left boundary
    def map(self, x, y):
        y[0] = x[0] - (b - a)

### Create mesh and define function spaces
mesh = IntervalMesh(nx, a, b)
F_ele = FiniteElement("CG", mesh.ufl_cell(), 1)
V = FunctionSpace(mesh, F_ele, 
                  constrained_domain=PeriodicBoundary())
W = FunctionSpace(mesh, MixedElement([F_ele, F_ele]), 
                  constrained_domain=PeriodicBoundary())

In [4]:
### Define and deduce initial values
u0_uninterpolated = Expression('2/cosh(x[0] - 403.0/15.0) + 5 / cosh(x[0] - 203.0/15.0)', degree=2)
u0 = interpolate(u0_uninterpolated, V)

# Find initial value m0 for m
q = TestFunction(V)
m = TrialFunction(V)

am = q * m * dx
Lm = (q * u0 + q.dx(0) * u0.dx(0)) * dx
      
m0 = Function(V)
solve(am == Lm, m0)

In [5]:
### Storing results
# We begin by storing the intial condition.
uvals = [u0.compute_vertex_values(mesh)]

In [6]:
### State variational problem
p, q = TestFunctions(W)
w = Function(W)  # Function to solve for
m, u = split(w)
# Relabel i.e. initialise m_prev, u_prev as m0, u0.
m_prev, u_prev = m0, u0
m_mid = 0.5 * (m + m_prev)
u_mid = 0.5 * (u + u_prev)
F = (
    (q * u + q.dx(0) * u.dx(0) - q * m) * dx +                                          # q part
    (p * (m - m_prev) + dt * (p * m_mid * u_mid.dx(0) - p.dx(0) * m_mid * u_mid)) * dx  # p part
    )
J = derivative(F, w)
problem = NonlinearVariationalProblem(F, w, J=J)
solver = NonlinearVariationalSolver(problem)

In [7]:
### Time step through the problem
for n in range(nt):
#     if n % 10 == 0:
#         E = assemble((u_prev * u_prev + u_prev.dx(0) * u_prev.dx(0)) * dx)
#         print("time {:0>4} energy {}".format(t, E))  # Energy should remain constant
    t += dt
    solver.solve()
    # For some strange reason m_prev.assign(m) fails unless a deepcopy is made
    m, u = w.split(deepcopy=True)
    uvals.append(u.compute_vertex_values(mesh))  # Save result
    m_prev.assign(m)  # Update for next loop
    u_prev.assign(u)  #

In [8]:
uvals = np.array(uvals)
xvals = mesh.coordinates()[:, 0]
tvals = np.linspace(0, T, nt + 1)

In [9]:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1, projection='3d')
ax.scatter(*np.meshgrid(xvals, tvals), uvals)

FigureCanvasNbAgg()

<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x7fe1d7332dd8>

In [10]:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1)
ax.scatter(xvals, uvals[25])

FigureCanvasNbAgg()

<matplotlib.collections.PathCollection at 0x7fe1d73002b0>