Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unet3d - koila.errors.UnsupportedError #22

Open
etienne87 opened this issue Jan 5, 2022 · 1 comment
Open

unet3d - koila.errors.UnsupportedError #22

etienne87 opened this issue Jan 5, 2022 · 1 comment

Comments

@etienne87
Copy link

I am trying to apply koila lazy eval on a Unet3D.

# defining the model
import torch
import torch.nn as nn
import torch.nn.functional as F


def conv3(in_channels, out_channels, stride, norm='BatchNorm3d', act='GELU'):
    return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, 3, 1, 1),
            getattr(nn, norm)(out_channels),
            getattr(nn, act)())


def double_conv3(in_channels, out_channels, stride):
    return nn.Sequential(conv3(in_channels, out_channels, 1),
                         conv3(out_channels, out_channels, stride))

def merge_skip(x, skip):
    x = F.upsample(x, size=skip.shape[-3:], mode='trilinear', align_corners=True)
    return torch.cat((x,skip),dim=1)



class Unet3D(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers=4, base=16):  
        super().__init__()
	
        enc_channels = [in_channels]+[base * 2**i for i in range(num_layers)]
        dec_channels = [base * 2**i for i in range(num_layers-1,-1,-1)]+[out_channels]

        self.encoders = nn.ModuleList()
        for i in range(len(enc_channels)-1):
            cin = enc_channels[i]
            cout = enc_channels[i+1]
            enc = double_conv3(cin, cout, 2)
            self.encoders.append(enc)

        self.decoders = nn.ModuleList()
        for i in range(len(dec_channels)-1):
            cin_skip = enc_channels[-i-2]
            cin_up = dec_channels[i]
            cin = cin_skip + cin_up 
            cout = dec_channels[i+1]
            dec = double_conv3(cin, cout, 1)	
            self.decoders.append(dec)

    def forward(self, x, return_all=False):
        out = [x]
        for encoder in self.encoders:
            x = encoder(x)
            out.append(x)
        n = len(out)
        for i, decoder in enumerate(self.decoders): 
            skip = out[n - 2 - i]
            x = merge_skip(out[-1], skip)
            x = decoder(x)
            out.append(x)

        if return_all:
            return out 
        else:
            return out[-1]

# test of koila on unet
def test_lazy():
    net = Unet3D(1,3)
    net.cuda()
    s = 64 
    b,c,d,h,w = 2,1,s,s,s
    x = torch.randn(b,c,d,h,w).cuda()
    t = torch.randint(0,3, (b,d,h,w)).cuda()

    loss_fn = nn.CrossEntropyLoss()
    net.zero_grad()

    lazy_x, lazy_t = lazy(x, t, batch=0)
    lazy_out = net(lazy_x)
    lazy_loss = loss_fn(lazy_out, lazy_t) 
    assert isinstance(lazy_loss, LazyTensor), type(lazy_loss)
    lazy_loss.backward()



# This fails
test_lazy()

This fails and outputs:

tensors = (tensor([[[[[-8.9936e-02, -7.9037e-02, -1.5048e-02,  ...,  2.9969e-01,
             2.9774e-01, -1.0489e-01],
        ...]]], device='cuda:0',
       grad_fn=<UpsampleTrilinear3DBackward1>), <koila.lazy.LazyTensor object at 0x7fa21bf99880>)
dim = 1, args = (), kwargs = {}, shapes = [torch.Size([2, 128, 64, 64, 64]), (2, 64, 64, 64, 64)]
no_dim = [torch.Size([2, 64, 64, 64]), (2, 64, 64, 64)], result_size = torch.Size([2, 64, 64, 64])
size = (2, 64, 64, 64)

    def cat(
        tensors: Sequence[TensorLike], dim: int = 0, *args: Any, **kwargs: Any
    ) -> PrePass:
        mute_unused_args(*args, **kwargs)

        if len(tensors) == 0:
            raise ValueError("Expected a sequence of tensors. Got empty sequence.")

        shapes = [t.size() for t in tensors]
        no_dim = [t[:dim] + t[dim + 1 :] for t in shapes]

        result_size = no_dim[0]
        for size in no_dim[1:]:
            if result_size != size:
                raise ValueError(
                    f"Dimension should be equal outside dim {dim}. Got {shapes}."
                )

        if len(set(interfaces.bat(t) for t in tensors)) != 1:
>           raise UnsupportedError
E           koila.errors.UnsupportedError

../miniconda3/envs/snakes/lib/python3.9/site-packages/koila/prepasses.py:423: UnsupportedError
@rentruewang
Copy link
Owner

Hi, that means the batch sizes don't match, and the library doesn't know how to deal with that situation.

Since PyTorch's broadcasting rules are extensive, not all rules are supported yet.

I'll see what I can do about it in the upcoming changes in #18 with a much more modular implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants