Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Separates IR types and transforms into different modules.

  • Loading branch information...
commit e10b39469b88b092016273acd0558abf140d6671 1 parent 484f798
@tomahawkins authored
Showing with 119 additions and 83 deletions.
  1. +119 −83 Language/CIRC.hs
View
202 Language/CIRC.hs
@@ -53,7 +53,7 @@ data TypeRefinement
-- | A transform is a module name, the constructor to be transformed, a list of new type definitions,
-- and the implementation (imports and code).
-data Transform = Transform ModuleName [Import] CtorName (ModuleName -> Code) [TypeRefinement]
+data Transform = Transform ModuleName [Import] [Import] CtorName (ModuleName -> Code) [TypeRefinement]
-- | An unparameterized type.
t :: String -> Type
@@ -61,102 +61,124 @@ t n = T n []
-- | Compiles a CIRC spec.
circ :: Spec -> IO ()
-circ (Spec initModuleName initImports rootTypeName types transforms) = do
- maybeWriteFile (initModuleName ++ ".hs") $ codeModule' initModuleName types Nothing
- foldM_ codeTransform (initModuleName, types) transforms
+circ (Spec initModuleName initImports rootTypeName typeDefsUnsorted transforms) = do
+ maybeWriteFile (initModuleName ++ ".hs") $ codeTypeModule initModuleName initImports typeDefs
+ maybeWriteFile (initModuleName ++ "Trans.hs") $ codeInitTransModule initModuleName rootTypeName
+ foldM_ codeTransform (initModuleName, typeDefs) transforms
where
- codeModule' = codeModule initModuleName initImports rootTypeName
+ typeDefs = sortTypeDefs typeDefsUnsorted
+
codeTransform :: (Name, [TypeDef]) -> Transform -> IO (Name, [TypeDef])
- codeTransform (prevName, prevTypes) (Transform currName localImports ctorName code typeMods) = do
- maybeWriteFile (currName ++ ".hs") $ codeModule' currName currTypes $ Just (prevName, localImports, prevTypes, ctorName, code prevName, [ (ctor, code prevName)| NewCtor _ (CtorDef ctor _) code <- typeMods ])
- return (currName, currTypes)
+ codeTransform (prevModuleName, prevTypeDefs) (Transform moduleName typeImports transImports ctorName transCode typeRefinements) = do
+ maybeWriteFile (moduleName ++ ".hs") $ codeTypeModule moduleName typeImports typeDefs
+ maybeWriteFile (moduleName ++ "Trans.hs") $ codeTransModule
+ initModuleName
+ rootTypeName
+ prevModuleName
+ prevTypeDefs
+ moduleName
+ transImports
+ typeDefs
+ ctorName
+ (transCode prevModuleName)
+ [ (ctorName, transCode prevModuleName) | NewCtor _ (CtorDef ctorName _) transCode <- typeRefinements ]
+ return (moduleName, typeDefs)
where
- filteredCtor = [ TypeDef name params [ CtorDef ctorName' args | CtorDef ctorName' args <- ctors, ctorName /= ctorName' ] | TypeDef name params ctors <- prevTypes ]
- currTypes = filterRelevantTypes rootTypeName $ nextTypes filteredCtor typeMods
-
- maybeWriteFile :: FilePath -> String -> IO ()
- maybeWriteFile file contents = do
- a <- doesFileExist file
- if not a then writeFile file contents else do
- f <- openFile file ReadMode
- contents' <- hGetContents f
- if contents' == contents
- then do
- hClose f
- return ()
- else do
- hClose f
- writeFile file contents
+ filteredCtor = [ TypeDef name params [ CtorDef ctorName' args | CtorDef ctorName' args <- ctors, ctorName /= ctorName' ] | TypeDef name params ctors <- prevTypeDefs ]
+ typeDefs = sortTypeDefs $ filterRelevantTypes rootTypeName $ nextTypes filteredCtor typeRefinements
+
+-- | Write out a file if the file doesn't exist or is different. Doesn't bump the timestamp for Makefile-like build systems.
+maybeWriteFile :: FilePath -> String -> IO ()
+maybeWriteFile file contents = do
+ a <- doesFileExist file
+ if not a then writeFile file contents else do
+ f <- openFile file ReadMode
+ contents' <- hGetContents f
+ if contents' == contents
+ then do
+ hClose f
+ return ()
+ else do
+ hClose f
+ writeFile file contents
+-- | Sort a list of TypeDefs by type name.
sortTypeDefs :: [TypeDef] -> [TypeDef]
sortTypeDefs = sortBy (compare `on` \ (TypeDef n _ _) -> n)
-codeModule :: ModuleName -> [String] -> TypeName -> ModuleName -> [TypeDef] -> Maybe (ModuleName, [Import], [TypeDef], CtorName, Code, [(CtorName, Code)]) -> String
-codeModule initModuleName initImports rootTypeName moduleName unsortedTypes trans = unlines $
+-- | Code the module that contains the IR datatype definitions.
+codeTypeModule :: ModuleName -> [Import] -> [TypeDef] -> String
+codeTypeModule moduleName imports typeDefs = unlines $
[ printf "module %s" moduleName
- , " ( " ++ intercalate "\n , " [ name ++ " (..)"| TypeDef name _ _ <- currTypes ]
- , " , transform"
- , " , transform'"
+ , " ( " ++ intercalate "\n , " [ name ++ " (..)"| TypeDef name _ _ <- typeDefs ]
, " ) where"
, ""
- , "import Language.CIRC.Runtime"
- ] ++ nub (case trans of { Nothing -> initImports; Just (m, i, _, _, _, _) -> ["import qualified " ++ initModuleName, "import qualified " ++ m] ++ i}) ++
- [ ""
- ] ++ (map codeTypeDef currTypes) ++
- case trans of
- Nothing ->
- [ printf "transform :: %s -> CIRC (%s, [%s])" rootTypeName rootTypeName rootTypeName
- , "transform a = return (a, [a])"
- , ""
- , printf "transform' :: %s -> CIRC %s" rootTypeName rootTypeName
- , "transform' = return"
- , ""
- ]
- Just (prevName, _, prevTypes, ctor, code, backwards) ->
- [ printf "transform :: %s.%s -> CIRC (%s, [%s.%s])" initModuleName rootTypeName rootTypeName initModuleName rootTypeName
- , printf "transform a = do"
- , printf " (a, b) <- %s.transform a" prevName
- , printf " a <- trans%s a" rootTypeName
- , printf " c <- transform' a"
- , printf " return (a, b ++ [c])"
- , printf ""
- , printf "transform' :: %s -> CIRC %s.%s" rootTypeName initModuleName rootTypeName
- , printf "transform' a = trans%s' a >>= %s.transform'" rootTypeName prevName
- , printf ""
- , codeTypeTransforms prevName prevTypes currTypes (ctor, code) backwards
- ]
- where
- currTypes = sortTypeDefs unsortedTypes
-
-codeTypeDef :: TypeDef -> String
-codeTypeDef (TypeDef name params ctors) = "data " ++ name ++ " " ++ intercalate " " params ++ "\n = " ++
- intercalate "\n | " [ name ++ replicate (m - length name) ' ' ++ " " ++ intercalate " " (map codeType args) | CtorDef name args <- ctors' ] ++ "\n"
+ ] ++ nub (["import Language.CIRC.Runtime"] ++ imports) ++ [""] ++ map codeTypeDef typeDefs
where
- ctors' = sortBy (compare `on` \ (CtorDef n _) -> n) ctors
- m = maximum [ length n | CtorDef n _ <- ctors ]
+ codeTypeDef :: TypeDef -> String
+ codeTypeDef (TypeDef name params ctors) = "data " ++ name ++ " " ++ intercalate " " params ++ "\n = " ++
+ intercalate "\n | " [ name ++ replicate (m - length name) ' ' ++ " " ++ intercalate " " (map codeType args) | CtorDef name args <- ctors' ] ++ "\n"
+ where
+ ctors' = sortBy (compare `on` \ (CtorDef n _) -> n) ctors
+ m = maximum [ length n | CtorDef n _ <- ctors ]
+
+ codeType :: Type -> String
+ codeType a = case a of
+ T name [] -> name
+ T name params -> "(" ++ name ++ intercalate " " (map codeType params) ++ ")"
+ TList a -> "[" ++ codeType a ++ "]"
+ TMaybe a -> "(Maybe " ++ codeType a ++ ")"
+ TTuple a -> "(" ++ intercalate ", " (map codeType a) ++ ")"
-codeType :: Type -> String
-codeType a = case a of
- T name [] -> name
- T name params -> "(" ++ name ++ intercalate " " (map codeType params) ++ ")"
- TList a -> "[" ++ codeType a ++ "]"
- TMaybe a -> "(Maybe " ++ codeType a ++ ")"
- TTuple a -> "(" ++ intercalate ", " (map codeType a) ++ ")"
+-- | Code the initial transform module.
+codeInitTransModule :: ModuleName -> TypeName -> String
+codeInitTransModule moduleName rootTypeName = unlines
+ [ printf "module %sTrans" moduleName
+ , printf " ( transform"
+ , printf " , transform'"
+ , printf " ) where"
+ , printf ""
+ , printf "import Language.CIRC.Runtime (CIRC)"
+ , printf "import %s (%s)" moduleName rootTypeName
+ , printf ""
+ , printf "transform :: %s -> CIRC (%s, [%s])" rootTypeName rootTypeName rootTypeName
+ , printf "transform a = return (a, [a])"
+ , printf ""
+ , printf "transform' :: %s -> CIRC %s" rootTypeName rootTypeName
+ , printf "transform' = return"
+ , printf ""
+ ]
--- | Computes the next type definitions given a list of type definitions and a list of type refinements.
-nextTypes :: [TypeDef] -> [TypeRefinement] -> [TypeDef]
-nextTypes old new = sortTypeDefs $ foldl nextType old new
- where
- nextType :: [TypeDef] -> TypeRefinement -> [TypeDef]
- nextType types refinement = case refinement of
- NewType t -> t : types
- NewCtor typeName ctorDef _ -> case match of
- [] -> error $ "Type not found: " ++ typeName
- _ : _ : _ -> error $ "Redundent type name: " ++ typeName
- [TypeDef _ params ctors] -> TypeDef typeName params (ctorDef : ctors) : rest
- where
- (match, rest) = partition (\ (TypeDef name _ _) -> name == typeName) types
+-- | Code the module that contains the IR transformations.
+codeTransModule :: ModuleName -> TypeName -> ModuleName -> [TypeDef] -> ModuleName -> [Import] -> [TypeDef] -> CtorName -> Code -> [(CtorName, Code)] -> String
+codeTransModule initModuleName rootTypeName prevModuleName prevTypeDefs moduleName imports typeDefs ctorName transCode backwardTransCode = unlines $
+ [ printf "module %sTrans" moduleName
+ , " ( transform"
+ , " , transform'"
+ , " ) where"
+ , ""
+ ] ++ nub (
+ [ "import Language.CIRC.Runtime"
+ , "import qualified " ++ initModuleName
+ , "import qualified " ++ prevModuleName
+ , "import qualified " ++ prevModuleName ++ "Trans"
+ , "import " ++ moduleName
+ ] ++ imports) ++
+ [ printf "transform :: %s.%s -> CIRC (%s, [%s.%s])" initModuleName rootTypeName rootTypeName initModuleName rootTypeName
+ , printf "transform a = do"
+ , printf " (a, b) <- %sTrans.transform a" prevModuleName
+ , printf " a <- trans%s a" rootTypeName
+ , printf " c <- transform' a"
+ , printf " return (a, b ++ [c])"
+ , printf ""
+ , printf "transform' :: %s -> CIRC %s.%s" rootTypeName initModuleName rootTypeName
+ , printf "transform' a = trans%s' a >>= %sTrans.transform'" rootTypeName prevModuleName
+ , printf ""
+ , codeTypeTransforms prevModuleName prevTypeDefs typeDefs (ctorName, transCode) backwardTransCode
+ , printf ""
+ ]
+-- | Codes the type transform function.
codeTypeTransforms :: ModuleName -> [TypeDef] -> [TypeDef] -> (CtorName, Code) -> [(CtorName, Code)] -> String
codeTypeTransforms prevName prevTypes currTypes forwardTrans backwardTrans =
concatMap (codeTypeTransform prevTypes [forwardTrans] (\ t -> "trans" ++ t) qualified id) [ t | t@(TypeDef n _ _) <- prevTypes, elem n $ map typeDefName currTypes ] ++
@@ -208,6 +230,20 @@ primitiveTypes a = case a of
indent :: String -> String
indent = unlines . map (" " ++) . lines
+-- | Computes the next type definitions given a list of type definitions and a list of type refinements.
+nextTypes :: [TypeDef] -> [TypeRefinement] -> [TypeDef]
+nextTypes old new = foldl nextType old new
+ where
+ nextType :: [TypeDef] -> TypeRefinement -> [TypeDef]
+ nextType types refinement = case refinement of
+ NewType t -> t : types
+ NewCtor typeName ctorDef _ -> case match of
+ [] -> error $ "Type not found: " ++ typeName
+ _ : _ : _ -> error $ "Redundent type name: " ++ typeName
+ [TypeDef _ params ctors] -> TypeDef typeName params (ctorDef : ctors) : rest
+ where
+ (match, rest) = partition (\ (TypeDef name _ _) -> name == typeName) types
+
-- | Get rid of types that are not relevant to the root type.
filterRelevantTypes :: TypeName -> [TypeDef] -> [TypeDef]
filterRelevantTypes rootTypeName types = [ t | t@(TypeDef n _ _) <- types, elem n required ]
Please sign in to comment.
Something went wrong with that request. Please try again.