Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Differentiable CWT #39

Closed
david-andrew opened this issue Aug 28, 2022 · 13 comments
Closed

Support for Differentiable CWT #39

david-andrew opened this issue Aug 28, 2022 · 13 comments
Assignees
Labels
enhancement New feature or request
Milestone

Comments

@david-andrew
Copy link

Hello!

I was wondering if it would be possible to support a differentiable version of the ptwt.continuous_transform.cwt function. I see that internally, the function converts everything to numpy arrays, and so it's not able to handle input tensors with gradients attached to them.

This would be very useful for my case where I'm using CWT scalograms for computing a similarity score/loss between signals. I understand that several other transforms support gradients e.g. wavedec and waverec, which work fantastically in ML pipelines I've tested, so I was hoping that such functionality could be extended to the continuous transform as well.

Cheers!

@v0lta
Copy link
Owner

v0lta commented Aug 29, 2022

Dear @david-andrew,
You are right, there is still a lot of NumPy code in the cwt module. We have a test to check the GPU functionality, but backprop into the NumPy code for the continuous wavelet is impossible. I would love to allow backprop into the cwt. I may be able to make time for this feature this autumn. Of course, contributions are welcome.

@v0lta v0lta added the enhancement New feature or request label Aug 29, 2022
@david-andrew
Copy link
Author

david-andrew commented Aug 30, 2022

@v0lta I might be interested in taking a look since I probably would be implementing my own version in the meantime anyways. What do you think the solution might entail? e.g. is it just a case of adjusting the function to replace instances of numpy with torch + making sure the inputs/outputs match + checking that gradients can flow through, or do you think a solution would be more involved?

I'm also not familiar with this library's development process/norms, so good to know if there's anything to be aware of

@v0lta
Copy link
Owner

v0lta commented Aug 30, 2022

@david-andrew I would expect the development to turn out a lot like you are describing it.

Regarding the development workflow, the idea is to follow the best practices as established by the Python community. The ptwt library uses nox to automate testing. You can run the test pipeline locally using nox -s test. Similarly nox -s lint checks the code style using flake8. You do not have to fix everything yourself. nox -s format will fix many formatting issues for you. Finally, you can double-check your typing with nox -s typing. If you add support for a new feature, I would expect tests for it in the tests folder.

@v0lta
Copy link
Owner

v0lta commented Aug 30, 2022

Regarding the linting, I have hardcoded version four in cc13922 . I suggest we adopt the newest version when people have sorted out the problem described in tholo/pytest-flake8#87 .

@v0lta v0lta added this to the v.0.1.4 milestone Sep 12, 2022
@v0lta
Copy link
Owner

v0lta commented Oct 10, 2022

Dear @david-andrew ,
I have a first working prototype. See https://github.com/v0lta/PyTorch-Wavelet-Toolbox/blob/diff_cwt/examples/continuous_signal_analysis/adaptive_cwt.py , the cwt now allows backdrop into wavelets. I am still looking for a reasonable cost function for the example. So that is still missing. However, I think the diff_cwt branch may already be worth a closer look.

@v0lta v0lta self-assigned this Oct 10, 2022
@v0lta
Copy link
Owner

v0lta commented Oct 10, 2022

I think we have no established cost function yet. The cost won't be part of the toolbox to avoid experimental features. The new cwt function from the diff_cwt branch supports gradient descent and should be usable.

@v0lta
Copy link
Owner

v0lta commented Oct 10, 2022

Until the next release is done the new cwt is installable via the command below.

pip install git+ssh://git@github.com/v0lta/PyTorch-Wavelet-Toolbox.git@diff_cwt

@david-andrew
Copy link
Author

Awesome! I really appreciate you putting this all together!

I was able to run the example you mentioned, looks like it works great! I presume you haven't looked into gpu/cuda support yet since it looks like the scales computations still use numpy. But this looks like an awesome first step

@v0lta
Copy link
Owner

v0lta commented Oct 11, 2022

I have added a unit test for GPU support with differentiable continuous wavelets here:

sig = sig.cuda()

differentiable cwt computations on graphics cards should now work as expected.

@v0lta
Copy link
Owner

v0lta commented Oct 11, 2022

@david-andrew does it work for your use case now?

@david-andrew
Copy link
Author

Gave it a shot and it sort of looks like it works. I was able to get it to run a single training iteration, but then it crashes due to encountering tensors on both cpu and gpu.

Here's a minimal version of what I'm doing:

import torch
from torch import nn
import numpy as np
import ptwt
from ptwt.continuous_transform import _ComplexMorletWavelet


class ScalogramLoss(nn.Module):
    """Complex Continuous Wavelet Transform Loss"""
    def __init__(self, wavelet, octaves=9, octave_divs=24, alpha=0.5, eps=1e-8):
        super().__init__()
        self.scales = 2**(torch.arange(octave_divs*octaves)/octave_divs + 1)
        self.wavelet = wavelet
        self.alpha = alpha
        self.eps = eps

    def forward(self, x, x_hat):
        S, _ = ptwt.continuous_transform.cwt(x, self.scales, self.wavelet)
        S_hat, _ = ptwt.continuous_transform.cwt(x_hat, self.scales, self.wavelet)
        S, S_hat = S.abs(), S_hat.abs() #take the magnitude of the complex wavelet transform

        linear_term = nn.functional.l1_loss(S, S_hat)
        log_term = nn.functional.l1_loss((S + self.eps).log2(), (S_hat + self.eps).log2())

        return self.alpha * linear_term + (1 - self.alpha) * log_term



def main():
    duration = 10 # seconds
    fs = 44100
    sig = np.sin(np.arange(int(fs*duration))*2*np.pi*440/fs)
    sig = torch.Tensor(sig).cuda()
    
    #reconstruct signal starting from random noise
    sig_hat = torch.randn_like(sig, requires_grad=True).cuda()

    wavelet = _ComplexMorletWavelet(name='cmor0.5-0.5').cuda()

    optim = torch.optim.Adam([sig_hat], lr=1e-3)
    loss_fn = ScalogramLoss(wavelet=wavelet)
    
    iterations = 500
    for it in range(iterations):
        loss = loss_fn(sig, sig_hat)
        loss.backward()
        optim.step()
        print(f'{it}: {loss.item()}')
       

if __name__ == "__main__":
    main()

and the error I'm getting:

$ python test2.py
0: 1.4774491040433757
Traceback (most recent call last):
  File "test2.py", line 52, in <module>
    main()
  File "test2.py", line 46, in main
    loss.backward()
  File "/home/david/anaconda3/envs/audio/lib/python3.8/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/david/anaconda3/envs/audio/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I haven't had enough time to figure out if it's a problem with the ptwt implementation, or something else on my end. But exciting to see it work for at least a single iteration! I'm not sure I have a ton of time to focus on this, but I'll keep poking around at it when I can.

@v0lta
Copy link
Owner

v0lta commented Oct 12, 2022

Dear @david-andrew ,
I think I found the problem on my end. I was moving the wavelet module to CPU to run the pywt-code for the axis labels. Afterward, it has to move back to the GPU. Commit f9dbff3 fixes the problem. Your example runs on my machine now. The commit is in the wavedec2-improved-errors branch now, diff_cwt merged with the changes suggested in issue #40 .
For testing

pip install git+ssh://git@github.com/v0lta/PyTorch-Wavelet-Toolbox.git@wavedec2-improved-errors

should do the job.

@david-andrew
Copy link
Author

Awesome, yeah looks to be working great on my side as well!

@v0lta v0lta closed this as completed Oct 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants