-
Notifications
You must be signed in to change notification settings - Fork 194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TF.assign surprising behaviour #92
Comments
i agree that this seems to be a bug (the equiv. python works as you expect). i don't have the answer for you, but for whatever reason i got distracted looking into this (and not getting very far!) i did find some things like this - haskell/tensorflow/src/TensorFlow/Build.hs Line 271 in 2c5c879
i'm kind of surprised, i think, that say i'd love to know the answer to this! |
Thanks for finding this. I think it's a problem with laziness and how we're interacting with the C API. In particular, forcing the values immediately after fetching them causes the test to pass:
And removing the call to It seems that in both calls to I'll look into making our code more strict to prevent this problem, and try to see whether this behavior also warrants a fix from the TensorFlow side. |
Haven't looked into this yet, but I would expect that making things more strict is only enough for scalar and Data.Vector results. TensorData and Data.Vector.Storable results point directly at the array returned by tensorflow instead of copying. The C API has a comment that I thought implied this was safe, but now I'm guessing it is only true in a trivial sense (the caller owns the pointer to the pointer to the data, but not the data...): https://github.com/tensorflow/tensorflow/blob/0c68156ffcf918df905f9f39a632888724c66c3b/tensorflow/c/c_api.h#L895 |
Ouch, you're right; I forgot about that behavior of Data.Vector.Storable. It looks like switching to the new, DT_RESOURCE-based variable ops fixes this problem. AFAIK the Python APIs have all made the switch, so we probably should as well. I'll try putting together a patch and see how invasive the change ends up being. |
For reference, see also tensorflow/tensorflow/issues/4663 which details some of the issues with the old DT_REF_*-based approach, and in particular the difference between |
The main difference between these and the `Ref`-bases ops is the explicit `readValue` op. I'm not sure how this should interact with gradients and save/restore, so I'm keeping it as a separate module for now. Once we figure out the details, we can merge it into `TensorFlow.Ops` and replace all uses of the old `Ref`-based ops. (That would also fix tensorflow#92.) Also replaces our special case newtype `ResourceHandle` to `Tensor Value ResourceHandle`, where `ResourceHandle` is the TF proto corresponding to `DT_RESOURCE`.
The main difference between these and the `Ref`-bases ops is the explicit `readValue` op. I'm not sure how this should interact with gradients and save/restore, so I'm keeping it as a separate module for now. Once we figure out the details, we can merge it into `TensorFlow.Ops` and replace all uses of the old `Ref`-based ops. (That would also fix tensorflow#92.) Also replaces our special case newtype `ResourceHandle` to `Tensor Value ResourceHandle`, where `ResourceHandle` is the TF proto corresponding to `DT_RESOURCE`.
The main difference between these and the `Ref`-bases ops is the explicit `readValue` op. I'm not sure how this should interact with gradients and save/restore, so I'm keeping it as a separate module for now. Once we figure out the details, we can merge it into `TensorFlow.Ops` and replace all uses of the old `Ref`-based ops. (That would also fix tensorflow#92.) Also replaces our special case newtype `ResourceHandle` to `Tensor Value ResourceHandle`, where `ResourceHandle` is the TF proto corresponding to `DT_RESOURCE`.
The main difference between these and the `Ref`-bases ops is the explicit `readValue` op. I'm not sure how this should interact with gradients and save/restore, so I'm keeping it as a separate module for now. Once we figure out the details, we can merge it into `TensorFlow.Ops` and replace all uses of the old `Ref`-based ops. (That would also fix tensorflow#92.) Also replaces our special case newtype `ResourceHandle` to `Tensor Value ResourceHandle`, where `ResourceHandle` is the TF proto corresponding to `DT_RESOURCE`.
The main difference between these and the `Ref`-bases ops is the explicit `readValue` op. I'm not sure how this should interact with gradients and save/restore, so I'm keeping it as a separate module for now. Once we figure out the details, we can merge it into `TensorFlow.Ops` and replace all uses of the old `Ref`-based ops. (That would also fix #92.) Also replaces our special case newtype `ResourceHandle` to `Tensor Value ResourceHandle`, where `ResourceHandle` is the TF proto corresponding to `DT_RESOURCE`.
@blackgnezdo, can you reopen this issue? (I don't have permission to do it.) The stated problem won't be resolved until TensorFlow.Variable completely replaces the functions in TensorFlow.Ops. Looks like committing my PR (which referenced this issue) caused GitHub to auto-close it. |
Unfortunately, it turns out that in TensorFlow 1.1.0 (released last month), switching to TensorFlow.Variable won't fix this issue anymore: I think this means we need to treat |
It would be better to avoid the copy when it's not necessary, but that will require more involved changes to the internal API. (For example, Fetchable might need to allow IO or ST actions.)
That commit would make you think they did it for performance reasons, but after digging around some, I'm pretty sure it was done for semantic reasons... Always copying TensorData contents SGTM for now. Maybe we can ask for an attribute to disable that optimization pass, or find some way to keep it from triggering. |
It would be better to avoid the copy when it's not necessary, but that will require more involved changes to the internal API. (For example, Fetchable might need to allow IO or ST actions.)
Closing this ticket now that the behavior is fixed; filed #109 as a long-term task to avoid the extra copy for some types. |
I wrote some tests to guide my understanding of
TF.assign
in relation torun
andrunSession
. Similar to:https://github.com/tensorflow/haskell/blob/master/tensorflow-ops/tests/BuildTest.hs#L83
This fails with f0, f1 and f2 being the final value of the assigned variable. If I change the test to:
by changing to
let formula = TF.mul 1 w
This works, but why doesn't the previous version?I wrote these tests because I tried in other code do something like this:
to collect the evolution of the
w
terms but they ended up being equal (but verifiably different from the initial value), even with theTF.value
conversion.(The
fitStep
doesassign
torefW
)Is this expected behaviour? Am I missing something?
The text was updated successfully, but these errors were encountered: