In [1]:
%matplotlib qt

In [2]:
from pipe_flow_eqs import setup_vars,calc_cfl, get_initial_conditions, timestep, get_local_Mach
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)

from tqdm import tqdm

In [3]:
from plotting_plt import setup_plot, update_plot
from matplotlib import animation

In [4]:
dt_relax = 0.1
cfl_max = 2.0


def run(iterations = 1000, N = 10000, plotting = True):
    print(f"Starting Simulation for iteration = {iterations} N = {N}")
    x, S, dx, dS_dx = setup_vars(N)
    max_dx = jnp.max(dx)
    S_all = jnp.hstack((S[0], S, S[-1]))
    U0 = get_initial_conditions(N)
    diff = jnp.zeros_like(U0)

    checkpoint_times = [
        0.001, # Checkpoint 0 for time t=0.001
        0.002, # Checkpoint 1 for time t=0.002
        0.003, # Checkpoint 2 for time t=0.003
        0.004, # Checkpoint 3 for time t=0.004
        0.005, # Checkpoint 4 for time t=0.005
        0.008, # Checkpoint 5 for time t=0.008
        0.01, # Checkpoint 6 for time t=0.01
        0.012, # Checkpoint 7 for time t=0.012
        0.015, # Checkpoint 8 for time t=0.015
        0.02, # Checkpoint 9 for time t=0.02
        0.05, # Checkpoint 10 for time t=0.05
        0.1, # Checkpoint 11 for time t=0.1
        0.2, # Checkpoint 12 for time t=0.2
    ]

    checkpoint_flags = [
        False, # Checkpoint 0 for time t=0.001
        False, # Checkpoint 1 for time t=0.002
        False, # Checkpoint 2 for time t=0.003
        False, # Checkpoint 3 for time t=0.004
        False, # Checkpoint 4 for time t=0.005
        False, # Checkpoint 5 for time t=0.008
        False, # Checkpoint 6 for time t=0.01
        False, # Checkpoint 7 for time t=0.012
        False, # Checkpoint 8 for time t=0.015
        False, # Checkpoint 9 for time t=0.02
        False, # Checkpoint 10 for time t=0.05
        False, # Checkpoint 11 for time t=0.1
        False, # Checkpoint 12 for time t=0.2
    ]

    U_history = []
    dU_history = []
    dt_history = []
    cfls = []
    dts = []
    max_diff_u = []
    mean_diff_u = []

    dt = 5e-7
    time = 0
    U = jnp.copy(U0)
    jax.device_put(U)

    if plotting:
        fig, axs, line_rho, line_u, line_p, patches = setup_plot(x,S,dx,U,N, N_toplot=100)

    # if saving == True:
    Writer = animation.writers['ffmpeg']
    writer = Writer(fps=15, metadata=dict(artist='Me'), bitrate=1800)

    with writer.saving(fig, f"movie_N_{N}.mp4", dpi = 100):
        for i in tqdm(range(iterations)):
            U,diff= timestep(U,dt,S, S_all, dx, dS_dx)
            time+=dt 
            cfl = calc_cfl(U,dt,max_dx)

            max_diff_u.append(jnp.max(jnp.abs(diff[1,:])))
            mean_diff_u.append(jnp.mean(jnp.abs(diff[1,:])))
            dts.append(dt)
            cfls.append(cfl)

            if time > 0.2:
                break
            
            # Update dt
            dt = (
                dt * cfl_max/cfl 
            )* dt_relax + (1-dt_relax)*dt

            ## Live Plot
            if i%20 == 0:
                if plotting:
                    update_plot(U,time,fig, axs, line_rho, line_u, line_p, patches,N_toplot=100)
                    writer.grab_frame()
            
            ## Checkpointing
            # When we cross the checpoint time we need to get the U and diff, dt variables
            # and append them to the history arrays 
            for j, t in enumerate(checkpoint_times):
                if time > t and checkpoint_flags[j] == False:
                    checkpoint_flags[j] = True
                    U_history.append(jnp.copy(U))


    print("Simulation finished")
    U_checkpoints ={
        "U": U_history,
    }
    
    run_results = (U, U_checkpoints, cfls, dts, max_diff_u, mean_diff_u)
    return run_results

In [5]:
run_results100 = run(40000,100)
run_results200 = run(40000,200)
run_results500 = run(40000,500)
run_results1000 = run(40000,1000)
run_results2000 =  run(50000,2000)
run_results5000 =  run(500000,5000)
run_results10000 = run(500000,10000)

Starting Simulation for iteration = 40000 N = 100


 10%|▉         | 3821/40000 [01:42<16:13, 37.15it/s]


Simulation finished
Starting Simulation for iteration = 40000 N = 200


 19%|█▉        | 7695/40000 [03:25<14:23, 37.43it/s]


Simulation finished
Starting Simulation for iteration = 40000 N = 500


 48%|████▊     | 19324/40000 [08:36<09:12, 37.42it/s]


Simulation finished
Starting Simulation for iteration = 40000 N = 1000


 97%|█████████▋| 38711/40000 [17:01<00:34, 37.91it/s]


Simulation finished
Starting Simulation for iteration = 50000 N = 2000


100%|██████████| 50000/50000 [22:35<00:00, 36.88it/s] 


Simulation finished
Starting Simulation for iteration = 500000 N = 5000


 39%|███▉      | 193832/500000 [1:36:46<2:32:51, 33.38it/s] 


Simulation finished
Starting Simulation for iteration = 500000 N = 10000


 78%|███████▊  | 387743/500000 [3:33:12<1:01:43, 30.31it/s] 


Simulation finished


# Plot run histories

In [6]:
import matplotlib.pyplot as plt
run_results = [run_results100, run_results200, run_results500, run_results1000, run_results2000, run_results5000, run_results10000]
Ns = [100, 200, 500, 1000, 2000, 5000, 10000]

In [61]:
# Plot the max and mean differences of u for each result

#from dts calculate the time that corresponds to each iteration
dts = run_results[-1][3]
times = [0]
for dt in dts:
    times.append(times[-1]+dt)
times = times[1:]


fig, axs = plt.subplots(2,1, figsize = (10,10))
for i, result in enumerate([run_results[-1]]):
    axs[0].plot(times[::100], result[4][::100], label = f"N = {Ns[-1]}")
    axs[1].plot(times[::100], result[5][::100], label = f"N = {Ns[-1]}")
axs[0].set_yscale("log")
axs[0].set_xlabel("Time")
axs[0].set_ylabel("max ΔU at each iteration")
axs[0].legend()

axs[1].set_yscale("log")
axs[1].set_xlabel("Time")
axs[1].set_ylabel("mean ΔU at each iteration")
axs[1].legend()
plt.show()

In [64]:
# Plot the CFLs for each result
fig, ax = plt.subplots(2,1, figsize = (10,7))
ax[0].plot(run_results[-1][2], label = f"N = {Ns[-1]}")
ax[0].set_xlabel("Iteration")
ax[0].set_ylabel("CFL at each iteration")
ax[0].legend()

ax[1].plot(run_results[-1][3], label = f"N = {Ns[-1]}")
ax[1].set_yscale("log")
ax[1].set_xlabel("Iteration")
ax[1].set_ylabel("Dt at each iteration")
ax[1].legend()

plt.show()

In [55]:
# plot the checkpoints for each result
checkpoint_times = [
    0.001, # Checkpoint 0 for time t=0.001
    0.002, # Checkpoint 1 for time t=0.002
    0.003, # Checkpoint 2 for time t=0.003
    0.004, # Checkpoint 3 for time t=0.004
    0.005, # Checkpoint 4 for time t=0.005
    0.008, # Checkpoint 5 for time t=0.008
    0.01, # Checkpoint 6 for time t=0.01
    0.012, # Checkpoint 7 for time t=0.012
    0.015, # Checkpoint 8 for time t=0.015
    0.02, # Checkpoint 9 for time t=0.02
    0.05, # Checkpoint 10 for time t=0.05
    0.1, # Checkpoint 11 for time t=0.1
    0.2, # Checkpoint 12 for time t=0.2
]
time_idxs = [6,7,8,9,10,11]

# Get some colors for each line
colors = [
    "red",
    "orange",
    "blue",
    "black",
    "green",
    "purple",
    "pink"
]

fig, axs = plt.subplots(2,2, figsize = (10,10))
axs[0, 0].set_title("Mach Number")
axs[1, 0].set_title("Rho")
axs[0, 1].set_title("U")
axs[1, 1].set_title("P")

for i, time_idx in enumerate(time_idxs):
    for j,res in enumerate([run_results[-1]]):
        checkpoint = res[1]
        U = checkpoint["U"][time_idx]
        N = 10000 #Ns[j]
        x,_,_,_= setup_vars(N)
        
        axs[0,0].plot(x,get_local_Mach(U), color = colors[i], label = f"time = {checkpoint_times[time_idx]}")
        axs[1, 0].plot(x, U[0,:], color= colors[i]) 
        axs[0, 1].plot(x, U[1,:], color=colors[i])
        axs[1, 1].plot(x, U[2,:], color=colors[i])

axs[0,0].legend()
fig.tight_layout()
fig.suptitle(f"Time: {checkpoint_times[time_idx]:.4f}")
fig.show()


In [35]:
import numpy as np
from scipy.interpolate import interp1d

def resampling(curve_long, curve_short):
    # Interpolate curve2 onto the grid of curve1
    interp_curve2 = interp1d(np.linspace(0, 1, len(curve_short)), curve_short)
    curve2_resampled = interp_curve2(np.linspace(0, 1, len(curve_long)))
    return curve2_resampled

def l2_dist(curve1, curve2, dx ):
    # Compute element-wise squared differences
    squared_diff = (curve1 - curve2) ** 2

    # Compute L2 distance
    l2_dist = np.sum(np.sqrt(squared_diff)*dx[0])/2
    return l2_dist

In [36]:
dx_10000 =  setup_vars(Ns[-1])[2]
errors_rho = {
    N:[] for N in Ns
}
errors_p = {
    N:[] for N in Ns
}
errors_u = {
    N:[] for N in Ns
}
for checkp_idx in range(len(checkpoint_times)-1):
# Sample all solutions to the N=1000 solutions
    U_10000 = run_results10000[1]["U"][checkp_idx]

    U_100 = run_results100[1]["U"][checkp_idx]
    U_200 = run_results200[1]["U"][checkp_idx]
    U_500 = run_results500[1]["U"][checkp_idx]
    U_1000 = run_results1000[1]["U"][checkp_idx]
    U_2000 = run_results2000[1]["U"][checkp_idx]
    U_5000 = run_results5000[1]["U"][checkp_idx]

    # Resample the solutions to the same grid
    rho_100_resampled = resampling(U_10000[0], U_100[0])
    rho_200_resampled = resampling(U_10000[0], U_200[0])
    rho_500_resampled = resampling(U_10000[0], U_500[0])
    rho_1000_resampled = resampling(U_10000[0], U_1000[0])
    rho_2000_resampled = resampling(U_10000[0], U_2000[0])
    rho_5000_resampled = resampling(U_10000[0], U_5000[0])

    rhos_reshampled = [
        rho_100_resampled,
        rho_200_resampled,
        rho_500_resampled,
        rho_1000_resampled,
        rho_2000_resampled,
        rho_5000_resampled,
    ]

    u_100_resampled = resampling(U_10000[1], U_100[1])
    u_200_resampled = resampling(U_10000[1], U_200[1])
    u_500_resampled = resampling(U_10000[1], U_500[1])
    u_1000_resampled = resampling(U_10000[1], U_1000[1])
    u_2000_resampled = resampling(U_10000[1], U_2000[1])
    u_5000_resampled = resampling(U_10000[1], U_5000[1])

    u_reshampled = [
        u_100_resampled,
        u_200_resampled,
        u_500_resampled,
        u_1000_resampled,
        u_2000_resampled,
        u_5000_resampled,
    ]

    p_100_resampled = resampling(U_10000[2], U_100[2])
    p_200_resampled = resampling(U_10000[2], U_200[2])
    p_500_resampled = resampling(U_10000[2], U_500[2])
    p_1000_resampled = resampling(U_10000[2], U_1000[2])
    p_2000_resampled = resampling(U_10000[2], U_2000[2])
    p_5000_resampled = resampling(U_10000[2], U_5000[2])

    p_resampled = [
        p_100_resampled,
        p_200_resampled,
        p_500_resampled,
        p_1000_resampled,
        p_2000_resampled,
        p_5000_resampled,
    ]
    for j in range(len(Ns)-1):
        errors_p[Ns[j]].append(l2_dist(U_10000[2], p_resampled[j], dx_10000))
        errors_rho[Ns[j]].append(l2_dist(U_10000[0], rhos_reshampled[j], dx_10000))
        errors_u[Ns[j]].append(l2_dist(U_10000[1], u_reshampled[j], dx_10000))


In [38]:
fig, axs = plt.subplots(3,1, figsize = (10,10))
fig.suptitle('L2 Difference from N=10000')

str_Ns = [str(N) for N in Ns[:-1]]
for j in range(len(str_Ns)):
    axs[0].plot(checkpoint_times[:-1], errors_rho[Ns[j]], color = colors[j], label=f"{Ns[j]}")
    axs[1].plot(checkpoint_times[:-1], errors_u[Ns[j]], color = colors[j], label=f"{Ns[j]}")
    axs[2].plot(checkpoint_times[:-1], errors_p[Ns[j]], color = colors[j], label=f"{Ns[j]}")

axs[0].set_title("Rho")
axs[1].set_title("U")
axs[2].set_title("P")
axs[2].set_xlabel("Time")
axs[0].legend(
    title = "N",

    loc='upper left'
)
for ax in axs:
    # Set logarithmic scale
    ax.set_yscale('log')
from collections import OrderedDict
import matplotlib.pyplot as plt

handles, labels = plt.gca().get_legend_handles_labels()
by_label = OrderedDict(zip(labels, handles))
axs[0].legend(by_label.values(), by_label.keys(), title = "N",
              bbox_to_anchor = (1.05, 1), loc='upper left')
fig.show()