-
Notifications
You must be signed in to change notification settings - Fork 138
Add RNN wrapper for Cells #100
Conversation
Ok, here's roughly how to define a VJP: @differentiating(call(_:))
@usableFromInline
internal func _vjpCall(
_ input: [Cell.TimeStepInput]
) -> (value: [Cell.Output],
pullback: ([Cell.Output].CotangentVector) -> (CotangentVector, [Cell.TimeStepInput].CotangentVector)) {
var currentHiddenState = cell.zeroState
var outputs: [Cell.Output] = []
var backpropagators: [Backpropagator] = []
for timestep in input {
let (timestepOutput, backpropagator) =
cell.appliedForBackpropagation(to: timestep, state: currentHiddenState)
currentHiddenState = timestepOutput.state
outputs.append(timestepOutput)
backpropagators.append(timestepOutput)
}
func pullback(v: [Cell.Output].CotangentVector) -> (CotangentVector [Cell.TimeStepInput].CotangentVector) {
// Apply backpropagators in reverse order.
}
return (value: outputs, pullback: pullback)
} |
I'm working on implementing this VJP, but I think I'm a bit confused (semantic satiation, probably). How exactly, would I apply the |
@tanmayb123 would I be right in assuming that we use this for multi-layer RNNs, LSTMs later on? |
Affirmative. |
@jekbradbury pointed out that it's not necessary to make |
I'm gonna try prototyping this VJP and reply here :) |
Done! I also added some tests. Here's my working branch: rxwei@224463a. Seems to work fine! If you'd like, feel free to cherry-pick my commit onto this PR, or I can push here directly. Let me know! |
BTW, the manual VJP is actually more efficient than control flow AD-generated code in the near term without further optimizations. So we should keep that even when control flow is done! :) |
This is wonderful, Richard! Thank you. Please go ahead and push here directly :) |
Oops, I mis-operated since I didn't have push access initially. Let me see how to restore it! |
I don't think I have permissions to push to your repo. If it's not too much trouble, would you mind adding me as a collaborator on your fork, or reopening a PR? |
I'll add you to my repo. |
Thanks! Since I'm still unable to reopen, I opened #105. I'll appreciate your review! |
I'd love some feedback on how this could be made better. We can't merge this yet because the for loop isn't differentiable.
@rxwei I'd love any hints as to how I could define a custom VJP that makes this function differentiable including the for loop.
#52