Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Conversation

tanmayb123
Copy link
Contributor

@tanmayb123 tanmayb123 commented Apr 18, 2019

I'd love some feedback on how this could be made better. We can't merge this yet because the for loop isn't differentiable.

@rxwei I'd love any hints as to how I could define a custom VJP that makes this function differentiable including the for loop.

#52

@rxwei
Copy link
Contributor

rxwei commented Apr 18, 2019

Ok, here's roughly how to define a VJP:

    @differentiating(call(_:))
    @usableFromInline
    internal func _vjpCall(
        _ input: [Cell.TimeStepInput]
    ) -> (value: [Cell.Output],
             pullback: ([Cell.Output].CotangentVector) -> (CotangentVector, [Cell.TimeStepInput].CotangentVector)) {
        var currentHiddenState = cell.zeroState
        var outputs: [Cell.Output] = []
        var backpropagators: [Backpropagator] = []
        for timestep in input {
            let (timestepOutput, backpropagator) =
                cell.appliedForBackpropagation(to: timestep, state: currentHiddenState)
            currentHiddenState = timestepOutput.state
            outputs.append(timestepOutput)
            backpropagators.append(timestepOutput)
        }
        func pullback(v: [Cell.Output].CotangentVector) -> (CotangentVector [Cell.TimeStepInput].CotangentVector) {
            // Apply backpropagators in reverse order.
        }
        return (value: outputs, pullback: pullback)
    }

@tanmayb123
Copy link
Contributor Author

I'm working on implementing this VJP, but I think I'm a bit confused (semantic satiation, probably). How exactly, would I apply the backpropagators and return both a RNN<Cell: RNNCell>.CotangentVector as well as a [Cell.TimeStepInput.CotangentVector]?

@Shashi456
Copy link
Contributor

@tanmayb123 would I be right in assuming that we use this for multi-layer RNNs, LSTMs later on?

@tanmayb123
Copy link
Contributor Author

Affirmative.

@rxwei
Copy link
Contributor

rxwei commented Apr 19, 2019

@jekbradbury pointed out that it's not necessary to make call return the full sequence of Outputs--just the timestep outputs (no hidden states) would be fine.

@rxwei
Copy link
Contributor

rxwei commented Apr 19, 2019

I'm gonna try prototyping this VJP and reply here :)

@rxwei rxwei added the enhancement New feature or request label Apr 19, 2019
@rxwei
Copy link
Contributor

rxwei commented Apr 20, 2019

Done! I also added some tests. Here's my working branch: rxwei@224463a. Seems to work fine! If you'd like, feel free to cherry-pick my commit onto this PR, or I can push here directly. Let me know!

@rxwei
Copy link
Contributor

rxwei commented Apr 20, 2019

BTW, the manual VJP is actually more efficient than control flow AD-generated code in the near term without further optimizations. So we should keep that even when control flow is done! :)

@rxwei rxwei requested a review from jekbradbury April 20, 2019 01:30
@tanmayb123
Copy link
Contributor Author

This is wonderful, Richard! Thank you. Please go ahead and push here directly :)

@rxwei
Copy link
Contributor

rxwei commented Apr 20, 2019

Oops, I mis-operated since I didn't have push access initially. Let me see how to restore it!

@rxwei
Copy link
Contributor

rxwei commented Apr 20, 2019

I don't think I have permissions to push to your repo. If it's not too much trouble, would you mind adding me as a collaborator on your fork, or reopening a PR?

@tanmayb123
Copy link
Contributor Author

I'll add you to my repo.

@rxwei rxwei mentioned this pull request Apr 20, 2019
@rxwei
Copy link
Contributor

rxwei commented Apr 20, 2019

Thanks! Since I'm still unable to reopen, I opened #105. I'll appreciate your review!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants