Skip to content

Conversation

tonybeltramelli
Copy link
Contributor

Introducing two updates.

1. Add param to He initialization scheme in torch.nn.init
Problem solved:
The function calculate_gain can take an argument to specify the type of non-linearity used. However, it wasn't possible to pass this argument directly to the He / Kaiming weight initialization function.

2. Add util to clip gradient value in torch.nn.utils.clip_grad
Problem solved:
DL libraries typically provide users with easy access to functions for clipping the gradients both using the norm and a fixed value. However, the utils clip_grad.py only had a function to clip the gradient norm.

@ezyang
Copy link
Contributor

ezyang commented Apr 2, 2018

@pytorchbot test this please

Gradients are modified in-place.
Arguments:
parameters (Iterable[Variable]): an iterable of Variables that will have

This comment was marked as off-topic.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have some tests please?

clip_value (float or int): maximum allowed value of the gradients
The gradients are clipped in the range [-clip_value, clip_value]
"""
parameters = list(filter(lambda p: p.grad is not None, parameters))

This comment was marked as off-topic.

This comment was marked as off-topic.

@karandwivedi42
Copy link
Contributor

Might be good to have this confirm to the _ convention as it's an in-place operation.

@ezyang
Copy link
Contributor

ezyang commented Apr 2, 2018

@pytorchbot test this please

@ezyang
Copy link
Contributor

ezyang commented Apr 2, 2018

Might be good to have this confirm to the _ convention as it's an in-place operation.

Well, there's already other clip functions which don't have a suffix _. If we change this one, the rest of them should change to (and also add the BC-compat code...)

@tonybeltramelli
Copy link
Contributor Author

The only check that failed comes from "short-perf-test-cpu" which is unrelated to the pytorch tests.
I probably don't have the permission to quickstart @pytorchbot to give it another shot.

@ezyang
Copy link
Contributor

ezyang commented Apr 3, 2018

@pytorchbot retest this please

@tonybeltramelli
Copy link
Contributor Author

Thanks @ezyang!

@ssnl
Copy link
Collaborator

ssnl commented Apr 3, 2018

I think it is reasonable to change the name to have suffix_. We did this for the init methods, and there are really just two grad clip methods including the one added in this PR...

@tonybeltramelli
Copy link
Contributor Author

Good point @ssnl let's get it done.

@@ -1,12 +1,12 @@

def clip_grad_norm(parameters, max_norm, norm_type=2):
def clip_grad_norm_(parameters, max_norm, norm_type=2):

This comment was marked as off-topic.

This comment was marked as off-topic.

"""
warnings.warn("torch.nn.utils.clip_grad_norm is now deprecated in favor "
"of torch.nn.utils.clip_grad_norm_.",
category=DeprecationWarning, stacklevel=2)

This comment was marked as off-topic.

@tonybeltramelli
Copy link
Contributor Author

The tests are stuck even though their console output show them being done.

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LTGM but @apaszke might want to take an extra look.

@tonybeltramelli
Copy link
Contributor Author

Sounds good @ssnl and thanks for reviewing these changes.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost good to go! Three minor things that could be improved

test/test_nn.py Outdated

grads = torch.arange(-50, 50).view(10, 10).div(5), torch.ones(10).mul(2)
for p, g in zip(l.parameters(), grads):
p._grad = Variable(g.clone().view_as(p.data))

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

The gradients are clipped in the range [-clip_value, clip_value]
"""
clip_value = float(clip_value)
for p in list(filter(lambda p: p.grad is not None, parameters)):

This comment was marked as off-topic.

@ssnl
Copy link
Collaborator

ssnl commented Apr 10, 2018

Hi @tonybeltramelli , are you planning to finish this PR soon? If not, I can fix the minor things for you :)

@tonybeltramelli
Copy link
Contributor Author

@ssnl sorry for the delay! I just pushed these minor fixes.

Total norm of the parameters (viewed as a single vector).
"""
parameters = list(filter(lambda p: p.grad is not None, parameters))
parameters = filter(lambda p: p.grad is not None, parameters)

This comment was marked as off-topic.

This comment was marked as off-topic.

test/test_nn.py Outdated

grads = torch.arange(-50, 50).view(10, 10).div(5), torch.ones(10).mul(2)
for p, g in zip(l.parameters(), grads):
p._grad = Variable(g.clone().view_as(p.data))

This comment was marked as off-topic.

clip_grad_value_(l.parameters(), clip_value)
for p in filter(lambda p: p.grad is not None, l.parameters()):
self.assertLessEqual(p.grad.data.max(), clip_value)
self.assertGreaterEqual(p.grad.data.min(), -clip_value)

This comment was marked as off-topic.

This comment was marked as off-topic.

@apaszke
Copy link
Contributor

apaszke commented Apr 16, 2018

@tonybeltramelli the code looks good, but I'd really like to get rid of the Variable. Can you please rebuild PyTorch from the master branch and paste the error if you're still getting one?

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apaszke doesn't want Variable in the commit

@ssnl
Copy link
Collaborator

ssnl commented Apr 17, 2018

If we can merge this, I'll remove the Variable wrapper in #6641.

@ezyang
Copy link
Contributor

ezyang commented Apr 17, 2018

@ssnl is planning to fix the Variable problem in a codemode, so this is OK to go in.

@ezyang ezyang merged commit 7fcaf3b into pytorch:master Apr 17, 2018
@tonybeltramelli
Copy link
Contributor Author

@apaszke Thanks and sorry for keeping that Variable. I have noticed that most of the tests still use explicit casting to Variable and figured it would probably make more sense and be cleaner to update all the tests at once to get rid of this deprecated requirement.

@ssnl and @ezyang thank you and sorry for my unresponsive response time this week!

@ssnl
Copy link
Collaborator

ssnl commented Apr 17, 2018

@tonybeltramelli No worries. The tests are that way because we haven't gotten around to update them (fully). It was already quite some work for me to update part of those in #6641 ...

@tonybeltramelli
Copy link
Contributor Author

@ssnl Makes total sense, pytorch is becoming a beast! :)

@ngimel
Copy link
Collaborator

ngimel commented Apr 17, 2018

My local tests on the fresh build are failing, I wonder how CI is passing

root@bf33ceab382b:/raid/pytorch/test# python test_nn.py 
Traceback (most recent call last):
  File "test_nn.py", line 23, in <module>
    from torch.nn.utils import clip_grad_norm_, clip_grad_value_
ImportError: cannot import name 'clip_grad_norm_'

@ssnl
Copy link
Collaborator

ssnl commented Apr 17, 2018

@ngimel My local test script imports fine.. Do you have prior binary installs that are not properly cleaned?

@ngimel
Copy link
Collaborator

ngimel commented Apr 17, 2018

Ah, right, I think it's prior install that I forgot to clean. Sorry for the noise.

Jorghi12 pushed a commit to wsttiger/pytorch that referenced this pull request May 10, 2018
Introducing two updates.

1. Add param to He initialization scheme in torch.nn.init
Problem solved:
The function calculate_gain can take an argument to specify the type of non-linearity used. However, it wasn't possible to pass this argument directly to the He / Kaiming weight initialization function.

2. Add util to clip gradient value in torch.nn.utils.clip_grad
Problem solved:
DL libraries typically provide users with easy access to functions for clipping the gradients both using the norm and a fixed value. However, the utils clip_grad.py only had a function to clip the gradient norm.

* add param to He initialization scheme in torch.nn.init

* add util to clip gradient value in torch/nn/utils/clip_grad.py

* update doc in torch.nn.utils.clip_grad

* update and add test for torch.nn.utils.clip_grad

* update function signature in torch.nn.utils.clip_grad to match suffix_ convention

* ensure backward compatibility in torch.nn.utils.clip_grad

* remove DeprecationWarning in torch.nn.utils.clip_grad

* extend test and implementation of torch.nn.utils.clip_grad

* update test and implementation torch.nn.utils.clip_grad
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants