diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 49b959a0..3b46bde0 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,7 +21,7 @@ jobs: test: strategy: matrix: - go-version: [1.23.x, 1.24.x, 1.25.x] + go-version: [1.24.x, 1.25.x] os: [ubuntu-latest] runs-on: ${{ matrix.os }} timeout-minutes: 10 diff --git a/_generated/binary_marshaler.go b/_generated/binary_marshaler.go new file mode 100644 index 00000000..8b2eadeb --- /dev/null +++ b/_generated/binary_marshaler.go @@ -0,0 +1,289 @@ +package _generated + +import ( + "encoding" + "fmt" +) + +//go:generate msgp -v + +// BinaryTestType implements encoding.BinaryMarshaler and encoding.BinaryUnmarshaler +type BinaryTestType struct { + Value string +} + +func (t *BinaryTestType) MarshalBinary() ([]byte, error) { + return []byte(t.Value), nil +} + +func (t *BinaryTestType) UnmarshalBinary(data []byte) error { + t.Value = string(data) + return nil +} + +// Verify it implements the interfaces +var _ encoding.BinaryMarshaler = (*BinaryTestType)(nil) +var _ encoding.BinaryUnmarshaler = (*BinaryTestType)(nil) + +//msgp:binmarshal BinaryTestType + +// TextBinTestType implements encoding.TextMarshaler and encoding.TextUnmarshaler +type TextBinTestType struct { + Value string +} + +func (t *TextBinTestType) MarshalText() ([]byte, error) { + return []byte(fmt.Sprintf("text:%s", t.Value)), nil +} + +func (t *TextBinTestType) UnmarshalText(data []byte) error { + t.Value = string(data[5:]) // Remove "text:" prefix + return nil +} + +var _ encoding.TextMarshaler = (*TextBinTestType)(nil) +var _ encoding.TextUnmarshaler = (*TextBinTestType)(nil) + +//msgp:textmarshal TextBinTestType + +// TextStringTestType for testing as:string option +type TextStringTestType struct { + Value string +} + +func (t *TextStringTestType) MarshalText() ([]byte, error) { + return []byte(fmt.Sprintf("stringtext:%s", t.Value)), nil +} + +func (t *TextStringTestType) UnmarshalText(data []byte) error { + t.Value = string(data[11:]) // Remove "stringtext:" prefix + return nil +} + +var _ encoding.TextMarshaler = (*TextStringTestType)(nil) +var _ encoding.TextUnmarshaler = (*TextStringTestType)(nil) + +//msgp:textmarshal as:string TextStringTestType + +// TestStruct contains various combinations of marshaler types +type TestStruct struct { + // Direct values + BinaryValue BinaryTestType `msg:"bin_val"` + TextBinValue TextBinTestType `msg:"text_bin_val"` + TextStringValue TextStringTestType `msg:"text_str_val"` + + // Pointers + BinaryPtr *BinaryTestType `msg:"bin_ptr"` + TextBinPtr *TextBinTestType `msg:"text_bin_ptr,omitempty"` + TextStringPtr *TextStringTestType `msg:"text_str_ptr,omitempty"` + + // Slices + BinarySlice []BinaryTestType `msg:"bin_slice"` + TextBinSlice []TextBinTestType `msg:"text_bin_slice"` + TextStringSlice []TextStringTestType `msg:"text_str_slice"` + + // Arrays + BinaryArray [3]BinaryTestType `msg:"bin_array"` + TextBinArray [2]TextBinTestType `msg:"text_bin_array"` + TextStringArray [4]TextStringTestType `msg:"text_str_array"` + + // Maps with marshaler types as values + BinaryMap map[string]BinaryTestType `msg:"bin_map"` + TextBinMap map[string]TextBinTestType `msg:"text_bin_map"` + TextStringMap map[string]TextStringTestType `msg:"text_str_map"` + + // Nested pointers and slices + NestedPtrSlice []*BinaryTestType `msg:"nested_ptr_slice"` + SliceOfArrays [][2]TextBinTestType `msg:"slice_of_arrays"` + MapOfSlices map[string][]BinaryTestType `msg:"map_of_slices"` +} + +//msgp:binmarshal ErrorTestType + +// ErrorTestType for testing error conditions +type ErrorTestType struct { + ShouldError bool +} + +func (e *ErrorTestType) MarshalBinary() ([]byte, error) { + if e.ShouldError { + return nil, fmt.Errorf("intentional marshal error") + } + return []byte("ok"), nil +} + +func (e *ErrorTestType) UnmarshalBinary(data []byte) error { + if string(data) == "error" { + return fmt.Errorf("intentional unmarshal error") + } + e.ShouldError = false + return nil +} + +// Test types for as:string positioning flexibility + +// TestTextMarshalerStringMiddle for testing as:string in middle of type list +type TestTextMarshalerStringMiddle struct { + Value string +} + +func (t *TestTextMarshalerStringMiddle) MarshalText() ([]byte, error) { + return []byte("middle:" + t.Value), nil +} + +func (t *TestTextMarshalerStringMiddle) UnmarshalText(text []byte) error { + t.Value = string(text) + return nil +} + +// TestTextMarshalerStringEnd for testing as:string at end of type list +type TestTextMarshalerStringEnd struct { + Value string +} + +func (t *TestTextMarshalerStringEnd) MarshalText() ([]byte, error) { + return []byte("end:" + t.Value), nil +} + +func (t *TestTextMarshalerStringEnd) UnmarshalText(text []byte) error { + t.Value = string(text) + return nil +} + +// TestTextAppenderStringPos for testing textappend with as:string positioning +type TestTextAppenderStringPos struct { + Value string +} + +func (t *TestTextAppenderStringPos) AppendText(dst []byte) ([]byte, error) { + return append(dst, []byte("append:"+t.Value)...), nil +} + +func (t *TestTextAppenderStringPos) UnmarshalText(text []byte) error { + t.Value = string(text) + return nil +} + +// BinaryAppenderType implements encoding.BinaryAppender (Go 1.22+) +type BinaryAppenderType struct { + Value string +} + +func (t *BinaryAppenderType) AppendBinary(dst []byte) ([]byte, error) { + return append(dst, []byte("binappend:"+t.Value)...), nil +} + +func (t *BinaryAppenderType) UnmarshalBinary(data []byte) error { + t.Value = string(data) + return nil +} + +// TextAppenderBinType implements encoding.TextAppender (stored as binary) +type TextAppenderBinType struct { + Value string +} + +func (t *TextAppenderBinType) AppendText(dst []byte) ([]byte, error) { + return append(dst, []byte("textbin:"+t.Value)...), nil +} + +func (t *TextAppenderBinType) UnmarshalText(text []byte) error { + t.Value = string(text) + return nil +} + +// ErrorBinaryAppenderType for testing error conditions with BinaryAppender +type ErrorBinaryAppenderType struct { + ShouldError bool + Value string +} + +func (e *ErrorBinaryAppenderType) AppendBinary(dst []byte) ([]byte, error) { + if e.ShouldError { + return nil, fmt.Errorf("intentional append binary error") + } + return append(dst, []byte("ok")...), nil +} + +func (e *ErrorBinaryAppenderType) UnmarshalBinary(data []byte) error { + if string(data) == "error" { + return fmt.Errorf("intentional unmarshal binary error") + } + e.ShouldError = false + e.Value = string(data) + return nil +} + +// ErrorTextAppenderType for testing error conditions with TextAppender +type ErrorTextAppenderType struct { + ShouldError bool + Value string +} + +func (e *ErrorTextAppenderType) AppendText(dst []byte) ([]byte, error) { + if e.ShouldError { + return nil, fmt.Errorf("intentional append text error") + } + return append(dst, []byte("ok")...), nil +} + +func (e *ErrorTextAppenderType) UnmarshalText(text []byte) error { + if string(text) == "error" { + return fmt.Errorf("intentional unmarshal text error") + } + e.ShouldError = false + e.Value = string(text) + return nil +} + +//msgp:binappend BinaryAppenderType ErrorBinaryAppenderType +//msgp:textappend TextAppenderBinType ErrorTextAppenderType +//msgp:textmarshal TestTextMarshalerStringMiddle as:string TestTextMarshalerStringEnd +//msgp:textappend TestTextAppenderStringPos as:string + +//msgp:binappend BinaryAppenderValue + +// BinaryAppenderValue implements encoding.BinaryAppender (Go 1.22+) +type BinaryAppenderValue struct { + Value string `msg:"-"` +} + +func (t BinaryAppenderValue) AppendBinary(dst []byte) ([]byte, error) { + return append(dst, []byte("binappend:"+t.Value)...), nil +} + +func (t *BinaryAppenderValue) UnmarshalBinary(data []byte) error { + t.Value = string(data) + return nil +} + +//msgp:textappend TestAppendTextString as:string + +type TestAppendTextString struct { + Value string `msg:"-"` +} + +func (t TestAppendTextString) AppendText(dst []byte) ([]byte, error) { + return append(dst, []byte("append:"+t.Value)...), nil +} + +func (t *TestAppendTextString) UnmarshalText(text []byte) error { + t.Value = string(text) + return nil +} + +//msgp:textappend TextAppenderBinValue + +// TextAppenderBinValue implements encoding.TextAppender (stored as binary) +type TextAppenderBinValue struct { + Value string `msg:"-"` +} + +func (t TextAppenderBinValue) AppendText(dst []byte) ([]byte, error) { + return append(dst, []byte("textbin:"+t.Value)...), nil +} + +func (t *TextAppenderBinValue) UnmarshalText(text []byte) error { + t.Value = string(text) + return nil +} diff --git a/_generated/binary_marshaler_test.go b/_generated/binary_marshaler_test.go new file mode 100644 index 00000000..7d08442b --- /dev/null +++ b/_generated/binary_marshaler_test.go @@ -0,0 +1,656 @@ +package _generated + +import ( + "bytes" + "strings" + "testing" + + "github.com/tinylib/msgp/msgp" +) + +func TestBinaryMarshalerDirective(t *testing.T) { + original := &BinaryTestType{Value: "test_data"} + + // Test marshal + data, err := original.MarshalMsg(nil) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + // Test unmarshal + result := &BinaryTestType{} + _, err = result.UnmarshalMsg(data) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + if result.Value != original.Value { + t.Errorf("Expected %s, got %s", original.Value, result.Value) + } +} + +func TestTextMarshalerBinDirective(t *testing.T) { + original := &TextBinTestType{Value: "test_data"} + + // Test marshal + data, err := original.MarshalMsg(nil) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + // Test unmarshal + result := &TextBinTestType{} + _, err = result.UnmarshalMsg(data) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + if result.Value != original.Value { + t.Errorf("Expected %s, got %s", original.Value, result.Value) + } +} + +func TestTextMarshalerStringDirective(t *testing.T) { + original := &TextStringTestType{Value: "test_data"} + + // Test marshal + data, err := original.MarshalMsg(nil) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + // Test unmarshal + result := &TextStringTestType{} + _, err = result.UnmarshalMsg(data) + if err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + if result.Value != original.Value { + t.Errorf("Expected %s, got %s", original.Value, result.Value) + } +} + +func TestBinaryMarshalerRoundTrip(t *testing.T) { + tests := []string{"", "hello", "world with spaces", "unicode: 测试"} + + for _, testVal := range tests { + original := &BinaryTestType{Value: testVal} + + // Test marshal/unmarshal + buf, err := original.MarshalMsg(nil) + if err != nil { + t.Fatalf("MarshalMsg failed for %q: %v", testVal, err) + } + + result := &BinaryTestType{} + _, err = result.UnmarshalMsg(buf) + if err != nil { + t.Fatalf("UnmarshalMsg failed for %q: %v", testVal, err) + } + + if result.Value != testVal { + t.Errorf("Round trip failed: expected %q, got %q", testVal, result.Value) + } + } +} + +func TestTextMarshalerSize(t *testing.T) { + testType := &TextBinTestType{Value: "test"} + + // Get the size + size := testType.Msgsize() + + // Marshal and check the actual size matches estimate + data, err := testType.MarshalMsg(nil) + if err != nil { + t.Fatalf("MarshalMsg failed: %v", err) + } + + if len(data) > size { + t.Errorf("Msgsize underestimated: estimated %d, actual %d", size, len(data)) + } + + if size > len(data)+100 { // Allow some reasonable overhead + t.Errorf("Msgsize too conservative: estimated %d, actual %d", size, len(data)) + } +} + +func TestStructWithMarshalerFields(t *testing.T) { + // Create a test struct with all field types populated + original := &TestStruct{ + // Direct values + BinaryValue: BinaryTestType{Value: "binary_val"}, + TextBinValue: TextBinTestType{Value: "text_bin_val"}, + TextStringValue: TextStringTestType{Value: "text_str_val"}, + + // Pointers + BinaryPtr: &BinaryTestType{Value: "binary_ptr"}, + TextBinPtr: &TextBinTestType{Value: "text_bin_ptr"}, + TextStringPtr: &TextStringTestType{Value: "text_str_ptr"}, + + // Slices + BinarySlice: []BinaryTestType{ + {Value: "bin_slice_0"}, + {Value: "bin_slice_1"}, + }, + TextBinSlice: []TextBinTestType{ + {Value: "text_bin_slice_0"}, + {Value: "text_bin_slice_1"}, + }, + TextStringSlice: []TextStringTestType{ + {Value: "text_str_slice_0"}, + {Value: "text_str_slice_1"}, + }, + + // Arrays + BinaryArray: [3]BinaryTestType{ + {Value: "bin_array_0"}, + {Value: "bin_array_1"}, + {Value: "bin_array_2"}, + }, + TextBinArray: [2]TextBinTestType{ + {Value: "text_bin_array_0"}, + {Value: "text_bin_array_1"}, + }, + TextStringArray: [4]TextStringTestType{ + {Value: "text_str_array_0"}, + {Value: "text_str_array_1"}, + {Value: "text_str_array_2"}, + {Value: "text_str_array_3"}, + }, + + // Maps + BinaryMap: map[string]BinaryTestType{ + "key1": {Value: "bin_map_val1"}, + "key2": {Value: "bin_map_val2"}, + }, + TextBinMap: map[string]TextBinTestType{ + "key1": {Value: "text_bin_map_val1"}, + "key2": {Value: "text_bin_map_val2"}, + }, + TextStringMap: map[string]TextStringTestType{ + "key1": {Value: "text_str_map_val1"}, + "key2": {Value: "text_str_map_val2"}, + }, + + // Nested types + NestedPtrSlice: []*BinaryTestType{ + {Value: "nested_ptr_0"}, + {Value: "nested_ptr_1"}, + }, + SliceOfArrays: [][2]TextBinTestType{ + {{Value: "slice_arr_0_0"}, {Value: "slice_arr_0_1"}}, + {{Value: "slice_arr_1_0"}, {Value: "slice_arr_1_1"}}, + }, + MapOfSlices: map[string][]BinaryTestType{ + "slice1": {{Value: "map_slice_val_0"}, {Value: "map_slice_val_1"}}, + "slice2": {{Value: "map_slice_val_2"}}, + }, + } + + // Test marshal/unmarshal + data, err := original.MarshalMsg(nil) + if err != nil { + t.Fatalf("MarshalMsg failed: %v", err) + } + + result := &TestStruct{} + _, err = result.UnmarshalMsg(data) + if err != nil { + t.Fatalf("UnmarshalMsg failed: %v", err) + } + + // Verify direct values + if result.BinaryValue.Value != original.BinaryValue.Value { + t.Errorf("BinaryValue mismatch: expected %s, got %s", original.BinaryValue.Value, result.BinaryValue.Value) + } + if result.TextBinValue.Value != original.TextBinValue.Value { + t.Errorf("TextBinValue mismatch: expected %s, got %s", original.TextBinValue.Value, result.TextBinValue.Value) + } + if result.TextStringValue.Value != original.TextStringValue.Value { + t.Errorf("TextStringValue mismatch: expected %s, got %s", original.TextStringValue.Value, result.TextStringValue.Value) + } + + // Verify pointers + if result.BinaryPtr == nil || result.BinaryPtr.Value != original.BinaryPtr.Value { + t.Errorf("BinaryPtr mismatch") + } + if result.TextBinPtr == nil || result.TextBinPtr.Value != original.TextBinPtr.Value { + t.Errorf("TextBinPtr mismatch") + } + if result.TextStringPtr == nil || result.TextStringPtr.Value != original.TextStringPtr.Value { + t.Errorf("TextStringPtr mismatch") + } + + // Verify slices + if len(result.BinarySlice) != len(original.BinarySlice) { + t.Errorf("BinarySlice length mismatch: expected %d, got %d", len(original.BinarySlice), len(result.BinarySlice)) + } else { + for i := range result.BinarySlice { + if result.BinarySlice[i].Value != original.BinarySlice[i].Value { + t.Errorf("BinarySlice[%d] mismatch: expected %s, got %s", i, original.BinarySlice[i].Value, result.BinarySlice[i].Value) + } + } + } + + // Verify arrays + for i := range result.BinaryArray { + if result.BinaryArray[i].Value != original.BinaryArray[i].Value { + t.Errorf("BinaryArray[%d] mismatch: expected %s, got %s", i, original.BinaryArray[i].Value, result.BinaryArray[i].Value) + } + } + + // Verify maps + if len(result.BinaryMap) != len(original.BinaryMap) { + t.Errorf("BinaryMap length mismatch: expected %d, got %d", len(original.BinaryMap), len(result.BinaryMap)) + } else { + for k, v := range original.BinaryMap { + if resultVal, exists := result.BinaryMap[k]; !exists || resultVal.Value != v.Value { + t.Errorf("BinaryMap[%s] mismatch: expected %s, got %s", k, v.Value, resultVal.Value) + } + } + } + + // Verify nested structures + if len(result.NestedPtrSlice) != len(original.NestedPtrSlice) { + t.Errorf("NestedPtrSlice length mismatch") + } else { + for i := range result.NestedPtrSlice { + if result.NestedPtrSlice[i] == nil || result.NestedPtrSlice[i].Value != original.NestedPtrSlice[i].Value { + t.Errorf("NestedPtrSlice[%d] mismatch", i) + } + } + } +} + +func TestStructWithOmitEmptyFields(t *testing.T) { + // Test struct with nil pointers (should use omitempty/omitzero) + original := &TestStruct{ + BinaryValue: BinaryTestType{Value: "present"}, + TextBinValue: TextBinTestType{Value: "also_present"}, + TextStringValue: TextStringTestType{Value: "string_present"}, + // Leave pointers nil + BinaryPtr: nil, + TextBinPtr: nil, // this has omitempty tag + TextStringPtr: nil, // this has omitzero tag + // Empty slices and maps + BinarySlice: []BinaryTestType{}, + BinaryMap: map[string]BinaryTestType{}, + } + + // Test marshal/unmarshal + data, err := original.MarshalMsg(nil) + if err != nil { + t.Fatalf("MarshalMsg failed: %v", err) + } + + result := &TestStruct{} + _, err = result.UnmarshalMsg(data) + if err != nil { + t.Fatalf("UnmarshalMsg failed: %v", err) + } + + // Verify values are preserved and nils are handled correctly + if result.BinaryValue.Value != original.BinaryValue.Value { + t.Errorf("BinaryValue mismatch") + } + if result.BinaryPtr != nil { + t.Errorf("Expected BinaryPtr to be nil") + } + if result.TextBinPtr != nil { + t.Errorf("Expected TextBinPtr to be nil") + } + if result.TextStringPtr != nil { + t.Errorf("Expected TextStringPtr to be nil") + } +} + +func TestStructSizeEstimation(t *testing.T) { + testStruct := &TestStruct{ + BinaryValue: BinaryTestType{Value: "test"}, + BinarySlice: []BinaryTestType{ + {Value: "slice1"}, + {Value: "slice2"}, + }, + BinaryMap: map[string]BinaryTestType{ + "key": {Value: "value"}, + }, + } + + // Get size estimate + size := testStruct.Msgsize() + + // Marshal and verify size estimate + data, err := testStruct.MarshalMsg(nil) + if err != nil { + t.Fatalf("MarshalMsg failed: %v", err) + } + + if len(data) > size { + t.Errorf("Msgsize underestimated: estimated %d, actual %d", size, len(data)) + } +} + +func TestMarshalerErrorHandling(t *testing.T) { + // Test marshaling with marshaler that can fail + errorType := &ErrorTestType{ShouldError: true} + + _, err := errorType.MarshalMsg(nil) + if err == nil { + t.Error("Expected marshal error but got none") + } + + // Test in struct context + testStruct := &TestStruct{ + BinaryValue: BinaryTestType{Value: "good"}, + } + + // This should succeed + _, err = testStruct.MarshalMsg(nil) + if err != nil { + t.Errorf("Expected no error but got: %v", err) + } +} + +func TestTextMarshalerAsStringPositioning(t *testing.T) { + // Test that as:string works when placed in middle of type list + middle := &TestTextMarshalerStringMiddle{Value: "test_middle"} + + // Marshal + data, err := middle.MarshalMsg(nil) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + // Unmarshal + var decoded TestTextMarshalerStringMiddle + _, err = decoded.UnmarshalMsg(data) + if err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if decoded.Value != "middle:test_middle" { + t.Errorf("Expected 'middle:test_middle', got '%s'", decoded.Value) + } + + // Test that as:string works when placed at end of type list + end := &TestTextMarshalerStringEnd{Value: "test_end"} + + // Marshal + data, err = end.MarshalMsg(nil) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + // Unmarshal + var decodedEnd TestTextMarshalerStringEnd + _, err = decodedEnd.UnmarshalMsg(data) + if err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if decodedEnd.Value != "end:test_end" { + t.Errorf("Expected 'end:test_end', got '%s'", decodedEnd.Value) + } +} + +func TestTextAppenderAsStringPositioning(t *testing.T) { + // Test that as:string works with textappend directive + appender := &TestTextAppenderStringPos{Value: "test_append"} + + // Marshal + data, err := appender.MarshalMsg(nil) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + // Unmarshal + var decoded TestTextAppenderStringPos + _, err = decoded.UnmarshalMsg(data) + if err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if decoded.Value != "append:test_append" { + t.Errorf("Expected 'append:test_append', got '%s'", decoded.Value) + } +} + +func TestBinaryAppenderDirective(t *testing.T) { + for _, size := range []int{10, 100, 1000, 100000} { + + // Test BinaryAppender interface + appender := &BinaryAppenderType{Value: strings.Repeat("a", size)} + + // Test round trip + data, err := appender.MarshalMsg(nil) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + var decoded BinaryAppenderType + _, err = decoded.UnmarshalMsg(data) + if err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if want := "binappend:" + appender.Value; decoded.Value != want { + t.Errorf("Expected '%s', got '%s'", want, decoded.Value) + } + + // Test encode/decode + var buf bytes.Buffer + writer := msgp.NewWriter(&buf) + err = appender.EncodeMsg(writer) + if err != nil { + t.Fatalf("Failed to encode: %v", err) + } + writer.Flush() + + reader := msgp.NewReader(&buf) + var decoded2 BinaryAppenderType + err = decoded2.DecodeMsg(reader) + if err != nil { + t.Fatalf("Failed to decode: %v", err) + } + + if want := "binappend:" + appender.Value; decoded2.Value != want { + t.Errorf("Expected '%s', got '%s'", want, decoded2.Value) + } + } +} + +func TestTextAppenderBinDirective(t *testing.T) { + // Test TextAppender interface (stored as binary) + appender := &TextAppenderBinType{Value: "test_text_bin"} + + // Test round trip + data, err := appender.MarshalMsg(nil) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + var decoded TextAppenderBinType + _, err = decoded.UnmarshalMsg(data) + if err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if decoded.Value != "textbin:test_text_bin" { + t.Errorf("Expected 'textbin:test_text_bin', got '%s'", decoded.Value) + } + + // Test encode/decode + var buf bytes.Buffer + writer := msgp.NewWriter(&buf) + err = appender.EncodeMsg(writer) + if err != nil { + t.Fatalf("Failed to encode: %v", err) + } + writer.Flush() + + reader := msgp.NewReader(&buf) + var decoded2 TextAppenderBinType + err = decoded2.DecodeMsg(reader) + if err != nil { + t.Fatalf("Failed to decode: %v", err) + } + + if decoded2.Value != "textbin:test_text_bin" { + t.Errorf("Expected 'textbin:test_text_bin', got '%s'", decoded2.Value) + } +} + +func TestBinaryAppenderErrorHandling(t *testing.T) { + // Test error handling with BinaryAppender + errorType := &ErrorBinaryAppenderType{ShouldError: true} + + _, err := errorType.MarshalMsg(nil) + if err == nil { + t.Error("Expected marshal error but got none") + } + + // Test successful case + successType := &ErrorBinaryAppenderType{ShouldError: false} + data, err := successType.MarshalMsg(nil) + if err != nil { + t.Fatalf("Expected no error but got: %v", err) + } + + var decoded ErrorBinaryAppenderType + _, err = decoded.UnmarshalMsg(data) + if err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if decoded.Value != "ok" { + t.Errorf("Expected 'ok', got '%s'", decoded.Value) + } +} + +func TestTextAppenderErrorHandling(t *testing.T) { + // Test error handling with TextAppender + errorType := &ErrorTextAppenderType{ShouldError: true} + + _, err := errorType.MarshalMsg(nil) + if err == nil { + t.Error("Expected marshal error but got none") + } + + // Test successful case + successType := &ErrorTextAppenderType{ShouldError: false} + data, err := successType.MarshalMsg(nil) + if err != nil { + t.Fatalf("Expected no error but got: %v", err) + } + + var decoded ErrorTextAppenderType + _, err = decoded.UnmarshalMsg(data) + if err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + if decoded.Value != "ok" { + t.Errorf("Expected 'ok', got '%s'", decoded.Value) + } +} + +func TestAppenderTypesSizeEstimation(t *testing.T) { + // Test that size estimation works for appender types + binaryAppender := &BinaryAppenderType{Value: "size_test"} + size := binaryAppender.Msgsize() + + data, err := binaryAppender.MarshalMsg(nil) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + if len(data) > size { + t.Errorf("Binary appender size underestimated: estimated %d, actual %d", size, len(data)) + } + + textAppender := &TextAppenderBinType{Value: "size_test"} + size = textAppender.Msgsize() + + data, err = textAppender.MarshalMsg(nil) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + if len(data) > size { + t.Errorf("Text appender size underestimated: estimated %d, actual %d", size, len(data)) + } +} + +func TestBinaryAppenderValueRoundTrip(t *testing.T) { + tests := []string{"", "hello", "world with spaces", "unicode: 测试"} + + for _, testVal := range tests { + original := BinaryAppenderValue{Value: testVal} + + buf, err := original.MarshalMsg(nil) + if err != nil { + t.Fatalf("MarshalMsg failed for %q: %v", testVal, err) + } + + var result BinaryAppenderValue + _, err = result.UnmarshalMsg(buf) + if err != nil { + t.Fatalf("UnmarshalMsg failed for %q: %v", testVal, err) + } + + expected := "binappend:" + testVal + if result.Value != expected { + t.Errorf("Round trip failed: expected %q, got %q", expected, result.Value) + } + } +} + +func TestTestAppendTextStringRoundTrip(t *testing.T) { + tests := []string{"", "hello", "world with spaces", "unicode: 测试"} + + for _, testVal := range tests { + original := TestAppendTextString{Value: testVal} + + buf, err := original.MarshalMsg(nil) + if err != nil { + t.Fatalf("MarshalMsg failed for %q: %v", testVal, err) + } + + var result TestAppendTextString + _, err = result.UnmarshalMsg(buf) + if err != nil { + t.Fatalf("UnmarshalMsg failed for %q: %v", testVal, err) + } + + expected := "append:" + testVal + if result.Value != expected { + t.Errorf("Round trip failed: expected %q, got %q", expected, result.Value) + } + } +} + +func TestTextAppenderBinValueRoundTrip(t *testing.T) { + tests := []string{"", "hello", "world with spaces", "unicode: 测试"} + + for _, testVal := range tests { + original := TextAppenderBinValue{Value: testVal} + + buf, err := original.MarshalMsg(nil) + if err != nil { + t.Fatalf("MarshalMsg failed for %q: %v", testVal, err) + } + + var result TextAppenderBinValue + _, err = result.UnmarshalMsg(buf) + if err != nil { + t.Fatalf("UnmarshalMsg failed for %q: %v", testVal, err) + } + + expected := "textbin:" + testVal + if result.Value != expected { + t.Errorf("Round trip failed: expected %q, got %q", expected, result.Value) + } + } +} diff --git a/gen/decode.go b/gen/decode.go index 17f14454..6e1af0ee 100644 --- a/gen/decode.go +++ b/gen/decode.go @@ -399,6 +399,14 @@ func (d *decodeGen) gBase(b *BaseElem) { vname := b.Varname() // e.g. "z.FieldOne" bname := b.BaseName() // e.g. "Float64" checkNil := vname // Name of var to check for nil + alwaysRef := vname + + // make sure we always reference the pointer + if strings.Contains(alwaysRef, "*") { + alwaysRef = strings.Trim(alwaysRef, "*()") + } else if !b.parentIsPtr { + alwaysRef = "&" + vname + } // handle special cases // for object type. @@ -410,6 +418,12 @@ func (d *decodeGen) gBase(b *BaseElem) { } else { checkNil = d.readBytesWithLimit(vname, 0) } + case BinaryMarshaler, BinaryAppender: + d.p.printf("\nerr = dc.ReadBinaryUnmarshal(%s)", alwaysRef) + case TextMarshalerBin, TextAppenderBin: + d.p.printf("\nerr = dc.ReadTextUnmarshal(%s)", alwaysRef) + case TextMarshalerString, TextAppenderString: + d.p.printf("\nerr = dc.ReadTextUnmarshalString(%s)", alwaysRef) case IDENT: dst := b.BaseType() if b.typeParams.isPtr { @@ -539,6 +553,10 @@ func (d *decodeGen) gPtr(p *Ptr) { tp.isPtr = true p.Value.SetTypeParams(tp) } + if be, ok := p.Value.(*BaseElem); ok { + be.parentIsPtr = true + defer func() { be.parentIsPtr = false }() + } next(d, p.Value) d.p.closeblock() } diff --git a/gen/elem.go b/gen/elem.go index b85bd714..ed05be60 100644 --- a/gen/elem.go +++ b/gen/elem.go @@ -100,6 +100,18 @@ const ( AUint32 ABool + // Binary marshaler types + BinaryMarshaler // encoding.BinaryMarshaler/BinaryUnmarshaler + BinaryAppender // encoding.BinaryAppender/BinaryUnmarshaler + + // Text marshaler types (stored as binary by default) + TextMarshalerBin // encoding.TextMarshaler/TextUnmarshaler -> bin + TextAppenderBin // encoding.TextAppender/TextUnmarshaler -> bin + + // Text marshaler types (stored as string) + TextMarshalerString // encoding.TextMarshaler/TextUnmarshaler -> string + TextAppenderString // encoding.TextAppender/TextUnmarshaler -> string + IDENT // IDENT means an unrecognized identifier ) @@ -491,7 +503,10 @@ func (s *Ptr) SetVarname(a string) { case *BaseElem: // identities have pointer receivers - if x.Value == IDENT { + // marshaler types also have pointer receivers + if x.Value == IDENT || x.Value == BinaryMarshaler || x.Value == BinaryAppender || + x.Value == TextMarshalerBin || x.Value == TextAppenderBin || + x.Value == TextMarshalerString || x.Value == TextAppenderString { // replace directive sets Convert=true and Needsref=true // since BaseElem is behind a pointer we set Needsref=false if x.Convert { @@ -673,6 +688,7 @@ type BaseElem struct { zerocopy bool // Allow zerocopy for byte slices in unmarshal. mustinline bool // must inline; not printable needsref bool // needs reference for shim + parentIsPtr bool // parent is a pointer allowNil *bool // Override from parent. } @@ -938,6 +954,18 @@ func (k Primitive) String() string { return "atomic.Uint32" case ABool: return "atomic.Bool" + case BinaryMarshaler: + return "BinaryMarshaler" + case BinaryAppender: + return "BinaryAppender" + case TextMarshalerBin: + return "TextMarshalerBin" + case TextAppenderBin: + return "TextAppenderBin" + case TextMarshalerString: + return "TextMarshalerString" + case TextAppenderString: + return "TextAppenderString" case IDENT: return "Ident" default: diff --git a/gen/encode.go b/gen/encode.go index 1805b7be..84c2f323 100644 --- a/gen/encode.go +++ b/gen/encode.go @@ -75,6 +75,23 @@ func (e *encodeGen) Fuse(b []byte) { } } +// binaryEncodeCall generates code for marshaler interfaces +func (e *encodeGen) binaryEncodeCall(vname, method, writeType, arg string) { + bts := randIdent() + e.p.printf("\nvar %s []byte", bts) + if arg == "" { + e.p.printf("\n%s, err = %s.%s()", bts, vname, method) + } else { + e.p.printf("\n%s, err = %s.%s(%s)", bts, vname, method, arg) + } + e.p.wrapErrCheck(e.ctx.ArgsStr()) + if writeType == "String" { + e.writeAndCheck(writeType, literalFmt, "string("+bts+")") + } else { + e.writeAndCheck(writeType, literalFmt, bts) + } +} + func (e *encodeGen) Execute(p Elem, ctx Context) error { e.ctx = &ctx if !e.p.ok() { @@ -360,6 +377,22 @@ func (e *encodeGen) gBase(b *BaseElem) { case AInt64, AInt32, AUint64, AUint32, ABool: t := strings.TrimPrefix(b.BaseName(), "atomic.") e.writeAndCheck(t, literalFmt, strings.TrimPrefix(vname, "*")+".Load()") + case BinaryMarshaler: + e.binaryEncodeCall(vname, "MarshalBinary", "Bytes", "") + case TextMarshalerBin: + e.binaryEncodeCall(vname, "MarshalText", "Bytes", "") + case TextMarshalerString: + e.binaryEncodeCall(vname, "MarshalText", "String", "") + case BinaryAppender: + // We do not know if the interface is implemented on pointer or value. + vname = strings.Trim(vname, "*()") + e.writeAndCheck("BinaryAppender", literalFmt, vname) + case TextAppenderBin: + vname = strings.Trim(vname, "*()") + e.writeAndCheck("TextAppender", literalFmt, vname) + case TextAppenderString: + vname = strings.Trim(vname, "*()") + e.writeAndCheck("TextAppenderString", literalFmt, vname) case IDENT: // unknown identity dst := b.BaseType() if b.typeParams.isPtr { diff --git a/gen/marshal.go b/gen/marshal.go index 1f033d80..433dd260 100644 --- a/gen/marshal.go +++ b/gen/marshal.go @@ -372,6 +372,24 @@ func (m *marshalGen) gBase(b *BaseElem) { var echeck bool switch b.Value { + case BinaryMarshaler: + echeck = true + m.binaryMarshalCall(vname, "MarshalBinary", "", "msgp.AppendBytes") + case BinaryAppender: + echeck = false + m.binaryAppendCall(vname, "AppendBinary", "msgp.AppendBytes") + case TextMarshalerBin: + echeck = true + m.binaryMarshalCall(vname, "MarshalText", "", "msgp.AppendBytes") + case TextAppenderBin: + echeck = false + m.binaryAppendCall(vname, "AppendText", "msgp.AppendBytes") + case TextMarshalerString: + echeck = true + m.binaryMarshalCall(vname, "MarshalText", "string", "msgp.AppendString") + case TextAppenderString: + echeck = false + m.binaryAppendCall(vname, "AppendText", "msgp.AppendString") case IDENT: dst := b.BaseType() if b.typeParams.isPtr { @@ -397,3 +415,34 @@ func (m *marshalGen) gBase(b *BaseElem) { m.p.wrapErrCheck(m.ctx.ArgsStr()) } } + +// binaryMarshalCall generates code for marshaler interfaces that return []byte +func (m *marshalGen) binaryMarshalCall(vname, method, convert, appendFunc string) { + bts := randIdent() + vname = strings.Trim(vname, "(*)") + m.p.printf("\nvar %s []byte", bts) + m.p.printf("\n%s, err = %s.%s()", bts, vname, method) + m.p.wrapErrCheck(m.ctx.ArgsStr()) + if convert != "" { + m.p.printf("\no = %s(o, %s(%s))", appendFunc, convert, bts) + } else { + m.p.printf("\no = %s(o, %s)", appendFunc, bts) + } +} + +// binaryAppendCall generates code for appender interfaces that use pre-allocated buffer. +// We optimize for cases where the size is 0-256 bytes. +func (m *marshalGen) binaryAppendCall(vname, method, appendFunc string) { + sz := randIdent() + vname = strings.Trim(vname, "(*)") + // Reserve 2 bytes for the header bin8 or str8. + m.p.printf("\no = append(o, 0, 0); %s := len(o)", sz) + m.p.printf("\no, err = %s.%s(o)", vname, method) + m.p.wrapErrCheck(m.ctx.ArgsStr()) + m.p.printf("\n%s = len(o) - %s", sz, sz) + if appendFunc == "msgp.AppendString" { + m.p.printf("\no = msgp.AppendBytesStringTwoPrefixed(o, %s)", sz) + } else { + m.p.printf("\no = msgp.AppendBytesTwoPrefixed(o, %s)", sz) + } +} diff --git a/gen/unmarshal.go b/gen/unmarshal.go index 9ce21988..f74b6180 100644 --- a/gen/unmarshal.go +++ b/gen/unmarshal.go @@ -385,6 +385,21 @@ func (u *unmarshalGen) mapstruct(s *Struct) { } } +// binaryUnmarshalCall generates code for unmarshaling marshaler/appender interfaces +func (u *unmarshalGen) binaryUnmarshalCall(refname, unmarshalMethod, readType string) { + tmpBytes := randIdent() + refname = strings.Trim(refname, "(*)") + + u.p.printf("\nvar %s []byte", tmpBytes) + if readType == "String" { + u.p.printf("\n%s, bts, err = msgp.ReadStringZC(bts)", tmpBytes) + } else { + u.p.printf("\n%s, bts, err = msgp.ReadBytesZC(bts)", tmpBytes) + } + u.p.wrapErrCheck(u.ctx.ArgsStr()) + u.p.printf("\nerr = %s.%s(%s)", refname, unmarshalMethod, tmpBytes) +} + func (u *unmarshalGen) gBase(b *BaseElem) { if !u.p.ok() { return @@ -404,6 +419,12 @@ func (u *unmarshalGen) gBase(b *BaseElem) { nilCheck = u.readBytesWithLimit(refname, lowered, b.zerocopy, 0) case Ext: u.p.printf("\nbts, err = msgp.ReadExtensionBytes(bts, %s)", lowered) + case BinaryMarshaler, BinaryAppender: + u.binaryUnmarshalCall(refname, "UnmarshalBinary", "Bytes") + case TextMarshalerBin, TextAppenderBin: + u.binaryUnmarshalCall(refname, "UnmarshalText", "Bytes") + case TextMarshalerString, TextAppenderString: + u.binaryUnmarshalCall(refname, "UnmarshalText", "String") case IDENT: if b.Convert { lowered = b.ToBase() + "(" + lowered + ")" diff --git a/go.mod b/go.mod index a604be4f..b3501536 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/tinylib/msgp -go 1.23 +go 1.24 require ( github.com/philhofer/fwd v1.2.0 diff --git a/msgp/read.go b/msgp/read.go index 34c4f74a..c4160f6c 100644 --- a/msgp/read.go +++ b/msgp/read.go @@ -1,6 +1,7 @@ package msgp import ( + "encoding" "encoding/binary" "encoding/json" "io" @@ -1567,3 +1568,36 @@ func (m *Reader) ReadIntf() (i any, err error) { return nil, fatal // unreachable } } + +// ReadBinaryUnmarshal reads a binary-encoded object from the reader and unmarshals it into dst. +func (m *Reader) ReadBinaryUnmarshal(dst encoding.BinaryUnmarshaler) (err error) { + tmp := bytesPool.Get().([]byte) + defer bytesPool.Put(tmp) //nolint:staticcheck + tmp, err = m.ReadBytes(tmp[:0]) + if err != nil { + return + } + return dst.UnmarshalBinary(tmp) +} + +// ReadTextUnmarshal reads a text-encoded bin array from the reader and unmarshals it into dst. +func (m *Reader) ReadTextUnmarshal(dst encoding.TextUnmarshaler) (err error) { + tmp := bytesPool.Get().([]byte) + defer bytesPool.Put(tmp) //nolint:staticcheck + tmp, err = m.ReadBytes(tmp[:0]) + if err != nil { + return + } + return dst.UnmarshalText(tmp) +} + +// ReadTextUnmarshalString reads a text-encoded string from the reader and unmarshals it into dst. +func (m *Reader) ReadTextUnmarshalString(dst encoding.TextUnmarshaler) (err error) { + tmp := bytesPool.Get().([]byte) + defer bytesPool.Put(tmp) //nolint:staticcheck + tmp, err = m.ReadStringAsBytes(tmp[:0]) + if err != nil { + return + } + return dst.UnmarshalText(tmp) +} diff --git a/msgp/size.go b/msgp/size.go index 585a67fd..b81b94e6 100644 --- a/msgp/size.go +++ b/msgp/size.go @@ -37,4 +37,13 @@ const ( BytesPrefixSize = 5 StringPrefixSize = 5 ExtensionPrefixSize = 6 + + // We cannot determine the exact size of the marshalled bytes, + // so we assume 32 bytes + BinaryMarshalerSize = BytesPrefixSize + 32 + BinaryAppenderSize + TextMarshalerBinSize + TextAppenderBinSize + TextMarshalerStringSize = StringPrefixSize + 32 + TextAppenderStringSize ) diff --git a/msgp/write.go b/msgp/write.go index e6f891ba..f071635f 100644 --- a/msgp/write.go +++ b/msgp/write.go @@ -1,6 +1,7 @@ package msgp import ( + "encoding" "encoding/binary" "encoding/json" "errors" @@ -884,3 +885,42 @@ func GuessSize(i any) int { return 512 } } + +// Temporary buffer for reading/writing binary data. +var bytesPool = sync.Pool{New: func() any { return make([]byte, 0, 1024) }} + +// WriteBinaryAppender will write the bytes from the given +// encoding.BinaryAppender as a bin array. +func (mw *Writer) WriteBinaryAppender(b encoding.BinaryAppender) error { + dst := bytesPool.Get().([]byte) + defer bytesPool.Put(dst) //nolint:staticcheck + dst, err := b.AppendBinary(dst[:0]) + if err != nil { + return err + } + return mw.WriteBytes(dst) +} + +// WriteTextAppender will write the bytes from the given +// encoding.TextAppender as a bin array. +func (mw *Writer) WriteTextAppender(b encoding.TextAppender) error { + dst := bytesPool.Get().([]byte) + defer bytesPool.Put(dst) //nolint:staticcheck + dst, err := b.AppendText(dst[:0]) + if err != nil { + return err + } + return mw.WriteBytes(dst) +} + +// WriteTextAppenderString will write the bytes from the given +// encoding.TextAppender as a string. +func (mw *Writer) WriteTextAppenderString(b encoding.TextAppender) error { + dst := bytesPool.Get().([]byte) + defer bytesPool.Put(dst) //nolint:staticcheck + dst, err := b.AppendText(dst[:0]) + if err != nil { + return err + } + return mw.WriteStringFromBytes(dst) +} diff --git a/msgp/write_bytes.go b/msgp/write_bytes.go index ff13f9f9..378f14f9 100644 --- a/msgp/write_bytes.go +++ b/msgp/write_bytes.go @@ -518,3 +518,49 @@ func AppendJSONNumber(b []byte, n json.Number) ([]byte, error) { } return b, err } + +// AppendBytesTwoPrefixed will add the length to a bin section written with +// 2 bytes of space saved for a bin8 header. +// If the sz cannot fit inside a bin8, the data will be moved to make space for the header. +func AppendBytesTwoPrefixed(b []byte, sz int) []byte { + off := len(b) - sz - 2 + switch { + case sz <= math.MaxUint8: + // Just write header... + prefixu8(b[off:], mbin8, uint8(sz)) + case sz <= math.MaxUint16: + // Scoot one + b = append(b, 0) + copy(b[off+1:], b[off:]) + prefixu16(b[off:], mbin16, uint16(sz)) + default: + // Scoot three + b = append(b, 0, 0, 0) + copy(b[off+3:], b[off:]) + prefixu32(b[off:], mbin32, uint32(sz)) + } + return b +} + +// AppendBytesStringTwoPrefixed will add the length to a string section written with +// 2 bytes of space saved for a str8 header. +// If the sz cannot fit inside a str8, the data will be moved to make space for the header. +func AppendBytesStringTwoPrefixed(b []byte, sz int) []byte { + off := len(b) - sz - 2 + switch { + case sz <= math.MaxUint8: + // Just write header... + prefixu8(b[off:], mstr8, uint8(sz)) + case sz <= math.MaxUint16: + // Scoot one + b = append(b, 0) + copy(b[off+1:], b[off:]) + prefixu16(b[off:], mstr16, uint16(sz)) + default: + // Scoot three + b = append(b, 0, 0, 0) + copy(b[off+3:], b[off:]) + prefixu32(b[off:], mstr32, uint32(sz)) + } + return b +} diff --git a/parse/directives.go b/parse/directives.go index 05a1597d..1460321d 100644 --- a/parse/directives.go +++ b/parse/directives.go @@ -34,6 +34,10 @@ var directives = map[string]directive{ "newtime": newtime, "timezone": newtimezone, "limit": limit, + "binmarshal": binmarshal, + "binappend": binappend, + "textmarshal": textmarshal, + "textappend": textappend, } // map of all recognized directives which will be applied @@ -367,3 +371,156 @@ func limit(text []string, f *FileSet) (err error) { infof("limits - arrays:%d maps:%d marshal:%t\n", f.ArrayLimit, f.MapLimit, f.MarshalLimits) return nil } + +//msgp:binmarshal pkg.Type pkg.Type2 +func binmarshal(text []string, f *FileSet) error { + if len(text) < 2 { + return fmt.Errorf("binmarshal directive should have at least 1 argument; found %d", len(text)-1) + } + alwaysPtr := true + + for _, item := range text[1:] { + name := strings.TrimSpace(item) + be := gen.Ident(name) + be.Value = gen.BinaryMarshaler + be.Alias(name) + be.Convert = false // Don't use conversion for marshaler types + be.AlwaysPtr(&alwaysPtr) + + infof("%s -> BinaryMarshaler\n", name) + f.findShim(name, be, true) + } + + return nil +} + +//msgp:binappend pkg.Type pkg.Type2 +func binappend(text []string, f *FileSet) error { + if len(text) < 2 { + return fmt.Errorf("binappend directive should have at least 1 argument; found %d", len(text)-1) + } + alwaysPtr := true + for _, item := range text[1:] { + name := strings.TrimSpace(item) + be := gen.Ident(name) + be.Value = gen.BinaryAppender + be.Alias(name) + be.Convert = false // Don't use conversion for marshaler types + be.AlwaysPtr(&alwaysPtr) + + infof("%s -> BinaryAppender\n", name) + f.findShim(name, be, true) + } + + return nil +} + +//msgp:textmarshal [as:string] pkg.Type pkg.Type2 +func textmarshal(text []string, f *FileSet) error { + if len(text) < 2 { + return fmt.Errorf("textmarshal directive should have at least 1 argument; found %d", len(text)-1) + } + + // Check for as:string option anywhere in the arguments + var asString bool + var typeArgs []string + alwaysPtr := true + + for _, item := range text[1:] { + trimmed := strings.TrimSpace(item) + if strings.HasPrefix(trimmed, "as:") { + option := strings.TrimPrefix(trimmed, "as:") + switch option { + case "string": + asString = true + case "bin": + asString = false + default: + return fmt.Errorf("invalid as: option %q, expected 'string' or 'bin'", option) + } + } else { + typeArgs = append(typeArgs, trimmed) + } + } + + if len(typeArgs) == 0 { + return fmt.Errorf("textmarshal directive should have at least 1 type argument") + } + + for _, item := range typeArgs { + name := strings.TrimSpace(item) + be := gen.Ident(name) + be.AlwaysPtr(&alwaysPtr) + if asString { + be.Value = gen.TextMarshalerString + } else { + be.Value = gen.TextMarshalerBin + } + be.Alias(name) + be.Convert = false // Don't use conversion for marshaler types + + if asString { + infof("%s -> TextMarshaler (as string)\n", name) + } else { + infof("%s -> TextMarshaler (as bin)\n", name) + } + f.findShim(name, be, true) + } + + return nil +} + +//msgp:textappend [as:string] pkg.Type pkg.Type2 +func textappend(text []string, f *FileSet) error { + if len(text) < 2 { + return fmt.Errorf("textappend directive should have at least 1 argument; found %d", len(text)-1) + } + + // Check for as:string option anywhere in the arguments + var asString bool + var typeArgs []string + alwaysPtr := true + + for _, item := range text[1:] { + trimmed := strings.TrimSpace(item) + if strings.HasPrefix(trimmed, "as:") { + option := strings.TrimPrefix(trimmed, "as:") + switch option { + case "string": + asString = true + case "bin": + asString = false + default: + return fmt.Errorf("invalid as: option %q, expected 'string' or 'bin'", option) + } + } else { + typeArgs = append(typeArgs, trimmed) + } + } + + if len(typeArgs) == 0 { + return fmt.Errorf("textappend directive should have at least 1 type argument") + } + + for _, item := range typeArgs { + name := strings.TrimSpace(item) + be := gen.Ident(name) + if asString { + be.Value = gen.TextAppenderString + } else { + be.Value = gen.TextAppenderBin + } + be.Alias(name) + be.Convert = false // Don't use conversion for marshaler types + be.AlwaysPtr(&alwaysPtr) + + if asString { + infof("%s -> TextAppender (as string)\n", name) + } else { + infof("%s -> TextAppender (as bin)\n", name) + } + f.findShim(name, be, true) + } + + return nil +}