Skip to content

Commit

Permalink
fix: handle zero-length strings in unsafe unmarshalling and update te…
Browse files Browse the repository at this point in the history
…st logic to suit
  • Loading branch information
maheeshap-canopus committed Jan 8, 2024
1 parent c47c2bc commit bcde995
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
12 changes: 11 additions & 1 deletion features/unmarshal/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,11 @@ func (p *unmarshal) mapField(varName string, field *protogen.Field) {
p.P(`return `, p.Ident("io", `ErrUnexpectedEOF`))
p.P(`}`)
if p.unsafe {
p.P(`if intStringLen`, varName, ` == 0 {`)
p.P(varName, ` = `, p.Ident("unsafe", `String`), `(nil, intStringLen`, varName, `)`)
p.P(`} else {`)
p.P(varName, ` = `, p.Ident("unsafe", `String`), `(&dAtA[iNdEx], intStringLen`, varName, `)`)
p.P(`}`)
} else {
p.P(varName, ` = `, "string", `(dAtA[iNdEx:postStringIndex`, varName, `])`)
}
Expand Down Expand Up @@ -420,7 +424,13 @@ func (p *unmarshal) fieldItem(field *protogen.Field, fieldname string, message *
p.P(`}`)
str := "string(dAtA[iNdEx:postIndex])"
if p.unsafe {
str = p.Ident("unsafe", `String`) + `(&dAtA[iNdEx], intStringLen)`
str = "stringValue"
p.P(`var stringValue string`)
p.P(`if intStringLen == 0 {`)
p.P(`stringValue = `, p.Ident("unsafe", `String`), `(nil, intStringLen)`)
p.P(`} else {`)
p.P(`stringValue = `, p.Ident("unsafe", `String`), `(&dAtA[iNdEx], intStringLen)`)
p.P(`}`)
}
if oneof {
p.P(`m.`, fieldname, ` = &`, field.GoIdent, `{`, field.GoName, ": ", str, `}`)
Expand Down
10 changes: 8 additions & 2 deletions testproto/unsafe/unsafe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@ import (
// collector doesn't move memory. To provide guarantee that this test works, consider using
// https://pkg.go.dev/runtime#Pinner when upgrading Go to >= 1.21.
func assertStringIsOriginal(t *testing.T, s string, belongs bool, originalData []byte) {
start := uintptr(unsafe.Pointer(unsafe.StringData(s)))
// empty string has no underlying array, compare pointer to nil
if len(s) == 0 {
assert.Equal(t, uintptr(unsafe.Pointer(nil)), start)
return
}
end := start + uintptr(len(s)) - 1

originalStart := uintptr(unsafe.Pointer(unsafe.SliceData(originalData)))
originalEnd := originalStart + uintptr(len(originalData)) - 1

start := uintptr(unsafe.Pointer(unsafe.StringData(s)))
end := start + uintptr(len(s)) - 1
assert.Equal(t, belongs, start >= originalStart && start < originalEnd)
assert.Equal(t, belongs, end > originalStart && end <= originalEnd)
}
Expand Down

0 comments on commit bcde995

Please sign in to comment.