Permalink
2f7c347 Dec 7, 2016
604 lines (443 sloc) 18.2 KB

Haskell's continuation monad

I am convinced that many complicated things in Haskell can - and have to be - understood on two levels: how to use it, and how it's implemented. Usage is usually much easier to grasp than implementation, which is why it should be the first approach. However, this practical knowledge doesn't tackle corner cases much, and may lead to unexpected surprises. This is where the inner workings of a function come into play, which is a lot easier to learn now that one knows what a function does; as a result, this feeds back to the intuitive side of things, eventually reaching "true" understanding. This article is split in these two levels: first the intuitive one, then the technicalities.

The concept of continuations

You surely know the transformation to introduce better recursion by packing the result into an accumulator. For example

length []     = 0
length (_:xs) = 1 + length xs

will build a large stack of 1+1+1+... operations, but this issue can be fixed by packing the "1+" in another parameter, so that there's some cheap form of state (namely the length so far) in the function:

length = length' 0 where
    length' l []     = l
    length' l (_:xs) = length' (l+1) xs

This converts the return value of the function to a parameter containing what has been done already - its past, so to speak.

Now continuations can be seen as a similar transformation: you put something ordinary in a parameter instead of directly dealing with it. In case of a continuation, that parameter is what's to be done next.

--               |-------------| <- "Continuation" type. ("square maps an
--               |             |    'a' to a continuation")
square :: Int -> (Int -> r) -> r
square x = \k -> k (x^2)

Instead of returning the value of x^2, it returns it wrapped in another function. You can see k as the future of square x. For example, if you set the future of square x to be the identity, you get

  square x id
= (\k -> k (x^2)) id
= id (x^2)
= x^2

A very useful way of picturing how this works in practice is this:

square 2 $ \twoSquare -> (...)

This means that wherever you are in (...), you have the result of square 2 at hand by using twoSquare. So the intuitive understanding is that continuations put the result value in a lambda argument (instead of returning it like ordinary functions). Let's see that in action:

square :: Int -> (Int -> r) -> r
square x = \k -> k (x^2)

add :: Int -> Int -> (Int -> r) -> r
add x y  = \k -> k (x + y)

pythagoras :: Int -> Int -> (Int -> r) -> r
pythagoras x y = \k ->
    square x $ \xSquare ->
    square y $ \ySquare ->
    add xSquare ySquare $ \result ->
    k result

In the end the result of the addition is bound to result, which is what we want to have. Since we're staying in continuation-passing style, pythagoras' continuation k is applied to the result.

Now calling pythagoras 3 4 will result in the hypothetical result of 3^2 + 4^2, so pythagoras 3 4 id = 25.

The important take-away message here, again, is that the following two lines are equivalent: both of them calculate f x, and then make the result available inside the (...) under the name y.

y = f x
(...)

f x $ \y -> (...)

The same piece of pythagoras code written using the Cont type would look like this:

square :: Int -> Cont r Int
square x = Cont $ \k -> k (x^2)

add :: Int -> Int -> Cont r Int
add x y = Cont $ \k -> k (x + y)

pythagoras, pythagoras' :: Int -> Int -> Cont r Int
pythagoras x y = do
    xSquare <- square x
    ySquare <- square y
    result <- add xSquare ySquare
    pure result
-- or in equivalent explicit bind notation:
pythagoras' x y =
    square x >>= \xSquare ->
    square y >>= \ySquare ->
    add xSquare ySquare >>= \result ->
    pure result

But now this can be refactored because of all the typeclasses Cont ships with. All of the functions now take Cont as parameters, and the intermediate continuations are passed through automatically.

-- Square a Cont
square :: Cont r Int -> Cont r Int
square = fmap (^2)

-- Add two Conts
add :: Cont r Int -> Cont r Int -> Cont r Int
add = liftA2 (+)

-- Pythagoras two conts
pythagoras :: Cont r Int -> Cont r Int -> Cont r Int
pythagoras x y = add (square x)
                     (square y)

-- GHCi
>>> pythagoras (pure 3) (pure 4) `runCont` id
25

What Cont does

The purpose of Cont is getting rid of all the explicit \k -> k ... you would have to introduce manually, much like State takes care of dragging along the \s -> ... explicitly. Cont itself is simply a wrapper around the type signature of a continuation like above in the Pythagoras example; note that square took a parameter to produce a continuation, so it had an additional a ->.

newtype Cont r a = Cont { (>>-) :: (a -> r) -> r }

Usually, the operator >>- would be called runCont, but it turns out that using it infix makes for some very readable notation, as long as we keep in mind that

(>>-) :: Cont r a -> (a -> r) -> r

We will now be building our way up to the Monad instance, defining Functor and Applicative on the way.

Functor

The Functor instance is probably the most important one you have to learn for Cont: once you understand it, Applicative and Monad will come naturally. I was very surprised myself once I realized this.

We'll be constructing the Functor based on the unwrap-apply-wrap idea. The box picture of certain Haskell types might sometimes be awkward, but for certain computations it's a helpful model nevertheless. So for a Functor instance, what we want to do to get fmap :: (a -> b) -> Cont r a -> Cont r b is:

  1. Extract the value of type a out of the Cont we're given.
  2. Apply the function f to that value, yielding a new value of type b.
  3. Wrap this value in a Cont again.

The hardest part is the first one, I believe. But let's start simple: by writing as much as we know down. We know that we want a Cont again in the end, so the code will look like this:

fmap f cx = Cont ...

Because of Conts type ((a -> r) -> r) -> Cont r a, we can see that the argument given to Cont is of type (a -> r) -> r. The a -> r part of that can be bound to a value named k:

fmap f cx = Cont $ \k -> ...

Now let's tackle the first problem: extracting a value x :: a out of the given cx :: Cont r a. Remember the pythagoras example from above, where instead of returning values they were provided to us as a lambda parameter? That's exactly what happens here again.

fmap f cx = Cont $ \k ->
    cx >>- \x ->
    ...

This gives us access to x :: a, so 1. is finished! The second issue is applying f to that value, which is trivial:

fmap f cx = Cont $ \k ->
    cx >>- \x ->
    let result = f x
    in  ...

now we have our result as a value, next is step 3, making a proper Cont again. We could simply return result from the function. But that would completely ignore the k, and thus break the continuation (ignoring all future continuations). But what have we learned before? "Don't return values directly, return them wrapped in the continuation parameter!"

fmap f cx = Cont $ \k ->
    cx >>- \x ->
    let result = f x
    in  k result

And that's it! Inlining result gives us a readable Functor instance:

instance Functor (Cont r) where
    fmap f cx = Cont $ \k ->
        cx >>- \x ->
        k (f x)

To reiterate, here are the three steps we've done again:

  1. Extract the value out of a computation cx. This is done by cx >>- \x ->, which gives us access to said vaue in the lambda's body.
  2. Transform that value using the mapping function f.
  3. Wrap the new value in the continuation parameter k again.

Applicative

I promised this would be easy when you've understood the Functor instance, let's hope I can stand true to my promise.

pure does not have to do any extraction or transformation, so steps 1 and 2 fall away. All that's left to do is to take our general form

pure :: a -> Cont r a
pure x = Cont $ \k -> ...

The resulting value is provided by x already, and instead of returning it directly in the lambda's body, we wrap it in k.

pure x = Cont $ \k -> k x

Done.

Up next: <*>. This time, we have two Cont values: one holding a function, and one holding the value to apply it to. We'll extract the function, then extract the value, apply the function to the value, and wrap it in k again.

cf <*> cx = Cont $ \k ->
    cf >>- \f -> -- extract f
    cx >>- \x -> -- extract x
    k (f x)      -- apply f to x, wrap in k

And we're done!

(Bonus joke for the experienced Haskeller: imagine what happens when we switch the second and third lines above. Nothing? Now imagine a ContT transformer doing this.)

Monad

Since we already have the Applicative we can set return = pure. Now for cx >>= f, you guessed it: extract the value x out of the provided Cont cx, apply the function f to it (yielding a new Cont), extract the value out of that Cont, and wrap it up in k again.

cx >>= f = Cont $ \k ->
    cx >>- \x ->
    (f x) >>- \fx ->
    k fx

As a bonus, we could also define join, which is similarly simple.

join' :: Cont r (Cont r a) -> Cont r a
join' ccx = Cont $ \k ->
    ccx >>- \cx -> -- extract inner continuation
    cx >>- \x ->   -- extract value of inner continuation
    k x            -- wrap value in k again

Special API function: callCC

Purpose

callCC introduces a single, well-defined early return.

Suppose you want a Cont value to break out of the evaluation and return a value x. You would implement this by simply ignoring the continuation parameter k:

exit x = Cont $ \_k -> x

If you place this anywhere in a large Cont calculation, the final result will simply be x if exit is evaluated (used by >>=) at any point.

callCC creates a sandboxed environment for this idea: instead of leaving the entire calculation, only the sandbox is jumped out of. It is usually used like this:

foo :: Cont r a
foo = callCC $ \exit -> ...

inside the ..., you now have an exit function that you can use. When this function is evaluated, its argument will become the value of the callCC block.

An example: suppose you want to convert some data to PDF format, but if there is an error it should abort and return an empty result.

-- Create a PDF represented as Text from some Data
toPDF :: Data -> Cont r Text
toPDF d = callCC $ \exit -> do
    when (broken d) (exit "ERROR")
    makePDF d

The nice thing about callCC is that it is nestable, providing the functionality of short-circuiting arbirary levels. The following example first checks whether the data is broken (and terminates the procedure entirely). If it is fine, it examines whether it's not too long (resulting in a worded error message); if the data is alright but the format is dirty, it cleans it, and if everything works out alright it adds annoying eye candy.

toPDF :: Data -> Cont r Text
toPDF d = callCC $ \exit1 -> do
    when (broken d) (exit1 "Data corrupt")
    d' <- callCC $ \exit2 -> do
        when (tooLong d) (exit1 "Data too long") -- Jump out of everything
        when (dirty d) (exit2 (clean d)) -- Jump out of the inner callCC
        pure (decorateWithFlowers d)
    pure (makePDF d')

To sum it up, each callCC carries around its own exit function, callable by the name of the lambda parameter. If you use this exit anywhere in its scope, the result of the corresponding callCC block will be exit's argument; if you do not use it, the program works as if there was no callCC $ \exit -> in the first place.

Some readers coming from an imperative background might recognize this jumping-around-behaviour as similar to statements like break, continue and goto. callCC can indeed be seen as a generalization of these concepts. callCC should not be used too lightheartedly, or the control flow of the program might become an obfuscated mess.

Intuition of the implementation

There is another intuitive, but more technical, view of callCC. We've learned that the purpose of having Cont around is so we don't have to worry about explicitly passing all the k continuations around, just like State abstracts passing the state s away.

We can label Cont by the parameter it abstracts away, for example we might say that

do x <- foo
   y <- bar
   let y' = if even x then y else y+1
   pure (x + y')

hides passing all the \k -> from us. callCC now opens a new block with another parameter, let's say it hides the \l -> in the following (equivalent) code:

do x <- foo                       -- k hidden
   y <- bar                       -- k hidden
   y' <- callCC $                 -- k hidden
       \exit -> do
           when (even x) (exit y) -- l hidden, conditionally inject k into l
           pure (y+1)             -- l hidden
   pure (x + y')                  -- k hidden

callCC has to somehow connect k and l, i.e. it defines what to continue with. It has two choices: keep using l (and continue like normal inside the callCC), or use the parent's continuation parameter k, ignoring the rest of the ls and "jumping" to the next instruction after the callCC block (pure (x + y') in our case).

Technical implementation

Recall we had our "hard exit function"

exit x = Cont $ \_ -> x

and the idea that callCC gives us a sandbox where a function like this is locally contained.

The idea behind the implementation is building an almost fully contained Cont calculation (which is evaluated inside the callCC). Once it terminates, its value is extracted, and fed to the parent Cont calculation (the one containing callCC). This gives us the structure

callCC' = Cont $ \k ->
    "<sandboxed calculation>" >>- \result ->
    k result

where as before we have the usual Cont $ \k -> wrapper in which we build a value, and then return k <value> in the end. What's left to ask now is how the sandboxed calculation should be built up.

The most primitive calculation that has to go in there somehow is simpy exit.

callCC' :: Cont Text a
callCC' = Cont $ \k ->
    exit "foobar" >>- \result ->
    k result

When evaluating this, the inner continuation encounters the exit, the final result of everything Cont is set to "foobar" by ignoring k, and thus everything after the second ($) is never evaluated. How dull. But we can fix this! If we abort the inner Cont calculation not with "foobar" but with k "foobar", then calling exit short-circuits out once more, but since it's using k now, it will continue as if its value was "foobar"!

callCC' :: Cont r Text
callCC' = Cont $ \k ->
    exit (k "foobar") >>- \result ->
    k result

You can see that this is actually equivalent to pure "foobar" when you refactor the code a bit:

callCC'
  = Cont $ \k ->  exit (k "foobar")       >>- \result -> k result
    -- inline exit,
  = Cont $ \k ->  Cont $ \_ -> k "foobar" >>- \result -> k result
    -- resolve >>-,
  = Cont $ \k -> (       \_ -> k "foobar")               k
    -- Delete the gaps,
  = Cont $ \k -> (\_ -> k "foobar") k
    -- Function evaluation,
  = Cont $ \k -> k "foobar"
    -- Recognize definition,
  = pure "foobar"

Great, so now we've achieved a detour version of pure using exit: instead of directly returning a value, we construct a new computation that short-circuits out with that value.

And now the key idea: what if we only conditionally evaluate exit here?

callCC' f = Cont $ \k ->
    f (exit (k "foobar")) >>- \result ->
    k result

Now it depends on f what happens:

  1. f uses its argument. The >>- will evaluate the result of whatever f produces, and eventually encounter the exit function. The entire thing short-circuits with k "foobar", so "foobar" is the value passed on to result.
  2. f ignores its argument, i.e. f is a constant function that maps to Cont r a. Whatever that value is, it is passed to result. No short-circuiting here.

This is already very close to callCC, only one last step is missing: abstracting the "foobar" away, since that's hardly the only value we ever want to short-circuit with.

callCC f = Cont $ \k ->
    f (\x -> exit (k x)) >>- \result ->
    k result
  where
    exit x = Cont $ \_ -> x

And there you have it: callCC! Let's reiterate the design: take a Cont computation that might short-circuit out by discarding its own continuation (the _ in exit), but wrap the value to exit with with the parent continuation k. On short-circuiting, the parent continuation is used for further calculations. On not short-circuiting, the exit function is never called, and the inner Cont is evaluated as if there was no callCC around in the first place, and it was simply evaluated along with its parent calculation.

Cont in the Haskell libraries

One thing worth mentioning is that Cont in the Haskell library Control.Monad.Trans.Cont is defined in terms of a monad transformer (on Identity). The general purpose Cont wrapper I've used in this article is only a smart constructor named cont (lower case c), and the type signatures are all a little more general to account for possible uses as a transformer. Apart from these small items, the other notable difference is the pointfree style. For example, fmap would be defined as

fmap f cx = Cont $ \k -> runCont cx (k . f)

I find this to be very hard to understand, so I chose to present the more pointful version here. Similar thoughts apply to all the other sections.

Acknowledgements

I would probably still be in the dark about continuations if it wasn't for #haskell, most notably due to parcs' and Chris Done's explanations. Thanks for that!