JavaTorch.java
: A Java Class that calls a function inJavaTorch.cpp
through the Java Native Interface (JNI)JavaTorch.cpp
: A JNI compatible function that loads a previously traced torch model, performs inference and returns the predicted classmodel.py
: A script that ann generate or evaluate a simple torch model
Download and unzip torchscript 2.1.2, e.g. the CPU version:
wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.1.2%2Bcpu.zip
unzip libtorch-cxx11-abi-shared-with-deps-2.1.2%2Bcpu.zip
Compile with
make
This will:
- Generate the
JavaTorch.h
header based on the class defined inJavaTorch.java
- Compile
JavaTorch.java
intoJavaTorch.class
- Compile
JavaTorch.cpp
and link with Torchscript
Run:
LD_LIBRARY_PATH=libtorch/lib java -Djava.library.path=. JavaTorch models/traced_dummy.pt
Install the python torch library, e.g. with conda/mamba
conda create -n torch pytorch=2.1.2 cpuonly numpy
Evaluate a model with the same dummy input used in JavaTorch.cpp
python model.py --mode run --model_path models/traced_dummy.pt
Or recreate the pytorch model and trace it
python model.py --mode create --model_path models/traced_dummy.pt
- Use cmake instead of custom Makefile, torch provides CMAKE recipes. JNI is also supported by cmake
- Generate the dummy data in Java and pass it to the C function. Also, pass the logits to Java instead of only the class
- Trace the preprocessing routines
- Load the model once instead of every time inference is performed. Profile inference.
- Link against libraries in the torch environment instead of additionally downloading
libtorch
Tested with:
- openjdk 21.0.2 2024-01-16
- g++ (GCC) 13.2.1 20230801
- torchscript 2.1.2 cpu version