In [1]:
%load_ext autoreload
%autoreload 2

import torch 
torch.set_printoptions(linewidth=200, threshold=100000)

from ln import MultiHeadLayerNorm
from mlstm_parallel import mlstm_torch_autograd
from mlstm_chunkwise._torch_fw_legacy import mlstm_chunkwise_parallel_legacy
from mlstm_chunkwise.torch_fw import mlstm_chunkwise_parallel_fw_looped, mlstm_chunkwise_parallel_fw_parallel


# Match vLSTM chunkwise parallel to parallel (forward and backward)


In [2]:
# params
S = 12 # seq len
B = 1 # batch size
NH = 1 # num heads
DH = 6 # dim per head

DTYPE = torch.float64
PT_P_AG_DTYPE = torch.float64
PT_CPL_AG_DTYPE = torch.float64
PT_CPP_AG_DTYPE = torch.float64
DEVICE = torch.device("cuda:0")
EPS = 0.0

In [3]:
torch.manual_seed(0)
matQ = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
matK = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
matV = torch.randn((B, NH, S, DH), dtype=DTYPE, device=DEVICE)
vecI = torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)
vecF = torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)

In [4]:
offset = 3.* torch.randn((B, NH, S, DH), device=DEVICE, dtype=DTYPE) # offset for scaled version to have a larger gradient

In [5]:
mh_layernorm = MultiHeadLayerNorm(NH*DH, eps=1e-6).to(device=DEVICE, dtype=DTYPE)
mh_layernorm.weight, mh_layernorm.bias

