Skip to content

Commit

Permalink
Merge pull request #541 from team-ocean/fix-gpu-extension
Browse files Browse the repository at this point in the history
Fix gpu extension
  • Loading branch information
dionhaefner committed Oct 11, 2023
2 parents 00b4cc0 + 8a7879d commit 7d78a26
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 12 deletions.
73 changes: 73 additions & 0 deletions benchmarks/tdma_benchmark.py
@@ -0,0 +1,73 @@
from benchmark_base import benchmark_cli

from time import perf_counter

import numpy as np
from veros import logger
from veros.pyom_compat import load_pyom, pyom_from_state


@benchmark_cli
def main(pyom2_lib, timesteps, size):
from veros.state import get_default_state
from veros.distributed import barrier
from veros.core.utilities import create_water_masks
from veros.core.operators import flush, solve_tridiagonal

state = get_default_state()

with state.settings.unlock():
state.settings.update(
nx=size[0],
ny=size[1],
nz=size[2],
enable_neutral_diffusion=True,
)

state.initialize_variables()
state.variables.__locked__ = False

nx, ny, nz = 70, 60, 50
a, b, c, d = (np.random.randn(nx, ny, nz) for _ in range(4))
kbot = np.random.randint(0, nz, size=(nx, ny))

if not pyom2_lib:
_, water_mask, edge_mask = create_water_masks(kbot, nz)

def run():
out_vs = solve_tridiagonal(a, b, c, d, water_mask, edge_mask)
return out_vs

else:
pyom_obj = load_pyom(pyom2_lib)
pyom_obj = pyom_from_state(state, pyom_obj, init_streamfunction=False)

def run():
out_pyom = np.zeros((nx, ny, nz))
for i in range(nx):
for j in range(ny):
ks = kbot[i, j] - 1
ke = nz

if ks < 0:
continue

out_pyom[i, j, ks:ke] = pyom_obj.solve_tridiag(
a=a[i, j, ks:ke], b=b[i, j, ks:ke], c=c[i, j, ks:ke], d=d[i, j, ks:ke], n=ke - ks
)
return out_pyom

for _ in range(timesteps):
start = perf_counter()

run()
flush()
barrier()

end = perf_counter()

logger.debug(f"Time step took {end-start}s")


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion test/pyom_consistency/tridiag_test.py
Expand Up @@ -31,7 +31,8 @@ def test_solve_tridiag_jax(pyom2_lib, use_ext):
)

_, water_mask, edge_mask = create_water_masks(kbot, nz)
out_vs = solve_tridiagonal_jax(a, b, c, d, water_mask, edge_mask, use_ext=use_ext)
object.__setattr__(runtime_settings, "use_special_tdma", use_ext)
out_vs = solve_tridiagonal_jax(a, b, c, d, water_mask, edge_mask)

np.testing.assert_allclose(out_pyom, out_vs)

Expand Down
6 changes: 4 additions & 2 deletions veros/core/operators.py
Expand Up @@ -99,11 +99,13 @@ def scan_numpy(f, init, xs, length=None):
return carry, np.stack(ys)


@veros_kernel(static_args=("use_ext",))
def solve_tridiagonal_jax(a, b, c, d, water_mask, edge_mask, use_ext=None):
@veros_kernel
def solve_tridiagonal_jax(a, b, c, d, water_mask, edge_mask):
import jax.lax
import jax.numpy as jnp

use_ext = runtime_settings.use_special_tdma

try:
from veros.core.special.tdma_ import tdma, HAS_CPU_EXT, HAS_GPU_EXT
except ImportError:
Expand Down
10 changes: 5 additions & 5 deletions veros/core/special/tdma_.py
Expand Up @@ -148,7 +148,7 @@ def tdma_xla_encode_gpu(ctx, a, b, c, d, system_depths):
raise ValueError("TDMA does not support system_depths argument on GPU")

x_aval, *_ = ctx.avals_in
x_nptype = x_aval.dtype
np_dtype = x_aval.dtype

x_type = ir.RankedTensorType(a.type)
dtype = x_type.element_type
Expand All @@ -159,7 +159,7 @@ def tdma_xla_encode_gpu(ctx, a, b, c, d, system_depths):
np.dtype(np.float64),
)

if x_nptype not in supported_dtypes:
if np_dtype not in supported_dtypes:
raise TypeError(f"TDMA only supports {supported_dtypes} arrays, got: {dtype}")

# compute number of elements to vectorize over
Expand All @@ -169,9 +169,9 @@ def tdma_xla_encode_gpu(ctx, a, b, c, d, system_depths):

system_depth = dims[-1]

if dtype is np.dtype(np.float32):
if np_dtype is np.dtype(np.float32):
kernel = b"tdma_cuda_float"
elif dtype is np.dtype(np.float64):
elif np_dtype is np.dtype(np.float64):
kernel = b"tdma_cuda_double"
else:
raise RuntimeError("got unrecognized dtype")
Expand All @@ -187,7 +187,7 @@ def tdma_xla_encode_gpu(ctx, a, b, c, d, system_depths):
out = custom_call(
kernel,
operands=(a, b, c, d),
result_tyes=out_types,
result_types=out_types,
result_layouts=out_layouts,
operand_layouts=(arr_layout,) * 4,
backend_config=descriptor,
Expand Down
9 changes: 5 additions & 4 deletions veros/runtime.py
Expand Up @@ -101,11 +101,12 @@ def set_log_all_processes(val):
"log_all_processes": RuntimeSetting(set_log_all_processes, False),
"use_io_threads": RuntimeSetting(parse_bool, False),
"io_timeout": RuntimeSetting(float, 20),
"hdf5_gzip_compression": RuntimeSetting(bool, True),
"force_overwrite": RuntimeSetting(bool, False),
"diskless_mode": RuntimeSetting(bool, False),
"pyom_compatibility_mode": RuntimeSetting(bool, False),
"hdf5_gzip_compression": RuntimeSetting(parse_bool, True),
"force_overwrite": RuntimeSetting(parse_bool, False),
"diskless_mode": RuntimeSetting(parse_bool, False),
"pyom_compatibility_mode": RuntimeSetting(parse_bool, False),
"setup_file": RuntimeSetting(str, None, read_from_env=False),
"use_special_tdma": RuntimeSetting(parse_bool, None),
}


Expand Down

0 comments on commit 7d78a26

Please sign in to comment.