-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
thop fails to count the flops and parameters of the custom operator Mamba #303
Comments
Counting params should just be |
Thank you for your valuable input. You are correct that by carefully examining the operators used in the model, we can accurately count the number of parameters, and sum(p.numel() for p in model.parameters()) is a reliable way to calculate the total number of parameters, including those of custom operators like Mamba. Regarding the calculation of FLOPs for Mamba, I noticed that MzeroMiko has provided a method to manually calculate Mamba's FLOPs in the linked issue (#110). This is indeed helpful for understanding the computational complexity of models involving Mamba. However, I have a question about combining the manually calculated FLOPs of Mamba with the FLOPs calculated by thop for the rest of the model. Is it valid to simply add the FLOPs obtained from these two methods? There are a few considerations:
Considering these points, while adding manually calculated FLOPs of Mamba to thop's results could serve as an approximation, it might not always give an exact measure of the model's total computational complexity. I would be grateful for your thoughts on this matter. Is there a way to accurately calculate the total FLOPs of a model that includes Mamba, given the information we have from thop and the manual calculation method provided by MzeroMiko? Thank you once again for your help in understanding the complexity of models with custom operators. |
I am not familiar with A code period from thop which may help you understand what I am saying below: def counter_matmul(input_size, output_size):
input_size = np.array(input_size)
output_size = np.array(output_size)
return np.prod(input_size) * output_size[-1] In fact, there's no such thing as if equation == "abc,abd->acd":
n, c, t = input_shapes[0]
p = input_shapes[-1][-1]
flop = n * c * t * p
return flop So here comes my solution, based on mamba issue #110, which is for VMamba: def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False):
"""
u: r(B D L)
delta: r(B D L)
A: r(D N)
B: r(B N L)
C: r(B N L)
D: r(D)
z: r(B D L)
delta_bias: r(D), fp32
ignores:
[.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu]
"""
assert not with_complex
# https://github.com/state-spaces/mamba/issues/110
flops = 9 * B * L * D * N
if with_D:
flops += B * D * L
if with_Z:
flops += B * D * L
return flops
def selective_scan_flop_jit(inputs, outputs, flops_fn=flops_selective_scan_fn):
print_jit_input_names(inputs)
B, D, L = inputs[0].type().sizes()
N = inputs[2].type().sizes()[1]
flops = flops_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False)
return flops supported_ops={
"prim::PythonOp.SelectiveScanMamba": partial(selective_scan_flop_jit, flops_fn=flops_selective_scan_fn),
}
Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops) Note that the code above do not fit any situations when using mamba modules, (I may support more cases in the future, but now it only support the situation as in VMamba). If you are using mamba, and do know exactly about what situation you are in, you can write another flops counting function and add it into |
Do you mean calculate mamba's FLOPs with fvcore and add with 9 * B * L * D * N then I can get the mamba's FLOPs? thank you |
@lth456321 Basically yes. But 9BLDN is just for the associative scan in forward, you need to pay attention to extra calculations like plus D or multiply z, and also pay attention to the dimensions you use in practise. |
@MzeroMiko Thank you for the informative reply, I just wonder for "flops_selective_scan_fn" function, this should be the flops for 1 ssm block right? for the whole model I can just num_layers*flops_selective_scan_fn right? much appreciated
|
@Sawyer117 Not really,this is only the flops for ssm, but for the whole mamba block,you may need to take all the einsum or linear which are used to prepare the parameters for ssm or to build the structure into account. |
I am using a custom operator Mamba (from mamba_ssm.modules.mamba_simple) in my project, but when I use the thop library to count the model parameters, it seems that thop does not count the parameters of Mamba.
My model definition is as follows:
When I use thop to count the parameters, I use the following code:
However, the output params do not seem to include the parameters of Mamba. I suspect this might be because Mamba is a custom operator and thop cannot automatically count its parameters.
I would like to ask:
I would greatly appreciate any suggestions or solutions! Please let me know if more information is needed.
I hope this Issue description is clear and provides enough information. You can modify and supplement it according to your actual situation. When submitting the Issue, remember to choose appropriate labels (such as "question" or "help wanted"), and if your project has multiple maintainers, consider setting them as assignees.
If you have any other questions, feel free to ask!
The text was updated successfully, but these errors were encountered: