Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Size mismatch occurs in UNet model at 5th stage #548

Closed
Fritskee opened this issue Jan 30, 2022 · 12 comments
Closed

Size mismatch occurs in UNet model at 5th stage #548

Fritskee opened this issue Jan 30, 2022 · 12 comments
Labels

Comments

@Fritskee
Copy link

Fritskee commented Jan 30, 2022

I used the SMP library to create a UNet model with the following configurations:
model = smp.Unet(encoder_name='resnet50', encoder_weights='imagenet', in_channels=3, classes=30)

However, I have also tried with other encoders (including the default resnet34) and the error seems to appear for every encoder that I choose. I am training it on a custom dataset of which the dimensions of the images are: w=320, h=192

My code runs fine until one of the final steps in the decoder block. The error traces back to smp/unet/decoder.py. When I'm running a training epoch, the error occurs in def forward(self, x, skip=None) of decoder.py

def forward(self, x, skip=None):
    x = F.interpolate(x, scale_factor=2, mode="nearest")
    if skip is not None:
            x = torch.cat([x, skip], dim=1)
            x = self.attention1(x)
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.attention2(x)
    return x

For the first steps, everything runs fine and the dimensions of 'x' match with 'skip'. Below you can find a list of the dimensions of both x and skip as I go through the decoder:

STEP 1
x.shape
Out[1]: torch.Size([1, 2048, 14, 20])
skip.shape
Out[2]: torch.Size([1, 1024, 14, 20])
STEP 2
x.shape
Out[3]: torch.Size([1, 256, 28, 40])
skip.shape
Out[4]: torch.Size([1, 512, 28, 40])
STEP 3
x.shape
Out[5]: torch.Size([1, 128, 56, 80])
skip.shape
Out[6]: torch.Size([1, 256, 55, 80])
STEP 4
x.shape
Out[7]: torch.Size([1, 128, 56, 80])
skip.shape
Out[8]: torch.Size([1, 256, 55, 80])
STEP 5
x.shape
Out[9]: torch.Size([1, 3, 192, 320])
skip.shape
Out[10]: torch.Size([1, 256, 55, 80])

Around step 3, a mismatch between the tensors starts occurring which causes the error. This error traceback can be seen in the indented block below.
What I find weird about this, is that I have used the exact same codebase with a different dataset that only consisted of 6 classes and in that case there was no issue. I am also unsure where this is happening as I cannot seem to find the root cause.

Traceback

(most recent call last):
File "/Users/fc/Desktop/ct/segmentation_code/main.py", line 141, in
trainer.train()
File "/Users/fc/Desktop/ct/segmentation_code/ops/trainer.py", line 44, in train
self.train_logs = self.train_epoch.run(self.trainloader)
File "/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/utils/train.py", line 47, in run
loss, y_pred = self.batch_update(x, y)
File "/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/utils/train.py", line 87, in batch_update
prediction = self.model.forward(x)
File "/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/base/model.py", line 16, in forward
decoder_output = self.decoder(*features)
File "/Users/fc/miniconda3/envs/ct/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/unet/decoder.py", line 119, in forward
x = decoder_block(x, skip)
File "/Users/fc/miniconda3/envs/ct/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/unet/decoder.py", line 38, in forward
x = torch.cat([x, skip], dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 56 but got size 55 for tensor number 1 in the list.

@qubvel
Copy link
Owner

qubvel commented Jan 30, 2022

I guess some image has wrong shape, which is not divisible by 32 (maybe after augmentation). I would recommend first goes through the whole dataset/dataloader and check shapes. Because in another case this error appeared in the start of training.

@Fritskee
Copy link
Author

I guess some image has wrong shape, which is not divisible by 32 (maybe after augmentation). I would recommend first goes through the whole dataset/dataloader and check shapes. Because in another case this error appeared in the start of training.

Actually I checked the dimensions of all my images before deciding on the resize, I then set the width and height to the minimum value of the entire dataset. Currently I also don't perform any augmentations, so I am not sure what it could be. What I am pretty certain of is that it is not a data issue (I'm currently just using 4 samples to test whether it all runs and the issue also remains. Image dimensions are definitely above the dimensions I resize to)

@qubvel
Copy link
Owner

qubvel commented Jan 30, 2022

Could you please construct a simple code example to reproduce the bug? With dummy tensors, e.g.

batch = torch.ones([10, 3, 320, 192])
model = smp.Unet()
model(batch)

@Fritskee
Copy link
Author

Could you please construct a simple code example to reproduce the bug? With dummy tensors, e.g.

batch = torch.ones([10, 3, 320, 192])
model = smp.Unet()
model(batch)

I can try it tonight when I'm on my computer at home. However, you can replicate the error to one that is ALMOST similar (but has slightly different message than the one I got) by doing the following:

batch = torch.ones([10, 3, 320, 218])
model = smp.Unet()
model(batch)

However, this yields the message

RuntimeError: Sizes of tensors must match except in dimension 1. Got 56 and 55 in dimension 3 (The offending index is 1)
Which has a small difference in the second sentence of the error.

@qubvel
Copy link
Owner

qubvel commented Jan 31, 2022

But, 218 is not divisible by 32. It is not that case.

@qubvel
Copy link
Owner

qubvel commented Jan 31, 2022

Once again, in order to use models, your input size should be divisible by 32 (in some cases by 16). I have updated master branch with proper input validation. Please chech your case with the latest version.

@Fritskee
Copy link
Author

Once again, in order to use models, your input size should be divisible by 32 (in some cases by 16). I have updated master branch with proper input validation. Please chech your case with the latest version.

Hello, I will check my case tonight when I'm back at my personal computer. However, my error is slightly different and in my case yesterday when I used 320x192 (both divisible by 30 and passing the example above) I still get the error that I posted. This makes me suggest it is something (totally) different still.

@qubvel
Copy link
Owner

qubvel commented Jan 31, 2022

I can suggest a kind of wrapper for the model to ensure it will resize your image to the appropriate size:

import torch

class ResizeWrapper(torch.nn.Module):

    def __init__(self, model):
        super().__init__()
        self.model = model

    def compute_sizes(self, x):
        h, w = x.shape[-2:]
        output_stride = self.model.encoder.output_stride
        new_h = (h // output_stride + 1) * output_stride if h % output_stride != 0 else h
        new_w = (w // output_stride + 1) * output_stride if w % output_stride != 0 else w
        return (h, w), (new_h, new_w)

    def forward(self, x):
        old_size, new_size = self.compute_sizes(x)
        if old_size != new_size:
            x = torch.nn.functional.interpolate(x, size=new_size, mode="bilinear")
        x = self.model(x)
        if old_size != new_size:
            x = torch.nn.functional.interpolate(x, size=old_size, mode="bilinear")
        return x

model = ResizeWrapper(smp.Unet(...))

should work with the latest master version

@JulienMaille
Copy link
Contributor

Once again, in order to use models, your input size should be divisible by 32 (in some cases by 16). I have updated master branch with proper input validation. Please chech your case with the latest version.

not 32 but 2^depth, you can relax the constraing if using 3 or 4 stages instead of 5

@qubvel
Copy link
Owner

qubvel commented Jan 31, 2022

@JulienMaille yes, it is taken into consideration

@github-actions
Copy link

github-actions bot commented Apr 2, 2022

This issue is stale because it has been open 60 days with no activity. Remove stale label or comment or this will be closed in 7 days.

@github-actions github-actions bot added the Stale label Apr 2, 2022
@github-actions
Copy link

This issue was closed because it has been stalled for 7 days with no activity.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants