In [10]:
from torch import nn
import torch
class TorchRMSNorm(nn.Module):

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        # (dim), the gamma parameter for scaling
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x: torch.Tensor):
        # x.pow(2).mean(-1, keepdim=True): (bs, seq_len, dim) -> (bs, seq_len, 1)
        # x: (bs, seq_len, dim) * (bs, seq_len, 1) -> (bs, seq_len, dim)
        # rsqrt: 1/sqrt(x)
        # pow: x^y
        # mean: mean of x along the last dimension
        # keepdim: keep the last dimension
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        return self.weight * self._norm(x.to(torch.float32)).type_as(x)

In [11]:
import jax.numpy as jnp
from flax import nnx
from typing import Optional

class JAXRMSNorm(nnx.Module):
    """RMSNorm implementation using Flax NNX"""
    
    def __init__(self, dim: int, eps: float = 1e-6, rngs: Optional[nnx.Rngs] = None):
        super().__init__()
        self.dim = dim
        self.eps = eps
        # Initialize scale parameter (gamma)
        self.weight = nnx.Param(jnp.ones(dim), name='weight')
        
    def _norm(self, x: jnp.ndarray) -> jnp.ndarray:
        # Calculate RMS along the last dimension
        rms = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + self.eps)
        # Normalize input
        return x / rms
    
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        # Apply normalization and scaling
        return self.weight * self._norm(x.astype(jnp.float32)).astype(x.dtype)

In [13]:
# 设置PyTorch打印选项
torch.set_printoptions(precision=7)

# 设置JAX打印选项
jnp.set_printoptions(precision=7)
import numpy as np
np.random.seed(42)
a = np.random.rand(3, 4)

b = torch.from_numpy(a)
c = jnp.array(a)

norm = TorchRMSNorm(b.shape[-1])
y1 = norm(b)
print(y1)

n = JAXRMSNorm(c.shape[-1])
y2 = n(c)
print(y2)


tensor([[0.5380372, 1.3657274, 1.0515295, 0.8599895],
        [0.3483647, 0.3483108, 0.1296914, 1.9340326],
        [0.8951303, 1.0544027, 0.0306527, 1.4443089]], dtype=torch.float64,
       grad_fn=<MulBackward0>)
[[0.5380372 1.3657274 1.0515295 0.8599895]
 [0.3483647 0.3483108 0.1296914 1.9340326]
 [0.8951303 1.0544027 0.0306527 1.4443089]]
