In [1]:
from pathlib import Path
from pprint import pprint

from jfi import jaxm

import torch
import torch.utils.cpp_extension as cpp_extension
import numpy as np

from src.python.torch_call_jax import build, torch_call, wrap_torch_fn

In [2]:
mod = cpp_extension.load(name="my_ops", sources=["src/cpp/torch_call.cu"])
build(mod)

In [20]:
def test_fn(a, b):
    return (a + b) / torch.norm(a), a

In [21]:
a, b = torch.randn(3, device="cuda"), torch.randn(3, device="cuda")
test_fn_jax = wrap_torch_fn(mod, test_fn, [a, b], id=21)
out1 = test_fn(a, b)

In [22]:
a, b = jaxm.to(jaxm.array(a.cpu().numpy()), device="cuda"), jaxm.to(jaxm.array(b.cpu().numpy()), device="cuda")
out2 = test_fn_jax(a, b)

In [23]:
out1

(tensor([-0.7079, -0.8680, -0.1019], device='cuda:0'),
 tensor([-0.1051,  0.7972,  0.2765], device='cuda:0'))

In [24]:
out2

[Array([-0.7078985 , -0.86804414, -0.10188818], dtype=float32),
 Array([-0.10506006,  0.7971908 ,  0.27645925], dtype=float32)]

In [3]:
device, dtype = "cuda", jaxm.float64
#device, dtype = "cpu", jaxm.float64
a, b = jaxm.randn(50, device=device, dtype=dtype), jaxm.randn(50, device=device, dtype=dtype)
out = torch_call(a, b)[0]
print(out)
expected = a * jaxm.sum(b) + b * jaxm.sum(a)
err = jaxm.norm(out - expected) / jaxm.norm(expected)
print(f"err = {err:.4e}")

op_name = gpu_torch_call_f64
ir_dtype = tensor<50xf64>
dtype: float64
[ -7.69301755   9.85556456  17.91446018   5.05608369   2.9735934
   8.47066228   0.18601189   9.81644489   2.02669975   6.33979124
   5.99166052  -4.96604249   1.89614211  11.30684113  -2.90173717
  11.41611024 -15.54126008   2.22760819   9.07442734  -8.23877566
  -7.02110279 -11.70127286  -3.60064385  13.64872768  -0.62094695
  -8.94874319  11.81052007   1.60637174  -4.29390398  -1.55727641
  -5.68923435  -3.1449265    2.95529341  -8.01558365  -7.18923466
   2.70486956   0.16546794 -13.86874635 -12.74236983   7.19246436
  -7.70809397  -0.73172217   5.07880735  -3.41690986 -22.41979778
   1.00101514   2.68671555   4.02988285  17.04027984   1.78406585]
err = 1.5036e-16


In [4]:
%timeit out = torch_call(a, b)[0]

189 µs ± 5.63 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
