diff --git a/Data/Array/Accelerate/Trafo/Simplify.hs b/Data/Array/Accelerate/Trafo/Simplify.hs index 617edf3ff..0022703c7 100644 --- a/Data/Array/Accelerate/Trafo/Simplify.hs +++ b/Data/Array/Accelerate/Trafo/Simplify.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE BangPatterns #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} @@ -28,6 +27,7 @@ import Data.List ( nubBy ) import Data.Maybe import Data.Monoid import Data.Typeable +import Text.Printf import Control.Applicative hiding ( Const ) import Prelude hiding ( exp, iterate ) @@ -301,15 +301,20 @@ simplifyOpenExp env = first getAny . cvtE -- If we are projecting elements from a tuple structure or tuple of constant -- valued tuple, pick out the appropriate component directly. -- + -- Follow variable bindings, but only if they result in a simplification. + -- prj :: forall s t. (Elt s, Elt t, IsTuple t) => TupleIdx (TupleRepr t) s -> (Any, PreOpenExp acc env aenv t) -> (Any, PreOpenExp acc env aenv s) prj ix exp@(_,exp') - | Tuple t <- exp' = Stats.inline "prj/Tuple" . yes $ prjT ix t - | Const c <- exp' = Stats.inline "prj/Const" . yes $ prjC ix (fromTuple (toElt c :: t)) - | Let a b <- exp' = Stats.ruleFired "prj/Let" $ cvtE (Let a (Prj ix b)) - | otherwise = Prj ix <$> exp + | Tuple t <- exp' = Stats.inline "prj/Tuple" . yes $ prjT ix t + | Const c <- exp' = Stats.inline "prj/Const" . yes $ prjC ix (fromTuple (toElt c :: t)) + | Var v <- exp' + , e <- prjExp v env + , Nothing <- match exp' e + , (Any True, c) <- prj ix (pure e) = Stats.inline "prj/Var" . yes $ c + | otherwise = Prj ix <$> exp where prjT :: TupleIdx tup s -> Tuple (PreOpenExp acc env aenv) tup -> PreOpenExp acc env aenv s prjT ZeroTupIdx (SnocTup _ e) = e @@ -421,34 +426,35 @@ iterate -> (f a -> (Bool, f a)) -> f a -> f a -iterate ppr f = fix 0 . setup . simplify' +iterate ppr f = fix 1 . setup where -- The maximum number of simplifier iterations. To be conservative and avoid -- excessive run times, we set this value very low. -- - lIMIT = 1 + lIMIT = 5 simplify' = Stats.simplifierDone . f - setup (_,x) = msg x x + setup x = Stats.trace Stats.dump_simpl_iterations (printf "simplifier begin:\n%s\n" (ppr x)) + $ snd (trace 0 "simplify" (simplify' x)) fix :: Int -> f a -> f a - fix !i !x0 - | i >= lIMIT = $internalWarning "iterate" "iteration limit reached" (x0 ==^ f x0) x0 + fix i x0 + | i > lIMIT = $internalWarning "iterate" "iteration limit reached" (x0 ==^ f x0) x0 | not shrunk = x1 | not simplified = x2 | otherwise = fix (i+1) x2 where - (shrunk, x1) = trace $ shrink' x0 - (simplified, x2) = trace $ simplify' x1 + (shrunk, x1) = trace i "shrink" $ shrink' x0 + (simplified, x2) = trace i "simplify" $ simplify' x1 -- debugging support -- u ==^ (_,v) = isJust (match u v) - trace v@(changed,x) - | changed = msg x v + trace i s v@(changed,x) + | changed = Stats.trace Stats.dump_simpl_iterations (msg i s x) v | otherwise = v - msg :: f a -> x -> x - msg x next = Stats.trace Stats.dump_simpl_iterations (unlines [ "simplifier done", ppr x ]) next + msg :: Int -> String -> f a -> String + msg i s x = printf "%s [%d/%d]:\n%s\n" s i lIMIT (ppr x)