In [1]:
import jax
from flax import nnx
import jax.numpy as jnp
import optax
from torch.utils.data import DataLoader

from model import DAGMM
from dataloader import load_kddcup99

In [2]:
def objective_fn(model: DAGMM, inputs):
    gamma, x_hat, z = model(inputs)
    n = inputs.shape[0]
    phi, mu, covariances = model.calc_mixture_stats(inputs, gamma, z)
    loss = jnp.mean(jnp.linalg.norm(inputs - x_hat, ord=2, axis=1)**2)
    reg_1 = (model.lambda_1 / n) * model.calc_sample_energy(z, phi, mu, covariances)
    reg_2 = model.lambda_2 * (jnp.sum(jnp.diagonal(covariances, axis1=1, axis2=2)) ** -1)
    return loss + reg_1 + reg_2

@nnx.jit
def train_step(model: DAGMM, 
               optimizer: nnx.Optimizer, 
               metrics: nnx.MultiMetric, 
               inputs: jnp.ndarray):
    """Train for a single step."""
    grad_fn = nnx.value_and_grad(objective_fn)
    objective, grads = grad_fn(model, inputs)
    metrics.update(loss=objective)
    optimizer.update(grads) 

# @nnx.jit
# def eval_step(model: DAGMM, metrics: nnx.MultiMetric, batch):
#     loss, logits = loss_fn(model, batch)
#     metrics.update(loss=loss, logits=logits, labels=batch['label'])

In [3]:
key = jax.random.PRNGKey(42)
model = DAGMM(n_features=122, rngs=nnx.Rngs(key))

E0426 22:14:38.730867   25242 cuda_dnn.cc:520] Loaded runtime CuDNN library: 9.7.1 but source was compiled with: 9.8.0.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library.  If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
E0426 22:14:38.734521   25242 cuda_dnn.cc:520] Loaded runtime CuDNN library: 9.7.1 but source was compiled with: 9.8.0.  CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library.  If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
E0426 22:14:38.755809   25242 cuda_dnn.cc:520] Loaded runtime CuDNN library: 9.7.1 but source was compiled with: 9.8.0.  CuDNN library needs to have matching major version and equal or higher minor ve

XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

In [15]:
batch_size = 1024
key, dataloader_key = jax.random.split(key, 2)
dataloader_train, dataloader_test = load_kddcup99(dataloader_key, batch_size=batch_size)

2025-04-21 09:48:30.026219: W external/xla/xla/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.33GiB (rounded to 1432575744)requested by op 
2025-04-21 09:48:30.026546: W external/xla/xla/tsl/framework/bfc_allocator.cc:494] *_________***********____________________*******************************************************____
E0421 09:48:30.026586   24695 pjrt_stream_executor_client.cc:2985] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1432575728 bytes.


ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1432575728 bytes.

In [10]:
learning_rate = 0.0001
optimizer = nnx.Optimizer(model, optax.adam(learning_rate))
metrics = nnx.MultiMetric(
    loss=nnx.metrics.Average('loss'),
    # precision=nnx.metrics.Average('precision'),
    # recall=nnx.metrics.Average('recall'),
    # f1=nnx.metrics.Average('f1'),
)

In [13]:
jax.local_devices()

[cuda(id=0)]

In [None]:
epochs = 200


for epoch in range(epochs):
    print(f'epoch: {epoch + 1}')
    for step, (inputs, _) in enumerate(dataloader_train):
        inputs = jax.tree.map(lambda x: jnp.array(x), inputs)
        train_step(model, optimizer, metrics, inputs)
    for metric, value in metrics.compute().items():
        print(f'{metric}: {value}')
    metrics.reset()

ValueError: No active profiler server.

ValueError: No active profiler server.