Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong gradient in RNNCell with ReLU? #11662

Closed
zichaow opened this issue Sep 13, 2018 · 9 comments
Closed

Wrong gradient in RNNCell with ReLU? #11662

zichaow opened this issue Sep 13, 2018 · 9 comments

Comments

@zichaow
Copy link

zichaow commented Sep 13, 2018

Issue description

It seems that nn.RNNCell + ReLU and nn.RNN + ReLU + CPU do not calculate gradient properly. This can be observed from the different gradients calculated by RNNs constructed with nn.RNNCell + ReLU on CPU and on GPU, and nn.RNN + ReLU on CPU and GPU. Of these 4 configurations, the gradients from nn.RNN + ReLU + GPU, which I think is the correct gradient, are different from the gradients calculated by the other three configurations.

All other RNN modules (nn.LSTM, nn.GRU, nn.LSTMCell, nn.GRUCell) do not have this problem. It could be due to a bug in my test script, but a few of my friends have verified and reproduced this problem, so I'm opening an issue here.

Code example

See below. If you save the script as rnn_test.py, you can and run it with: python3 rnn_test.py --model <model> --nonlinearity <nonlinearity> --use_gpu

For example:

python3 rnn_test.py --model rnn --nonlinearity relu --use_gpu
python3 rnn_test.py --model rnn --nonlinearity relu

The first line will give the following results:

use gpu.
rnn1 - loss: 0.182044
rnn1 -  sum of gradient norm is: 22.836142
---
rnn2 - loss: 0.182044
rnn2 - sum of gradient norm is: 44.010647

The second line will give the following results:

use cpu.
rnn1 - loss: 0.182044
rnn1 -  sum of gradient norm is: 22.836143
---
rnn2 - loss: 0.182044
rnn2 - sum of gradient norm is: 22.836143

Here, rnn1 is the RNN constructed with nn.RNNCell, and rnn2 is RNN constructed with nn.RNN. As you can see, nn.RNN gives different gradients when running on CPU and on GPU. But I think the gradient when it runs on GPU is the correct one (44.010647), which suggests that nn.RNNCell running on either CPU or GPU and nn.RNN running on CPU give the wrong gradients.

Code example:

import torch
import torch.nn as nn
from torch.autograd import Variable
import random
from copy import deepcopy
import argparse

parser = argparse.ArgumentParser(description="rnn cpu and gpu tests")
parser.add_argument('--use_gpu', action='store_true')
parser.add_argument('--model', type=str, default='rnn', choices=['rnn', 'gru', 'lstm'])
parser.add_argument('--nonlinearity', type=str, default='relu', choices=['relu', 'tanh'])
use_gpu = parser.parse_args().use_gpu
model = parser.parse_args().model
nonlinearity = parser.parse_args().nonlinearity
print('use gpu.') if use_gpu else print('use cpu.')

torch.cuda.manual_seed(0)
torch.manual_seed(0)
random.seed(0)

## manually create input, target, initial hidden state and criterion
input = Variable(torch.randn(100, 64, 1).cuda()) if use_gpu else Variable(torch.randn(100, 64, 1)) # dim = (seq_len, batch_size, input_size)
target = Variable(torch.randint(low=0, high=1, size=(64, ), dtype=torch.long).cuda()) if use_gpu else Variable(torch.randint(low=0, high=1, size=(64, ), dtype=torch.long))
hx0 = Variable(torch.randn(64, 20).cuda()) if use_gpu else Variable(torch.randn(64, 20)) # dim = (batch_size, hidden_size)
if model == 'lstm':
    c0 = Variable(torch.zeros(64, 20).cuda()) if use_gpu else Variable(torch.zeros(64, 20)) # dim = (batch_size, hidden_size)
criterion = nn.CrossEntropyLoss() # use cross entropy loss


## first network, its output and rnn gradients
if model=='rnn':
    rnn1 = nn.RNNCell(1, 20, nonlinearity=nonlinearity, bias=False).cuda() if use_gpu else nn.RNNCell(1, 20, nonlinearity=nonlinearity, bias=False)
elif model=='gru':
    rnn1 = nn.GRUCell(1, 20, bias=False).cuda() if use_gpu else nn.GRUCell(1, 20, bias=False)
elif model=='lstm':
    rnn1 = nn.LSTMCell(1, 20, bias=False).cuda() if use_gpu else nn.LSTMCell(1, 20, bias=False)
linear1 = nn.Linear(20, 2, bias=False).cuda() if use_gpu else nn.Linear(20, 2, bias=False)

# no bias and eye init to make sure two networks have the same parameters
for name, param in rnn1.named_parameters():
    if 'weight' in name:
        nn.init.eye_(param)
for name, param in linear1.named_parameters():
    if 'weight' in name:
        nn.init.eye_(param)

# run the net
hx1 = deepcopy(hx0)
if model=='lstm':
    c1 = deepcopy(c0)
output1 = []
for i in range(100):
    if model != 'lstm':
        hx1 = rnn1(input[i], hx1)
    else:
        hx1, c1 = rnn1(input[i], (hx1, c1))
    output1.append(hx1)

logit1 = linear1(hx1)
loss1 = criterion(logit1, target)

# calculate gradients and sum of gradient norms
grad_params1 = torch.autograd.grad(loss1, rnn1.parameters(), create_graph=True, retain_graph=True, allow_unused=True)
grad_norm1 = 0
for idx in range(len(grad_params1)):
    grad_norm1 += torch.norm(grad_params1[idx])
print('rnn1 - loss: %f' % (loss1))
print('rnn1 -  sum of gradient norm is: %f' % (grad_norm1))
print('---')

## second network, its output and rnn gradients
## first network, its output and rnn gradients
if model=='rnn':
    rnn2 = nn.RNN(1, 20, nonlinearity=nonlinearity, bias=False).cuda() if use_gpu else nn.RNN(1, 20, nonlinearity=nonlinearity, bias=False)
elif model=='gru':
    rnn2 = nn.GRU(1, 20, bias=False).cuda() if use_gpu else nn.GRU(1, 20, bias=False)
elif model=='lstm':
    rnn2 = nn.LSTM(1, 20, bias=False).cuda() if use_gpu else nn.LSTM(1, 20, bias=False)
linear2 = nn.Linear(20, 2, bias=False).cuda() if use_gpu else nn.Linear(20, 2, bias=False)

# same init as the first network
for name, param in rnn2.named_parameters():
    if 'weight' in name:
        nn.init.eye_(param)
for name, param in linear2.named_parameters():
    if 'weight' in name:
        nn.init.eye_(param)
        
# run the net 
if model != 'lstm':
    output2, hx2 = rnn2(input, hx0.unsqueeze(0))
else:
    output2, (hx2, _) = rnn2(input, (hx0.unsqueeze(0), c0.unsqueeze(0)))
logit2 = linear2(hx2[-1])
loss2 = criterion(logit2, target)

# calculate gradients and sum of gradient norms
grad_params2 = torch.autograd.grad(loss2, rnn2.parameters(), create_graph=True, retain_graph=True, allow_unused=True)
grad_norm2 = 0
for idx in range(len(grad_params2)):
    grad_norm2 += torch.norm(grad_params2[idx])
print('rnn2 - loss: %f' % (loss2))
print('rnn2 - sum of gradient norm is: %f' % (grad_norm2))

System Info

  • PyTorch or Caffe2: pyTorch
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • OS: ubuntu 16.04
  • PyTorch version: 0.4.1
  • Python version: 3.5.2
  • CUDA/cuDNN version: 8.0
  • GPU models and configuration: TITAN X
  • GCC version (if compiling from source):
  • CMake version: 3.5.1
  • Versions of any other relevant libraries:
@zichaow
Copy link
Author

zichaow commented Sep 14, 2018

@ssnl does this appear to be a bug in the RNNCell with ReLU?

@ssnl
Copy link
Collaborator

ssnl commented Sep 17, 2018

@moonlightlane Yes it is a bug. You are correct.

@t-vi
Copy link
Collaborator

t-vi commented Sep 20, 2018

@ssnl No, it's not a bug:

It's a corner case of whether the grad of relu at 0 should be 0 or 1. CuDNN uses 1, we use 0.

import torch

torch.backends.cudnn.enabled = True
torch.cuda.manual_seed(0)
torch.manual_seed(0)
device = 'cuda'
T = 2
bs = 1
H = 3

input = torch.randn(T, bs, 1, device=device)
hx0 = torch.randn(bs, H, device=device)
rnn1 = torch.nn.RNNCell(1, H, nonlinearity=nonlinearity, bias=False).to(device)
nn.init.eye_(rnn1.weight_hh)
nn.init.eye_(rnn1.weight_ih)

hx1 = hx0
output1 = []
for i in range(T):
    hx1 = rnn1(input[i], hx1)
loss1 = hx1.sum()

grad_params1 = torch.autograd.grad(loss1, rnn1.parameters(), create_graph=True, retain_graph=True, allow_unused=True)


grad_norm1 = 0
for idx in range(len(grad_params1)):
    grad_norm1 += torch.norm(grad_params1[idx])
print('rnn1 - loss: %f' % (loss1))
print('rnn1 -  sum of gradient norm is: %f' % (grad_norm1))
print('---')

rnn2 = torch.nn.RNN(1, H, nonlinearity=nonlinearity, bias=False).to(device)
with torch.no_grad():
    rnn2.weight_hh_l0.copy_(rnn1.weight_hh)
    rnn2.weight_ih_l0.copy_(rnn1.weight_ih)

output2, hx2 = rnn2(input, hx0.unsqueeze(0))

loss2 = hx2.sum()

# calculate gradients and sum of gradient norms
grad_params2 = torch.autograd.grad(loss2, rnn2.parameters(), create_graph=True, retain_graph=True, allow_unused=True)
grad_norm2 = 0
for idx in range(len(grad_params2)):
    grad_norm2 += torch.norm(grad_params2[idx])
print('rnn2 - loss: %f' % (loss2))
print('rnn2 - sum of gradient norm is: %f' % (grad_norm2))


class myReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, relutype):
        if relutype:
            ctx._mask = (x >= 0)
        else:
            ctx._mask = (x > 0)
        return x.relu()
    @staticmethod
    def backward(ctx, grad_out):
        return grad_out * ctx._mask.to(grad_out.dtype), None

def myrnn(input, hx, w_ih, w_hh, relutype=False):
    for aninp in input:
        i_ = torch.mm(aninp, w_ih.t())
        h_ = torch.mm(hx, w_hh.t())
        hx = myReLU.apply(i_ + h_, relutype)
    return hx

w_ih = rnn1.weight_ih.detach().requires_grad_()
w_hh = rnn1.weight_hh.detach().requires_grad_()

loss3 = myrnn(input, hx0, w_ih, w_hh).sum()
grads3 = torch.autograd.grad(loss3, [w_ih, w_hh])
print ("3:",loss3.item(),"-",(torch.norm(grads3[0])+torch.norm(grads3[1]).sum()).item())
loss4 = myrnn(input, hx0, w_ih, w_hh, True).sum()
grads4 = torch.autograd.grad(loss4, [w_ih, w_hh])
print ("4:", loss4.item(), "-",(torch.norm(grads4[0])+torch.norm(grads4[1]).sum()).item())

@ssnl
Copy link
Collaborator

ssnl commented Sep 20, 2018

Haha @t-vi Good find! Thanks :) I think we can close this now!

@ssnl ssnl closed this as completed Sep 20, 2018
@zichaow
Copy link
Author

zichaow commented Sep 20, 2018

@t-vi thanks so much!!! I almost switched to tensorflow for this experiment because of this ...

@ssnl
Copy link
Collaborator

ssnl commented Sep 20, 2018

@moonlightlane Well tf and pytorch call the same cudnn function, so ....

@t-vi
Copy link
Collaborator

t-vi commented Sep 20, 2018

@moonlightlane We love to keep you on PyTorch. 🤗
You QG-Net-Readme looks awefully cool, but I didn't try it out.

@zichaow
Copy link
Author

zichaow commented Oct 6, 2018 via email

@ssnl
Copy link
Collaborator

ssnl commented Oct 8, 2018

copied code here for md format:

import torch
import torch.nn as nn
from torch.autograd import Variable
import random
from copy import deepcopy
import argparse
from pdb import set_trace
import math

from pdb import set_trace


parser = argparse.ArgumentParser(description="rnn cpu and gpu tests")
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--use_gpu', action='store_true')
parser.add_argument('--has_bias', action='store_true') # without this flag = no bias
parser.add_argument('--use_cuda_relu', action='store_true') # without this flag = custom relu is the same as in pytorch; with flag = same as cuda
seed = parser.parse_args().seed
use_gpu = parser.parse_args().use_gpu
has_bias = parser.parse_args().has_bias
use_cuda_relu= parser.parse_args().use_cuda_relu

print('use gpu.') if use_gpu else print('use cpu.')
print('use cuda relu') if use_cuda_relu else print('use pytorch relu')

torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
random.seed(seed)


## manually create input, target, initial hidden state and criterion
# dim = (seq_len, batch_size, input_size)
input = Variable(torch.randn(100, 64, 1)).cuda() if use_gpu \
        else Variable(torch.randn(100, 64, 1))
target = Variable(torch.randint(low=0, high=1, size=(64, ),
        dtype=torch.long)).cuda() if use_gpu else \
        Variable(torch.randint(low=0, high=1, size=(64, ), dtype=torch.long))
hx_init = Variable(torch.randn(64, 20)).cuda() if use_gpu \
        else Variable(torch.randn(64, 20)) # dim = (batch_size, hidden_size)
criterion = nn.CrossEntropyLoss() # use cross entropy loss

# define an init function
# no bias and eye init to make sure two networks have the same parameters
def param_init(rnn, linear):
    for name, param in rnn.named_parameters():
        if 'weight' in name:
            nn.init.eye_(param)
        else:
            nn.init.constant_(param, 1)
    for name, param in linear.named_parameters():
        if 'weight' in name:
            nn.init.eye_(param)
        else:
            nn.init.constant_(param, 1)

###########################################################################
## first network, its output and rnn gradients
rnn1 = nn.RNN(1, 20, nonlinearity='relu', bias=has_bias).cuda() \
        if use_gpu else nn.RNN(1, 20, nonlinearity='relu', bias=has_bias)
linear1 = nn.Linear(20, 2, bias=has_bias).cuda() \
        if use_gpu else nn.Linear(20, 2, bias=has_bias)

# param init
param_init(rnn1, linear1)

# run the net
rnn1.zero_grad()
linear1.zero_grad()

output1, hx1 = rnn1(input, hx_init.unsqueeze(0))
logit1 = linear1(output1[-1])
loss1 = criterion(logit1, target)

# calculate gradients and sum of gradient norms
grad_params1 = torch.autograd.grad(loss1, rnn1.parameters(),
        create_graph=True, retain_graph=True, allow_unused=True)
grad_norm1 = 0
for idx in range(len(grad_params1)):
    grad_norm1 += torch.norm(grad_params1[idx])

print('rnn1 - nn.RNN')
print('rnn1 - loss: %f' % (loss1))
print('rnn1 - sum of gradient norm is: %f' % (grad_norm1))
print('----')


###########################################################################
## second network, its output and rnn gradients
rnn2 = nn.RNNCell(1, 20, nonlinearity='relu', bias=has_bias).cuda() \
        if use_gpu else nn.RNNCell(1, 20, nonlinearity='relu', bias=has_bias)
linear2 = nn.Linear(20, 2, bias=has_bias).cuda() \
        if use_gpu else nn.Linear(20, 2, bias=has_bias)

param_init(rnn2, linear2)

# run the net
rnn2.zero_grad()
linear2.zero_grad()

hx2 = deepcopy(hx_init)
output2 = []
for i in range(100):
    hx2 = rnn2(input[i], hx2)
    output2.append(hx2)
logit2 = linear2(hx2)
loss2 = criterion(logit2, target)

# calculate gradients and sum of gradient norms
grad_params2 = torch.autograd.grad(loss2, rnn2.parameters(),
        create_graph=True, retain_graph=True, allow_unused=True)
grad_norm2 = 0
for idx in range(len(grad_params2)):
    grad_norm2 += torch.norm(grad_params2[idx])

print('rnn2 - nn.RNNCell')
print('rnn2 - loss: %f' % (loss2))
print('rnn2: sum of gradient norm is: %f' % (grad_norm2))
print('----')


###########################################################################
## third network, its output and rnn gradients

class myReLU(torch.autograd.Function):
    '''
    custom RNN cell
    '''
    @staticmethod
    def forward(ctx, x, relutype=use_cuda_relu):
        if relutype:
            ctx._mask = (x >= 0)
        else:
            ctx._mask = (x > 0)
        return x.relu()
    @staticmethod
    def backward(ctx, grad_out):
        return grad_out * ctx._mask.to(grad_out.dtype), None

class ReLURNN(nn.Module):
    '''
    a ReLU RNN cell with ReLU implementation consistent with cuda
    '''
    def __init__(self, input_size, hidden_size, bias=True, nonlinearity='relu'):
        super(ReLURNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.nonlinearity = nonlinearity

        self.weight_ih = nn.Parameter(torch.Tensor(hidden_size, input_size))
        self.weight_hh = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
        if bias:
            self.bias_ih = nn.Parameter(torch.Tensor(hidden_size))
            self.bias_hh = nn.Parameter(torch.Tensor(hidden_size))
        else:
            self.register_parameter('bias_ih', None)
            self.register_parameter('bias_hh', None)
        # self.reset_parameters()

    def extra_repr(self):
        s = '{input_size}, {hidden_size}'
        if 'bias' in self.__dict__ and self.bias is not True:
            s += ', bias={bias}'
        if 'nonlinearity' in self.__dict__ and self.nonlinearity != "tanh":
            s += ', nonlinearity={nonlinearity}'
        return s.format(**self.__dict__)

    def check_forward_input(self, input):
        if input.size(1) != self.input_size:
            raise RuntimeError(
                "input has inconsistent input_size: got {}, expected {}".format(
                    input.size(1), self.input_size))

    def check_forward_hidden(self, input, hx, hidden_label=''):
        if input.size(0) != hx.size(0):
            raise RuntimeError(
                "Input batch size {} doesn't match hidden{} batch size {}".format(
                    input.size(0), hidden_label, hx.size(0)))

        if hx.size(1) != self.hidden_size:
            raise RuntimeError(
                "hidden{} has inconsistent hidden_size: got {}, expected {}".format(
                    hidden_label, hx.size(1), self.hidden_size))

    def forward(self, input, hx=None):
        self.check_forward_input(input)
        if hx is None:
            hx = input.new_zeros(input.size(0), self.hidden_size, requires_grad=False)
        # set_trace()
        self.check_forward_hidden(input, hx)

        # set_trace()
        # TODO take a look at the primitive implementation and implement this myself
        if self.nonlinearity == 'relu':
            if self.bias:
            #    return  myReLU.apply(torch.mm(input, self.weight_ih.t()) + self.bias_ih
            #                    + torch.mm(hx, self.weight_hh.t()) + self.bias_hh)
                return  myReLU.apply(torch.mm(input, self.weight_ih.t()) +
                               self.bias_ih.expand(input.shape[0], self.bias_ih.shape[0]).contiguous() +
                               torch.mm(hx, self.weight_hh.t())  +
                               self.bias_hh.expand(input.shape[0], self.bias_hh.shape[0]).contiguous())
            else:
                return  myReLU.apply(torch.mm(input, self.weight_ih.t())
                                    + torch.mm(hx, self.weight_hh.t()))

rnn3 = ReLURNN(input_size=1, hidden_size=20, bias=has_bias).cuda() \
        if use_gpu else ReLURNN(input_size=1, hidden_size=20, bias=has_bias)
linear3 = nn.Linear(20, 2, bias=has_bias).cuda() if use_gpu \
        else nn.Linear(20, 2, bias=has_bias)

param_init(rnn3, linear3)

# run the net
rnn3.zero_grad()
linear3.zero_grad()

output3 = []
hx3 = deepcopy(hx_init)
for inp in input:
    hx3 = rnn3(inp, hx3)
    output3.append(hx3)

logit3 = linear3(output3[-1])
loss3 = criterion(logit3, target)

# calculate gradients and sum of gradient norms
grad_params3 = torch.autograd.grad(loss3, rnn3.parameters(), create_graph=True, retain_graph=True, allow_unused=True)
grad_norm3 = 0
for idx in range(len(grad_params3)):
    grad_norm3 += torch.norm(grad_params3[idx])

print('rnn3 - ReLURNN + myReLU')
print('rnn3 - loss: %f' % (loss3))
print('rnn3: sum of gradient norm is: %f' % (grad_norm3))

# set_trace()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants