Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Errors when running the test suite. #183

Closed
CompRhys opened this issue Feb 22, 2024 · 2 comments
Closed

Errors when running the test suite. #183

CompRhys opened this issue Feb 22, 2024 · 2 comments

Comments

@CompRhys
Copy link

________________________________________________________________________________________________ test_mamba_inner_fn[False-True-128-itype0-wtype0] ________________________________________________________________________________________________

is_variable_B = False, is_variable_C = True, seqlen = 128, itype = torch.float32, wtype = torch.float32

    @pytest.mark.parametrize('wtype', [torch.float32, torch.complex64])
    # @pytest.mark.parametrize('wtype', [torch.complex64])
    # @pytest.mark.parametrize('itype', [torch.float32, torch.float16, torch.bfloat16])
    @pytest.mark.parametrize('itype', [torch.float32])
    # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 372, 512, 784, 1024, 1134, 2048, 4096])
    @pytest.mark.parametrize('seqlen', [128])
    @pytest.mark.parametrize("is_variable_C", [False, True])
    # @pytest.mark.parametrize("is_variable_C", [False])
    @pytest.mark.parametrize("is_variable_B", [False, True])
    # @pytest.mark.parametrize("is_variable_B", [True])
    def test_mamba_inner_fn(is_variable_B, is_variable_C, seqlen, itype, wtype):
        device = 'cuda'
        rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
        if itype == torch.bfloat16:
            rtol, atol = 3e-2, 5e-2
        rtolw, atolw = (1e-3, 1e-3)
        # If we have z, the errors on the weights seem higher
        rtolw = max(rtolw, rtol)
        atolw = max(atolw, atol)
        # set seed
        torch.random.manual_seed(0)
        batch_size = 2
        dim = 768
        dstate = 8
        dt_rank = 48
        is_complex = wtype == torch.complex64
        xz = torch.randn(batch_size, 2 * dim, seqlen, device=device, dtype=itype, requires_grad=True)
        conv1d_weight = torch.randn(dim, 1, 3, device=device, dtype=torch.float32, requires_grad=True)
        conv1d_bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
        x_proj_weight = torch.randn(dt_rank + (bool(is_variable_B) + bool(is_variable_C)) * dstate
                                    * (1 if not is_complex else 2),
                                    dim, device=device, dtype=itype, requires_grad=True)
        delta_proj_weight = torch.randn(dim, dt_rank, device=device, dtype=itype, requires_grad=True)
        out_proj_weight = torch.randn(dim // 2, dim, device=device, dtype=itype, requires_grad=True)
        out_proj_bias = None
        A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)).requires_grad_()
        B = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
             if not is_variable_B else None)
        C = (torch.randn(dim, dstate, device=device, dtype=wtype, requires_grad=True)
             if not is_variable_C else None)
        D = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
        delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)).requires_grad_()
        B_proj_bias = None
        C_proj_bias = None
        xz_ref = xz.detach().clone().requires_grad_()
        conv1d_weight_ref = conv1d_weight.detach().clone().requires_grad_()
        conv1d_bias_ref = conv1d_bias.detach().clone().requires_grad_()
        x_proj_weight_ref = x_proj_weight.detach().clone().requires_grad_()
        delta_proj_weight_ref = delta_proj_weight.detach().clone().requires_grad_()
        out_proj_weight_ref = out_proj_weight.detach().clone().requires_grad_()
        out_proj_bias_ref = (out_proj_bias.detach().clone().requires_grad_()
                             if out_proj_bias is not None else None)
        A_ref = A.detach().clone().requires_grad_()
        B_ref = B.detach().clone().requires_grad_() if B is not None else None
        C_ref = C.detach().clone().requires_grad_() if C is not None else None
        D_ref = D.detach().clone().requires_grad_()
        delta_bias_ref = delta_bias.detach().clone().requires_grad_() if delta_bias is not None else None
        out = mamba_inner_fn(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                             out_proj_weight, out_proj_bias,
                             A, B, C, D, delta_bias=delta_bias, delta_softplus=True)
>       out_ref = mamba_inner_ref(xz_ref, conv1d_weight_ref, conv1d_bias_ref, x_proj_weight_ref,
                                  delta_proj_weight_ref, out_proj_weight_ref, out_proj_bias_ref,
                                  A_ref, B_ref, C_ref, D_ref,
                                  delta_bias=delta_bias_ref, delta_softplus=True)

tests/ops/test_selective_scan.py:210: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
mamba_ssm/ops/selective_scan_interface.py:321: in mamba_inner_ref
    x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
/opt/miniconda/lib/python3.10/site-packages/causal_conv1d/causal_conv1d_interface.py:49: in causal_conv1d_fn
    return CausalConv1dFn.apply(x, weight, bias, seq_idx, activation)
/opt/miniconda/lib/python3.10/site-packages/torch/autograd/function.py:553: in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

ctx = <torch.autograd.function.CausalConv1dFnBackward object at 0x7fa222223840>
x = tensor([[[-0.9247, -0.4253, -2.6438,  ..., -0.2128, -0.3315, -0.2023],
         [-1.1451, -0.5715, -0.6510,  ...,  1.3...      [-0.2289, -0.1726,  1.8851,  ..., -0.1589,  0.6690,  1.3431]]],
       device='cuda:0', grad_fn=<SplitBackward0>)
weight = tensor([[ 0.1808, -0.5523,  0.9238],
        [-0.7350,  1.3800,  0.8676],
        [ 0.1297, -0.9406,  0.8109],
       ...],
        [ 0.8140,  1.0932, -0.2314],
        [-0.2205, -0.9232, -1.6818]], device='cuda:0', grad_fn=<ViewBackward0>)
bias = tensor([ 2.5441e+00, -7.1635e-01, -4.9337e-01,  1.2671e-01,  1.0136e-01,
        -4.0353e-01,  9.0226e-01,  8.0993e-01...,  1.3356e+00, -1.1588e+00,
        -2.5133e-01, -1.3636e-01,  2.8971e-01], device='cuda:0',
       requires_grad=True)
seq_idx = 'silu', activation = None

    @staticmethod
    def forward(ctx, x, weight, bias=None, seq_idx=None, activation=None):
        if activation not in [None, "silu", "swish"]:
            raise NotImplementedError("activation must be None, silu, or swish")
        if x.stride(2) != 1 and x.stride(1) != 1:
            x = x.contiguous()
        bias = bias.contiguous() if bias is not None else None
>       seq_idx = seq_idx.contiguous() if seq_idx is not None else None
E       AttributeError: 'str' object has no attribute 'contiguous'

/opt/miniconda/lib/python3.10/site-packages/causal_conv1d/causal_conv1d_interface.py:19: AttributeError

Installing mamba-ssm and then running the test suite to check the install I got the preceding error. In total there were 8 failures in the test suite for the same .contiguous attribute error.

@BlenderWang9487
Copy link

I think this is caused by the causal_conv1d interface change: this commit

There is a PR trying to fix the bug

But for now I think install causal-conv1d<=1.0.2 might fix this?

@CompRhys
Copy link
Author

closing as the code has been update and I believe this issue is no longer relevant.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants