Skip to content
Permalink
Browse files

enforce recursion depth checking for unknown fields

  • Loading branch information
jtattermusch authored and acozzette committed Jan 24, 2020
1 parent ac70b7c commit f20be839276cfc1129c12d89924164624ef3796d
@@ -33,6 +33,7 @@
using System;
using System.IO;
using Google.Protobuf.TestProtos;
using Proto2 = Google.Protobuf.TestProtos.Proto2;
using NUnit.Framework;

namespace Google.Protobuf
@@ -337,6 +338,66 @@ public void MaliciousRecursion()
CodedInputStream input = CodedInputStream.CreateWithLimits(new MemoryStream(atRecursiveLimit.ToByteArray()), 1000000, CodedInputStream.DefaultRecursionLimit - 1);
Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(input));
}

private static byte[] MakeMaliciousRecursionUnknownFieldsPayload(int recursionDepth)
{
// generate recursively nested groups that will be parsed as unknown fields
int unknownFieldNumber = 14; // an unused field number
MemoryStream ms = new MemoryStream();
CodedOutputStream output = new CodedOutputStream(ms);
for (int i = 0; i < recursionDepth; i++)
{
output.WriteTag(WireFormat.MakeTag(unknownFieldNumber, WireFormat.WireType.StartGroup));
}
for (int i = 0; i < recursionDepth; i++)
{
output.WriteTag(WireFormat.MakeTag(unknownFieldNumber, WireFormat.WireType.EndGroup));
}
output.Flush();
return ms.ToArray();
}

[Test]
public void MaliciousRecursion_UnknownFields()
{
byte[] payloadAtRecursiveLimit = MakeMaliciousRecursionUnknownFieldsPayload(CodedInputStream.DefaultRecursionLimit);
byte[] payloadBeyondRecursiveLimit = MakeMaliciousRecursionUnknownFieldsPayload(CodedInputStream.DefaultRecursionLimit + 1);

Assert.DoesNotThrow(() => TestRecursiveMessage.Parser.ParseFrom(payloadAtRecursiveLimit));
Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(payloadBeyondRecursiveLimit));
}

[Test]
public void ReadGroup_WrongEndGroupTag()
{
int groupFieldNumber = Proto2.TestAllTypes.OptionalGroupFieldNumber;

// write Proto2.TestAllTypes with "optional_group" set, but use wrong EndGroup closing tag
MemoryStream ms = new MemoryStream();
CodedOutputStream output = new CodedOutputStream(ms);
output.WriteTag(WireFormat.MakeTag(groupFieldNumber, WireFormat.WireType.StartGroup));
output.WriteGroup(new Proto2.TestAllTypes.Types.OptionalGroup { A = 12345 });
// end group with different field number
output.WriteTag(WireFormat.MakeTag(groupFieldNumber + 1, WireFormat.WireType.EndGroup));
output.Flush();
var payload = ms.ToArray();

Assert.Throws<InvalidProtocolBufferException>(() => Proto2.TestAllTypes.Parser.ParseFrom(payload));
}

[Test]
public void ReadGroup_UnknownFields_WrongEndGroupTag()
{
MemoryStream ms = new MemoryStream();
CodedOutputStream output = new CodedOutputStream(ms);
output.WriteTag(WireFormat.MakeTag(14, WireFormat.WireType.StartGroup));
// end group with different field number
output.WriteTag(WireFormat.MakeTag(15, WireFormat.WireType.EndGroup));
output.Flush();
var payload = ms.ToArray();

Assert.Throws<InvalidProtocolBufferException>(() => TestRecursiveMessage.Parser.ParseFrom(payload));
}

[Test]
public void SizeLimit()
@@ -735,4 +796,4 @@ public override int Read(byte[] buffer, int offset, int count)
}
}
}
}
}
@@ -307,10 +307,17 @@ internal void CheckReadEndOfStreamTag()
throw InvalidProtocolBufferException.MoreDataAvailable();
}
}
#endregion

internal void CheckLastTagWas(uint expectedTag)
{
if (lastTag != expectedTag) {
throw InvalidProtocolBufferException.InvalidEndTag();
}
}
#endregion

#region Reading of tags etc


/// <summary>
/// Peeks at the next field tag. This is like calling <see cref="ReadTag"/>, but the
/// tag is not consumed. (So a subsequent call to <see cref="ReadTag"/> will return the
@@ -636,7 +643,27 @@ public void ReadGroup(IMessage builder)
throw InvalidProtocolBufferException.RecursionLimitExceeded();
}
++recursionDepth;

uint tag = lastTag;
int fieldNumber = WireFormat.GetTagFieldNumber(tag);

builder.MergeFrom(this);
CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));
--recursionDepth;
}

/// <summary>
/// Reads an embedded group unknown field from the stream.
/// </summary>
internal void ReadGroup(int fieldNumber, UnknownFieldSet set)
{
if (recursionDepth >= recursionLimit)
{
throw InvalidProtocolBufferException.RecursionLimitExceeded();
}
++recursionDepth;
set.MergeGroupFrom(this);
CheckLastTagWas(WireFormat.MakeTag(fieldNumber, WireFormat.WireType.EndGroup));
--recursionDepth;
}

@@ -215,12 +215,8 @@ private bool MergeFieldFrom(CodedInputStream input)
}
case WireFormat.WireType.StartGroup:
{
uint endTag = WireFormat.MakeTag(number, WireFormat.WireType.EndGroup);
UnknownFieldSet set = new UnknownFieldSet();
while (input.ReadTag() != endTag)
{
set.MergeFieldFrom(input);
}
input.ReadGroup(number, set);
GetOrAddField(number).AddGroup(set);
return true;
}
@@ -233,6 +229,22 @@ private bool MergeFieldFrom(CodedInputStream input)
}
}

internal void MergeGroupFrom(CodedInputStream input)
{
while (true)
{
uint tag = input.ReadTag();
if (tag == 0)
{
break;
}
if (!MergeFieldFrom(input))
{
break;
}
}
}

/// <summary>
/// Create a new UnknownFieldSet if unknownFields is null.
/// Parse a single field from <paramref name="input"/> and merge it

0 comments on commit f20be83

Please sign in to comment.
You can’t perform that action at this time.