diff --git a/_generated/generics.go b/_generated/generics.go new file mode 100644 index 00000000..6beb306d --- /dev/null +++ b/_generated/generics.go @@ -0,0 +1,84 @@ +package _generated + +import "github.com/tinylib/msgp/msgp" + +//go:generate msgp -v + +type Int64 int + +type GenericTest[T any, P msgp.RTFor[T]] struct { + A T + B P + C []T `msg:",allownil"` + D map[string]T `msg:",allownil"` + E GenericTest2[T, P, string] + F []GenericTest2[T, P, string] `msg:",allownil"` + G map[string]GenericTest2[T, P, string] `msg:",allownil"` + AP *T + CP []*T `msg:",allownil"` + DP map[string]*T `msg:",allownil"` + EP *GenericTest2[T, P, string] + FP []*GenericTest2[T, P, string] `msg:",allownil"` + GP map[string]*GenericTest2[T, P, string] `msg:",allownil"` +} + +type GenericTest2[T any, P msgp.RTFor[T], B any] struct { + A T +} + +//msgp:tuple GenericTuple +type GenericTuple[T any, P msgp.RTFor[T], B any] struct { + A T + B P + C []T `msg:",allownil"` + D map[string]T `msg:",allownil"` + E GenericTest2[T, P, string] + F []GenericTest2[T, P, string] `msg:",allownil"` + G map[string]GenericTest2[T, P, string] `msg:",allownil"` + AP *T + CP []*T `msg:",allownil"` + DP map[string]*T `msg:",allownil"` + EP *GenericTest2[T, P, string] + FP []*GenericTest2[T, P, string] `msg:",allownil"` + GP map[string]*GenericTest2[T, P, string] `msg:",allownil"` +} + +// Type that doesn't have any fields using the generic type should just output valid code. +type GenericTestUnused[T any] struct { + A string +} + +type GenericTestTwo[A, B any, AP msgp.RTFor[A], BP msgp.RTFor[B]] struct { + A A + B AP + C []A `msg:",allownil"` + D map[string]T `msg:",allownil"` + E GenericTest2[A, AP, string] + F []GenericTest2[A, AP, string] `msg:",allownil"` + G map[string]GenericTest2[A, AP, string] `msg:",allownil"` + AP *A + CP []*A `msg:",allownil"` + DP map[string]*A `msg:",allownil"` + EP *GenericTest2[A, AP, string] + FP []*GenericTest2[A, AP, string] `msg:",allownil"` + GP map[string]*GenericTest2[A, AP, string] `msg:",allownil"` + + A2 B + B2 BP + C2 []B `msg:",allownil"` + D2 map[string]B `msg:",allownil"` + E2 GenericTest2[B, BP, string] + F2 []GenericTest2[B, BP, string] `msg:",allownil"` + G2 map[string]GenericTest2[B, BP, string] `msg:",allownil"` + AP2 *B + CP2 []*B `msg:",allownil"` + DP2 map[string]*B `msg:",allownil"` + EP2 *GenericTest2[B, BP, string] + FP2 []*GenericTest2[B, BP, string] `msg:",allownil"` + GP2 map[string]*GenericTest2[B, BP, string] `msg:",allownil"` +} + +type GenericTest3[A, B any, _ msgp.RTFor[A], _ msgp.RTFor[B]] struct { + A A + B B +} diff --git a/_generated/generics_test.go b/_generated/generics_test.go new file mode 100644 index 00000000..a669859e --- /dev/null +++ b/_generated/generics_test.go @@ -0,0 +1,65 @@ +package _generated + +import ( + "bytes" + "reflect" + "testing" + + "github.com/tinylib/msgp/msgp" +) + +func TestGenericsMarshal(t *testing.T) { + x := GenericTest[Fixed, *Fixed]{} + x.B = &Fixed{A: 1.5} + x.C = append(x.C, Fixed{A: 2.5}) + x.D = map[string]Fixed{"hello": Fixed{A: 2.5}} + x.E.A = Fixed{A: 2.5} + x.F = append(x.F, GenericTest2[Fixed, *Fixed, string]{A: Fixed{A: 3.5}}) + x.G = map[string]GenericTest2[Fixed, *Fixed, string]{"hello": {A: Fixed{A: 3.5}}} + + bts, err := x.MarshalMsg(nil) + if err != nil { + t.Fatal(err) + } + got := GenericTest[Fixed, *Fixed]{} + got.B = x.B // We must initialize this. + *got.B = Fixed{} + bts, err = got.UnmarshalMsg(bts) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(x, got) { + t.Errorf("\n got=%#v\nwant=%#v", got, x) + } +} + +func TestGenericsEncode(t *testing.T) { + x := GenericTest[Fixed, *Fixed]{} + x.B = &Fixed{A: 1.5} + x.C = append(x.C, Fixed{A: 2.5}) + x.D = map[string]Fixed{"hello": Fixed{A: 2.5}} + x.E.A = Fixed{A: 2.5} + x.F = append(x.F, GenericTest2[Fixed, *Fixed, string]{A: Fixed{A: 3.5}}) + x.G = map[string]GenericTest2[Fixed, *Fixed, string]{"hello": {A: Fixed{A: 3.5}}} + + var buf bytes.Buffer + w := msgp.NewWriter(&buf) + err := x.EncodeMsg(w) + if err != nil { + t.Fatal(err) + } + w.Flush() + got := GenericTest[Fixed, *Fixed]{} + got.B = x.B // We must initialize this. + *got.B = Fixed{} + r := msgp.NewReader(&buf) + err = got.DecodeMsg(r) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(x, got) { + t.Errorf("\n got=%#v\nwant=%#v", got, x) + } +} diff --git a/gen/decode.go b/gen/decode.go index b3276c08..07352c6e 100644 --- a/gen/decode.go +++ b/gen/decode.go @@ -1,6 +1,7 @@ package gen import ( + "fmt" "io" "strconv" "strings" @@ -96,6 +97,7 @@ func (d *decodeGen) structAsTuple(s *Struct) { } SetIsAllowNil(fieldElem, anField) d.ctx.PushString(s.Fields[i].FieldName) + setTypeParams(fieldElem, s.typeParams) next(d, fieldElem) d.ctx.Pop() if anField { @@ -147,6 +149,7 @@ func (d *decodeGen) structAsMap(s *Struct) { d.p.printf("\n%s = nil\n} else {", fieldElem.Varname()) } SetIsAllowNil(fieldElem, anField) + setTypeParams(fieldElem, s.typeParams) next(d, fieldElem) if oeCount > 0 && (s.Fields[i].HasTagPart("omitempty") || s.Fields[i].HasTagPart("omitzero")) { d.p.printf("\n%s", bm.setStmt(len(oeEmittedIdx))) @@ -218,10 +221,20 @@ func (d *decodeGen) gBase(b *BaseElem) { checkNil = vname } case IDENT: + dst := b.BaseType() + if b.typeParams.isPtr { + dst = "*" + dst + } if b.Convert { + if remap := b.typeParams.ToPointerMap[dst]; remap != "" { + vname = fmt.Sprintf(remap, vname) + } lowered := b.ToBase() + "(" + vname + ")" d.p.printf("\nerr = %s.DecodeMsg(dc)", lowered) } else { + if remap := b.typeParams.ToPointerMap[dst]; remap != "" { + vname = fmt.Sprintf(remap, vname) + } d.p.printf("\nerr = %s.DecodeMsg(dc)", vname) } case Ext: @@ -279,6 +292,7 @@ func (d *decodeGen) gMap(m *Map) { d.p.declare(m.Validx, m.Value.TypeName()) d.ctx.PushVar(m.Keyidx) m.Value.SetIsAllowNil(false) + setTypeParams(m.Value, m.typeParams) next(d, m.Value) d.p.mapAssign(m) d.ctx.Pop() @@ -297,6 +311,7 @@ func (d *decodeGen) gSlice(s *Slice) { } else { d.p.resizeSlice(sz, s) } + setTypeParams(s.Els, s.typeParams) d.p.rangeBlock(d.ctx, s.Index, s.Varname(), d, s.Els) } @@ -315,6 +330,7 @@ func (d *decodeGen) gArray(a *Array) { d.p.declare(sz, u32) d.assignAndCheck(sz, arrayHeader) d.p.arrayCheck(coerceArraySize(a.Size), sz) + setTypeParams(a.Els, a.typeParams) d.p.rangeBlock(d.ctx, a.Index, a.Varname(), d, a.Els) } @@ -327,6 +343,11 @@ func (d *decodeGen) gPtr(p *Ptr) { d.p.wrapErrCheck(d.ctx.ArgsStr()) d.p.printf("\n%s = nil\n} else {", p.Varname()) d.p.initPtr(p) + if p.typeParams.TypeParams != "" { + tp := p.typeParams + tp.isPtr = true + p.Value.SetTypeParams(tp) + } next(d, p.Value) d.p.closeblock() } diff --git a/gen/elem.go b/gen/elem.go index 3b71bc1a..7868a3d9 100644 --- a/gen/elem.go +++ b/gen/elem.go @@ -149,14 +149,24 @@ var builtins = map[string]struct{}{ type common struct { vname, alias string ptrRcv bool + typeParams GenericTypeParams // Generic type parameters, e.g., "[T]" } -func (c *common) SetVarname(s string) { c.vname = s } -func (c *common) Varname() string { return c.vname } -func (c *common) Alias(typ string) { c.alias = typ } -func (c *common) hidden() {} -func (c *common) AllowNil() bool { return false } -func (c *common) SetIsAllowNil(bool) {} +// GenericTypeParams is a struct that contains the generic type parameters for an element. +type GenericTypeParams struct { + TypeParams string + ToPointerMap map[string]string + isPtr bool +} + +func (c *common) SetVarname(s string) { c.vname = s } +func (c *common) Varname() string { return c.vname } +func (c *common) Alias(typ string) { c.alias = typ } +func (c *common) hidden() {} +func (c *common) AllowNil() bool { return false } +func (c *common) SetIsAllowNil(bool) {} +func (c *common) SetTypeParams(tp GenericTypeParams) { c.typeParams = tp } +func (c *common) TypeParams() GenericTypeParams { return c.typeParams } func (c *common) AlwaysPtr(set *bool) bool { if c != nil && set != nil { c.ptrRcv = *set @@ -228,6 +238,12 @@ type Elem interface { // Note that this is NOT used by the `omitzero` feature. IfZeroExpr() string + // SetTypeParams sets the generic type parameters for this element + SetTypeParams(tp GenericTypeParams) + + // TypeParams returns the generic type parameters for this element + TypeParams() GenericTypeParams + hidden() } diff --git a/gen/encode.go b/gen/encode.go index 0b122743..78613821 100644 --- a/gen/encode.go +++ b/gen/encode.go @@ -112,6 +112,7 @@ func (e *encodeGen) tuple(s *Struct) { } SetIsAllowNil(fieldElem, anField) e.ctx.PushString(s.Fields[i].FieldName) + setTypeParams(s.Fields[i].FieldElem, s.typeParams) next(e, s.Fields[i].FieldElem) e.ctx.Pop() if anField { @@ -224,6 +225,7 @@ func (e *encodeGen) structmap(s *Struct) { SetIsAllowNil(fieldElem, anField) e.ctx.PushString(s.Fields[i].FieldName) + setTypeParams(s.Fields[i].FieldElem, s.typeParams) next(e, s.Fields[i].FieldElem) e.ctx.Pop() @@ -269,6 +271,7 @@ func (e *encodeGen) gMap(m *Map) { } e.ctx.PushVar(m.Keyidx) m.Value.SetIsAllowNil(false) + setTypeParams(m.Value, m.typeParams) next(e, m.Value) e.ctx.Pop() e.p.closeblock() @@ -280,6 +283,11 @@ func (e *encodeGen) gPtr(s *Ptr) { } e.fuseHook() e.p.printf("\nif %s == nil { err = en.WriteNil(); if err != nil { return; } } else {", s.Varname()) + if s.typeParams.TypeParams != "" { + tp := s.typeParams + tp.isPtr = true + s.Value.SetTypeParams(tp) + } next(e, s.Value) e.p.closeblock() } @@ -290,6 +298,7 @@ func (e *encodeGen) gSlice(s *Slice) { } e.fuseHook() e.writeAndCheck(arrayHeader, lenAsUint32, s.Varname()) + setTypeParams(s.Els, s.typeParams) e.p.rangeBlock(e.ctx, s.Index, s.Varname(), e, s.Els) } @@ -306,6 +315,7 @@ func (e *encodeGen) gArray(a *Array) { } e.writeAndCheck(arrayHeader, literalFmt, coerceArraySize(a.Size)) + setTypeParams(a.Els, a.typeParams) e.p.rangeBlock(e.ctx, a.Index, a.Varname(), e, a.Els) } @@ -330,6 +340,13 @@ func (e *encodeGen) gBase(b *BaseElem) { t := strings.TrimPrefix(b.BaseName(), "atomic.") e.writeAndCheck(t, literalFmt, strings.TrimPrefix(vname, "*")+".Load()") case IDENT: // unknown identity + dst := b.BaseType() + if b.typeParams.isPtr { + dst = "*" + dst + } + if remap := b.typeParams.ToPointerMap[dst]; remap != "" { + vname = fmt.Sprintf(remap, vname) + } e.p.printf("\nerr = %s.EncodeMsg(en)", vname) e.p.wrapErrCheck(e.ctx.ArgsStr()) default: diff --git a/gen/marshal.go b/gen/marshal.go index 33fa60b3..5ce6c585 100644 --- a/gen/marshal.go +++ b/gen/marshal.go @@ -119,6 +119,7 @@ func (m *marshalGen) tuple(s *Struct) { } m.ctx.PushString(s.Fields[i].FieldName) SetIsAllowNil(fieldElem, anField) + setTypeParams(fieldElem, s.typeParams) next(m, fieldElem) m.ctx.Pop() if anField { @@ -221,6 +222,7 @@ func (m *marshalGen) mapstruct(s *Struct) { } m.ctx.PushString(s.Fields[i].FieldName) SetIsAllowNil(fieldElem, anField) + setTypeParams(fieldElem, s.typeParams) next(m, fieldElem) m.ctx.Pop() @@ -276,6 +278,7 @@ func (m *marshalGen) gMap(s *Map) { m.ctx.PushVar(s.Keyidx) s.Value.SetIsAllowNil(false) + setTypeParams(s.Value, s.typeParams) next(m, s.Value) m.ctx.Pop() m.p.closeblock() @@ -287,6 +290,8 @@ func (m *marshalGen) gSlice(s *Slice) { } m.fuseHook() vname := s.Varname() + setTypeParams(s.Els, s.typeParams) + m.rawAppend(arrayHeader, lenAsUint32, vname) m.p.rangeBlock(m.ctx, s.Index, vname, m, s.Els) } @@ -300,17 +305,31 @@ func (m *marshalGen) gArray(a *Array) { m.rawAppend("Bytes", "(%s)[:]", a.Varname()) return } + setTypeParams(a.Els, a.typeParams) m.rawAppend(arrayHeader, literalFmt, coerceArraySize(a.Size)) m.p.rangeBlock(m.ctx, a.Index, a.Varname(), m, a.Els) } +func setTypeParams(e Elem, tp GenericTypeParams) { + if e == nil { + return + } + tp.isPtr = false + e.SetTypeParams(tp) +} + func (m *marshalGen) gPtr(p *Ptr) { if !m.p.ok() { return } m.fuseHook() m.p.printf("\nif %s == nil {\no = msgp.AppendNil(o)\n} else {", p.Varname()) + if p.typeParams.TypeParams != "" { + tp := p.typeParams + tp.isPtr = true + p.Value.SetTypeParams(tp) + } next(m, p.Value) m.p.closeblock() } @@ -335,6 +354,13 @@ func (m *marshalGen) gBase(b *BaseElem) { var echeck bool switch b.Value { case IDENT: + dst := b.BaseType() + if b.typeParams.isPtr { + dst = "*" + dst + } + if remap := b.typeParams.ToPointerMap[dst]; remap != "" { + vname = fmt.Sprintf(remap, vname) + } echeck = true m.p.printf("\no, err = %s.MarshalMsg(o)", vname) case Intf, Ext, JsonNumber: diff --git a/gen/size.go b/gen/size.go index cd89638f..53504d8d 100644 --- a/gen/size.go +++ b/gen/size.go @@ -116,6 +116,7 @@ func (s *sizeGen) gStruct(st *Struct) { if !s.p.ok() { return } + setTypeParams(st.Fields[i].FieldElem, st.typeParams) next(s, st.Fields[i].FieldElem) } } else { @@ -125,6 +126,7 @@ func (s *sizeGen) gStruct(st *Struct) { data = data[:0] data = msgp.AppendString(data, st.Fields[i].FieldTag) s.addConstant(strconv.Itoa(len(data))) + setTypeParams(st.Fields[i].FieldElem, st.typeParams) next(s, st.Fields[i].FieldElem) } } @@ -133,6 +135,11 @@ func (s *sizeGen) gStruct(st *Struct) { func (s *sizeGen) gPtr(p *Ptr) { s.state = add // inner must use add s.p.printf("\nif %s == nil {\ns += msgp.NilSize\n} else {", p.Varname()) + if p.typeParams.TypeParams != "" { + tp := p.typeParams + tp.isPtr = true + p.Value.SetTypeParams(tp) + } next(s, p.Value) s.state = add // closing block; reset to add s.p.closeblock() @@ -154,6 +161,7 @@ func (s *sizeGen) gSlice(sl *Slice) { } // add inside the range block, and immediately after + setTypeParams(sl.Els, sl.typeParams) s.state = add s.p.rangeBlock(s.ctx, sl.Index, sl.Varname(), s, sl.Els) s.state = add @@ -174,6 +182,7 @@ func (s *sizeGen) gArray(a *Array) { return } + setTypeParams(a.Els, a.typeParams) s.state = add s.p.rangeBlock(s.ctx, a.Index, a.Varname(), s, a.Els) s.state = add @@ -193,6 +202,7 @@ func (s *sizeGen) gMap(m *Map) { keyIdx = fmt.Sprintf("%s(%s)", toBase, keyIdx) s.p.printf("\ns += msgp.StringPrefixSize + len(%s)", keyIdx) } else if key.Value == IDENT && key.ShimToBase == "" && m.AllowBinMaps { + //TODO: Generic keys? s.p.printf("\ns += %s.Msgsize()", keyIdx) } default: @@ -205,6 +215,7 @@ func (s *sizeGen) gMap(m *Map) { } s.state = expr s.ctx.PushVar(m.Keyidx) + setTypeParams(m.Value, m.typeParams) next(s, m.Value) s.ctx.Pop() s.p.closeblock() @@ -232,6 +243,13 @@ func (s *sizeGen) gBase(b *BaseElem) { if b.Convert { vname = tobaseConvert(b) } + dst := b.BaseType() + if b.typeParams.isPtr { + dst = "*" + dst + } + if remap := b.typeParams.ToPointerMap[dst]; remap != "" { + vname = fmt.Sprintf(remap, vname) + } s.addConstant(basesizeExpr(b.Value, vname, b.BaseName())) } } diff --git a/gen/spec.go b/gen/spec.go index 42aa45ab..3307d338 100644 --- a/gen/spec.go +++ b/gen/spec.go @@ -268,6 +268,9 @@ func next(t traversal, e Elem) { // possibly-immutable method receiver func imutMethodReceiver(p Elem) string { + typeName := p.TypeName() + typeParams := p.TypeParams() + switch e := p.(type) { case *Struct: // TODO(HACK): actually do real math here. @@ -277,19 +280,19 @@ func imutMethodReceiver(p Elem) string { goto nope } } - return p.TypeName() + return typeName + typeParams.TypeParams } nope: - return "*" + p.TypeName() + return "*" + typeName + typeParams.TypeParams // gets dereferenced automatically case *Array: - return "*" + p.TypeName() + return "*" + typeName + typeParams.TypeParams // everything else can be // by-value. default: - return p.TypeName() + return typeName + typeParams.TypeParams } } @@ -297,18 +300,21 @@ func imutMethodReceiver(p Elem) string { // so that its method receiver // is of the write type. func methodReceiver(p Elem) string { + typeName := p.TypeName() + typeParams := p.TypeParams() + switch p.(type) { // structs and arrays are // dereferenced automatically, // so no need to alter varname case *Struct, *Array: - return "*" + p.TypeName() + return "*" + typeName + typeParams.TypeParams // set variable name to // *varname default: p.SetVarname("(*" + p.Varname() + ")") - return "*" + p.TypeName() + return "*" + typeName + typeParams.TypeParams } } @@ -404,7 +410,11 @@ func (p *printer) clearMap(name string) { func (p *printer) wrapErrCheck(ctx string) { p.print("\nif err != nil {") - p.printf("\nerr = msgp.WrapError(err, %s)", ctx) + if ctx != "" { + p.printf("\nerr = msgp.WrapError(err, %s)", ctx) + } else { + p.print("\nerr = msgp.WrapError(err)") + } p.printf("\nreturn") p.print("\n}") } diff --git a/gen/testgen.go b/gen/testgen.go index d8e85b58..533b0268 100644 --- a/gen/testgen.go +++ b/gen/testgen.go @@ -1,6 +1,7 @@ package gen import ( + "fmt" "io" "text/template" ) @@ -29,6 +30,10 @@ type mtestGen struct { func (m *mtestGen) Execute(p Elem, _ Context) error { p = m.applyall(p) if p != nil && IsPrintable(p) { + if p.TypeParams().TypeParams != "" { + fmt.Fprintf(m.w, "\n// %s: Cannot generate marshal test for generic types", p.TypeName()) + return nil + } switch p.(type) { case *Struct, *Array, *Slice, *Map: return marshalTestTempl.Execute(m.w, p) @@ -51,6 +56,10 @@ func etest(w io.Writer) *etestGen { func (e *etestGen) Execute(p Elem, _ Context) error { p = e.applyall(p) if p != nil && IsPrintable(p) { + if p.TypeParams().TypeParams != "" { + fmt.Fprintf(e.w, "\n// %s: Cannot generate encoder test for generic types", p.TypeName()) + return nil + } switch p.(type) { case *Struct, *Array, *Slice, *Map: return encodeTestTempl.Execute(e.w, p) diff --git a/gen/unmarshal.go b/gen/unmarshal.go index 44d02a0d..f1d6bfce 100644 --- a/gen/unmarshal.go +++ b/gen/unmarshal.go @@ -1,6 +1,7 @@ package gen import ( + "fmt" "io" "strconv" "strings" @@ -97,6 +98,7 @@ func (u *unmarshalGen) tuple(s *Struct) { if s.Fields[i].HasTagPart("zerocopy") { setRecursiveZC(fieldElem, true) } + setTypeParams(fieldElem, s.typeParams) next(u, fieldElem) if s.Fields[i].HasTagPart("zerocopy") { setRecursiveZC(fieldElem, false) @@ -173,6 +175,8 @@ func (u *unmarshalGen) mapstruct(s *Struct) { if s.Fields[i].HasTagPart("zerocopy") { setRecursiveZC(fieldElem, true) } + setTypeParams(fieldElem, s.typeParams) + next(u, fieldElem) if s.Fields[i].HasTagPart("zerocopy") { setRecursiveZC(fieldElem, false) @@ -239,6 +243,14 @@ func (u *unmarshalGen) gBase(b *BaseElem) { if b.Convert { lowered = b.ToBase() + "(" + lowered + ")" } + dst := b.BaseType() + if b.typeParams.isPtr { + dst = "*" + dst + } + if remap := b.typeParams.ToPointerMap[dst]; remap != "" { + lowered = fmt.Sprintf(remap, lowered) + } + u.p.printf("\nbts, err = %s.UnmarshalMsg(bts)", lowered) case Time: if u.ctx.asUTC { @@ -292,6 +304,7 @@ func (u *unmarshalGen) gArray(a *Array) { u.p.declare(sz, u32) u.assignAndCheck(sz, arrayHeader) u.p.arrayCheck(coerceArraySize(a.Size), sz) + setTypeParams(a.Els, a.typeParams) u.p.rangeBlock(u.ctx, a.Index, a.Varname(), u, a.Els) } @@ -307,6 +320,7 @@ func (u *unmarshalGen) gSlice(s *Slice) { } else { u.p.resizeSlice(sz, s) } + setTypeParams(s.Els, s.typeParams) u.p.rangeBlock(u.ctx, s.Index, s.Varname(), u, s.Els) } @@ -331,6 +345,7 @@ func (u *unmarshalGen) gMap(m *Map) { m.readKey(u.ctx, u.p, u, u.assignAndCheck) u.ctx.PushVar(m.Keyidx) m.Value.SetIsAllowNil(false) + setTypeParams(m.Value, m.typeParams) next(u, m.Value) u.ctx.Pop() u.p.mapAssign(m) @@ -340,6 +355,11 @@ func (u *unmarshalGen) gMap(m *Map) { func (u *unmarshalGen) gPtr(p *Ptr) { u.p.printf("\nif msgp.IsNil(bts) { bts, err = msgp.ReadNilBytes(bts); if err != nil { return }; %s = nil; } else { ", p.Varname()) u.p.initPtr(p) + if p.typeParams.TypeParams != "" { + tp := p.typeParams + tp.isPtr = true + p.Value.SetTypeParams(tp) + } next(u, p.Value) u.p.closeblock() } diff --git a/msgp/defs.go b/msgp/defs.go index 47a8c183..f622bf1e 100644 --- a/msgp/defs.go +++ b/msgp/defs.go @@ -26,6 +26,27 @@ // the wiki at http://github.com/tinylib/msgp package msgp +// RT is the runtime interface for all types that can be encoded and decoded. +type RT interface { + Decodable + Encodable + Sizer + Unmarshaler + Marshaler +} + +// PtrTo is the runtime interface for all types that can be encoded and decoded. +type PtrTo[T any] interface { + ~*T +} + +// RTFor is the runtime interface for all types that can be encoded and decoded. +// Use for generic types. +type RTFor[T any] interface { + PtrTo[T] + RT +} + const ( last4 = 0x0f first4 = 0xf0 diff --git a/parse/getast.go b/parse/getast.go index 48cf2827..ac7ed4bd 100644 --- a/parse/getast.go +++ b/parse/getast.go @@ -13,22 +13,29 @@ import ( "github.com/tinylib/msgp/gen" ) +// TypeInfo holds both the type expression and its generic type parameters +type TypeInfo struct { + Type ast.Expr // The actual type expression + TypeParams *ast.FieldList // Generic type parameters +} + // A FileSet is the in-memory representation of a // parsed file. type FileSet struct { - Package string // package name - Specs map[string]ast.Expr // type specs in file - Identities map[string]gen.Elem // processed from specs - Aliased map[string]string // Aliased types. - Directives []string // raw preprocessor directives - Imports []*ast.ImportSpec // imports - CompactFloats bool // Use smaller floats when feasible - ClearOmitted bool // Set omitted fields to zero value - NewTime bool // Set to use -1 extension for time.Time - AsUTC bool // Set timezone to UTC instead of local - AllowMapShims bool // Allow map keys to be shimmed (default true) - AllowBinMaps bool // Allow maps with binary keys to be used (default false) - AutoMapShims bool // Automatically shim map keys of builtin types(default false) + Package string // package name + Specs map[string]ast.Expr // type specs in file + TypeInfos map[string]*TypeInfo // type specs with generic info + Identities map[string]gen.Elem // processed from specs + Aliased map[string]string // Aliased types. + Directives []string // raw preprocessor directives + Imports []*ast.ImportSpec // imports + CompactFloats bool // Use smaller floats when feasible + ClearOmitted bool // Set omitted fields to zero value + NewTime bool // Set to use -1 extension for time.Time + AsUTC bool // Set timezone to UTC instead of local + AllowMapShims bool // Allow map keys to be shimmed (default true) + AllowBinMaps bool // Allow maps with binary keys to be used (default false) + AutoMapShims bool // Automatically shim map keys of builtin types(default false) tagName string // tag to read field names from pointerRcv bool // generate with pointer receivers. @@ -45,6 +52,7 @@ func File(name string, unexported bool, directives []string) (*FileSet, error) { defer popstate() fs := &FileSet{ Specs: make(map[string]ast.Expr), + TypeInfos: make(map[string]*TypeInfo), Identities: make(map[string]gen.Elem), Directives: append([]string{}, directives...), } @@ -213,6 +221,57 @@ func (fs *FileSet) resolve(ls linkset) { } } +// formatTypeParams converts an AST FieldList to a string representation +func formatTypeParams(params *ast.FieldList) string { + if params == nil || params.NumFields() == 0 { + return "" + } + + var paramStrs []string + for _, field := range params.List { + str := stringify(field.Type) + // Convert underscores to _RTn where n is the number of the parameter + convert := strings.HasPrefix(str, "msgp.RTFor[") + + // Each field can have multiple names (e.g., T, U constraint) + for _, name := range field.Names { + if convert && name.Name == "_" { + name.Name = fmt.Sprintf("_RT%d", len(paramStrs)+1) + } + // For method receivers, we only include the type parameter name + // The constraints are defined in the type declaration, not the method receiver + paramStrs = append(paramStrs, name.Name) + } + } + + return "[" + strings.Join(paramStrs, ", ") + "]" +} + +// formatTypeParams converts an AST FieldList to a string representation. +// For 'Foo[T any, P msgp.RTFor[T]]' will return {"T": "P"}. +func getMspTypeParams(params *ast.FieldList) map[string]string { + if params == nil || params.NumFields() == 0 { + return nil + } + + paramStrs := make(map[string]string) + for _, field := range params.List { + str := stringify(field.Type) + if !strings.HasPrefix(str, "msgp.RTFor[") { + continue + } + for _, name := range field.Names { + t := strings.TrimSuffix(strings.TrimPrefix(str, "msgp.RTFor["), "]") + paramStrs[t] = name.Name + "(&%s)" + paramStrs["*"+t] = name.Name + "(%s)" + paramStrs[name.Name] = "%s" + infof("found generic type %s, with roundtrippper %s\n", t, name.Name) + } + } + + return paramStrs +} + // process takes the contents of f.Specs and // uses them to populate f.Identities func (fs *FileSet) process() { @@ -227,6 +286,19 @@ parse: continue parse } el.AlwaysPtr(&fs.pointerRcv) + + // Apply type parameters if available + if typeInfo, ok := fs.TypeInfos[name]; ok && typeInfo.TypeParams != nil { + typeParamsStr := formatTypeParams(typeInfo.TypeParams) + ptrMap := getMspTypeParams(typeInfo.TypeParams) + if typeParamsStr != "" && ptrMap != nil { + el.SetTypeParams(gen.GenericTypeParams{ + TypeParams: typeParamsStr, + ToPointerMap: ptrMap, + }) + } + } + // push unresolved identities into // the graph of links and resolve after // we've handled every possible named type. @@ -347,12 +419,17 @@ func (fs *FileSet) getTypeSpecs(f *ast.File) { switch ts.Type.(type) { // this is the list of parse-able // type specs - case *ast.StructType, - *ast.ArrayType, + case *ast.ArrayType, *ast.StarExpr, - *ast.MapType, - *ast.Ident: + *ast.Ident, + *ast.StructType, + *ast.MapType: fs.Specs[ts.Name.Name] = ts.Type + // Store type info (no type params for non-struct types yet) + fs.TypeInfos[ts.Name.Name] = &TypeInfo{ + Type: ts.Type, + TypeParams: ts.TypeParams, + } } } } @@ -540,6 +617,16 @@ func stringify(e ast.Expr) string { } case *ast.BasicLit: return e.Value + case *ast.IndexExpr: + // Single type argument: Generic[T] + return fmt.Sprintf("%s[%s]", stringify(e.X), stringify(e.Index)) + case *ast.IndexListExpr: + // Multiple type arguments: Generic[A,B,...] + args := make([]string, 0, len(e.Indices)) + for _, ix := range e.Indices { + args = append(args, stringify(ix)) + } + return fmt.Sprintf("%s[%s]", stringify(e.X), strings.Join(args, ",")) } return "" } @@ -616,6 +703,7 @@ func (fs *FileSet) parseExpr(e ast.Expr) gen.Elem { // everything else. if b.Value == gen.IDENT { if _, ok := fs.Specs[e.Name]; !ok && fs.Aliased[e.Name] == "" { + // This can be a generic type. warnf("possible non-local identifier: %s\n", e.Name) } } @@ -683,6 +771,15 @@ func (fs *FileSet) parseExpr(e ast.Expr) gen.Elem { } return nil + case *ast.IndexExpr: + // Treat a generic instantiation like an identifier of the instantiated name. + // Example: GenericTest2[T] -> "GenericTest2[T]" + return gen.Ident(stringify(e)) + + case *ast.IndexListExpr: + // Treat a generic instantiation with multiple args similarly. + return gen.Ident(stringify(e)) + default: // other types not supported return nil } diff --git a/parse/inline.go b/parse/inline.go index 7cd7f643..d3f9d150 100644 --- a/parse/inline.go +++ b/parse/inline.go @@ -2,6 +2,7 @@ package parse import ( "sort" + "strings" "github.com/tinylib/msgp/gen" ) @@ -113,16 +114,16 @@ func (fs *FileSet) propInline() { switch el := all[i].el.(type) { case *gen.Struct: for i := range el.Fields { - fs.nextInline(&el.Fields[i].FieldElem, name) + fs.nextInline(&el.Fields[i].FieldElem, name, el.TypeParams()) } case *gen.Array: - fs.nextInline(&el.Els, name) + fs.nextInline(&el.Els, name, el.TypeParams()) case *gen.Slice: - fs.nextInline(&el.Els, name) + fs.nextInline(&el.Els, name, el.TypeParams()) case *gen.Map: - fs.nextInline(&el.Value, name) + fs.nextInline(&el.Value, name, el.TypeParams()) case *gen.Ptr: - fs.nextInline(&el.Value, name) + fs.nextInline(&el.Value, name, el.TypeParams()) } popstate() } @@ -133,7 +134,7 @@ Please file a bug at github.com/tinylib/msgp/issues! Thanks! ` -func (fs *FileSet) nextInline(ref *gen.Elem, root string) { +func (fs *FileSet) nextInline(ref *gen.Elem, root string, params gen.GenericTypeParams) { switch el := (*ref).(type) { case *gen.BaseElem: // ensure that we're not inlining @@ -150,26 +151,28 @@ func (fs *FileSet) nextInline(ref *gen.Elem, root string) { } *ref = node.Copy() - fs.nextInline(ref, node.TypeName()) + fs.nextInline(ref, node.TypeName(), params) } else if !ok && !el.Resolved() { - // this is the point at which we're sure that - // we've got a type that isn't a primitive, - // a library builtin, or a processed type - warnf("unresolved identifier: %s\n", typ) + if params.ToPointerMap[typ] == "" && (!strings.Contains(typ, "[") || !strings.Contains(typ, "]")) { + // this is the point at which we're sure that + // we've got a type that isn't a primitive, + // a library builtin, or a processed type + warnf("unresolved identifier: %s\n", typ) + } } } case *gen.Struct: for i := range el.Fields { - fs.nextInline(&el.Fields[i].FieldElem, root) + fs.nextInline(&el.Fields[i].FieldElem, root, el.TypeParams()) } case *gen.Array: - fs.nextInline(&el.Els, root) + fs.nextInline(&el.Els, root, params) case *gen.Slice: - fs.nextInline(&el.Els, root) + fs.nextInline(&el.Els, root, params) case *gen.Map: - fs.nextInline(&el.Value, root) + fs.nextInline(&el.Value, root, params) case *gen.Ptr: - fs.nextInline(&el.Value, root) + fs.nextInline(&el.Value, root, params) default: panic("bad elem type") }