diff --git a/tests/ops/triton/test_selective_state_update.py b/tests/ops/triton/test_selective_state_update.py index 696b2c77..e81807ae 100644 --- a/tests/ops/triton/test_selective_state_update.py +++ b/tests/ops/triton/test_selective_state_update.py @@ -6,7 +6,7 @@ import torch.nn.functional as F import pytest -from einops import rearrange +from einops import rearrange, repeat from mamba_ssm.ops.triton.selective_state_update import selective_state_update, selective_state_update_ref