diff --git a/Control/Concurrent/Async.hs b/Control/Concurrent/Async.hs index bd1865b..71974e4 100644 --- a/Control/Concurrent/Async.hs +++ b/Control/Concurrent/Async.hs @@ -1,4 +1,6 @@ -{-# LANGUAGE CPP, MagicHash, UnboxedTuples, RankNTypes #-} +{-# LANGUAGE CPP, MagicHash, UnboxedTuples, RankNTypes, + ScopedTypeVariables, DeriveDataTypeable, + ExistentialQuantification #-} #if __GLASGOW_HASKELL__ >= 701 {-# LANGUAGE Trustworthy #-} #endif @@ -123,6 +125,10 @@ import Prelude hiding (catch) import Control.Monad import Control.Applicative import Data.Traversable +import Data.Typeable +import Data.Unique +import Data.IORef +import Unsafe.Coerce import GHC.Exts import GHC.IO hiding (finally, onException) @@ -184,9 +190,9 @@ asyncUsing doFork = \action -> do -- | Spawn an asynchronous action in a separate thread, and pass its -- @Async@ handle to the supplied function. When the function returns --- or throws an exception, 'cancel' is called on the @Async@. --- --- > withAsync action inner = bracket (async action) cancel inner +-- or throws a /synchronous/ exception, 'cancel' is called on the +-- @Async@. When the function throws an /asynchronous/ exception, the +-- exception is rethrown to the other thread. -- -- This is a useful variant of 'async' that ensures an @Async@ is -- never left running unintentionally. @@ -217,17 +223,29 @@ withAsyncOnWithUnmask cpu actionWith = withAsyncUsing (rawForkOn cpu) (actionWit withAsyncUsing :: (IO () -> IO ThreadId) -> IO a -> (Async a -> IO b) -> IO b --- The bracket version works, but is slow. We can do better by --- hand-coding it: withAsyncUsing doFork = \action inner -> do var <- newEmptyTMVarIO mask $ \restore -> do t <- doFork $ try (restore action) >>= atomically . putTMVar var let a = Async t (readTMVar var) - r <- restore (inner a) `catchAll` \e -> do cancel a; throwIO e + r <- restore (inner a) `catch` \e -> do + throwTo t $ normalizeAsyncEx e + throwIO e cancel a return r +-- | Converts non-asynchronous exception into ThreadKilled and leaves +-- asynchronous exceptions alone. +normalizeAsyncEx :: SomeException -> SomeException +normalizeAsyncEx e = + case fromException e of +#if MIN_VERSION_base(4,7,0) + Just (_ :: SomeAsyncException) -> e +#else + Just (_ :: AsyncException) -> e +#endif + Nothing -> toException ThreadKilled + -- | Wait for an asynchronous action to complete, and return its -- value. If the asynchronous action threw an exception, then the -- exception is re-thrown by 'wait'. @@ -486,44 +504,115 @@ concurrently left right = -- MVar versions of race/concurrently -- More ugly than the Async versions, but quite a bit faster. +-- @race left right@ forks a thread to perform the @left@ computation and +-- performs the @right@ computation in the current thread. When one of them +-- terminates, whether normally or by raising an exception, the other thread is +-- interrupt by way of a specially crafted asynchronous exception. +-- +-- More concretely: +-- +-- When @left@ terminates, whether normally or by raising an +-- exception, it wraps its result in the 'UniqueInterruptWithResult' +-- exception and throws that to the right thread. +-- +-- When the right thread catches the 'UniqueInterruptWithResult' +-- exception it will either throw the contained exception or return +-- the left result normally. +-- +-- When @right@ terminates normally or by a non-asynchronous exception +-- it kills the left thread. When it terminates with an asynchronous +-- exception the exception is thrown to the left thread. +-- +-- When @right@ terminates it has to throw an exception to the left +-- thread. However, the left thread will throw an exception back to +-- the right thread when it receives one. We don't want the left +-- thread to throw back an exception it just got from the right +-- thread. To protect against this we set a boolean before the right +-- thread throws an exception to the left thread. The left thread +-- checks this boolean before it needs to throw an exception to the +-- right thread. +-- +-- Because calls to @race@ can be nested it's important that different +-- 'UniqueInterruptWithResult' exceptions are not mixed-up. For this +-- reason each call to @race@ creates a 'Unique' value that gets +-- embedded in the interrupt exceptions being thrown. When catching +-- the interrupt exceptions we check if the Unique equals the Unique +-- of this invocation of @race@. (This is the same trick used in the +-- Timeout exception from System.Timeout). + -- race :: IO a -> IO b -> IO (Either a b) -race left right = concurrently' left right collect - where - collect m = do - e <- takeMVar m - case e of - Left ex -> throwIO ex - Right r -> return r +race left right = do + rightTid <- myThreadId + u <- newUnique + throwToRightRef <- newIORef True + let interruptRight r = do + throwToRight <- readIORef throwToRightRef + when throwToRight $ throwTo rightTid $ UniqueInterruptWithResult u r + mask $ \restore -> do + leftTid <- forkIO $ catch (restore left >>= interruptRight . Right) + (interruptRight . Left) + let interruptLeft e = do + writeIORef throwToRightRef False + throwTo leftTid e + catch (Right <$> restore right <* interruptLeft ThreadKilled) $ \e -> + case fromException e of + Just (UniqueInterruptWithResult u' leftResult) | u == u' + -> case leftResult of + Left ex -> throwIO ex + Right l -> return $ Left $ unsafeCoerce l + _ -> do interruptLeft $ normalizeAsyncEx e + throwIO e -- race_ :: IO a -> IO b -> IO () race_ left right = void $ race left right -- concurrently :: IO a -> IO b -> IO (a,b) -concurrently left right = concurrently' left right (collect []) - where - collect [Left a, Right b] _ = return (a,b) - collect [Right b, Left a] _ = return (a,b) - collect xs m = do - e <- takeMVar m - case e of - Left ex -> throwIO ex - Right r -> collect (r:xs) m - -concurrently' :: IO a -> IO b - -> (MVar (Either SomeException (Either a b)) -> IO r) - -> IO r -concurrently' left right collect = do - done <- newEmptyMVar +concurrently left right = do + leftResultMv <- newEmptyMVar + rightTid <- myThreadId + u <- newUnique + throwToRightRef <- newIORef True + let interruptRight e = do + throwToRight <- readIORef throwToRightRef + when throwToRight $ throwTo rightTid $ UniqueInterruptWithException u e mask $ \restore -> do - lid <- forkIO $ restore (left >>= putMVar done . Right . Left) - `catchAll` (putMVar done . Left) - rid <- forkIO $ restore (right >>= putMVar done . Right . Right) - `catchAll` (putMVar done . Left) - let stop = killThread lid >> killThread rid - r <- restore (collect done) `onException` stop - stop - return r + leftTid <- forkIO $ catch (restore left >>= putMVar leftResultMv) + interruptRight + let interruptLeft e = do + writeIORef throwToRightRef False + throwTo leftTid e + catch (flip (,) <$> restore right <*> takeMVar leftResultMv) $ \e -> + case fromException e of + Just (UniqueInterruptWithException u' ex) | u == u' + -> throwIO ex + _ -> do interruptLeft $ normalizeAsyncEx e + throwIO e + +data UniqueInterruptWithResult = + forall a. UniqueInterruptWithResult Unique (Either SomeException a) + deriving Typeable + +instance Show UniqueInterruptWithResult where + show _ = "<< UniqueInterruptWithResult >>" + +instance Exception UniqueInterruptWithResult where +#if MIN_VERSION_base(4,7,0) + toException = asyncExceptionToException + fromException = asyncExceptionFromException +#endif + +data UniqueInterruptWithException = + UniqueInterruptWithException Unique SomeException + deriving Typeable +instance Show UniqueInterruptWithException where + show _ = "<< UniqueInterruptWithException >>" + +instance Exception UniqueInterruptWithException where +#if MIN_VERSION_base(4,7,0) + toException = asyncExceptionToException + fromException = asyncExceptionFromException +#endif #endif -- | maps an @IO@-performing function over any @Traversable@ data @@ -590,9 +679,6 @@ forkRepeat action = _ -> return () in forkIO go -catchAll :: IO a -> (SomeException -> IO a) -> IO a -catchAll = catch - tryAll :: IO a -> IO (Either SomeException a) tryAll = try diff --git a/test/test-async.hs b/test/test-async.hs index 1791185..af077d6 100644 --- a/test/test-async.hs +++ b/test/test-async.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE ScopedTypeVariables,DeriveDataTypeable #-} +{-# LANGUAGE CPP, ScopedTypeVariables, DeriveDataTypeable #-} module Main where import Test.Framework (defaultMain, testGroup) @@ -9,9 +9,11 @@ import Test.HUnit import Control.Concurrent.Async import Control.Exception import Data.Typeable +import Data.IORef import Control.Concurrent import Control.Monad import Data.Maybe +import System.Timeout import Prelude hiding (catch) @@ -23,13 +25,52 @@ tests = [ , testCase "async_exwait" async_exwait , testCase "async_exwaitCatch" async_exwaitCatch , testCase "withasync_waitCatch" withasync_waitCatch - , testCase "withasync_wait2" withasync_wait2 + , testCase "withasync_functionReturned_threadKilled" + withasync_functionReturned_threadKilled + , testCase "withasync_synchronousException_threadKilled" + withasync_synchronousException_threadKilled + , testCase "withasync_asynchronousException_rethrown" + withasync_asynchronousException_rethrown +#if MIN_VERSION_base(4,7,0) + , testCase "withasync_timeoutException_rethrown" + withasync_timeoutException_rethrown +#endif , testGroup "async_cancel_rep" $ replicate 1000 $ testCase "async_cancel" async_cancel , testCase "async_poll" async_poll , testCase "async_poll2" async_poll2 , testCase "withasync_waitCatch_blocked" withasync_waitCatch_blocked + , testGroup "race" $ + [ testCase "right_terminate_normally" + race_right_terminate_normally + , testCase "left_terminate_normally" + race_left_terminate_normally + , testCase "right_terminate_by_synchronous_exception" + race_right_terminate_by_synchronous_exception + , testCase "left_terminate_by_synchronous_exception" + race_left_terminate_by_synchronous_exception + , testCase "right_terminates_normally_kills_left" + race_right_terminates_normally_kills_left + , testCase "left_terminates_normally_kills_right" + race_left_terminates_normally_kills_right + , testCase "right_terminates_by_asynchronous_exception_kills_both" + race_right_terminates_by_asynchronous_exception_kills_both + , testCase "left_terminates_by_asynchronous_exception_kills_right" + race_left_terminates_by_asynchronous_exception_kills_right + , testCase "left_receives_asynchronous_exception" + race_left_receives_asynchronous_exception + , testCase "nested" + race_nested + ] + , testGroup "concurrently" $ + [ testCase "1" concurrently_1 + , testCase "2" concurrently_2 + , testCase "left_receives_asynchronous_exception" + concurrently_left_receives_asynchronous_exception + , testCase "nested" + concurrently_nested + ] ] value = 42 :: Int @@ -72,14 +113,62 @@ withasync_waitCatch = do Left _ -> assertFailure "" Right e -> e @?= value -withasync_wait2 :: Assertion -withasync_wait2 = do +withasync_functionReturned_threadKilled :: Assertion +withasync_functionReturned_threadKilled = do a <- withAsync (threadDelay 1000000) $ return r <- waitCatch a case r of Left e -> fromException e @?= Just ThreadKilled Right _ -> assertFailure "" +withasync_synchronousException_threadKilled :: Assertion +withasync_synchronousException_threadKilled = do + mv <- newEmptyMVar + catchIgnore $ withAsync (threadDelay 1000000) $ \a -> do + putMVar mv a + throwIO DivideByZero + a <- takeMVar mv + r <- waitCatch a + case r of + Left e -> fromException e @?= Just ThreadKilled + Right _ -> assertFailure "" + +catchIgnore :: IO a -> IO () +catchIgnore m = void m `catch` \(e :: SomeException) -> return () + +withasync_asynchronousException_rethrown :: Assertion +withasync_asynchronousException_rethrown = do + mv <- newEmptyMVar + catchIgnore $ withAsync (threadDelay 1000000) $ \a -> do + putMVar mv a + throwIO UserInterrupt + a <- takeMVar mv + r <- waitCatch a + case r of + Left e -> fromException e @?= Just UserInterrupt + Right _ -> assertFailure "" + +#if MIN_VERSION_base(4,7,0) +-- This test requires the SomeAsyncException type +-- which is only available in base >= 4.7 +withasync_timeoutException_rethrown :: Assertion +withasync_timeoutException_rethrown = do + mv <- newEmptyMVar + timeout 100000 $ withAsync (threadDelay 1000000) $ \a -> do + putMVar mv a + threadDelay 1000000 + a <- takeMVar mv + r <- waitCatch a + case r of + Left e -> do + case fromException e of + Nothing -> assertFailure "" + Just (e :: SomeAsyncException) -> + -- e should be a Timeout exception + return () + Right _ -> assertFailure "" +#endif + async_cancel :: Assertion async_cancel = do a <- async (return value) @@ -115,3 +204,157 @@ withasync_waitCatch_blocked = do Just BlockedIndefinitelyOnMVar -> return () Nothing -> assertFailure $ show e Right () -> assertFailure "" + +race_right_terminate_normally :: Assertion +race_right_terminate_normally = do + r <- race (threadDelay 100000 >> return 1) + (threadDelay 10000 >> return 'x') + r @?= (Right 'x') + +race_left_terminate_normally :: Assertion +race_left_terminate_normally = do + r <- race (threadDelay 10000 >> return 1) + (threadDelay 100000 >> return 'x') + r @?= (Left 1) + +race_right_terminate_by_synchronous_exception :: Assertion +race_right_terminate_by_synchronous_exception = do + r <- try (race (threadDelay 100000 >> return 1) + (threadDelay 10000 >> throwIO DivideByZero)) + case r of + Left e -> e @?= DivideByZero + _ -> assertFailure "" + +race_left_terminate_by_synchronous_exception :: Assertion +race_left_terminate_by_synchronous_exception = do + r <- try (race (threadDelay 10000 >> throwIO DivideByZero) + (threadDelay 100000 >> return 'x')) + case r of + Left e -> e @?= DivideByZero + _ -> assertFailure "" + +race_right_terminates_normally_kills_left :: Assertion +race_right_terminates_normally_kills_left = do + ref <- newIORef False + r <- race (threadDelay 100000 >> writeIORef ref True) + (threadDelay 10000 >> return 'x') + leftCompleted <- readIORef ref + assertBool "" $ not leftCompleted && r == Right 'x' + +race_left_terminates_normally_kills_right :: Assertion +race_left_terminates_normally_kills_right = do + ref <- newIORef False + r <- race (threadDelay 10000 >> return 1) + (threadDelay 100000 >> writeIORef ref True) + rightCompleted <- readIORef ref + assertBool "" $ not rightCompleted && r == Left 1 + +race_right_terminates_by_asynchronous_exception_kills_both :: Assertion +race_right_terminates_by_asynchronous_exception_kills_both = do + leftRef <- newIORef False + rightRef <- newIORef False + + timeout 1000 $ + race (threadDelay 10000 >> writeIORef leftRef True) + (threadDelay 10000 >> writeIORef rightRef True) + + leftCompleted <- readIORef leftRef + rightCompleted <- readIORef rightRef + + assertBool "" $ not leftCompleted && not rightCompleted + +race_left_terminates_by_asynchronous_exception_kills_right :: Assertion +race_left_terminates_by_asynchronous_exception_kills_right = do + mv <- newEmptyMVar + + forkIO $ do + threadDelay 100000 + leftTid <- takeMVar mv + throwTo leftTid ThreadKilled + + r <- try $ race (do leftTid <- myThreadId + putMVar mv leftTid + threadDelay 1000000 + return 1) + (do threadDelay 1000000 + return 'x') + + r @?= Left ThreadKilled + +race_left_receives_asynchronous_exception :: Assertion +race_left_receives_asynchronous_exception = do + rightTidMv <- newEmptyMVar + + exMv <- newEmptyMVar + + forkIO $ do + threadDelay 1000 + rightTid <- takeMVar rightTidMv + throwTo rightTid UserInterrupt + + catchIgnore $ + race (threadDelay 100000 `catch` putMVar exMv) + (do rightTid <- myThreadId + putMVar rightTidMv rightTid + threadDelay 10000) + + ex <- takeMVar exMv + + ex @?= UserInterrupt + +race_nested :: Assertion +race_nested = do + r <- race (threadDelay 1000 >> return 1) + (race (threadDelay 10000 >> return 'x') + (threadDelay 100000 >> return False) + ) + r @?= Left 1 + +concurrently_1 :: Assertion +concurrently_1 = do + r <- concurrently (threadDelay 1000 >> return 1) + (threadDelay 1000 >> return 'x') + r @?= (1, 'x') + +concurrently_2 :: Assertion +concurrently_2 = do + void $ timeout 1000 $ + concurrently (threadDelay 10000) + (threadDelay 10000) + threadDelay 10000 + +concurrently_left_receives_asynchronous_exception :: Assertion +concurrently_left_receives_asynchronous_exception = do + rightTidMv <- newEmptyMVar + + exMv <- newEmptyMVar + + forkIO $ do + threadDelay 1000 + rightTid <- takeMVar rightTidMv + throwTo rightTid UserInterrupt + + catchIgnore $ + concurrently (threadDelay 100000 `catch` putMVar exMv) + (do rightTid <- myThreadId + putMVar rightTidMv rightTid + threadDelay 10000) + + ex <- takeMVar exMv + + ex @?= UserInterrupt + +concurrently_nested :: Assertion +concurrently_nested = do + ref <- newIORef False + r <- try $ concurrently (threadDelay 1000 >> throwIO DivideByZero) + (concurrently (threadDelay 10000 >> writeIORef ref True) + (threadDelay 100000 >> return False)) + threadDelay 100000 + case r of + Left e -> case fromException e of + Just DivideByZero -> do + middleCompleted <- readIORef ref + assertBool "Middle completed!" $ not middleCompleted + _ -> assertFailure "" + Right _ -> assertFailure ""