In [1]:
%load_ext autoreload
import torch

from mlstm_kernels.utils.test.checks import verify_output
from mlstm_kernels.torch.utils import to_numpy
from tests.torch.losses_tests import loss_layernorm_offset_quadratic

torch.set_printoptions(linewidth=200)

In [2]:
%autoreload 2
from mlstm_kernels.torch.parallel.native_siging import (
    mlstm_siging_parallel__native_autograd,
    mlstm_siging_parallel__native_custbw,
)

In [3]:
seed = 1
B = 1
NH = 1
S = 128
DHQK = 32
DHHV = 64
device = torch.device("cuda:0")
dtype = torch.float32

vecI_offset = 0.0
vecF_offset = 3.0

In [4]:
torch.manual_seed(seed)
matQ = torch.randn((B, NH, S, DHQK), dtype=torch.float32, device=device)
matK = torch.randn((B, NH, S, DHQK), dtype=torch.float32, device=device)
matV = torch.randn((B, NH, S, DHHV), dtype=torch.float32, device=device)
vecI = vecI_offset + torch.randn((B, NH, S), dtype=torch.float32, device=device)
vecF = vecF_offset + torch.randn((B, NH, S), dtype=torch.float32, device=device)

baseline_dtype = dtype
matQ_baseline = matQ.clone().to(dtype=baseline_dtype).detach().requires_grad_(True)
matK_baseline = matK.clone().to(dtype=baseline_dtype).detach().requires_grad_(True)
matV_baseline = matV.clone().to(dtype=baseline_dtype).detach().requires_grad_(True)
vecI_baseline = vecI.clone().to(dtype=baseline_dtype).detach().requires_grad_(True)
vecF_baseline = vecF.clone().to(dtype=baseline_dtype).detach().requires_grad_(True)

target_dtype = dtype
matQ_target = matQ.clone().to(dtype=target_dtype).detach().requires_grad_(True)
matK_target = matK.clone().to(dtype=target_dtype).detach().requires_grad_(True)
matV_target = matV.clone().to(dtype=target_dtype).detach().requires_grad_(True)
vecI_target = vecI.clone().to(dtype=target_dtype).detach().requires_grad_(True)
vecF_target = vecF.clone().to(dtype=target_dtype).detach().requires_grad_(True)

In [5]:
matH_bl = mlstm_siging_parallel__native_autograd(
    matQ_baseline,
    matK_baseline,
    matV_baseline,
    vecI_baseline,
    vecF_baseline,
    stable_fgate=True,
    normalize=False,
)
loss_layernorm_offset_quadratic(matH_bl).backward()

In [None]:
vecI_baseline.unsqueeze(-1).transpose(-2, -1).shape

In [7]:
matH_tgt = (
    mlstm_siging_parallel__native_custbw(  # mlstm_siging_parallel__native_autograd(
        matQ_target,
        matK_target,
        matV_target,
        vecI_target,
        vecF_target,
        stable_fgate=True,
        normalize=False,
    )
)
loss_layernorm_offset_quadratic(matH_tgt).backward()

In [None]:
fig = verify_output(
    "matH_stable_fgate", to_numpy(matH_bl), to_numpy(matH_tgt), atol=1e-5, rtol=1e-5
)

In [None]:
fig = verify_output(
    "matQ.grad",
    to_numpy(matQ_baseline.grad),
    to_numpy(matQ_target.grad),
    atol=1e-5,
    rtol=1e-5,
)

In [None]:
fig = verify_output(
    "matK.grad",
    to_numpy(matK_baseline.grad),
    to_numpy(matK_target.grad),
    atol=1e-5,
    rtol=1e-5,
)

In [None]:
fig = verify_output(
    "matV.grad",
    to_numpy(matV_baseline.grad),
    to_numpy(matV_target.grad),
    atol=1e-5,
    rtol=1e-5,
)

In [None]:
fig = verify_output(
    "vecI.grad",
    to_numpy(vecI_baseline.grad),
    to_numpy(vecI_target.grad),
    atol=1e-5,
    rtol=1e-5,
)

In [None]:
fig = verify_output(
    "vecF.grad",
    to_numpy(vecF_baseline.grad),
    to_numpy(vecF_target.grad),
    atol=1e-5,
    rtol=1e-5,
)