Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
04b5912
commit 79ed57e
Showing
14 changed files
with
525 additions
and
136 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
|
||
{-# LANGUAGE FlexibleContexts #-} | ||
{-# LANGUAGE TypeOperators #-} | ||
{-# LANGUAGE ScopedTypeVariables #-} | ||
{-# LANGUAGE ConstraintKinds #-} | ||
{-# LANGUAGE GADTs #-} | ||
{-# LANGUAGE RankNTypes #-} | ||
|
||
|
||
module SortNet where | ||
|
||
import Obsidian | ||
|
||
import Data.Word | ||
|
||
import Prelude hiding (zip, reverse ) | ||
import qualified Prelude as P | ||
|
||
|
||
--------------------------------------------------------------------------- | ||
-- basics | ||
--------------------------------------------------------------------------- | ||
riffle :: (Data a, ASize s) => Pull s a -> Pull s a | ||
riffle = unpair . uncurry zip . halve | ||
|
||
riffle' :: (t *<=* Block, Data a, ASize s) => Pull s a -> Push t s (a,a) | ||
riffle' = push . uncurry zip . halve | ||
--------------------------------------------------------------------------- | ||
-- compare and swap | ||
--------------------------------------------------------------------------- | ||
cmpswap :: (Scalar a, Ord a) => (Exp a,Exp a) -> (Exp a,Exp a) | ||
cmpswap (a,b) = ifThenElse (b <* a) (b,a) (a,b) | ||
|
||
|
||
--------------------------------------------------------------------------- | ||
-- Merger | ||
--------------------------------------------------------------------------- | ||
|
||
shex :: (Compute t, Data a) | ||
=> ((a,a) -> (a,a)) -> SPull a -> SPush t a | ||
shex cmp arr = | ||
exec $ rep (logBaseI 2 (len arr)) (compute . core cmp) arr | ||
where | ||
core c = unpairP . push . fmap c . pair . riffle | ||
|
||
shexRev :: (Compute t, Data a) | ||
=> ((a,a) -> (a,a)) -> SPull a -> SPush t a | ||
shexRev cmp arr = | ||
let (arr1,arr2) = halve arr | ||
arr2' = reverse arr2 | ||
arr' = arr1 `append` arr2' | ||
in | ||
exec $ rep (logBaseI 2 (len arr)) (compute . core cmp) arr' | ||
where | ||
core c = unpairP . push . fmap c . pair . riffle | ||
|
||
shexRev' :: (Array (Push t), Compute t, Data a) | ||
=> ((a,a) -> (a,a)) -> SPull a -> SPush t a | ||
shexRev' cmp arr = | ||
let (arr1,arr2) = halve arr | ||
arr2' = reverse arr2 | ||
arr' = (push arr1) `append` (push arr2') | ||
in | ||
exec $ do | ||
arr'' <- compute arr' | ||
rep (logBaseI 2 (len arr)) (compute . core cmp) arr'' | ||
where | ||
core c = unpairP . fmap c . riffle' | ||
|
||
|
||
|
||
|
||
|
||
--------------------------------------------------------------------------- | ||
-- Sorter | ||
--------------------------------------------------------------------------- | ||
|
||
{- | ||
-- | -- | -- | -- | ||
-- | -- | -- | -- | ||
| | | ||
-- | -- | -- | -- | ||
-- | -- | -- | -- | ||
| | ||
-- | -- | -- | -- | ||
-- | -- | -- | -- | ||
| | | ||
-- | -- | -- | -- | ||
-- | -- | -- | -- | ||
-} | ||
|
||
test :: forall a . (Scalar a, Ord a) => SPull (Exp a) -> SPush Block (Exp a) | ||
test = divConq $ shexRev' cmpswap | ||
|
||
mapTest :: (Scalar a, Ord a) => DPull (Exp a) -> DPush Grid (Exp a) | ||
mapTest arr = asGridMap test (splitUp 1024 arr) | ||
|
||
-- What does this mean ? | ||
divConq :: forall a . Data a | ||
=> (SPull a -> forall t . (Array (Push t), Compute t) => SPush t a) | ||
-> SPull a -> SPush Block a | ||
divConq f arr = execBlock $ doIt (logLen - 1) arr | ||
where logLen = logBaseI 2 (len arr) | ||
doIt 0 arr = do | ||
return $ (f :: SPull a -> SPush Block a) arr | ||
|
||
-- doIt n arr = do | ||
-- arr' <- compute $ asBlockMap (f :: SPull a -> SPush Warp a) | ||
-- $ splitUp (2^(logLen - n)) arr | ||
-- doIt (n - 1) arr' | ||
|
||
doIt n arr | 2^(logLen - n) > 32 = do | ||
arr' <- compute $ asBlockMap (f :: SPull a -> SPush Warp a) | ||
$ splitUp (2^(logLen - n)) arr | ||
doIt (n - 1) arr' | ||
| otherwise = do | ||
arr' <- compute $ asBlockMap (f :: SPull a -> SPush Thread a) | ||
$ splitUp (2^(logLen - n)) arr | ||
doIt (n - 1) arr' | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
{-# LANGUAGE ScopedTypeVariables #-} | ||
module Main where | ||
|
||
import SortNet | ||
|
||
import Obsidian | ||
import Obsidian.Run.CUDA.Exec | ||
|
||
|
||
import qualified Data.Vector.Storable as V | ||
import Control.Monad.State | ||
|
||
import Data.Int | ||
import Data.Word | ||
|
||
import System.Exit | ||
|
||
import Data.Time.Clock | ||
|
||
|
||
perform :: IO () | ||
perform = | ||
withCUDA $ | ||
do | ||
compile_t0 <- lift getCurrentTime | ||
kern <- capture 32 mapTest | ||
compile_t1 <- lift getCurrentTime | ||
|
||
(inputs' :: V.Vector Word32) <- lift $ mkRandomVec 1024 | ||
let inputs = V.map (`mod` 64) inputs' | ||
|
||
transfer_start <- lift getCurrentTime | ||
useVector inputs $ \i -> | ||
allocaVector 1024 $ \ o -> | ||
do | ||
transfer_done <- lift getCurrentTime | ||
|
||
t0 <- lift getCurrentTime | ||
forM_ [0..999] $ \ _ -> do | ||
o <== (1,kern) <> i | ||
syncAll | ||
t1 <- lift getCurrentTime | ||
|
||
r <- peekCUDAVector o | ||
|
||
t_end <- lift getCurrentTime | ||
|
||
lift $ putStrLn $ "SELFTIMED: " ++ show (diffUTCTime t1 t0) | ||
-- lift $ putStrLn $ "CYCLES: " ++ show (cnt1 - cnt0) | ||
|
||
lift $ putStrLn $ "COMPILATION_TIME: " ++ show (diffUTCTime compile_t1 compile_t0) | ||
|
||
-- lift $ putStrLn $ "BYTES_TO_DEVICE: " ++ show (fromIntegral (blcks * elts) * sizeOf (undefined :: EWord32)) | ||
-- lift $ putStrLn $ "BYTES_FROM_DEVICE: " ++ show (fromIntegral blcks * sizeOf (undefined :: EWord32)) | ||
lift $ putStrLn $ "TRANSFER_TO_DEVICE: " ++ show (diffUTCTime transfer_done transfer_start) | ||
lift $ putStrLn $ "TRANSFER_FROM_DEVICE: " ++ show (diffUTCTime t_end t1) | ||
|
||
-- lift $ putStrLn $ "ELEMENTS_PROCESSED: " ++ show (fromIntegral (blcks * elts)) | ||
-- lift $ putStrLn $ "NUMBER_OF_BLOCKS: " ++ show (fromIntegral blcks) | ||
-- lift $ putStrLn $ "ELEMENTS_PER_BLOCK: " ++ show (fromIntegral elts) | ||
|
||
|
||
|
||
lift $ putStrLn $ show r | ||
|
||
if isSorted r | ||
then | ||
do | ||
lift $ putStrLn "Success" | ||
lift $ exitSuccess | ||
else | ||
do | ||
lift $ putStrLn "Failure" | ||
lift $ exitFailure | ||
|
||
|
||
isSorted [] = True | ||
isSorted [x] = True | ||
isSorted (x:xs) = x <= minimum xs && isSorted xs | ||
|
||
main :: IO () | ||
main = perform |
Oops, something went wrong.