In [5]:
import torch
import torch.nn as nn
def main():
    print("Hello from torch-test!")
    if torch.backends.mps.is_available():
        print("Excellent! MPS backend is available.")
    else:
        print("MPS backend is not available: Something went wrong! Are you running this on a Mac with Apple Silicon chip?")

if __name__ == "__main__":
    main()

Hello from torch-test!
Excellent! MPS backend is available.


In [6]:
input = torch.randn(2,3,4, requires_grad=True)

print(input)
print(input.mean(-1, keepdim=True))
print(input.mean(1))

variance_epsilon = 1e-6


tensor([[[ 1.1574, -0.5219, -0.5566,  0.3880],
         [-0.1867,  0.5207, -0.8282, -1.5612],
         [ 0.3143,  0.1646,  0.9471,  0.1488]],

        [[-0.1375,  0.6165,  0.2010, -0.7181],
         [ 1.0336, -0.3422,  1.0562,  1.1867],
         [-0.1696, -0.4101,  1.1450, -0.9568]]], requires_grad=True)
tensor([[[ 0.1167],
         [-0.5138],
         [ 0.3937]],

        [[-0.0095],
         [ 0.7336],
         [-0.0979]]], grad_fn=<MeanBackward1>)
tensor([[ 0.4284,  0.0545, -0.1459, -0.3415],
        [ 0.2422, -0.0453,  0.8007, -0.1627]], grad_fn=<MeanBackward1>)


In [7]:
input = input.to(torch.float32)
variance = input.pow(2).mean(-1, keepdim=True)
hidden_states = input * torch.rsqrt(variance + variance_epsilon)
# print(0.1089**2, input.pow(2))

layerNorm = nn.RMSNorm([4])
hidden_states1 = layerNorm(input)
print(hidden_states)
print(hidden_states1)

tensor([[[ 1.6080, -0.7250, -0.7733,  0.5391],
         [-0.2016,  0.5624, -0.8945, -1.6861],
         [ 0.6149,  0.3221,  1.8529,  0.2911]],

        [[-0.2814,  1.2617,  0.4114, -1.4696],
         [ 1.0733, -0.3553,  1.0968,  1.2324],
         [-0.2179, -0.5269,  1.4710, -1.2292]]], grad_fn=<MulBackward0>)
tensor([[[ 1.6080, -0.7250, -0.7733,  0.5391],
         [-0.2016,  0.5624, -0.8945, -1.6861],
         [ 0.6149,  0.3221,  1.8529,  0.2911]],

        [[-0.2814,  1.2617,  0.4114, -1.4696],
         [ 1.0733, -0.3553,  1.0968,  1.2324],
         [-0.2179, -0.5269,  1.4710, -1.2292]]], grad_fn=<MulBackward0>)


In [9]:
import torch

def build_rope_cache(seq_len, dim, base=10000, device=None):
    device = device or torch.device('cpu')
    position = torch.arange(seq_len, dtype=torch.float32, device=device)
    dim_idx = torch.arange(dim // 2, dtype=torch.float32, device=device)
    inv_freq = base ** (-dim_idx / (dim // 2))
    freqs = torch.outer(position, inv_freq)
    cos = torch.cos(freqs)
    sin = torch.sin(freqs)
    return cos, sin

def apply_rope(x, cos, sin):
    cos = cos.unsqueeze(0)
    sin = sin.unsqueeze(0)
    x_even = x[..., ::2]
    x_odd = x[..., 1::2]
    rotated_even = x_even * cos - x_odd * sin
    rotated_odd = x_odd * cos + x_even * sin
    rotated = torch.stack([rotated_even, rotated_odd], dim=-1).flatten(-2)
    return rotated

seq_len, dim = 4, 8
hidden_states = torch.arange(seq_len * dim, dtype=torch.float32).view(1, seq_len, dim)
cos, sin = build_rope_cache(seq_len, dim, device=hidden_states.device)
rope_hidden = apply_rope(hidden_states, cos, sin)
print('original token 0:', hidden_states)
print('after RoPE    :', rope_hidden)


original token 0: tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11., 12., 13., 14., 15.],
         [16., 17., 18., 19., 20., 21., 22., 23.],
         [24., 25., 26., 27., 28., 29., 30., 31.]]])
after RoPE    : tensor([[[  0.0000,   1.0000,   2.0000,   3.0000,   4.0000,   5.0000,   6.0000,
            7.0000],
         [ -3.2508,  11.5945,   8.8519,  11.9434,  11.8694,  13.1193,  13.9850,
           15.0140],
         [-22.1164,   7.4743,  13.8665,  22.1973,  19.5760,  21.3958,  21.9540,
           23.0440],
         [-27.2878, -21.3629,  16.8597,  33.4776,  27.1175,  29.8268,  29.9069,
           31.0899]]])
