Permalink
Browse files

add special kernel for the upsweep phase of multi-block scan

This kernel calculates the final reduction value for each interval in a multi-block scan. Previously, a standard fold1 was used for this purpose, but that requires a commutative operator which we don't have for the case of segmented scans.

Introduces some mess in the compilation phase, because this falls outside the standard set of AST nodes.

required for AccelerateHS/accelerate#44
  • Loading branch information...
1 parent 70914eb commit aad9233f431f8283b9963d9709e8261a7a66cb42 @tmcdonell committed Aug 8, 2012
View
23 Data/Array/Accelerate/CUDA/CodeGen.hs
@@ -12,7 +12,8 @@
module Data.Array.Accelerate.CUDA.CodeGen (
- CUTranslSkel, codegenAcc
+ CUTranslSkel, codegenAcc,
+ codegenScanlIntervals, codegenScanrIntervals,
) where
@@ -242,6 +243,26 @@ codegenOpenAcc dev acc@(OpenAcc pacc) = case pacc of
$ codegenConst (Sugar.eltType (undefined::e)) c
+-- Exceptional cases
+-- -----------------
+
+-- Scan is a multi-pass algorithm that requires a few additional operations not
+-- in the standard set. These helper functions generate the necessary code.
+--
+codegenScanlIntervals, codegenScanrIntervals
+ :: CUDA.DeviceProperties
+ -> Fun aenv (a -> a -> a)
+ -> AccBindings aenv
+ -> CUTranslSkel
+codegenScanlIntervals dev f avar
+ = codegenBindEnv avar
+ $ mkScanlIntervals dev (codegenFun f)
+
+codegenScanrIntervals dev f avar
+ = codegenBindEnv avar
+ $ mkScanrIntervals dev (codegenFun f)
+
+
-- Scalar Expressions
-- ------------------
View
91 Data/Array/Accelerate/CUDA/CodeGen/PrefixSum.hs
@@ -17,9 +17,10 @@ module Data.Array.Accelerate.CUDA.CodeGen.PrefixSum (
-- skeletons
mkScanl, mkScanr,
+ mkScanlIntervals, mkScanrIntervals,
-- closets
- scanBlock
+ scanBlock,
) where
@@ -44,6 +45,10 @@ mkScanl, mkScanr :: DeviceProperties -> CUFun (a -> a -> a) -> Maybe (CUExp a) -
mkScanl = mkScan L
mkScanr = mkScan R
+mkScanlIntervals, mkScanrIntervals :: DeviceProperties -> CUFun (a -> a -> a) -> CUTranslSkel
+mkScanlIntervals = mkScanIntervals L
+mkScanrIntervals = mkScanIntervals R
+
-- [OVERVIEW]
--
@@ -239,6 +244,90 @@ mkScan dir dev (CULam _ (CULam use0 (CUBody (CUExp env combine)))) mseed =
|]
+-- This computes the _upsweep_ phase of a multi-block scan. This is much like a
+-- regular inclusive scan, except that only the final value for each interval is
+-- output, rather than the entire body of the scan. Indeed, if the combination
+-- function were commutative, this is equivalent to a parallel tree reduction.
+--
+mkScanIntervals
+ :: forall a.
+ Direction
+ -> DeviceProperties
+ -> CUFun (a -> a -> a)
+ -> CUTranslSkel
+mkScanIntervals dir dev (CULam _ (CULam use0 (CUBody (CUExp env combine)))) =
+ CUTranslSkel name [cunit|
+ extern "C"
+ __global__ void
+ $id:name
+ (
+ $params:argOut,
+ $params:argIn0,
+ typename Ix interval_size,
+ const typename Ix num_intervals,
+ const typename Ix num_elements
+ )
+ {
+ $decls:smem
+ $decls:decl1
+ $decls:decl0
+
+ const int start = blockIdx.x * interval_size;
+ const int end = min(start + interval_size, num_elements);
+ interval_size = end - start;
+
+ int carry_in = false;
+
+ for (int i = threadIdx.x; i < interval_size; i += blockDim.x)
+ {
+ const int j = $id:(if dir == L then "start + i" else "end - i - 1");
+ $stms:(x0 .=. getIn0 "j")
+
+ /*
+ * Carry in the result from the previous segment, stored in x1
+ */
+ if ( threadIdx.x == 0 && carry_in ) {
+ $decls:env
+ $stms:(x0 .=. combine)
+ }
+
+ /*
+ * Store our input into shared memory and perform a cooperative
+ * inclusive left scan.
+ */
+ $stms:(sdata "threadIdx.x" .=. x0)
+ __syncthreads();
+
+ $stms:(scanBlock dev elt Nothing (cvar "blockDim.x") sdata env combine)
+
+ /*
+ * Add the final result of this block to the set x1. If this is the
+ * final interval, this value is written out as the interval sum.
+ */
+ if ( threadIdx.x == 0 ) {
+ const int last = min(interval_size - i, blockDim.x) - 1;
+ $stms:(x1 .=. sdata "last")
+ }
+ carry_in = true;
+ }
+
+ /*
+ * Finally, the first thread writes the result for this segment.
+ */
+ if ( threadIdx.x == 0 ) {
+ $stms:(setOut "blockIdx.x" x1)
+ }
+ }
+ |]
+ where
+ name = "scan" ++ show dir ++ "Itv"
+ elt = eltType (undefined :: a)
+ (argIn0, x0, decl0, getIn0, _) = getters 0 elt use0
+ (argOut, _, setOut) = setters elt
+ (x1, decl1) = locals "x1" elt
+ (smem, sdata) = shared 0 Nothing [cexp| blockDim.x |] elt
+
+
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
View
120 Data/Array/Accelerate/CUDA/Compile.hs
@@ -30,21 +30,21 @@ import Data.Array.Accelerate.CUDA.AST
import Data.Array.Accelerate.CUDA.State
import Data.Array.Accelerate.CUDA.CodeGen
import Data.Array.Accelerate.CUDA.Array.Sugar
-import Data.Array.Accelerate.CUDA.Analysis.Launch
-import Data.Array.Accelerate.CUDA.FullList as FL
-import Data.Array.Accelerate.CUDA.Persistent as KT
-import qualified Data.Array.Accelerate.CUDA.Debug as D
+import Data.Array.Accelerate.CUDA.Persistent as KT
+import qualified Data.Array.Accelerate.CUDA.FullList as FL
+import qualified Data.Array.Accelerate.CUDA.Analysis.Launch as Analysis
+import qualified Data.Array.Accelerate.CUDA.Debug as D
-- libraries
import Numeric
-import Prelude hiding ( exp, catch )
+import Prelude hiding ( exp, catch, scanl, scanr )
import Control.Applicative hiding ( Const )
import Control.Exception
import Control.Monad
import Control.Monad.Trans
import Crypto.Hash.MD5 ( hashlazy )
import Data.Label.PureM
-import Data.List
+import Data.List ( intercalate )
import Data.Maybe
import Data.Monoid
import System.Directory
@@ -61,6 +61,7 @@ import qualified Data.Text.Lazy.IO as T
import qualified Data.Text.Lazy.Encoding as T
import qualified Foreign.CUDA.Driver as CUDA
import qualified Foreign.CUDA.Analysis as CUDA
+import Foreign.CUDA.Analysis ( DeviceProperties, Occupancy )
#ifdef VERSION_unix
import System.Posix.Process
@@ -102,7 +103,7 @@ prepareAcc rootAcc = traverseAcc rootAcc
let exec :: (AccBindings aenv, PreOpenAcc ExecOpenAcc aenv a) -> CIO (ExecOpenAcc aenv a)
exec (var, eacc) = do
- kernel <- build acc var
+ kernel <- buildOpenAcc acc var
return $ ExecAcc (FL.singleton () kernel) var eacc
node :: (AccBindings aenv, PreOpenAcc ExecOpenAcc aenv a) -> CIO (ExecOpenAcc aenv a)
@@ -151,43 +152,18 @@ prepareAcc rootAcc = traverseAcc rootAcc
Fold1 f a -> exec =<< liftA2 Fold1 <$> travF f <*> travA a
FoldSeg f e a s -> exec =<< liftA4 FoldSeg <$> travF f <*> travE e <*> travA a <*> travA (segments s)
Fold1Seg f a s -> exec =<< liftA3 Fold1Seg <$> travF f <*> travA a <*> travA (segments s)
+ Scanl f e a -> scanl f =<< liftA3 Scanl <$> travF f <*> travE e <*> travA a
+ Scanl' f e a -> scanl f =<< liftA3 Scanl' <$> travF f <*> travE e <*> travA a
+ Scanl1 f a -> scanl f =<< liftA2 Scanl1 <$> travF f <*> travA a
+ Scanr f e a -> scanr f =<< liftA3 Scanr <$> travF f <*> travE e <*> travA a
+ Scanr' f e a -> scanr f =<< liftA3 Scanr' <$> travF f <*> travE e <*> travA a
+ Scanr1 f a -> scanr f =<< liftA2 Scanr1 <$> travF f <*> travA a
Permute f a g b -> exec =<< liftA4 Permute <$> travF f <*> travA a <*> travF g <*> travA b
Backpermute e f a -> exec =<< liftA3 Backpermute <$> travE e <*> travF f <*> travA a
Stencil f b a -> exec =<< liftA2 (flip Stencil b) <$> travF f <*> travA a
Stencil2 f b1 a1 b2 a2 -> exec =<< liftA3 stencil2 <$> travF f <*> travA a1 <*> travA a2
where stencil2 f' a1' a2' = Stencil2 f' b1 a1' b2 a2'
- -- TODO: write helper functions to clean these up
- Scanl f e a -> do
- ExecAcc (FL _ scan _) var eacc <- exec =<< liftA3 Scanl <$> travF f <*> travE e <*> travA a
- add <- build (OpenAcc (Fold1 f mat)) var
- return $ ExecAcc (cons () add $ FL.singleton () scan) var eacc
-
- Scanl' f e a -> do
- ExecAcc (FL _ scan _) var eacc <- exec =<< liftA3 Scanl' <$> travF f <*> travE e <*> travA a
- add <- build (OpenAcc (Fold1 f mat)) var
- return $ ExecAcc (cons () (retag add) $ FL.singleton () scan) var eacc
-
- Scanl1 f a -> do
- ExecAcc (FL _ scan1 _) var eacc <- exec =<< liftA2 Scanl1 <$> travF f <*> travA a
- add <- build (OpenAcc (Fold1 f mat)) var
- return $ ExecAcc (cons () add $ FL.singleton () scan1) var eacc
-
- Scanr f e a -> do
- ExecAcc (FL _ scan _) var eacc <- exec =<< liftA3 Scanr <$> travF f <*> travE e <*> travA a
- add <- build (OpenAcc (Fold1 f mat)) var
- return $ ExecAcc (cons () add $ FL.singleton () scan) var eacc
-
- Scanr' f e a -> do
- ExecAcc (FL _ scan _) var eacc <- exec =<< liftA3 Scanr' <$> travF f <*> travE e <*> travA a
- add <- build (OpenAcc (Fold1 f mat)) var
- return $ ExecAcc (cons () (retag add) $ FL.singleton () scan) var eacc
-
- Scanr1 f a -> do
- ExecAcc (FL _ scan1 _) var eacc <- exec =<< liftA2 Scanr1 <$> travF f <*> travA a
- add <- build (OpenAcc (Fold1 f mat)) var
- return $ ExecAcc (cons () add $ FL.singleton () scan1) var eacc
-
where
travA :: OpenAcc aenv' a' -> CIO (AccBindings aenv', ExecOpenAcc aenv' a')
travA a = pure <$> traverseAcc a
@@ -210,11 +186,34 @@ prepareAcc rootAcc = traverseAcc rootAcc
Tuple (NilTup `SnocTup` Var (SuccIdx ZeroIdx)
`SnocTup` Var ZeroIdx))))
- mat :: Elt e => OpenAcc aenv (Array DIM2 e)
- mat = OpenAcc $ Use ((), Array (((),0),0) undefined)
-
- noKernel :: FullList () (AccKernel a)
- noKernel = FL () (INTERNAL_ERROR(error) "compile" "no kernel module for this node") Nil
+ -- Special versions of 'exec' for left and right scans, which include
+ -- the first phase upsweep kernel required for multi-block scans.
+ --
+ -- Using 'acc' for the launch and occupancy configuration of the upsweep
+ -- operation, which is similar to an inclusive scan, is close enough.
+ --
+ scanl :: Fun aenv (e -> e -> e) -- hmm?
+ -> (AccBindings aenv, PreOpenAcc ExecOpenAcc aenv a)
+ -> CIO (ExecOpenAcc aenv a)
+ scanl f (var, eacc) = do
+ scan <- buildOpenAcc acc var
+ upsweep <- build (\k d -> compile k d (codegenScanlIntervals d f var))
+ (Analysis.launchConfig acc)
+ (Analysis.determineOccupancy acc)
+ return $! ExecAcc (FL.cons () upsweep $ FL.singleton () scan) var eacc
+
+ scanr :: Fun aenv (e -> e -> e)
+ -> (AccBindings aenv, PreOpenAcc ExecOpenAcc aenv a)
+ -> CIO (ExecOpenAcc aenv a)
+ scanr f (var, eacc) = do
+ scan <- buildOpenAcc acc var
+ upsweep <- build (\k d -> compile k d (codegenScanrIntervals d f var))
+ (Analysis.launchConfig acc)
+ (Analysis.determineOccupancy acc)
+ return $! ExecAcc (FL.cons () upsweep $ FL.singleton () scan) var eacc
+
+ noKernel :: FL.FullList () (AccKernel a)
+ noKernel = FL.FL () (INTERNAL_ERROR(error) "compile" "no kernel module for this node") FL.Nil
-- Traverse a scalar expression
--
@@ -271,17 +270,26 @@ liftA4 f a b c d = f <$> a <*> b <*> c <*> d
-- evaluates and blocks on the external compiler only once the compiled object
-- is truly needed.
--
-build :: OpenAcc aenv a -> AccBindings aenv -> CIO (AccKernel a)
-build acc fvar = do
+buildOpenAcc :: OpenAcc aenv a -> AccBindings aenv -> CIO (AccKernel a)
+buildOpenAcc acc fvar =
+ build (\table dev -> compileOpenAcc table dev acc fvar)
+ (Analysis.launchConfig acc)
+ (Analysis.determineOccupancy acc)
+
+build :: (KernelTable -> DeviceProperties -> CIO (String, KernelKey))
+ -> (DeviceProperties -> Occupancy -> (Int, Int -> Int, Int))
+ -> (DeviceProperties -> CUDA.Fun -> Int -> IO Occupancy)
+ -> CIO (AccKernel a)
+build compile' launchConfig determineOccupancy = do
dev <- gets deviceProps
table <- gets kernelTable
- (entry,key) <- compile table dev acc fvar
- let (cta,blocks,smem) = launchConfig acc dev occ
+ (entry,key) <- compile' table dev
+ let (cta,blocks,smem) = launchConfig dev occ
(mdl,fun,occ) = unsafePerformIO $ do
m <- link table key
f <- CUDA.getFun m entry
l <- CUDA.requires f CUDA.MaxKernelThreadsPerBlock
- o <- determineOccupancy acc dev f l
+ o <- determineOccupancy dev f l
D.when D.dump_cc (stats entry f o)
return (m,f,o)
--
@@ -300,7 +308,8 @@ build acc fvar = do
++ shows (CUDA.activeWarps occ) " warps in "
++ shows (CUDA.activeThreadBlocks occ) " blocks"
--
- -- make sure kernel/stats are printed together
+ -- make sure kernel/stats are printed together. Use 'intercalate' rather
+ -- than 'unlines' to avoid a trailing newline.
--
message $ intercalate "\n" [msg1, " ... " ++ msg2]
@@ -358,12 +367,12 @@ link table key =
-- Generate and compile code for a single open array expression
--
-compile :: KernelTable
- -> CUDA.DeviceProperties
- -> OpenAcc aenv a
- -> AccBindings aenv
- -> CIO (String, KernelKey)
-compile table dev acc fvar = do
+compileOpenAcc :: KernelTable -> CUDA.DeviceProperties -> OpenAcc aenv a -> AccBindings aenv -> CIO (String, KernelKey)
+compileOpenAcc table dev acc fvar
+ = compile table dev (codegenAcc dev acc fvar)
+
+compile :: KernelTable -> CUDA.DeviceProperties -> CUTranslSkel -> CIO (String, KernelKey)
+compile table dev cunit = do
exists <- isJust `fmap` liftIO (KT.lookup table key)
unless exists $ do
message $ unlines [ show key, T.unpack code ]
@@ -379,7 +388,6 @@ compile table dev acc fvar = do
--
return (entry, key)
where
- cunit = codegenAcc dev acc fvar
entry = show cunit
key = (CUDA.computeCapability dev, hashlazy (T.encodeUtf8 code) )
code = displayLazyText . renderCompact $ ppr cunit
View
33 Data/Array/Accelerate/CUDA/Execute.hs
@@ -433,8 +433,10 @@ scanOp
-> Val aenv
-> Vector e
-> CIO (Vector e)
-scanOp dir (FL _ kfold1' (Cons _ kscan Nil)) bindings aenv (Array sh0 in0) = do
- let (_,num_intervals,_) = configure kscan num_elements
+scanOp dir (FL _ kupsweep (Cons _ kscan Nil)) bindings aenv (Array sh0 in0) = do
+ let num_elements = size sh0
+ (_,num_intervals,_) = configure kscan num_elements
+ --
a_out@(Array _ out) <- allocateArray (Z :. num_elements + 1)
(Array _ blk) <- allocateArray (Z :. num_intervals) :: CIO (Vector e)
d_out <- devicePtrsOfArrayData out
@@ -459,7 +461,7 @@ scanOp dir (FL _ kfold1' (Cons _ kscan Nil)) bindings aenv (Array sh0 in0) = do
message $ "scan phase 1: interval_size = " ++ shows interval_size
", num_intervals = " ++ shows num_intervals
", num_elements = " ++ show num_elements
- execute kfold1 bindings aenv num_elements
+ execute kupsweep bindings aenv num_elements
((((((), blk)
, in0)
, convertIx interval_size)
@@ -489,10 +491,6 @@ scanOp dir (FL _ kfold1' (Cons _ kscan Nil)) bindings aenv (Array sh0 in0) = do
, convertIx interval_size)
, convertIx num_elements)
return a_out
- where
- num_elements = size sh0
- kfold1 = retag kfold1' :: AccKernel (Vector e)
--- kscan1 = retag kscan1' :: AccKernel (Vector e)
scanOp _ _ _ _ _ = error "I'll just pretend to hug you until you get here."
@@ -504,8 +502,10 @@ scan'Op
-> Val aenv
-> Vector e
-> CIO (Vector e, Scalar e)
-scan'Op (FL _ kfold1' (Cons _ kscan Nil)) bindings aenv (Array sh0 in0) = do
- let (_,num_intervals,_) = configure kscan num_elements
+scan'Op (FL _ kupsweep (Cons _ kscan Nil)) bindings aenv (Array sh0 in0) = do
+ let num_elements = size sh0
+ (_,num_intervals,_) = configure kscan num_elements
+ --
(Array _ blk) <- allocateArray (Z :. num_intervals) :: CIO (Vector e)
a_out@(Array _ out) <- allocateArray (Z :. num_elements)
a_sum@(Array _ sum) <- allocateArray Z
@@ -516,7 +516,7 @@ scan'Op (FL _ kfold1' (Cons _ kscan Nil)) bindings aenv (Array sh0 in0) = do
message $ "scan phase 1: interval_size = " ++ shows interval_size
", num_intervals = " ++ shows num_intervals
", num_elements = " ++ show num_elements
- execute kfold1 bindings aenv num_elements
+ execute kupsweep bindings aenv num_elements
((((((), blk)
, in0)
, convertIx interval_size)
@@ -541,9 +541,6 @@ scan'Op (FL _ kfold1' (Cons _ kscan Nil)) bindings aenv (Array sh0 in0) = do
, convertIx interval_size)
, convertIx num_elements)
return (a_out, a_sum)
- where
- num_elements = size sh0
- kfold1 = retag kfold1' :: AccKernel (Vector e)
scan'Op _ _ _ _ = error "If I promise not to kill you, can I have a hug?"
@@ -555,8 +552,10 @@ scan1Op
-> Val aenv
-> Vector e
-> CIO (Vector e)
-scan1Op (FL _ kfold1' (Cons _ kscan1 Nil)) bindings aenv (Array sh0 in0) = do
- let (_,num_intervals,_) = configure kscan1 num_elements
+scan1Op (FL _ kupsweep (Cons _ kscan1 Nil)) bindings aenv (Array sh0 in0) = do
+ let num_elements = size sh0
+ (_,num_intervals,_) = configure kscan1 num_elements
+ --
(Array _ sum) <- allocateArray Z :: CIO (Scalar e)
(Array _ blk) <- allocateArray (Z :. num_intervals) :: CIO (Vector e)
a_out@(Array _ out) <- allocateArray (Z :. num_elements)
@@ -567,7 +566,7 @@ scan1Op (FL _ kfold1' (Cons _ kscan1 Nil)) bindings aenv (Array sh0 in0) = do
message $ "scan phase 1: interval_size = " ++ shows interval_size
", num_intervals = " ++ shows num_intervals
", num_elements = " ++ show num_elements
- execute kfold1 bindings aenv num_elements
+ execute kupsweep bindings aenv num_elements
((((((), blk)
, in0)
, convertIx interval_size)
@@ -593,8 +592,6 @@ scan1Op (FL _ kfold1' (Cons _ kscan1 Nil)) bindings aenv (Array sh0 in0) = do
, convertIx num_elements)
return a_out
where
- num_elements = size sh0
- kfold1 = retag kfold1' :: AccKernel (Vector e)
scan1Op _ _ _ _ = error "If you get wet, you'll get sick."

0 comments on commit aad9233

Please sign in to comment.