New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Size mismatch occurs in UNet model at 5th stage #548
Comments
I guess some image has wrong shape, which is not divisible by 32 (maybe after augmentation). I would recommend first goes through the whole dataset/dataloader and check shapes. Because in another case this error appeared in the start of training. |
Actually I checked the dimensions of all my images before deciding on the resize, I then set the width and height to the minimum value of the entire dataset. Currently I also don't perform any augmentations, so I am not sure what it could be. What I am pretty certain of is that it is not a data issue (I'm currently just using 4 samples to test whether it all runs and the issue also remains. Image dimensions are definitely above the dimensions I resize to) |
Could you please construct a simple code example to reproduce the bug? With dummy tensors, e.g. batch = torch.ones([10, 3, 320, 192])
model = smp.Unet()
model(batch) |
I can try it tonight when I'm on my computer at home. However, you can replicate the error to one that is ALMOST similar (but has slightly different message than the one I got) by doing the following:
However, this yields the message
|
But, 218 is not divisible by 32. It is not that case. |
Once again, in order to use models, your input size should be divisible by 32 (in some cases by 16). I have updated master branch with proper input validation. Please chech your case with the latest version. |
Hello, I will check my case tonight when I'm back at my personal computer. However, my error is slightly different and in my case yesterday when I used 320x192 (both divisible by 30 and passing the example above) I still get the error that I posted. This makes me suggest it is something (totally) different still. |
I can suggest a kind of wrapper for the model to ensure it will resize your image to the appropriate size: import torch
class ResizeWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def compute_sizes(self, x):
h, w = x.shape[-2:]
output_stride = self.model.encoder.output_stride
new_h = (h // output_stride + 1) * output_stride if h % output_stride != 0 else h
new_w = (w // output_stride + 1) * output_stride if w % output_stride != 0 else w
return (h, w), (new_h, new_w)
def forward(self, x):
old_size, new_size = self.compute_sizes(x)
if old_size != new_size:
x = torch.nn.functional.interpolate(x, size=new_size, mode="bilinear")
x = self.model(x)
if old_size != new_size:
x = torch.nn.functional.interpolate(x, size=old_size, mode="bilinear")
return x
model = ResizeWrapper(smp.Unet(...)) should work with the latest master version |
not 32 but 2^depth, you can relax the constraing if using 3 or 4 stages instead of 5 |
@JulienMaille yes, it is taken into consideration |
This issue is stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 7 days. |
This issue was closed because it has been stalled for 7 days with no activity. |
I used the SMP library to create a UNet model with the following configurations:
model = smp.Unet(encoder_name='resnet50', encoder_weights='imagenet', in_channels=3, classes=30)
However, I have also tried with other encoders (including the default resnet34) and the error seems to appear for every encoder that I choose. I am training it on a custom dataset of which the dimensions of the images are: w=320, h=192
My code runs fine until one of the final steps in the decoder block. The error traces back to
smp/unet/decoder.py
. When I'm running a training epoch, the error occurs indef forward(self, x, skip=None)
of decoder.pyFor the first steps, everything runs fine and the dimensions of 'x' match with 'skip'. Below you can find a list of the dimensions of both x and skip as I go through the decoder:
Around step 3, a mismatch between the tensors starts occurring which causes the error. This error traceback can be seen in the indented block below.
What I find weird about this, is that I have used the exact same codebase with a different dataset that only consisted of 6 classes and in that case there was no issue. I am also unsure where this is happening as I cannot seem to find the root cause.
Traceback
The text was updated successfully, but these errors were encountered: