Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deriving Hashable #3446

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/Language/PureScript/Constants/Prelude.hs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ returnEscaped = "$return"
unit :: forall a. (IsString a) => a
unit = "unit"

hashWithSalt :: forall a. (IsString a) => a
hashWithSalt = "hashWithSalt"

-- Core lib values

runST :: forall a. (IsString a) => a
Expand Down
72 changes: 72 additions & 0 deletions src/Language/PureScript/Sugar/TypeClasses/Deriving.hs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ deriveInstance
-> Declaration
-> m Declaration
deriveInstance mn syns kinds _ ds (TypeInstanceDeclaration sa@(ss, _) ch idx nm deps className tys DerivedInstance)
| className == Qualified (Just dataHashable) (ProperName "Hashable")
= case tys of
[ty] | Just (Qualified mn' tyCon, _) <- unwrapTypeConstructor ty
, mn == fromMaybe mn mn'
-> TypeInstanceDeclaration sa ch idx nm deps className tys . ExplicitInstance <$> deriveHashable ss mn syns kinds ds tyCon
| otherwise -> throwError . errorMessage' ss $ ExpectedTypeConstructor className tys ty
_ -> throwError . errorMessage' ss $ InvalidDerivedInstance className tys 1
| className == Qualified (Just dataEq) (ProperName "Eq")
= case tys of
[ty] | Just (Qualified mn' tyCon, _) <- unwrapTypeConstructor ty
Expand Down Expand Up @@ -260,6 +267,9 @@ dataOrd = ModuleName "Data.Ord"
dataFunctor :: ModuleName
dataFunctor = ModuleName "Data.Functor"

dataHashable :: ModuleName
dataHashable = ModuleName "Data.Hashable"

unguarded :: Expr -> [GuardedExpr]
unguarded e = [MkUnguarded e]

Expand Down Expand Up @@ -554,6 +564,68 @@ deriveOrd1 ss =
dataOrdCompare :: Expr
dataOrdCompare = Var ss (Qualified (Just dataOrd) (Ident Prelude.compare))

deriveHashable :: forall m
. (MonadError MultipleErrors m, MonadSupply m)
=> SourceSpan
-> ModuleName
-> SynonymMap
-> KindMap
-> [Declaration]
-> ProperName 'TypeName
-> m [Declaration]
deriveHashable ss mn syns kinds ds tyConNm = do
tyCon <- findTypeDecl ss tyConNm ds
hashFun <- mkHashFunction tyCon
return [ ValueDecl (ss, []) (Ident Prelude.hashWithSalt) Public [] (unguarded hashFun) ]
where
mkHashFunction :: Declaration -> m Expr
mkHashFunction (DataDeclaration (ss', _) _ _ _ args) = do
s <- freshIdent "s"
x <- freshIdent "x"
lam ss s <$> lamCase ss' x <$> mapM (mkCase s) (zip args [0..])
mkHashFunction _ = internalError "mkHashFunction: expected DataDeclaration"

mkCase :: Ident -> (DataConstructorDeclaration, Int) -> m CaseAlternative
mkCase s (DataConstructorDeclaration _ dataCtorName dataCtorFields, nth) = do
xs <- replicateM (length dataCtorFields) (freshIdent "x")
let binder = ConstructorBinder ss (Qualified (Just mn) dataCtorName) (map (VarBinder ss) xs)
dataCtorFields' <- mapM (replaceAllTypeSynonymsM syns kinds . snd) dataCtorFields

-- collect all record and constructor fields
let fields = collectFields (zip (map var xs) dataCtorFields')
-- generate names for hash results (|fields| + 1 for constructor discrimination and initial salt)
hashes <- replicateM (length fields + 1) (freshIdent "hash")

-- build a chain of hashWithSalt calls, starting with the seed s and the
-- constructor discriminator nth and then successively feeding the output
-- as the seed to a further hashWithSalt call for each field
let chain = foldr callHashWithSalt (wrapHash (var (last hashes)))
(zip3 (s:hashes) (Literal ss (NumericLiteral (Left (toEnum nth))):fields) hashes)

return $ CaseAlternative [binder] (unguarded chain)

-- Creates this expression: case $prev `hashWithSalt` $expression of
-- Hash $next -> $body
callHashWithSalt (prev, expression, next) body =
Case [ App (App (Var ss (Qualified (Just dataHashable) (Ident Prelude.hashWithSalt))) (var prev)) expression ]
[ CaseAlternative [ hashBinder next ] (unguarded body) ]

collectFields [] = []
collectFields ((e,ty):rest)
| Just rec <- objectType ty
, Just fields <- decomposeRec rec
= collectFields (map (\((Label str), fieldTy) -> (Accessor str e, fieldTy)) fields) ++ collectFields rest
| otherwise = e:collectFields rest

var = Var ss . Qualified Nothing

dataHashableHashNewtypeConstructor = (Qualified (Just dataHashable) (ProperName "Hash"))

wrapHash = App (Constructor ss dataHashableHashNewtypeConstructor)

hashBinder :: Ident -> Binder
hashBinder hx = ConstructorBinder ss dataHashableHashNewtypeConstructor [ VarBinder ss hx ]

deriveNewtype
:: forall m
. (MonadError MultipleErrors m, MonadSupply m)
Expand Down