Skip to content
This repository has been archived by the owner on Nov 24, 2022. It is now read-only.

[WIP] Add some parallelization to ahc-ld and ahc-link #622

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions asterius/app/ahc-ld.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
{-# LANGUAGE ViewPatterns #-}

import Asterius.Ld
import Control.Concurrent
import Control.Monad
import Data.List
import Data.Maybe
import Data.String
Expand All @@ -20,6 +22,10 @@ parseLinkTask args = do
linkObjs = link_objs,
linkLibs = link_libs,
linkModule = mempty,
threadPoolSize =
maybe 1 read $
find ("--thread-pool-size=" `isPrefixOf`) args
>>= stripPrefix "--thread-pool-size=",
hasMain = "--no-main" `notElem` args,
debug = "--debug" `elem` args,
gcSections = "--no-gc-sections" `notElem` args,
Expand Down Expand Up @@ -63,5 +69,6 @@ main = do
rsp <- readFile rsp_path
let rsp_args = map read $ lines rsp
task <- parseLinkTask rsp_args
when (threadPoolSize task > 1) $ setNumCapabilities (threadPoolSize task)
ignore <- isJust <$> getEnv "ASTERIUS_AHC_LD_IGNORE"
if ignore then callProcess "touch" [linkOutput task] else linkExe task
7 changes: 6 additions & 1 deletion asterius/app/ahc-link.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import Asterius.Main
import Control.Concurrent
import Control.Monad

main :: IO ()
main = getTask >>= ahcLinkMain
main = do
task <- getTask
when (threadPoolSize task > 1) $ setNumCapabilities (threadPoolSize task)
ahcLinkMain task
32 changes: 23 additions & 9 deletions asterius/src/Asterius/Ar.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,36 @@

module Asterius.Ar
( loadAr,
loadArchiveEntries,
loadArchiveEntry
)
where

import qualified Ar as GHC
import Asterius.Binary.ByteString
import Asterius.Types
import Data.Foldable
import qualified IfaceEnv as GHC

-- | Load an archive file from disk, deserialize all objects it contains and
-- concatenate them into a single 'AsteriusCachedModule'.
loadAr :: GHC.NameCacheUpdater -> FilePath -> IO AsteriusCachedModule
loadAr ncu p = do
loadAr ncu p = do -- TODO: This sequential version is currently being used by
-- Asterius.GHCi.Internals.asteriusIservCall
entries <- loadArchiveEntries p
mconcat <$> mapM (loadArchiveEntry ncu) entries

-- | Load all the archive entries from an archive file @.a@, as a list of plain
-- 'ByteString's (content only).
{-# INLINE loadArchiveEntries #-}
loadArchiveEntries :: FilePath -> IO [GHC.ArchiveEntry]
loadArchiveEntries p = do
GHC.Archive entries <- GHC.loadAr p
foldlM
( \acc GHC.ArchiveEntry {..} -> tryGetBS ncu filedata >>= \case
Left _ -> pure acc
Right m -> pure $ m <> acc
)
mempty
entries
return entries

-- | Deserialize an 'GHC.ArchiveEntry'. In case deserialization fails, return
-- an empty 'AsteriusModule'.
loadArchiveEntry :: GHC.NameCacheUpdater -> GHC.ArchiveEntry -> IO AsteriusCachedModule
loadArchiveEntry ncu = \GHC.ArchiveEntry {..} ->
tryGetBS ncu filedata >>= \case
Left {} -> pure mempty
Right m -> pure m
27 changes: 13 additions & 14 deletions asterius/src/Asterius/Backends/Binaryen.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ where
import Asterius.Internals.Barf
import Asterius.Internals.MagicNumber
import Asterius.Internals.Marshal
import Asterius.Internals.Parallel
import Asterius.Types
import qualified Asterius.Types.SymbolMap as SM
import Asterius.TypesConv
Expand Down Expand Up @@ -525,24 +526,22 @@ marshalFunctionTable m tbl_slots FunctionTable {..} = flip runContT pure $ do
(fromIntegral fnl)
o

-- | Marshal the memory segments of a 'Module'. NOTE: It would be nice to
-- parallelize this process (see issue #621), but given that we want marshaling
-- to happen in @ContT@ for efficiency reasons, this might backfire. Leaving
-- linear for now.
marshalMemorySegments :: Int -> [DataSegment] -> CodeGen ()
marshalMemorySegments mbs segs = do
env <- ask
m <- askModuleRef
let segs_len = length segs
marshalOffset = \DataSegment {..} ->
lift $ flip runReaderT env $ marshalExpression $ ConstI32 offset
lift $ flip runContT pure $ do
(seg_bufs, _) <- marshalV =<< for segs (marshalBS . content)
(seg_passives, _) <- marshalV $ replicate segs_len 0
(seg_offsets, _) <-
marshalV
=<< for
segs
( \DataSegment {..} ->
lift $ flip runReaderT env $ marshalExpression $ ConstI32 offset
)
(seg_sizes, _) <-
marshalV $
map (fromIntegral . BS.length . content) segs
(seg_offsets, _) <- marshalV =<< for segs marshalOffset
(seg_sizes, _) <- marshalV $ map (fromIntegral . BS.length . content) segs
lift $
Binaryen.setMemory
m
Expand Down Expand Up @@ -571,8 +570,8 @@ marshalMemoryImport m MemoryImport {..} = flip runContT pure $ do
lift $ Binaryen.addMemoryImport m inp emp ebp 0

marshalModule ::
Bool -> SM.SymbolMap Int64 -> Module -> IO Binaryen.Module
marshalModule tail_calls sym_map hs_mod@Module {..} = do
Bool -> Int -> SM.SymbolMap Int64 -> Module -> IO Binaryen.Module
marshalModule tail_calls pool_size sym_map hs_mod@Module {..} = do
let fts = generateWasmFunctionTypeSet hs_mod
m <- Binaryen.Module.create
Binaryen.setFeatures m
Expand All @@ -588,8 +587,8 @@ marshalModule tail_calls sym_map hs_mod@Module {..} = do
envSymbolMap = sym_map,
envModuleRef = m
}
for_ (M.toList functionMap') $ \(k, f@Function {..}) ->
flip runReaderT env $ marshalFunction k (ftps M.! functionType) f
parallelFoldMap pool_size (M.toList functionMap') $ \(k, f@Function {..}) ->
flip runReaderT env $ void $ marshalFunction k (ftps M.! functionType) f
forM_ functionImports $ \fi@FunctionImport {..} ->
marshalFunctionImport m (ftps M.! functionType) fi
forM_ functionExports $ marshalFunctionExport m
Expand Down
1 change: 1 addition & 0 deletions asterius/src/Asterius/GHCi/Internals.hs
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ asteriusWriteIServ hsc_env i a
debug = False,
gcSections = True,
verboseErr = True,
threadPoolSize = 1,
outputIR = Nothing,
rootSymbols =
[ run_q_exp_sym,
Expand Down
60 changes: 60 additions & 0 deletions asterius/src/Asterius/Internals/Parallel.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module : Asterius.Internals.Parallel
-- Copyright : (c) 2018 EURL Tweag
-- License : All rights reserved (see LICENCE file in the distribution).
--
-- Simple parallel combinators. Since we need to control our dependency
-- surface, our current approach to parallelism is very simple: given the
-- worker thread pool capacity @c@ and the list of tasks to be performed,
-- 'parallelFoldMap' pins exactly @c@ threads on each of the capabilities, lets
-- them consume the input concurrently, and gathers the results using their
-- 'Monoid' instance. Notice that this behavior is deterministic only if '<>'
-- is also symmetric (not only associative), but that is sufficient for our
-- usecases.
--
-- To avoid needless threading overhead, if @c = 1@ them we fall back to the
-- sequential implementation.
module Asterius.Internals.Parallel
( parallelRnf,
parallelFoldMap,
)
where

import Control.Concurrent
import Control.Concurrent.MVar
import Control.DeepSeq
import Control.Exception
import Control.Monad
import Data.IORef
import System.IO.Unsafe

-- | Given the worker thread pool capacity @c@, @parallelRnf c xs@ deeply
-- evaluates a list of objects in parallel on the global thread pool.
parallelRnf :: NFData a => Int -> [a] -> ()
parallelRnf n xs
| n >= 2 = unsafePerformIO $ parallelFoldMap n xs (void . evaluate . force)
| otherwise = rnf xs

-- | Given the worker thread pool capacity @c@, @parallelFoldMap c xs f@ maps @f@
-- on @xs@ in parallel on the global thread pool, and concatenates the results.
parallelFoldMap :: (NFData r, Monoid r) => Int -> [a] -> (a -> IO r) -> IO r
parallelFoldMap n xs fn
| n >= 2 = do
input <- newIORef xs
mvars <- replicateM n newEmptyMVar
let getNextElem = atomicModifyIORef' input $ \case
[] -> ([], Nothing)
(y : ys) -> (ys, Just y)
loop mvar !acc = getNextElem >>= \case -- was (force -> !acc)
Nothing -> putMVar mvar acc
Just y -> do
!res <- fn y -- was: res <- fn y -- was (force -> !res) <- fn y
loop mvar (acc <> res)
forM_ ([0 ..] `zip` mvars) $ \(i, mvar) ->
forkOn i (loop mvar mempty)
mconcat <$> forM mvars takeMVar
| otherwise = mconcat <$> mapM fn xs
2 changes: 2 additions & 0 deletions asterius/src/Asterius/JSRun/NonMain.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ linkNonMain store_m extra_syms = (m, link_report)
Asterius.Ld.debug = False,
Asterius.Ld.gcSections = True,
Asterius.Ld.verboseErr = True,
Asterius.Ld.threadPoolSize = 1,
Asterius.Ld.outputIR = Nothing,
rootSymbols = extra_syms,
Asterius.Ld.exportFunctions = []
Expand All @@ -60,6 +61,7 @@ distNonMain p extra_syms =
yolo = True,
Asterius.Main.Task.hasMain = False,
Asterius.Main.Task.verboseErr = True,
Asterius.Main.Task.threadPoolSize = 1,
extraRootSymbols = extra_syms
}

Expand Down
17 changes: 13 additions & 4 deletions asterius/src/Asterius/Ld.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
Expand All @@ -18,11 +19,11 @@ import Asterius.Binary.File
import Asterius.Binary.NameCache
import Asterius.Builtins
import Asterius.Builtins.Main
import Asterius.Internals.Parallel
import Asterius.Resolve
import Asterius.Types
import qualified Asterius.Types.SymbolSet as SS
import Control.Exception
import Data.Either
import Data.Traversable

data LinkTask
Expand All @@ -31,6 +32,7 @@ data LinkTask
linkObjs, linkLibs :: [FilePath],
linkModule :: AsteriusCachedModule,
hasMain, debug, gcSections, verboseErr :: Bool,
threadPoolSize :: Int,
outputIR :: Maybe FilePath,
rootSymbols, exportFunctions :: [EntitySymbol]
}
Expand All @@ -45,9 +47,15 @@ data LinkTask
loadTheWorld :: LinkTask -> IO AsteriusCachedModule
loadTheWorld LinkTask {..} = do
ncu <- newNameCacheUpdater
lib <- mconcat <$> for linkLibs (loadAr ncu)
objs <- rights <$> for linkObjs (tryGetFile ncu)
evaluate $ linkModule <> mconcat objs <> lib
lib <- do
entries <- concat <$> for linkLibs loadArchiveEntries
parallelFoldMap threadPoolSize entries (loadArchiveEntry ncu)
objs <- parallelFoldMap threadPoolSize linkObjs (loadObj ncu)
evaluate $ linkModule <> objs <> lib
where
loadObj ncu path = tryGetFile ncu path >>= \case
Left {} -> pure mempty
Right m -> pure m

-- | The *_info are generated from Cmm using the INFO_TABLE macro.
-- For example, see StgMiscClosures.cmm / Exception.cmm
Expand Down Expand Up @@ -95,6 +103,7 @@ linkModules LinkTask {..} m =
debug
gcSections
verboseErr
threadPoolSize
( toCachedModule
( (if hasMain then mainBuiltins else mempty)
<> rtsAsteriusModule
Expand Down
7 changes: 7 additions & 0 deletions asterius/src/Asterius/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ parseTask args = case err_msgs of
in if i >= 0 && i <= 2
then t {shrinkLevel = i}
else error "Shrink level must be [0..2]",
str_opt "thread-pool-size" $ \s t ->
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget to pass the --thread-pool-size option to ahc-ld as well. ahc-ld is called as the linker executable by ahc, so search for --optl in this module and see how other ahc-ld options are being passed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Should be ok now.

let i = read s
in if i >= 1
then t {threadPoolSize = i}
else error "Thread pool size must be positive",
bool_opt "debug" $
\t ->
t
Expand Down Expand Up @@ -272,6 +277,7 @@ ahcLink task = do
]
<> ["-optl--no-gc-sections" | not (gcSections task)]
<> ["-optl--verbose-err" | verboseErr task]
<> ["-optl--thread-pool-size=" <> show (threadPoolSize task)]
<> extraGHCFlags task
<> [ "-optl--output-ir="
<> outputDirectory task
Expand Down Expand Up @@ -313,6 +319,7 @@ ahcDistMain logger task (final_m, report) = do
m_ref <-
Binaryen.marshalModule
(tailCalls task)
(threadPoolSize task)
(staticsSymbolMap report <> functionSymbolMap report)
final_m
when (optimizeLevel task > 0 || shrinkLevel task > 0) $ do
Expand Down
3 changes: 3 additions & 0 deletions asterius/src/Asterius/Main/Task.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ module Asterius.Main.Task
verboseErr,
yolo,
consoleHistory,
threadPoolSize,
extraGHCFlags,
exportFunctions,
extraRootSymbols,
Expand Down Expand Up @@ -53,6 +54,7 @@ data Task
outputDirectory :: FilePath,
outputBaseName :: String,
hasMain, validate, tailCalls, gcSections, bundle, debug, outputIR, run, verboseErr, yolo, consoleHistory :: Bool,
threadPoolSize :: Int,
extraGHCFlags :: [String],
exportFunctions, extraRootSymbols :: [EntitySymbol],
gcThreshold :: Int
Expand All @@ -79,6 +81,7 @@ defTask = Task
verboseErr = False,
yolo = False,
consoleHistory = False,
threadPoolSize = 1,
extraGHCFlags = [],
exportFunctions = [],
extraRootSymbols = [],
Expand Down
5 changes: 3 additions & 2 deletions asterius/src/Asterius/Resolve.hs
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,12 @@ linkStart ::
Bool ->
Bool ->
Bool ->
Int ->
AsteriusCachedModule ->
SS.SymbolSet ->
[EntitySymbol] ->
(AsteriusModule, Module, LinkReport)
linkStart debug gc_sections verbose_err store root_syms export_funcs =
linkStart debug gc_sections verbose_err pool_size store root_syms export_funcs =
( merged_m,
result_m,
mempty
Expand All @@ -126,7 +127,7 @@ linkStart debug gc_sections verbose_err store root_syms export_funcs =
merged_m0
| gc_sections = gcSections verbose_err store root_syms export_funcs
| otherwise = fromCachedModule store
!merged_m0_evaluated = force merged_m0
!merged_m0_evaluated = parForceAsteriusModule pool_size merged_m0
merged_m1
| debug = addMemoryTrap merged_m0_evaluated
| otherwise = merged_m0_evaluated
Expand Down
Loading