From fbedfb51b554c48c184c1715317b7d08ef011cf9 Mon Sep 17 00:00:00 2001 From: Bas van Dijk Date: Thu, 9 Oct 2014 13:22:56 +0200 Subject: [PATCH 01/12] Change semantics of withAsync & double performance of race and concurrently Previously in "withAsync action inner" when "inner" would receive an asynchronous exception "action" would be killed. Now "action" will receive the same asynchronous exception has "inner". It's still the case that when"inner" receives a synchronous exception "action" will be killed. This change in semantics allows for an implementation of "race" and "concurrently" which is twice as fast because we can now fork a single thread instead of two. This patch was co-written with my colleague @asayers. --- Control/Concurrent/Async.hs | 171 ++++++++++++++++++++++++++++-------- test/test-async.hs | 163 +++++++++++++++++++++++++++++++++- 2 files changed, 291 insertions(+), 43 deletions(-) diff --git a/Control/Concurrent/Async.hs b/Control/Concurrent/Async.hs index bd1865b..0eb62ea 100644 --- a/Control/Concurrent/Async.hs +++ b/Control/Concurrent/Async.hs @@ -1,4 +1,6 @@ -{-# LANGUAGE CPP, MagicHash, UnboxedTuples, RankNTypes #-} +{-# LANGUAGE CPP, MagicHash, UnboxedTuples, RankNTypes, + ScopedTypeVariables, DeriveDataTypeable, + ExistentialQuantification #-} #if __GLASGOW_HASKELL__ >= 701 {-# LANGUAGE Trustworthy #-} #endif @@ -123,6 +125,8 @@ import Prelude hiding (catch) import Control.Monad import Control.Applicative import Data.Traversable +import Data.Typeable +import Data.Unique import GHC.Exts import GHC.IO hiding (finally, onException) @@ -184,9 +188,9 @@ asyncUsing doFork = \action -> do -- | Spawn an asynchronous action in a separate thread, and pass its -- @Async@ handle to the supplied function. When the function returns --- or throws an exception, 'cancel' is called on the @Async@. --- --- > withAsync action inner = bracket (async action) cancel inner +-- or throws a /synchronous/ exception, 'cancel' is called on the +-- @Async@. When the function throws an /asynchronous/ exception, the +-- exception is rethrown to the other thread. -- -- This is a useful variant of 'async' that ensures an @Async@ is -- never left running unintentionally. @@ -224,10 +228,26 @@ withAsyncUsing doFork = \action inner -> do mask $ \restore -> do t <- doFork $ try (restore action) >>= atomically . putTMVar var let a = Async t (readTMVar var) - r <- restore (inner a) `catchAll` \e -> do cancel a; throwIO e + r <- restore (inner a) `alsoThrowingTo` t cancel a return r +-- | If the given action throws an asynchronous exception then also +-- throw it to the specified thread. If it throws a synchronous +-- exception then kill the specified thread. +alsoThrowingTo :: IO a -> ThreadId -> IO a +m `alsoThrowingTo` tid = m `catch` handler + where + handler e = do + case fromException e of +# if MIN_VERSION_base(4,7,0) + Just (_ :: SomeAsyncException) -> throwTo tid e +# else + Just (_ :: AsyncException) -> throwTo tid e +# endif + Nothing -> throwTo tid ThreadKilled + throwIO e + -- | Wait for an asynchronous action to complete, and return its -- value. If the asynchronous action threw an exception, then the -- exception is re-thrown by 'wait'. @@ -486,44 +506,120 @@ concurrently left right = -- MVar versions of race/concurrently -- More ugly than the Async versions, but quite a bit faster. +-- @race left right@ forks a thread to perform the @left@ computation and +-- performs the @right@ computation in the current thread. When one of them +-- terminates, whether normally or by raising an exception, the other thread is +-- interrupt by way of a specially crafted asynchronous exception. +-- +-- More concretely: +-- +-- * When @left@ terminates normally it puts its result in an MVar and throws +-- the 'InterruptRight' exception to the right thread. +-- +-- * When @left@ terminates by an exception @e@ it throws the 'InterruptRight' +-- exception (containing the exception @e@) to the right thread. +-- +-- * When the right thread catches the 'InterruptRight' exception it will check +-- for the optional exception thrown in the left thread and throw it if it's +-- there. When it's not there it means the left thread terminated normally and +-- the left result can be retrieved by taking the MVar. +-- +-- Instead of putting the left result inside an MVar, another implementation +-- is to put the result in the 'InterruptRight' exception. The right thread +-- can then take out and return this result when it catches the exception. +-- This does require the use of 'unsafeCoerce' to trick the type-system which +-- is why I haven't used this approach. +-- +-- * When @right@ terminates normally it throws an 'InterruptLeft' exception to +-- the left thread in order to stop that thread from doing any more work. +-- +-- * When @right@ throws an exception it is catched an thrown to the left thread +-- contained in an 'InterruptLeft' exception. +-- +-- The exact exception that gets contained in the 'InterruptLeft' exception is +-- dependent on the type of exception being thrown: if an asynchronous +-- exception was thrown the exception itself gets contained. For synchonous +-- exceptions the 'ThreadKilled' exception is contained. This is to mimic the +-- behaviour of 'withAsync'. +-- +-- * When the left thread catches the 'InterruptLeft' exception it throws the +-- contained exception. +-- +-- Because calls to @race@ can be nested it's important that different +-- 'InterruptLeft' or 'InterruptRight' exceptions are not mixed-up. For this +-- reason each call to @race@ creates a 'Unique' value that gets embedded in the +-- interrupt exceptions being thrown. When catching the interrupt exceptions we +-- check if the Unique equals the Unique of this invocation of @race@. (This is +-- the same trick used in the Timeout exception from System.Timeout). + -- race :: IO a -> IO b -> IO (Either a b) -race left right = concurrently' left right collect - where - collect m = do - e <- takeMVar m - case e of - Left ex -> throwIO ex - Right r -> return r +race left right = do + leftResultMVar <- newEmptyMVar + rightTid <- myThreadId + u <- newUnique + mask $ \restore -> do + leftTid <- forkIO $ + catch (do restore left >>= putMVar leftResultMVar + throwTo rightTid $ InterruptRight u Nothing) $ \e -> + case fromException e of + Just (InterruptLeft u' rightEx) | u == u' -> throwIO rightEx + _ -> do throwTo rightTid $ InterruptRight u (Just e) + throwIO e + catch (do r <- restore right + throwTo leftTid $ InterruptLeft u ThreadKilled + return $ Right r) $ \e -> + case fromException e of + Just (InterruptRight u' mbEx) | u == u' -> + case mbEx of + Just leftEx -> throwIO leftEx + Nothing -> Left <$> takeMVar leftResultMVar + _ -> do + case fromException e of +# if MIN_VERSION_base(4,7,0) + Just (_ :: SomeAsyncException) + -> throwTo leftTid $ InterruptLeft u e +# else + Just (_ :: AsyncException) + -> throwTo leftTid $ InterruptLeft u e +# endif + Nothing + -> throwTo leftTid $ InterruptLeft u ThreadKilled + throwIO e + +data InterruptLeft = forall e. (Exception e) => InterruptLeft Unique e + deriving (Typeable) + +instance Show InterruptLeft where + show _ = "<< InterruptLeft >>" + +instance Exception InterruptLeft where +#if MIN_VERSION_base(4,7,0) + toException = asyncExceptionToException + fromException = asyncExceptionFromException +#endif + +data InterruptRight = InterruptRight Unique (Maybe SomeException) + deriving Typeable + +instance Show InterruptRight where + show _ = "<< InterruptRight >>" + +instance Exception InterruptRight where +#if MIN_VERSION_base(4,7,0) + toException = asyncExceptionToException + fromException = asyncExceptionFromException +#endif -- race_ :: IO a -> IO b -> IO () race_ left right = void $ race left right -- concurrently :: IO a -> IO b -> IO (a,b) -concurrently left right = concurrently' left right (collect []) - where - collect [Left a, Right b] _ = return (a,b) - collect [Right b, Left a] _ = return (a,b) - collect xs m = do - e <- takeMVar m - case e of - Left ex -> throwIO ex - Right r -> collect (r:xs) m - -concurrently' :: IO a -> IO b - -> (MVar (Either SomeException (Either a b)) -> IO r) - -> IO r -concurrently' left right collect = do - done <- newEmptyMVar +concurrently left right = do + mv <- newEmptyMVar + rightTid <- myThreadId mask $ \restore -> do - lid <- forkIO $ restore (left >>= putMVar done . Right . Left) - `catchAll` (putMVar done . Left) - rid <- forkIO $ restore (right >>= putMVar done . Right . Right) - `catchAll` (putMVar done . Left) - let stop = killThread lid >> killThread rid - r <- restore (collect done) `onException` stop - stop - return r - + leftTid <- forkIO $ restore left `alsoThrowingTo` rightTid >>= putMVar mv + (flip (,) <$> restore right <*> takeMVar mv) `alsoThrowingTo` leftTid #endif -- | maps an @IO@-performing function over any @Traversable@ data @@ -590,9 +686,6 @@ forkRepeat action = _ -> return () in forkIO go -catchAll :: IO a -> (SomeException -> IO a) -> IO a -catchAll = catch - tryAll :: IO a -> IO (Either SomeException a) tryAll = try diff --git a/test/test-async.hs b/test/test-async.hs index 1791185..573f9f3 100644 --- a/test/test-async.hs +++ b/test/test-async.hs @@ -1,4 +1,4 @@ -{-# LANGUAGE ScopedTypeVariables,DeriveDataTypeable #-} +{-# LANGUAGE CPP, ScopedTypeVariables, DeriveDataTypeable #-} module Main where import Test.Framework (defaultMain, testGroup) @@ -9,9 +9,13 @@ import Test.HUnit import Control.Concurrent.Async import Control.Exception import Data.Typeable +import Data.IORef import Control.Concurrent import Control.Monad import Data.Maybe +#if MIN_VERSION_base(4,7,0) +import System.Timeout +#endif import Prelude hiding (catch) @@ -23,13 +27,40 @@ tests = [ , testCase "async_exwait" async_exwait , testCase "async_exwaitCatch" async_exwaitCatch , testCase "withasync_waitCatch" withasync_waitCatch - , testCase "withasync_wait2" withasync_wait2 + , testCase "withasync_functionReturned_threadKilled" + withasync_functionReturned_threadKilled + , testCase "withasync_synchronousException_threadKilled" + withasync_synchronousException_threadKilled + , testCase "withasync_asynchronousException_rethrown" + withasync_asynchronousException_rethrown +#if MIN_VERSION_base(4,7,0) + , testCase "withasync_timeoutException_rethrown" + withasync_timeoutException_rethrown +#endif , testGroup "async_cancel_rep" $ replicate 1000 $ testCase "async_cancel" async_cancel , testCase "async_poll" async_poll , testCase "async_poll2" async_poll2 , testCase "withasync_waitCatch_blocked" withasync_waitCatch_blocked + , testGroup "race" $ + [ testCase "right_terminate_normally" + race_right_terminate_normally + , testCase "left_terminate_normally" + race_left_terminate_normally + , testCase "right_terminate_by_synchronous_exception" + race_right_terminate_by_synchronous_exception + , testCase "left_terminate_by_synchronous_exception" + race_left_terminate_by_synchronous_exception + , testCase "right_terminates_normally_kills_left" + race_right_terminates_normally_kills_left + , testCase "left_terminates_normally_kills_right" + race_left_terminates_normally_kills_right + , testCase "right_terminates_by_asynchronous_exception_kills_both" + race_right_terminates_by_asynchronous_exception_kills_both + , testCase "left_terminates_by_asynchronous_exception_kills_right" + race_left_terminates_by_asynchronous_exception_kills_right + ] ] value = 42 :: Int @@ -72,14 +103,62 @@ withasync_waitCatch = do Left _ -> assertFailure "" Right e -> e @?= value -withasync_wait2 :: Assertion -withasync_wait2 = do +withasync_functionReturned_threadKilled :: Assertion +withasync_functionReturned_threadKilled = do a <- withAsync (threadDelay 1000000) $ return r <- waitCatch a case r of Left e -> fromException e @?= Just ThreadKilled Right _ -> assertFailure "" +withasync_synchronousException_threadKilled :: Assertion +withasync_synchronousException_threadKilled = do + mv <- newEmptyMVar + catchIgnore $ withAsync (threadDelay 1000000) $ \a -> do + putMVar mv a + throwIO DivideByZero + a <- takeMVar mv + r <- waitCatch a + case r of + Left e -> fromException e @?= Just ThreadKilled + Right _ -> assertFailure "" + +catchIgnore :: IO a -> IO () +catchIgnore m = void m `catch` \(e :: SomeException) -> return () + +withasync_asynchronousException_rethrown :: Assertion +withasync_asynchronousException_rethrown = do + mv <- newEmptyMVar + catchIgnore $ withAsync (threadDelay 1000000) $ \a -> do + putMVar mv a + throwIO UserInterrupt + a <- takeMVar mv + r <- waitCatch a + case r of + Left e -> fromException e @?= Just UserInterrupt + Right _ -> assertFailure "" + +#if MIN_VERSION_base(4,7,0) +-- This test requires the SomeAsyncException type +-- which is only available in base >= 4.7 +withasync_timeoutException_rethrown :: Assertion +withasync_timeoutException_rethrown = do + mv <- newEmptyMVar + timeout 100000 $ withAsync (threadDelay 1000000) $ \a -> do + putMVar mv a + threadDelay 1000000 + a <- takeMVar mv + r <- waitCatch a + case r of + Left e -> do + case fromException e of + Nothing -> assertFailure "" + Just (e :: SomeAsyncException) -> + -- e should be a Timeout exception + return () + Right _ -> assertFailure "" +#endif + async_cancel :: Assertion async_cancel = do a <- async (return value) @@ -115,3 +194,79 @@ withasync_waitCatch_blocked = do Just BlockedIndefinitelyOnMVar -> return () Nothing -> assertFailure $ show e Right () -> assertFailure "" + +race_right_terminate_normally :: Assertion +race_right_terminate_normally = do + r <- race (threadDelay 100000 >> return 1) + (threadDelay 10000 >> return 'x') + r @?= (Right 'x') + +race_left_terminate_normally :: Assertion +race_left_terminate_normally = do + r <- race (threadDelay 10000 >> return 1) + (threadDelay 100000 >> return 'x') + r @?= (Left 1) + +race_right_terminate_by_synchronous_exception :: Assertion +race_right_terminate_by_synchronous_exception = do + r <- try (race (threadDelay 100000 >> return 1) + (threadDelay 10000 >> throwIO DivideByZero)) + case r of + Left e -> e @?= DivideByZero + _ -> assertFailure "" + +race_left_terminate_by_synchronous_exception :: Assertion +race_left_terminate_by_synchronous_exception = do + r <- try (race (threadDelay 10000 >> throwIO DivideByZero) + (threadDelay 100000 >> return 'x')) + case r of + Left e -> e @?= DivideByZero + _ -> assertFailure "" + +race_right_terminates_normally_kills_left :: Assertion +race_right_terminates_normally_kills_left = do + ref <- newIORef False + r <- race (threadDelay 100000 >> writeIORef ref True) + (threadDelay 10000 >> return 'x') + leftCompleted <- readIORef ref + assertBool "" $ not leftCompleted && r == Right 'x' + +race_left_terminates_normally_kills_right :: Assertion +race_left_terminates_normally_kills_right = do + ref <- newIORef False + r <- race (threadDelay 10000 >> return 1) + (threadDelay 100000 >> writeIORef ref True) + rightCompleted <- readIORef ref + assertBool "" $ not rightCompleted && r == Left 1 + +race_right_terminates_by_asynchronous_exception_kills_both :: Assertion +race_right_terminates_by_asynchronous_exception_kills_both = do + leftRef <- newIORef False + rightRef <- newIORef False + + timeout 1000 $ + race (threadDelay 10000 >> writeIORef leftRef True) + (threadDelay 10000 >> writeIORef rightRef True) + + leftCompleted <- readIORef leftRef + rightCompleted <- readIORef rightRef + + assertBool "" $ not leftCompleted && not rightCompleted + +race_left_terminates_by_asynchronous_exception_kills_right :: Assertion +race_left_terminates_by_asynchronous_exception_kills_right = do + mv <- newEmptyMVar + + forkIO $ do + threadDelay 100000 + leftTid <- takeMVar mv + throwTo leftTid ThreadKilled + + r <- try $ race (do leftTid <- myThreadId + putMVar mv leftTid + threadDelay 1000000 + return 1) + (do threadDelay 1000000 + return 'x') + + r @?= Left ThreadKilled From 6a093b6abc6f3fe1b732d08555c0c4adcf99a79d Mon Sep 17 00:00:00 2001 From: Bas van Dijk Date: Thu, 9 Oct 2014 15:00:51 +0200 Subject: [PATCH 02/12] Always import System.Timeout --- test/test-async.hs | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test-async.hs b/test/test-async.hs index 573f9f3..8e644f9 100644 --- a/test/test-async.hs +++ b/test/test-async.hs @@ -13,9 +13,7 @@ import Data.IORef import Control.Concurrent import Control.Monad import Data.Maybe -#if MIN_VERSION_base(4,7,0) import System.Timeout -#endif import Prelude hiding (catch) From d9023ab4a17ec4bd433bbb83cdd244bab8c9c27a Mon Sep 17 00:00:00 2001 From: Bas van Dijk Date: Thu, 9 Oct 2014 17:17:16 +0200 Subject: [PATCH 03/12] Fixed bug in concurrently --- Control/Concurrent/Async.hs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/Control/Concurrent/Async.hs b/Control/Concurrent/Async.hs index 0eb62ea..469e497 100644 --- a/Control/Concurrent/Async.hs +++ b/Control/Concurrent/Async.hs @@ -618,7 +618,8 @@ concurrently left right = do mv <- newEmptyMVar rightTid <- myThreadId mask $ \restore -> do - leftTid <- forkIO $ restore left `alsoThrowingTo` rightTid >>= putMVar mv + leftTid <- forkIO $ (restore left >>= putMVar mv) + `catchAll` throwTo rightTid (flip (,) <$> restore right <*> takeMVar mv) `alsoThrowingTo` leftTid #endif @@ -686,6 +687,9 @@ forkRepeat action = _ -> return () in forkIO go +catchAll :: IO a -> (SomeException -> IO a) -> IO a +catchAll = catch + tryAll :: IO a -> IO (Either SomeException a) tryAll = try From c909e3089bc21470d86cced849039442e1c6fdd7 Mon Sep 17 00:00:00 2001 From: Bas van Dijk Date: Thu, 9 Oct 2014 22:29:05 +0200 Subject: [PATCH 04/12] Transfer the result of the left thread via an exception --- Control/Concurrent/Async.hs | 39 ++++++++++++++----------------------- 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/Control/Concurrent/Async.hs b/Control/Concurrent/Async.hs index 469e497..7351812 100644 --- a/Control/Concurrent/Async.hs +++ b/Control/Concurrent/Async.hs @@ -127,6 +127,7 @@ import Control.Applicative import Data.Traversable import Data.Typeable import Data.Unique +import Unsafe.Coerce import GHC.Exts import GHC.IO hiding (finally, onException) @@ -513,22 +514,13 @@ concurrently left right = -- -- More concretely: -- --- * When @left@ terminates normally it puts its result in an MVar and throws --- the 'InterruptRight' exception to the right thread. +-- * When @left@ terminates, whether normally or by raising an +-- exception, it wraps its result in the 'InterruptRight' exception +-- and throws that to the right thread. -- --- * When @left@ terminates by an exception @e@ it throws the 'InterruptRight' --- exception (containing the exception @e@) to the right thread. --- --- * When the right thread catches the 'InterruptRight' exception it will check --- for the optional exception thrown in the left thread and throw it if it's --- there. When it's not there it means the left thread terminated normally and --- the left result can be retrieved by taking the MVar. --- --- Instead of putting the left result inside an MVar, another implementation --- is to put the result in the 'InterruptRight' exception. The right thread --- can then take out and return this result when it catches the exception. --- This does require the use of 'unsafeCoerce' to trick the type-system which --- is why I haven't used this approach. +-- * When the right thread catches the 'InterruptRight' exception it +-- will either throw the contained exception or return the left result +-- normally. -- -- * When @right@ terminates normally it throws an 'InterruptLeft' exception to -- the left thread in order to stop that thread from doing any more work. @@ -554,25 +546,24 @@ concurrently left right = -- race :: IO a -> IO b -> IO (Either a b) race left right = do - leftResultMVar <- newEmptyMVar rightTid <- myThreadId u <- newUnique mask $ \restore -> do leftTid <- forkIO $ - catch (do restore left >>= putMVar leftResultMVar - throwTo rightTid $ InterruptRight u Nothing) $ \e -> + catch (do l <- restore left + throwTo rightTid $ InterruptRight u $ Right l) $ \e -> case fromException e of Just (InterruptLeft u' rightEx) | u == u' -> throwIO rightEx - _ -> do throwTo rightTid $ InterruptRight u (Just e) + _ -> do throwTo rightTid $ InterruptRight u (Left e) throwIO e catch (do r <- restore right throwTo leftTid $ InterruptLeft u ThreadKilled return $ Right r) $ \e -> case fromException e of - Just (InterruptRight u' mbEx) | u == u' -> - case mbEx of - Just leftEx -> throwIO leftEx - Nothing -> Left <$> takeMVar leftResultMVar + Just (InterruptRight u' leftResult) | u == u' -> + case leftResult of + Left ex -> throwIO ex + Right l -> return $ Left $ unsafeCoerce l _ -> do case fromException e of # if MIN_VERSION_base(4,7,0) @@ -598,7 +589,7 @@ instance Exception InterruptLeft where fromException = asyncExceptionFromException #endif -data InterruptRight = InterruptRight Unique (Maybe SomeException) +data InterruptRight = forall a. InterruptRight Unique (Either SomeException a) deriving Typeable instance Show InterruptRight where From fb2b7e86ddaf07c613cc4b78665d938305370957 Mon Sep 17 00:00:00 2001 From: Bas van Dijk Date: Thu, 9 Oct 2014 23:15:16 +0200 Subject: [PATCH 05/12] Simplified race --- Control/Concurrent/Async.hs | 39 +++++++++---------------------------- 1 file changed, 9 insertions(+), 30 deletions(-) diff --git a/Control/Concurrent/Async.hs b/Control/Concurrent/Async.hs index 7351812..616fd5c 100644 --- a/Control/Concurrent/Async.hs +++ b/Control/Concurrent/Async.hs @@ -522,20 +522,9 @@ concurrently left right = -- will either throw the contained exception or return the left result -- normally. -- --- * When @right@ terminates normally it throws an 'InterruptLeft' exception to --- the left thread in order to stop that thread from doing any more work. --- --- * When @right@ throws an exception it is catched an thrown to the left thread --- contained in an 'InterruptLeft' exception. --- --- The exact exception that gets contained in the 'InterruptLeft' exception is --- dependent on the type of exception being thrown: if an asynchronous --- exception was thrown the exception itself gets contained. For synchonous --- exceptions the 'ThreadKilled' exception is contained. This is to mimic the --- behaviour of 'withAsync'. --- --- * When the left thread catches the 'InterruptLeft' exception it throws the --- contained exception. +-- * When @right@ terminates, whether normally or by raising an +-- exception, it throws an 'InterruptLeft' exception to the left +-- thread in order to stop that thread from doing any more work. -- -- Because calls to @race@ can be nested it's important that different -- 'InterruptLeft' or 'InterruptRight' exceptions are not mixed-up. For this @@ -553,11 +542,10 @@ race left right = do catch (do l <- restore left throwTo rightTid $ InterruptRight u $ Right l) $ \e -> case fromException e of - Just (InterruptLeft u' rightEx) | u == u' -> throwIO rightEx - _ -> do throwTo rightTid $ InterruptRight u (Left e) - throwIO e + Just (InterruptLeft u') | u == u' -> return () + _ -> throwTo rightTid $ InterruptRight u (Left e) catch (do r <- restore right - throwTo leftTid $ InterruptLeft u ThreadKilled + throwTo leftTid $ InterruptLeft u return $ Right r) $ \e -> case fromException e of Just (InterruptRight u' leftResult) | u == u' -> @@ -565,20 +553,11 @@ race left right = do Left ex -> throwIO ex Right l -> return $ Left $ unsafeCoerce l _ -> do - case fromException e of -# if MIN_VERSION_base(4,7,0) - Just (_ :: SomeAsyncException) - -> throwTo leftTid $ InterruptLeft u e -# else - Just (_ :: AsyncException) - -> throwTo leftTid $ InterruptLeft u e -# endif - Nothing - -> throwTo leftTid $ InterruptLeft u ThreadKilled + throwTo leftTid $ InterruptLeft u throwIO e -data InterruptLeft = forall e. (Exception e) => InterruptLeft Unique e - deriving (Typeable) +data InterruptLeft = InterruptLeft Unique + deriving Typeable instance Show InterruptLeft where show _ = "<< InterruptLeft >>" From ef3aa05c5260dd4030ed6485e5ececa769b1277b Mon Sep 17 00:00:00 2001 From: Bas van Dijk Date: Fri, 10 Oct 2014 00:22:36 +0200 Subject: [PATCH 06/12] Fixed bug in concurrently & renamed the interrupt exceptions --- Control/Concurrent/Async.hs | 105 +++++++++++++++++++++--------------- test/test-async.hs | 17 ++++++ 2 files changed, 80 insertions(+), 42 deletions(-) diff --git a/Control/Concurrent/Async.hs b/Control/Concurrent/Async.hs index 616fd5c..50a8661 100644 --- a/Control/Concurrent/Async.hs +++ b/Control/Concurrent/Async.hs @@ -515,23 +515,24 @@ concurrently left right = -- More concretely: -- -- * When @left@ terminates, whether normally or by raising an --- exception, it wraps its result in the 'InterruptRight' exception --- and throws that to the right thread. +-- exception, it wraps its result in the 'UniqueInterruptWithResult' +-- exception and throws that to the right thread. -- --- * When the right thread catches the 'InterruptRight' exception it --- will either throw the contained exception or return the left result --- normally. +-- * When the right thread catches the 'UniqueInterruptWithResult' +-- exception it will either throw the contained exception or return +-- the left result normally. -- -- * When @right@ terminates, whether normally or by raising an --- exception, it throws an 'InterruptLeft' exception to the left +-- exception, it throws an 'UniqueInterrupt' exception to the left -- thread in order to stop that thread from doing any more work. -- -- Because calls to @race@ can be nested it's important that different --- 'InterruptLeft' or 'InterruptRight' exceptions are not mixed-up. For this --- reason each call to @race@ creates a 'Unique' value that gets embedded in the --- interrupt exceptions being thrown. When catching the interrupt exceptions we --- check if the Unique equals the Unique of this invocation of @race@. (This is --- the same trick used in the Timeout exception from System.Timeout). +-- 'UniqueInterrupt' or 'UniqueInterruptWithResult' exceptions are not +-- mixed-up. For this reason each call to @race@ creates a 'Unique' +-- value that gets embedded in the interrupt exceptions being +-- thrown. When catching the interrupt exceptions we check if the +-- Unique equals the Unique of this invocation of @race@. (This is the +-- same trick used in the Timeout exception from System.Timeout). -- race :: IO a -> IO b -> IO (Either a b) race left right = do @@ -539,58 +540,81 @@ race left right = do u <- newUnique mask $ \restore -> do leftTid <- forkIO $ - catch (do l <- restore left - throwTo rightTid $ InterruptRight u $ Right l) $ \e -> + catch + (do l <- restore left + throwTo rightTid $ UniqueInterruptWithResult u $ Right l) $ \e -> case fromException e of - Just (InterruptLeft u') | u == u' -> return () - _ -> throwTo rightTid $ InterruptRight u (Left e) - catch (do r <- restore right - throwTo leftTid $ InterruptLeft u - return $ Right r) $ \e -> + Just (UniqueInterrupt u') | u == u' -> return () + _ -> throwTo rightTid $ UniqueInterruptWithResult u (Left e) + catch + (do r <- restore right + throwTo leftTid $ UniqueInterrupt u + return $ Right r) $ \e -> case fromException e of - Just (InterruptRight u' leftResult) | u == u' -> + Just (UniqueInterruptWithResult u' leftResult) | u == u' -> case leftResult of Left ex -> throwIO ex Right l -> return $ Left $ unsafeCoerce l _ -> do - throwTo leftTid $ InterruptLeft u + throwTo leftTid $ UniqueInterrupt u throwIO e -data InterruptLeft = InterruptLeft Unique +-- race_ :: IO a -> IO b -> IO () +race_ left right = void $ race left right + +-- concurrently :: IO a -> IO b -> IO (a,b) +concurrently left right = do + mv <- newEmptyMVar + rightTid <- myThreadId + u <- newUnique + mask $ \restore -> do + leftTid <- forkIO $ catch (restore left >>= putMVar mv) $ \e -> + case fromException e of + Just (UniqueInterrupt u') | u == u' -> return () + _ -> throwTo rightTid $ UniqueInterruptWithSomeException u e + catch (flip (,) <$> restore right <*> takeMVar mv) $ \e -> + case fromException e of + Just (UniqueInterruptWithSomeException u' ex) | u == u' -> throwIO ex + _ -> do throwTo leftTid (UniqueInterrupt u) + throwIO e + +data UniqueInterrupt = UniqueInterrupt Unique deriving Typeable -instance Show InterruptLeft where - show _ = "<< InterruptLeft >>" +instance Show UniqueInterrupt where + show _ = "<< UniqueInterrupt >>" -instance Exception InterruptLeft where +instance Exception UniqueInterrupt where #if MIN_VERSION_base(4,7,0) toException = asyncExceptionToException fromException = asyncExceptionFromException #endif -data InterruptRight = forall a. InterruptRight Unique (Either SomeException a) - deriving Typeable +data UniqueInterruptWithResult = + forall a. UniqueInterruptWithResult Unique (Either SomeException a) + deriving Typeable -instance Show InterruptRight where - show _ = "<< InterruptRight >>" +instance Show UniqueInterruptWithResult where + show _ = "<< UniqueInterruptWithResult >>" -instance Exception InterruptRight where +instance Exception UniqueInterruptWithResult where #if MIN_VERSION_base(4,7,0) toException = asyncExceptionToException fromException = asyncExceptionFromException #endif --- race_ :: IO a -> IO b -> IO () -race_ left right = void $ race left right +data UniqueInterruptWithSomeException = + UniqueInterruptWithSomeException Unique SomeException + deriving Typeable --- concurrently :: IO a -> IO b -> IO (a,b) -concurrently left right = do - mv <- newEmptyMVar - rightTid <- myThreadId - mask $ \restore -> do - leftTid <- forkIO $ (restore left >>= putMVar mv) - `catchAll` throwTo rightTid - (flip (,) <$> restore right <*> takeMVar mv) `alsoThrowingTo` leftTid +instance Show UniqueInterruptWithSomeException where + show _ = "<< UniqueInterruptWithSomeException >>" + +instance Exception UniqueInterruptWithSomeException where +#if MIN_VERSION_base(4,7,0) + toException = asyncExceptionToException + fromException = asyncExceptionFromException +#endif #endif -- | maps an @IO@-performing function over any @Traversable@ data @@ -657,9 +681,6 @@ forkRepeat action = _ -> return () in forkIO go -catchAll :: IO a -> (SomeException -> IO a) -> IO a -catchAll = catch - tryAll :: IO a -> IO (Either SomeException a) tryAll = try diff --git a/test/test-async.hs b/test/test-async.hs index 8e644f9..15bdf28 100644 --- a/test/test-async.hs +++ b/test/test-async.hs @@ -59,6 +59,10 @@ tests = [ , testCase "left_terminates_by_asynchronous_exception_kills_right" race_left_terminates_by_asynchronous_exception_kills_right ] + , testGroup "concurrently" $ + [ testCase "1" concurrently_1 + , testCase "2" concurrently_2 + ] ] value = 42 :: Int @@ -268,3 +272,16 @@ race_left_terminates_by_asynchronous_exception_kills_right = do return 'x') r @?= Left ThreadKilled + +concurrently_1 :: Assertion +concurrently_1 = do + r <- concurrently (threadDelay 1000 >> return 1) + (threadDelay 1000 >> return 'x') + r @?= (1, 'x') + +concurrently_2 :: Assertion +concurrently_2 = do + void $ timeout 1000 $ + concurrently (threadDelay 10000) + (threadDelay 10000) + threadDelay 10000 From 1f3d9f98a31dbdc3dbdfea984fb5a838baa0c74b Mon Sep 17 00:00:00 2001 From: Bas van Dijk Date: Fri, 10 Oct 2014 09:24:00 +0200 Subject: [PATCH 07/12] Fixed some bugs in race and concurrently --- Control/Concurrent/Async.hs | 75 ++++++++++++++++--------------------- test/test-async.hs | 46 +++++++++++++++++++++++ 2 files changed, 79 insertions(+), 42 deletions(-) diff --git a/Control/Concurrent/Async.hs b/Control/Concurrent/Async.hs index 50a8661..07e2766 100644 --- a/Control/Concurrent/Async.hs +++ b/Control/Concurrent/Async.hs @@ -127,6 +127,7 @@ import Control.Applicative import Data.Traversable import Data.Typeable import Data.Unique +import Data.IORef import Unsafe.Coerce import GHC.Exts @@ -222,32 +223,26 @@ withAsyncOnWithUnmask cpu actionWith = withAsyncUsing (rawForkOn cpu) (actionWit withAsyncUsing :: (IO () -> IO ThreadId) -> IO a -> (Async a -> IO b) -> IO b --- The bracket version works, but is slow. We can do better by --- hand-coding it: withAsyncUsing doFork = \action inner -> do var <- newEmptyTMVarIO mask $ \restore -> do t <- doFork $ try (restore action) >>= atomically . putTMVar var let a = Async t (readTMVar var) - r <- restore (inner a) `alsoThrowingTo` t + r <- restore (inner a) `catch` \e -> do + throwAsyncTo t e + throwIO e cancel a return r --- | If the given action throws an asynchronous exception then also --- throw it to the specified thread. If it throws a synchronous --- exception then kill the specified thread. -alsoThrowingTo :: IO a -> ThreadId -> IO a -m `alsoThrowingTo` tid = m `catch` handler - where - handler e = do - case fromException e of -# if MIN_VERSION_base(4,7,0) - Just (_ :: SomeAsyncException) -> throwTo tid e -# else - Just (_ :: AsyncException) -> throwTo tid e -# endif - Nothing -> throwTo tid ThreadKilled - throwIO e +throwAsyncTo :: ThreadId -> SomeException -> IO () +throwAsyncTo tid e = + case fromException e of +#if MIN_VERSION_base(4,7,0) + Just (_ :: SomeAsyncException) -> throwTo tid e +#else + Just (_ :: AsyncException) -> throwTo tid e +#endif + Nothing -> throwTo tid ThreadKilled -- | Wait for an asynchronous action to complete, and return its -- value. If the asynchronous action threw an exception, then the @@ -538,17 +533,21 @@ concurrently left right = race left right = do rightTid <- myThreadId u <- newUnique + throwToRightRef <- newIORef True mask $ \restore -> do leftTid <- forkIO $ catch (do l <- restore left - throwTo rightTid $ UniqueInterruptWithResult u $ Right l) $ \e -> - case fromException e of - Just (UniqueInterrupt u') | u == u' -> return () - _ -> throwTo rightTid $ UniqueInterruptWithResult u (Left e) + throwToRight <- readIORef throwToRightRef + when throwToRight $ + throwTo rightTid $ UniqueInterruptWithResult u $ Right l) $ \e -> do + throwToRight <- readIORef throwToRightRef + when throwToRight $ + throwTo rightTid $ UniqueInterruptWithResult u (Left e) catch (do r <- restore right - throwTo leftTid $ UniqueInterrupt u + writeIORef throwToRightRef False + throwTo leftTid ThreadKilled return $ Right r) $ \e -> case fromException e of Just (UniqueInterruptWithResult u' leftResult) | u == u' -> @@ -556,7 +555,8 @@ race left right = do Left ex -> throwIO ex Right l -> return $ Left $ unsafeCoerce l _ -> do - throwTo leftTid $ UniqueInterrupt u + writeIORef throwToRightRef False + throwAsyncTo leftTid e throwIO e -- race_ :: IO a -> IO b -> IO () @@ -567,28 +567,19 @@ concurrently left right = do mv <- newEmptyMVar rightTid <- myThreadId u <- newUnique + throwToRightRef <- newIORef True mask $ \restore -> do - leftTid <- forkIO $ catch (restore left >>= putMVar mv) $ \e -> - case fromException e of - Just (UniqueInterrupt u') | u == u' -> return () - _ -> throwTo rightTid $ UniqueInterruptWithSomeException u e + leftTid <- forkIO $ catch (restore left >>= putMVar mv) $ \e -> do + throwToRight <- readIORef throwToRightRef + when throwToRight $ + throwTo rightTid $ UniqueInterruptWithSomeException u e catch (flip (,) <$> restore right <*> takeMVar mv) $ \e -> case fromException e of Just (UniqueInterruptWithSomeException u' ex) | u == u' -> throwIO ex - _ -> do throwTo leftTid (UniqueInterrupt u) - throwIO e - -data UniqueInterrupt = UniqueInterrupt Unique - deriving Typeable - -instance Show UniqueInterrupt where - show _ = "<< UniqueInterrupt >>" - -instance Exception UniqueInterrupt where -#if MIN_VERSION_base(4,7,0) - toException = asyncExceptionToException - fromException = asyncExceptionFromException -#endif + _ -> do + writeIORef throwToRightRef False + throwAsyncTo leftTid e + throwIO e data UniqueInterruptWithResult = forall a. UniqueInterruptWithResult Unique (Either SomeException a) diff --git a/test/test-async.hs b/test/test-async.hs index 15bdf28..86b333b 100644 --- a/test/test-async.hs +++ b/test/test-async.hs @@ -58,10 +58,14 @@ tests = [ race_right_terminates_by_asynchronous_exception_kills_both , testCase "left_terminates_by_asynchronous_exception_kills_right" race_left_terminates_by_asynchronous_exception_kills_right + , testCase "left_receives_asynchronous_exception" + race_left_receives_asynchronous_exception ] , testGroup "concurrently" $ [ testCase "1" concurrently_1 , testCase "2" concurrently_2 + , testCase "left_receives_asynchronous_exception" + concurrently_left_receives_asynchronous_exception ] ] @@ -273,6 +277,27 @@ race_left_terminates_by_asynchronous_exception_kills_right = do r @?= Left ThreadKilled +race_left_receives_asynchronous_exception :: Assertion +race_left_receives_asynchronous_exception = do + rightTidMv <- newEmptyMVar + + exMv <- newEmptyMVar + + forkIO $ do + threadDelay 1000 + rightTid <- takeMVar rightTidMv + throwTo rightTid UserInterrupt + + catchIgnore $ + race (threadDelay 100000 `catch` putMVar exMv) + (do rightTid <- myThreadId + putMVar rightTidMv rightTid + threadDelay 10000) + + ex <- takeMVar exMv + + ex @?= UserInterrupt + concurrently_1 :: Assertion concurrently_1 = do r <- concurrently (threadDelay 1000 >> return 1) @@ -285,3 +310,24 @@ concurrently_2 = do concurrently (threadDelay 10000) (threadDelay 10000) threadDelay 10000 + +concurrently_left_receives_asynchronous_exception :: Assertion +concurrently_left_receives_asynchronous_exception = do + rightTidMv <- newEmptyMVar + + exMv <- newEmptyMVar + + forkIO $ do + threadDelay 1000 + rightTid <- takeMVar rightTidMv + throwTo rightTid UserInterrupt + + catchIgnore $ + concurrently (threadDelay 100000 `catch` putMVar exMv) + (do rightTid <- myThreadId + putMVar rightTidMv rightTid + threadDelay 10000) + + ex <- takeMVar exMv + + ex @?= UserInterrupt From b5b35599de3b70ce64ec7e3d7030dfe3845da78c Mon Sep 17 00:00:00 2001 From: Bas van Dijk Date: Fri, 10 Oct 2014 09:38:02 +0200 Subject: [PATCH 08/12] Improved internal documentation --- Control/Concurrent/Async.hs | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/Control/Concurrent/Async.hs b/Control/Concurrent/Async.hs index 07e2766..deb3b64 100644 --- a/Control/Concurrent/Async.hs +++ b/Control/Concurrent/Async.hs @@ -509,17 +509,26 @@ concurrently left right = -- -- More concretely: -- --- * When @left@ terminates, whether normally or by raising an --- exception, it wraps its result in the 'UniqueInterruptWithResult' --- exception and throws that to the right thread. --- --- * When the right thread catches the 'UniqueInterruptWithResult' --- exception it will either throw the contained exception or return --- the left result normally. --- --- * When @right@ terminates, whether normally or by raising an --- exception, it throws an 'UniqueInterrupt' exception to the left --- thread in order to stop that thread from doing any more work. +-- When @left@ terminates, whether normally or by raising an +-- exception, it wraps its result in the 'UniqueInterruptWithResult' +-- exception and throws that to the right thread. +-- +-- When the right thread catches the 'UniqueInterruptWithResult' +-- exception it will either throw the contained exception or return +-- the left result normally. +-- +-- When @right@ terminates normally or by a non-asynchronous exception +-- it kills the left thread. When it terminates with an asynchronous +-- exception the exception is thrown to the left thread. +-- +-- When @right@ terminates it has to throw an exception to the left +-- thread. However, the left thread will throw an exception back to +-- the right thread when it receives one. We don't want the left +-- thread to throw back an exception it just got from the right +-- thread. To protect against this we set a boolean before the right +-- thread throws an exception to the left thread. The left thread +-- checks this boolean before it needs to throw an exception to the +-- right thread. -- -- Because calls to @race@ can be nested it's important that different -- 'UniqueInterrupt' or 'UniqueInterruptWithResult' exceptions are not From 4d5d73c800dfe7424b51006ce46b924a1d54ee23 Mon Sep 17 00:00:00 2001 From: Bas van Dijk Date: Fri, 10 Oct 2014 10:38:30 +0200 Subject: [PATCH 09/12] Added test-cases for nested calls of race and concurrently --- test/test-async.hs | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/test/test-async.hs b/test/test-async.hs index 86b333b..af077d6 100644 --- a/test/test-async.hs +++ b/test/test-async.hs @@ -60,12 +60,16 @@ tests = [ race_left_terminates_by_asynchronous_exception_kills_right , testCase "left_receives_asynchronous_exception" race_left_receives_asynchronous_exception + , testCase "nested" + race_nested ] , testGroup "concurrently" $ [ testCase "1" concurrently_1 , testCase "2" concurrently_2 , testCase "left_receives_asynchronous_exception" concurrently_left_receives_asynchronous_exception + , testCase "nested" + concurrently_nested ] ] @@ -298,6 +302,14 @@ race_left_receives_asynchronous_exception = do ex @?= UserInterrupt +race_nested :: Assertion +race_nested = do + r <- race (threadDelay 1000 >> return 1) + (race (threadDelay 10000 >> return 'x') + (threadDelay 100000 >> return False) + ) + r @?= Left 1 + concurrently_1 :: Assertion concurrently_1 = do r <- concurrently (threadDelay 1000 >> return 1) @@ -331,3 +343,18 @@ concurrently_left_receives_asynchronous_exception = do ex <- takeMVar exMv ex @?= UserInterrupt + +concurrently_nested :: Assertion +concurrently_nested = do + ref <- newIORef False + r <- try $ concurrently (threadDelay 1000 >> throwIO DivideByZero) + (concurrently (threadDelay 10000 >> writeIORef ref True) + (threadDelay 100000 >> return False)) + threadDelay 100000 + case r of + Left e -> case fromException e of + Just DivideByZero -> do + middleCompleted <- readIORef ref + assertBool "Middle completed!" $ not middleCompleted + _ -> assertFailure "" + Right _ -> assertFailure "" From 5ad752a292bd2d10ea8049388dae8a2b1ec534a6 Mon Sep 17 00:00:00 2001 From: Bas van Dijk Date: Fri, 10 Oct 2014 13:01:19 +0200 Subject: [PATCH 10/12] Abstracted a bit of common code in race --- Control/Concurrent/Async.hs | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/Control/Concurrent/Async.hs b/Control/Concurrent/Async.hs index deb3b64..cd2ac2e 100644 --- a/Control/Concurrent/Async.hs +++ b/Control/Concurrent/Async.hs @@ -531,28 +531,24 @@ concurrently left right = -- right thread. -- -- Because calls to @race@ can be nested it's important that different --- 'UniqueInterrupt' or 'UniqueInterruptWithResult' exceptions are not --- mixed-up. For this reason each call to @race@ creates a 'Unique' --- value that gets embedded in the interrupt exceptions being --- thrown. When catching the interrupt exceptions we check if the --- Unique equals the Unique of this invocation of @race@. (This is the --- same trick used in the Timeout exception from System.Timeout). +-- 'UniqueInterruptWithResult' exceptions are not mixed-up. For this +-- reason each call to @race@ creates a 'Unique' value that gets +-- embedded in the interrupt exceptions being thrown. When catching +-- the interrupt exceptions we check if the Unique equals the Unique +-- of this invocation of @race@. (This is the same trick used in the +-- Timeout exception from System.Timeout). -- race :: IO a -> IO b -> IO (Either a b) race left right = do rightTid <- myThreadId u <- newUnique throwToRightRef <- newIORef True - mask $ \restore -> do - leftTid <- forkIO $ - catch - (do l <- restore left - throwToRight <- readIORef throwToRightRef - when throwToRight $ - throwTo rightTid $ UniqueInterruptWithResult u $ Right l) $ \e -> do + let interruptRight r = do throwToRight <- readIORef throwToRightRef - when throwToRight $ - throwTo rightTid $ UniqueInterruptWithResult u (Left e) + when throwToRight $ throwTo rightTid $ UniqueInterruptWithResult u r + mask $ \restore -> do + leftTid <- forkIO $ catch (restore left >>= interruptRight . Right) + (interruptRight . Left) catch (do r <- restore right writeIORef throwToRightRef False From f8abf4483b46537dd36e56de0a378dbd880882cf Mon Sep 17 00:00:00 2001 From: Bas van Dijk Date: Fri, 10 Oct 2014 13:30:48 +0200 Subject: [PATCH 11/12] Refactored race and concurrently to make them look more similar --- Control/Concurrent/Async.hs | 74 +++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 36 deletions(-) diff --git a/Control/Concurrent/Async.hs b/Control/Concurrent/Async.hs index cd2ac2e..6182fc0 100644 --- a/Control/Concurrent/Async.hs +++ b/Control/Concurrent/Async.hs @@ -229,20 +229,22 @@ withAsyncUsing doFork = \action inner -> do t <- doFork $ try (restore action) >>= atomically . putTMVar var let a = Async t (readTMVar var) r <- restore (inner a) `catch` \e -> do - throwAsyncTo t e + throwTo t $ normalizeAsyncEx e throwIO e cancel a return r -throwAsyncTo :: ThreadId -> SomeException -> IO () -throwAsyncTo tid e = +-- | Converts non-asynchronous exception into ThreadKilled and leaves +-- asynchronous exceptions alone. +normalizeAsyncEx :: SomeException -> SomeException +normalizeAsyncEx e = case fromException e of #if MIN_VERSION_base(4,7,0) - Just (_ :: SomeAsyncException) -> throwTo tid e + Just (_ :: SomeAsyncException) -> e #else - Just (_ :: AsyncException) -> throwTo tid e + Just (_ :: AsyncException) -> e #endif - Nothing -> throwTo tid ThreadKilled + Nothing -> toException ThreadKilled -- | Wait for an asynchronous action to complete, and return its -- value. If the asynchronous action threw an exception, then the @@ -549,42 +551,42 @@ race left right = do mask $ \restore -> do leftTid <- forkIO $ catch (restore left >>= interruptRight . Right) (interruptRight . Left) - catch - (do r <- restore right + let interruptLeft e = do writeIORef throwToRightRef False - throwTo leftTid ThreadKilled - return $ Right r) $ \e -> - case fromException e of - Just (UniqueInterruptWithResult u' leftResult) | u == u' -> - case leftResult of - Left ex -> throwIO ex - Right l -> return $ Left $ unsafeCoerce l - _ -> do - writeIORef throwToRightRef False - throwAsyncTo leftTid e - throwIO e + throwTo leftTid e + catch ((Right <$> restore right) <* interruptLeft ThreadKilled) $ \e -> + case fromException e of + Just (UniqueInterruptWithResult u' leftResult) | u == u' + -> case leftResult of + Left ex -> throwIO ex + Right l -> return $ Left $ unsafeCoerce l + _ -> do interruptLeft $ normalizeAsyncEx e + throwIO e -- race_ :: IO a -> IO b -> IO () race_ left right = void $ race left right -- concurrently :: IO a -> IO b -> IO (a,b) concurrently left right = do - mv <- newEmptyMVar + leftResultMv <- newEmptyMVar rightTid <- myThreadId u <- newUnique throwToRightRef <- newIORef True + let interruptRight e = do + throwToRight <- readIORef throwToRightRef + when throwToRight $ throwTo rightTid $ UniqueInterruptWithException u e mask $ \restore -> do - leftTid <- forkIO $ catch (restore left >>= putMVar mv) $ \e -> do - throwToRight <- readIORef throwToRightRef - when throwToRight $ - throwTo rightTid $ UniqueInterruptWithSomeException u e - catch (flip (,) <$> restore right <*> takeMVar mv) $ \e -> - case fromException e of - Just (UniqueInterruptWithSomeException u' ex) | u == u' -> throwIO ex - _ -> do - writeIORef throwToRightRef False - throwAsyncTo leftTid e - throwIO e + leftTid <- forkIO $ catch (restore left >>= putMVar leftResultMv) + interruptRight + let interruptLeft e = do + writeIORef throwToRightRef False + throwTo leftTid e + catch (flip (,) <$> restore right <*> takeMVar leftResultMv) $ \e -> + case fromException e of + Just (UniqueInterruptWithException u' ex) | u == u' + -> throwIO ex + _ -> do interruptLeft $ normalizeAsyncEx e + throwIO e data UniqueInterruptWithResult = forall a. UniqueInterruptWithResult Unique (Either SomeException a) @@ -599,14 +601,14 @@ instance Exception UniqueInterruptWithResult where fromException = asyncExceptionFromException #endif -data UniqueInterruptWithSomeException = - UniqueInterruptWithSomeException Unique SomeException +data UniqueInterruptWithException = + UniqueInterruptWithException Unique SomeException deriving Typeable -instance Show UniqueInterruptWithSomeException where - show _ = "<< UniqueInterruptWithSomeException >>" +instance Show UniqueInterruptWithException where + show _ = "<< UniqueInterruptWithException >>" -instance Exception UniqueInterruptWithSomeException where +instance Exception UniqueInterruptWithException where #if MIN_VERSION_base(4,7,0) toException = asyncExceptionToException fromException = asyncExceptionFromException From ed671602a147dd58546edbf299e131efc494239f Mon Sep 17 00:00:00 2001 From: Bas van Dijk Date: Fri, 10 Oct 2014 13:39:19 +0200 Subject: [PATCH 12/12] Remove redundant parenthesis --- Control/Concurrent/Async.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Control/Concurrent/Async.hs b/Control/Concurrent/Async.hs index 6182fc0..71974e4 100644 --- a/Control/Concurrent/Async.hs +++ b/Control/Concurrent/Async.hs @@ -554,7 +554,7 @@ race left right = do let interruptLeft e = do writeIORef throwToRightRef False throwTo leftTid e - catch ((Right <$> restore right) <* interruptLeft ThreadKilled) $ \e -> + catch (Right <$> restore right <* interruptLeft ThreadKilled) $ \e -> case fromException e of Just (UniqueInterruptWithResult u' leftResult) | u == u' -> case leftResult of