Skip to content

Commit

Permalink
Work around #92 by always copying TensorData when fetching.
Browse files Browse the repository at this point in the history
It would be better to avoid the copy when it's not necessary, but
that will require more involved changes to the internal API.  (For example,
Fetchable might need to allow IO or ST actions.)
  • Loading branch information
judah authored and fkm3 committed May 9, 2017
1 parent 37e3c9b commit a64af50
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
12 changes: 12 additions & 0 deletions tensorflow-ops/tests/OpsTest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,22 @@ testScalarFeedCse = testCase "testScalarFeedCse" $ TF.runSession $ do
$ p1 `TF.add` p2
liftIO $ result @=? TF.Scalar 5

-- | See https://github.com/tensorflow/haskell/issues/92.
-- Even though we're not explicitly evaluating `f0` until the end,
-- it should hold the earlier value of the variable.
testRereadRef :: Test
testRereadRef = testCase "testReRunAssign" $ TF.runSession $ do
w <- TF.initializedVariable 0
f0 <- TF.run w
TF.run_ =<< TF.assign w (TF.scalar (0.1 :: Float))
f1 <- TF.run w
liftIO $ (0.0, 0.1) @=? (TF.unScalar f0, TF.unScalar f1)

main :: IO ()
main = googleTest [ testSaveRestore
, testSize
, testReducedShape
, testPlaceholderCse
, testScalarFeedCse
, testRereadRef
]
15 changes: 11 additions & 4 deletions tensorflow/src/TensorFlow/Internal/FFI.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import Data.Typeable (Typeable)
import Data.Word (Word8)
import Foreign (Ptr, FunPtr, nullPtr, castPtr)
import Foreign.C.String (CString)
import Foreign.ForeignPtr (newForeignPtr, withForeignPtr)
import Foreign.ForeignPtr (newForeignPtr, newForeignPtr_, withForeignPtr)
import Foreign.Marshal.Alloc (free)
import Foreign.Marshal.Array (withArrayLen, peekArray, mallocArray, copyArray)
import System.IO.Unsafe (unsafePerformIO)
Expand All @@ -51,7 +51,7 @@ import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Data.Text.Encoding.Error as T
import qualified Data.Vector.Storable as S
import qualified Foreign.Concurrent as ForeignC
import qualified Data.Vector.Storable.Mutable as M

import Data.ProtoLens (Message, encodeMessage)
import Proto.Tensorflow.Core.Framework.Graph (GraphDef)
Expand Down Expand Up @@ -193,6 +193,10 @@ tensorDeallocFunPtr = unsafePerformIO $ Raw.wrapTensorDealloc $ \x _ _ -> free x
-- | Create a TensorData from a Raw.Tensor.
--
-- Takes ownership of the Raw.Tensor.
-- TODO: Currently, it just makes a copy of the Tensor (and then deletes it),
-- since the raw pointer may refer to storage inside a mutable TensorFlow
-- variable. We should avoid that copy when it's not needed; for example,
-- by making TensorData wrap an IOVector, and changing the code that uses it.
createTensorData :: Raw.Tensor -> IO TensorData
createTensorData t = do
-- Read dimensions.
Expand All @@ -203,8 +207,11 @@ createTensorData t = do
-- Read data.
len <- safeConvert <$> Raw.tensorByteSize t
bytes <- castPtr <$> Raw.tensorData t :: IO (Ptr Word8)
fp <- ForeignC.newForeignPtr bytes (Raw.deleteTensor t)
let v = S.unsafeFromForeignPtr0 fp len
fp <- newForeignPtr_ bytes
-- Make an explicit copy of the raw data, since it might point
-- to a mutable variable's memory.
v <- S.freeze (M.unsafeFromForeignPtr0 fp len)
Raw.deleteTensor t
return $ TensorData (map safeConvert dims) dtype v

-- | Runs the given action which does FFI calls updating a provided
Expand Down

0 comments on commit a64af50

Please sign in to comment.