Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The fp32fft option #2

Closed
liuzhuang13 opened this issue Jul 10, 2021 · 5 comments
Closed

The fp32fft option #2

liuzhuang13 opened this issue Jul 10, 2021 · 5 comments

Comments

@liuzhuang13
Copy link

Hello, thanks for your nice work!

I wonder what does the option fp32fft do. In my experiments the input and output to the fft function are already torch.float32, so I'm not sure why there is an option for converting to fp32. Thanks in advance

@raoyongming
Copy link
Owner

Hi, thanks for pointing out it.

I just check the dtype here. It's true that the input and output are already float32.

We have this option since we tried a linear -> fft -> global filter -> ifft architecture in our early experiments, where the tensors after linear will be converted to fp16 since we use Automatic Mixed Precision training following the implementation of DeiT. This fc will not improve the performance so we didn't use the architecture in our final models.

It seems the LayerNorm layer will convert the input back to fp32, so the fp32 option is redundant here. But it may be useful if the architecture is slightly modified.

@liuzhuang13
Copy link
Author

Thanks for your timely answer! Do you find performance difference in your original linear -> fft -> global filter -> ifft structure, between using fp32fft and not using fp32fft?

@raoyongming
Copy link
Owner

Oops, it seems the fft functions don't support complex fp16 tensors. My logs show fp32 is slightly better than fp16, but it may come from the differences between identical runs. The inputs should always be converted to fp32 when using fft functions. I have updated the code to avoid further confusion.

@liuzhuang13
Copy link
Author

Hello, I tested and it seems torch fft functions indeed don't support float16. But given that fft functions don't support fp16 inputs, how did you get a result with fp16, that is slightly worse than fp32? Thanks

@raoyongming
Copy link
Owner

raoyongming commented Jul 15, 2021

I suspect there is an inconsistency between my old logs and the actual implementation. Maybe I ran two identical experiments in this case and the differences may come from the randomness during training. Since the above-mentioned model and the fp32fft option are only used in our early experiments, I didn't re-run the experiments to check this result. I think the correct implementation is always converting the input to fp32/fp64 before using the fft functions, and I have removed the option from our code. So sorry for the confusion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants