-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error in wavedec2 #40
Comments
Dear @abeyang00 import torch, pywt, ptwt
tmp = torch.randn(3, 10, 10)
wavelet = pywt.Wavelet("haar")
coeff2d = ptwt.wavedec2(tmp, wavelet, level=1, mode="zero") Runs as expected. It is not apparent how a 2d wavelet transform would deal with colour channels. So the toolbox does not expect any. If you want to transform colour channels use : import torch, pywt, ptwt
tmp = torch.randn(1, 3, 10, 10)
wavelet = pywt.Wavelet("haar")
coeff3d = ptwt.wavedec3(tmp, wavelet, level=1, mode="zero") transforms the colour channel and does not exit with an error. import torch, pywt, ptwt
import numpy as np
from scipy.misc import face
# get a colour image and move the channels into the batch dimension.
face = np.transpose(
face()[256 : 512, 256 : 512], [2, 0, 1]
).astype(np.float64)
# compute the transform
coeff2d = ptwt.wavedec2(torch.from_numpy(face), pywt.Wavelet("Haar"), level=1, mode='zero')
# invert the transform
rec = ptwt.waverec2(coeff2d, pywt.Wavelet("Haar"))
# move the color channels back
face_rec = rec.squeeze().permute(1,2,0) To process batches of multiple color images at a time you can concatenate these along the batch dimension. |
On second thought, the toolbox should have caught the problem earlier. I will add a more informative error message in the next release. |
@v0lta Thank you for the quick answer. However, i'm still confused as I want to incorporate wavedec2 inside convolutional layers in CNN. So concatenating along the batch dimension seems to be a naive approach. This might corrupt the data, no? And when i try to use wavedec3 with haar wavelet, this seems to output something that i do not want. ('aad', 'ada', etc...) Best, |
Dear @abeyang , import torch, pywt, ptwt
import numpy as np
from scipy.misc import face
# get a colour image and move the channels into the batch dimension.
face = face()[256 : 512, 256 : 512]
face_tp = np.transpose(
face, [2, 0, 1]
).astype(np.float64)
# compute the transform
coeff2d = ptwt.wavedec2(torch.from_numpy(face_tp), pywt.Wavelet("Haar"),
level=1, mode='zero')
# invert the transform
rec = ptwt.waverec2(coeff2d, pywt.Wavelet("Haar"))
# move the color channels back
face_rec = rec.squeeze().permute(1,2,0)
np.allclose(face, face_rec) All the best, |
cd94022 implements better error messages. |
The 0.1.6 version of the code takes care of this for you [https://github.com//pull/63].
|
Your initial example now produces the following result:
|
Hello,
Thank you for the great work.
I'm trying to implement wavedec2 on image tensor (e.g. size (1x3x10x10))
When i run
tmp = torch.randn(1, 3, 10, 10)
wavelet = pywt.Wavelet("haar")
coeff2d = ptwt.wavedec2(tmp, wavelet, level=1, mode="zero")
I get
RuntimeError: Given groups=1, weight of size [4, 1, 2, 2], expected input[1, 3, 10, 10] to have 1 channels, but got 3 channels instead
This error. It seems to be having problem with input tensor that has a channel dimension bigger than 1.
It would be greatly appreciated if you could help me out on this
Best,
The text was updated successfully, but these errors were encountered: