-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for Differentiable CWT #39
Comments
Dear @david-andrew, |
@v0lta I might be interested in taking a look since I probably would be implementing my own version in the meantime anyways. What do you think the solution might entail? e.g. is it just a case of adjusting the function to replace instances of I'm also not familiar with this library's development process/norms, so good to know if there's anything to be aware of |
@david-andrew I would expect the development to turn out a lot like you are describing it. Regarding the development workflow, the idea is to follow the best practices as established by the Python community. The |
Regarding the linting, I have hardcoded version four in cc13922 . I suggest we adopt the newest version when people have sorted out the problem described in tholo/pytest-flake8#87 . |
Dear @david-andrew , |
I think we have no established cost function yet. The cost won't be part of the toolbox to avoid experimental features. The new cwt function from the |
Until the next release is done the new cwt is installable via the command below. pip install git+ssh://git@github.com/v0lta/PyTorch-Wavelet-Toolbox.git@diff_cwt |
Awesome! I really appreciate you putting this all together! I was able to run the example you mentioned, looks like it works great! I presume you haven't looked into gpu/cuda support yet since it looks like the scales computations still use numpy. But this looks like an awesome first step |
I have added a unit test for GPU support with differentiable continuous wavelets here: PyTorch-Wavelet-Toolbox/tests/test_cwt.py Line 111 in 8a00d66
differentiable cwt computations on graphics cards should now work as expected. |
@david-andrew does it work for your use case now? |
Gave it a shot and it sort of looks like it works. I was able to get it to run a single training iteration, but then it crashes due to encountering tensors on both cpu and gpu. Here's a minimal version of what I'm doing:
and the error I'm getting:
I haven't had enough time to figure out if it's a problem with the |
Dear @david-andrew , pip install git+ssh://git@github.com/v0lta/PyTorch-Wavelet-Toolbox.git@wavedec2-improved-errors should do the job. |
Awesome, yeah looks to be working great on my side as well! |
Hello!
I was wondering if it would be possible to support a differentiable version of the
ptwt.continuous_transform.cwt
function. I see that internally, the function converts everything to numpy arrays, and so it's not able to handle input tensors with gradients attached to them.This would be very useful for my case where I'm using CWT scalograms for computing a similarity score/loss between signals. I understand that several other transforms support gradients e.g.
wavedec
andwaverec
, which work fantastically in ML pipelines I've tested, so I was hoping that such functionality could be extended to the continuous transform as well.Cheers!
The text was updated successfully, but these errors were encountered: