-
Notifications
You must be signed in to change notification settings - Fork 73
/
pool.go
105 lines (89 loc) · 2.52 KB
/
pool.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
package pool
import (
"fmt"
"google.golang.org/protobuf/compiler/protogen"
"google.golang.org/protobuf/reflect/protoreflect"
"github.com/planetscale/vtprotobuf/generator"
)
func init() {
generator.RegisterFeature("pool", func(gen *generator.GeneratedFile) generator.FeatureGenerator {
return &pool{GeneratedFile: gen}
})
}
type pool struct {
*generator.GeneratedFile
once bool
}
var _ generator.FeatureGenerator = (*pool)(nil)
func (p *pool) GenerateFile(file *protogen.File) bool {
for _, message := range file.Messages {
p.message(message)
}
return p.once
}
func (p *pool) message(message *protogen.Message) {
for _, nested := range message.Messages {
p.message(nested)
}
if message.Desc.IsMapEntry() || !p.ShouldPool(message) {
return
}
p.once = true
ccTypeName := message.GoIdent
p.P(`var vtprotoPool_`, ccTypeName, ` = `, p.Ident("sync", "Pool"), `{`)
p.P(`New: func() interface{} {`)
p.P(`return &`, ccTypeName, `{}`)
p.P(`},`)
p.P(`}`)
p.P(`func (m *`, ccTypeName, `) ResetVT() {`)
p.P(`if m != nil {`)
var saved []*protogen.Field
for _, field := range message.Fields {
fieldName := field.GoName
if field.Desc.IsList() {
switch field.Desc.Kind() {
case protoreflect.MessageKind, protoreflect.GroupKind:
p.P(`for _, mm := range m.`, fieldName, `{`)
if p.ShouldPool(field.Message) {
p.P(`mm.ResetVT()`)
} else {
p.P(`mm.Reset()`)
}
p.P(`}`)
}
p.P(fmt.Sprintf("f%d", len(saved)), ` := m.`, fieldName, `[:0]`)
saved = append(saved, field)
} else if field.Oneof != nil && !field.Oneof.Desc.IsSynthetic() {
if p.ShouldPool(field.Message) {
p.P(`if oneof, ok := m.`, field.Oneof.GoName, `.(*`, field.GoIdent, `); ok {`)
p.P(`oneof.`, fieldName, `.ReturnToVTPool()`)
p.P(`}`)
}
} else {
switch field.Desc.Kind() {
case protoreflect.MessageKind, protoreflect.GroupKind:
if !field.Desc.IsMap() && p.ShouldPool(field.Message) {
p.P(`m.`, fieldName, `.ReturnToVTPool()`)
}
case protoreflect.BytesKind:
p.P(fmt.Sprintf("f%d", len(saved)), ` := m.`, fieldName, `[:0]`)
saved = append(saved, field)
}
}
}
p.P(`m.Reset()`)
for i, field := range saved {
p.P(`m.`, field.GoName, ` = `, fmt.Sprintf("f%d", i))
}
p.P(`}`)
p.P(`}`)
p.P(`func (m *`, ccTypeName, `) ReturnToVTPool() {`)
p.P(`if m != nil {`)
p.P(`m.ResetVT()`)
p.P(`vtprotoPool_`, ccTypeName, `.Put(m)`)
p.P(`}`)
p.P(`}`)
p.P(`func `, ccTypeName, `FromVTPool() *`, ccTypeName, `{`)
p.P(`return vtprotoPool_`, ccTypeName, `.Get().(*`, ccTypeName, `)`)
p.P(`}`)
}