Skip to content

Commit

Permalink
sortnet implementation + bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
svenssonjoel committed Jan 22, 2015
1 parent 04b5912 commit 79ed57e
Show file tree
Hide file tree
Showing 14 changed files with 525 additions and 136 deletions.
2 changes: 2 additions & 0 deletions Benchmarks/ScanBench/Scan.hs
Expand Up @@ -3,6 +3,8 @@
ScopedTypeVariables#-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE GADTs #-}

module Scan where

Expand Down
2 changes: 2 additions & 0 deletions Examples/Intro/Intro.hs
Expand Up @@ -2,6 +2,8 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ConstraintKinds #-}

module Main where

Expand Down
2 changes: 2 additions & 0 deletions Examples/Simple/MatMul.hs
@@ -1,5 +1,7 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE GADTs #-}

module MatMul where

import Obsidian
Expand Down
3 changes: 3 additions & 0 deletions Examples/Simple/Reduction.hs
@@ -1,6 +1,9 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE GADTs #-}

module Reduction where

import Obsidian
Expand Down
26 changes: 26 additions & 0 deletions Examples/Simple/Simple.cabal
Expand Up @@ -185,3 +185,29 @@ executable Increment.exe
, rdtsc >= 1.3

default-language: Haskell2010

executable SortNetExec.exe
-- .hs or .lhs file containing the Main module.
main-is: SortNetExec.hs

other-extensions: ScopedTypeVariables
, FlexibleContexts
, MultiParamTypeClasses
, NoMonomorphismRestriction
, TypeOperators
, TypeSynonymInstances
, FlexibleInstances
, ConstraintKinds
, GADTs

-- Other library packages from which modules are imported.
build-depends: base >=4 && <5
, vector >=0.10.9.1
, mtl >=2.0
, random >=1.0 && <1.1
, bytestring >=0.10 && <0.11
, Obsidian >= 0.3.0.0
, rdtsc >= 1.3
, time >= 1.4.2

default-language: Haskell2010
120 changes: 120 additions & 0 deletions Examples/Simple/SortNet.hs
@@ -0,0 +1,120 @@

{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}


module SortNet where

import Obsidian

import Data.Word

import Prelude hiding (zip, reverse )
import qualified Prelude as P


---------------------------------------------------------------------------
-- basics
---------------------------------------------------------------------------
riffle :: (Data a, ASize s) => Pull s a -> Pull s a
riffle = unpair . uncurry zip . halve

riffle' :: (t *<=* Block, Data a, ASize s) => Pull s a -> Push t s (a,a)
riffle' = push . uncurry zip . halve
---------------------------------------------------------------------------
-- compare and swap
---------------------------------------------------------------------------
cmpswap :: (Scalar a, Ord a) => (Exp a,Exp a) -> (Exp a,Exp a)
cmpswap (a,b) = ifThenElse (b <* a) (b,a) (a,b)


---------------------------------------------------------------------------
-- Merger
---------------------------------------------------------------------------

shex :: (Compute t, Data a)
=> ((a,a) -> (a,a)) -> SPull a -> SPush t a
shex cmp arr =
exec $ rep (logBaseI 2 (len arr)) (compute . core cmp) arr
where
core c = unpairP . push . fmap c . pair . riffle

shexRev :: (Compute t, Data a)
=> ((a,a) -> (a,a)) -> SPull a -> SPush t a
shexRev cmp arr =
let (arr1,arr2) = halve arr
arr2' = reverse arr2
arr' = arr1 `append` arr2'
in
exec $ rep (logBaseI 2 (len arr)) (compute . core cmp) arr'
where
core c = unpairP . push . fmap c . pair . riffle

shexRev' :: (Array (Push t), Compute t, Data a)
=> ((a,a) -> (a,a)) -> SPull a -> SPush t a
shexRev' cmp arr =
let (arr1,arr2) = halve arr
arr2' = reverse arr2
arr' = (push arr1) `append` (push arr2')
in
exec $ do
arr'' <- compute arr'
rep (logBaseI 2 (len arr)) (compute . core cmp) arr''
where
core c = unpairP . fmap c . riffle'





---------------------------------------------------------------------------
-- Sorter
---------------------------------------------------------------------------

{-
-- | -- | -- | --
-- | -- | -- | --
| |
-- | -- | -- | --
-- | -- | -- | --
|
-- | -- | -- | --
-- | -- | -- | --
| |
-- | -- | -- | --
-- | -- | -- | --
-}

test :: forall a . (Scalar a, Ord a) => SPull (Exp a) -> SPush Block (Exp a)
test = divConq $ shexRev' cmpswap

mapTest :: (Scalar a, Ord a) => DPull (Exp a) -> DPush Grid (Exp a)
mapTest arr = asGridMap test (splitUp 1024 arr)

-- What does this mean ?
divConq :: forall a . Data a
=> (SPull a -> forall t . (Array (Push t), Compute t) => SPush t a)
-> SPull a -> SPush Block a
divConq f arr = execBlock $ doIt (logLen - 1) arr
where logLen = logBaseI 2 (len arr)
doIt 0 arr = do
return $ (f :: SPull a -> SPush Block a) arr

-- doIt n arr = do
-- arr' <- compute $ asBlockMap (f :: SPull a -> SPush Warp a)
-- $ splitUp (2^(logLen - n)) arr
-- doIt (n - 1) arr'

doIt n arr | 2^(logLen - n) > 32 = do
arr' <- compute $ asBlockMap (f :: SPull a -> SPush Warp a)
$ splitUp (2^(logLen - n)) arr
doIt (n - 1) arr'
| otherwise = do
arr' <- compute $ asBlockMap (f :: SPull a -> SPush Thread a)
$ splitUp (2^(logLen - n)) arr
doIt (n - 1) arr'

82 changes: 82 additions & 0 deletions Examples/Simple/SortNetExec.hs
@@ -0,0 +1,82 @@
{-# LANGUAGE ScopedTypeVariables #-}
module Main where

import SortNet

import Obsidian
import Obsidian.Run.CUDA.Exec


import qualified Data.Vector.Storable as V
import Control.Monad.State

import Data.Int
import Data.Word

import System.Exit

import Data.Time.Clock


perform :: IO ()
perform =
withCUDA $
do
compile_t0 <- lift getCurrentTime
kern <- capture 32 mapTest
compile_t1 <- lift getCurrentTime

(inputs' :: V.Vector Word32) <- lift $ mkRandomVec 1024
let inputs = V.map (`mod` 64) inputs'

transfer_start <- lift getCurrentTime
useVector inputs $ \i ->
allocaVector 1024 $ \ o ->
do
transfer_done <- lift getCurrentTime

t0 <- lift getCurrentTime
forM_ [0..999] $ \ _ -> do
o <== (1,kern) <> i
syncAll
t1 <- lift getCurrentTime

r <- peekCUDAVector o

t_end <- lift getCurrentTime

lift $ putStrLn $ "SELFTIMED: " ++ show (diffUTCTime t1 t0)
-- lift $ putStrLn $ "CYCLES: " ++ show (cnt1 - cnt0)

lift $ putStrLn $ "COMPILATION_TIME: " ++ show (diffUTCTime compile_t1 compile_t0)

-- lift $ putStrLn $ "BYTES_TO_DEVICE: " ++ show (fromIntegral (blcks * elts) * sizeOf (undefined :: EWord32))
-- lift $ putStrLn $ "BYTES_FROM_DEVICE: " ++ show (fromIntegral blcks * sizeOf (undefined :: EWord32))
lift $ putStrLn $ "TRANSFER_TO_DEVICE: " ++ show (diffUTCTime transfer_done transfer_start)
lift $ putStrLn $ "TRANSFER_FROM_DEVICE: " ++ show (diffUTCTime t_end t1)

-- lift $ putStrLn $ "ELEMENTS_PROCESSED: " ++ show (fromIntegral (blcks * elts))
-- lift $ putStrLn $ "NUMBER_OF_BLOCKS: " ++ show (fromIntegral blcks)
-- lift $ putStrLn $ "ELEMENTS_PER_BLOCK: " ++ show (fromIntegral elts)



lift $ putStrLn $ show r

if isSorted r
then
do
lift $ putStrLn "Success"
lift $ exitSuccess
else
do
lift $ putStrLn "Failure"
lift $ exitFailure


isSorted [] = True
isSorted [x] = True
isSorted (x:xs) = x <= minimum xs && isSorted xs

main :: IO ()
main = perform

0 comments on commit 79ed57e

Please sign in to comment.