Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions _generated/generics.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package _generated

import "github.com/tinylib/msgp/msgp"

//go:generate msgp -v

type Int64 int

type GenericTest[T any, P msgp.RTFor[T]] struct {
A T
B P
C []T `msg:",allownil"`
D map[string]T `msg:",allownil"`
E GenericTest2[T, P, string]
F []GenericTest2[T, P, string] `msg:",allownil"`
G map[string]GenericTest2[T, P, string] `msg:",allownil"`
AP *T
CP []*T `msg:",allownil"`
DP map[string]*T `msg:",allownil"`
EP *GenericTest2[T, P, string]
FP []*GenericTest2[T, P, string] `msg:",allownil"`
GP map[string]*GenericTest2[T, P, string] `msg:",allownil"`
}

type GenericTest2[T any, P msgp.RTFor[T], B any] struct {
A T
}

//msgp:tuple GenericTuple
type GenericTuple[T any, P msgp.RTFor[T], B any] struct {
A T
B P
C []T `msg:",allownil"`
D map[string]T `msg:",allownil"`
E GenericTest2[T, P, string]
F []GenericTest2[T, P, string] `msg:",allownil"`
G map[string]GenericTest2[T, P, string] `msg:",allownil"`
AP *T
CP []*T `msg:",allownil"`
DP map[string]*T `msg:",allownil"`
EP *GenericTest2[T, P, string]
FP []*GenericTest2[T, P, string] `msg:",allownil"`
GP map[string]*GenericTest2[T, P, string] `msg:",allownil"`
}

// Type that doesn't have any fields using the generic type should just output valid code.
type GenericTestUnused[T any] struct {
A string
}

type GenericTestTwo[A, B any, AP msgp.RTFor[A], BP msgp.RTFor[B]] struct {
A A
B AP
C []A `msg:",allownil"`
D map[string]T `msg:",allownil"`
E GenericTest2[A, AP, string]
F []GenericTest2[A, AP, string] `msg:",allownil"`
G map[string]GenericTest2[A, AP, string] `msg:",allownil"`
AP *A
CP []*A `msg:",allownil"`
DP map[string]*A `msg:",allownil"`
EP *GenericTest2[A, AP, string]
FP []*GenericTest2[A, AP, string] `msg:",allownil"`
GP map[string]*GenericTest2[A, AP, string] `msg:",allownil"`

A2 B
B2 BP
C2 []B `msg:",allownil"`
D2 map[string]B `msg:",allownil"`
E2 GenericTest2[B, BP, string]
F2 []GenericTest2[B, BP, string] `msg:",allownil"`
G2 map[string]GenericTest2[B, BP, string] `msg:",allownil"`
AP2 *B
CP2 []*B `msg:",allownil"`
DP2 map[string]*B `msg:",allownil"`
EP2 *GenericTest2[B, BP, string]
FP2 []*GenericTest2[B, BP, string] `msg:",allownil"`
GP2 map[string]*GenericTest2[B, BP, string] `msg:",allownil"`
}

type GenericTest3[A, B any, _ msgp.RTFor[A], _ msgp.RTFor[B]] struct {
A A
B B
}
65 changes: 65 additions & 0 deletions _generated/generics_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package _generated

import (
"bytes"
"reflect"
"testing"

"github.com/tinylib/msgp/msgp"
)

func TestGenericsMarshal(t *testing.T) {
x := GenericTest[Fixed, *Fixed]{}
x.B = &Fixed{A: 1.5}
x.C = append(x.C, Fixed{A: 2.5})
x.D = map[string]Fixed{"hello": Fixed{A: 2.5}}
x.E.A = Fixed{A: 2.5}
x.F = append(x.F, GenericTest2[Fixed, *Fixed, string]{A: Fixed{A: 3.5}})
x.G = map[string]GenericTest2[Fixed, *Fixed, string]{"hello": {A: Fixed{A: 3.5}}}

bts, err := x.MarshalMsg(nil)
if err != nil {
t.Fatal(err)
}
got := GenericTest[Fixed, *Fixed]{}
got.B = x.B // We must initialize this.
*got.B = Fixed{}
bts, err = got.UnmarshalMsg(bts)
if err != nil {
t.Fatal(err)
}

if !reflect.DeepEqual(x, got) {
t.Errorf("\n got=%#v\nwant=%#v", got, x)
}
}

func TestGenericsEncode(t *testing.T) {
x := GenericTest[Fixed, *Fixed]{}
x.B = &Fixed{A: 1.5}
x.C = append(x.C, Fixed{A: 2.5})
x.D = map[string]Fixed{"hello": Fixed{A: 2.5}}
x.E.A = Fixed{A: 2.5}
x.F = append(x.F, GenericTest2[Fixed, *Fixed, string]{A: Fixed{A: 3.5}})
x.G = map[string]GenericTest2[Fixed, *Fixed, string]{"hello": {A: Fixed{A: 3.5}}}

var buf bytes.Buffer
w := msgp.NewWriter(&buf)
err := x.EncodeMsg(w)
if err != nil {
t.Fatal(err)
}
w.Flush()
got := GenericTest[Fixed, *Fixed]{}
got.B = x.B // We must initialize this.
*got.B = Fixed{}
r := msgp.NewReader(&buf)
err = got.DecodeMsg(r)
if err != nil {
t.Fatal(err)
}

if !reflect.DeepEqual(x, got) {
t.Errorf("\n got=%#v\nwant=%#v", got, x)
}
}
21 changes: 21 additions & 0 deletions gen/decode.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gen

import (
"fmt"
"io"
"strconv"
"strings"
Expand Down Expand Up @@ -96,6 +97,7 @@ func (d *decodeGen) structAsTuple(s *Struct) {
}
SetIsAllowNil(fieldElem, anField)
d.ctx.PushString(s.Fields[i].FieldName)
setTypeParams(fieldElem, s.typeParams)
next(d, fieldElem)
d.ctx.Pop()
if anField {
Expand Down Expand Up @@ -147,6 +149,7 @@ func (d *decodeGen) structAsMap(s *Struct) {
d.p.printf("\n%s = nil\n} else {", fieldElem.Varname())
}
SetIsAllowNil(fieldElem, anField)
setTypeParams(fieldElem, s.typeParams)
next(d, fieldElem)
if oeCount > 0 && (s.Fields[i].HasTagPart("omitempty") || s.Fields[i].HasTagPart("omitzero")) {
d.p.printf("\n%s", bm.setStmt(len(oeEmittedIdx)))
Expand Down Expand Up @@ -218,10 +221,20 @@ func (d *decodeGen) gBase(b *BaseElem) {
checkNil = vname
}
case IDENT:
dst := b.BaseType()
if b.typeParams.isPtr {
dst = "*" + dst
}
if b.Convert {
if remap := b.typeParams.ToPointerMap[dst]; remap != "" {
vname = fmt.Sprintf(remap, vname)
}
lowered := b.ToBase() + "(" + vname + ")"
d.p.printf("\nerr = %s.DecodeMsg(dc)", lowered)
} else {
if remap := b.typeParams.ToPointerMap[dst]; remap != "" {
vname = fmt.Sprintf(remap, vname)
}
d.p.printf("\nerr = %s.DecodeMsg(dc)", vname)
}
case Ext:
Expand Down Expand Up @@ -279,6 +292,7 @@ func (d *decodeGen) gMap(m *Map) {
d.p.declare(m.Validx, m.Value.TypeName())
d.ctx.PushVar(m.Keyidx)
m.Value.SetIsAllowNil(false)
setTypeParams(m.Value, m.typeParams)
next(d, m.Value)
d.p.mapAssign(m)
d.ctx.Pop()
Expand All @@ -297,6 +311,7 @@ func (d *decodeGen) gSlice(s *Slice) {
} else {
d.p.resizeSlice(sz, s)
}
setTypeParams(s.Els, s.typeParams)
d.p.rangeBlock(d.ctx, s.Index, s.Varname(), d, s.Els)
}

Expand All @@ -315,6 +330,7 @@ func (d *decodeGen) gArray(a *Array) {
d.p.declare(sz, u32)
d.assignAndCheck(sz, arrayHeader)
d.p.arrayCheck(coerceArraySize(a.Size), sz)
setTypeParams(a.Els, a.typeParams)
d.p.rangeBlock(d.ctx, a.Index, a.Varname(), d, a.Els)
}

Expand All @@ -327,6 +343,11 @@ func (d *decodeGen) gPtr(p *Ptr) {
d.p.wrapErrCheck(d.ctx.ArgsStr())
d.p.printf("\n%s = nil\n} else {", p.Varname())
d.p.initPtr(p)
if p.typeParams.TypeParams != "" {
tp := p.typeParams
tp.isPtr = true
p.Value.SetTypeParams(tp)
}
next(d, p.Value)
d.p.closeblock()
}
28 changes: 22 additions & 6 deletions gen/elem.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,24 @@ var builtins = map[string]struct{}{
type common struct {
vname, alias string
ptrRcv bool
typeParams GenericTypeParams // Generic type parameters, e.g., "[T]"
}

func (c *common) SetVarname(s string) { c.vname = s }
func (c *common) Varname() string { return c.vname }
func (c *common) Alias(typ string) { c.alias = typ }
func (c *common) hidden() {}
func (c *common) AllowNil() bool { return false }
func (c *common) SetIsAllowNil(bool) {}
// GenericTypeParams is a struct that contains the generic type parameters for an element.
type GenericTypeParams struct {
TypeParams string
ToPointerMap map[string]string
isPtr bool
}

func (c *common) SetVarname(s string) { c.vname = s }
func (c *common) Varname() string { return c.vname }
func (c *common) Alias(typ string) { c.alias = typ }
func (c *common) hidden() {}
func (c *common) AllowNil() bool { return false }
func (c *common) SetIsAllowNil(bool) {}
func (c *common) SetTypeParams(tp GenericTypeParams) { c.typeParams = tp }
func (c *common) TypeParams() GenericTypeParams { return c.typeParams }
func (c *common) AlwaysPtr(set *bool) bool {
if c != nil && set != nil {
c.ptrRcv = *set
Expand Down Expand Up @@ -228,6 +238,12 @@ type Elem interface {
// Note that this is NOT used by the `omitzero` feature.
IfZeroExpr() string

// SetTypeParams sets the generic type parameters for this element
SetTypeParams(tp GenericTypeParams)

// TypeParams returns the generic type parameters for this element
TypeParams() GenericTypeParams

hidden()
}

Expand Down
17 changes: 17 additions & 0 deletions gen/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ func (e *encodeGen) tuple(s *Struct) {
}
SetIsAllowNil(fieldElem, anField)
e.ctx.PushString(s.Fields[i].FieldName)
setTypeParams(s.Fields[i].FieldElem, s.typeParams)
next(e, s.Fields[i].FieldElem)
e.ctx.Pop()
if anField {
Expand Down Expand Up @@ -224,6 +225,7 @@ func (e *encodeGen) structmap(s *Struct) {
SetIsAllowNil(fieldElem, anField)

e.ctx.PushString(s.Fields[i].FieldName)
setTypeParams(s.Fields[i].FieldElem, s.typeParams)
next(e, s.Fields[i].FieldElem)
e.ctx.Pop()

Expand Down Expand Up @@ -269,6 +271,7 @@ func (e *encodeGen) gMap(m *Map) {
}
e.ctx.PushVar(m.Keyidx)
m.Value.SetIsAllowNil(false)
setTypeParams(m.Value, m.typeParams)
next(e, m.Value)
e.ctx.Pop()
e.p.closeblock()
Expand All @@ -280,6 +283,11 @@ func (e *encodeGen) gPtr(s *Ptr) {
}
e.fuseHook()
e.p.printf("\nif %s == nil { err = en.WriteNil(); if err != nil { return; } } else {", s.Varname())
if s.typeParams.TypeParams != "" {
tp := s.typeParams
tp.isPtr = true
s.Value.SetTypeParams(tp)
}
next(e, s.Value)
e.p.closeblock()
}
Expand All @@ -290,6 +298,7 @@ func (e *encodeGen) gSlice(s *Slice) {
}
e.fuseHook()
e.writeAndCheck(arrayHeader, lenAsUint32, s.Varname())
setTypeParams(s.Els, s.typeParams)
e.p.rangeBlock(e.ctx, s.Index, s.Varname(), e, s.Els)
}

Expand All @@ -306,6 +315,7 @@ func (e *encodeGen) gArray(a *Array) {
}

e.writeAndCheck(arrayHeader, literalFmt, coerceArraySize(a.Size))
setTypeParams(a.Els, a.typeParams)
e.p.rangeBlock(e.ctx, a.Index, a.Varname(), e, a.Els)
}

Expand All @@ -330,6 +340,13 @@ func (e *encodeGen) gBase(b *BaseElem) {
t := strings.TrimPrefix(b.BaseName(), "atomic.")
e.writeAndCheck(t, literalFmt, strings.TrimPrefix(vname, "*")+".Load()")
case IDENT: // unknown identity
dst := b.BaseType()
if b.typeParams.isPtr {
dst = "*" + dst
}
if remap := b.typeParams.ToPointerMap[dst]; remap != "" {
vname = fmt.Sprintf(remap, vname)
}
e.p.printf("\nerr = %s.EncodeMsg(en)", vname)
e.p.wrapErrCheck(e.ctx.ArgsStr())
default:
Expand Down
Loading
Loading