Skip to content

Commit

Permalink
codegen: split body and header bindings generation
Browse files Browse the repository at this point in the history
  • Loading branch information
tmcdonell committed Aug 8, 2012
1 parent 1a22a1f commit 70914eb
Showing 1 changed file with 130 additions and 100 deletions.
230 changes: 130 additions & 100 deletions Data/Array/Accelerate/CUDA/CodeGen.hs
Expand Up @@ -75,96 +75,141 @@ prj _ _ = INTERNAL_ERROR(error) "prj" "inconsistent valuat
--
-- TODO: include a measure of how much shared memory a kernel requires.
--
codegenAcc :: forall aenv a.
CUDA.DeviceProperties
-> OpenAcc aenv a
-> AccBindings aenv
-> CUTranslSkel
codegenAcc dev acc (AccBindings vars) = CUTranslSkel entry (extras : fvars code)
codegenAcc :: CUDA.DeviceProperties -> OpenAcc aenv a -> AccBindings aenv -> CUTranslSkel
codegenAcc dev acc var
= codegenBindEnv var
$ codegenOpenAcc dev acc


-- Add the include statement for the standard Accelerate headers, and any
-- binding points required for environment arrays accessed by scalar functions.
--
codegenBindEnv :: AccBindings aenv -> CUTranslSkel -> CUTranslSkel
codegenBindEnv (AccBindings vars) (CUTranslSkel entry code) =
CUTranslSkel entry (headers : fvars code)
where
fvars rest = Set.foldr (\v vs -> liftAcc acc v ++ vs) rest vars
extras = [cedecl| $esc:("#include <accelerate_cuda_extras.h>") |]
CUTranslSkel entry code = codegen acc
headers = [cedecl| $esc:("#include <accelerate_cuda_extras.h>") |]
fvars rest = Set.foldr (\v vs -> liftAcc v ++ vs) rest vars

codegen :: OpenAcc aenv a -> CUTranslSkel
codegen (OpenAcc pacc) = case pacc of
--
-- Non-computation forms
--
Alet _ _ -> internalError
Avar _ -> internalError
Apply _ _ -> internalError
Acond _ _ _ -> internalError
Atuple _ -> internalError
Aprj _ _ -> internalError
Use _ -> internalError
Unit _ -> internalError
Reshape _ _ -> internalError
-- Generate binding points (texture references and shapes) for arrays lifted
-- from scalar expressions
--
liftAcc :: ArrayVar aenv -> [C.Definition]
liftAcc (ArrayVar idx) =
let avar = OpenAcc (Avar idx)
idx' = show $ idxToInt idx
sh = cshape ("sh" ++ idx') (accDim avar)
ty = accTypeTex avar
arr n = "avar" ++ idx' ++ "_a" ++ show (n::Int)
in
sh : zipWith (\t n -> cglobal t (arr n)) (reverse ty) [0..]


-- Generate code for a single __device__ entry function.
--
codegenOpenAcc :: forall aenv a. CUDA.DeviceProperties -> OpenAcc aenv a -> CUTranslSkel
codegenOpenAcc dev acc@(OpenAcc pacc) = case pacc of
--
-- Non-computation forms -> sadness.
--
Alet _ _ -> internalError
Avar _ -> internalError
Apply _ _ -> internalError
Acond _ _ _ -> internalError
Atuple _ -> internalError
Aprj _ _ -> internalError
Use _ -> internalError
Unit _ -> internalError
Reshape _ _ -> internalError

--
-- Skeleton nodes
--
Generate _ f ->
mkGenerate (accDim acc) (codegenFun f)

Replicate sl _ a ->
mkReplicate dimSl dimOut (extend sl) (undefined :: a)
where
dimSl = accDim a
dimOut = accDim acc
--
-- Skeleton nodes
extend :: SliceIndex slix sl co dim -> CUExp dim
extend = CUExp [] . reverse . extend' 0

extend' :: Int -> SliceIndex slix sl co dim -> [C.Exp]
extend' _ (SliceNil) = []
extend' n (SliceAll sliceIdx) = mkPrj dimOut "dim" n : extend' (n+1) sliceIdx
extend' n (SliceFixed sliceIdx) = extend' (n+1) sliceIdx

Index sl a slix ->
mkSlice dimSl dimCo dimIn0 (restrict sl) (undefined :: a)
where
dimCo = length (expType slix)
dimSl = accDim acc
dimIn0 = accDim a
--
Generate _ f -> mkGenerate (accDim acc) (codegenFun f)

Replicate sl _ a -> mkReplicate dimSl dimOut (extend sl) (undefined :: a)
where
dimSl = accDim a
dimOut = accDim acc
--
extend :: SliceIndex slix sl co dim -> CUExp dim
extend = CUExp [] . reverse . extend' 0

extend' :: Int -> SliceIndex slix sl co dim -> [C.Exp]
extend' _ (SliceNil) = []
extend' n (SliceAll sliceIdx) = mkPrj dimOut "dim" n : extend' (n+1) sliceIdx
extend' n (SliceFixed sliceIdx) = extend' (n+1) sliceIdx

Index sl a slix -> mkSlice dimSl dimCo dimIn0 (restrict sl) (undefined :: a)
where
dimCo = length (expType slix)
dimSl = accDim acc
dimIn0 = accDim a
--
restrict :: SliceIndex slix sl co dim -> CUExp slix
restrict = CUExp [] . reverse . restrict' (0,0)

restrict' :: (Int,Int) -> SliceIndex slix sl co dim -> [C.Exp]
restrict' _ (SliceNil) = []
restrict' (m,n) (SliceAll sliceIdx) = mkPrj dimSl "sl" n : restrict' (m,n+1) sliceIdx
restrict' (m,n) (SliceFixed sliceIdx) = mkPrj dimCo "co" m : restrict' (m+1,n) sliceIdx

Map f _ -> mkMap (codegenFun f)
ZipWith f _ _ -> mkZipWith (accDim acc) (codegenFun f)

Fold f e _ ->
if accDim acc == 0
then mkFoldAll dev (codegenFun f) (Just (codegenExp e))
else mkFold dev (codegenFun f) (Just (codegenExp e))

Fold1 f _ ->
if accDim acc == 0
then mkFoldAll dev (codegenFun f) Nothing
else mkFold dev (codegenFun f) Nothing

FoldSeg f e _ s -> mkFoldSeg dev (accDim acc) (segmentsType s) (codegenFun f) (Just (codegenExp e))
Fold1Seg f _ s -> mkFoldSeg dev (accDim acc) (segmentsType s) (codegenFun f) Nothing

Scanl f e _ -> mkScanl dev (codegenFun f) (Just (codegenExp e))
Scanl' f e _ -> mkScanl dev (codegenFun f) (Just (codegenExp e))
Scanl1 f _ -> mkScanl dev (codegenFun f) Nothing

Scanr f e _ -> mkScanr dev (codegenFun f) (Just (codegenExp e))
Scanr' f e _ -> mkScanr dev (codegenFun f) (Just (codegenExp e))
Scanr1 f _ -> mkScanr dev (codegenFun f) Nothing

Permute f _ ix a -> mkPermute dev (accDim acc) (accDim a) (codegenFun f) (codegenFun ix)
Backpermute _ f a -> mkBackpermute (accDim acc) (accDim a) (codegenFun f) (undefined :: a)

Stencil f b0 a0 -> mkStencil (accDim acc) (codegenFun f) (codegenBoundary a0 b0) (undefined :: a)
Stencil2 f b1 a1 b0 a0
-> mkStencil2 (accDim acc) (codegenFun f) (codegenBoundary a1 b1) (codegenBoundary a0 b0) (undefined :: a)
restrict :: SliceIndex slix sl co dim -> CUExp slix
restrict = CUExp [] . reverse . restrict' (0,0)

--
restrict' :: (Int,Int) -> SliceIndex slix sl co dim -> [C.Exp]
restrict' _ (SliceNil) = []
restrict' (m,n) (SliceAll sliceIdx) = mkPrj dimSl "sl" n : restrict' (m,n+1) sliceIdx
restrict' (m,n) (SliceFixed sliceIdx) = mkPrj dimCo "co" m : restrict' (m+1,n) sliceIdx

Map f _ ->
mkMap (codegenFun f)

ZipWith f _ _ ->
mkZipWith (accDim acc) (codegenFun f)

Fold f e _ ->
if accDim acc == 0
then mkFoldAll dev (codegenFun f) (Just (codegenExp e))
else mkFold dev (codegenFun f) (Just (codegenExp e))

Fold1 f _ ->
if accDim acc == 0
then mkFoldAll dev (codegenFun f) Nothing
else mkFold dev (codegenFun f) Nothing

FoldSeg f e _ s ->
mkFoldSeg dev (accDim acc) (segmentsType s) (codegenFun f) (Just (codegenExp e))

Fold1Seg f _ s ->
mkFoldSeg dev (accDim acc) (segmentsType s) (codegenFun f) Nothing

Scanl f e _ ->
mkScanl dev (codegenFun f) (Just (codegenExp e))

Scanl' f e _ ->
mkScanl dev (codegenFun f) (Just (codegenExp e))

Scanl1 f _ ->
mkScanl dev (codegenFun f) Nothing

Scanr f e _ ->
mkScanr dev (codegenFun f) (Just (codegenExp e))

Scanr' f e _ ->
mkScanr dev (codegenFun f) (Just (codegenExp e))

Scanr1 f _ ->
mkScanr dev (codegenFun f) Nothing

Permute f _ ix a ->
mkPermute dev (accDim acc) (accDim a) (codegenFun f) (codegenFun ix)

Backpermute _ f a ->
mkBackpermute (accDim acc) (accDim a) (codegenFun f) (undefined :: a)

Stencil f b0 a0 ->
mkStencil (accDim acc) (codegenFun f) (codegenBoundary a0 b0) (undefined :: a)

Stencil2 f b1 a1 b0 a0 ->
mkStencil2 (accDim acc) (codegenFun f) (codegenBoundary a1 b1)
(codegenBoundary a0 b0) (undefined :: a)
where
-- caffeine and misery
--
internalError =
Expand All @@ -175,19 +220,6 @@ codegenAcc dev acc (AccBindings vars) = CUTranslSkel entry (extras : fvars code)
in
INTERNAL_ERROR(error) "codegenAcc" msg

-- Generate binding points (texture references and shapes) for arrays lifted
-- from scalar expressions
--
liftAcc :: OpenAcc aenv a -> ArrayVar aenv -> [C.Definition]
liftAcc _ (ArrayVar idx) =
let avar = OpenAcc (Avar idx)
idx' = show $ idxToInt idx
sh = cshape ("sh" ++ idx') (accDim avar)
ty = accTypeTex avar
arr n = "avar" ++ idx' ++ "_a" ++ show (n::Int)
in
sh : zipWith (\t n -> cglobal t (arr n)) (reverse ty) [0..]

-- Shapes are still represented as C structs, so we need to generate field
-- indexing code for shapes
--
Expand All @@ -196,11 +228,10 @@ codegenAcc dev acc (AccBindings vars) = CUTranslSkel entry (extras : fvars code)
| ndim <= 1 = cvar var
| otherwise = [cexp| $exp:(cvar var) . $id:('a':show c) |]


-- code generation for stencil boundary conditions
--
codegenBoundary :: forall dim e. Sugar.Elt e
=> OpenAcc aenv (Sugar.Array dim e) {- dummy -}
=> OpenAcc aenv (Sugar.Array dim e) {- dummy -}
-> Boundary (Sugar.EltRepr e)
-> Boundary (CUExp e)
codegenBoundary _ Clamp = Clamp
Expand All @@ -211,7 +242,6 @@ codegenAcc dev acc (AccBindings vars) = CUTranslSkel entry (extras : fvars code)
$ codegenConst (Sugar.eltType (undefined::e)) c



-- Scalar Expressions
-- ------------------

Expand Down

0 comments on commit 70914eb

Please sign in to comment.