New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OOM during backward pass on a model with ~600k parameters #54
Comments
Hi Rafael, The most memory-expensive operation is this just line https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/adjoint.py#L41 Can you make sure that a single forward & backward pass of this network can fit into memory? Otherwise, I can only think of the batch size being too large and reducing that.. |
Let me take a look at it and get back to you. |
Although I was able to execute multiple forward (~40) and backward passes ~(5) after changing the the Neural ODE layers to something very small (see end of reply), the grad operation inside the Error trace
Model
|
Hmm this is running out of memory when computing line I linked before. When you say forward and backward passes, do you mean of this network or odeint? I'd try with an even smaller model and check if there's a memory leak first. If there isn't, then it's simply using more memory than what's available on a single GPU. |
By forward passes and backward passes I mean forward evaluations of the NODE and backward evaluations (adjoint method) of the ODE, that is, odeint. The memory footprint on forward and backward ode evaluations should not change, right? |
Right it shouldn't change. If it does then there might be a memory leak, but I haven't observed that happening yet, at least with pytorch version 1.0.. Are there any other variable-memory components, perhaps variable-length inputs? |
Not really, inputs are images of fixed length. |
Any suggestions on how to debug this? |
I can't tell based on the current information as it just seems like a standard OOM error. Do you have a reproducible script that you can show? |
Let me recreate a small reproducible script and share it with you. |
Here it is! As minimal as possible. from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
tanh = nn.Tanh
act = partial(nn.ReLU, inplace=True)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
def norm(dim, n_groups=32):
"""GroupNorm"""
return nn.GroupNorm(min(n_groups, dim), dim)
class ResBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(ResBlock, self).__init__()
self.norm1 = norm(inplanes)
self.act1 = act()
self.downsample = downsample
self.conv1 = conv3x3(inplanes, planes, stride)
self.norm2 = norm(planes)
self.act2 = act()
self.conv2 = conv3x3(planes, planes)
def forward(self, x):
shortcut = x
out = self.act1(self.norm1(x))
if self.downsample is not None:
shortcut = self.downsample(out)
out = self.conv1(out)
out = self.norm2(out)
out = self.act2(out)
out = self.conv2(out)
return out + shortcut
class ConcatConv3d(nn.Module):
def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0,
dilation=1, groups=1, bias=True, transpose=False,
zeros_init=False):
super(ConcatConv3d, self).__init__()
module = nn.ConvTranspose3d if transpose else nn.Conv3d
self._layer = module(dim_in + 1, dim_out, kernel_size=ksize,
stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias)
if zeros_init:
self._layer.weight.data.zero_()
self._layer.bias.data.zero_()
def forward(self, t, x):
tt = torch.ones_like(x[:, :1, :, :]) * t
ttx = torch.cat([tt, x], 1)
return self._layer(ttx)
class ODEfunc(nn.Module):
def __init__(self, dim_init_value, dim_cond, dim, zeros_init=False):
super(ODEfunc, self).__init__()
self.tanh = tanh()
# init_value embedding
self.conv_init_value1 = ConcatConv3d(dim_init_value, dim, 3, 1, 1, zeros_init=zeros_init)
self.norm_init_value1 = norm(dim)
self.conv_init_value2 = ConcatConv3d(dim, dim, 3, 1, 1, zeros_init=zeros_init)
self.norm_init_value2 = norm(dim)
# image embedding
self.norm_img_pre = norm(dim_cond)
self.conv_img = ConcatConv3d(dim_cond, dim, 3, 1, 1, zeros_init=zeros_init)
self.norm_img = norm(dim)
# first input is concatenation of h_init_value and h_img
self.conv1 = ConcatConv3d(2*dim, dim, 3, 1, 1, zeros_init=zeros_init)
self.norm_conv1 = norm(dim)
self.conv2 = ConcatConv3d(dim, dim, 3, 1, 1, zeros_init=zeros_init)
self.norm_conv2 = norm(dim)
# conv_out
self.conv_out = ConcatConv3d(dim, dim_init_value, 1, 1, 0, zeros_init=zeros_init)
# house keeping number of function evaluations
self.nfe = 0
def forward(self, t, init_value):
self.nfe += 1
print(self.nfe)
# compute initi_value hidden state
h_init_value = self.conv_init_value1(t, init_value)
h_init_value = self.norm_init_value1(h_init_value)
h_init_value = self.tanh(h_init_value)
h_init_value = self.conv_init_value2(t, h_init_value)
h_init_value = self.norm_init_value2(h_init_value)
h_init_value = self.tanh(h_init_value)
# compute image hidden state, h_img is conv out, no norm nor activation
h_img = self.condition
h_img = self.norm_img_pre(h_img)
h_img = self.tanh(h_img)
h_img = self.conv_img(t, h_img)
h_img = self.norm_img(h_img)
h_img = self.tanh(h_img)
# compute output based on init_value and image hidden states
V = torch.cat((h_init_value, h_img), dim=1)
V = self.conv1(t, V)
V = self.norm_conv1(V)
V = self.tanh(V)
V = self.conv2(t, V)
V = self.norm_conv2(V)
V = self.tanh(V)
# project back to init_value dimension
V = self.conv_out(t, V)
return V
class ODEBlock(nn.Module):
def __init__(self, odefunc, odeint, rtol, atol):
super(ODEBlock, self).__init__()
self.odefunc = odefunc
self.integration_time = torch.tensor([0, 1]).float()
self.odeint = odeint
self.rtol = rtol
self.atol = atol
def forward(self, x):
self.integration_time = self.integration_time.type_as(x).to(x.device)
out = self.odeint(self.odefunc, x, self.integration_time,
rtol=self.rtol, atol=self.atol)
return out[1]
@property
def nfe(self):
return self.odefunc.nfe
@nfe.setter
def nfe(self, value):
self.odefunc.nfe = value
class UNet(nn.Module):
# adapted from https://github.com/milesial/Pytorch-UNet/
def __init__(self, n_in_channels):
super(UNet, self).__init__()
self.in_norm = norm(8)
self.in_act = nn.ReLU(inplace=True)
self.down1 = down(8, 16, kernel_size=3, stride=1)
self.down2 = down(16, 32, kernel_size=3, stride=1)
self.down3 = down(32, 64, kernel_size=3, stride=1)
self.down4 = down(64, 64, kernel_size=3, stride=1)
self.up0 = up(128, 32, kernel_size=3, stride=1)
self.up1 = up(64, 16, kernel_size=3, stride=1)
self.up2 = up(32, 8, kernel_size=3, stride=1)
self.up3 = up(16, n_in_channels, kernel_size=3, stride=1)
self.out = nn.Conv3d(n_in_channels, n_in_channels, kernel_size=1,
stride=1, padding=0)
self.out_norm = norm(n_in_channels)
def forward(self, x1):
x1 = self.in_norm(x1)
x1 = self.in_act(x1)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up0(x5, x4)
x = self.up1(x, x3)
x = self.up2(x, x2)
x = self.up3(x, x1)
x = self.out(x)
x = self.out_norm(x)
return x
class double_conv(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv3d(in_ch, out_ch, kernel_size, stride=stride,
padding=(kernel_size-1)//2),
nn.BatchNorm3d(out_ch),
nn.LeakyReLU(0.1, inplace=True),
nn.Conv3d(out_ch, out_ch, kernel_size, stride=stride,
padding=(kernel_size-1)//2),
nn.BatchNorm3d(out_ch),
nn.LeakyReLU(0.1, inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class down(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride):
super(down, self).__init__()
self.mpconv = nn.Sequential(
nn.MaxPool3d(2),
double_conv(in_ch, out_ch, kernel_size, stride)
)
def forward(self, x):
x = self.mpconv(x)
return x
class up(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride):
super(up, self).__init__()
# bilinear upsampling to save memory
self.up = nn.Upsample(scale_factor=2, mode='trilinear',
align_corners=True)
# out of memory
# self.up = nn.ConvTranspose3d(in_ch//2, in_ch//2, 2, stride=2)
self.conv = double_conv(in_ch, out_ch, kernel_size, stride)
def forward(self, x1, x2):
x1 = self.up(x1)
# pad if needed
diff_h = x2.size(2) - x1.size(2)
diff_w = x2.size(3) - x1.size(3)
diff_z = x2.size(4) - x1.size(4)
x1 = F.pad(x1, (diff_w // 2, diff_w - diff_w//2,
diff_h // 2, diff_h - diff_h//2,
diff_z // 2, diff_z - diff_z//2))
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class Model(torch.nn.Module):
def __init__(self, n_image_channels=3, n_initval_filters=1,
n_ode_filters=8, n_prenet_filters=8, rtol=1e-5, atol=1e-5,
zeros_init=True):
super(Model, self).__init__()
prenet_layers = [nn.Conv3d(n_image_channels, n_prenet_filters, 3, 1, 1)]
prenet_layers.append(UNet(n_prenet_filters))
self.prenet_layers = nn.Sequential(*prenet_layers)
# neural ode layers
from torchdiffeq import odeint_adjoint as odeint
ode_func = ODEfunc(n_initval_filters, n_prenet_filters,
n_ode_filters, zeros_init)
node_layers = [ODEBlock(ode_func, odeint, rtol, atol)]
self.node_layers = nn.Sequential(*node_layers)
def forward(self, x):
init_value, img = x[:, 0][:, None], x[:, 1:]
h_img = self.prenet_layers(img)
self.node_layers[0].odefunc.condition = h_img
node_out = self.node_layers(init_value)
return node_out
model = Model().cuda()
print(model)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
init_val = torch.randn(1, 1, 64, 496, 496).cuda()
img = torch.randn(1, 3, 64, 496, 496).cuda()
x = torch.cat((init_val, img), dim=1)
y = torch.zeros(1, 1, 64, 496, 496).cuda()
optimizer.zero_grad()
print("forward")
y_hat = model(x)
model.node_layers[0].nfe = 0
loss = torch.mean((y - y_hat)**2)
print("backward")
loss.backward()
optimizer.step()
model.node_layers[0].nfe = 0 |
Yeah, this seems to be a standard OOM error. Changing the output of ODEBlock
into
results in an OOM because a single backward pass can't even fit into memory. This is because the input size you're using is incredibly large. Just to store these variables
requires 1.5GB of memory. As a result, the activations in the network are huge as well. It might be best to split up your data samples into chunks, downsample beforehand, or use multiple GPUs. |
In that case, Ricky, why is it that the code I shared goes through the ODEfunc's forward method twice during the loss.backward? |
I think it's running out of memory during the ODE solving step: it stores at least 6 evaluations of the ODE, then takes a step using a linear combination of them. I think it's able to make some evaluations of f, but then the ODE solver runs out of memory because it requires multiple tensors the same size as f to be stored in memory. The intermediate layers of f isn't kept in memory, but the output values are needed (see https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/rk_common.py#L52.) On my computer (with 12GB), I couldn't even get the forward pass to finish. But a single NFE isn't a single ODE step. You would need roughly 6 (maybe 7) NFEs to finish before concluding it can take a single ODE step. |
Hey Ricky,
I'm running out of memory during the backward pass on a 16gb gpu when running the adjoint method with rtol 1e-5, atol 1e-5, and a network with 631058 parameters.
I'm not sure why this happens given that the augmented_dynamics is within
torch.no_grad()
and the tensors saved during the forward pass should not be that large.Any thoughts on what is happening and how to debug it?
The model network is a 3d unet (UNet) that goes into a few 3d conv(node_layers).
The text was updated successfully, but these errors were encountered: