Skip to content

Commit

Permalink
Merge pull request #2 from andy-morris/master
Browse files Browse the repository at this point in the history
Allow fields to occur in multiple constructors
  • Loading branch information
trskop committed Jul 25, 2016
2 parents bab85ce + 56fe8bf commit 9602be2
Showing 1 changed file with 70 additions and 25 deletions.
95 changes: 70 additions & 25 deletions src/Data/OverloadedRecords/TH/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,10 @@ module Data.OverloadedRecords.TH.Internal
)
where

import Prelude (Num((-)), fromIntegral)
import Prelude (Num((-)), fromIntegral, fst, unzip, lookup, Eq((==)))

import Control.Applicative (Applicative((<*>)))
import Control.Arrow (Arrow((***)))
import Control.Monad (Monad((>>=) _FAIL_IN_MONAD, return), replicateM)
#if 0
#if HAVE_MONAD_FAIL && MIN_VERSION_template_haskell(2,11,0)
Expand All @@ -94,7 +95,7 @@ import Control.Monad.Fail (MonadFail(fail))
import Data.Bool (Bool(False), otherwise)
import qualified Data.Char as Char (toLower)
import Data.Foldable (concat, foldl)
import Data.Function ((.), ($))
import Data.Function ((.), ($), flip)
import Data.Functor (Functor(fmap), (<$>))
import qualified Data.List as List
( drop
Expand All @@ -104,10 +105,10 @@ import qualified Data.List as List
, replicate
, zip
)
import Data.Maybe (Maybe(Just, Nothing), fromMaybe)
import Data.Maybe (Maybe(Just, Nothing), fromMaybe, catMaybes)
import Data.Monoid ((<>))
import Data.String (String)
import Data.Traversable (forM, mapM)
import Data.Traversable (mapM, sequence)
import Data.Typeable (Typeable)
import Data.Word (Word)
import GHC.Generics (Generic)
Expand Down Expand Up @@ -318,6 +319,8 @@ instance Default DeriveOverloadedRecordsParams where
}

-- | Derive magic OverloadedRecordFields instances for specified type.
-- Fails if different record fields within the same type would map to the
-- same overloaded label.
overloadedRecord
:: DeriveOverloadedRecordsParams
-- ^ Parameters for customization of deriving process. Use 'def' to get
Expand All @@ -334,14 +337,19 @@ overloadedRecord params = withReified $ \name -> \case
#else
NewtypeD [] typeName typeVars constructor _deriving ->
#endif
deriveForConstructor params typeName typeVars constructor
fst $ deriveForConstructor params [] typeName typeVars constructor
#if MIN_VERSION_template_haskell(2,11,0)
DataD [] typeName typeVars _kindSignature constructors _deriving ->
#else
DataD [] typeName typeVars constructors _deriving ->
#endif
fmap concat . forM constructors
$ deriveForConstructor params typeName typeVars
fst $ foldl go (return [], []) constructors
where
go :: (DecsQ, [(String, String)]) -> Con -> (DecsQ, [(String, String)])
go (decs, seen) con =
let (decs', seen') =
deriveForConstructor params seen typeName typeVars con
in ((<>) <$> decs <*> decs', seen <> seen')
x -> canNotDeriveError name x

x -> canNotDeriveError name x
Expand Down Expand Up @@ -411,11 +419,14 @@ deriveForConstructor
:: DeriveOverloadedRecordsParams
-- ^ Parameters for customization of deriving process. Use 'def' to get
-- default behaviour.
-> [(String, String)]
-- ^ Pairs of instances already generated along with the field names
-- they were made from.
-> Name
-> [TyVarBndr]
-> Con
-> DecsQ
deriveForConstructor params name typeVars = \case
-> (DecsQ, [(String, String)])
deriveForConstructor params seen name typeVars = \case
NormalC constructorName args ->
deriveFor constructorName args $ \(strict, argType) f ->
f Nothing strict argType
Expand All @@ -429,22 +440,23 @@ deriveForConstructor params name typeVars = \case
f Nothing strict argType

#if MIN_VERSION_template_haskell(2,11,0)
GadtC _ _ _ -> fail "GADTs aren't yet supported."
RecGadtC _ _ _ -> fail "GADTs aren't yet supported."
GadtC _ _ _ -> (fail "GADTs aren't yet supported.", [])
RecGadtC _ _ _ -> (fail "GADTs aren't yet supported.", [])
#endif

-- Existentials aren't supported.
ForallC _typeVariables _context _constructor -> return []
ForallC _typeVariables _context _constructor -> (return [], [])
where
deriveFor
:: Name
-> [a]
-> (a -> (Maybe Name -> Strict -> Type -> DecsQ) -> DecsQ)
-> DecsQ
-> (a -> (Maybe Name -> Strict -> Type -> (DecsQ, Maybe (String, String)))
-> (DecsQ, Maybe (String, String)))
-> (DecsQ, [(String, String)])
deriveFor constrName args f =
fmap concat . forM (withIndexes args) $ \(idx, arg) ->
concatBoth . flip fmap (withIndexes args) $ \(idx, arg) ->
f arg $ \accessor strict fieldType' ->
deriveForField params DeriveFieldParams
deriveForField params seen DeriveFieldParams
{ typeName = name
, typeVariables = List.map getTypeName typeVars
, constructorName = constrName
Expand All @@ -460,6 +472,9 @@ deriveForConstructor params name typeVars = \case
PlainTV n -> n
KindedTV n _kind -> n

concatBoth :: [(Q [a], Maybe b)] -> (Q [a], [b])
concatBoth = (fmap concat . sequence *** catMaybes) . unzip

withIndexes = List.zip [(0 :: Word) ..]

-- | Parameters for 'deriveForField' function.
Expand Down Expand Up @@ -490,27 +505,57 @@ deriveForField
:: DeriveOverloadedRecordsParams
-- ^ Parameters for customization of deriving process. Use 'def' to get
-- default behaviour.
-> [(String, String)]
-- ^ Pairs of instances already generated along with the field names
-- they were made from.
-> DeriveFieldParams
-- ^ All the necessary information for derivation procedure.
-> DecsQ
deriveForField params DeriveFieldParams{..} =
-> (DecsQ, Maybe (String, String))
-- If instances were generated, then the second part is a pair
-- (instanceLabel, fieldLabel)
deriveForField params seen DeriveFieldParams{..} =
case possiblyLabel of
Nothing -> return []
Nothing -> (return [], Nothing)
Just (GetterOnlyField label customGetterExpr) ->
deriveGetter' (strTyLitT label)
$ fromMaybe derivedGetterExpr customGetterExpr
Just (GetterAndSetterField label customGetterAndSetterExpr) -> (<>)
<$> deriveGetter' labelType getterExpr
<*> deriveSetter' labelType setterExpr
case lookup label seen of
Just from ->
if Just from == accessorBase then
(return [],
Nothing)
else
(fail $ "Two different fields map to label \""
<> from <> "\"",
Nothing)
Nothing ->
(deriveGetter' (strTyLitT label)
$ fromMaybe derivedGetterExpr customGetterExpr,
(,) label <$> accessorBase)
Just (GetterAndSetterField label customGetterAndSetterExpr) ->
case lookup label seen of
Just from ->
if Just from == accessorBase then
(return [],
Nothing)
else
(fail $ "Two different fields map to label \""
<> from <> "\"",
Nothing)
Nothing ->
((<>)
<$> deriveGetter' labelType getterExpr
<*> deriveSetter' labelType setterExpr,
(,) label <$> accessorBase)
where
labelType = strTyLitT label

(getterExpr, setterExpr) =
fromMaybe (derivedGetterExpr, derivedSetterExpr)
customGetterAndSetterExpr
where
accessorBase = fmap nameBase accessorName

possiblyLabel = _fieldDerivation params (nameBase typeName)
(nameBase constructorName) currentIndex (fmap nameBase accessorName)
(nameBase constructorName) currentIndex accessorBase

deriveGetter' labelType =
deriveGetter labelType recordType (return fieldType)
Expand Down

0 comments on commit 9602be2

Please sign in to comment.