Skip to content
Permalink
master
Switch branches/tags
Go to file
 
 
Cannot retrieve contributors at this time
{-# language OverloadedStrings #-}
{-# language DataKinds #-}
{-# language BangPatterns #-}
module OptimizeTailRecursion where
import Control.Applicative ((<|>))
import Control.Lens.Cons (_last, _init)
import Control.Lens.Fold ((^..), (^?), (^?!), allOf, anyOf, folded, foldrOf)
import Control.Lens.Getter ((^.), to)
import Control.Lens.Plated (cosmos, transform, transformOn)
import Control.Lens.Prism (_Just)
import Control.Lens.Review ((#))
import Control.Lens.Setter ((%~), (.~))
import Control.Lens.Tuple (_2, _3)
import Data.Foldable (toList)
import Data.Function ((&))
import Data.Semigroup ((<>))
import Language.Python.Optics
import Language.Python.DSL
import Language.Python.Syntax.Expr (Expr (..), _Exprs, argExpr, paramName)
import Language.Python.Syntax.Statement (CompoundStatement (..), Statement (..), SmallStatement (..), SimpleStatement (..), _Statements)
optimizeTailRecursion :: Raw Statement -> Maybe (Raw Statement)
optimizeTailRecursion st = do
function <- st ^? _Fundef
let functionBody = function ^. body_
bodyLast <- lastStatement functionBody
let
functionName = function ^. fdName.identValue
bodyInit = functionBody ^?! _init
paramNames = function ^.. fdParameters.folded.paramName.identValue
if not $ hasTC functionName bodyLast
then Nothing
else
Just $
_Fundef #
(function &
body_ .~
(zipWith
(\a b -> line_ (var_ (a <> "__tr") .= var_ b))
paramNames
paramNames <>
[ line_ ("__res__tr" .= none_)
, line_ . while_ true_ .
transformOn (traverse._Exprs) (renameIn paramNames "__tr") $
bodyInit <>
looped functionName paramNames bodyLast
, line_ $ return_ "__res__tr"
]))
where
lastStatement :: [Raw Line] -> Maybe (Raw Statement)
lastStatement = go Nothing
where
go !res [] = res
go !res (a:as) = go (a ^? _Statements <|> res) as
isTailCall :: String -> Raw Expr -> Bool
isTailCall name e
| anyOf (cosmos._Call.callFunction._Ident.identValue) (== name) e
= (e ^? _Call.callFunction._Ident.identValue) == Just name
| otherwise = False
hasTC :: String -> Raw Statement -> Bool
hasTC name st =
case st of
CompoundStatement (If _ _ _ _ sts [] sts') ->
allOf _last (hasTC name) (sts ^.. _Statements) ||
allOf _last (hasTC name) (sts' ^.. _Just._3._Statements)
SmallStatement _ (MkSmallStatement s ss _ _ _) ->
case last (s : fmap (^. _2) ss) of
Return _ _ (Just e) -> isTailCall name e
-- Return _ _ Nothing -> True
Expr _ e -> isTailCall name e
_ -> False
_ -> False
renameIn :: [String] -> String -> Raw Expr -> Raw Expr
renameIn params suffix =
transform
(_Ident.identValue %~ (\a -> if a `elem` params then a <> suffix else a))
looped :: String -> [String] -> Raw Statement -> [Raw Line]
looped name params st
| Just ifSt <- st ^? _If
, hasTC name st =
let
ifBodyLines = toList $ ifSt ^. body_
in
case ifSt ^? to getElse._Just.body_ of
Nothing ->
[ line_ $
if_ (ifSt ^. ifCond)
((ifBodyLines ^?! _init) <>
looped name params (ifBodyLines ^?! _last._Statements))
]
Just sts'' ->
[ line_ $
if_ (ifSt ^. ifCond)
((ifSt ^?! body_.to toList._init) <>
looped name params (ifBodyLines ^?! _last._Statements)) &
else_
((toList sts'' ^?! _init) <>
looped name params (toList sts'' ^?! _last._Statements))
]
| otherwise =
case st of
CompoundStatement{} -> [line_ st]
SmallStatement idnts (MkSmallStatement s ss sc cmt nl) ->
let
initExps = foldr (\_ _ -> init ss) [] ss
lastExp = foldrOf (folded._2) (\_ _ -> last ss ^. _2) s ss
newSts =
case initExps of
[] -> []
first : rest ->
[ line_ $
SmallStatement idnts
(MkSmallStatement (first ^. _2) rest sc cmt nl)
]
in
case lastExp of
Return _ _ e ->
case e ^? _Just._Call of
Just call
| Just name' <- call ^? callFunction._Ident.identValue
, name' == name ->
newSts <>
fmap
(\a -> line_ (var_ (a <> "__tr__old") .= var_ (a <> "__tr")))
params <>
zipWith
(\a b -> line_ (var_ (a <> "__tr") .= b))
params
(transformOn
traverse
(renameIn params "__tr__old")
(call ^.. callArguments.folded.folded.argExpr))
_ ->
newSts <>
maybe [] (\e' -> [ line_ ("__res__tr" .= e') ]) e <>
[ line_ break_ ]
Expr _ e
| isTailCall name e -> newSts <> [line_ pass_]
_ -> [line_ st]