diff --git a/_generated/field_limits.go b/_generated/field_limits.go new file mode 100644 index 00000000..d27542a2 --- /dev/null +++ b/_generated/field_limits.go @@ -0,0 +1,28 @@ +package _generated + +//go:generate msgp + +// Aliased types for testing +type AliasedSlice []string +type AliasedMap map[string]bool +type AliasedIntSlice []int + +// Test structures for field-level limit tags +type FieldLimitTestData struct { + SmallSlice []int `msg:"small_slice,limit=5"` + LargeSlice []string `msg:"large_slice,limit=100"` + SmallMap map[string]int `msg:"small_map,limit=3"` + LargeMap map[int]string `msg:"large_map,limit=20"` + NoLimit []byte `msg:"no_limit"` // Uses file-level limits if any + FixedArray [10]int `msg:"fixed_array,limit=2"` // Should be ignored +} + +// Test structure with aliased types and field limits +type AliasedFieldLimitTestData struct { + SmallAliasedSlice AliasedSlice `msg:"small_aliased_slice,limit=3"` + LargeAliasedSlice AliasedSlice `msg:"large_aliased_slice,limit=50"` + SmallAliasedMap AliasedMap `msg:"small_aliased_map,limit=2"` + LargeAliasedMap AliasedMap `msg:"large_aliased_map,limit=25"` + IntSliceAlias AliasedIntSlice `msg:"int_slice_alias,limit=10"` + NoLimitAlias AliasedSlice `msg:"no_limit_alias"` // Uses file-level limits +} \ No newline at end of file diff --git a/_generated/limits.go b/_generated/limits.go new file mode 100644 index 00000000..f0aa119f --- /dev/null +++ b/_generated/limits.go @@ -0,0 +1,35 @@ +//msgp:limit arrays:100 maps:50 + +package _generated + +//go:generate msgp + +// Test structures for limit directive +type LimitedData struct { + SmallArray [10]int `msg:"small_array"` + LargeSlice []byte `msg:"large_slice"` + SmallMap map[string]int `msg:"small_map"` +} + +type UnlimitedData struct { + BigArray [1000]int `msg:"big_array"` + BigSlice []string `msg:"big_slice"` + BigMap map[string][]int `msg:"big_map"` +} + +type LimitTestData struct { + SmallArray [10]int `msg:"small_array"` + LargeSlice []byte `msg:"large_slice"` + SmallMap map[string]int `msg:"small_map"` +} + +// Test field limits vs file limits precedence +// File limits: arrays:100 maps:50 +type FieldOverrideTestData struct { + TightSlice []int `msg:"tight_slice,limit=10"` // Field limit (10) < file limit (100) + LooseSlice []string `msg:"loose_slice,limit=200"` // Field limit (200) > file limit (100) + TightMap map[string]int `msg:"tight_map,limit=5"` // Field limit (5) < file limit (50) + LooseMap map[int]string `msg:"loose_map,limit=80"` // Field limit (80) > file limit (50) + DefaultSlice []byte `msg:"default_slice"` // No field limit, uses file limit (100) + DefaultMap map[string]string `msg:"default_map"` // No field limit, uses file limit (50) +} diff --git a/_generated/limits2.go b/_generated/limits2.go new file mode 100644 index 00000000..834925cd --- /dev/null +++ b/_generated/limits2.go @@ -0,0 +1,23 @@ +package _generated + +//go:generate msgp + +//msgp:limit arrays:200 maps:100 + +type LimitTestData2 struct { + BigArray [20]int `msg:"big_array"` + BigMap map[string]int `msg:"big_map"` +} + +// Test field limits vs file limits precedence with different file limits +// File limits: arrays:200 maps:100 +type FieldOverrideTestData2 struct { + TightSlice []int `msg:"tight_slice,limit=30"` // Field limit (30) < file limit (200) + MediumSlice []string `msg:"medium_slice,limit=150"` // Field limit (150) < file limit (200) + LooseSlice []byte `msg:"loose_slice,limit=300"` // Field limit (300) > file limit (200) + TightMap map[string]int `msg:"tight_map,limit=20"` // Field limit (20) < file limit (100) + MediumMap map[int]string `msg:"medium_map,limit=75"` // Field limit (75) < file limit (100) + LooseMap map[string][]int `msg:"loose_map,limit=150"` // Field limit (150) > file limit (100) + DefaultSlice []int `msg:"default_slice"` // No field limit, uses file limit (200) + DefaultMap map[int]int `msg:"default_map"` // No field limit, uses file limit (100) +} diff --git a/_generated/limits_test.go b/_generated/limits_test.go new file mode 100644 index 00000000..8a0fdebf --- /dev/null +++ b/_generated/limits_test.go @@ -0,0 +1,723 @@ +package _generated + +import ( + "bytes" + "fmt" + "testing" + + "github.com/tinylib/msgp/msgp" +) + +func TestSliceLimitEnforcement(t *testing.T) { + data := UnlimitedData{} + + // Test slice limit with DecodeMsg (using big_slice which is []string) + t.Run("DecodeMsg_SliceLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_slice") + buf = msgp.AppendArrayHeader(buf, 150) // Exceeds limit of 100 + + reader := msgp.NewReader(bytes.NewReader(buf)) + err := data.DecodeMsg(reader) + if err != msgp.ErrLimitExceeded { + t.Errorf("Expected ErrLimitExceeded, got %v", err) + } + }) + + // Test slice limit with UnmarshalMsg + t.Run("UnmarshalMsg_SliceLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_slice") + buf = msgp.AppendArrayHeader(buf, 150) // Exceeds limit of 100 + + _, err := data.UnmarshalMsg(buf) + if err != msgp.ErrLimitExceeded { + t.Errorf("Expected ErrLimitExceeded, got %v", err) + } + }) + + // Test that slices within limit work fine + t.Run("SliceWithinLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_slice") + buf = msgp.AppendArrayHeader(buf, 50) // Within limit + for i := 0; i < 50; i++ { + buf = msgp.AppendString(buf, "test") + } + + _, err := data.UnmarshalMsg(buf) + if err != nil { + t.Errorf("Unexpected error for slice within limit: %v", err) + } + }) +} + +func TestMapLimitEnforcement(t *testing.T) { + data := LimitTestData{} + + // Test map limit with DecodeMsg + t.Run("DecodeMsg_MapLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "small_map") + buf = msgp.AppendMapHeader(buf, 60) // Exceeds limit of 50 + + reader := msgp.NewReader(bytes.NewReader(buf)) + err := data.DecodeMsg(reader) + if err != msgp.ErrLimitExceeded { + t.Errorf("Expected ErrLimitExceeded, got %v", err) + } + }) + + // Test map limit with UnmarshalMsg + t.Run("UnmarshalMsg_MapLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "small_map") + buf = msgp.AppendMapHeader(buf, 60) // Exceeds limit of 50 + + _, err := data.UnmarshalMsg(buf) + if err != msgp.ErrLimitExceeded { + t.Errorf("Expected ErrLimitExceeded, got %v", err) + } + }) + + // Test that maps within limit work fine + t.Run("MapWithinLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "small_map") + buf = msgp.AppendMapHeader(buf, 3) // Within limit + buf = msgp.AppendString(buf, "a") + buf = msgp.AppendInt(buf, 1) + buf = msgp.AppendString(buf, "b") + buf = msgp.AppendInt(buf, 2) + buf = msgp.AppendString(buf, "c") + buf = msgp.AppendInt(buf, 3) + + _, err := data.UnmarshalMsg(buf) + if err != nil { + t.Errorf("Unexpected error for map within limit: %v", err) + } + }) +} + +func TestFixedArraysNotLimited(t *testing.T) { + // Test that fixed arrays are not subject to limits + // BigArray [1000]int should work even though 1000 > 100 (array limit) + data := UnlimitedData{} + + t.Run("FixedArray_DecodeMsg", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_array") + buf = msgp.AppendArrayHeader(buf, 1000) // Fixed array size, should not be limited + for i := 0; i < 1000; i++ { + buf = msgp.AppendInt(buf, i) + } + + reader := msgp.NewReader(bytes.NewReader(buf)) + err := data.DecodeMsg(reader) + if err != nil { + t.Errorf("Fixed arrays should not be limited, got error: %v", err) + } + }) + + t.Run("FixedArray_UnmarshalMsg", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_array") + buf = msgp.AppendArrayHeader(buf, 1000) // Fixed array size, should not be limited + for i := 0; i < 1000; i++ { + buf = msgp.AppendInt(buf, i) + } + + _, err := data.UnmarshalMsg(buf) + if err != nil { + t.Errorf("Fixed arrays should not be limited, got error: %v", err) + } + }) +} + +func TestSliceLimitsApplied(t *testing.T) { + // Test that dynamic slices are subject to limits + data := UnlimitedData{} + + t.Run("Slice_ExceedsLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_slice") + buf = msgp.AppendArrayHeader(buf, 150) // Exceeds array limit of 100 + + _, err := data.UnmarshalMsg(buf) + if err != msgp.ErrLimitExceeded { + t.Errorf("Expected ErrLimitExceeded for slice, got %v", err) + } + }) + + t.Run("Slice_WithinLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_slice") + buf = msgp.AppendArrayHeader(buf, 50) // Within array limit of 100 + for i := 0; i < 50; i++ { + buf = msgp.AppendString(buf, "test") + } + + _, err := data.UnmarshalMsg(buf) + if err != nil { + t.Errorf("Unexpected error for slice within limit: %v", err) + } + }) +} + +func TestNestedArrayLimits(t *testing.T) { + // Test limits on nested arrays within maps + data := UnlimitedData{} + + t.Run("NestedArray_ExceedsLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_map") + buf = msgp.AppendMapHeader(buf, 1) // Within map limit + buf = msgp.AppendString(buf, "key") + buf = msgp.AppendArrayHeader(buf, 150) // Nested array exceeds limit of 100 + + _, err := data.UnmarshalMsg(buf) + if err != msgp.ErrLimitExceeded { + t.Errorf("Expected ErrLimitExceeded for nested array, got %v", err) + } + }) + + t.Run("NestedArray_WithinLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_map") + buf = msgp.AppendMapHeader(buf, 1) // Within map limit + buf = msgp.AppendString(buf, "key") + buf = msgp.AppendArrayHeader(buf, 50) // Nested array within limit + for i := 0; i < 50; i++ { + buf = msgp.AppendInt(buf, i) + } + + _, err := data.UnmarshalMsg(buf) + if err != nil { + t.Errorf("Unexpected error for nested array within limit: %v", err) + } + }) +} + +func TestMapExceedsLimit(t *testing.T) { + data := UnlimitedData{} + + t.Run("Map_ExceedsLimit", func(t *testing.T) { + buf := msgp.AppendMapHeader(nil, 1) + buf = msgp.AppendString(buf, "big_map") + buf = msgp.AppendMapHeader(buf, 60) // Exceeds map limit of 50 + + _, err := data.UnmarshalMsg(buf) + if err != msgp.ErrLimitExceeded { + t.Errorf("Expected ErrLimitExceeded for map, got %v", err) + } + }) +} + +func TestStructLevelLimits(t *testing.T) { + // Test that the struct-level map limits are enforced + data := LimitTestData{} + + t.Run("StructMap_ExceedsLimit", func(t *testing.T) { + // Create a struct with too many fields + buf := msgp.AppendMapHeader(nil, 60) // Exceeds map limit of 50 + + _, err := data.UnmarshalMsg(buf) + if err != msgp.ErrLimitExceeded { + t.Errorf("Expected ErrLimitExceeded for struct map, got %v", err) + } + }) +} + +func TestNormalOperationWithinLimits(t *testing.T) { + // Test that normal operation works when everything is within limits + data := LimitTestData{} + + // Create valid data + data.SmallArray = [10]int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + data.LargeSlice = []byte("test data") + data.SmallMap = map[string]int{"a": 1, "b": 2, "c": 3} + + t.Run("RoundTrip_Marshal_Unmarshal", func(t *testing.T) { + // Test MarshalMsg -> UnmarshalMsg + buf, err := data.MarshalMsg(nil) + if err != nil { + t.Fatalf("MarshalMsg failed: %v", err) + } + + var result LimitTestData + _, err = result.UnmarshalMsg(buf) + if err != nil { + t.Fatalf("UnmarshalMsg failed: %v", err) + } + + // Verify data integrity + if result.SmallArray != data.SmallArray { + t.Errorf("SmallArray mismatch: got %v, want %v", result.SmallArray, data.SmallArray) + } + if !bytes.Equal(result.LargeSlice, data.LargeSlice) { + t.Errorf("LargeSlice mismatch: got %v, want %v", result.LargeSlice, data.LargeSlice) + } + if len(result.SmallMap) != len(data.SmallMap) { + t.Errorf("SmallMap length mismatch: got %d, want %d", len(result.SmallMap), len(data.SmallMap)) + } + }) + + t.Run("RoundTrip_Encode_Decode", func(t *testing.T) { + // Test EncodeMsg -> DecodeMsg + var buf bytes.Buffer + writer := msgp.NewWriter(&buf) + err := data.EncodeMsg(writer) + if err != nil { + t.Fatalf("EncodeMsg failed: %v", err) + } + writer.Flush() + + var result LimitTestData + reader := msgp.NewReader(&buf) + err = result.DecodeMsg(reader) + if err != nil { + t.Fatalf("DecodeMsg failed: %v", err) + } + + // Verify data integrity + if result.SmallArray != data.SmallArray { + t.Errorf("SmallArray mismatch: got %v, want %v", result.SmallArray, data.SmallArray) + } + if !bytes.Equal(result.LargeSlice, data.LargeSlice) { + t.Errorf("LargeSlice mismatch: got %v, want %v", result.LargeSlice, data.LargeSlice) + } + if len(result.SmallMap) != len(data.SmallMap) { + t.Errorf("SmallMap length mismatch: got %d, want %d", len(result.SmallMap), len(data.SmallMap)) + } + }) +} + +func TestMarshalLimitEnforcement(t *testing.T) { + // Test marshal-time limit enforcement with MarshalLimitTestData + // This struct has marshal:true with arrays:30 maps:20 + + t.Run("MarshalMsg_SliceLimit", func(t *testing.T) { + data := MarshalLimitTestData{ + TestSlice: make([]string, 40), // Exceeds array limit of 30 + } + // Fill the slice + for i := range data.TestSlice { + data.TestSlice[i] = "test" + } + + _, err := data.MarshalMsg(nil) + if err == nil { + t.Error("Expected error for slice exceeding marshal limit, got nil") + } + }) + + t.Run("MarshalMsg_MapLimit", func(t *testing.T) { + data := MarshalLimitTestData{ + TestMap: make(map[string]int, 25), // Exceeds map limit of 20 + } + // Fill the map + for i := 0; i < 25; i++ { + data.TestMap[fmt.Sprintf("key%d", i)] = i + } + + _, err := data.MarshalMsg(nil) + if err == nil { + t.Error("Expected error for map exceeding marshal limit, got nil") + } + }) + + t.Run("EncodeMsg_SliceLimit", func(t *testing.T) { + data := MarshalLimitTestData{ + TestSlice: make([]string, 40), // Exceeds array limit of 30 + } + // Fill the slice + for i := range data.TestSlice { + data.TestSlice[i] = "test" + } + + var buf bytes.Buffer + writer := msgp.NewWriter(&buf) + err := data.EncodeMsg(writer) + if err == nil { + t.Error("Expected error for slice exceeding marshal limit, got nil") + } + }) + + t.Run("EncodeMsg_MapLimit", func(t *testing.T) { + data := MarshalLimitTestData{ + TestMap: make(map[string]int, 25), // Exceeds map limit of 20 + } + // Fill the map + for i := 0; i < 25; i++ { + data.TestMap[fmt.Sprintf("key%d", i)] = i + } + + var buf bytes.Buffer + writer := msgp.NewWriter(&buf) + err := data.EncodeMsg(writer) + if err == nil { + t.Error("Expected error for map exceeding marshal limit, got nil") + } + }) + + t.Run("MarshalWithinLimits", func(t *testing.T) { + data := MarshalLimitTestData{ + SmallArray: [5]int{1, 2, 3, 4, 5}, + TestSlice: []string{"a", "b", "c"}, // Within limit of 30 + TestMap: map[string]int{"x": 1, "y": 2}, // Within limit of 20 + } + + // Test MarshalMsg + _, err := data.MarshalMsg(nil) + if err != nil { + t.Errorf("Unexpected error for data within marshal limits: %v", err) + } + + // Test EncodeMsg + var buf bytes.Buffer + writer := msgp.NewWriter(&buf) + err = data.EncodeMsg(writer) + if err != nil { + t.Errorf("Unexpected error for data within marshal limits: %v", err) + } + }) + + t.Run("FixedArraysNotLimited_Marshal", func(t *testing.T) { + // Fixed arrays should not be subject to marshal limits + data := MarshalLimitTestData{ + SmallArray: [5]int{1, 2, 3, 4, 5}, // Fixed array size + } + + _, err := data.MarshalMsg(nil) + if err != nil { + t.Errorf("Fixed arrays should not be limited during marshal, got error: %v", err) + } + }) +} + +func TestFieldLevelLimits(t *testing.T) { + // Test field-level limit enforcement with FieldLimitTestData + + t.Run("SmallSlice_WithinLimit", func(t *testing.T) { + data := FieldLimitTestData{ + SmallSlice: []int{1, 2, 3}, // Within limit of 5 + } + + // Marshal + marshaled, err := data.MarshalMsg(nil) + if err != nil { + t.Fatalf("Unexpected marshal error: %v", err) + } + + // Unmarshal + var result FieldLimitTestData + _, err = result.UnmarshalMsg(marshaled) + if err != nil { + t.Fatalf("Unexpected unmarshal error: %v", err) + } + + if len(result.SmallSlice) != len(data.SmallSlice) { + t.Errorf("SmallSlice length mismatch: got %d, want %d", len(result.SmallSlice), len(data.SmallSlice)) + } + }) + + t.Run("SmallSlice_ExceedsLimit", func(t *testing.T) { + data := FieldLimitTestData{ + SmallSlice: []int{1, 2, 3, 4, 5, 6, 7}, // Exceeds limit of 5 + } + + marshaled, err := data.MarshalMsg(nil) + if err != nil { + t.Fatalf("Unexpected marshal error: %v", err) + } + + // Unmarshal should fail + var result FieldLimitTestData + _, err = result.UnmarshalMsg(marshaled) + if err == nil { + t.Error("Expected error for SmallSlice exceeding limit, got nil") + } + }) + + t.Run("SmallMap_WithinLimit", func(t *testing.T) { + data := FieldLimitTestData{ + SmallMap: map[string]int{"a": 1, "b": 2}, // Within limit of 3 + } + + marshaled, err := data.MarshalMsg(nil) + if err != nil { + t.Fatalf("Unexpected marshal error: %v", err) + } + + var result FieldLimitTestData + _, err = result.UnmarshalMsg(marshaled) + if err != nil { + t.Fatalf("Unexpected unmarshal error: %v", err) + } + + if len(result.SmallMap) != len(data.SmallMap) { + t.Errorf("SmallMap length mismatch: got %d, want %d", len(result.SmallMap), len(data.SmallMap)) + } + }) + + t.Run("SmallMap_ExceedsLimit", func(t *testing.T) { + data := FieldLimitTestData{ + SmallMap: map[string]int{"a": 1, "b": 2, "c": 3, "d": 4}, // Exceeds limit of 3 + } + + marshaled, err := data.MarshalMsg(nil) + if err != nil { + t.Fatalf("Unexpected marshal error: %v", err) + } + + // Unmarshal should fail + var result FieldLimitTestData + _, err = result.UnmarshalMsg(marshaled) + if err == nil { + t.Error("Expected error for SmallMap exceeding limit, got nil") + } + }) + + t.Run("FixedArrays_NotLimited_Field", func(t *testing.T) { + // Fixed arrays should not be subject to field-level limits + data := FieldLimitTestData{ + FixedArray: [10]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, // Fixed size, should not be limited + } + + marshaled, err := data.MarshalMsg(nil) + if err != nil { + t.Fatalf("Unexpected marshal error for fixed array: %v", err) + } + + var result FieldLimitTestData + _, err = result.UnmarshalMsg(marshaled) + if err != nil { + t.Errorf("Fixed arrays should not be limited with field tags, got error: %v", err) + } + }) + + t.Run("DecodeMsg_FieldLimits", func(t *testing.T) { + data := FieldLimitTestData{ + SmallSlice: []int{1, 2, 3, 4, 5, 6}, // Exceeds limit of 5 + } + + marshaled, err := data.MarshalMsg(nil) + if err != nil { + t.Fatalf("Unexpected marshal error: %v", err) + } + + // Test DecodeMsg path + var result FieldLimitTestData + reader := msgp.NewReader(bytes.NewReader(marshaled)) + err = result.DecodeMsg(reader) + if err == nil { + t.Error("Expected error for DecodeMsg with field exceeding limit, got nil") + } + }) +} + +func TestFieldVsFileLimitPrecedence(t *testing.T) { + // Test precedence: field limits should override file limits + // limits.go has: arrays:100 maps:50 + + t.Run("TightSlice_FieldOverride", func(t *testing.T) { + // File limit: arrays:100, field limit: 10 -> field limit should apply + data := FieldOverrideTestData{ + TightSlice: make([]int, 15), // Exceeds field limit (10) but within file limit (100) + } + + marshaled, err := data.MarshalMsg(nil) + if err != nil { + t.Fatalf("Unexpected marshal error: %v", err) + } + + var result FieldOverrideTestData + _, err = result.UnmarshalMsg(marshaled) + if err == nil { + t.Error("Expected error for TightSlice exceeding field limit (10), got nil") + } + }) + + t.Run("LooseSlice_FieldOverride", func(t *testing.T) { + // File limit: arrays:100, field limit: 200 -> field limit should apply + data := FieldOverrideTestData{ + LooseSlice: make([]string, 150), // Within field limit (200) but exceeds file limit (100) + } + for i := range data.LooseSlice { + data.LooseSlice[i] = "test" + } + + marshaled, err := data.MarshalMsg(nil) + if err != nil { + t.Fatalf("Unexpected marshal error: %v", err) + } + + var result FieldOverrideTestData + _, err = result.UnmarshalMsg(marshaled) + if err != nil { + t.Errorf("Expected success for LooseSlice within field limit (200), got error: %v", err) + } + }) + + t.Run("DefaultFields_UseFileLimit", func(t *testing.T) { + // Fields without field limits should use file limits + data := FieldOverrideTestData{ + DefaultSlice: make([]byte, 120), // Exceeds file limit (100) + } + + marshaled, err := data.MarshalMsg(nil) + if err != nil { + t.Fatalf("Unexpected marshal error: %v", err) + } + + var result FieldOverrideTestData + _, err = result.UnmarshalMsg(marshaled) + if err == nil { + t.Error("Expected error for DefaultSlice exceeding file limit (100), got nil") + } + }) +} + +func TestMarshalFieldVsFileLimitPrecedence(t *testing.T) { + // Test precedence with marshal:true + // marshal_limits.go has: arrays:30 maps:20 marshal:true + + t.Run("Marshal_TightSlice_FieldOverride", func(t *testing.T) { + // File limit: arrays:30, field limit: 10 -> field limit should apply + // Note: Since I only implemented field limits for unmarshal.go, marshal limits won't work yet + data := MarshalFieldOverrideTestData{ + TightSlice: make([]int, 15), // Exceeds field limit (10) but within file limit (30) + } + + marshaled, err := data.MarshalMsg(nil) + if err != nil { + t.Fatalf("Unexpected marshal error: %v", err) + } + + var result MarshalFieldOverrideTestData + _, err = result.UnmarshalMsg(marshaled) + if err == nil { + t.Error("Expected error for TightSlice exceeding field limit (10), got nil") + } + }) + + t.Run("Marshal_DefaultFields_UseFileLimit", func(t *testing.T) { + // Fields without field limits should use file limits + data := MarshalFieldOverrideTestData{ + DefaultSlice: make([]byte, 35), // Exceeds file limit (30) + } + + marshaled, err := data.MarshalMsg(nil) + if err != nil { + t.Fatalf("Unexpected marshal error: %v", err) + } + + var result MarshalFieldOverrideTestData + _, err = result.UnmarshalMsg(marshaled) + if err == nil { + t.Error("Expected error for DefaultSlice exceeding file limit (30), got nil") + } + }) +} + +func TestAliasedTypesWithFieldLimits(t *testing.T) { + // Test field-level limits with aliased types + + t.Run("SmallAliasedSlice_WithinLimit", func(t *testing.T) { + data := AliasedFieldLimitTestData{ + SmallAliasedSlice: AliasedSlice{"a", "b"}, // Within limit of 3 + } + + marshaled, err := data.MarshalMsg(nil) + if err != nil { + t.Fatalf("Unexpected marshal error: %v", err) + } + + var result AliasedFieldLimitTestData + _, err = result.UnmarshalMsg(marshaled) + if err != nil { + t.Fatalf("Unexpected unmarshal error: %v", err) + } + + if len(result.SmallAliasedSlice) != len(data.SmallAliasedSlice) { + t.Errorf("SmallAliasedSlice length mismatch: got %d, want %d", len(result.SmallAliasedSlice), len(data.SmallAliasedSlice)) + } + }) + + t.Run("SmallAliasedSlice_ExceedsLimit", func(t *testing.T) { + data := AliasedFieldLimitTestData{ + SmallAliasedSlice: AliasedSlice{"a", "b", "c", "d", "e"}, // Exceeds limit of 3 + } + + marshaled, err := data.MarshalMsg(nil) + if err != nil { + t.Fatalf("Unexpected marshal error: %v", err) + } + + var result AliasedFieldLimitTestData + _, err = result.UnmarshalMsg(marshaled) + if err == nil { + t.Error("Expected error for SmallAliasedSlice exceeding limit (3), got nil") + } + }) + + t.Run("SmallAliasedMap_ExceedsLimit", func(t *testing.T) { + data := AliasedFieldLimitTestData{ + SmallAliasedMap: AliasedMap{ + "key1": true, + "key2": false, + "key3": true, // Exceeds limit of 2 + }, + } + + marshaled, err := data.MarshalMsg(nil) + if err != nil { + t.Fatalf("Unexpected marshal error: %v", err) + } + + var result AliasedFieldLimitTestData + _, err = result.UnmarshalMsg(marshaled) + if err == nil { + t.Error("Expected error for SmallAliasedMap exceeding limit (2), got nil") + } + }) + + t.Run("IntSliceAlias_ExceedsLimit", func(t *testing.T) { + data := AliasedFieldLimitTestData{ + IntSliceAlias: make(AliasedIntSlice, 15), // Exceeds limit of 10 + } + for i := range data.IntSliceAlias { + data.IntSliceAlias[i] = i + } + + marshaled, err := data.MarshalMsg(nil) + if err != nil { + t.Fatalf("Unexpected marshal error: %v", err) + } + + var result AliasedFieldLimitTestData + _, err = result.UnmarshalMsg(marshaled) + if err == nil { + t.Error("Expected error for IntSliceAlias exceeding limit (10), got nil") + } + }) + + t.Run("DecodeMsg_AliasedTypes", func(t *testing.T) { + data := AliasedFieldLimitTestData{ + SmallAliasedSlice: AliasedSlice{"a", "b", "c", "d"}, // Exceeds limit of 3 + } + + marshaled, err := data.MarshalMsg(nil) + if err != nil { + t.Fatalf("Unexpected marshal error: %v", err) + } + + // Test DecodeMsg path + var result AliasedFieldLimitTestData + reader := msgp.NewReader(bytes.NewReader(marshaled)) + err = result.DecodeMsg(reader) + if err == nil { + t.Error("Expected error for DecodeMsg with aliased slice exceeding limit, got nil") + } + }) +} diff --git a/_generated/marshal_limits.go b/_generated/marshal_limits.go new file mode 100644 index 00000000..bf2fea03 --- /dev/null +++ b/_generated/marshal_limits.go @@ -0,0 +1,23 @@ +//msgp:limit arrays:30 maps:20 marshal:true + +package _generated + +//go:generate msgp + +// Test structures for marshal-time limit enforcement +type MarshalLimitTestData struct { + SmallArray [5]int `msg:"small_array"` + TestSlice []string `msg:"test_slice"` + TestMap map[string]int `msg:"test_map"` +} + +// Test field limits vs file limits precedence with marshal:true +// File limits: arrays:30 maps:20 marshal:true +type MarshalFieldOverrideTestData struct { + TightSlice []int `msg:"tight_slice,limit=10"` // Field limit (10) < file limit (30) + LooseSlice []string `msg:"loose_slice,limit=50"` // Field limit (50) > file limit (30) + TightMap map[string]int `msg:"tight_map,limit=5"` // Field limit (5) < file limit (20) + LooseMap map[int]string `msg:"loose_map,limit=40"` // Field limit (40) > file limit (20) + DefaultSlice []byte `msg:"default_slice"` // No field limit, uses file limit (30) + DefaultMap map[string]byte `msg:"default_map"` // No field limit, uses file limit (20) +} diff --git a/gen/decode.go b/gen/decode.go index 07352c6e..b0d26bce 100644 --- a/gen/decode.go +++ b/gen/decode.go @@ -3,6 +3,7 @@ package gen import ( "fmt" "io" + "math" "strconv" "strings" ) @@ -74,10 +75,141 @@ func (d *decodeGen) assignAndCheck(name string, typ string) { d.p.wrapErrCheck(d.ctx.ArgsStr()) } + +func (d *decodeGen) assignArray(name string, typ string, fieldLimit uint32) { + if !d.p.ok() { + return + } + d.p.printf("\n%s, err = dc.Read%s()", name, typ) + d.p.wrapErrCheck(d.ctx.ArgsStr()) + + // Determine effective limit: field limit > context field limit > file limit + var limit uint32 + var limitName string + + if fieldLimit > 0 { + // Explicit field limit passed as parameter + limit = fieldLimit + limitName = fmt.Sprintf("%d", fieldLimit) + } else if d.ctx.currentFieldArrayLimit != math.MaxUint32 { + // Field limit from context (set during field processing) + limit = d.ctx.currentFieldArrayLimit + limitName = fmt.Sprintf("%d", d.ctx.currentFieldArrayLimit) + } else if d.ctx.arrayLimit != math.MaxUint32 { + // File-level limit + limit = d.ctx.arrayLimit + limitName = fmt.Sprintf("%slimitArrays", d.ctx.limitPrefix) + } + + if limit > 0 && limit != math.MaxUint32 { + d.p.printf("\nif %s > %s {", name, limitName) + d.p.printf("\nerr = msgp.ErrLimitExceeded") + d.p.printf("\nreturn") + d.p.printf("\n}") + } +} + +func (d *decodeGen) assignMap(name string, typ string, fieldLimit uint32) { + if !d.p.ok() { + return + } + d.p.printf("\n%s, err = dc.Read%s()", name, typ) + d.p.wrapErrCheck(d.ctx.ArgsStr()) + + // Determine effective limit: field limit > context field limit > file limit + var limit uint32 + var limitName string + + if fieldLimit > 0 { + // Explicit field limit passed as parameter + limit = fieldLimit + limitName = fmt.Sprintf("%d", fieldLimit) + } else if d.ctx.currentFieldMapLimit != math.MaxUint32 { + // Field limit from context (set during field processing) + limit = d.ctx.currentFieldMapLimit + limitName = fmt.Sprintf("%d", d.ctx.currentFieldMapLimit) + } else if d.ctx.mapLimit != math.MaxUint32 { + // File-level limit + limit = d.ctx.mapLimit + limitName = fmt.Sprintf("%slimitMaps", d.ctx.limitPrefix) + } + + if limit > 0 && limit != math.MaxUint32 { + d.p.printf("\nif %s > %s {", name, limitName) + d.p.printf("\nerr = msgp.ErrLimitExceeded") + d.p.printf("\nreturn") + d.p.printf("\n}") + } +} + +func (d *decodeGen) readBytesWithLimit(vname, checkNil string, fieldLimit uint32) { + if !d.p.ok() { + return + } + d.p.printf("\n%s, err = dc.ReadBytes(%s)", vname, vname) + d.p.wrapErrCheck(d.ctx.ArgsStr()) + + // Determine effective limit: field limit > context field limit > file limit + var limit uint32 + var limitName string + + if fieldLimit > 0 { + // Explicit field limit passed as parameter + limit = fieldLimit + limitName = fmt.Sprintf("%d", fieldLimit) + } else if d.ctx.currentFieldArrayLimit != math.MaxUint32 { + // Field limit from context (set during field processing) + limit = d.ctx.currentFieldArrayLimit + limitName = fmt.Sprintf("%d", d.ctx.currentFieldArrayLimit) + } else if d.ctx.arrayLimit != math.MaxUint32 { + // File-level limit + limit = d.ctx.arrayLimit + limitName = fmt.Sprintf("%slimitArrays", d.ctx.limitPrefix) + } + + if limit > 0 && limit != math.MaxUint32 { + d.p.printf("\nif uint32(len(%s)) > %s {", checkNil, limitName) + d.p.printf("\nerr = msgp.ErrLimitExceeded") + d.p.printf("\nreturn") + d.p.printf("\n}") + } +} + +func (d *decodeGen) checkByteLimits(vname string, fieldLimit uint32) { + if !d.p.ok() { + return + } + + // Determine effective limit: field limit > context field limit > file limit + var limit uint32 + var limitName string + + if fieldLimit > 0 { + // Explicit field limit passed as parameter + limit = fieldLimit + limitName = fmt.Sprintf("%d", fieldLimit) + } else if d.ctx.currentFieldArrayLimit != math.MaxUint32 { + // Field limit from context (set during field processing) + limit = d.ctx.currentFieldArrayLimit + limitName = fmt.Sprintf("%d", d.ctx.currentFieldArrayLimit) + } else if d.ctx.arrayLimit != math.MaxUint32 { + // File-level limit + limit = d.ctx.arrayLimit + limitName = fmt.Sprintf("%slimitArrays", d.ctx.limitPrefix) + } + + if limit > 0 && limit != math.MaxUint32 { + d.p.printf("\nif uint32(len(%s)) > %s {", vname, limitName) + d.p.printf("\nerr = msgp.ErrLimitExceeded") + d.p.printf("\nreturn") + d.p.printf("\n}") + } +} + func (d *decodeGen) structAsTuple(s *Struct) { sz := randIdent() d.p.declare(sz, u32) - d.assignAndCheck(sz, arrayHeader) + d.assignArray(sz, arrayHeader, 0) if s.AsVarTuple { d.p.printf("\nif %[1]s == 0 { return }", sz) } else { @@ -89,6 +221,15 @@ func (d *decodeGen) structAsTuple(s *Struct) { } fieldElem := s.Fields[i].FieldElem anField := s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil() + + // Set field-specific limits in context based on struct field's FieldLimit + if s.Fields[i].FieldLimit > 0 { + // Apply same limit to both arrays and maps for this field + d.ctx.SetFieldLimits(s.Fields[i].FieldLimit, s.Fields[i].FieldLimit) + } else { + d.ctx.ClearFieldLimits() + } + if anField { d.p.print("\nif dc.IsNil() {") d.p.print("\nerr = dc.ReadNil()") @@ -99,6 +240,10 @@ func (d *decodeGen) structAsTuple(s *Struct) { d.ctx.PushString(s.Fields[i].FieldName) setTypeParams(fieldElem, s.typeParams) next(d, fieldElem) + + // Clear field limits after processing + d.ctx.ClearFieldLimits() + d.ctx.Pop() if anField { d.p.printf("\n}") // close if statement @@ -116,7 +261,7 @@ func (d *decodeGen) structAsMap(s *Struct) { d.needsField() sz := randIdent() d.p.declare(sz, u32) - d.assignAndCheck(sz, mapHeader) + d.assignMap(sz, mapHeader, 0) oeCount := s.CountFieldTagPart("omitempty") + s.CountFieldTagPart("omitzero") if !d.ctx.clearOmitted { @@ -142,6 +287,15 @@ func (d *decodeGen) structAsMap(s *Struct) { d.p.printf("\ncase %q:", s.Fields[i].FieldTag) fieldElem := s.Fields[i].FieldElem anField := s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil() + + // Set field-specific limits in context based on struct field's FieldLimit + if s.Fields[i].FieldLimit > 0 { + // Apply same limit to both arrays and maps for this field + d.ctx.SetFieldLimits(s.Fields[i].FieldLimit, s.Fields[i].FieldLimit) + } else { + d.ctx.ClearFieldLimits() + } + if anField { d.p.print("\nif dc.IsNil() {") d.p.print("\nerr = dc.ReadNil()") @@ -151,6 +305,10 @@ func (d *decodeGen) structAsMap(s *Struct) { SetIsAllowNil(fieldElem, anField) setTypeParams(fieldElem, s.typeParams) next(d, fieldElem) + + // Clear field limits after processing + d.ctx.ClearFieldLimits() + if oeCount > 0 && (s.Fields[i].HasTagPart("omitempty") || s.Fields[i].HasTagPart("omitzero")) { d.p.printf("\n%s", bm.setStmt(len(oeEmittedIdx))) oeEmittedIdx = append(oeEmittedIdx, i) @@ -215,9 +373,12 @@ func (d *decodeGen) gBase(b *BaseElem) { if b.Convert { lowered := b.ToBase() + "(" + vname + ")" d.p.printf("\n%s, err = dc.ReadBytes(%s)", tmp, lowered) + d.p.wrapErrCheck(d.ctx.ArgsStr()) checkNil = tmp + // Check byte slice limits after reading + d.checkByteLimits(tmp, 0) } else { - d.p.printf("\n%s, err = dc.ReadBytes(%s)", vname, vname) + d.readBytesWithLimit(vname, vname, 0) checkNil = vname } case IDENT: @@ -281,7 +442,7 @@ func (d *decodeGen) gMap(m *Map) { // resize or allocate map d.p.declare(sz, u32) - d.assignAndCheck(sz, mapHeader) + d.assignMap(sz, mapHeader, 0) d.p.resizeMap(sz, m) // for element in map, read string/value @@ -305,7 +466,7 @@ func (d *decodeGen) gSlice(s *Slice) { } sz := randIdent() d.p.declare(sz, u32) - d.assignAndCheck(sz, arrayHeader) + d.assignArray(sz, arrayHeader, 0) if s.isAllowNil { d.p.resizeSliceNoNil(sz, s) } else { diff --git a/gen/elem.go b/gen/elem.go index bcf3c0e6..3ce38d9c 100644 --- a/gen/elem.go +++ b/gen/elem.go @@ -595,6 +595,7 @@ type StructField struct { RawTag string // the full struct tag FieldName string // the name of the struct field FieldElem Elem // the field type + FieldLimit uint32 // field-specific size limit for slices/maps (0 = no limit) } // HasTagPart returns true if the specified tag part (option) is present. @@ -605,6 +606,21 @@ func (sf *StructField) HasTagPart(pname string) bool { return slices.Contains(sf.FieldTagParts[1:], pname) } +// GetTagValue returns the value for a tag part with the format "key=value". +// Returns the value string and true if found, empty string and false if not found. +func (sf *StructField) GetTagValue(key string) (string, bool) { + if len(sf.FieldTagParts) < 2 { + return "", false + } + prefix := key + "=" + for _, part := range sf.FieldTagParts[1:] { + if strings.HasPrefix(part, prefix) { + return strings.TrimPrefix(part, prefix), true + } + } + return "", false +} + type ShimMode int const ( diff --git a/gen/encode.go b/gen/encode.go index 2b8bd7e8..5d42cdec 100644 --- a/gen/encode.go +++ b/gen/encode.go @@ -3,6 +3,7 @@ package gen import ( "fmt" "io" + "math" "strings" "github.com/tinylib/msgp/msgp" @@ -39,6 +40,26 @@ func (e *encodeGen) writeAndCheck(typ string, argfmt string, arg any) { e.p.wrapErrCheck(e.ctx.ArgsStr()) } +func (e *encodeGen) writeAndCheckWithArrayLimit(typ string, argfmt string, arg any) { + e.writeAndCheck(typ, argfmt, arg) + if e.ctx.marshalLimits && e.ctx.arrayLimit != math.MaxUint32 { + e.p.printf("\nif %s > %slimitArrays {", fmt.Sprintf(argfmt, arg), e.ctx.limitPrefix) + e.p.printf("\nerr = msgp.ErrLimitExceeded") + e.p.printf("\nreturn") + e.p.printf("\n}") + } +} + +func (e *encodeGen) writeAndCheckWithMapLimit(typ string, argfmt string, arg any) { + e.writeAndCheck(typ, argfmt, arg) + if e.ctx.marshalLimits && e.ctx.mapLimit != math.MaxUint32 { + e.p.printf("\nif %s > %slimitMaps {", fmt.Sprintf(argfmt, arg), e.ctx.limitPrefix) + e.p.printf("\nerr = msgp.ErrLimitExceeded") + e.p.printf("\nreturn") + e.p.printf("\n}") + } +} + func (e *encodeGen) fuseHook() { if len(e.fuse) > 0 { e.appendraw(e.fuse) @@ -244,7 +265,7 @@ func (e *encodeGen) gMap(m *Map) { } e.fuseHook() vname := m.Varname() - e.writeAndCheck(mapHeader, lenAsUint32, vname) + e.writeAndCheckWithMapLimit(mapHeader, lenAsUint32, vname) e.p.printf("\nfor %s, %s := range %s {", m.Keyidx, m.Validx, vname) if m.Key != nil { @@ -297,7 +318,7 @@ func (e *encodeGen) gSlice(s *Slice) { return } e.fuseHook() - e.writeAndCheck(arrayHeader, lenAsUint32, s.Varname()) + e.writeAndCheckWithArrayLimit(arrayHeader, lenAsUint32, s.Varname()) setTypeParams(s.Els, s.typeParams) e.p.rangeBlock(e.ctx, s.Index, s.Varname(), e, s.Els) } diff --git a/gen/marshal.go b/gen/marshal.go index e68464ad..41fe111d 100644 --- a/gen/marshal.go +++ b/gen/marshal.go @@ -3,6 +3,7 @@ package gen import ( "fmt" "io" + "math" "strings" "github.com/tinylib/msgp/msgp" @@ -73,6 +74,24 @@ func (m *marshalGen) rawAppend(typ string, argfmt string, arg any) { m.p.printf("\no = msgp.Append%s(o, %s)", typ, fmt.Sprintf(argfmt, arg)) } +func (m *marshalGen) rawAppendWithArrayLimit(typ string, argfmt string, arg any) { + m.rawAppend(typ, argfmt, arg) + if m.ctx.marshalLimits && m.ctx.arrayLimit != math.MaxUint32 { + m.p.printf("\nif %s > %slimitArrays {", fmt.Sprintf(argfmt, arg), m.ctx.limitPrefix) + m.p.printf("\nreturn nil, msgp.ErrLimitExceeded") + m.p.printf("\n}") + } +} + +func (m *marshalGen) rawAppendWithMapLimit(typ string, argfmt string, arg any) { + m.rawAppend(typ, argfmt, arg) + if m.ctx.marshalLimits && m.ctx.mapLimit != math.MaxUint32 { + m.p.printf("\nif %s > %slimitMaps {", fmt.Sprintf(argfmt, arg), m.ctx.limitPrefix) + m.p.printf("\nreturn nil, msgp.ErrLimitExceeded") + m.p.printf("\n}") + } +} + func (m *marshalGen) fuseHook() { if len(m.fuse) > 0 { m.rawbytes(m.fuse) @@ -250,7 +269,7 @@ func (m *marshalGen) gMap(s *Map) { } m.fuseHook() vname := s.Varname() - m.rawAppend(mapHeader, lenAsUint32, vname) + m.rawAppendWithMapLimit(mapHeader, lenAsUint32, vname) m.p.printf("\nfor %s, %s := range %s {", s.Keyidx, s.Validx, vname) // Shim key to base type if necessary. if s.Key != nil { @@ -292,7 +311,7 @@ func (m *marshalGen) gSlice(s *Slice) { vname := s.Varname() setTypeParams(s.Els, s.typeParams) - m.rawAppend(arrayHeader, lenAsUint32, vname) + m.rawAppendWithArrayLimit(arrayHeader, lenAsUint32, vname) m.p.rangeBlock(m.ctx, s.Index, vname, m, s.Els) } diff --git a/gen/spec.go b/gen/spec.go index 4ff4fbee..8dc2c83c 100644 --- a/gen/spec.go +++ b/gen/spec.go @@ -4,6 +4,7 @@ import ( "bytes" "fmt" "io" + "math" "strings" ) @@ -81,6 +82,10 @@ type Printer struct { ClearOmitted bool NewTime bool AsUTC bool + ArrayLimit uint32 + MapLimit uint32 + MarshalLimits bool + LimitPrefix string } func NewPrinter(m Method, out io.Writer, tests io.Writer) *Printer { @@ -151,10 +156,16 @@ func (p *Printer) Print(e Elem) error { // hence the separate prefixes. resetIdent("zb") err := g.Execute(e, Context{ - compFloats: p.CompactFloats, - clearOmitted: p.ClearOmitted, - newTime: p.NewTime, - asUTC: p.AsUTC, + compFloats: p.CompactFloats, + clearOmitted: p.ClearOmitted, + newTime: p.NewTime, + asUTC: p.AsUTC, + arrayLimit: p.ArrayLimit, + mapLimit: p.MapLimit, + marshalLimits: p.MarshalLimits, + limitPrefix: p.LimitPrefix, + currentFieldArrayLimit: math.MaxUint32, // Initialize to "no field limit" + currentFieldMapLimit: math.MaxUint32, // Initialize to "no field limit" }) resetIdent("za") @@ -182,11 +193,17 @@ func (c contextVar) Arg() string { } type Context struct { - path []contextItem - compFloats bool - clearOmitted bool - newTime bool - asUTC bool + path []contextItem + compFloats bool + clearOmitted bool + newTime bool + asUTC bool + arrayLimit uint32 + mapLimit uint32 + marshalLimits bool + limitPrefix string + currentFieldArrayLimit uint32 // Current field's array limit (0 = no field-level limit) + currentFieldMapLimit uint32 // Current field's map limit (0 = no field-level limit) } func (c *Context) PushString(s string) { @@ -201,6 +218,18 @@ func (c *Context) Pop() { c.path = c.path[:len(c.path)-1] } +// SetFieldLimits sets the current field-specific limits for the context +func (c *Context) SetFieldLimits(arrayLimit, mapLimit uint32) { + c.currentFieldArrayLimit = arrayLimit + c.currentFieldMapLimit = mapLimit +} + +// ClearFieldLimits clears the current field-specific limits (use file limits) +func (c *Context) ClearFieldLimits() { + c.currentFieldArrayLimit = math.MaxUint32 + c.currentFieldMapLimit = math.MaxUint32 +} + func (c *Context) ArgsStr() string { var out string for idx, p := range c.path { diff --git a/gen/unmarshal.go b/gen/unmarshal.go index f1d6bfce..f2fd9b11 100644 --- a/gen/unmarshal.go +++ b/gen/unmarshal.go @@ -3,6 +3,7 @@ package gen import ( "fmt" "io" + "math" "strconv" "strings" ) @@ -63,6 +64,120 @@ func (u *unmarshalGen) assignAndCheck(name string, base string) { u.p.wrapErrCheck(u.ctx.ArgsStr()) } + +func (u *unmarshalGen) assignArray(name string, base string, fieldLimit uint32) { + if !u.p.ok() { + return + } + u.p.printf("\n%s, bts, err = msgp.Read%sBytes(bts)", name, base) + u.p.wrapErrCheck(u.ctx.ArgsStr()) + + // Determine effective limit: field limit > context field limit > file limit + var limit uint32 + var limitName string + + if fieldLimit > 0 { + // Explicit field limit passed as parameter + limit = fieldLimit + limitName = fmt.Sprintf("%d", fieldLimit) + } else if u.ctx.currentFieldArrayLimit != math.MaxUint32 { + // Field limit from context (set during field processing) + limit = u.ctx.currentFieldArrayLimit + limitName = fmt.Sprintf("%d", u.ctx.currentFieldArrayLimit) + } else if u.ctx.arrayLimit != math.MaxUint32 { + // File-level limit + limit = u.ctx.arrayLimit + limitName = fmt.Sprintf("%slimitArrays", u.ctx.limitPrefix) + } + + if limit > 0 && limit != math.MaxUint32 { + u.p.printf("\nif %s > %s {", name, limitName) + u.p.printf("\nerr = msgp.ErrLimitExceeded") + u.p.printf("\nreturn") + u.p.printf("\n}") + } +} + +func (u *unmarshalGen) assignMap(name string, base string, fieldLimit uint32) { + if !u.p.ok() { + return + } + u.p.printf("\n%s, bts, err = msgp.Read%sBytes(bts)", name, base) + u.p.wrapErrCheck(u.ctx.ArgsStr()) + + // Determine effective limit: field limit > context field limit > file limit + var limit uint32 + var limitName string + + if fieldLimit > 0 { + // Explicit field limit passed as parameter + limit = fieldLimit + limitName = fmt.Sprintf("%d", fieldLimit) + } else if u.ctx.currentFieldMapLimit != math.MaxUint32 { + // Field limit from context (set during field processing) + limit = u.ctx.currentFieldMapLimit + limitName = fmt.Sprintf("%d", u.ctx.currentFieldMapLimit) + } else if u.ctx.mapLimit != math.MaxUint32 { + // File-level limit + limit = u.ctx.mapLimit + limitName = fmt.Sprintf("%slimitMaps", u.ctx.limitPrefix) + } + + if limit > 0 && limit != math.MaxUint32 { + u.p.printf("\nif %s > %s {", name, limitName) + u.p.printf("\nerr = msgp.ErrLimitExceeded") + u.p.printf("\nreturn") + u.p.printf("\n}") + } +} + +func (u *unmarshalGen) readBytesWithLimit(refname, lowered string, zerocopy bool, fieldLimit uint32) { + if !u.p.ok() { + return + } + + if zerocopy { + u.p.printf("\n%s, bts, err = msgp.ReadBytesZC(bts)", refname) + } else { + u.p.printf("\n%s, bts, err = msgp.ReadBytesBytes(bts, %s)", refname, lowered) + } + u.p.wrapErrCheck(u.ctx.ArgsStr()) + + // Check byte slice limits after reading + u.checkByteLimits(refname, fieldLimit) +} + +func (u *unmarshalGen) checkByteLimits(vname string, fieldLimit uint32) { + if !u.p.ok() { + return + } + + // Determine effective limit: field limit > context field limit > file limit + var limit uint32 + var limitName string + + if fieldLimit > 0 { + // Explicit field limit passed as parameter + limit = fieldLimit + limitName = fmt.Sprintf("%d", fieldLimit) + } else if u.ctx.currentFieldArrayLimit != math.MaxUint32 { + // Field limit from context (set during field processing) + limit = u.ctx.currentFieldArrayLimit + limitName = fmt.Sprintf("%d", u.ctx.currentFieldArrayLimit) + } else if u.ctx.arrayLimit != math.MaxUint32 { + // File-level limit + limit = u.ctx.arrayLimit + limitName = fmt.Sprintf("%slimitArrays", u.ctx.limitPrefix) + } + + if limit > 0 && limit != math.MaxUint32 { + u.p.printf("\nif uint32(len(%s)) > %s {", vname, limitName) + u.p.printf("\nerr = msgp.ErrLimitExceeded") + u.p.printf("\nreturn") + u.p.printf("\n}") + } +} + func (u *unmarshalGen) gStruct(s *Struct) { if !u.p.ok() { return @@ -91,6 +206,15 @@ func (u *unmarshalGen) tuple(s *Struct) { u.ctx.PushString(s.Fields[i].FieldName) fieldElem := s.Fields[i].FieldElem anField := s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil() + + // Set field-specific limits in context based on struct field's FieldLimit + if s.Fields[i].FieldLimit > 0 { + // Apply same limit to both arrays and maps for this field + u.ctx.SetFieldLimits(s.Fields[i].FieldLimit, s.Fields[i].FieldLimit) + } else { + u.ctx.ClearFieldLimits() + } + if anField { u.p.printf("\nif msgp.IsNil(bts) {\nbts = bts[1:]\n%s = nil\n} else {", fieldElem.Varname()) } @@ -100,6 +224,10 @@ func (u *unmarshalGen) tuple(s *Struct) { } setTypeParams(fieldElem, s.typeParams) next(u, fieldElem) + + // Clear field limits after processing + u.ctx.ClearFieldLimits() + if s.Fields[i].HasTagPart("zerocopy") { setRecursiveZC(fieldElem, false) } @@ -137,7 +265,7 @@ func (u *unmarshalGen) mapstruct(s *Struct) { u.needsField() sz := randIdent() u.p.declare(sz, u32) - u.assignAndCheck(sz, mapHeader) + u.assignMap(sz, mapHeader, 0) oeCount := s.CountFieldTagPart("omitempty") + s.CountFieldTagPart("omitzero") if !u.ctx.clearOmitted { @@ -168,6 +296,15 @@ func (u *unmarshalGen) mapstruct(s *Struct) { fieldElem := s.Fields[i].FieldElem anField := s.Fields[i].HasTagPart("allownil") && fieldElem.AllowNil() + + // Set field-specific limits in context based on struct field's FieldLimit + if s.Fields[i].FieldLimit > 0 { + // Apply same limit to both arrays and maps for this field + u.ctx.SetFieldLimits(s.Fields[i].FieldLimit, s.Fields[i].FieldLimit) + } else { + u.ctx.ClearFieldLimits() + } + if anField { u.p.printf("\nif msgp.IsNil(bts) {\nbts = bts[1:]\n%s = nil\n} else {", fieldElem.Varname()) } @@ -178,6 +315,10 @@ func (u *unmarshalGen) mapstruct(s *Struct) { setTypeParams(fieldElem, s.typeParams) next(u, fieldElem) + + // Clear field limits after processing + u.ctx.ClearFieldLimits() + if s.Fields[i].HasTagPart("zerocopy") { setRecursiveZC(fieldElem, false) } @@ -232,11 +373,7 @@ func (u *unmarshalGen) gBase(b *BaseElem) { switch b.Value { case Bytes: - if b.zerocopy { - u.p.printf("\n%s, bts, err = msgp.ReadBytesZC(bts)", refname) - } else { - u.p.printf("\n%s, bts, err = msgp.ReadBytesBytes(bts, %s)", refname, lowered) - } + u.readBytesWithLimit(refname, lowered, b.zerocopy, 0) case Ext: u.p.printf("\nbts, err = msgp.ReadExtensionBytes(bts, %s)", lowered) case IDENT: @@ -314,7 +451,7 @@ func (u *unmarshalGen) gSlice(s *Slice) { } sz := randIdent() u.p.declare(sz, u32) - u.assignAndCheck(sz, arrayHeader) + u.assignArray(sz, arrayHeader, 0) if s.isAllowNil { u.p.resizeSliceNoNil(sz, s) } else { @@ -330,7 +467,7 @@ func (u *unmarshalGen) gMap(m *Map) { } sz := randIdent() u.p.declare(sz, u32) - u.assignAndCheck(sz, mapHeader) + u.assignMap(sz, mapHeader, 0) // allocate or clear map u.p.resizeMap(sz, m) diff --git a/parse/directives.go b/parse/directives.go index dfe56099..09474a6f 100644 --- a/parse/directives.go +++ b/parse/directives.go @@ -32,6 +32,7 @@ var directives = map[string]directive{ "clearomitted": clearomitted, "newtime": newtime, "timezone": newtimezone, + "limit": limit, } // map of all recognized directives which will be applied @@ -315,3 +316,37 @@ func newtimezone(text []string, f *FileSet) error { infof("using timezone %q\n", text[1]) return nil } + +//msgp:limit arrays:n maps:n marshal:true/false +func limit(text []string, f *FileSet) (err error) { + for _, arg := range text[1:] { + arg = strings.ToLower(strings.TrimSpace(arg)) + switch { + case strings.HasPrefix(arg, "arrays:"): + limitStr := strings.TrimPrefix(arg, "arrays:") + limit, err := strconv.ParseUint(limitStr, 10, 32) + if err != nil { + return fmt.Errorf("invalid arrays limit; found %s, expected positive integer", limitStr) + } + f.ArrayLimit = uint32(limit) + case strings.HasPrefix(arg, "maps:"): + limitStr := strings.TrimPrefix(arg, "maps:") + limit, err := strconv.ParseUint(limitStr, 10, 32) + if err != nil { + return fmt.Errorf("invalid maps limit; found %s, expected positive integer", limitStr) + } + f.MapLimit = uint32(limit) + case strings.HasPrefix(arg, "marshal:"): + marshalStr := strings.TrimPrefix(arg, "marshal:") + marshal, err := strconv.ParseBool(marshalStr) + if err != nil { + return fmt.Errorf("invalid marshal option; found %s, expected 'true' or 'false'", marshalStr) + } + f.MarshalLimits = marshal + default: + return fmt.Errorf("invalid limit directive; found %s, expected 'arrays:n', 'maps:n', or 'marshal:true/false'", arg) + } + } + infof("limits - arrays:%d maps:%d marshal:%t\n", f.ArrayLimit, f.MapLimit, f.MarshalLimits) + return nil +} diff --git a/parse/getast.go b/parse/getast.go index cdc91699..15f4cccc 100644 --- a/parse/getast.go +++ b/parse/getast.go @@ -5,9 +5,11 @@ import ( "go/ast" "go/parser" "go/token" + "math" "os" "reflect" "sort" + "strconv" "strings" "github.com/tinylib/msgp/gen" @@ -36,6 +38,10 @@ type FileSet struct { AllowMapShims bool // Allow map keys to be shimmed (default true) AllowBinMaps bool // Allow maps with binary keys to be used (default false) AutoMapShims bool // Automatically shim map keys of builtin types(default false) + ArrayLimit uint32 // Maximum array/slice size allowed during deserialization + MapLimit uint32 // Maximum map size allowed during deserialization + MarshalLimits bool // Whether to enforce limits during marshaling + LimitPrefix string // Unique prefix for limit constants to avoid collisions tagName string // tag to read field names from pointerRcv bool // generate with pointer receivers. @@ -55,6 +61,8 @@ func File(name string, unexported bool, directives []string) (*FileSet, error) { TypeInfos: make(map[string]*TypeInfo), Identities: make(map[string]gen.Elem), Directives: append([]string{}, directives...), + ArrayLimit: math.MaxUint32, + MapLimit: math.MaxUint32, } fset := token.NewFileSet() @@ -410,6 +418,10 @@ loop: p.ClearOmitted = fs.ClearOmitted p.NewTime = fs.NewTime p.AsUTC = fs.AsUTC + p.ArrayLimit = fs.ArrayLimit + p.MapLimit = fs.MapLimit + p.MarshalLimits = fs.MarshalLimits + p.LimitPrefix = fs.LimitPrefix } func (fs *FileSet) PrintTo(p *gen.Printer) error { @@ -523,11 +535,23 @@ func (fs *FileSet) getField(f *ast.Field) []gen.StructField { } tags := strings.Split(body, ",") if len(tags) >= 2 { - switch tags[1] { - case "extension": - extension = true - case "flatten": - flatten = true + for _, tag := range tags[1:] { + switch tag { + case "extension": + extension = true + case "flatten": + flatten = true + default: + // Check for limit=N format + if strings.HasPrefix(tag, "limit=") { + limitStr := strings.TrimPrefix(tag, "limit=") + if limit, err := strconv.ParseUint(limitStr, 10, 32); err == nil { + sf[0].FieldLimit = uint32(limit) + } else { + warnf("invalid limit value in field tag: %s", limitStr) + } + } + } } } // ignore "-" fields diff --git a/printer/print.go b/printer/print.go index a2b48472..16971bf9 100644 --- a/printer/print.go +++ b/printer/print.go @@ -3,8 +3,11 @@ package printer import ( "bytes" "fmt" + "hash/crc32" "io" + "math" "os" + "path/filepath" "strings" "github.com/tinylib/msgp/gen" @@ -18,7 +21,7 @@ var Logf func(s string, v ...any) // of elements to the given file name and canonical // package path. func PrintFile(file string, f *parse.FileSet, mode gen.Method) error { - out, tests, err := generate(f, mode) + out, tests, err := generate(file, f, mode) if err != nil { return err } @@ -83,7 +86,7 @@ func dedupImports(imp []string) []string { return r } -func generate(f *parse.FileSet, mode gen.Method) (*bytes.Buffer, *bytes.Buffer, error) { +func generate(file string, f *parse.FileSet, mode gen.Method) (*bytes.Buffer, *bytes.Buffer, error) { outbuf := bytes.NewBuffer(make([]byte, 0, 4096)) writePkgHeader(outbuf, f.Package) @@ -99,6 +102,8 @@ func generate(f *parse.FileSet, mode gen.Method) (*bytes.Buffer, *bytes.Buffer, dedup := dedupImports(myImports) writeImportHeader(outbuf, dedup...) + writeLimitConstants(outbuf, file, f) + var testbuf *bytes.Buffer var testwr io.Writer if mode&gen.Test == gen.Test { @@ -136,3 +141,28 @@ func writeImportHeader(b *bytes.Buffer, imports ...string) { } b.WriteString(")\n\n") } + +// generateFilePrefix creates a deterministic, unique prefix for constants based on the file name +func generateFilePrefix(filename string) string { + base := filepath.Base(filename) + hash := crc32.ChecksumIEEE([]byte(base)) + return fmt.Sprintf("z%08x", hash) +} + +func writeLimitConstants(b *bytes.Buffer, file string, f *parse.FileSet) { + if f.ArrayLimit != math.MaxUint32 || f.MapLimit != math.MaxUint32 { + prefix := generateFilePrefix(file) + b.WriteString("// Size limits for msgp deserialization\n") + b.WriteString("const (\n") + if f.ArrayLimit != math.MaxUint32 { + fmt.Fprintf(b, "\t%slimitArrays = %d\n", prefix, f.ArrayLimit) + } + if f.MapLimit != math.MaxUint32 { + fmt.Fprintf(b, "\t%slimitMaps = %d\n", prefix, f.MapLimit) + } + b.WriteString(")\n\n") + + // Store the prefix in FileSet so generators can use it + f.LimitPrefix = prefix + } +}