Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM during backward pass on a model with ~600k parameters #54

Closed
rafaelvalle opened this issue May 22, 2019 · 14 comments
Closed

OOM during backward pass on a model with ~600k parameters #54

rafaelvalle opened this issue May 22, 2019 · 14 comments

Comments

@rafaelvalle
Copy link
Contributor

Hey Ricky,

I'm running out of memory during the backward pass on a 16gb gpu when running the adjoint method with rtol 1e-5, atol 1e-5, and a network with 631058 parameters.

I'm not sure why this happens given that the augmented_dynamics is within torch.no_grad() and the tensors saved during the forward pass should not be that large.

Any thoughts on what is happening and how to debug it?

The model network is a 3d unet (UNet) that goes into a few 3d conv(node_layers).

Model(
  (prenet_layers): PrenetLayer(
    (initval_layers): Identity()
    (image_layers): Sequential(
      (0): Conv3d(3, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): UNet(
        (in_norm): GroupNorm(8, 8, eps=1e-05, affine=True)
        (in_act): ReLU(inplace)
        (down1): down(
          (mpconv): Sequential(
            (0): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (1): double_conv(
              (conv): Sequential(
                (0): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (2): LeakyReLU(negative_slope=0.1, inplace)
                (3): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                (4): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (5): LeakyReLU(negative_slope=0.1, inplace)
              )
            )
          )
        )
        (down2): down(
          (mpconv): Sequential(
            (0): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (1): double_conv(
              (conv): Sequential(
                (0): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (2): LeakyReLU(negative_slope=0.1, inplace)
                (3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (5): LeakyReLU(negative_slope=0.1, inplace)
              )
            )
          )
        )
        (down3): down(
          (mpconv): Sequential(
            (0): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (1): double_conv(
              (conv): Sequential(
                (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (2): LeakyReLU(negative_slope=0.1, inplace)
                (3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (5): LeakyReLU(negative_slope=0.1, inplace)
              )
            )
          )
        )
        (down4): down(
          (mpconv): Sequential(
            (0): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
            (1): double_conv(
              (conv): Sequential(
                (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (2): LeakyReLU(negative_slope=0.1, inplace)
                (3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                (5): LeakyReLU(negative_slope=0.1, inplace)
              )
            )
          )
        )
        (up0): up(
          (up): Upsample(scale_factor=2.0, mode=trilinear)
          (conv): double_conv(
            (conv): Sequential(
              (0): Conv3d(128, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): LeakyReLU(negative_slope=0.1, inplace)
              (3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (5): LeakyReLU(negative_slope=0.1, inplace)
            )
          )
        )
        (up1): up(
          (up): Upsample(scale_factor=2.0, mode=trilinear)
          (conv): double_conv(
            (conv): Sequential(
              (0): Conv3d(64, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): LeakyReLU(negative_slope=0.1, inplace)
              (3): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (4): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (5): LeakyReLU(negative_slope=0.1, inplace)
            )
          )
        )
        (up2): up(
          (up): Upsample(scale_factor=2.0, mode=trilinear)
          (conv): double_conv(
            (conv): Sequential(
              (0): Conv3d(32, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (1): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): LeakyReLU(negative_slope=0.1, inplace)
              (3): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (4): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (5): LeakyReLU(negative_slope=0.1, inplace)
            )
          )
        )
        (up3): up(
          (up): Upsample(scale_factor=2.0, mode=trilinear)
          (conv): double_conv(
            (conv): Sequential(
              (0): Conv3d(16, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (1): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): LeakyReLU(negative_slope=0.1, inplace)
              (3): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
              (4): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (5): LeakyReLU(negative_slope=0.1, inplace)
            )
          )
        )
        (out): Conv3d(8, 8, kernel_size=(1, 1, 1), stride=(1, 1, 1))
        (out_norm): GroupNorm(8, 8, eps=1e-05, affine=True)
      )
    )
  )
  (node_layers): Sequential(
    (0): ODEBlock(
      (odefunc): ODEfunc(
        (tanh): Tanh()
        (conv_emb1): ConcatConv3d(
          (_layer): Conv3d(2, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (norm_emb1): GroupNorm(4, 4, eps=1e-05, affine=True)
        (conv_emb2): ConcatConv3d(
          (_layer): Conv3d(5, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (norm_emb2): GroupNorm(4, 4, eps=1e-05, affine=True)
        (norm_img_pre): GroupNorm(8, 8, eps=1e-05, affine=True)
        (conv_img): ConcatConv3d(
          (_layer): Conv3d(9, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (norm_img): GroupNorm(4, 4, eps=1e-05, affine=True)
        (conv1): ConcatConv3d(
          (_layer): Conv3d(9, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (norm_conv1): GroupNorm(4, 4, eps=1e-05, affine=True)
        (conv2): ConcatConv3d(
          (_layer): Conv3d(5, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (norm_conv2): GroupNorm(4, 4, eps=1e-05, affine=True)
        (conv_out): ConcatConv3d(
          (_layer): Conv3d(5, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1))
        )
      )
    )
  )
  (postnet_layers): Sequential(
    (0): Identity()
  )
)
@rtqichen
Copy link
Owner

Hi Rafael,

The most memory-expensive operation is this just line https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/adjoint.py#L41
where a single backward call is made to the network.

Can you make sure that a single forward & backward pass of this network can fit into memory? Otherwise, I can only think of the batch size being too large and reducing that..

@rafaelvalle
Copy link
Contributor Author

Let me take a look at it and get back to you.

@rafaelvalle
Copy link
Contributor Author

rafaelvalle commented May 25, 2019

Although I was able to execute multiple forward (~40) and backward passes ~(5) after changing the the Neural ODE layers to something very small (see end of reply), the grad operation inside the augmented_dynamics method still results in an out of memory (see error trace below).

Error trace

Traceback (most recent call last):
  File "train.py", line 367, in <module>
    init_value_method=init_value_method)
  File "train.py", line 280, in train
    loss.backward()
  File "/opt/conda/lib/python3.6/site-packages/torch/tensor.py", line 107, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/opt/conda/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
  File "/opt/conda/lib/python3.6/site-packages/torch/autograd/function.py", line 77, in apply
    return self._forward_cls.backward(self, *args)
  File "/torchdiffeq/torchdiffeq/_impl/adjoint.py", line 83, in backward
    torch.tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options
  File "/torchdiffeq/torchdiffeq/_impl/odeint.py", line 72, in odeint
    solution = solver.integrate(t)
  File "/torchdiffeq/torchdiffeq/_impl/solvers.py", line 31, in integrate
    y = self.advance(t[i])
  File "/torchdiffeq/torchdiffeq/_impl/dopri5.py", line 90, in advance
    self.rk_state = self._adaptive_dopri5_step(self.rk_state)
  File "/torchdiffeq/torchdiffeq/_impl/dopri5.py", line 103, in _adaptive_dopri5_step
    y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU)
  File "/torchdiffeq/torchdiffeq/_impl/rk_common.py", line 52, in _runge_kutta_step
    tuple(k_.append(f_) for k_, f_ in zip(k, func(ti, yi)))
  File "/torchdiffeq/torchdiffeq/_impl/misc.py", line 187, in <lambda>
    func = lambda t, y: tuple(-f_ for f_ in _base_reverse_func(-t, y))
  File "/torchdiffeq/torchdiffeq/_impl/adjoint.py", line 43, in augmented_dynamics
    tuple(-adj_y_ for adj_y_ in adj_y), allow_unused=True, retain_graph=True
  File "/opt/conda/lib/python3.6/site-packages/torch/autograd/__init__.py", line 149, in grad
    inputs, allow_unused)
RuntimeError: CUDA out of memory. Tried to allocate 482.00 MiB (GPU 0; 15.75 GiB total capacity; 13.56 GiB already allocated; 437.94 MiB free; 719.27 MiB cached)

Model

Model(
   UNet from previous post
  (node_layers): Sequential(
    (0): ODEBlock(
      (odefunc): ODEfunc(
        (tanh): Tanh()
        (conv_emb1): ConcatConv3d(
          (_layer): Conv3d(2, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (norm_emb1): GroupNorm(1, 1, eps=1e-05, affine=True)
        (conv_emb2): ConcatConv3d(
          (_layer): Conv3d(2, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (norm_emb2): GroupNorm(1, 1, eps=1e-05, affine=True)
        (norm_img_pre): GroupNorm(8, 8, eps=1e-05, affine=True)
        (conv_img): ConcatConv3d(
          (_layer): Conv3d(9, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (norm_img): GroupNorm(1, 1, eps=1e-05, affine=True)
        (conv1): ConcatConv3d(
          (_layer): Conv3d(3, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (norm_conv1): GroupNorm(1, 1, eps=1e-05, affine=True)
        (conv2): ConcatConv3d(
          (_layer): Conv3d(2, 1, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
        )
        (norm_conv2): GroupNorm(1, 1, eps=1e-05, affine=True)
        (conv_out): ConcatConv3d(
          (_layer): Conv3d(2, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1))
        )
      )
    )
  )
  (postnet_layers): Sequential(
    (0): Identity()
  )
)
Number of parameters: 628256

@rtqichen
Copy link
Owner

rtqichen commented May 26, 2019

Hmm this is running out of memory when computing line I linked before. When you say forward and backward passes, do you mean of this network or odeint? I'd try with an even smaller model and check if there's a memory leak first. If there isn't, then it's simply using more memory than what's available on a single GPU.

@rafaelvalle
Copy link
Contributor Author

rafaelvalle commented May 26, 2019

By forward passes and backward passes I mean forward evaluations of the NODE and backward evaluations (adjoint method) of the ODE, that is, odeint. The memory footprint on forward and backward ode evaluations should not change, right?

@rtqichen
Copy link
Owner

rtqichen commented May 26, 2019

Right it shouldn't change. If it does then there might be a memory leak, but I haven't observed that happening yet, at least with pytorch version 1.0.. Are there any other variable-memory components, perhaps variable-length inputs?

@rafaelvalle
Copy link
Contributor Author

Not really, inputs are images of fixed length.

@rafaelvalle
Copy link
Contributor Author

Any suggestions on how to debug this?

@rtqichen
Copy link
Owner

I can't tell based on the current information as it just seems like a standard OOM error. Do you have a reproducible script that you can show?

@rafaelvalle
Copy link
Contributor Author

Let me recreate a small reproducible script and share it with you.

@rafaelvalle
Copy link
Contributor Author

rafaelvalle commented May 26, 2019

Here it is! As minimal as possible.
Let me know if you need anything else.

from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F

tanh = nn.Tanh
act = partial(nn.ReLU, inplace=True)


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


def norm(dim, n_groups=32):
    """GroupNorm"""
    return nn.GroupNorm(min(n_groups, dim), dim)


class ResBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(ResBlock, self).__init__()
        self.norm1 = norm(inplanes)
        self.act1 = act()
        self.downsample = downsample
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.norm2 = norm(planes)
        self.act2 = act()
        self.conv2 = conv3x3(planes, planes)

    def forward(self, x):
        shortcut = x

        out = self.act1(self.norm1(x))

        if self.downsample is not None:
            shortcut = self.downsample(out)

        out = self.conv1(out)
        out = self.norm2(out)
        out = self.act2(out)
        out = self.conv2(out)

        return out + shortcut


class ConcatConv3d(nn.Module):
    def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0,
                 dilation=1, groups=1, bias=True, transpose=False,
                 zeros_init=False):
        super(ConcatConv3d, self).__init__()
        module = nn.ConvTranspose3d if transpose else nn.Conv3d
        self._layer = module(dim_in + 1, dim_out, kernel_size=ksize,
                             stride=stride, padding=padding, dilation=dilation,
                             groups=groups, bias=bias)
        if zeros_init:
            self._layer.weight.data.zero_()
            self._layer.bias.data.zero_()

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :, :]) * t
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)


class ODEfunc(nn.Module):
    def __init__(self, dim_init_value, dim_cond, dim, zeros_init=False):
        super(ODEfunc, self).__init__()
        self.tanh = tanh()

        # init_value embedding
        self.conv_init_value1 = ConcatConv3d(dim_init_value, dim, 3, 1, 1, zeros_init=zeros_init)
        self.norm_init_value1 = norm(dim)
        self.conv_init_value2 = ConcatConv3d(dim, dim, 3, 1, 1, zeros_init=zeros_init)
        self.norm_init_value2 = norm(dim)

        # image embedding
        self.norm_img_pre = norm(dim_cond)
        self.conv_img = ConcatConv3d(dim_cond, dim, 3, 1, 1, zeros_init=zeros_init)
        self.norm_img = norm(dim)

        # first input is concatenation of h_init_value and h_img
        self.conv1 = ConcatConv3d(2*dim, dim, 3, 1, 1, zeros_init=zeros_init)
        self.norm_conv1 = norm(dim)
        self.conv2 = ConcatConv3d(dim, dim, 3, 1, 1, zeros_init=zeros_init)
        self.norm_conv2 = norm(dim)

        # conv_out
        self.conv_out = ConcatConv3d(dim, dim_init_value, 1, 1, 0, zeros_init=zeros_init)

        # house keeping number of function evaluations
        self.nfe = 0

    def forward(self, t, init_value):
        self.nfe += 1
        print(self.nfe)

        # compute initi_value hidden state
        h_init_value = self.conv_init_value1(t, init_value)
        h_init_value = self.norm_init_value1(h_init_value)
        h_init_value = self.tanh(h_init_value)
        h_init_value = self.conv_init_value2(t, h_init_value)
        h_init_value = self.norm_init_value2(h_init_value)
        h_init_value = self.tanh(h_init_value)

        # compute image hidden state, h_img is conv out, no norm nor activation
        h_img = self.condition
        h_img = self.norm_img_pre(h_img)
        h_img = self.tanh(h_img)
        h_img = self.conv_img(t, h_img)
        h_img = self.norm_img(h_img)
        h_img = self.tanh(h_img)

        # compute output based on init_value and image hidden states
        V = torch.cat((h_init_value, h_img), dim=1)
        V = self.conv1(t, V)
        V = self.norm_conv1(V)
        V = self.tanh(V)
        V = self.conv2(t, V)
        V = self.norm_conv2(V)
        V = self.tanh(V)

        # project back to init_value dimension
        V = self.conv_out(t, V)

        return V


class ODEBlock(nn.Module):
    def __init__(self, odefunc, odeint, rtol, atol):
        super(ODEBlock, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()
        self.odeint = odeint
        self.rtol = rtol
        self.atol = atol

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x).to(x.device)
        out = self.odeint(self.odefunc, x, self.integration_time,
                          rtol=self.rtol, atol=self.atol)
        return out[1]

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value


class UNet(nn.Module):
    # adapted from https://github.com/milesial/Pytorch-UNet/
    def __init__(self, n_in_channels):
        super(UNet, self).__init__()
        self.in_norm = norm(8)
        self.in_act = nn.ReLU(inplace=True)
        self.down1 = down(8, 16, kernel_size=3, stride=1)
        self.down2 = down(16, 32,  kernel_size=3, stride=1)
        self.down3 = down(32, 64, kernel_size=3, stride=1)
        self.down4 = down(64, 64, kernel_size=3, stride=1)
        self.up0 = up(128, 32, kernel_size=3, stride=1)
        self.up1 = up(64, 16, kernel_size=3, stride=1)
        self.up2 = up(32, 8, kernel_size=3, stride=1)
        self.up3 = up(16, n_in_channels, kernel_size=3, stride=1)
        self.out = nn.Conv3d(n_in_channels, n_in_channels, kernel_size=1,
                             stride=1, padding=0)
        self.out_norm = norm(n_in_channels)

    def forward(self, x1):
        x1 = self.in_norm(x1)
        x1 = self.in_act(x1)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up0(x5, x4)
        x = self.up1(x, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x = self.out(x)
        x = self.out_norm(x)
        return x


class double_conv(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, kernel_size, stride=stride,
                      padding=(kernel_size-1)//2),
            nn.BatchNorm3d(out_ch),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv3d(out_ch, out_ch, kernel_size, stride=stride,
                      padding=(kernel_size-1)//2),
            nn.BatchNorm3d(out_ch),
            nn.LeakyReLU(0.1, inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool3d(2),
            double_conv(in_ch, out_ch, kernel_size, stride)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


class up(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride):
        super(up, self).__init__()
        # bilinear upsampling to save memory
        self.up = nn.Upsample(scale_factor=2, mode='trilinear',
                              align_corners=True)

        # out of memory
        # self.up = nn.ConvTranspose3d(in_ch//2, in_ch//2, 2, stride=2)
        self.conv = double_conv(in_ch, out_ch, kernel_size, stride)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        # pad if needed
        diff_h = x2.size(2) - x1.size(2)
        diff_w = x2.size(3) - x1.size(3)
        diff_z = x2.size(4) - x1.size(4)
        x1 = F.pad(x1, (diff_w // 2, diff_w - diff_w//2,
                        diff_h // 2, diff_h - diff_h//2,
                        diff_z // 2, diff_z - diff_z//2))

        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class Model(torch.nn.Module):
    def __init__(self, n_image_channels=3, n_initval_filters=1,
                 n_ode_filters=8, n_prenet_filters=8, rtol=1e-5, atol=1e-5,
                 zeros_init=True):
        super(Model, self).__init__()

        prenet_layers = [nn.Conv3d(n_image_channels, n_prenet_filters, 3, 1, 1)]
        prenet_layers.append(UNet(n_prenet_filters))
        self.prenet_layers = nn.Sequential(*prenet_layers)

        # neural ode layers
        from torchdiffeq import odeint_adjoint as odeint
        ode_func = ODEfunc(n_initval_filters, n_prenet_filters,
                           n_ode_filters, zeros_init)
        node_layers = [ODEBlock(ode_func, odeint, rtol, atol)]
        self.node_layers = nn.Sequential(*node_layers)

    def forward(self, x):
        init_value, img = x[:, 0][:, None], x[:, 1:]
        h_img = self.prenet_layers(img)
        self.node_layers[0].odefunc.condition = h_img
        node_out = self.node_layers(init_value)
        return node_out


model = Model().cuda()
print(model)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

init_val = torch.randn(1, 1, 64, 496, 496).cuda()
img = torch.randn(1, 3, 64, 496, 496).cuda()
x = torch.cat((init_val, img), dim=1)
y = torch.zeros(1, 1, 64, 496, 496).cuda()

optimizer.zero_grad()

print("forward")
y_hat = model(x)
model.node_layers[0].nfe = 0
loss = torch.mean((y - y_hat)**2)

print("backward")
loss.backward()
optimizer.step()
model.node_layers[0].nfe = 0

@rtqichen
Copy link
Owner

rtqichen commented May 26, 2019

Yeah, this seems to be a standard OOM error. Changing the output of ODEBlock

out = self.odeint(self.odefunc, x, self.integration_time, rtol=self.rtol, atol=self.atol)
return out[1]

into

return self.odefunc(torch.tensor(0.), x)

results in an OOM because a single backward pass can't even fit into memory.

This is because the input size you're using is incredibly large. Just to store these variables

init_val = torch.randn(1, 1, 64, 496, 496).cuda()
img = torch.randn(1, 3, 64, 496, 496).cuda()
x = torch.cat((init_val, img), dim=1)
y = torch.zeros(1, 1, 64, 496, 496).cuda()

requires 1.5GB of memory. As a result, the activations in the network are huge as well. It might be best to split up your data samples into chunks, downsample beforehand, or use multiple GPUs.

@rafaelvalle
Copy link
Contributor Author

rafaelvalle commented May 26, 2019

In that case, Ricky, why is it that the code I shared goes through the ODEfunc's forward method twice during the loss.backward?
I thought that this would be equal to two ODE evals where the first works fine and the second OOMs...

@rtqichen
Copy link
Owner

rtqichen commented May 26, 2019

I think it's running out of memory during the ODE solving step: it stores at least 6 evaluations of the ODE, then takes a step using a linear combination of them. I think it's able to make some evaluations of f, but then the ODE solver runs out of memory because it requires multiple tensors the same size as f to be stored in memory. The intermediate layers of f isn't kept in memory, but the output values are needed (see https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/rk_common.py#L52.)

On my computer (with 12GB), I couldn't even get the forward pass to finish. But a single NFE isn't a single ODE step. You would need roughly 6 (maybe 7) NFEs to finish before concluding it can take a single ODE step.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants