In [7]:
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from jax.experimental import mesh_utils

devices = mesh_utils.create_device_mesh((2,2))
mesh = Mesh(devices, axis_names=('x', 'y'))


def weight(w_):
        #w_ = jax.random.normal(jax.random.PRNGKey(0), (64, 64), dtype='float32')
        w_ = jax.lax.with_sharding_constraint(w_, NamedSharding(mesh, P()))
        # Cayley Transform Calculation
        I = jnp.eye(w_.shape[0], dtype=w_.dtype)
        W = jnp.linalg.solve(I + w_, I - w_)  # Solve linear system for W
        return W.sum()

w_in = jax.random.normal(jax.random.PRNGKey(0), (64, 64), dtype='float32')
result = weight(w_in)

In [9]:
result

Array(-62.208313, dtype=float32)

In [10]:
grad_weight = jax.grad(weight)

In [11]:
print(grad_weight(w_in))

[[-0.6879153  -0.7543204   1.1403724  ...  0.23609817  2.9483542
  -2.1114783 ]
 [-0.6547487  -0.71795225  1.0853913  ...  0.22471511  2.8062043
  -2.009677  ]
 [-0.09802833 -0.10749111  0.16250375 ...  0.03364413  0.4201422
  -0.3008869 ]
 ...
 [ 0.36669135  0.40208846 -0.60787237 ... -0.12585142 -1.5716119
   1.1255177 ]
 [-0.18188933 -0.19944727  0.30152196 ...  0.0624259   0.77956414
  -0.55828875]
 [-0.41243052 -0.45224273  0.68369514 ...  0.14154953  1.7676466
  -1.2659087 ]]
