Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow fields to occur in multiple constructors #2

Merged
merged 1 commit into from
Jul 25, 2016
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 70 additions & 25 deletions src/Data/OverloadedRecords/TH/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,10 @@ module Data.OverloadedRecords.TH.Internal
)
where

import Prelude (Num((-)), fromIntegral)
import Prelude (Num((-)), fromIntegral, fst, unzip, lookup, Eq((==)))

import Control.Applicative (Applicative((<*>)))
import Control.Arrow (Arrow((***)))
import Control.Monad (Monad((>>=) _FAIL_IN_MONAD, return), replicateM)
#if 0
#if HAVE_MONAD_FAIL && MIN_VERSION_template_haskell(2,11,0)
Expand All @@ -94,7 +95,7 @@ import Control.Monad.Fail (MonadFail(fail))
import Data.Bool (Bool(False), otherwise)
import qualified Data.Char as Char (toLower)
import Data.Foldable (concat, foldl)
import Data.Function ((.), ($))
import Data.Function ((.), ($), flip)
import Data.Functor (Functor(fmap), (<$>))
import qualified Data.List as List
( drop
Expand All @@ -104,10 +105,10 @@ import qualified Data.List as List
, replicate
, zip
)
import Data.Maybe (Maybe(Just, Nothing), fromMaybe)
import Data.Maybe (Maybe(Just, Nothing), fromMaybe, catMaybes)
import Data.Monoid ((<>))
import Data.String (String)
import Data.Traversable (forM, mapM)
import Data.Traversable (mapM, sequence)
import Data.Typeable (Typeable)
import Data.Word (Word)
import GHC.Generics (Generic)
Expand Down Expand Up @@ -318,6 +319,8 @@ instance Default DeriveOverloadedRecordsParams where
}

-- | Derive magic OverloadedRecordFields instances for specified type.
-- Fails if different record fields within the same type would map to the
-- same overloaded label.
overloadedRecord
:: DeriveOverloadedRecordsParams
-- ^ Parameters for customization of deriving process. Use 'def' to get
Expand All @@ -334,14 +337,19 @@ overloadedRecord params = withReified $ \name -> \case
#else
NewtypeD [] typeName typeVars constructor _deriving ->
#endif
deriveForConstructor params typeName typeVars constructor
fst $ deriveForConstructor params [] typeName typeVars constructor
#if MIN_VERSION_template_haskell(2,11,0)
DataD [] typeName typeVars _kindSignature constructors _deriving ->
#else
DataD [] typeName typeVars constructors _deriving ->
#endif
fmap concat . forM constructors
$ deriveForConstructor params typeName typeVars
fst $ foldl go (return [], []) constructors
where
go :: (DecsQ, [(String, String)]) -> Con -> (DecsQ, [(String, String)])
go (decs, seen) con =
let (decs', seen') =
deriveForConstructor params seen typeName typeVars con
in ((<>) <$> decs <*> decs', seen <> seen')
x -> canNotDeriveError name x

x -> canNotDeriveError name x
Expand Down Expand Up @@ -411,11 +419,14 @@ deriveForConstructor
:: DeriveOverloadedRecordsParams
-- ^ Parameters for customization of deriving process. Use 'def' to get
-- default behaviour.
-> [(String, String)]
-- ^ Pairs of instances already generated along with the field names
-- they were made from.
-> Name
-> [TyVarBndr]
-> Con
-> DecsQ
deriveForConstructor params name typeVars = \case
-> (DecsQ, [(String, String)])
deriveForConstructor params seen name typeVars = \case
NormalC constructorName args ->
deriveFor constructorName args $ \(strict, argType) f ->
f Nothing strict argType
Expand All @@ -429,22 +440,23 @@ deriveForConstructor params name typeVars = \case
f Nothing strict argType

#if MIN_VERSION_template_haskell(2,11,0)
GadtC _ _ _ -> fail "GADTs aren't yet supported."
RecGadtC _ _ _ -> fail "GADTs aren't yet supported."
GadtC _ _ _ -> (fail "GADTs aren't yet supported.", [])
RecGadtC _ _ _ -> (fail "GADTs aren't yet supported.", [])
#endif

-- Existentials aren't supported.
ForallC _typeVariables _context _constructor -> return []
ForallC _typeVariables _context _constructor -> (return [], [])
where
deriveFor
:: Name
-> [a]
-> (a -> (Maybe Name -> Strict -> Type -> DecsQ) -> DecsQ)
-> DecsQ
-> (a -> (Maybe Name -> Strict -> Type -> (DecsQ, Maybe (String, String)))
-> (DecsQ, Maybe (String, String)))
-> (DecsQ, [(String, String)])
deriveFor constrName args f =
fmap concat . forM (withIndexes args) $ \(idx, arg) ->
concatBoth . flip fmap (withIndexes args) $ \(idx, arg) ->
f arg $ \accessor strict fieldType' ->
deriveForField params DeriveFieldParams
deriveForField params seen DeriveFieldParams
{ typeName = name
, typeVariables = List.map getTypeName typeVars
, constructorName = constrName
Expand All @@ -460,6 +472,9 @@ deriveForConstructor params name typeVars = \case
PlainTV n -> n
KindedTV n _kind -> n

concatBoth :: [(Q [a], Maybe b)] -> (Q [a], [b])
concatBoth = (fmap concat . sequence *** catMaybes) . unzip

withIndexes = List.zip [(0 :: Word) ..]

-- | Parameters for 'deriveForField' function.
Expand Down Expand Up @@ -490,27 +505,57 @@ deriveForField
:: DeriveOverloadedRecordsParams
-- ^ Parameters for customization of deriving process. Use 'def' to get
-- default behaviour.
-> [(String, String)]
-- ^ Pairs of instances already generated along with the field names
-- they were made from.
-> DeriveFieldParams
-- ^ All the necessary information for derivation procedure.
-> DecsQ
deriveForField params DeriveFieldParams{..} =
-> (DecsQ, Maybe (String, String))
-- If instances were generated, then the second part is a pair
-- (instanceLabel, fieldLabel)
deriveForField params seen DeriveFieldParams{..} =
case possiblyLabel of
Nothing -> return []
Nothing -> (return [], Nothing)
Just (GetterOnlyField label customGetterExpr) ->
deriveGetter' (strTyLitT label)
$ fromMaybe derivedGetterExpr customGetterExpr
Just (GetterAndSetterField label customGetterAndSetterExpr) -> (<>)
<$> deriveGetter' labelType getterExpr
<*> deriveSetter' labelType setterExpr
case lookup label seen of
Just from ->
if Just from == accessorBase then
(return [],
Nothing)
else
(fail $ "Two different fields map to label \""
<> from <> "\"",
Nothing)
Nothing ->
(deriveGetter' (strTyLitT label)
$ fromMaybe derivedGetterExpr customGetterExpr,
(,) label <$> accessorBase)
Just (GetterAndSetterField label customGetterAndSetterExpr) ->
case lookup label seen of
Just from ->
if Just from == accessorBase then
(return [],
Nothing)
else
(fail $ "Two different fields map to label \""
<> from <> "\"",
Nothing)
Nothing ->
((<>)
<$> deriveGetter' labelType getterExpr
<*> deriveSetter' labelType setterExpr,
(,) label <$> accessorBase)
where
labelType = strTyLitT label

(getterExpr, setterExpr) =
fromMaybe (derivedGetterExpr, derivedSetterExpr)
customGetterAndSetterExpr
where
accessorBase = fmap nameBase accessorName

possiblyLabel = _fieldDerivation params (nameBase typeName)
(nameBase constructorName) currentIndex (fmap nameBase accessorName)
(nameBase constructorName) currentIndex accessorBase

deriveGetter' labelType =
deriveGetter labelType recordType (return fieldType)
Expand Down