-
Notifications
You must be signed in to change notification settings - Fork 22.1k
-
Notifications
You must be signed in to change notification settings - Fork 22.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrong gradient in RNNCell with ReLU? #11662
Comments
@ssnl does this appear to be a bug in the RNNCell with ReLU? |
@moonlightlane Yes it is a bug. You are correct. |
@ssnl No, it's not a bug: It's a corner case of whether the grad of relu at 0 should be 0 or 1. CuDNN uses 1, we use 0.
|
Haha @t-vi Good find! Thanks :) I think we can close this now! |
@t-vi thanks so much!!! I almost switched to tensorflow for this experiment because of this ... |
@moonlightlane Well tf and pytorch call the same cudnn function, so .... |
@moonlightlane We love to keep you on PyTorch. 🤗 |
Thanks! That repo has a few data loading issues that someone has pointed out a while back… I was working on a deadline last few weeks so didn’t have time to fix but am going to fix it in the next few days.
I took a closer look at the inconsistent ReLU implementation in pytorch and cuda; For RNNs without bias, the custom ReLU implementation works fine, and the outputs are the same as those of nn.RNN.
However, there seems to be numerical computation inconsistencies between pytorch matrix calculus implementation and the C implementation underlying both nn.RNNCell and nn.RNN. It is more obvious when bias is added in the RNN models. I have a piece of code to demonstrate this, see below. If you run with
python3 rnn_test.py --use_cuda_relu -—use_gpu -—has_bias -—seed 123456
you’ll see that output of nn.RNN and the custom ReLURNN actually calculates different gradients (sum of norms of gradients are slightly different). In fact, some hidden states are also calculated differently, although only few of them.
The hidden states calculations are the same between nn.RNN and nn.RNNCell, although the backward pass are different (different gradients) because of the issue of different ReLU gradient you pointed out a while ago.
You can try with various flags (different seeds, whether has bias, whether using GPU, whether using the cuda-consistent ReLU implementation), and in a lot of cases the outputs between nn.RNN and the custom ReLURNN are different.
Should I be concerned about the difference in numerical calculations? Or perhaps there is a bug in my code? I haven’t tested on actually training on a dataset, though.
—————— code (also in attachment) ————————————
import torch
import torch.nn as nn
from torch.autograd import Variable
import random
from copy import deepcopy
import argparse
from pdb import set_trace
import math
from pdb import set_trace
parser = argparse.ArgumentParser(description="rnn cpu and gpu tests")
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--use_gpu', action='store_true')
parser.add_argument('--has_bias', action='store_true') # without this flag = no bias
parser.add_argument('--use_cuda_relu', action='store_true') # without this flag = custom relu is the same as in pytorch; with flag = same as cuda
seed = parser.parse_args().seed
use_gpu = parser.parse_args().use_gpu
has_bias = parser.parse_args().has_bias
use_cuda_relu= parser.parse_args().use_cuda_relu
print('use gpu.') if use_gpu else print('use cpu.')
print('use cuda relu') if use_cuda_relu else print('use pytorch relu')
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
random.seed(seed)
## manually create input, target, initial hidden state and criterion
# dim = (seq_len, batch_size, input_size)
input = Variable(torch.randn(100, 64, 1)).cuda() if use_gpu \
else Variable(torch.randn(100, 64, 1))
target = Variable(torch.randint(low=0, high=1, size=(64, ),
dtype=torch.long)).cuda() if use_gpu else \
Variable(torch.randint(low=0, high=1, size=(64, ), dtype=torch.long))
hx_init = Variable(torch.randn(64, 20)).cuda() if use_gpu \
else Variable(torch.randn(64, 20)) # dim = (batch_size, hidden_size)
criterion = nn.CrossEntropyLoss() # use cross entropy loss
# define an init function
# no bias and eye init to make sure two networks have the same parameters
def param_init(rnn, linear):
for name, param in rnn.named_parameters():
if 'weight' in name:
nn.init.eye_(param)
else:
nn.init.constant_(param, 1)
for name, param in linear.named_parameters():
if 'weight' in name:
nn.init.eye_(param)
else:
nn.init.constant_(param, 1)
###########################################################################
## first network, its output and rnn gradients
rnn1 = nn.RNN(1, 20, nonlinearity='relu', bias=has_bias).cuda() \
if use_gpu else nn.RNN(1, 20, nonlinearity='relu', bias=has_bias)
linear1 = nn.Linear(20, 2, bias=has_bias).cuda() \
if use_gpu else nn.Linear(20, 2, bias=has_bias)
# param init
param_init(rnn1, linear1)
# run the net
rnn1.zero_grad()
linear1.zero_grad()
output1, hx1 = rnn1(input, hx_init.unsqueeze(0))
logit1 = linear1(output1[-1])
loss1 = criterion(logit1, target)
# calculate gradients and sum of gradient norms
grad_params1 = torch.autograd.grad(loss1, rnn1.parameters(),
create_graph=True, retain_graph=True, allow_unused=True)
grad_norm1 = 0
for idx in range(len(grad_params1)):
grad_norm1 += torch.norm(grad_params1[idx])
print('rnn1 - nn.RNN')
print('rnn1 - loss: %f' % (loss1))
print('rnn1 - sum of gradient norm is: %f' % (grad_norm1))
print('----')
###########################################################################
## second network, its output and rnn gradients
rnn2 = nn.RNNCell(1, 20, nonlinearity='relu', bias=has_bias).cuda() \
if use_gpu else nn.RNNCell(1, 20, nonlinearity='relu', bias=has_bias)
linear2 = nn.Linear(20, 2, bias=has_bias).cuda() \
if use_gpu else nn.Linear(20, 2, bias=has_bias)
param_init(rnn2, linear2)
# run the net
rnn2.zero_grad()
linear2.zero_grad()
hx2 = deepcopy(hx_init)
output2 = []
for i in range(100):
hx2 = rnn2(input[i], hx2)
output2.append(hx2)
logit2 = linear2(hx2)
loss2 = criterion(logit2, target)
# calculate gradients and sum of gradient norms
grad_params2 = torch.autograd.grad(loss2, rnn2.parameters(),
create_graph=True, retain_graph=True, allow_unused=True)
grad_norm2 = 0
for idx in range(len(grad_params2)):
grad_norm2 += torch.norm(grad_params2[idx])
print('rnn2 - nn.RNNCell')
print('rnn2 - loss: %f' % (loss2))
print('rnn2: sum of gradient norm is: %f' % (grad_norm2))
print('----')
###########################################################################
## third network, its output and rnn gradients
class myReLU(torch.autograd.Function):
'''
custom RNN cell
'''
@staticmethod
def forward(ctx, x, relutype=use_cuda_relu):
if relutype:
ctx._mask = (x >= 0)
else:
ctx._mask = (x > 0)
return x.relu()
@staticmethod
def backward(ctx, grad_out):
return grad_out * ctx._mask.to(grad_out.dtype), None
class ReLURNN(nn.Module):
'''
a ReLU RNN cell with ReLU implementation consistent with cuda
'''
def __init__(self, input_size, hidden_size, bias=True, nonlinearity='relu'):
super(ReLURNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.nonlinearity = nonlinearity
self.weight_ih = nn.Parameter(torch.Tensor(hidden_size, input_size))
self.weight_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
if bias:
self.bias_ih = nn.Parameter(torch.Tensor(hidden_size))
self.bias_hh = nn.Parameter(torch.Tensor(hidden_size))
else:
self.register_parameter('bias_ih', None)
self.register_parameter('bias_hh', None)
# self.reset_parameters()
def extra_repr(self):
s = '{input_size}, {hidden_size}'
if 'bias' in self.__dict__ and self.bias is not True:
s += ', bias={bias}'
if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
s += ', nonlinearity={nonlinearity}'
return s.format(**self.__dict__)
def check_forward_input(self, input):
if input.size(1) != self.input_size:
raise RuntimeError(
"input has inconsistent input_size: got {}, expected {}".format(
input.size(1), self.input_size))
def check_forward_hidden(self, input, hx, hidden_label=''):
if input.size(0) != hx.size(0):
raise RuntimeError(
"Input batch size {} doesn't match hidden{} batch size {}".format(
input.size(0), hidden_label, hx.size(0)))
if hx.size(1) != self.hidden_size:
raise RuntimeError(
"hidden{} has inconsistent hidden_size: got {}, expected {}".format(
hidden_label, hx.size(1), self.hidden_size))
def forward(self, input, hx=None):
self.check_forward_input(input)
if hx is None:
hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
# set_trace()
self.check_forward_hidden(input, hx)
# set_trace()
# TODO take a look at the primitive implementation and implement this myself
if self.nonlinearity == 'relu':
if self.bias:
# return myReLU.apply(torch.mm(input, self.weight_ih.t()) + self.bias_ih
# + torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
return myReLU.apply(torch.mm(input, self.weight_ih.t()) +
self.bias_ih.expand(input.shape[0], self.bias_ih.shape[0]).contiguous() +
torch.mm(hx, self.weight_hh.t()) +
self.bias_hh.expand(input.shape[0], self.bias_hh.shape[0]).contiguous())
else:
return myReLU.apply(torch.mm(input, self.weight_ih.t())
+ torch.mm(hx, self.weight_hh.t()))
rnn3 = ReLURNN(input_size=1, hidden_size=20, bias=has_bias).cuda() \
if use_gpu else ReLURNN(input_size=1, hidden_size=20, bias=has_bias)
linear3 = nn.Linear(20, 2, bias=has_bias).cuda() if use_gpu \
else nn.Linear(20, 2, bias=has_bias)
param_init(rnn3, linear3)
# run the net
rnn3.zero_grad()
linear3.zero_grad()
output3 = []
hx3 = deepcopy(hx_init)
for inp in input:
hx3 = rnn3(inp, hx3)
output3.append(hx3)
logit3 = linear3(output3[-1])
loss3 = criterion(logit3, target)
# calculate gradients and sum of gradient norms
grad_params3 = torch.autograd.grad(loss3, rnn3.parameters(), create_graph=True, retain_graph=True, allow_unused=True)
grad_norm3 = 0
for idx in range(len(grad_params3)):
grad_norm3 += torch.norm(grad_params3[idx])
print('rnn3 - ReLURNN + myReLU')
print('rnn3 - loss: %f' % (loss3))
print('rnn3: sum of gradient norm is: %f' % (grad_norm3))
# set_trace()
… 在 2018年9月20日,下午12:25,Thomas Viehmann ***@***.***> 写道:
@moonlightlane <https://github.com/moonlightlane> We love to keep you on PyTorch. 🤗
You QG-Net-Readme looks awefully cool, but I didn't try it out.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub <#11662 (comment)>, or mute the thread <https://github.com/notifications/unsubscribe-auth/AIMJIiJFI4QaBzhcwEXXsF_36ZNINdKzks5uc8-DgaJpZM4WoDk3>.
|
copied code here for md format: import torch
import torch.nn as nn
from torch.autograd import Variable
import random
from copy import deepcopy
import argparse
from pdb import set_trace
import math
from pdb import set_trace
parser = argparse.ArgumentParser(description="rnn cpu and gpu tests")
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--use_gpu', action='store_true')
parser.add_argument('--has_bias', action='store_true') # without this flag = no bias
parser.add_argument('--use_cuda_relu', action='store_true') # without this flag = custom relu is the same as in pytorch; with flag = same as cuda
seed = parser.parse_args().seed
use_gpu = parser.parse_args().use_gpu
has_bias = parser.parse_args().has_bias
use_cuda_relu= parser.parse_args().use_cuda_relu
print('use gpu.') if use_gpu else print('use cpu.')
print('use cuda relu') if use_cuda_relu else print('use pytorch relu')
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
random.seed(seed)
## manually create input, target, initial hidden state and criterion
# dim = (seq_len, batch_size, input_size)
input = Variable(torch.randn(100, 64, 1)).cuda() if use_gpu \
else Variable(torch.randn(100, 64, 1))
target = Variable(torch.randint(low=0, high=1, size=(64, ),
dtype=torch.long)).cuda() if use_gpu else \
Variable(torch.randint(low=0, high=1, size=(64, ), dtype=torch.long))
hx_init = Variable(torch.randn(64, 20)).cuda() if use_gpu \
else Variable(torch.randn(64, 20)) # dim = (batch_size, hidden_size)
criterion = nn.CrossEntropyLoss() # use cross entropy loss
# define an init function
# no bias and eye init to make sure two networks have the same parameters
def param_init(rnn, linear):
for name, param in rnn.named_parameters():
if 'weight' in name:
nn.init.eye_(param)
else:
nn.init.constant_(param, 1)
for name, param in linear.named_parameters():
if 'weight' in name:
nn.init.eye_(param)
else:
nn.init.constant_(param, 1)
###########################################################################
## first network, its output and rnn gradients
rnn1 = nn.RNN(1, 20, nonlinearity='relu', bias=has_bias).cuda() \
if use_gpu else nn.RNN(1, 20, nonlinearity='relu', bias=has_bias)
linear1 = nn.Linear(20, 2, bias=has_bias).cuda() \
if use_gpu else nn.Linear(20, 2, bias=has_bias)
# param init
param_init(rnn1, linear1)
# run the net
rnn1.zero_grad()
linear1.zero_grad()
output1, hx1 = rnn1(input, hx_init.unsqueeze(0))
logit1 = linear1(output1[-1])
loss1 = criterion(logit1, target)
# calculate gradients and sum of gradient norms
grad_params1 = torch.autograd.grad(loss1, rnn1.parameters(),
create_graph=True, retain_graph=True, allow_unused=True)
grad_norm1 = 0
for idx in range(len(grad_params1)):
grad_norm1 += torch.norm(grad_params1[idx])
print('rnn1 - nn.RNN')
print('rnn1 - loss: %f' % (loss1))
print('rnn1 - sum of gradient norm is: %f' % (grad_norm1))
print('----')
###########################################################################
## second network, its output and rnn gradients
rnn2 = nn.RNNCell(1, 20, nonlinearity='relu', bias=has_bias).cuda() \
if use_gpu else nn.RNNCell(1, 20, nonlinearity='relu', bias=has_bias)
linear2 = nn.Linear(20, 2, bias=has_bias).cuda() \
if use_gpu else nn.Linear(20, 2, bias=has_bias)
param_init(rnn2, linear2)
# run the net
rnn2.zero_grad()
linear2.zero_grad()
hx2 = deepcopy(hx_init)
output2 = []
for i in range(100):
hx2 = rnn2(input[i], hx2)
output2.append(hx2)
logit2 = linear2(hx2)
loss2 = criterion(logit2, target)
# calculate gradients and sum of gradient norms
grad_params2 = torch.autograd.grad(loss2, rnn2.parameters(),
create_graph=True, retain_graph=True, allow_unused=True)
grad_norm2 = 0
for idx in range(len(grad_params2)):
grad_norm2 += torch.norm(grad_params2[idx])
print('rnn2 - nn.RNNCell')
print('rnn2 - loss: %f' % (loss2))
print('rnn2: sum of gradient norm is: %f' % (grad_norm2))
print('----')
###########################################################################
## third network, its output and rnn gradients
class myReLU(torch.autograd.Function):
'''
custom RNN cell
'''
@staticmethod
def forward(ctx, x, relutype=use_cuda_relu):
if relutype:
ctx._mask = (x >= 0)
else:
ctx._mask = (x > 0)
return x.relu()
@staticmethod
def backward(ctx, grad_out):
return grad_out * ctx._mask.to(grad_out.dtype), None
class ReLURNN(nn.Module):
'''
a ReLU RNN cell with ReLU implementation consistent with cuda
'''
def __init__(self, input_size, hidden_size, bias=True, nonlinearity='relu'):
super(ReLURNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.nonlinearity = nonlinearity
self.weight_ih = nn.Parameter(torch.Tensor(hidden_size, input_size))
self.weight_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
if bias:
self.bias_ih = nn.Parameter(torch.Tensor(hidden_size))
self.bias_hh = nn.Parameter(torch.Tensor(hidden_size))
else:
self.register_parameter('bias_ih', None)
self.register_parameter('bias_hh', None)
# self.reset_parameters()
def extra_repr(self):
s = '{input_size}, {hidden_size}'
if 'bias' in self.__dict__ and self.bias is not True:
s += ', bias={bias}'
if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
s += ', nonlinearity={nonlinearity}'
return s.format(**self.__dict__)
def check_forward_input(self, input):
if input.size(1) != self.input_size:
raise RuntimeError(
"input has inconsistent input_size: got {}, expected {}".format(
input.size(1), self.input_size))
def check_forward_hidden(self, input, hx, hidden_label=''):
if input.size(0) != hx.size(0):
raise RuntimeError(
"Input batch size {} doesn't match hidden{} batch size {}".format(
input.size(0), hidden_label, hx.size(0)))
if hx.size(1) != self.hidden_size:
raise RuntimeError(
"hidden{} has inconsistent hidden_size: got {}, expected {}".format(
hidden_label, hx.size(1), self.hidden_size))
def forward(self, input, hx=None):
self.check_forward_input(input)
if hx is None:
hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
# set_trace()
self.check_forward_hidden(input, hx)
# set_trace()
# TODO take a look at the primitive implementation and implement this myself
if self.nonlinearity == 'relu':
if self.bias:
# return myReLU.apply(torch.mm(input, self.weight_ih.t()) + self.bias_ih
# + torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
return myReLU.apply(torch.mm(input, self.weight_ih.t()) +
self.bias_ih.expand(input.shape[0], self.bias_ih.shape[0]).contiguous() +
torch.mm(hx, self.weight_hh.t()) +
self.bias_hh.expand(input.shape[0], self.bias_hh.shape[0]).contiguous())
else:
return myReLU.apply(torch.mm(input, self.weight_ih.t())
+ torch.mm(hx, self.weight_hh.t()))
rnn3 = ReLURNN(input_size=1, hidden_size=20, bias=has_bias).cuda() \
if use_gpu else ReLURNN(input_size=1, hidden_size=20, bias=has_bias)
linear3 = nn.Linear(20, 2, bias=has_bias).cuda() if use_gpu \
else nn.Linear(20, 2, bias=has_bias)
param_init(rnn3, linear3)
# run the net
rnn3.zero_grad()
linear3.zero_grad()
output3 = []
hx3 = deepcopy(hx_init)
for inp in input:
hx3 = rnn3(inp, hx3)
output3.append(hx3)
logit3 = linear3(output3[-1])
loss3 = criterion(logit3, target)
# calculate gradients and sum of gradient norms
grad_params3 = torch.autograd.grad(loss3, rnn3.parameters(), create_graph=True, retain_graph=True, allow_unused=True)
grad_norm3 = 0
for idx in range(len(grad_params3)):
grad_norm3 += torch.norm(grad_params3[idx])
print('rnn3 - ReLURNN + myReLU')
print('rnn3 - loss: %f' % (loss3))
print('rnn3: sum of gradient norm is: %f' % (grad_norm3))
# set_trace() |
Issue description
It seems that
nn.RNNCell
+ ReLU andnn.RNN
+ ReLU + CPU do not calculate gradient properly. This can be observed from the different gradients calculated by RNNs constructed withnn.RNNCell
+ ReLU on CPU and on GPU, andnn.RNN
+ ReLU on CPU and GPU. Of these 4 configurations, the gradients fromnn.RNN
+ ReLU + GPU, which I think is the correct gradient, are different from the gradients calculated by the other three configurations.All other RNN modules (
nn.LSTM
,nn.GRU
,nn.LSTMCell
,nn.GRUCell
) do not have this problem. It could be due to a bug in my test script, but a few of my friends have verified and reproduced this problem, so I'm opening an issue here.Code example
See below. If you save the script as
rnn_test.py
, you can and run it with:python3 rnn_test.py --model <model> --nonlinearity <nonlinearity> --use_gpu
For example:
The first line will give the following results:
The second line will give the following results:
Here,
rnn1
is the RNN constructed withnn.RNNCell
, andrnn2
is RNN constructed withnn.RNN
. As you can see,nn.RNN
gives different gradients when running on CPU and on GPU. But I think the gradient when it runs on GPU is the correct one (44.010647
), which suggests thatnn.RNNCell
running on either CPU or GPU andnn.RNN
running on CPU give the wrong gradients.Code example:
System Info
The text was updated successfully, but these errors were encountered: