Skip to content

Conversation

soulitzer
Copy link
Contributor

@soulitzer soulitzer commented Oct 26, 2020

Fixes #46373

As noted in #46373, there needs to be a flag passed into the engine that indicates whether it was executed through the backward api or grad api. Tentatively named the flag accumulate_grad since functionally, backward api accumulates grad into .grad while grad api captures the grad and returns it.

Moving changes not necessary to the python api (cpp, torchscript) to a new PR.

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Oct 26, 2020
@dr-ci
Copy link

dr-ci bot commented Oct 26, 2020

💊 CI failures summary and remediations

As of commit 30b3516 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 120 times.

@soulitzer soulitzer force-pushed the change_backward_api branch from 311cd4e to fcf409b Compare October 26, 2020 17:46
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To appease the backward compatibility test as you are willingly changing the signature, you should add it here:

("aten::hash", datetime.date(2020, 11, 15)),

@codecov
Copy link

codecov bot commented Oct 27, 2020

Codecov Report

Merging #46855 into master will decrease coverage by 0.00%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master   #46855      +/-   ##
==========================================
- Coverage   68.87%   68.87%   -0.01%     
==========================================
  Files         436      436              
  Lines       56368    56371       +3     
==========================================
+ Hits        38823    38825       +2     
- Misses      17545    17546       +1     

@soulitzer soulitzer requested a review from albanD October 27, 2020 17:41
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are missing the check to ensure that all the given inputs are actually leafs when passed to .backward() no?
I think that it can be in autograd/autograd.cpp when you build the output_edges.

@soulitzer soulitzer requested a review from albanD October 28, 2020 18:54
@soulitzer soulitzer changed the title (WIP) Add inputs argument to autograd.backward() Add inputs argument to autograd.backward() Oct 29, 2020
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks mostly good. Just minor comments.

@soulitzer soulitzer requested a review from albanD October 29, 2020 18:43
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks ok to me can you fix the lint?

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soulitzer has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

// The user either called autograd.backward(...) or autograd.grad(...) to get here
bool backward_api_called = inputs == nullptr;
bool backward_api_called = accumulate_grad;
TORCH_CHECK(!backward_api_called || at::impl::VmapMode::current_vmap_level() == 0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does accumulate_grad do?

const variable_list& inputs,
bool keep_graph,
bool create_graph,
bool accumulate_grad,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a comment here for what accumulate_grad does, for future code readers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a separate PR for this

@facebook-github-bot
Copy link
Contributor

@soulitzer merged this pull request in f5073b0.

facebook-github-bot pushed a commit that referenced this pull request Nov 3, 2020
Summary:
Addressing a comment from a PR that has already been merged #46855

#46855 (comment)

Pull Request resolved: #47266

Reviewed By: agolynski

Differential Revision: D24709017

Pulled By: soulitzer

fbshipit-source-id: 3c104c2fef90ffd75951ecef4ae9e938d4b12d8c
facebook-github-bot pushed a commit that referenced this pull request Nov 4, 2020
Summary:
Helps fix #46373 for the cpp api.

Follow up to #46855 which only changed the api for python only

Pull Request resolved: #47214

Reviewed By: agolynski

Differential Revision: D24716139

Pulled By: soulitzer

fbshipit-source-id: 3e1f35968e8dee132985b883481cfd0d1872ccdd
@soulitzer soulitzer deleted the change_backward_api branch November 5, 2020 21:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add inputs argument to autograd.backward()

4 participants