In [108]:
import jax.numpy as jnp
import numpy as np
from jax import jit
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import colormaps as cmap

Parameters

In [109]:
n = 201
m = 201
dt = 1e-1
t_total = 0.2
DL = 1
N = n + 2 * DL
M = m + 2 * DL
Ntime = int(t_total/dt)

dims = int(3) # 3 equations

In [110]:
len('themas')

6

In [111]:
Mach = 0.0
f = 10 + 0  # Hertz
Th_0 = len("themas") + 10
a = 1 / len("themas")
T_0 = 273 + Th_0
c_0 = jnp.sqrt(1.4*287.05*T_0)
u_0 = Mach*c_0
rho_0 = 1.225
c_0

Array(340.79382, dtype=float32, weak_type=True)

Grid

In [112]:
x = np.linspace(-100 - DL, 100 + DL, N)
y = np.linspace(-100 - DL, 100 + DL, M)
dx = x[1]-x[0]
dy = y[1]-y[0]
r = jnp.zeros((N,M))
def map_to_grid(x,y):
    return np.sqrt(x**2 + y**2)

r = np.array([[map_to_grid(x[i],y[j]) for j in range(M)] for i in range(N)])

Buffer Zone Parameters

In [113]:
buff_m = 2.
buff_r_0 = 100.
buff_w =DL

buff_sig_max = 10. / dx

Constant Matrices (Linear Problem)

In [116]:
A_x =   jnp.zeros((dims,dims))
Lam_x = jnp.zeros((dims,dims))
R_x =   jnp.zeros((dims,dims))
L_x =   jnp.zeros((dims,dims))
RAL_x = jnp.zeros((dims,dims))
A_y =   jnp.zeros((dims, dims))
Lam_y = jnp.zeros((dims, dims))
R_y =   jnp.zeros((dims, dims))
L_y =   jnp.zeros((dims, dims))
RAL_y = jnp.zeros((dims, dims))


A_x = A_x.at[0,0].set(u_0)
A_x = A_x.at[0,1].set(rho_0)
A_x = A_x.at[1,0].set(c_0**2/rho_0)
A_x = A_x.at[1,1].set(u_0)
A_x = A_x.at[2,2].set(u_0)

A_y = A_y.at[0,2].set(rho_0)
A_y = A_y.at[2,0].set(c_0**2/rho_0)

R_x = R_x.at[0,1].set(rho_0/c_0)
R_x = R_x.at[0,2].set(-rho_0/c_0)
R_x = R_x.at[1,1].set(1.)
R_x = R_x.at[1,2].set(1.)
R_x = R_x.at[2,0].set(1.)

Lam_x = Lam_x.at[0,0].set(jnp.abs(u_0))
Lam_x = Lam_x.at[1,1].set(jnp.abs(u_0+c_0))
Lam_x = Lam_x.at[2,2].set(jnp.abs(u_0-c_0))

# L_x = L_x.at[0, 2].set(1.)
# L_x = L_x.at[1, 0].set(c_0/(2*rho_0))
# L_x = L_x.at[2, 0].set(-c_0/(2*rho_0))
# L_x = L_x.at[1, 1].set(0.5)
# L_x = L_x.at[2, 1].set(0.5)
L_x = jnp.linalg.inv(R_x)

R_y = R_y.at[0, 1].set(rho_0 / c_0)
R_y = R_y.at[0, 2].set(-rho_0 / c_0)
R_y = R_y.at[1, 0].set(1.0)
R_y = R_y.at[2, 1].set(1.0)
R_y = R_y.at[2, 2].set(1.0)

Lam_y = Lam_y.at[1,1].set( jnp.abs(+ c_0))
Lam_y = Lam_y.at[2,2].set( jnp.abs(- c_0))

L_y = jnp.linalg.inv(R_y)
# L_y = L_y.at[0, 1].set(1.0)
# L_y = L_y.at[1, 0].set(c_0 / (2 * rho_0))
# L_y = L_y.at[2, 0].set(-c_0 / (2 * rho_0))
# L_y = L_y.at[1, 2].set(0.5)
# L_y = L_y.at[2, 2].set(0.5)

RAL_x = R_x@Lam_x@L_x
RAL_y = R_y@Lam_y@L_y


In [104]:
# U are the primal variables, dimensions: (N,M,dims,1)
from jax import jit, grad, vmap
import jax 
from jax import lax

@jit
def roes(U):
    N, M, dims = U.shape
    Roe_x = jnp.zeros((N-1, M-1, dims))
    Roe_y = jnp.zeros((N-1, M-1, dims))
    Ux_diff = jnp.gradient(U, axis=0)
    Uy_diff = jnp.gradient(U, axis=1)

    # Define a function for computing Roe_x and Roe_y for a single (i, j) pair
    def compute_roe(i, j):
        Ux_diff_ij = Ux_diff[i, j]
        Uy_diff_ij = Uy_diff[i, j]
        Roe_x_ij = RAL_x @ Ux_diff_ij
        Roe_y_ij = RAL_y @ Uy_diff_ij
        return Roe_x_ij, Roe_y_ij

    # Vectorize the compute_roe function over the i and j dimensions
    roe_x_batched, roe_y_batched = vmap(vmap(compute_roe, in_axes=(None, 0)), in_axes=(0, None))(jnp.arange(1, N-2), jnp.arange(1, M-2))

    # Assign the computed values to the appropriate slices
    Roe_x = Roe_x.at[1:N-2, 1:M-2].set(roe_x_batched)
    Roe_y = Roe_y.at[1:N-2, 1:M-2].set(roe_y_batched)

    return Roe_x, Roe_y

@jit
def flux(U, Roe_x, Roe_y):
    N, M, dims = U.shape
    flux_x = jnp.zeros((dims))
    flux_w = jnp.zeros((dims))
    flux_y = jnp.zeros((dims))
    flux_s = jnp.zeros((dims))
    F_x = jnp.zeros((N-1, M-1, dims))
    F_y = jnp.zeros((N-1, M-1, dims))

    # Define a function for computing flux for a single (i, j) pair
    def compute_flux(i, j):
        flux_x = jnp.dot(A_x, U[i, j])
        flux_w = jnp.dot(A_x, U[i-1, j])
        flux_y = jnp.dot(A_y, U[i, j])
        flux_s = jnp.dot(A_y, U[i, j-1])

        F_x_ij = 0.5 * (flux_x + flux_w - Roe_x[i, j])
        F_y_ij = 0.5 * (flux_y + flux_s - Roe_y[i, j])
        return F_x_ij, F_y_ij

    # Vectorize the compute_flux function over the i and j dimensions
    F_x_batched, F_y_batched = vmap(vmap(compute_flux, in_axes=(None, 0)), in_axes=(0, None))(jnp.arange(1, N-2), jnp.arange(1, M-2))

    # Assign the computed values to the appropriate slices
    F_x = F_x.at[1:N-2, 1:M-2].set(F_x_batched)
    F_y = F_y.at[1:N-2, 1:M-2].set(F_y_batched)

    return F_x, F_y

@jit
def source(U, r, t):
    N, M, dims = U.shape
    Q = jnp.zeros((N, M, dims))

    # Define a function for computing the source for a single (i, j) pair
    def compute_source(i, j):
        S_u = 0.0
        S_v = 0.0
        S_m = jnp.sin(2 * jnp.pi * f * t) * jnp.exp(-a * r[i, j]**2) / c_0

        damp = buff_sig_max * ((r[i, j] - buff_r_0) / buff_w)**buff_m
        def cond_fun(args):
            r, Sm , Su, Sv = args
            jax.debug.print(r)
            return r - buff_r_0 > 0
        
        def true_fun(args):
            nonlocal U
            r, Sm , Su, Sv = args
            Sm = Sm - damp * U[i, j, 0]
            Su = Su - damp * U[i, j, 1]
            Sv = Sv - damp * U[i, j, 2]
            return r, S_m, S_u, S_v
        
        def false_fun(args):
            return args
        
        _, S_m, S_u, S_v = jax.lax.cond(cond_fun, true_fun, false_fun, (r[i, j], S_m, S_u, S_v))
        
        Q_ij = jnp.array([S_m, S_u, S_v]).reshape((dims))
        return Q_ij

    # Vectorize the compute_source function over the i and j dimensions
    Q_batched = vmap(vmap(compute_source, in_axes=(None, 0)), in_axes=(0, None))(jnp.arange(N), jnp.arange(M))

    # Assign the computed values to the appropriate slices
    Q = Q.at[:, :, :].set(Q_batched)

    return Q

@jit
def integrate_step(U, U_0, t, i):
    Roe_x, Roe_y = roes(U)
    Flux_x, Flux_y = flux(U, Roe_x, Roe_y)
    Q = source(U, r, t)

    def compute_dU(i, j):
        return -dt * (
            dy * (Flux_x[i+1, j] - Flux_x[i, j]) / dx
            + dx * (Flux_y[i, j+1] - Flux_y[i, j]) / dy
            - Q[i, j]
        )

    dU = vmap(vmap(compute_dU, in_axes=(0, None)), in_axes=(None, 0))(jnp.arange(1, N-2), jnp.arange(1, M-2))
    dU2 = jnp.zeros((N, M, dims))
    dU2 = dU2.at[1:N-2, 1:M-2].set(dU)
    U = U_0 + (1.0 /( 5.0 - i)) * dU2
    return U, dU2

# @jit
def rk4(r, t_total, Ntime):
    time_s = 0.
    time_vec = jnp.linspace(time_s, t_total, Ntime)
    U = jnp.zeros((N, M, dims))
    dU = np.zeros((N, M, dims))
    U_0 = np.zeros((N, M, dims))

    for k, t in enumerate(time_vec):
        for l in range(0, 4):
            U, dU = integrate_step(U, U_0, t, l)
            U_0 = U
            if jnp.isnan(U).any():
                print("NaNs encountered at time: ", t)
                break
            print(t)

    return U, dU


In [105]:
U = rk4(r, .001, 1000)

TypeError: Pred type must be either boolean or number, got <function source.<locals>.compute_source.<locals>.cond_fun at 0x0000017EE9E3F740>.

In [None]:

rho = np.squeeze(U[:,:,0])
u = np.squeeze(U[:,:,1])
v = np.squeeze(U[:,:,2])

TypeError: tuple indices must be integers or slices, not tuple

In [None]:
L = 104

fig, ax = plt.subplots(1, 1)
cp = ax.contourf(x_mat, y_mat, u, levels=150)
x_r = np.linspace(0, 100, 100)
y_r = np.sqrt(1e4 - x_r**2)
fig.colorbar(cp)  # Add a colorbar to a plot
ax.set_title("Filled Contours Plot")
ax.set_xlabel("x (cm)")
ax.set_ylabel("y (cm)")
ax.plot(x_r, y_r, color="black")
ax.plot(-x_r, y_r, color="black")
ax.plot(-x_r, -y_r, color="black")
ax.plot(x_r, -y_r, color="black")
ax.set_xlim(-L, L)
ax.set_ylim(-L, L)
plt.show()

In [None]:
U