In [5]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["JAX_PLATFORMS"] = "cpu"    

import jax
import jax.numpy as jnp
from jax import lax

In [6]:

def _solve_line_in_i(ent, ap, aw, ae, at, ab, an, a_s, su, j, k,
                     eps=1e-12, tiny=1e-13):
    ni = ent.shape[0]
    i_sl = slice(1, ni - 1) 

    ap_i = ap[i_sl, j, k]
    aw_i = aw[i_sl, j, k]
    ae_i = ae[i_sl, j, k]
    at_i = at[i_sl, j, k]
    ab_i = ab[i_sl, j, k]
    an_i = an[i_sl, j, k]
    as_i = a_s[i_sl, j, k]
    su_i = su[i_sl, j, k]

    d_i = (
        at_i * ent[i_sl, j, k + 1] +
        ab_i * ent[i_sl, j, k - 1] +
        an_i * ent[i_sl, j + 1, k] +
        as_i * ent[i_sl, j - 1, k] +
        su_i
    )

    pr0 = jnp.array(0.0, dtype=ent.dtype)
    qr0 = ent[0, j, k]

    def fwd_step(carry, x):
        pr_prev, qr_prev = carry
        ap_t, aw_t, ae_t, d_t = x

        denom = ap_t - aw_t * pr_prev

        denom = jnp.where((denom >= 0) & (denom <= eps),  denom + tiny, denom)
        denom = jnp.where((denom < 0)  & (denom >= -eps), denom - tiny, denom)

        pr_t = ae_t / denom
        qr_t = (d_t + aw_t * qr_prev) / denom
        return (pr_t, qr_t), (pr_t, qr_t)

    (_, _), (pr_i, qr_i) = lax.scan(fwd_step, (pr0, qr0), (ap_i, aw_i, ae_i, d_i))
   
   
    ent_boundary = ent[ni - 1, j, k] 

    def back_step(ent_next, x):
        pr_t, qr_t = x
        ent_t = pr_t * ent_next + qr_t
        return ent_t, ent_t

    xs_rev = (pr_i[::-1], qr_i[::-1])
    _, ent_rev = lax.scan(back_step, ent_boundary, xs_rev)
    ent_i = ent_rev[::-1]  

    ent = ent.at[i_sl, j, k].set(ent_i)
    return ent


def tdma_gs_sweep(enthalpy, ap, aw, ae, at, ab, an, a_s, su,
                  eps=1e-12, tiny=1e-13):
    ni, nj, nk = enthalpy.shape

    def body_k(t, ent):
        k = (nk - 2) - t  # t=0..nk-3 -> k=nk-2..1

        def body_j(tj, ent2):
            j = tj + 1
            return _solve_line_in_i(ent2, ap, aw, ae, at, ab, an, a_s, su, j, k, eps, tiny)

        ent = lax.fori_loop(0, nj - 2, body_j, ent)
        return ent

    enthalpy = lax.fori_loop(0, nk - 2, body_k, enthalpy)
    return enthalpy


tdma_gs_sweep_jit = jax.jit(tdma_gs_sweep)

In [7]:
key = jax.random.PRNGKey(0)
shape = (50, 50, 50)
k = jax.random.split(key, 9)

dtype = jnp.float32

enthalpy = jax.random.normal(k[0], shape, dtype=dtype)

aw = 0.05 * jax.random.normal(k[1], shape, dtype=dtype)
ae = 0.10 * jax.random.normal(k[2], shape, dtype=dtype)
ap = 10.0 + jnp.abs(0.5 * jax.random.normal(k[3], shape, dtype=dtype))  # 保证 ap>0 且偏大

at  = 0.02 * jax.random.normal(k[4], shape, dtype=dtype)
ab  = 0.02 * jax.random.normal(k[5], shape, dtype=dtype)
an  = 0.02 * jax.random.normal(k[6], shape, dtype=dtype)
a_s = 0.02 * jax.random.normal(k[7], shape, dtype=dtype)

su = 0.10 * jax.random.normal(k[8], shape, dtype=dtype)


enthalpy_new = tdma_gs_sweep_jit(enthalpy, ap, aw, ae, at, ab, an, a_s, su)

In [8]:
import time

args = (enthalpy, ap, aw, ae, at, ab, an, a_s, su)

def run_eager(*args):
    _ = tdma_gs_sweep(*args)
    return _

run_jit = tdma_gs_sweep_jit

warmup = 3
repeat = 10

for _ in range(warmup):
    out = run_eager(*args)
    jax.block_until_ready(out)

t0 = time.perf_counter()
for _ in range(repeat):
    out = run_eager(*args)
    jax.block_until_ready(out)
t1 = time.perf_counter()
print(f"[eager] avg per call: {(t1 - t0) / repeat * 1e3:.3f} ms  (repeat={repeat})")

for _ in range(warmup):
    out = run_jit(*args)
    jax.block_until_ready(out)

t0 = time.perf_counter()
for _ in range(repeat):
    out = run_jit(*args)
    jax.block_until_ready(out)
t1 = time.perf_counter()
print(f"[jit]   avg per call: {(t1 - t0) / repeat * 1e3:.3f} ms  (repeat={repeat})")

[eager] avg per call: 269.429 ms  (repeat=10)
[jit]   avg per call: 3.130 ms  (repeat=10)
