diff --git a/proto-lens.cabal b/proto-lens.cabal index b241ff9e..fc01ecf5 100644 --- a/proto-lens.cabal +++ b/proto-lens.cabal @@ -26,16 +26,17 @@ library hs-source-dirs: src exposed-modules: Data.ProtoLens Data.ProtoLens.Encoding + Data.ProtoLens.Encoding.Wire Data.ProtoLens.Message Data.ProtoLens.Message.Enum Data.ProtoLens.TextFormat other-modules: Data.ProtoLens.Encoding.Bytes - Data.ProtoLens.Encoding.Wire Data.ProtoLens.TextFormat.Parser build-depends: attoparsec == 0.13.* , base >= 4.8 && < 4.11 , bytestring == 0.10.* , containers == 0.5.* + , deepseq == 1.4.* , data-default-class >= 0.0 && < 0.2 , lens-family == 1.2.* , parsec == 3.1.* diff --git a/src/Data/ProtoLens/Encoding.hs b/src/Data/ProtoLens/Encoding.hs index bc1b78ba..b00ae43f 100644 --- a/src/Data/ProtoLens/Encoding.hs +++ b/src/Data/ProtoLens/Encoding.hs @@ -54,12 +54,15 @@ parseMessage :: forall msg . Message msg => Parser () -> Parser msg parseMessage end = do (msg, unsetFields) <- loop def requiredFields if Map.null unsetFields - then return $ reverseRepeatedFields fields msg + then return $ over unknownFields reverse + $ reverseRepeatedFields fields msg else fail $ "Missing required fields " ++ show (map fieldDescriptorName $ Map.elems $ unsetFields) where fields = fieldsByTag descriptor + addUnknown :: TaggedValue -> msg -> msg + addUnknown !f = over' unknownFields (f :) requiredFields = Map.filter isRequired fields loop :: msg -> Map.Map Tag (FieldDescriptor msg) -> Parser (msg, Map.Map Tag (FieldDescriptor msg)) @@ -67,7 +70,7 @@ parseMessage end = do <|> do tv@(TaggedValue tag _) <- getTaggedValue case Map.lookup tag fields of - Nothing -> loop msg unsetFields + Nothing -> (loop $! addUnknown tv msg) unsetFields Just field -> do !msg' <- parseAndAddField msg field tv loop msg' $! Map.delete tag unsetFields @@ -172,10 +175,12 @@ buildMessageDelimited msg = -- | Encode a message as a sequence of key-value pairs. messageToTaggedValues :: Message msg => msg -> [TaggedValue] -messageToTaggedValues msg = mconcat - [ messageFieldToVals tag fieldDescr msg - | (tag, fieldDescr) <- Map.toList (fieldsByTag descriptor) - ] +messageToTaggedValues msg = + mconcat + [ messageFieldToVals tag fieldDescr msg + | (tag, fieldDescr) <- Map.toList (fieldsByTag descriptor) + ] + ++ msg ^. unknownFields messageFieldToVals :: Tag -> FieldDescriptor a -> a -> [TaggedValue] messageFieldToVals tag (FieldDescriptor _ typeDescriptor accessor) msg = diff --git a/src/Data/ProtoLens/Encoding/Wire.hs b/src/Data/ProtoLens/Encoding/Wire.hs index e3a7a7f1..f2780efa 100644 --- a/src/Data/ProtoLens/Encoding/Wire.hs +++ b/src/Data/ProtoLens/Encoding/Wire.hs @@ -6,6 +6,8 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE PatternGuards #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE RankNTypes #-} -- | Module defining the individual base wire types (e.g. VarInt, Fixed64) and -- how to encode/decode them. @@ -21,8 +23,10 @@ module Data.ProtoLens.Encoding.Wire( putWireValue, Equal(..), equalWireTypes, + decodeFieldSet, ) where +import Control.DeepSeq (NFData(..)) import Data.Attoparsec.ByteString as Parse import Data.Bits import qualified Data.ByteString as B @@ -33,6 +37,9 @@ import Data.Word import Data.ProtoLens.Encoding.Bytes data WireType a where + -- Note: all of these types are fully strict (vs, say, + -- Data.ByteString.Lazy.ByteString). If that changes, we'll + -- need to update the NFData instance. VarInt :: WireType Word64 Fixed64 :: WireType Word64 Fixed32 :: WireType Word32 @@ -47,16 +54,35 @@ instance Show (WireType a) where -- A value read from the wire data WireValue = forall a . WireValue !(WireType a) !a +instance Show WireValue where + show (WireValue VarInt x) = show x + show (WireValue Fixed64 x) = show x + show (WireValue Fixed32 x) = show x + show (WireValue Lengthy x) = show x + show (WireValue StartGroup x) = show x + show (WireValue EndGroup x) = show x + + -- The wire contents of a single key-value pair in a Message. data TaggedValue = TaggedValue !Tag !WireValue + deriving (Show, Eq, Ord) + +-- TaggedValue, WireValue and Tag are strict, so their NFData instances are +-- trivial: +instance NFData TaggedValue where + rnf = (`seq` ()) + +instance NFData WireValue where + rnf = (`seq` ()) -- | A tag that identifies a particular field of the message when converting -- to/from the wire format. newtype Tag = Tag { unTag :: Int} - deriving (Show, Eq, Ord, Num) + deriving (Show, Eq, Ord, Num, NFData) data Equal a b where - Equal :: Equal a a + -- TODO: move Eq/Ord instance somewhere else? + Equal :: (Eq a, Ord a) => Equal a a -- Assert that two wire types are the same, or fail with a message about this -- field. @@ -73,6 +99,18 @@ equalWireTypes name expected actual = fail $ "Field " ++ name ++ " expects wire type " ++ show expected ++ " but found " ++ show actual +instance Eq WireValue where + WireValue t v == WireValue t' v' + | Just Equal <- equalWireTypes "" t t' + = v == v' + | otherwise = False + +instance Ord WireValue where + WireValue t v `compare` WireValue t' v' + | Just Equal <- equalWireTypes "" t t' + = v `compare` v' + | otherwise = wireTypeToInt t `compare` wireTypeToInt t' + getWireValue :: WireType a -> Parser a getWireValue VarInt = getVarInt getWireValue Fixed64 = anyBits @@ -129,3 +167,6 @@ getTaggedValue = do putTaggedValue :: TaggedValue -> Builder putTaggedValue (TaggedValue tag (WireValue wt val)) = putTypeAndTag wt tag <> putWireValue wt val + +decodeFieldSet :: B.ByteString -> Either String [TaggedValue] +decodeFieldSet = parseOnly (manyTill getTaggedValue endOfInput) diff --git a/src/Data/ProtoLens/Message.hs b/src/Data/ProtoLens/Message.hs index 9b93b070..6f4544ba 100644 --- a/src/Data/ProtoLens/Message.hs +++ b/src/Data/ProtoLens/Message.hs @@ -41,6 +41,11 @@ module Data.ProtoLens.Message ( maybeLens, -- * Internal utilities for parsing protocol buffers reverseRepeatedFields, + -- * Unknown fields + FieldSet, + TaggedValue(..), + unknownFields, + discardUnknownFields, ) where import qualified Data.ByteString as B @@ -52,10 +57,13 @@ import Data.Maybe (fromMaybe) import Data.Proxy (Proxy(..)) import qualified Data.Text as T import Data.Word -import Lens.Family2 (Lens', over) +import Lens.Family2 (Lens', over, set) import Lens.Family2.Unchecked (lens) -import Data.ProtoLens.Encoding.Wire (Tag(..)) +import Data.ProtoLens.Encoding.Wire + ( Tag(..) + , TaggedValue(..) + ) -- | Every protocol buffer is an instance of 'Message'. This class enables -- serialization by providing reflection of all of the fields that may be used @@ -76,8 +84,10 @@ data MessageDescriptor msg = MessageDescriptor -- which use their Message type name in text protos instead of their -- field name. For example, "optional group Foo" has the field name "foo" -- but in this map it is stored with the key "Foo". + , unknownFieldsLens :: Lens' msg FieldSet } +type FieldSet = [TaggedValue] -- | A description of a specific field of a protocol buffer. -- @@ -285,3 +295,11 @@ lookupRegistered n (Registry m) = Map.lookup (snd $ T.breakOnEnd "/" n) m data SomeMessageType where SomeMessageType :: Message msg => Proxy msg -> SomeMessageType + +-- TODO: recursively +discardUnknownFields :: Message msg => msg -> msg +discardUnknownFields = set unknownFields [] + +-- | Access the unknown fields of a Message. +unknownFields :: Message msg => Lens' msg FieldSet +unknownFields = unknownFieldsLens descriptor diff --git a/src/Data/ProtoLens/TextFormat.hs b/src/Data/ProtoLens/TextFormat.hs index 9b8a5742..8d0ceaca 100644 --- a/src/Data/ProtoLens/TextFormat.hs +++ b/src/Data/ProtoLens/TextFormat.hs @@ -38,6 +38,7 @@ import Numeric (showOct) import Text.Parsec (parse) import Text.PrettyPrint +import Data.ProtoLens.Encoding.Wire import Data.ProtoLens.Message import qualified Data.ProtoLens.TextFormat.Parser as Parser @@ -86,7 +87,8 @@ pprintMessage' reg descr msg -- for each field. We use a single "sep" for all fields (and all elements -- of all the repeated fields) to avoid putting some repeated fields on one -- line and other fields on multiple lines, which is less readable. - = sep $ concatMap (pprintField reg msg) $ Map.elems $ fieldsByTag descr + = sep $ concatMap (pprintField reg msg) (Map.elems $ fieldsByTag descr) + ++ map pprintTaggedValue (msg ^. unknownFieldsLens descr) pprintField :: Registry -> msg -> FieldDescriptor msg -> [Doc] pprintField reg msg (FieldDescriptor name typeDescr accessor) @@ -109,14 +111,13 @@ pprintFieldValue reg name field@MessageField m fieldData <- view anyValueLens m, Just (SomeMessageType (Proxy :: Proxy value')) <- lookupRegistered typeUri reg, Right (anyValue :: value') <- decodeMessage fieldData = - sep [ text name <+> lbrace - , nest 2 $ sep + pprintSubmessage name + $ sep [ lbrack <> text (Text.unpack typeUri) <> rbrack <+> lbrace , nest 2 (pprintMessageWithRegistry reg anyValue) , rbrace ] - , rbrace ] | otherwise = - sep [text name <+> lbrace, nest 2 (pprintMessageWithRegistry reg m), rbrace] + pprintSubmessage name (pprintMessageWithRegistry reg m) pprintFieldValue _ name EnumField x = text name <> colon <+> text (showEnum x) pprintFieldValue _ name Int32Field x = primField name x pprintFieldValue _ name Int64Field x = primField name x @@ -134,7 +135,11 @@ pprintFieldValue _ name BoolField x = text name <> colon <+> boolValue x pprintFieldValue _ name StringField x = pprintByteString name (Text.encodeUtf8 x) pprintFieldValue _ name BytesField x = pprintByteString name x pprintFieldValue reg name GroupField m - = text name <+> lbrace $$ nest 2 (pprintMessageWithRegistry reg m) $$ rbrace + = pprintSubmessage name (pprintMessageWithRegistry reg m) + +pprintSubmessage :: String -> Doc -> Doc +pprintSubmessage name contents = + sep [text name <+> lbrace, nest 2 contents, rbrace] -- | Formats a string in a way that mostly matches the C-compatible escaping -- used by the Protocol Buffer distribution. We depart a bit by escaping all @@ -167,6 +172,21 @@ boolValue :: Bool -> Doc boolValue True = text "true" boolValue False = text "false" +pprintTaggedValue :: TaggedValue -> Doc +pprintTaggedValue (TaggedValue t (WireValue v x)) = case v of + VarInt -> primField name x + Fixed64 -> primField name x + Fixed32 -> primField name x + Lengthy -> case decodeFieldSet x of + Left _ -> pprintByteString name x + Right ts -> pprintSubmessage name + $ sep $ map pprintTaggedValue ts + -- TODO: implement better printing for unknown groups + StartGroup -> text name <> colon <+> text "start_group" + EndGroup -> text name <> colon <+> text "end_group" + where + name = show (unTag t) + -------------------------------------------------------------------------------- -- Parsing