Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in wavedec2 #40

Closed
abeyang00 opened this issue Sep 12, 2022 · 7 comments · Fixed by #41
Closed

Error in wavedec2 #40

abeyang00 opened this issue Sep 12, 2022 · 7 comments · Fixed by #41
Assignees
Labels
enhancement New feature or request question Further information is requested
Milestone

Comments

@abeyang00
Copy link

Hello,

Thank you for the great work.

I'm trying to implement wavedec2 on image tensor (e.g. size (1x3x10x10))

When i run

tmp = torch.randn(1, 3, 10, 10)
wavelet = pywt.Wavelet("haar")
coeff2d = ptwt.wavedec2(tmp, wavelet, level=1, mode="zero")

I get

RuntimeError: Given groups=1, weight of size [4, 1, 2, 2], expected input[1, 3, 10, 10] to have 1 channels, but got 3 channels instead

This error. It seems to be having problem with input tensor that has a channel dimension bigger than 1.

It would be greatly appreciated if you could help me out on this

Best,

@v0lta
Copy link
Owner

v0lta commented Sep 12, 2022

Dear @abeyang00
Please take a look at the documentation for wavedec2: https://pytorch-wavelet-toolbox.readthedocs.io/en/latest/ptwt.html#ptwt.conv_transform_2.wavedec2 .
wavedec2 expects a three-dimensional input of shape (batch_size, height, width).

import torch, pywt, ptwt
tmp = torch.randn(3, 10, 10)
wavelet = pywt.Wavelet("haar")
coeff2d = ptwt.wavedec2(tmp, wavelet, level=1, mode="zero")

Runs as expected. It is not apparent how a 2d wavelet transform would deal with colour channels. So the toolbox does not expect any. If you want to transform colour channels use wavedec3:

import torch, pywt, ptwt
tmp = torch.randn(1, 3, 10, 10)
wavelet = pywt.Wavelet("haar")
coeff3d = ptwt.wavedec3(tmp, wavelet, level=1, mode="zero")

transforms the colour channel and does not exit with an error.
If you have an input with colour channels but want a 2d transform, move the channels into the batch dimension.

import torch, pywt, ptwt
import numpy as np
from scipy.misc import face
# get a colour image and move the channels into the batch dimension.
face = np.transpose(
        face()[256 : 512, 256 : 512], [2, 0, 1]
    ).astype(np.float64)
# compute the transform
coeff2d = ptwt.wavedec2(torch.from_numpy(face), pywt.Wavelet("Haar"), level=1, mode='zero')
# invert the transform
rec = ptwt.waverec2(coeff2d, pywt.Wavelet("Haar"))
# move the color channels back
face_rec = rec.squeeze().permute(1,2,0)

To process batches of multiple color images at a time you can concatenate these along the batch dimension.
I.e. try face = np.concatenate([face]*3, 0), to process three color images at the same time.

@v0lta v0lta self-assigned this Sep 12, 2022
@v0lta v0lta added the question Further information is requested label Sep 12, 2022
@v0lta
Copy link
Owner

v0lta commented Sep 12, 2022

On second thought, the toolbox should have caught the problem earlier. I will add a more informative error message in the next release.

@v0lta v0lta added this to the v.0.1.4 milestone Sep 12, 2022
@v0lta v0lta added the enhancement New feature or request label Sep 12, 2022
@abeyang00
Copy link
Author

@v0lta Thank you for the quick answer.

However, i'm still confused as I want to incorporate wavedec2 inside convolutional layers in CNN.
So my input will have "n" batches of feature maps (e.g. 3x64x256x256 --> 3 batches with 64 channels of 256x256 feature map)

So concatenating along the batch dimension seems to be a naive approach. This might corrupt the data, no?

And when i try to use wavedec3 with haar wavelet, this seems to output something that i do not want. ('aad', 'ada', etc...)
I just want 4 high frequncy components --> ll, lh, hl, hh

Best,

@v0lta
Copy link
Owner

v0lta commented Sep 13, 2022

Dear @abeyang ,
If want four filters per decomposition level you want a 2D transform. The function wavedec2d does not touch the batch dimension, stacking channels there preserves them. Running the example below should illustrate that the process of stacking channels in the batch dimension is indeed reversible.

import torch, pywt, ptwt
import numpy as np
from scipy.misc import face
# get a colour image and move the channels into the batch dimension.
face = face()[256 : 512, 256 : 512]
face_tp = np.transpose(
       face, [2, 0, 1]
).astype(np.float64)
# compute the transform
coeff2d = ptwt.wavedec2(torch.from_numpy(face_tp), pywt.Wavelet("Haar"),
                                         level=1, mode='zero')
# invert the transform
rec = ptwt.waverec2(coeff2d, pywt.Wavelet("Haar"))
# move the color channels back
face_rec = rec.squeeze().permute(1,2,0)

np.allclose(face, face_rec)

All the best,
Moritz

@v0lta v0lta closed this as completed Sep 14, 2022
@v0lta v0lta reopened this Oct 10, 2022
@v0lta
Copy link
Owner

v0lta commented Oct 11, 2022

cd94022 implements better error messages.

@v0lta
Copy link
Owner

v0lta commented Aug 1, 2023

The 0.1.6 version of the code takes care of this for you [https://github.com//pull/63].
You can now feed batch and channel dimensions into the transform. For example, by running:

import ptwt, pywt, torch
import numpy as np

from scipy import datasets
face = np.transpose(datasets.face(),
                     [2, 0, 1]).astype(np.float64)
pytorch_face = torch.tensor(face).unsqueeze(0) # unsqueeze adds a batch dimension.
 
coefficients = ptwt.wavedec2(pytorch_face, pywt.Wavelet("haar"),
                              level=2, mode="constant")
reconstruction = ptwt.waverec2(coefficients, pywt.Wavelet("haar"))

@v0lta
Copy link
Owner

v0lta commented Aug 1, 2023

Your initial example now produces the following result:

In [1]: tmp = torch.randn(1, 3, 10, 10)
   ...: wavelet = pywt.Wavelet("haar")
   ...: coeff2d = ptwt.wavedec2(tmp, wavelet, level=1, mode="zero")

In [2]: coeff2d[0].shape
Out[2]: torch.Size([1, 3, 5, 5])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Further information is requested
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants