Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 126 additions & 40 deletions Control/Concurrent/Async.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
{-# LANGUAGE CPP, MagicHash, UnboxedTuples, RankNTypes #-}
{-# LANGUAGE CPP, MagicHash, UnboxedTuples, RankNTypes,
ScopedTypeVariables, DeriveDataTypeable,
ExistentialQuantification #-}
#if __GLASGOW_HASKELL__ >= 701
{-# LANGUAGE Trustworthy #-}
#endif
Expand Down Expand Up @@ -123,6 +125,10 @@ import Prelude hiding (catch)
import Control.Monad
import Control.Applicative
import Data.Traversable
import Data.Typeable
import Data.Unique
import Data.IORef
import Unsafe.Coerce

import GHC.Exts
import GHC.IO hiding (finally, onException)
Expand Down Expand Up @@ -184,9 +190,9 @@ asyncUsing doFork = \action -> do

-- | Spawn an asynchronous action in a separate thread, and pass its
-- @Async@ handle to the supplied function. When the function returns
-- or throws an exception, 'cancel' is called on the @Async@.
--
-- > withAsync action inner = bracket (async action) cancel inner
-- or throws a /synchronous/ exception, 'cancel' is called on the
-- @Async@. When the function throws an /asynchronous/ exception, the
-- exception is rethrown to the other thread.
--
-- This is a useful variant of 'async' that ensures an @Async@ is
-- never left running unintentionally.
Expand Down Expand Up @@ -217,17 +223,29 @@ withAsyncOnWithUnmask cpu actionWith = withAsyncUsing (rawForkOn cpu) (actionWit

withAsyncUsing :: (IO () -> IO ThreadId)
-> IO a -> (Async a -> IO b) -> IO b
-- The bracket version works, but is slow. We can do better by
-- hand-coding it:
withAsyncUsing doFork = \action inner -> do
var <- newEmptyTMVarIO
mask $ \restore -> do
t <- doFork $ try (restore action) >>= atomically . putTMVar var
let a = Async t (readTMVar var)
r <- restore (inner a) `catchAll` \e -> do cancel a; throwIO e
r <- restore (inner a) `catch` \e -> do
throwTo t $ normalizeAsyncEx e
throwIO e
cancel a
return r

-- | Converts non-asynchronous exception into ThreadKilled and leaves
-- asynchronous exceptions alone.
normalizeAsyncEx :: SomeException -> SomeException
normalizeAsyncEx e =
case fromException e of
#if MIN_VERSION_base(4,7,0)
Just (_ :: SomeAsyncException) -> e
#else
Just (_ :: AsyncException) -> e
#endif
Nothing -> toException ThreadKilled

-- | Wait for an asynchronous action to complete, and return its
-- value. If the asynchronous action threw an exception, then the
-- exception is re-thrown by 'wait'.
Expand Down Expand Up @@ -486,44 +504,115 @@ concurrently left right =
-- MVar versions of race/concurrently
-- More ugly than the Async versions, but quite a bit faster.

-- @race left right@ forks a thread to perform the @left@ computation and
-- performs the @right@ computation in the current thread. When one of them
-- terminates, whether normally or by raising an exception, the other thread is
-- interrupt by way of a specially crafted asynchronous exception.
--
-- More concretely:
--
-- When @left@ terminates, whether normally or by raising an
-- exception, it wraps its result in the 'UniqueInterruptWithResult'
-- exception and throws that to the right thread.
--
-- When the right thread catches the 'UniqueInterruptWithResult'
-- exception it will either throw the contained exception or return
-- the left result normally.
--
-- When @right@ terminates normally or by a non-asynchronous exception
-- it kills the left thread. When it terminates with an asynchronous
-- exception the exception is thrown to the left thread.
--
-- When @right@ terminates it has to throw an exception to the left
-- thread. However, the left thread will throw an exception back to
-- the right thread when it receives one. We don't want the left
-- thread to throw back an exception it just got from the right
-- thread. To protect against this we set a boolean before the right
-- thread throws an exception to the left thread. The left thread
-- checks this boolean before it needs to throw an exception to the
-- right thread.
--
-- Because calls to @race@ can be nested it's important that different
-- 'UniqueInterruptWithResult' exceptions are not mixed-up. For this
-- reason each call to @race@ creates a 'Unique' value that gets
-- embedded in the interrupt exceptions being thrown. When catching
-- the interrupt exceptions we check if the Unique equals the Unique
-- of this invocation of @race@. (This is the same trick used in the
-- Timeout exception from System.Timeout).

-- race :: IO a -> IO b -> IO (Either a b)
race left right = concurrently' left right collect
where
collect m = do
e <- takeMVar m
case e of
Left ex -> throwIO ex
Right r -> return r
race left right = do
rightTid <- myThreadId
u <- newUnique
throwToRightRef <- newIORef True
let interruptRight r = do
throwToRight <- readIORef throwToRightRef
when throwToRight $ throwTo rightTid $ UniqueInterruptWithResult u r
mask $ \restore -> do
leftTid <- forkIO $ catch (restore left >>= interruptRight . Right)
(interruptRight . Left)
let interruptLeft e = do
writeIORef throwToRightRef False
throwTo leftTid e
catch (Right <$> restore right <* interruptLeft ThreadKilled) $ \e ->
case fromException e of
Just (UniqueInterruptWithResult u' leftResult) | u == u'
-> case leftResult of
Left ex -> throwIO ex
Right l -> return $ Left $ unsafeCoerce l
_ -> do interruptLeft $ normalizeAsyncEx e
throwIO e

-- race_ :: IO a -> IO b -> IO ()
race_ left right = void $ race left right

-- concurrently :: IO a -> IO b -> IO (a,b)
concurrently left right = concurrently' left right (collect [])
where
collect [Left a, Right b] _ = return (a,b)
collect [Right b, Left a] _ = return (a,b)
collect xs m = do
e <- takeMVar m
case e of
Left ex -> throwIO ex
Right r -> collect (r:xs) m

concurrently' :: IO a -> IO b
-> (MVar (Either SomeException (Either a b)) -> IO r)
-> IO r
concurrently' left right collect = do
done <- newEmptyMVar
concurrently left right = do
leftResultMv <- newEmptyMVar
rightTid <- myThreadId
u <- newUnique
throwToRightRef <- newIORef True
let interruptRight e = do
throwToRight <- readIORef throwToRightRef
when throwToRight $ throwTo rightTid $ UniqueInterruptWithException u e
mask $ \restore -> do
lid <- forkIO $ restore (left >>= putMVar done . Right . Left)
`catchAll` (putMVar done . Left)
rid <- forkIO $ restore (right >>= putMVar done . Right . Right)
`catchAll` (putMVar done . Left)
let stop = killThread lid >> killThread rid
r <- restore (collect done) `onException` stop
stop
return r
leftTid <- forkIO $ catch (restore left >>= putMVar leftResultMv)
interruptRight
let interruptLeft e = do
writeIORef throwToRightRef False
throwTo leftTid e
catch (flip (,) <$> restore right <*> takeMVar leftResultMv) $ \e ->
case fromException e of
Just (UniqueInterruptWithException u' ex) | u == u'
-> throwIO ex
_ -> do interruptLeft $ normalizeAsyncEx e
throwIO e

data UniqueInterruptWithResult =
forall a. UniqueInterruptWithResult Unique (Either SomeException a)
deriving Typeable

instance Show UniqueInterruptWithResult where
show _ = "<< UniqueInterruptWithResult >>"

instance Exception UniqueInterruptWithResult where
#if MIN_VERSION_base(4,7,0)
toException = asyncExceptionToException
fromException = asyncExceptionFromException
#endif

data UniqueInterruptWithException =
UniqueInterruptWithException Unique SomeException
deriving Typeable

instance Show UniqueInterruptWithException where
show _ = "<< UniqueInterruptWithException >>"

instance Exception UniqueInterruptWithException where
#if MIN_VERSION_base(4,7,0)
toException = asyncExceptionToException
fromException = asyncExceptionFromException
#endif
#endif

-- | maps an @IO@-performing function over any @Traversable@ data
Expand Down Expand Up @@ -590,9 +679,6 @@ forkRepeat action =
_ -> return ()
in forkIO go

catchAll :: IO a -> (SomeException -> IO a) -> IO a
catchAll = catch

tryAll :: IO a -> IO (Either SomeException a)
tryAll = try

Expand Down
Loading