Skip to content

tum-pbs/PhiML

Repository files navigation

Φ<sub>ML</sub>

ΦML

Build Status DOI PyPI pyversions PyPI license Code Coverage Google Collab Book

ΦML is a math and neural network library designed for science applications. It enables you to quickly evaluate many network architectures on your data sets, perform linear and non-linear optimization, and write differentiable simulations. ΦML is compatible with Jax, PyTorch, TensorFlow and NumPy and your code can be executed on all of these backends.

📖 Documentation   •   🔗 API   •   ▶ Videos   •   Introduction   •   Examples

Installation

Installation with pip on Python 3.6 and later:

$ pip install phiml

Install PyTorch, TensorFlow or Jax to enable machine learning capabilities and GPU execution. For optimal GPU performance, you may compile the custom CUDA operators, see the detailed installation instructions.

You can verify your installation by running

$ python3 -c "import phiml; phiml.verify()"

This will check for compatible PyTorch, Jax and TensorFlow installations as well.

Why should I use ΦML?

Unique features

Compatibility

  • Writing code that works with PyTorch, Jax, and TensorFlow makes it easier to share code with other people and collaborate.
  • Your published research code will reach a broader audience.
  • When you run into a bug / roadblock with one library, you can simply switch to another.
  • ΦML can efficiently convert tensors between ML libraries on-the-fly, so you can even mix the different ecosystems.

Fewer mistakes

What parts of my code are library-agnostic?

With ΦML, you can write a full neural network training script that can run with Jax, PyTorch and TensorFlow. In particular, ΦML provides abstractions for the following functionality:

However, ΦML does not currently abstract the following use cases:

  • Custom or non-standard network architectures or optimizers require backend-specific code.
  • ΦML abstracts compute devices but does not currently allow mapping operations onto multiple GPUs.
  • ΦML has no data loading module. However, it can convert data, once loaded, to any other backend.
  • Some less-used math functions have not been wrapped yet. If you come across one you need, feel free to open an issue.
  • Higher-order derivatives are not supported for all backends.

ΦML's Tensor class

Many of ΦML's functions can be called on native tensors, i.e. Jax/PyTorch/TensorFlow tensors and NumPy arrays. In these cases, the function maps to the corresponding one from the matching backend.

However, we have noticed that code written this way is often hard-to-read, verbose and error-prone. One main reason is that dimensions are typically referred to by index and the meaning of that dimension might not be obvious (for examples, see here, here or here).

ΦML includes a Tensor class with the goal to remedy these shortcomings. A ΦML Tensor wraps one of the native tensors, such as ndarray, torch.Tensor or tf.Tensor, but extends them by two features:

  1. Names: All dimensions are named. Referring to a specific dimension can be done as tensor.<dimension name>. Elements along dimensions can also be named.
  2. Types: Every dimension is assigned a type flag, such as channel, batch or spatial.

For a full explanation of why these changes make your code not only easier to read but also shorter, see here. Here's the gist:

  • With dimension names, the dimension order becomes irrelevant and you don't need to worry about it.
  • Missing dimensions are automatically added when and where needed.
  • Tensors are automatically transposed to match.
  • Slicing by name is a lot more readable, e.g. image.channels['red'] vs image[:, :, :, 0].
  • Functions will automatically use the right dimensions, e.g. convolutions and FFTs act on spatial dimensions by default.
  • You can have arbitrarily many batch dimensions (or none) and your code will work the same.
  • The number of spatial dimensions control the dimensionality of not only your data but also your code. Your 2D code also runs in 3D!

Examples

The following three examples are taken from the examples notebook where you can also find examples on automatic differentiation, JIT compilation, and more. You can change the math.use(...) statements to any of the supported ML libraries.

Training an MLP

The following script trains an MLP with three hidden layers to learn a noisy 1D sine function in the range [-2, 2].

from phiml import math, nn
math.use('torch')

net = nn.mlp(1, 1, layers=[128, 128, 128], activation='ReLU')
optimizer = nn.adam(net, learning_rate=1e-3)

data_x = math.random_uniform(math.batch(batch=128), low=-2, high=2)
data_y = math.sin(data_x) + math.random_normal(math.batch(batch=128)) * .2

def loss_function(x, y):
    return math.l2_loss(y - math.native_call(net, x))

for i in range(100):
    loss = nn.update_weights(net, optimizer, loss_function, data_x, data_y)
    print(loss)

We didn't even have to import torch in this example since all calls were routed through ΦML.

Solving a sparse linear system with preconditioners

ΦML supports solving dense as well as sparse linear systems and can build an explicit matrix representation from linear Python functions in order to compute preconditioners. We recommend using ΦML's tensors, but you can pass native tensors to solve_linear() as well. The following example solves the 1D Poisson problem ∇x = b with b=1 with incomplete LU decomposition.

from phiml import math
import numpy as np

def laplace_1d(x):
    return math.pad(x[1:], (0, 1)) + math.pad(x[:-1], (1, 0)) - 2 * x

b = np.ones((6,))
solve = math.Solve('scipy-CG', rel_tol=1e-5, x0=0*b, preconditioner='ilu')
sol = math.solve_linear(math.jit_compile_linear(laplace_1d), b, solve)

Decorating the linear function with math.jit_compile_linear lets ΦML compute the sparse matrix inside solve_linear(). In this example, the matrix is a tridiagonal band matrix. Note that if you JIT-compile the math.solve_linear() call, the sparsity pattern and incomplete LU preconditioner are computed at JIT time. The L and U matrices then enter the computational graph as constants and are not recomputed every time the function is called.

Contributions

Contributions are welcome!

If you find a bug, feel free to open a GitHub issue or get in touch with the developers. If you have changes to be merged, check out our style guide before opening a pull request.

📄 Citation

Please use the following citation:

@article{Holl2024,
    doi = {10.21105/joss.06171},
    url = {https://doi.org/10.21105/joss.06171},
    year = {2024},
    publisher = {The Open Journal},
    volume = {9},
    number = {95},
    pages = {6171},
    author = {Philipp Holl and Nils Thuerey},
    title = {Φ-ML: Intuitive Scientific Computing with Dimension Types for Jax, PyTorch, TensorFlow & NumPy},
    journal = {Journal of Open Source Software}
}

Also see the corresponding journal article and software archive of version 1.4.0.

Projects using ΦML

ΦML is used by the simulation framework ΦFlow to integrate differentiable simulations with machine learning.