From ee4b9e045e7440d2cf8d51172546778306578a8f Mon Sep 17 00:00:00 2001 From: Marco Cuturi Date: Tue, 21 Nov 2023 10:12:14 +0100 Subject: [PATCH] Update layers.py --- src/ott/neural/models/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ott/neural/models/layers.py b/src/ott/neural/models/layers.py index 23052a93a..dfd222c60 100644 --- a/src/ott/neural/models/layers.py +++ b/src/ott/neural/models/layers.py @@ -79,7 +79,7 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: class PosDefPotentials(nn.Module): - """A layer to output (0.5 [A_i A_i^T] (x - b_i)_i potentials. + """A layer to output (0.5 || A_i^T (x - b_i)||^2)_i potentials. Args: use_bias: whether to add a bias to the output.