In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
a = 1
nx = 1000

T = 1
nt = 1000

In [None]:
def F(u):
   return 0.5*u**2

def Fp(u):
   return u

def Fp_inv(u):
   return u

In [None]:
uL = -0.5
uR = 1

In [None]:
def riemann(x_t):
   if uL >= uR:
      sigma = (F(uL) - F(uR)) / (uL - uR)
      if x_t <= sigma:
         return uL
      else:
         return uR
   else:
      if x_t <= Fp(uL):
         return uL
      elif x_t <= Fp(uR):
         return Fp_inv(x_t)
      else:
         return uR

In [None]:
x = np.linspace(-a, a, nx)
t = np.linspace(0, T, nt)

u = np.zeros((nx, nt))

for ix in range(nx):
   if x[ix] <= 0:
      u[ix, 0] = uL
   else:
      u[ix, 0] = uR
   for it in range(1, nt):
      u[ix, it] = riemann(x[ix] / t[it])

In [None]:
def plot_waves(style):
   if uL >= uR:
      plt.plot(((F(uL) - F(uR)) / (uL - uR))*t, t, style)
   else:
      plt.plot(Fp(uL)*t, t, style)
      plt.plot(Fp(uR)*t, t, style)

In [None]:
Xs,Ts = np.meshgrid(x,t)
plt.pcolormesh(Xs, Ts, u.T)

plot_waves("red")

plt.xlim((-a,a))
plt.ylim((0,1))
plt.xlabel("x")
plt.ylabel("t")

In [None]:
plot_waves("k")
plt.plot(x, 0*x, "k")
plt.xlabel("x")
plt.ylabel("t")

In [None]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

In [None]:
%%capture
fig, axis = plt.subplots(1, 2)
fig.set_figwidth(12.0)

axis[0].pcolormesh(Xs, Ts, u.T)
if uL >= uR:
   axis[0].plot(((F(uL) - F(uR)) / (uL - uR))*t, t, 'red')
else:
   axis[0].plot(Fp(uL)*t, t, 'red')
   axis[0].plot(Fp(uR)*t, t, 'red')

line_0, = axis[0].plot([])
line_1, = axis[1].plot([])
axis[0].set_xlim(-a, a)
axis[1].set_xlim(-a, a)
axis[0].set_ylim(0, 1)

minv = min(uL,uR)
maxv = max(uL,uR)

axis[1].set_ylim(minv - 0.1*(maxv-minv), max(uL,uR) + 0.1*(maxv-minv))

def set_data(frame):
   it = int(np.round(nt / nframes * frame))
   line_0.set_data((x, 0*x + t[it]))
   line_1.set_data((x, u[:,it]))

In [None]:
nframes = 50
anim = FuncAnimation(fig, set_data, frames=nframes)
HTML(anim.to_jshtml())