Skip to content

Commit 4ab9cb9

Browse files
jcberentsenjudah
authored andcommitted
Moved reduceMean to Ops (#136)
1 parent 042910b commit 4ab9cb9

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed

tensorflow-mnist/app/Main.hs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ randomParam width (TF.Shape shape) =
4141
where
4242
stddev = TF.scalar (1 / sqrt (fromIntegral width))
4343

44-
reduceMean :: TF.Tensor TF.Build Float -> TF.Tensor TF.Build Float
45-
reduceMean xs = TF.mean xs (TF.scalar (0 :: Int32))
46-
4744
-- Types must match due to model structure.
4845
type LabelType = Int32
4946

@@ -85,12 +82,12 @@ createModel = do
8582
labels <- TF.placeholder [batchSize]
8683
let labelVecs = TF.oneHot labels (fromIntegral numLabels) 1 0
8784
loss =
88-
reduceMean $ fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
85+
TF.reduceMean $ fst $ TF.softmaxCrossEntropyWithLogits logits labelVecs
8986
params = [hiddenWeights, hiddenBiases, logitWeights, logitBiases]
9087
trainStep <- TF.minimizeWith TF.adam loss params
9188

9289
let correctPredictions = TF.equal predict labels
93-
errorRateTensor <- TF.render $ 1 - reduceMean (TF.cast correctPredictions)
90+
errorRateTensor <- TF.render $ 1 - TF.reduceMean (TF.cast correctPredictions)
9491

9592
return Model {
9693
train = \imFeed lFeed -> TF.runWithFeeds_ [

tensorflow-ops/src/TensorFlow/Ops.hs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ module TensorFlow.Ops
106106
, CoreOps.range
107107
, CoreOps.range'
108108
, reducedShape
109+
, reduceMean
110+
, reduceMean'
109111
, CoreOps.relu
110112
, CoreOps.relu'
111113
, CoreOps.reluGrad
@@ -330,6 +332,23 @@ reduceSum' :: (OneOf '[ Double, Float, Int32, Int64
330332
reduceSum' params x = CoreOps.sum' params x allAxes
331333
where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1
332334

335+
-- | Computes the mean of elements across dimensions of a tensor.
336+
-- See `TensorFlow.GenOps.Core.mean`
337+
reduceMean
338+
:: ( TensorType a
339+
, OneOf '[ Double, Float, Complex Float, Complex Double] a
340+
)
341+
=> Tensor v a -> Tensor Build a
342+
reduceMean = reduceMean' id
343+
344+
reduceMean'
345+
:: ( TensorType a
346+
, OneOf '[ Double, Float, Complex Float, Complex Double] a
347+
)
348+
=> OpParams -> Tensor v a -> Tensor Build a
349+
reduceMean' params x = CoreOps.mean' params x allAxes
350+
where allAxes = CoreOps.range 0 (CoreOps.rank x :: Tensor Build Int32) 1
351+
333352
-- | Create a constant vector.
334353
vector :: TensorType a => [a] -> Tensor Build a
335354
vector = vector' id

tensorflow-ops/tests/MatrixTest.hs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import Control.Monad (replicateM_)
66

77
import qualified Data.Vector as V
88
import qualified TensorFlow.Core as TF
9-
import qualified TensorFlow.GenOps.Core as TF (square, rank)
9+
import qualified TensorFlow.GenOps.Core as TF (square)
1010
import qualified TensorFlow.Minimize as TF
1111
import qualified TensorFlow.Ops as TF hiding (initializedVariable)
1212
import qualified TensorFlow.Variable as TF
@@ -18,17 +18,14 @@ import TensorFlow.Test (assertAllClose)
1818
randomParam :: TF.Shape -> TF.Session (TF.Tensor TF.Value Float)
1919
randomParam (TF.Shape shape) = TF.truncatedNormal (TF.vector shape)
2020

21-
reduceMean :: TF.Tensor v Float -> TF.Tensor TF.Build Float
22-
reduceMean xs = TF.mean xs (TF.range 0 (TF.rank xs) 1)
23-
2421
fitMatrix :: Test
2522
fitMatrix = testCase "fitMatrix" $ TF.runSession $ do
2623
u <- TF.initializedVariable =<< randomParam [2, 1]
2724
v <- TF.initializedVariable =<< randomParam [1, 2]
2825
let ones = [1, 1, 1, 1] :: [Float]
2926
matx = TF.constant [2, 2] ones
3027
diff = matx `TF.sub` (TF.readValue u `TF.matMul` TF.readValue v)
31-
loss = reduceMean $ TF.square diff
28+
loss = TF.reduceMean $ TF.square diff
3229
trainStep <- TF.minimizeWith (TF.gradientDescent 0.01) loss [u, v]
3330
replicateM_ 1000 (TF.run trainStep)
3431
(u',v') <- TF.run (TF.readValue u, TF.readValue v)

0 commit comments

Comments
 (0)