Skip to content

Commit

Permalink
reflect/protoregistry: add Num methods for every Range method
Browse files Browse the repository at this point in the history
The Num methods provide an O(1) lookup for the number of entries that Range
would return. This is needed to implement efficient cache invalidation logic
for caches that wrap the global registry.

Change-Id: I7c4ff97f674c4e9e4caae291f017cfad7294856c
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/193599
Reviewed-by: Damien Neil <dneil@google.com>
  • Loading branch information
dsnet committed Sep 5, 2019
1 parent ea5ada1 commit 72980ee
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 27 deletions.
74 changes: 66 additions & 8 deletions reflect/protoregistry/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,23 +255,39 @@ func (r *Files) FindFileByPath(path string) (protoreflect.FileDescriptor, error)
return nil, NotFound
}

// NumFiles reports the number of registered files.
func (r *Files) NumFiles() int {
if r == nil {
return 0
}
return len(r.filesByPath)
}

// RangeFiles iterates over all registered files.
// The iteration order is undefined.
func (r *Files) RangeFiles(f func(protoreflect.FileDescriptor) bool) {
if r == nil {
return
}
for _, d := range r.descsByName {
if p, ok := d.(*packageDescriptor); ok {
for _, file := range p.files {
if !f(file) {
return
}
}
for _, file := range r.filesByPath {
if !f(file) {
return
}
}
}

// NumFilesByPackage reports the number of registered files in a proto package.
func (r *Files) NumFilesByPackage(name protoreflect.FullName) int {
if r == nil {
return 0
}
p, ok := r.descsByName[name].(*packageDescriptor)
if !ok {
return 0
}
return len(p.files)
}

// RangeFilesByPackage iterates over all registered files in a give proto package.
// The iteration order is undefined.
func (r *Files) RangeFilesByPackage(name protoreflect.FullName, f func(protoreflect.FileDescriptor) bool) {
Expand Down Expand Up @@ -399,6 +415,10 @@ type Types struct {

typesByName typesByName
extensionsByMessage extensionsByMessage

numEnums int
numMessages int
numExtensions int
}

type (
Expand Down Expand Up @@ -428,13 +448,17 @@ typeLoop:
case protoreflect.EnumType, protoreflect.MessageType, protoreflect.ExtensionType:
// Check for conflicts in typesByName.
var desc protoreflect.Descriptor
var pcnt *int
switch t := typ.(type) {
case protoreflect.EnumType:
desc = t.Descriptor()
pcnt = &r.numEnums
case protoreflect.MessageType:
desc = t.Descriptor()
pcnt = &r.numMessages
case protoreflect.ExtensionType:
desc = t.TypeDescriptor()
pcnt = &r.numExtensions
default:
panic(fmt.Sprintf("invalid type: %T", t))
}
Expand Down Expand Up @@ -478,11 +502,12 @@ typeLoop:
r.extensionsByMessage[message][field] = xt
}

// Update typesByName.
// Update typesByName and the count.
if r.typesByName == nil {
r.typesByName = make(typesByName)
}
r.typesByName[name] = typ
(*pcnt)++
default:
if firstErr == nil {
firstErr = errors.New("invalid type: %v", typeName(typ))
Expand Down Expand Up @@ -573,6 +598,14 @@ func (r *Types) FindExtensionByNumber(message protoreflect.FullName, field proto
return nil, NotFound
}

// NumEnums reports the number of registered enums.
func (r *Types) NumEnums() int {
if r == nil {
return 0
}
return r.numEnums
}

// RangeEnums iterates over all registered enums.
// Iteration order is undefined.
func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) {
Expand All @@ -588,6 +621,14 @@ func (r *Types) RangeEnums(f func(protoreflect.EnumType) bool) {
}
}

// NumMessages reports the number of registered messages.
func (r *Types) NumMessages() int {
if r == nil {
return 0
}
return r.numMessages
}

// RangeMessages iterates over all registered messages.
// Iteration order is undefined.
func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) {
Expand All @@ -603,6 +644,14 @@ func (r *Types) RangeMessages(f func(protoreflect.MessageType) bool) {
}
}

// NumExtensions reports the number of registered extensions.
func (r *Types) NumExtensions() int {
if r == nil {
return 0
}
return r.numExtensions
}

// RangeExtensions iterates over all registered extensions.
// Iteration order is undefined.
func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) {
Expand All @@ -618,6 +667,15 @@ func (r *Types) RangeExtensions(f func(protoreflect.ExtensionType) bool) {
}
}

// NumExtensionsByMessage reports the number of registered extensions for
// a given message type.
func (r *Types) NumExtensionsByMessage(message protoreflect.FullName) int {
if r == nil {
return 0
}
return len(r.extensionsByMessage[message])
}

// RangeExtensionsByMessage iterates over all registered extensions filtered
// by a given message type. Iteration order is undefined.
func (r *Types) RangeExtensionsByMessage(message protoreflect.FullName, f func(protoreflect.ExtensionType) bool) {
Expand Down
64 changes: 45 additions & 19 deletions reflect/protoregistry/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,16 @@ func TestFiles(t *testing.T) {

for _, tc := range tt.rangePkgs {
var gotFiles []file
var gotCnt int
wantCnt := files.NumFilesByPackage(tc.inPkg)
files.RangeFilesByPackage(tc.inPkg, func(fd pref.FileDescriptor) bool {
gotFiles = append(gotFiles, file{fd.Path(), fd.Package()})
gotCnt++
return true
})
if gotCnt != wantCnt {
t.Errorf("NumFilesByPackage(%v) = %v, want %v", tc.inPkg, gotCnt, wantCnt)
}
if diff := cmp.Diff(tc.wantFiles, gotFiles, sortFiles); diff != "" {
t.Errorf("RangeFilesByPackage(%v) mismatch (-want +got):\n%v", tc.inPkg, diff)
}
Expand Down Expand Up @@ -552,58 +558,78 @@ func TestTypes(t *testing.T) {
return x == y
})

t.Run("RangeMessages", func(t *testing.T) {
want := []preg.Type{mt1}
t.Run("RangeEnums", func(t *testing.T) {
want := []preg.Type{et1}
var got []preg.Type
registry.RangeMessages(func(mt pref.MessageType) bool {
got = append(got, mt)
var gotCnt int
wantCnt := registry.NumEnums()
registry.RangeEnums(func(et pref.EnumType) bool {
got = append(got, et)
gotCnt++
return true
})

diff := cmp.Diff(want, got, sortTypes, compare)
if diff != "" {
t.Errorf("RangeMessages() mismatch (-want +got):\n%v", diff)
if gotCnt != wantCnt {
t.Errorf("NumEnums() = %v, want %v", gotCnt, wantCnt)
}
if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
t.Errorf("RangeEnums() mismatch (-want +got):\n%v", diff)
}
})

t.Run("RangeEnums", func(t *testing.T) {
want := []preg.Type{et1}
t.Run("RangeMessages", func(t *testing.T) {
want := []preg.Type{mt1}
var got []preg.Type
registry.RangeEnums(func(et pref.EnumType) bool {
got = append(got, et)
var gotCnt int
wantCnt := registry.NumMessages()
registry.RangeMessages(func(mt pref.MessageType) bool {
got = append(got, mt)
gotCnt++
return true
})

diff := cmp.Diff(want, got, sortTypes, compare)
if diff != "" {
t.Errorf("RangeEnums() mismatch (-want +got):\n%v", diff)
if gotCnt != wantCnt {
t.Errorf("NumMessages() = %v, want %v", gotCnt, wantCnt)
}
if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
t.Errorf("RangeMessages() mismatch (-want +got):\n%v", diff)
}
})

t.Run("RangeExtensions", func(t *testing.T) {
want := []preg.Type{xt1, xt2}
var got []preg.Type
var gotCnt int
wantCnt := registry.NumExtensions()
registry.RangeExtensions(func(xt pref.ExtensionType) bool {
got = append(got, xt)
gotCnt++
return true
})

diff := cmp.Diff(want, got, sortTypes, compare)
if diff != "" {
if gotCnt != wantCnt {
t.Errorf("NumExtensions() = %v, want %v", gotCnt, wantCnt)
}
if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
t.Errorf("RangeExtensions() mismatch (-want +got):\n%v", diff)
}
})

t.Run("RangeExtensionsByMessage", func(t *testing.T) {
want := []preg.Type{xt1, xt2}
var got []preg.Type
registry.RangeExtensionsByMessage(pref.FullName("testprotos.Message1"), func(xt pref.ExtensionType) bool {
var gotCnt int
wantCnt := registry.NumExtensionsByMessage("testprotos.Message1")
registry.RangeExtensionsByMessage("testprotos.Message1", func(xt pref.ExtensionType) bool {
got = append(got, xt)
gotCnt++
return true
})

diff := cmp.Diff(want, got, sortTypes, compare)
if diff != "" {
if gotCnt != wantCnt {
t.Errorf("NumExtensionsByMessage() = %v, want %v", gotCnt, wantCnt)
}
if diff := cmp.Diff(want, got, sortTypes, compare); diff != "" {
t.Errorf("RangeExtensionsByMessage() mismatch (-want +got):\n%v", diff)
}
})
Expand Down

0 comments on commit 72980ee

Please sign in to comment.