Skip to content

Commit

Permalink
Deriving Hashable
Browse files Browse the repository at this point in the history
  • Loading branch information
fehrenbach committed Jan 30, 2021
1 parent e56d28b commit ad31518
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/Language/PureScript/Constants/Prelude.hs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ returnEscaped = "$return"
unit :: forall a. (IsString a) => a
unit = "unit"

hashWithSalt :: forall a. (IsString a) => a
hashWithSalt = "hashWithSalt"

-- Core lib values

runST :: forall a. (IsString a) => a
Expand Down
72 changes: 72 additions & 0 deletions src/Language/PureScript/Sugar/TypeClasses/Deriving.hs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ deriveInstance
-> Declaration
-> m Declaration
deriveInstance mn syns kinds _ ds (TypeInstanceDeclaration sa@(ss, _) ch idx nm deps className tys DerivedInstance)
| className == Qualified (Just dataHashable) (ProperName "Hashable")
= case tys of
[ty] | Just (Qualified mn' tyCon, _) <- unwrapTypeConstructor ty
, mn == fromMaybe mn mn'
-> TypeInstanceDeclaration sa ch idx nm deps className tys . ExplicitInstance <$> deriveHashable ss mn syns kinds ds tyCon
| otherwise -> throwError . errorMessage' ss $ ExpectedTypeConstructor className tys ty
_ -> throwError . errorMessage' ss $ InvalidDerivedInstance className tys 1
| className == Qualified (Just dataEq) (ProperName "Eq")
= case tys of
[ty] | Just (Qualified mn' tyCon, _) <- unwrapTypeConstructor ty
Expand Down Expand Up @@ -260,6 +267,9 @@ dataOrd = ModuleName "Data.Ord"
dataFunctor :: ModuleName
dataFunctor = ModuleName "Data.Functor"

dataHashable :: ModuleName
dataHashable = ModuleName "Data.Hashable"

unguarded :: Expr -> [GuardedExpr]
unguarded e = [MkUnguarded e]

Expand Down Expand Up @@ -554,6 +564,68 @@ deriveOrd1 ss =
dataOrdCompare :: Expr
dataOrdCompare = Var ss (Qualified (Just dataOrd) (Ident Prelude.compare))

deriveHashable :: forall m
. (MonadError MultipleErrors m, MonadSupply m)
=> SourceSpan
-> ModuleName
-> SynonymMap
-> KindMap
-> [Declaration]
-> ProperName 'TypeName
-> m [Declaration]
deriveHashable ss mn syns kinds ds tyConNm = do
tyCon <- findTypeDecl ss tyConNm ds
hashFun <- mkHashFunction tyCon
return [ ValueDecl (ss, []) (Ident Prelude.hashWithSalt) Public [] (unguarded hashFun) ]
where
mkHashFunction :: Declaration -> m Expr
mkHashFunction (DataDeclaration (ss', _) _ _ _ args) = do
s <- freshIdent "s"
x <- freshIdent "x"
lam ss s <$> lamCase ss' x <$> mapM (mkCase s) (zip args [0..])
mkHashFunction _ = internalError "mkHashFunction: expected DataDeclaration"

mkCase :: Ident -> (DataConstructorDeclaration, Int) -> m CaseAlternative
mkCase s (DataConstructorDeclaration _ dataCtorName dataCtorFields, nth) = do
xs <- replicateM (length dataCtorFields) (freshIdent "x")
let binder = ConstructorBinder ss (Qualified (Just mn) dataCtorName) (map (VarBinder ss) xs)
dataCtorFields' <- mapM (replaceAllTypeSynonymsM syns kinds . snd) dataCtorFields

-- collect all record and constructor fields
let fields = collectFields (zip (map var xs) dataCtorFields')
-- generate names for hash results (|fields| + 1 for constructor discrimination and initial salt)
hashes <- replicateM (length fields + 1) (freshIdent "hash")

-- build a chain of hashWithSalt calls, starting with the seed s and the
-- constructor discriminator nth and then successively feeding the output
-- as the seed to a further hashWithSalt call for each field
let chain = foldr callHashWithSalt (wrapHash (var (last hashes)))
(zip3 (s:hashes) (Literal ss (NumericLiteral (Left (toEnum nth))):fields) hashes)

return $ CaseAlternative [binder] (unguarded chain)

-- Creates this expression: case $prev `hashWithSalt` $expression of
-- Hash $next -> $body
callHashWithSalt (prev, expression, next) body =
Case [ App (App (Var ss (Qualified (Just dataHashable) (Ident Prelude.hashWithSalt))) (var prev)) expression ]
[ CaseAlternative [ hashBinder next ] (unguarded body) ]

collectFields [] = []
collectFields ((e,ty):rest)
| Just rec <- objectType ty
, Just fields <- decomposeRec rec
= collectFields (map (\((Label str), fieldTy) -> (Accessor str e, fieldTy)) fields) ++ collectFields rest
| otherwise = e:collectFields rest

var = Var ss . Qualified Nothing

dataHashableHashNewtypeConstructor = (Qualified (Just dataHashable) (ProperName "Hash"))

wrapHash = App (Constructor ss dataHashableHashNewtypeConstructor)

hashBinder :: Ident -> Binder
hashBinder hx = ConstructorBinder ss dataHashableHashNewtypeConstructor [ VarBinder ss hx ]

deriveNewtype
:: forall m
. (MonadError MultipleErrors m, MonadSupply m)
Expand Down

0 comments on commit ad31518

Please sign in to comment.