Skip to content

Commit

Permalink
Add annotation fields to the rest of PreSmartExp
Browse files Browse the repository at this point in the history
  • Loading branch information
robbert-vdh committed Sep 3, 2021
1 parent 0c5d3fa commit ab20388
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 143 deletions.
12 changes: 6 additions & 6 deletions src/Data/Array/Accelerate/Classes/Eq.hs
Expand Up @@ -60,9 +60,9 @@ infixr 3 &&
withFrozenCallStack
$ mkExp
$ Pair mkAnn
(SmartExp (Cond (SmartExp $ Prj mkAnn PairIdxLeft x)
(SmartExp $ Prj mkAnn PairIdxLeft y)
(SmartExp $ Const mkAnn scalarTypeWord8 0)))
(SmartExp (Cond mkAnn (SmartExp $ Prj mkAnn PairIdxLeft x)
(SmartExp $ Prj mkAnn PairIdxLeft y)
(SmartExp $ Const mkAnn scalarTypeWord8 0)))
(SmartExp (Nil mkAnn))

-- | Conjunction: True if both arguments are true. This is a strict version of
Expand All @@ -84,9 +84,9 @@ infixr 2 ||
withFrozenCallStack
$ mkExp
$ Pair mkAnn
(SmartExp (Cond (SmartExp $ Prj mkAnn PairIdxLeft x)
(SmartExp $ Const mkAnn scalarTypeWord8 1)
(SmartExp $ Prj mkAnn PairIdxLeft y)))
(SmartExp (Cond mkAnn (SmartExp $ Prj mkAnn PairIdxLeft x)
(SmartExp $ Const mkAnn scalarTypeWord8 1)
(SmartExp $ Prj mkAnn PairIdxLeft y)))
(SmartExp (Nil mkAnn))


Expand Down
2 changes: 1 addition & 1 deletion src/Data/Array/Accelerate/Classes/Ord.hs
Expand Up @@ -83,7 +83,7 @@ class Eq a => Ord a where
-- Local redefinition for use with RebindableSyntax (pulled forward from Prelude.hs)
--
ifThenElse :: (HasCallStack, Elt a) => Exp Bool -> Exp a -> Exp a -> Exp a
ifThenElse (Exp c) (Exp x) (Exp y) = withFrozenCallStack $ Exp $ SmartExp $ Cond (mkCoerce' c) x y
ifThenElse (Exp c) (Exp x) (Exp y) = withFrozenCallStack $ Exp $ SmartExp $ Cond mkAnn (mkCoerce' c) x y

instance Ord () where
(<) _ _ = withFrozenCallStack $ constant False
Expand Down
2 changes: 1 addition & 1 deletion src/Data/Array/Accelerate/Data/Complex.hs
Expand Up @@ -180,7 +180,7 @@ deconstructComplex :: forall a. (HasCallStack, Elt a) => Exp (Complex a) -> (Exp
deconstructComplex c@(Exp c') =
case complexR (eltR @a) of
ComplexTup -> let T2 r i = coerce c in (r, i)
ComplexVec t -> let T2 r i = Exp (SmartExp (VecUnpack (VecRsucc (VecRsucc (VecRnil t))) c'))
ComplexVec t -> let T2 r i = Exp (SmartExp (VecUnpack mkAnn (VecRsucc (VecRsucc (VecRnil t))) c'))
in (r, i)

coerce :: EltR a ~ EltR b => Exp a -> Exp b
Expand Down
22 changes: 11 additions & 11 deletions src/Data/Array/Accelerate/Language.hs
Expand Up @@ -1216,7 +1216,7 @@ foreignExp
-> (Exp x -> Exp y)
-> Exp x
-> Exp y
foreignExp asm f (Exp x) = withFrozenCallStack $ mkExp $ Foreign (eltR @y) asm (unExpFunction f) x
foreignExp asm f (Exp x) = withFrozenCallStack $ mkExp $ Foreign mkAnn (eltR @y) asm (unExpFunction f) x


-- Composition of array computations
Expand Down Expand Up @@ -1281,12 +1281,12 @@ toIndex
=> Exp sh -- ^ extent of the array
-> Exp sh -- ^ index to remap
-> Exp Int
toIndex (Exp sh) (Exp ix) = withFrozenCallStack $ mkExp $ ToIndex (shapeR @sh) sh ix
toIndex (Exp sh) (Exp ix) = withFrozenCallStack $ mkExp $ ToIndex mkAnn (shapeR @sh) sh ix

-- | Inverse of 'toIndex'
--
fromIndex :: forall sh. (HasCallStack, Shape sh) => Exp sh -> Exp Int -> Exp sh
fromIndex (Exp sh) (Exp e) = withFrozenCallStack $ mkExp $ FromIndex (shapeR @sh) sh e
fromIndex (Exp sh) (Exp e) = withFrozenCallStack $ mkExp $ FromIndex mkAnn (shapeR @sh) sh e

-- | Intersection of two shapes
--
Expand All @@ -1298,7 +1298,7 @@ intersect (Exp shx) (Exp shy) = withFrozenCallStack $ Exp $ intersect' (shapeR @
intersect' (ShapeRsnoc shR) (unPair -> (xs, x)) (unPair -> (ys, y))
= SmartExp
$ Pair mkAnn (intersect' shR xs ys)
(SmartExp (PrimApp (PrimMin singleType) $ SmartExp $ Pair mkAnn x y))
(SmartExp (PrimApp mkAnn (PrimMin singleType) $ SmartExp $ Pair mkAnn x y))


-- | Union of two shapes
Expand All @@ -1311,7 +1311,7 @@ union (Exp shx) (Exp shy) = withFrozenCallStack $ Exp $ union' (shapeR @sh) shx
union' (ShapeRsnoc shR) (unPair -> (xs, x)) (unPair -> (ys, y))
= SmartExp
$ Pair mkAnn (union' shR xs ys)
(SmartExp (PrimApp (PrimMax singleType) $ SmartExp $ Pair mkAnn x y))
(SmartExp (PrimApp mkAnn (PrimMax singleType) $ SmartExp $ Pair mkAnn x y))


-- Flow-control
Expand All @@ -1327,7 +1327,7 @@ cond :: (HasCallStack, Elt t)
-> Exp t -- ^ then-expression
-> Exp t -- ^ else-expression
-> Exp t
cond (Exp c) (Exp x) (Exp y) = withFrozenCallStack $ mkExp $ Cond (mkCoerce' c) x y
cond (Exp c) (Exp x) (Exp y) = withFrozenCallStack $ mkExp $ Cond mkAnn (mkCoerce' c) x y

-- | While construct. Continue to apply the given function, starting with the
-- initial value, until the test function evaluates to 'False'.
Expand All @@ -1338,7 +1338,7 @@ while :: forall e. (HasCallStack, Elt e)
-> Exp e -- ^ initial value
-> Exp e
while c f (Exp e) =
withFrozenCallStack $ mkExp $ While @(EltR e) (eltR @e)
withFrozenCallStack $ mkExp $ While @(EltR e) mkAnn (eltR @e)
(mkCoerce' . unExp . c . Exp)
(unExp . f . Exp) e

Expand All @@ -1363,7 +1363,7 @@ while c f (Exp e) =
--
infixl 9 !
(!) :: forall sh e. (HasCallStack, Shape sh, Elt e) => Acc (Array sh e) -> Exp sh -> Exp e
Acc a ! Exp ix = withFrozenCallStack $ mkExp $ Index (eltR @e) a ix
Acc a ! Exp ix = withFrozenCallStack $ mkExp $ Index mkAnn (eltR @e) a ix

-- | Extract the value from an array at the specified linear index.
-- Multidimensional arrays in Accelerate are stored in row-major order with
Expand All @@ -1383,12 +1383,12 @@ Acc a ! Exp ix = withFrozenCallStack $ mkExp $ Index (eltR @e) a ix
--
infixl 9 !!
(!!) :: forall sh e. (HasCallStack, Shape sh, Elt e) => Acc (Array sh e) -> Exp Int -> Exp e
Acc a !! Exp ix = withFrozenCallStack $ mkExp $ LinearIndex (eltR @e) a ix
Acc a !! Exp ix = withFrozenCallStack $ mkExp $ LinearIndex mkAnn (eltR @e) a ix

-- | Extract the shape (extent) of an array.
--
shape :: forall sh e. (HasCallStack, Shape sh, Elt e) => Acc (Array sh e) -> Exp sh
shape = withFrozenCallStack $ mkExp . Shape (shapeR @sh) . unAcc
shape = withFrozenCallStack $ mkExp . Shape mkAnn (shapeR @sh) . unAcc

-- | The number of elements in the array
--
Expand All @@ -1398,7 +1398,7 @@ size = withFrozenCallStack $ shapeSize . shape
-- | The number of elements that would be held by an array of the given shape.
--
shapeSize :: forall sh. (HasCallStack, Shape sh) => Exp sh -> Exp Int
shapeSize (Exp sh) = withFrozenCallStack $ mkExp $ ShapeSize (shapeR @sh) sh
shapeSize (Exp sh) = withFrozenCallStack $ mkExp $ ShapeSize mkAnn (shapeR @sh) sh


-- Numeric functions
Expand Down
4 changes: 2 additions & 2 deletions src/Data/Array/Accelerate/Pattern.hs
Expand Up @@ -213,8 +213,8 @@ runQ $ do
--
[d| instance $context => IsVector Exp $(varT v) $tup where
vpack x = case builder x :: Exp $tR of
Exp x' -> Exp (SmartExp (VecPack $vecR x'))
vunpack (Exp x) = matcher (Exp (SmartExp (VecUnpack $vecR x)) :: Exp $tR)
Exp x' -> Exp (SmartExp (VecPack mkAnn $vecR x'))
vunpack (Exp x) = matcher (Exp (SmartExp (VecUnpack mkAnn $vecR x)) :: Exp $tR)
|]
--
es <- mapM mkExpPattern [0..16]
Expand Down
4 changes: 2 additions & 2 deletions src/Data/Array/Accelerate/Prelude.hs
Expand Up @@ -119,6 +119,7 @@ module Data.Array.Accelerate.Prelude (
) where

import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Annotations
import Data.Array.Accelerate.Language
import Data.Array.Accelerate.Lift
import Data.Array.Accelerate.Pattern
Expand All @@ -137,7 +138,6 @@ import Data.Array.Accelerate.Classes.Ord

import Data.Array.Accelerate.Data.Bits

import GHC.Stack
import Lens.Micro ( Lens', (&), (^.), (.~), (+~), (-~), lens, over )
import Prelude ( (.), ($), Maybe(..), const, id, flip )

Expand Down Expand Up @@ -2308,7 +2308,7 @@ instance (Elt e, Matching r) => Matching (Exp e -> r) where
-- product types.
_ -> case rhs of
[(_,r)] -> Exp r
_ -> Exp (SmartExp (Case p rhs))
_ -> Exp (SmartExp (Case mkAnn p rhs))
where
rhs = [ (tag, unExp (mkMatch (f x') xs))
| tag <- tagsR @e
Expand Down

0 comments on commit ab20388

Please sign in to comment.