(Parameter containing:
 tensor([0., 0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float64, requires_grad=True),
 None)

### parallel baseline.

In [6]:
matQ_pt_p_ag = matQ.clone().to(PT_P_AG_DTYPE).detach().requires_grad_(True)
matK_pt_p_ag = matK.clone().to(PT_P_AG_DTYPE).detach().requires_grad_(True)
matV_pt_p_ag = matV.clone().to(PT_P_AG_DTYPE).detach().requires_grad_(True)
vecI_pt_p_ag = vecI.clone().to(PT_P_AG_DTYPE).detach().requires_grad_(True)
vecF_pt_p_ag = vecF.clone().to(PT_P_AG_DTYPE).detach().requires_grad_(True)

In [7]:
matQ_pt_p_ag.grad

In [8]:
matH_pt_p_ag = mlstm_torch_autograd(matQ_pt_p_ag, matK_pt_p_ag, matV_pt_p_ag, vecI_pt_p_ag, vecF_pt_p_ag, EPS)

In [9]:
matH_pt_p_ag.sum().backward()

### parallel baseline. With GroupNorm.

In [10]:
matQ_pt_p_gn_ag = matQ.clone().to(PT_P_AG_DTYPE).detach().requires_grad_(True)
matK_pt_p_gn_ag = matK.clone().to(PT_P_AG_DTYPE).detach().requires_grad_(True)
matV_pt_p_gn_ag = matV.clone().to(PT_P_AG_DTYPE).detach().requires_grad_(True)
vecI_pt_p_gn_ag = vecI.clone().to(PT_P_AG_DTYPE).detach().requires_grad_(True)
vecF_pt_p_gn_ag = vecF.clone().to(PT_P_AG_DTYPE).detach().requires_grad_(True)

In [11]:
matQ_pt_p_gn_ag.grad

In [12]:
matH_pt_p_gn_ag = mlstm_torch_autograd(matQ_pt_p_gn_ag, matK_pt_p_gn_ag, matV_pt_p_gn_ag, vecI_pt_p_gn_ag, vecF_pt_p_gn_ag, EPS)
matH_pt_p_gn_ag_scaled = mh_layernorm(matH_pt_p_gn_ag)

In [13]:
((matH_pt_p_gn_ag_scaled + offset) ** 2).sum().backward()

### chunkwise legacy version.

In [14]:
matH_cpl = mlstm_chunkwise_parallel_legacy(matQ, matK, matV, vecI.unsqueeze(-1), vecF.unsqueeze(-1), chunk_size=4)
matH_cpl

tensor([[[[ 1.4967e+00, -6.5781e-01, -2.8072e-01,  9.9590e-01, -7.2010e-01,  8.2376e-03],
          [ 1.1961e+00, -1.5677e-01,  1.4917e-01,  6.2761e-01, -9.5811e-01,  3.4698e-01],
          [ 3.5374e-01, -1.1173e+00, -1.7381e+00, -3.8857e-01, -1.5804e-01, -5.8152e-01],
          [ 3.5401e-02,  1.6806e-01,  6.9934e-01, -8.9639e-01,  3.4695e-01,  1.4512e-01],
          [-1.0049e+00,  9.2938e-01,  8.7770e-02,  4.1880e-01, -1.1683e+00,  4.4696e-01],
          [ 7.9167e-01,  1.6583e+00,  1.5999e+00,  1.2255e-01,  7.5304e-02, -1.3655e+00],
          [-7.1761e-01,  2.2927e+00,  1.3356e+00,  1.7731e+00,  1.0226e-01, -4.0546e+00],
          [-1.2917e+00, -3.1530e+00, -2.5687e+00,  2.9751e-01, -9.5043e-01,  1.6586e+00],
          [ 2.8938e-03,  1.0294e+00,  1.7216e+00,  3.1265e-01, -1.0807e-01, -1.8735e+00],
          [-3.6011e-02, -2.2485e-01, -2.7282e-01,  6.0402e-02,  2.2168e-02,  1.8846e-01],
          [ 3.6046e-01,  6.9569e-01,  8.1353e-01,  4.3021e-01, -1.8089e-01,  1.6464e-01],
          

In [15]:
matH_pt_p_ag - matH_cpl

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 2.2204e-16, -5.5511e-17,  0.0000e+00,  1.1102e-16, -2.2204e-16,  0.0000e+00],
          [ 5.5511e-17,  0.0000e+00,  2.2204e-16,  0.0000e+00, -5.5511e-17,  0.0000e+00],
          [ 2.0817e-17, -2.7756e-17,  1.1102e-16, -2.2204e-16,  5.5511e-17,  2.7756e-17],
          [ 2.2204e-16,  0.0000e+00,  1.3878e-17,  0.0000e+00,  0.0000e+00, -5.5511e-17],
          [-2.2204e-16, -2.2204e-16, -2.2204e-16,  0.0000e+00, -6.9389e-17,  2.2204e-16],
          [ 2.2204e-16,  8.8818e-16,  0.0000e+00,  4.4409e-16,  1.3878e-16,  0.0000e+00],
          [ 2.2204e-16,  8.8818e-16,  4.4409e-16,  0.0000e+00,  2.2204e-16, -2.2204e-16],
          [ 1.1926e-16,  1.1102e-15,  4.4409e-16,  4.4409e-16,  8.3267e-17, -1.3323e-15],
          [-4.8572e-17, -3.3307e-16, -1.6653e-16,  3.4694e-17, -2.0817e-16,  1.3878e-16],
          [ 1.1102e-16,  2.2204e-16,  1.1102e-16, -1.6653e-16,  1.6653e-16, -8.3267e-17],
          

In [16]:
(matH_pt_p_ag - matH_cpl).abs().max()

tensor(1.3323e-15, device='cuda:0', dtype=torch.float64, grad_fn=<MaxBackward1>)

### chunkwise looped version.

In [17]:
matQ_pt_cpl_ag = matQ.clone().to(PT_CPL_AG_DTYPE).detach().requires_grad_(True)
matK_pt_cpl_ag = matK.clone().to(PT_CPL_AG_DTYPE).detach().requires_grad_(True)
matV_pt_cpl_ag = matV.clone().to(PT_CPL_AG_DTYPE).detach().requires_grad_(True)
vecI_pt_cpl_ag = vecI.clone().to(PT_CPL_AG_DTYPE).detach().requires_grad_(True)
vecF_pt_cpl_ag = vecF.clone().to(PT_CPL_AG_DTYPE).detach().requires_grad_(True)

In [18]:
matH_cplo = mlstm_chunkwise_parallel_fw_looped(matQ_pt_cpl_ag, matK_pt_cpl_ag, matV_pt_cpl_ag, vecI_pt_cpl_ag, vecF_pt_cpl_ag, seq_chunk_size=4)

In [19]:
matH_cplo.sum().backward()

In [20]:
matH_pt_p_ag - matH_cplo

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 2.2204e-16, -5.5511e-17,  0.0000e+00,  1.1102e-16, -2.2204e-16,  0.0000e+00],
          [ 5.5511e-17,  0.0000e+00,  2.2204e-16,  0.0000e+00, -5.5511e-17,  0.0000e+00],
          [ 2.0817e-17, -2.7756e-17,  1.1102e-16, -2.2204e-16,  5.5511e-17,  2.7756e-17],
          [ 2.2204e-16,  0.0000e+00,  1.3878e-17,  0.0000e+00,  0.0000e+00, -5.5511e-17],
          [-2.2204e-16, -2.2204e-16, -2.2204e-16,  0.0000e+00, -6.9389e-17,  2.2204e-16],
          [ 2.2204e-16,  8.8818e-16,  0.0000e+00,  4.4409e-16,  1.3878e-16,  0.0000e+00],
          [ 2.2204e-16,  8.8818e-16,  4.4409e-16,  0.0000e+00,  2.2204e-16, -2.2204e-16],
          [ 1.1926e-16,  1.1102e-15,  4.4409e-16,  4.4409e-16,  8.3267e-17, -1.3323e-15],
          [-4.8572e-17, -3.3307e-16, -1.6653e-16,  3.4694e-17, -2.0817e-16,  1.3878e-16],
          [ 1.1102e-16,  2.2204e-16,  1.1102e-16, -1.6653e-16,  1.6653e-16, -8.3267e-17],
          

In [21]:
(matH_pt_p_ag - matH_cplo).abs().max()

tensor(1.3323e-15, device='cuda:0', dtype=torch.float64, grad_fn=<MaxBackward1>)

In [22]:
print(f"q grad diff: {(matQ_pt_p_ag.grad - matQ_pt_cpl_ag.grad).abs().max()}")
print(f"k grad diff: {(matK_pt_p_ag.grad - matK_pt_cpl_ag.grad).abs().max()}")
print(f"v grad diff: {(matV_pt_p_ag.grad - matV_pt_cpl_ag.grad).abs().max()}")
print(f"i grad diff: {(vecI_pt_p_ag.grad - vecI_pt_cpl_ag.grad).abs().max()}")
print(f"f grad diff: {(vecF_pt_p_ag.grad - vecF_pt_cpl_ag.grad).abs().max()}")

q grad diff: 1.9984014443252818e-15
k grad diff: 3.552713678800501e-15
v grad diff: 8.881784197001252e-16
i grad diff: 1.7763568394002505e-15
f grad diff: 1.1102230246251565e-15


### chunkwise parallel version.

In [23]:
matQ_pt_cpp_ag = matQ.clone().to(PT_CPP_AG_DTYPE).detach().requires_grad_(True)
matK_pt_cpp_ag = matK.clone().to(PT_CPP_AG_DTYPE).detach().requires_grad_(True)
matV_pt_cpp_ag = matV.clone().to(PT_CPP_AG_DTYPE).detach().requires_grad_(True)
vecI_pt_cpp_ag = vecI.clone().to(PT_CPP_AG_DTYPE).detach().requires_grad_(True)
vecF_pt_cpp_ag = vecF.clone().to(PT_CPP_AG_DTYPE).detach().requires_grad_(True)

In [24]:
matH_cppa = mlstm_chunkwise_parallel_fw_parallel(matQ_pt_cpp_ag, matK_pt_cpp_ag, matV_pt_cpp_ag, vecI_pt_cpp_ag, vecF_pt_cpp_ag, seq_chunk_size=4)

In [25]:
matH_cppa.sum().backward()

In [26]:
matH_pt_p_ag - matH_cppa

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 4.4409e-16, -1.3878e-16,  0.0000e+00,  2.2204e-16, -2.2204e-16,  0.0000e+00],
          [ 5.5511e-17,  0.0000e+00,  2.2204e-16,  0.0000e+00, -2.7756e-17,  1.1102e-16],
          [ 2.0817e-17, -2.7756e-17,  1.1102e-16, -2.2204e-16,  5.5511e-17,  2.7756e-17],
          [ 2.2204e-16,  0.0000e+00,  1.3878e-17,  0.0000e+00,  0.0000e+00, -5.5511e-17],
          [-2.2204e-16, -2.2204e-16, -2.2204e-16,  0.0000e+00, -6.9389e-17,  2.2204e-16],
          [ 1.1102e-16,  8.8818e-16,  0.0000e+00,  4.4409e-16,  9.7145e-17,  0.0000e+00],
          [ 4.4409e-16,  1.7764e-15,  4.4409e-16, -1.6653e-16,  4.4409e-16, -4.4409e-16],
          [ 7.1991e-17,  8.8818e-16,  4.4409e-16,  4.9960e-16, -1.3878e-17, -1.3323e-15],
          [-4.8572e-17, -3.0531e-16, -1.6653e-16,  3.4694e-17, -2.3592e-16,  1.3878e-16],
          [ 1.6653e-16,  3.3307e-16,  2.2204e-16,  0.0000e+00,  1.1102e-16, -2.7756e-17],
          

In [27]:
(matH_pt_p_ag - matH_cppa).abs().max()

tensor(1.7764e-15, device='cuda:0', dtype=torch.float64, grad_fn=<MaxBackward1>)

In [28]:
print(f"q grad diff: {(matQ_pt_p_ag.grad - matQ_pt_cpp_ag.grad).abs().max()}")
print(f"k grad diff: {(matK_pt_p_ag.grad - matK_pt_cpp_ag.grad).abs().max()}")
print(f"v grad diff: {(matV_pt_p_ag.grad - matV_pt_cpp_ag.grad).abs().max()}")
print(f"i grad diff: {(vecI_pt_p_ag.grad - vecI_pt_cpp_ag.grad).abs().max()}")
print(f"f grad diff: {(vecF_pt_p_ag.grad - vecF_pt_cpp_ag.grad).abs().max()}")


q grad diff: 5.329070518200751e-15
k grad diff: 7.105427357601002e-15
v grad diff: 8.881784197001252e-16
i grad diff: 5.329070518200751e-15
f grad diff: 8.881784197001252e-16


### chunkwise parallel version. With GroupNorm.

In [29]:
matQ_pt_cpp_gn_ag = matQ.clone().to(PT_CPP_AG_DTYPE).detach().requires_grad_(True)
matK_pt_cpp_gn_ag = matK.clone().to(PT_CPP_AG_DTYPE).detach().requires_grad_(True)
matV_pt_cpp_gn_ag = matV.clone().to(PT_CPP_AG_DTYPE).detach().requires_grad_(True)
vecI_pt_cpp_gn_ag = vecI.clone().to(PT_CPP_AG_DTYPE).detach().requires_grad_(True)
vecF_pt_cpp_gn_ag = vecF.clone().to(PT_CPP_AG_DTYPE).detach().requires_grad_(True)

In [30]:
matH_cppa_gn = mlstm_chunkwise_parallel_fw_parallel(matQ_pt_cpp_gn_ag, matK_pt_cpp_gn_ag, matV_pt_cpp_gn_ag, vecI_pt_cpp_gn_ag, vecF_pt_cpp_gn_ag, seq_chunk_size=4, detach_denominator=False)
matH_cppa_gn_scaled = mh_layernorm(matH_cppa_gn)

In [31]:
((matH_cppa_gn_scaled + offset) ** 2).sum().backward()

In [32]:
matH_pt_p_gn_ag - matH_cppa_gn

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 4.4409e-16, -1.3878e-16,  0.0000e+00,  2.2204e-16, -2.2204e-16,  0.0000e+00],
          [ 5.5511e-17,  0.0000e+00,  2.2204e-16,  0.0000e+00, -2.7756e-17,  1.1102e-16],
          [ 2.0817e-17, -2.7756e-17,  1.1102e-16, -2.2204e-16,  5.5511e-17,  2.7756e-17],
          [ 2.2204e-16,  0.0000e+00,  1.3878e-17,  0.0000e+00,  0.0000e+00, -5.5511e-17],
          [-2.2204e-16, -2.2204e-16, -2.2204e-16,  0.0000e+00, -6.9389e-17,  2.2204e-16],
          [ 1.1102e-16,  8.8818e-16,  0.0000e+00,  4.4409e-16,  9.7145e-17,  0.0000e+00],
          [ 4.4409e-16,  1.7764e-15,  4.4409e-16, -1.6653e-16,  4.4409e-16, -4.4409e-16],
          [ 7.1991e-17,  8.8818e-16,  4.4409e-16,  4.9960e-16, -1.3878e-17, -1.3323e-15],
          [-4.8572e-17, -3.0531e-16, -1.6653e-16,  3.4694e-17, -2.3592e-16,  1.3878e-16],
          [ 1.6653e-16,  3.3307e-16,  2.2204e-16,  0.0000e+00,  1.1102e-16, -2.7756e-17],
          

In [33]:
(matH_pt_p_gn_ag - matH_cppa_gn).abs().max()

tensor(1.7764e-15, device='cuda:0', dtype=torch.float64, grad_fn=<MaxBackward1>)

In [34]:
print(f"q grad diff: {(matQ_pt_p_gn_ag.grad - matQ_pt_cpp_gn_ag.grad).abs().max()}")
print(f"k grad diff: {(matK_pt_p_gn_ag.grad - matK_pt_cpp_gn_ag.grad).abs().max()}")
print(f"v grad diff: {(matV_pt_p_gn_ag.grad - matV_pt_cpp_gn_ag.grad).abs().max()}")
print(f"i grad diff: {(vecI_pt_p_gn_ag.grad - vecI_pt_cpp_gn_ag.grad).abs().max()}")
print(f"f grad diff: {(vecF_pt_p_gn_ag.grad - vecF_pt_cpp_gn_ag.grad).abs().max()}")

q grad diff: 1.0658141036401503e-14
k grad diff: 8.881784197001252e-15
v grad diff: 8.881784197001252e-15
i grad diff: 1.2212453270876722e-14
f grad diff: 4.884981308350689e-15


### chunkwise parallel version. With GroupNorm. Normalizer detached.

Should still match.

In [35]:
matQ_pt_cpp_gn_ag = matQ.clone().to(PT_CPP_AG_DTYPE).detach().requires_grad_(True)
matK_pt_cpp_gn_ag = matK.clone().to(PT_CPP_AG_DTYPE).detach().requires_grad_(True)
matV_pt_cpp_gn_ag = matV.clone().to(PT_CPP_AG_DTYPE).detach().requires_grad_(True)
vecI_pt_cpp_gn_ag = vecI.clone().to(PT_CPP_AG_DTYPE).detach().requires_grad_(True)
vecF_pt_cpp_gn_ag = vecF.clone().to(PT_CPP_AG_DTYPE).detach().requires_grad_(True)

In [36]:
matH_cppa_gn = mlstm_chunkwise_parallel_fw_parallel(matQ_pt_cpp_gn_ag, matK_pt_cpp_gn_ag, matV_pt_cpp_gn_ag, vecI_pt_cpp_gn_ag, vecF_pt_cpp_gn_ag, seq_chunk_size=4, detach_denominator=True)
matH_cppa_gn_scaled = mh_layernorm(matH_cppa_gn)

In [37]:
((matH_cppa_gn_scaled + offset) ** 2).sum().backward()

In [38]:
matH_pt_p_gn_ag - matH_cppa_gn

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 4.4409e-16, -1.3878e-16,  0.0000e+00,  2.2204e-16, -2.2204e-16,  0.0000e+00],
          [ 5.5511e-17,  0.0000e+00,  2.2204e-16,  0.0000e+00, -2.7756e-17,  1.1102e-16],
          [ 2.0817e-17, -2.7756e-17,  1.1102e-16, -2.2204e-16,  5.5511e-17,  2.7756e-17],
          [ 2.2204e-16,  0.0000e+00,  1.3878e-17,  0.0000e+00,  0.0000e+00, -5.5511e-17],
          [-2.2204e-16, -2.2204e-16, -2.2204e-16,  0.0000e+00, -6.9389e-17,  2.2204e-16],
          [ 1.1102e-16,  8.8818e-16,  0.0000e+00,  4.4409e-16,  9.7145e-17,  0.0000e+00],
          [ 4.4409e-16,  1.7764e-15,  4.4409e-16, -1.6653e-16,  4.4409e-16, -4.4409e-16],
          [ 7.1991e-17,  8.8818e-16,  4.4409e-16,  4.9960e-16, -1.3878e-17, -1.3323e-15],
          [-4.8572e-17, -3.0531e-16, -1.6653e-16,  3.4694e-17, -2.3592e-16,  1.3878e-16],
          [ 1.6653e-16,  3.3307e-16,  2.2204e-16,  0.0000e+00,  1.1102e-16, -2.7756e-17],
          

In [39]:
(matH_pt_p_gn_ag - matH_cppa_gn).abs().max()

tensor(1.7764e-15, device='cuda:0', dtype=torch.float64, grad_fn=<MaxBackward1>)

In [40]:
# when we detach the denominator in the parallel version too the difference is everywhere < 1e-14 (for torch.float64)
print(f"q grad diff: {(matQ_pt_p_gn_ag.grad - matQ_pt_cpp_gn_ag.grad).abs().max()}")
print(f"k grad diff: {(matK_pt_p_gn_ag.grad - matK_pt_cpp_gn_ag.grad).abs().max()}")
print(f"v grad diff: {(matV_pt_p_gn_ag.grad - matV_pt_cpp_gn_ag.grad).abs().max()}")
print(f"i grad diff: {(vecI_pt_p_gn_ag.grad - vecI_pt_cpp_gn_ag.grad).abs().max()}")
print(f"f grad diff: {(vecF_pt_p_gn_ag.grad - vecF_pt_cpp_gn_ag.grad).abs().max()}")

q grad diff: 8.372333267328833e-05
k grad diff: 9.059612933826067e-05
v grad diff: 8.881784197001252e-15
i grad diff: 0.001453663991701326
f grad diff: 1.2969772461013385e-06
