# A subtil error when using y_train.shape (B,) instead of y_train.shape (B, 1) 

Background: It happened to me that the network did not train beyond predicting the marginal distribution of the target variable. After some investigation, I realized the issue was caused by an incorrect shape of the target variable (y_true). Finally, ChatGPT-o1 suggested the correct way to fix the error. This notebook explains the root cause of the error and how to resolve it.

### The Problem: No singleton last dimension
The error occurred during a regression task where the network was designed to output parameters of a Gaussian distribution (mean and log standard deviation), with the negative log-likelihood (NLL) used as the loss function. The issue arises when the target variable (y_true) is provided as a 1D array (e.g., shape (B,)) rather than having a singleton last dimension (e.g., shape (B, 1)).

This mistake causes broadcasting errors during the computation of the NLL loss, leading to incorrect gradients and poor model training. This problem is not unique to JAX—it can occur in any deep learning framework, such as PyTorch or TensorFlow, where broadcasting rules apply.

In [34]:
import os
os.environ["KERAS_BACKEND"] = "jax" # Before importing Keras!
import jax
from jax import random
from numpyro import distributions as dist
import jax.numpy as jnp
import numpy as np
print(f"jax.__version__ {jax.__version__}")
cuda_available = any(device.platform == 'gpu' for device in jax.devices())
# Attempt to get CUDA version info (platform_version often includes CUDA info)
print(f"jax.devices() {jax.devices()}")
# get GPU Name
if (cuda_available):
    print(f"jax.devices()[0].device_kind {jax.devices()[0].device_kind}")

import keras
print(keras.__version__)
print(f"Keras version: {keras.__version__}")
print(f"Backend: {keras.backend.backend()}")

jax.__version__ 0.4.26
jax.devices() [CpuDevice(id=0)]
3.6.0
Keras version: 3.6.0
Backend: jax


## The last layer of the network
The output of the network has the dimension (B,2). The first dimension is the mean of the Gaussian distribution and the second dimension is the standard deviation. We use JAX to convert it into a distribution. 

In [30]:
@staticmethod
def output_to_gaussian_distribution(out):
    mean = out[:, :1]       # first column is mean
    log_sd = out[:, 1:]    # last column is log variance
    #scale = 1e-3 + stable_softplus(0.05 * out[:, 1:])  # Apply stable softplus to log scale
    scale = jnp.exp(log_sd)
    return dist.Normal(mean, scale)

out = jnp.array([[1.0, 0.1], [2.0, 0.2], [3.0, 0.1]])
output_to_gaussian_distribution(out)

<numpyro.distributions.continuous.Normal at 0x17f71f850>

### Correct y_train is (B,1) 👍

In [32]:
def NLL(y_true, y_pred):
    return -output_to_gaussian_distribution(y_pred).log_prob(y_true).mean()

y_train = np.array([1.11, 2.1, 3.23]).reshape(-1, 1)
jax.grad(NLL)(y_train, out)


Array([[0.03002013],
       [0.02234398],
       [0.06276936]], dtype=float32)

### Wong y_train is (B,) 👎

In [33]:
y_train = np.array([1.11, 2.1, 3.23])
jax.grad(NLL)(y_train, out)

Array([-0.22821394,  0.02564199,  0.3153968 ], dtype=float32)