Skip to content

Commit

Permalink
Add support for tracking unknown fields. (google#129)
Browse files Browse the repository at this point in the history
- Every message has a new field containing a list of unknown fields:
  ```unknownFields :: Message msg => Lens' msg [TaggedValue]```
- Unknown fields are preserved by `decodeMessage`, `encodeMessage`,
  and `showMessage`
- Unknown fields still cause an error for `readMessage`.

A few TODOs:
- For now, unknown groups are printed sub-optimally by `showMessage`: the
  start/end group tags (and everything in between) all get displayed as
  individual fields, rather than being organized into a sub-struct.
- The `discardUnknownFields` function isn't recursive, unlike in other
  languages.
- The Ord instance doesn't try to do anything special, just treating
  the unknown fields as a list of values.  If it really matters then
  `discardUnknownFields` can help resolve the ambiguity.
  • Loading branch information
judah committed Sep 1, 2017
1 parent 4ccec38 commit a94a1af
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 17 deletions.
3 changes: 2 additions & 1 deletion proto-lens.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,17 @@ library
hs-source-dirs: src
exposed-modules: Data.ProtoLens
Data.ProtoLens.Encoding
Data.ProtoLens.Encoding.Wire
Data.ProtoLens.Message
Data.ProtoLens.Message.Enum
Data.ProtoLens.TextFormat
other-modules: Data.ProtoLens.Encoding.Bytes
Data.ProtoLens.Encoding.Wire
Data.ProtoLens.TextFormat.Parser
build-depends: attoparsec == 0.13.*
, base >= 4.8 && < 4.11
, bytestring == 0.10.*
, containers == 0.5.*
, deepseq == 1.4.*
, data-default-class >= 0.0 && < 0.2
, lens-family == 1.2.*
, parsec == 3.1.*
Expand Down
17 changes: 11 additions & 6 deletions src/Data/ProtoLens/Encoding.hs
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,23 @@ parseMessage :: forall msg . Message msg => Parser () -> Parser msg
parseMessage end = do
(msg, unsetFields) <- loop def requiredFields
if Map.null unsetFields
then return $ reverseRepeatedFields fields msg
then return $ over unknownFields reverse
$ reverseRepeatedFields fields msg
else fail $ "Missing required fields "
++ show (map fieldDescriptorName
$ Map.elems $ unsetFields)
where
fields = fieldsByTag descriptor
addUnknown :: TaggedValue -> msg -> msg
addUnknown !f = over' unknownFields (f :)
requiredFields = Map.filter isRequired fields
loop :: msg -> Map.Map Tag (FieldDescriptor msg)
-> Parser (msg, Map.Map Tag (FieldDescriptor msg))
loop msg unsetFields = ((msg, unsetFields) <$ end)
<|> do
tv@(TaggedValue tag _) <- getTaggedValue
case Map.lookup tag fields of
Nothing -> loop msg unsetFields
Nothing -> (loop $! addUnknown tv msg) unsetFields
Just field -> do
!msg' <- parseAndAddField msg field tv
loop msg' $! Map.delete tag unsetFields
Expand Down Expand Up @@ -172,10 +175,12 @@ buildMessageDelimited msg =

-- | Encode a message as a sequence of key-value pairs.
messageToTaggedValues :: Message msg => msg -> [TaggedValue]
messageToTaggedValues msg = mconcat
[ messageFieldToVals tag fieldDescr msg
| (tag, fieldDescr) <- Map.toList (fieldsByTag descriptor)
]
messageToTaggedValues msg =
mconcat
[ messageFieldToVals tag fieldDescr msg
| (tag, fieldDescr) <- Map.toList (fieldsByTag descriptor)
]
++ msg ^. unknownFields

messageFieldToVals :: Tag -> FieldDescriptor a -> a -> [TaggedValue]
messageFieldToVals tag (FieldDescriptor _ typeDescriptor accessor) msg =
Expand Down
45 changes: 43 additions & 2 deletions src/Data/ProtoLens/Encoding/Wire.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
-- | Module defining the individual base wire types (e.g. VarInt, Fixed64) and
-- how to encode/decode them.
Expand All @@ -21,8 +23,10 @@ module Data.ProtoLens.Encoding.Wire(
putWireValue,
Equal(..),
equalWireTypes,
decodeFieldSet,
) where

import Control.DeepSeq (NFData(..))
import Data.Attoparsec.ByteString as Parse
import Data.Bits
import qualified Data.ByteString as B
Expand All @@ -33,6 +37,9 @@ import Data.Word
import Data.ProtoLens.Encoding.Bytes

data WireType a where
-- Note: all of these types are fully strict (vs, say,
-- Data.ByteString.Lazy.ByteString). If that changes, we'll
-- need to update the NFData instance.
VarInt :: WireType Word64
Fixed64 :: WireType Word64
Fixed32 :: WireType Word32
Expand All @@ -47,16 +54,35 @@ instance Show (WireType a) where
-- A value read from the wire
data WireValue = forall a . WireValue !(WireType a) !a

instance Show WireValue where
show (WireValue VarInt x) = show x
show (WireValue Fixed64 x) = show x
show (WireValue Fixed32 x) = show x
show (WireValue Lengthy x) = show x
show (WireValue StartGroup x) = show x
show (WireValue EndGroup x) = show x


-- The wire contents of a single key-value pair in a Message.
data TaggedValue = TaggedValue !Tag !WireValue
deriving (Show, Eq, Ord)

-- TaggedValue, WireValue and Tag are strict, so their NFData instances are
-- trivial:
instance NFData TaggedValue where
rnf = (`seq` ())

instance NFData WireValue where
rnf = (`seq` ())

-- | A tag that identifies a particular field of the message when converting
-- to/from the wire format.
newtype Tag = Tag { unTag :: Int}
deriving (Show, Eq, Ord, Num)
deriving (Show, Eq, Ord, Num, NFData)

data Equal a b where
Equal :: Equal a a
-- TODO: move Eq/Ord instance somewhere else?
Equal :: (Eq a, Ord a) => Equal a a

-- Assert that two wire types are the same, or fail with a message about this
-- field.
Expand All @@ -73,6 +99,18 @@ equalWireTypes name expected actual
= fail $ "Field " ++ name ++ " expects wire type " ++ show expected
++ " but found " ++ show actual

instance Eq WireValue where
WireValue t v == WireValue t' v'
| Just Equal <- equalWireTypes "" t t'
= v == v'
| otherwise = False

instance Ord WireValue where
WireValue t v `compare` WireValue t' v'
| Just Equal <- equalWireTypes "" t t'
= v `compare` v'
| otherwise = wireTypeToInt t `compare` wireTypeToInt t'

getWireValue :: WireType a -> Parser a
getWireValue VarInt = getVarInt
getWireValue Fixed64 = anyBits
Expand Down Expand Up @@ -129,3 +167,6 @@ getTaggedValue = do
putTaggedValue :: TaggedValue -> Builder
putTaggedValue (TaggedValue tag (WireValue wt val)) =
putTypeAndTag wt tag <> putWireValue wt val

decodeFieldSet :: B.ByteString -> Either String [TaggedValue]
decodeFieldSet = parseOnly (manyTill getTaggedValue endOfInput)
22 changes: 20 additions & 2 deletions src/Data/ProtoLens/Message.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ module Data.ProtoLens.Message (
maybeLens,
-- * Internal utilities for parsing protocol buffers
reverseRepeatedFields,
-- * Unknown fields
FieldSet,
TaggedValue(..),
unknownFields,
discardUnknownFields,
) where

import qualified Data.ByteString as B
Expand All @@ -52,10 +57,13 @@ import Data.Maybe (fromMaybe)
import Data.Proxy (Proxy(..))
import qualified Data.Text as T
import Data.Word
import Lens.Family2 (Lens', over)
import Lens.Family2 (Lens', over, set)
import Lens.Family2.Unchecked (lens)

import Data.ProtoLens.Encoding.Wire (Tag(..))
import Data.ProtoLens.Encoding.Wire
( Tag(..)
, TaggedValue(..)
)

-- | Every protocol buffer is an instance of 'Message'. This class enables
-- serialization by providing reflection of all of the fields that may be used
Expand All @@ -76,8 +84,10 @@ data MessageDescriptor msg = MessageDescriptor
-- which use their Message type name in text protos instead of their
-- field name. For example, "optional group Foo" has the field name "foo"
-- but in this map it is stored with the key "Foo".
, unknownFieldsLens :: Lens' msg FieldSet
}

type FieldSet = [TaggedValue]

-- | A description of a specific field of a protocol buffer.
--
Expand Down Expand Up @@ -285,3 +295,11 @@ lookupRegistered n (Registry m) = Map.lookup (snd $ T.breakOnEnd "/" n) m

data SomeMessageType where
SomeMessageType :: Message msg => Proxy msg -> SomeMessageType

-- TODO: recursively
discardUnknownFields :: Message msg => msg -> msg
discardUnknownFields = set unknownFields []

-- | Access the unknown fields of a Message.
unknownFields :: Message msg => Lens' msg FieldSet
unknownFields = unknownFieldsLens descriptor
32 changes: 26 additions & 6 deletions src/Data/ProtoLens/TextFormat.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import Numeric (showOct)
import Text.Parsec (parse)
import Text.PrettyPrint

import Data.ProtoLens.Encoding.Wire
import Data.ProtoLens.Message
import qualified Data.ProtoLens.TextFormat.Parser as Parser

Expand Down Expand Up @@ -86,7 +87,8 @@ pprintMessage' reg descr msg
-- for each field. We use a single "sep" for all fields (and all elements
-- of all the repeated fields) to avoid putting some repeated fields on one
-- line and other fields on multiple lines, which is less readable.
= sep $ concatMap (pprintField reg msg) $ Map.elems $ fieldsByTag descr
= sep $ concatMap (pprintField reg msg) (Map.elems $ fieldsByTag descr)
++ map pprintTaggedValue (msg ^. unknownFieldsLens descr)

pprintField :: Registry -> msg -> FieldDescriptor msg -> [Doc]
pprintField reg msg (FieldDescriptor name typeDescr accessor)
Expand All @@ -109,14 +111,13 @@ pprintFieldValue reg name field@MessageField m
fieldData <- view anyValueLens m,
Just (SomeMessageType (Proxy :: Proxy value')) <- lookupRegistered typeUri reg,
Right (anyValue :: value') <- decodeMessage fieldData =
sep [ text name <+> lbrace
, nest 2 $ sep
pprintSubmessage name
$ sep
[ lbrack <> text (Text.unpack typeUri) <> rbrack <+> lbrace
, nest 2 (pprintMessageWithRegistry reg anyValue)
, rbrace ]
, rbrace ]
| otherwise =
sep [text name <+> lbrace, nest 2 (pprintMessageWithRegistry reg m), rbrace]
pprintSubmessage name (pprintMessageWithRegistry reg m)
pprintFieldValue _ name EnumField x = text name <> colon <+> text (showEnum x)
pprintFieldValue _ name Int32Field x = primField name x
pprintFieldValue _ name Int64Field x = primField name x
Expand All @@ -134,7 +135,11 @@ pprintFieldValue _ name BoolField x = text name <> colon <+> boolValue x
pprintFieldValue _ name StringField x = pprintByteString name (Text.encodeUtf8 x)
pprintFieldValue _ name BytesField x = pprintByteString name x
pprintFieldValue reg name GroupField m
= text name <+> lbrace $$ nest 2 (pprintMessageWithRegistry reg m) $$ rbrace
= pprintSubmessage name (pprintMessageWithRegistry reg m)

pprintSubmessage :: String -> Doc -> Doc
pprintSubmessage name contents =
sep [text name <+> lbrace, nest 2 contents, rbrace]

-- | Formats a string in a way that mostly matches the C-compatible escaping
-- used by the Protocol Buffer distribution. We depart a bit by escaping all
Expand Down Expand Up @@ -167,6 +172,21 @@ boolValue :: Bool -> Doc
boolValue True = text "true"
boolValue False = text "false"

pprintTaggedValue :: TaggedValue -> Doc
pprintTaggedValue (TaggedValue t (WireValue v x)) = case v of
VarInt -> primField name x
Fixed64 -> primField name x
Fixed32 -> primField name x
Lengthy -> case decodeFieldSet x of
Left _ -> pprintByteString name x
Right ts -> pprintSubmessage name
$ sep $ map pprintTaggedValue ts
-- TODO: implement better printing for unknown groups
StartGroup -> text name <> colon <+> text "start_group"
EndGroup -> text name <> colon <+> text "end_group"
where
name = show (unTag t)

--------------------------------------------------------------------------------
-- Parsing

Expand Down

0 comments on commit a94a1af

Please sign in to comment.