In [2]:
import jax
import jax_dataclasses as jdc

In [None]:
@jdc.pytree_dataclass
class Ray:
    x: float
    y: float
    dx: float
    dy: float
    z: float
    pathlength: float
    _one =  1.0

@jdc.pytree_dataclass
class Lens:
    f: float
    z: float

    def step(self, ray):
        z = self.z
        f = self.f
        dx = ray.dx - ray.x / f
        dy = ray.dy - ray.y / f
        
        return Ray(
            x=ray.x,
            y=ray.y,
            dx=dx,
            dy=dy,
            z=z,
            pathlength=ray.pathlength,
        )

def mask_ray(ray, grad_mask):
    def mask_field(name, value):
        return value if grad_mask.get(name, True) else jax.lax.stop_gradient(value)
    
    ray_dict = ray.__dict__
    ray_masked = Ray(**{name: mask_field(name, pytree) for name, pytree in ray_dict.items()})

    return ray_masked


# Create a grad_mask indicating which fields should allow gradients.
# For example, here we allow gradients for x, y, dx, dy, and pathlength,
# but not for z.
grad_mask = {
    "x": True,
    "y": True,
    "dx": True,
    "dy": True,
    "_one": True,  # This is not a field we care about for gradients
    "z": False,
    "pathlength": False,

}

lens = Lens(f=1.0, z=0.0)
ray = Ray(x=0.1, y=0.2, dx=0.3, dy=0.4, z=0.0, pathlength=0.6)

ray_masked = mask_ray(ray, grad_mask)

jac = jax.jacobian(lens.step)(ray)
jac = jax.jacobian(lens.step)(ray_masked)
jacjac = jax.jacobian(jax.jacobian(lens.step))(ray_masked)
print(jac)



    

Ray(x=Ray(x=Array(1., dtype=float32, weak_type=True), y=Array(0., dtype=float32, weak_type=True), dx=Array(0., dtype=float32, weak_type=True), dy=Array(0., dtype=float32, weak_type=True), z=Array(0., dtype=float32, weak_type=True), pathlength=Array(0., dtype=float32, weak_type=True)), y=Ray(x=Array(0., dtype=float32, weak_type=True), y=Array(1., dtype=float32, weak_type=True), dx=Array(0., dtype=float32, weak_type=True), dy=Array(0., dtype=float32, weak_type=True), z=Array(0., dtype=float32, weak_type=True), pathlength=Array(0., dtype=float32, weak_type=True)), dx=Ray(x=Array(-1., dtype=float32, weak_type=True), y=Array(0., dtype=float32, weak_type=True), dx=Array(1., dtype=float32, weak_type=True), dy=Array(0., dtype=float32, weak_type=True), z=Array(0., dtype=float32, weak_type=True), pathlength=Array(0., dtype=float32, weak_type=True)), dy=Ray(x=Array(0., dtype=float32, weak_type=True), y=Array(-1., dtype=float32, weak_type=True), dx=Array(0., dtype=float32, weak_type=True), dy=Arra

In [48]:
jacjac.x.dx.x

Array(0., dtype=float32, weak_type=True)