Skip to content

Commit

Permalink
Introduce a MonadBuild class, and remove buildAnd. (#83)
Browse files Browse the repository at this point in the history
This change adds a class that both `Build` and `Session` are instances of:

    class MonadBuild m where
        build :: Build a -> m a

All stateful ops (generated and manually written) now have a signature that returns
an instance of `MonadBuild` (rather than just `Build`).  For example:

    assign_ :: (MonadBuild m, TensorType t)
            => Tensor Ref t -> Tensor v t -> m (Tensor Ref t)

This lets us remove a bunch of spurious calls to `build` in user code.  It also
lets us replace the pattern `buildAnd run foo` with the simpler pattern `foo >>= run`
(or `run =<< foo`, which is sometimes nicer when foo is a complicated expression).

I went ahead and deleted `buildAnd` altogether since it seems to lead to
confusion; in particular a few tests had `buildAnd run . pure` which is
actually equivalent to just `run`.
  • Loading branch information
judah committed Mar 18, 2017
1 parent 9209dfc commit 2c5c879
Show file tree
Hide file tree
Showing 22 changed files with 152 additions and 162 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ fit xData yData = TF.runSession $ do
let x = TF.vector xData
y = TF.vector yData
-- Create scalar variables for slope and intercept.
w <- TF.build (TF.initializedVariable 0)
b <- TF.build (TF.initializedVariable 0)
w <- TF.initializedVariable 0
b <- TF.initializedVariable 0
-- Define the loss function.
let yHat = (x `TF.mul` w) `TF.add` b
loss = TF.square (yHat `TF.sub` y)
-- Optimize with gradient descent.
trainStep <- TF.build (gradientDescent 0.001 loss [w, b])
trainStep <- gradientDescent 0.001 loss [w, b]
replicateM_ 1000 (TF.run trainStep)
-- Return the learned parameters.
(TF.Scalar w', TF.Scalar b') <- TF.run (w, b)
Expand All @@ -60,7 +60,7 @@ fit xData yData = TF.runSession $ do
gradientDescent :: Float
-> TF.Tensor TF.Value Float
-> [TF.Tensor TF.Ref Float]
-> TF.Build TF.ControlNode
-> TF.Session TF.ControlNode
gradientDescent alpha loss params = do
let applyGrad param grad =
TF.assign param (param `TF.sub` (TF.scalar alpha `TF.mul` grad))
Expand Down
6 changes: 3 additions & 3 deletions tensorflow-mnist/tests/ParseTest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import TensorFlow.Tensor
)
import TensorFlow.Ops
import TensorFlow.Session
(runSession, run, run_, runWithFeeds, build, buildAnd)
(runSession, run, run_, runWithFeeds, build)
import TensorFlow.Types (TensorDataType(..), Shape(..), unScalar)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
Expand Down Expand Up @@ -108,7 +108,7 @@ testGraphDefExec :: Test
testGraphDefExec = testCase "testGraphDefExec" $ do
let graphDef = asGraphDef $ render $ scalar (5 :: Float) * 10
runSession $ do
build $ addGraphDef graphDef
addGraphDef graphDef
x <- run $ tensorFromName ValueKind "Mul_2"
liftIO $ (50 :: Float) @=? unScalar x

Expand Down Expand Up @@ -147,7 +147,7 @@ testMNISTExec = testCase "testMNISTExec" $ do
wtsCkptPath <- liftIO wtsCkpt
biasCkptPath <- liftIO biasCkpt
-- Run those restoring nodes on the graph in the current session.
buildAnd run_ $ (sequence :: Monad m => [m a] -> m [a])
run_ =<< (sequence :: Monad m => [m a] -> m [a])
[ restore wtsCkptPath wts
, restoreFromName biasCkptPath "bias" bias
]
Expand Down
6 changes: 3 additions & 3 deletions tensorflow-nn/src/TensorFlow/NN.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ module TensorFlow.NN
import Prelude hiding ( log
, exp
)
import TensorFlow.Build ( Build
import TensorFlow.Build ( MonadBuild
, render
, withNameScope
)
Expand Down Expand Up @@ -71,10 +71,10 @@ import TensorFlow.Ops ( zerosLike
--
-- `logits` and `targets` must have the same type and shape.
sigmoidCrossEntropyWithLogits
:: (OneOf '[Float, Double] a, TensorType a, Num a)
:: (MonadBuild m, OneOf '[Float, Double] a, TensorType a, Num a)
=> Tensor Value a -- ^ __logits__
-> Tensor Value a -- ^ __targets__
-> Build (Tensor Value a)
-> m (Tensor Value a)
sigmoidCrossEntropyWithLogits logits targets = do
logits' <- render logits
targets' <- render targets
Expand Down
5 changes: 2 additions & 3 deletions tensorflow-nn/tests/NNTest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import TensorFlow.Test (assertAllClose)
import Test.Framework (Test)
import Test.Framework.Providers.HUnit (testCase)
import qualified Data.Vector as V
import qualified TensorFlow.Build as TF
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Nodes as TF
import qualified TensorFlow.NN as TF
Expand Down Expand Up @@ -97,8 +96,8 @@ testGradientAtZero = testCase "testGradientAtZero" $ do

assertAllClose (head r) (V.fromList [0.5, -0.5])

run :: TF.Fetchable t a => TF.Build t -> IO a
run = TF.runSession . TF.buildAnd TF.run
run :: TF.Fetchable t a => TF.Session t -> IO a
run = TF.runSession . (>>= TF.run)

main :: IO ()
main = googleTest [ testGradientAtZero
Expand Down
20 changes: 14 additions & 6 deletions tensorflow-opgen/src/TensorFlow/OpGen.hs
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,12 @@ renderHaskellAttrName :: Attr a -> Doc
renderHaskellAttrName = renderHaskellName . attrName

functionBody :: ParsedOp -> Doc
functionBody pOp = buildFunction <+> parens (hang 0 (stack buildOpParts))
functionBody pOp = maybeLift <+> buildFunction <+> parens (hang 0 (stack buildOpParts))
</> indent indentation (sep tensorArgs)
where
maybeLift
| parsedOpIsMonadic pOp = "build $"
| otherwise = ""
buildFunction
| null outputListsSizes = "buildOp"
| otherwise = "buildListOp" <+>
Expand Down Expand Up @@ -277,13 +280,18 @@ typeSig pOp = constraints
++ [outputs])
where
constraints
| null (inferredTypeAttrs pOp) = empty
| otherwise = "forall" <+> sep typeParams <+> "." <+> classConstraints <+> "=>"
| null classConstraints = empty
| otherwise = "forall" <+> sep typeParams <+> "." <+> tuple classConstraints <+> "=>"
typeParams = [strictText v | k <- parsedInputs pOp ++ parsedOutputs pOp,
Just (ArgTensorEither v) <- [argKind $ parsedArgCase k]]
++ [renderHaskellAttrName n | n <- inferredTypeAttrs pOp]
classConstraints = tuple $ map tensorArgConstraint
$ inferredTypeAttrs pOp
++ if parsedOpIsMonadic pOp then ["m'"] else []
-- Use m' as the type parameter to avoid clashing with an attribute name.
monadConstraint
| parsedOpIsMonadic pOp = ["MonadBuild m'"]
| otherwise = []
classConstraints = monadConstraint ++ map tensorArgConstraint
(inferredTypeAttrs pOp)
signatureFold = folddoc (\x y -> x </> "->" <+> y)
attrInput a = renderAttrType (attrInfo a) <+> hang 0 ("-- ^" <+> attrComment a)
renderAttrType (AttrSingle a) = renderAttrBaseType a
Expand All @@ -304,7 +312,7 @@ typeSig pOp = constraints
[a] -> wrapOutput (tensorArg a) <+> "-- ^" <+> argComment a
as -> wrapOutput (tuple (map tensorArg as)) <+/> resultComment as
wrapOutput o
| parsedOpIsMonadic pOp = "Build" <+> parens o
| parsedOpIsMonadic pOp = "m'" <+> parens o
| otherwise = o

-- | Render an op input or output.
Expand Down
9 changes: 5 additions & 4 deletions tensorflow-ops/src/TensorFlow/EmbeddingOps.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ module TensorFlow.EmbeddingOps where

import Control.Monad (zipWithM)
import Data.Int (Int32, Int64)
import TensorFlow.Build (Build, colocateWith, render)
import TensorFlow.Build (MonadBuild, colocateWith, render)
import TensorFlow.Ops (shape, vector) -- Also Num instance for Tensor
import TensorFlow.Tensor (Tensor, Value)
import TensorFlow.Types (OneOf, TensorType)
Expand All @@ -44,8 +44,9 @@ import qualified TensorFlow.GenOps.Core as CoreOps
--
-- The results of the lookup are concatenated into a dense
-- tensor. The returned tensor has shape `shape(ids) + shape(params)[1:]`.
embeddingLookup :: forall a b v .
( TensorType a
embeddingLookup :: forall a b v m .
( MonadBuild m
, TensorType a
, OneOf '[Int64, Int32] b
, Num b
)
Expand All @@ -58,7 +59,7 @@ embeddingLookup :: forall a b v .
-- containing the ids to be looked up in `params`.
-- The ids are required to have fewer than 2^31
-- entries.
-> Build (Tensor Value a)
-> m (Tensor Value a)
-- ^ A dense tensor with shape `shape(ids) + shape(params)[1:]`.
embeddingLookup [p0] ids = colocateWith p0 (render $ CoreOps.gather p0 ids)
embeddingLookup params@(p0 : _) ids = do
Expand Down
13 changes: 8 additions & 5 deletions tensorflow-ops/src/TensorFlow/Gradient.hs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ import qualified Data.Text as Text

import qualified TensorFlow.GenOps.Core as CoreOps
import TensorFlow.Build
( Build
( MonadBuild
, Build
, build
, render
, renderNodeName
, renderedNodeDefs
Expand Down Expand Up @@ -111,16 +113,17 @@ type GradientCompatible a =


-- | Gradient of @y@ w.r.t. each element of @xs@.
gradients :: forall a v1 v2 . ( Num (Tensor v1 a)
gradients :: forall a v1 v2 m . (MonadBuild m
, Num (Tensor v1 a)
-- TODO(gnezdo): remove indirect constraint.
-- It's a wart inherited from Num instance.
-- It's a wart inherited from Num instance.
, v1 ~ Value
, GradientCompatible a
)
=> Tensor v1 a -- ^ The output of the graph.
-> [Tensor v2 a] -- ^ Tensors for which gradients are computed.
-> Build [Tensor Value a]
gradients y xs = do
-> m [Tensor Value a]
gradients y xs = build $ do
-- The gradients are computed using "reverse accumulation", similarly to
-- what is described here:
-- https://en.wikipedia.org/wiki/Automatic_differentiation#The_chain_rule.2C_forward_and_reverse_accumulation
Expand Down
43 changes: 22 additions & 21 deletions tensorflow-ops/src/TensorFlow/Ops.hs
Original file line number Diff line number Diff line change
Expand Up @@ -155,20 +155,20 @@ matTranspose :: forall a v . TensorType a
=> Tensor v a -> Tensor Value a
matTranspose = flip CoreOps.transpose (vector [1, 0 :: Int32])

placeholder :: forall a . TensorType a => Shape -> Build (Tensor Value a)
placeholder :: forall a m . (MonadBuild m, TensorType a) => Shape -> m (Tensor Value a)
placeholder shape' =
buildOp $ opDef "Placeholder"
build $ buildOp $ opDef "Placeholder"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "shape" .~ shape'

-- | Creates a variable initialized to the given value.
-- Initialization happens next time session runs.
initializedVariable :: forall a . TensorType a
=> Tensor Value a -> Build (Tensor Ref a)
initializedVariable :: forall a m . (MonadBuild m, TensorType a)
=> Tensor Value a -> m (Tensor Ref a)
initializedVariable initializer = do
v <- CoreOps.variable [] -- The shape is not known initially.
(i :: Tensor Ref a) <-
buildOp (opDef "Assign"
build $ buildOp (opDef "Assign"
& opAttr "T" .~ tensorType (undefined :: a)
& opAttr "use_locking" .~ True
& opAttr "validate_shape" .~ False
Expand All @@ -179,45 +179,45 @@ initializedVariable initializer = do

-- | Creates a zero-initialized variable with the given shape.
zeroInitializedVariable
:: (TensorType a, Num a) =>
TensorFlow.Types.Shape -> Build (Tensor TensorFlow.Tensor.Ref a)
:: (MonadBuild m, TensorType a, Num a) =>
TensorFlow.Types.Shape -> m (Tensor TensorFlow.Tensor.Ref a)
zeroInitializedVariable = initializedVariable . zeros

-- TODO: Support heterogeneous list of tensors.
save :: forall a v . TensorType a
save :: forall a m v . (MonadBuild m, TensorType a)
=> ByteString -- ^ File path.
-> [Tensor v a] -- ^ Tensors to save.
-> Build ControlNode
-> m ControlNode
save path xs = do
let toByteStringTensor = scalar . encodeUtf8 . unNodeName
names <- mapM (fmap toByteStringTensor . renderNodeName) xs
names <- mapM (fmap toByteStringTensor . build . renderNodeName) xs
let types = replicate (length xs) (tensorType (undefined :: a))
let saveOp = buildOp $ opDef "Save"
& opAttr "T" .~ types
saveOp (scalar path) (CoreOps.pack names) xs
build $ saveOp (scalar path) (CoreOps.pack names) xs

-- | Restore a tensor's value from a checkpoint file.
--
-- This version allows restoring from a checkpoint file that uses a different
-- tensor name than the variable.
restoreFromName :: forall a . TensorType a
restoreFromName :: forall a m . (MonadBuild m, TensorType a)
=> ByteString -- ^ File path.
-> ByteString -- ^ Tensor name override.
-> Tensor Ref a -- ^ Tensor to restore.
-> Build ControlNode
-> m ControlNode
restoreFromName path name x = do
let restoreOp = buildOp $ opDef "Restore"
& opAttr "dt" .~ tensorType (undefined :: a)
group =<< CoreOps.assign x
(restoreOp (scalar path) (scalar name) :: Tensor Value a)

-- | Restore a tensor's value from a checkpoint file.
restore :: forall a . TensorType a
restore :: forall a m . (MonadBuild m, TensorType a)
=> ByteString -- ^ File path.
-> Tensor Ref a -- ^ Tensor to restore.
-> Build ControlNode
-> m ControlNode
restore path x = do
name <- encodeUtf8 . unNodeName <$> renderNodeName x
name <- encodeUtf8 . unNodeName <$> build (renderNodeName x)
restoreFromName path name x

-- | Create a constant tensor.
Expand Down Expand Up @@ -264,12 +264,13 @@ scalar :: forall a . TensorType a => a -> Tensor Value a
scalar x = constant [] [x]

-- Random tensor from the unit normal distribution with bounded values.
truncatedNormal :: forall a v . TensorType a
truncatedNormal :: forall a m v . (MonadBuild m, TensorType a)
=> Tensor v Int64 -- ^ Shape.
-> Build (Tensor Value a)
truncatedNormal = buildOp $ opDef "TruncatedNormal"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "T" .~ tensorType (undefined :: Int64)
-> m (Tensor Value a)
truncatedNormal
= build . buildOp (opDef "TruncatedNormal"
& opAttr "dtype" .~ tensorType (undefined :: a)
& opAttr "T" .~ tensorType (undefined :: Int64))

zeros :: forall a . (Num a, TensorType a) => Shape -> Tensor Value a
zeros (Shape shape') = CoreOps.fill (vector $ map fromIntegral shape') (scalar 0)
Expand Down
24 changes: 8 additions & 16 deletions tensorflow-ops/tests/BuildTest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
module Main where

import Control.Monad.IO.Class (liftIO)
import Data.Functor.Identity (runIdentity)
import Lens.Family2 ((^.))
import Data.List (sort)
import Proto.Tensorflow.Core.Framework.Graph
Expand All @@ -35,7 +34,6 @@ import TensorFlow.Build
, asGraphDef
, evalBuildT
, flushNodeBuffer
, hoistBuildT
, render
, withDevice
, colocateWith
Expand All @@ -53,9 +51,7 @@ import TensorFlow.Ops
import TensorFlow.Output (Device(..))
import TensorFlow.Tensor (Tensor, Value, Ref)
import TensorFlow.Session
( build
, buildAnd
, run
( run
, runSession
, run_
)
Expand All @@ -82,7 +78,7 @@ testNamedDeRef = testCase "testNamedDeRef" $ do
assign v 5
-- TODO: Implement TensorFlow get_variable and test it.
runSession $ do
out <- buildAnd run graph
out <- graph >>= run
liftIO $ 5 @=? (unScalar out :: Float)

-- | Test that "run" will render and extend any pure ops that haven't already
Expand All @@ -96,7 +92,7 @@ testPureRender = testCase "testPureRender" $ runSession $ do
testInitializedVariable :: Test
testInitializedVariable =
testCase "testInitializedVariable" $ runSession $ do
(formula, reset) <- build $ do
(formula, reset) <- do
v <- initializedVariable 42
r <- assign v 24
return (1 `add` v, r)
Expand All @@ -109,7 +105,7 @@ testInitializedVariable =
testInitializedVariableShape :: Test
testInitializedVariableShape =
testCase "testInitializedVariableShape" $ runSession $ do
vector <- build $ initializedVariable (constant [1] [42 :: Float])
vector <- initializedVariable (constant [1] [42 :: Float])
result <- run vector
liftIO $ [42] @=? (result :: V.Vector Float)

Expand All @@ -132,23 +128,19 @@ testNamedAndScoped = testCase "testNamedAndScoped" $ do
"RefIdentity" @=? (nodeDef ^. op)
"foo1/bar1" @=? (nodeDef ^. name)

-- | Lift a Build action into a context for HUnit to run.
liftBuild :: Build a -> BuildT IO a
liftBuild = hoistBuildT (return . runIdentity)

-- | Flush the node buffer and sort the nodes by name (for more stable tests).
flushed :: Ord a => (NodeDef -> a) -> BuildT IO [a]
flushed field = sort . map field <$> liftBuild flushNodeBuffer
flushed field = sort . map field <$> flushNodeBuffer

-- | Test the interaction of rendering, CSE and scoping.
testRenderDedup :: Test
testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
liftBuild renderNodes
renderNodes
names <- flushed (^. name)
liftIO $ ["Const_1", "Variable_0", "Variable_2"] @=? names
-- Render the nodes in a different scope, which should cause them
-- to be distinct from the previous ones.
liftBuild $ withNameScope "foo" renderNodes
withNameScope "foo" renderNodes
scopedNames <- flushed (^. name)
liftIO $ ["foo/Const_4", "foo/Variable_3", "foo/Variable_5"] @=? scopedNames
where
Expand All @@ -165,7 +157,7 @@ testRenderDedup = testCase "testRenderDedup" $ evalBuildT $ do
-- | Test the interaction of rendering, CSE and scoping.
testDeviceColocation :: Test
testDeviceColocation = testCase "testDeviceColocation" $ evalBuildT $ do
liftBuild renderNodes
renderNodes
devices <- flushed (\x -> (x ^. name, x ^. device))
liftIO $ [ ("Add_2","dev0")
, ("Const_1","dev0")
Expand Down
Loading

0 comments on commit 2c5c879

Please sign in to comment.