In [1]:
import numpy as np
import matplotlib.pyplot as plt
from celluloid import Camera
from IPython.display import HTML
from sympy import *

In [2]:
def dambreak_initial(domain):
    cells = len(domain)
    h_up, h_down = np.ones(cells//2), np.ones(cells//2)*0
    h = np.append(h_up, h_down)
    uh = h*0
    U = np.vstack((h,uh))
    return(U)

In [3]:
def initial_profile_1d_analytic(x,h_str,u_str):
    x_, t_ = sp.symbols("x t")
    expr_h  = sp.sympify(h_str)
    expr_u  = sp.sympify(u_str)
    expr_uh = expr_h*expr_u
    h  = sp.lambdify([x_,t_], expr_h,  "numpy")
    u  = sp.lambdify([x_,t_], expr_u,  "numpy")
    uh = sp.lambdify([x_,t_], expr_uh, "numpy")    
    U = [u(x,0),uh(x,0)]
    return U

In [4]:
def S_forced_solution_1d(x,t,h_str,u_str):    
    '''input: expr of function h and u 
       output: RHS of the forced solution
    '''
    t = 0
    x_, t_ = sp.symbols("x t")
    g = 9.8

    expr_h = sp.sympify(h_str)
    expr_u = sp.sympify(u_str)
    expr_uh = expr_h*expr_u

    h_t  = sp.diff(expr_h,t_)
    uh_t = sp.diff(expr_uh,t_)
    u_t  = sp.diff(expr_u,t_)

    E_1  = sp.diff(expr_uh, x_)
    E_2  = sp.diff(expr_u**2*expr_h + g*expr_h**2/2, x_)

    S_forced_1 = sp.lambdify([x_,t_], h_t + E_1, "numpy")
    S_forced_2 = sp.lambdify([x_,t_], u_t + E_2, "numpy")
    S_forced = [S_forced_1(domain,t), S_forced_2(domain,t)]
    return S_forced  

In [5]:
def SWWE_solver_1d(x, 
                   U, 
                   S=[0,0], 
                   forced_solution = 0,
                   ghost=2, 
                   g=9.8, 
                   finaltime = 0.15, 
                   boundary = 'none', 
                   reflective = True,
                   forced = False,
                   play=False):
    '''
    The domain would be a rectangle boundary
    U is the fluxes on x-axes
    forced_solution = h_str, u_str
    '''
    
    time = 0.0
    count = 0
    dx = x[1] - x[0]

    if forced == True:
        x_, t_ = symbols("x t")
        h_str,u_str = forced_solution
        U_record = np.zeros(len(x))
        
        expr_h = sympify(h_str)
        expr_u = sympify(u_str)
        expr_uh = expr_h*expr_u
        h_forced = lambdify([x_,t_],expr_h,"numpy")
        uh_forced = lambdify([x_,t_],expr_uh,"numpy")
        
        U_forced = initial_profile_1d(domain_1d,h_str,u_str)
        S = S_forced_solution_1d(domain,h_str,u_str)
        
    fig, axes = plt.subplots()
    axes.plot(x,U[0],'r')
    
    if play == True:
        camera = Camera(fig)
    
    while (time < finaltime):
        count += 1 
        h, uh = U[0,:], U[1,:]
        u = np.divide(uh,h,out=np.zeros_like(uh),where = h != 0)

        E1 = uh
        E2 = u**2*h + (1/2)*g*h**2
        E = np.vstack((E1,E2))

        U_R, U_L = U[:,1:], U[:,:-1]
        E_R, E_L = E[:,1:], E[:,:-1]
        u_R, u_L = u[1:], u[:-1]         #
        h_R, h_L = h[1:], h[:-1]         # 

        a_plus  = np.maximum(np.maximum(u_L + np.sqrt(g*h_L), u_R + np.sqrt(g*h_R)), 0.0*u_L)
        a_minus = np.minimum(np.minimum(u_L - np.sqrt(g*h_L), u_R - np.sqrt(g*h_R)), 0.0*u_L)

        dt  = (dx)/max(np.max(a_plus),np.max(-a_minus))
        time += dt

        np.seterr(divide='warn')
        E_HLL = np.divide((a_plus * E_L - a_minus * E_R)+ a_plus*a_minus*(U_R - U_L),
                          (a_plus- a_minus),
                          out=np.zeros_like((a_plus * E_L - a_minus * E_R)+ a_plus*a_minus*(U_R - U_L)),
                          where = a_plus - a_minus != 0)

        E_R, E_L = E_HLL[:,1:], E_HLL[:,:-1]
        
        if forced==True:
            U[:,1:-1] -= dt/dx*(E_R - E_L) + dt*S[:,1:-1]
        else:
            U[:,1:-1] -= dt/dx*(E_R - E_L)  
        
        if reflective == True:
            U[1,-2] = -U[1,-2]   
            U[1,1] = -U[1,1]
        
        axes.plot(x[3:-3],U[0,3:-3],'b')
        
        if forced == True:
            U_record = np.vstack((U_record,U[1,2:-2]))
                    
        if play == True:
            camera.snap()
        axes.plot(x[3:-3],U[1,3:-3],'b')
    plt.close('all')
    
    plt.plot(x[3:-3],u[3:-3])
    plt.show()
    print("It's done.  Time step=",count)
    
    if play == True:
        animation = camera.animate(blit=False, interval=10)
        play = HTML(animation.to_html5_video())
    
    if forced == True:
        return (U, E, axes, U_record, count)
    else:
        return (U, E, fig, axes, play) 

In [6]:
x = np.linspace(-2, 2, 400)
U = profile_1d(domain_1d)
U, E, play, x, u = SWWE_solver_1d(domain_1d,U, finaltime=0.2)

NameError: name 'profile_1d' is not defined

In [None]:
plt.plot(x[3:-3],u[3:-3])