This is a simple example where we use ginjax to learn scalar filters. We start by specifying what GPUs to use, and importing packages.

In [4]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=3

import time
import optax
from typing_extensions import Optional, Self

import jax
from jax import random
from jaxtyping import ArrayLike
import equinox as eqx

import ginjax.geometric as geom
import ginjax.ml as ml
import ginjax.models as models

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=3


Now lets define our images X and what filters we are going to use. Our image will be 2D, 64 x 64 scalar images. Our filters will be 3x3 and they will be the invariant scalar filters only. There are 3 of these, and the first one is the identity.

In [5]:
key = random.PRNGKey(time.time_ns())

D = 2
N = 64  # image size
M = 3  # filter image size
num_images = 10

group_actions = geom.make_all_operators(D)
conv_filters = geom.get_invariant_filters(
    Ms=[M], ks=[0], parities=[0], D=D, operators=group_actions
)

key, subkey = random.split(key)
multi_image_X = geom.MultiImage(
    {(0, 0): random.normal(subkey, shape=(num_images, 1) + (N,) * D)}, D
)

Now let us define our target function, and then construct our target images Y. The target function will merely be convolving by the filter at index 1, then convolving by the filter at index 2.

In [6]:
def target_function(
    multi_image: geom.MultiImage, conv_filter_a: jax.Array, conv_filter_b: jax.Array
) -> geom.MultiImage:
    convolved_data = geom.convolve(
        multi_image.D,
        geom.convolve(
            multi_image.D, multi_image[((), 0)], conv_filter_a[None, None], multi_image.is_torus
        ),
        conv_filter_b[None, None],
        multi_image.is_torus,
    )
    return geom.MultiImage({(0, 0): convolved_data}, multi_image.D, multi_image.is_torus)

multi_image_y = target_function(multi_image_X, conv_filters[((), 0)][1], conv_filters[((), 0)][2])

We now want to define our network and loss function. Machine learning on the GeometricImageNet is done on the MultiImage object, which is a way of collecting batches of multiple channels of images at possible different tensor orders in a single object.

For this toy example, we will make our task straightforward by making our network a linear combination of all the pairs of convolving by one filter from our set of three, then another filter from our set of three with replacement. In this fashion, our target function will be the 5th of 6 images. Our loss is simply the root mean square error loss (RMSE). The ml.train function expects a map_and_loss function that operates on MultiImages.

In [7]:
class SimpleModel(models.MultiImageModule):
    D: int
    net: list[ml.ConvContract]

    def __init__(
        self: Self,
        D: int,
        input_keys: geom.Signature,
        output_keys: geom.Signature,
        conv_filters: geom.MultiImage,
        key: ArrayLike,
    ):
        self.D = D
        key, subkey1, subkey2 = random.split(key, num=3)
        self.net = [
            ml.ConvContract(input_keys, output_keys, conv_filters, False, key=subkey1),
            ml.ConvContract(output_keys, output_keys, conv_filters, False, key=subkey2),
        ]

    def __call__(
        self: Self, x: geom.MultiImage, aux_data: Optional[eqx.nn.State] = None
    ) -> tuple[geom.MultiImage, Optional[eqx.nn.State]]:
        for layer in self.net:
            x = layer(x)

        return x, aux_data


def map_and_loss(
    model: models.MultiImageModule,
    multi_image_x: geom.MultiImage,
    multi_image_y: geom.MultiImage,
    aux_data: Optional[eqx.nn.State] = None,
) -> tuple[jax.Array, Optional[eqx.nn.State]]:
    pred_y, aux_data = jax.vmap(model, in_axes=(0,None), out_axes=(0,None))(multi_image_x, aux_data)
    return ml.smse_loss(multi_image_y, pred_y), aux_data

Now we will train our model using the `train` function from `ml.py`. Train takes the input data as a MultiImage, the target data as a MultiImage, a map and loss function that takes arguments (model, x, y, aux_data), the model, a random key for doing the batches, the number of epochs to run, the batch size, and the desired optax optimizer.

In [8]:
key, subkey = random.split(key)
model = SimpleModel(
    D, multi_image_X.get_signature(), multi_image_y.get_signature(), conv_filters, subkey
)

key, subkey = random.split(key)
trained_model, _, _, _ = ml.train(
    multi_image_X,
    multi_image_y,
    map_and_loss,
    model,
    subkey,
    ml.EpochStop(500, verbose=1),
    num_images,
    optimizer=optax.adam(optax.exponential_decay(0.1, transition_steps=1, decay_rate=0.99)),
)
assert isinstance(trained_model, SimpleModel)

print(trained_model.net[0].weights)
print(trained_model.net[1].weights)

Epoch 50 Train: 0.1476211 Epoch time: 0.01771
Epoch 100 Train: 0.0060808 Epoch time: 0.02097
Epoch 150 Train: 0.0001239 Epoch time: 0.01527
Epoch 200 Train: 0.0000097 Epoch time: 0.01195
Epoch 250 Train: 0.0000021 Epoch time: 0.01348
Epoch 300 Train: 0.0000008 Epoch time: 0.01915
Epoch 350 Train: 0.0000004 Epoch time: 0.01971
Epoch 400 Train: 0.0000003 Epoch time: 0.01432
Epoch 450 Train: 0.0000002 Epoch time: 0.01266
Epoch 500 Train: 0.0000002 Epoch time: 0.01407
{((), 0): {((), 0): Array([[[4.1636670e-04, 1.3408582e-05, 1.1102087e+00]]], dtype=float32)}}
{((), 0): {((), 0): Array([[[-2.3628289e-05,  9.0061241e-01, -6.9961043e-06]]], dtype=float32)}}


We can see that two are the filters have weight very close to 1, and the rest are close to 0. Hooray!