From 96a931a30b076ea6c07f1938e9cc7f5a58f3aa41 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Fri, 31 May 2024 17:53:53 -0600 Subject: [PATCH 1/5] GODRIVER-3023 Remove unused error from Clone function signature (#1645) --- internal/integration/mtest/mongotest.go | 4 +--- mongo/collection.go | 6 +++--- mongo/gridfs_bucket.go | 7 ++----- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/internal/integration/mtest/mongotest.go b/internal/integration/mtest/mongotest.go index d58a1dd469..785044a42f 100644 --- a/internal/integration/mtest/mongotest.go +++ b/internal/integration/mtest/mongotest.go @@ -602,9 +602,7 @@ func (t *T) CloneDatabase(opts *options.DatabaseOptions) { // CloneCollection modifies the default collection for this test to match the given options. func (t *T) CloneCollection(opts *options.CollectionOptions) { - var err error - t.Coll, err = t.Coll.Clone(opts) - assert.Nil(t, err, "error cloning collection: %v", err) + t.Coll = t.Coll.Clone(opts) } func sanitizeCollectionName(db string, coll string) string { diff --git a/mongo/collection.go b/mongo/collection.go index f5661c3be4..1c9754f3c6 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -162,7 +162,7 @@ func (coll *Collection) copy() *Collection { // Clone creates a copy of the Collection configured with the given CollectionOptions. // The specified options are merged with the existing options on the collection, with the specified options taking // precedence. -func (coll *Collection) Clone(opts ...*options.CollectionOptions) (*Collection, error) { +func (coll *Collection) Clone(opts ...*options.CollectionOptions) *Collection { copyColl := coll.copy() optsColl := mergeCollectionOptions(opts...) @@ -187,7 +187,7 @@ func (coll *Collection) Clone(opts ...*options.CollectionOptions) (*Collection, description.LatencySelector(copyColl.client.localThreshold), }) - return copyColl, nil + return copyColl } // Name returns the name of the collection. @@ -2169,7 +2169,7 @@ func (coll *Collection) Indexes() IndexView { // SearchIndexes returns a SearchIndexView instance that can be used to perform operations on the search indexes for the collection. func (coll *Collection) SearchIndexes() SearchIndexView { - c, _ := coll.Clone() // Clone() always return a nil error. + c := coll.Clone() // Clone() always return a nil error. c.readConcern = nil c.writeConcern = nil return SearchIndexView{ diff --git a/mongo/gridfs_bucket.go b/mongo/gridfs_bucket.go index 7c2bbac64e..e5016a5179 100644 --- a/mongo/gridfs_bucket.go +++ b/mongo/gridfs_bucket.go @@ -538,14 +538,11 @@ func createNumericalIndexIfNotExists(ctx context.Context, iv IndexView, model In // create indexes on the files and chunks collection if needed func (b *GridFSBucket) createIndexes(ctx context.Context) error { // must use primary read pref mode to check if files coll empty - cloned, err := b.filesColl.Clone(options.Collection().SetReadPreference(readpref.Primary())) - if err != nil { - return err - } + cloned := b.filesColl.Clone(options.Collection().SetReadPreference(readpref.Primary())) docRes := cloned.FindOne(ctx, bson.D{}, options.FindOne().SetProjection(bson.D{{"_id", 1}})) - _, err = docRes.Raw() + _, err := docRes.Raw() if !errors.Is(err, ErrNoDocuments) { // nil, or error that occurred during the FindOne operation return err From d263eca5efeb7dba99606a34bd559b89fa5565e3 Mon Sep 17 00:00:00 2001 From: Qingyang Hu <103950869+qingyang-hu@users.noreply.github.com> Date: Mon, 3 Jun 2024 09:36:26 -0400 Subject: [PATCH 2/5] GODRIVER-1808 Fix BSON unmarshaling into an interface containing a concrete value. (#1584) --- bson/bsoncodec.go | 13 +- bson/decoder_test.go | 156 ++++++++++++ bson/default_value_decoders.go | 355 ++++------------------------ bson/default_value_decoders_test.go | 8 - bson/empty_interface_codec.go | 2 +- bson/map_codec.go | 3 +- bson/struct_codec.go | 13 + bson/unmarshal_test.go | 151 ++++++++++++ 8 files changed, 368 insertions(+), 333 deletions(-) diff --git a/bson/bsoncodec.go b/bson/bsoncodec.go index 860a6b82af..ad1d4a8ded 100644 --- a/bson/bsoncodec.go +++ b/bson/bsoncodec.go @@ -309,17 +309,10 @@ type decodeAdapter struct { var _ ValueDecoder = decodeAdapter{} var _ typeDecoder = decodeAdapter{} -// decodeTypeOrValue calls decoder.decodeType is decoder is a typeDecoder. Otherwise, it allocates a new element of type -// t and calls decoder.DecodeValue on it. -func decodeTypeOrValue(decoder ValueDecoder, dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { - td, _ := decoder.(typeDecoder) - return decodeTypeOrValueWithInfo(decoder, td, dc, vr, t, true) -} - -func decodeTypeOrValueWithInfo(vd ValueDecoder, td typeDecoder, dc DecodeContext, vr ValueReader, t reflect.Type, convert bool) (reflect.Value, error) { - if td != nil { +func decodeTypeOrValueWithInfo(vd ValueDecoder, dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { + if td, _ := vd.(typeDecoder); td != nil { val, err := td.decodeType(dc, vr, t) - if err == nil && convert && val.Type() != t { + if err == nil && val.Type() != t { // This conversion step is necessary for slices and maps. If a user declares variables like: // // type myBool bool diff --git a/bson/decoder_test.go b/bson/decoder_test.go index 8fe8d07480..dbef3e7fb0 100644 --- a/bson/decoder_test.go +++ b/bson/decoder_test.go @@ -39,6 +39,162 @@ func TestBasicDecode(t *testing.T) { } } +func TestDecodingInterfaces(t *testing.T) { + t.Parallel() + + type testCase struct { + name string + stub func() ([]byte, interface{}, func(*testing.T)) + } + testCases := []testCase{ + { + name: "struct with interface containing a concrete value", + stub: func() ([]byte, interface{}, func(*testing.T)) { + type testStruct struct { + Value interface{} + } + var value string + + data := docToBytes(struct { + Value string + }{ + Value: "foo", + }) + + receiver := testStruct{&value} + + check := func(t *testing.T) { + t.Helper() + assert.Equal(t, "foo", value) + } + + return data, &receiver, check + }, + }, + { + name: "struct with interface containing a struct", + stub: func() ([]byte, interface{}, func(*testing.T)) { + type demo struct { + Data string + } + + type testStruct struct { + Value interface{} + } + var value demo + + data := docToBytes(struct { + Value demo + }{ + Value: demo{"foo"}, + }) + + receiver := testStruct{&value} + + check := func(t *testing.T) { + t.Helper() + assert.Equal(t, "foo", value.Data) + } + + return data, &receiver, check + }, + }, + { + name: "struct with interface containing a slice", + stub: func() ([]byte, interface{}, func(*testing.T)) { + type testStruct struct { + Values interface{} + } + var values []string + + data := docToBytes(struct { + Values []string + }{ + Values: []string{"foo", "bar"}, + }) + + receiver := testStruct{&values} + + check := func(t *testing.T) { + t.Helper() + assert.Equal(t, []string{"foo", "bar"}, values) + } + + return data, &receiver, check + }, + }, + { + name: "struct with interface containing an array", + stub: func() ([]byte, interface{}, func(*testing.T)) { + type testStruct struct { + Values interface{} + } + var values [2]string + + data := docToBytes(struct { + Values []string + }{ + Values: []string{"foo", "bar"}, + }) + + receiver := testStruct{&values} + + check := func(t *testing.T) { + t.Helper() + assert.Equal(t, [2]string{"foo", "bar"}, values) + } + + return data, &receiver, check + }, + }, + { + name: "struct with interface array containing concrete values", + stub: func() ([]byte, interface{}, func(*testing.T)) { + type testStruct struct { + Values [3]interface{} + } + var str string + var i, j int + + data := docToBytes(struct { + Values []interface{} + }{ + Values: []interface{}{"foo", 42, nil}, + }) + + receiver := testStruct{[3]interface{}{&str, &i, &j}} + + check := func(t *testing.T) { + t.Helper() + assert.Equal(t, "foo", str) + assert.Equal(t, 42, i) + assert.Equal(t, 0, j) + assert.Equal(t, testStruct{[3]interface{}{&str, &i, nil}}, receiver) + } + + return data, &receiver, check + }, + }, + } + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data, receiver, check := tc.stub() + got := reflect.ValueOf(receiver).Elem() + vr := NewValueReader(data) + reg := DefaultRegistry + decoder, err := reg.LookupDecoder(got.Type()) + noerr(t, err) + err = decoder.DecodeValue(DecodeContext{Registry: reg}, vr, got) + noerr(t, err) + check(t) + }) + } +} + func TestDecoderv2(t *testing.T) { t.Parallel() diff --git a/bson/default_value_decoders.go b/bson/default_value_decoders.go index bc8c7b9344..3256f92089 100644 --- a/bson/default_value_decoders.go +++ b/bson/default_value_decoders.go @@ -14,7 +14,6 @@ import ( "net/url" "reflect" "strconv" - "time" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" ) @@ -162,7 +161,6 @@ func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr ValueReader, v if err != nil { return err } - tEmptyTypeDecoder, _ := decoder.(typeDecoder) // Use the elements in the provided value if it's non nil. Otherwise, allocate a new D instance. var elems D @@ -181,13 +179,13 @@ func (dvd DefaultValueDecoders) DDecodeValue(dc DecodeContext, vr ValueReader, v return err } - // Pass false for convert because we don't need to call reflect.Value.Convert for tEmpty. - elem, err := decodeTypeOrValueWithInfo(decoder, tEmptyTypeDecoder, dc, elemVr, tEmpty, false) + var v interface{} + err = decoder.DecodeValue(dc, elemVr, reflect.ValueOf(&v).Elem()) if err != nil { return err } - elems = append(elems, E{Key: key, Value: elem.Interface()}) + elems = append(elems, E{Key: key, Value: v}) } val.Set(reflect.ValueOf(elems)) @@ -363,89 +361,6 @@ func (dvd DefaultValueDecoders) IntDecodeValue(dc DecodeContext, vr ValueReader, return nil } -// UintDecodeValue is the ValueDecoderFunc for uint types. -// -// Deprecated: UintDecodeValue is not registered by default. Use UintCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) UintDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { - var i64 int64 - var err error - switch vr.Type() { - case TypeInt32: - i32, err := vr.ReadInt32() - if err != nil { - return err - } - i64 = int64(i32) - case TypeInt64: - i64, err = vr.ReadInt64() - if err != nil { - return err - } - case TypeDouble: - f64, err := vr.ReadDouble() - if err != nil { - return err - } - if !dc.Truncate && math.Floor(f64) != f64 { - return errors.New("UintDecodeValue can only truncate float64 to an integer type when truncation is enabled") - } - if f64 > float64(math.MaxInt64) { - return fmt.Errorf("%g overflows int64", f64) - } - i64 = int64(f64) - case TypeBoolean: - b, err := vr.ReadBoolean() - if err != nil { - return err - } - if b { - i64 = 1 - } - default: - return fmt.Errorf("cannot decode %v into an integer type", vr.Type()) - } - - if !val.CanSet() { - return ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, - Received: val, - } - } - - switch val.Kind() { - case reflect.Uint8: - if i64 < 0 || i64 > math.MaxUint8 { - return fmt.Errorf("%d overflows uint8", i64) - } - case reflect.Uint16: - if i64 < 0 || i64 > math.MaxUint16 { - return fmt.Errorf("%d overflows uint16", i64) - } - case reflect.Uint32: - if i64 < 0 || i64 > math.MaxUint32 { - return fmt.Errorf("%d overflows uint32", i64) - } - case reflect.Uint64: - if i64 < 0 { - return fmt.Errorf("%d overflows uint64", i64) - } - case reflect.Uint: - if i64 < 0 || int64(uint(i64)) != i64 { // Can we fit this inside of an uint - return fmt.Errorf("%d overflows uint", i64) - } - default: - return ValueDecoderError{ - Name: "UintDecodeValue", - Kinds: []reflect.Kind{reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint}, - Received: val, - } - } - - val.SetUint(uint64(i64)) - return nil -} - func (dvd DefaultValueDecoders) floatDecodeType(dc DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { var f float64 var err error @@ -527,30 +442,6 @@ func (dvd DefaultValueDecoders) FloatDecodeValue(ec DecodeContext, vr ValueReade return nil } -// StringDecodeValue is the ValueDecoderFunc for string types. -// -// Deprecated: StringDecodeValue is not registered by default. Use StringCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) StringDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { - var str string - var err error - switch vr.Type() { - // TODO(GODRIVER-577): Handle JavaScript and Symbol BSON types when allowed. - case TypeString: - str, err = vr.ReadString() - if err != nil { - return err - } - default: - return fmt.Errorf("cannot decode %v into a string type", vr.Type()) - } - if !val.CanSet() || val.Kind() != reflect.String { - return ValueDecoderError{Name: "StringDecodeValue", Kinds: []reflect.Kind{reflect.String}, Received: val} - } - - val.SetString(str) - return nil -} - func (DefaultValueDecoders) javaScriptDecodeType(_ DecodeContext, vr ValueReader, t reflect.Type) (reflect.Value, error) { if t != tJavaScript { return emptyValue, ValueDecoderError{ @@ -1287,114 +1178,6 @@ func (dvd DefaultValueDecoders) URLDecodeValue(dc DecodeContext, vr ValueReader, return nil } -// TimeDecodeValue is the ValueDecoderFunc for time.Time. -// -// Deprecated: TimeDecodeValue is not registered by default. Use TimeCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) TimeDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { - if vr.Type() != TypeDateTime { - return fmt.Errorf("cannot decode %v into a time.Time", vr.Type()) - } - - dt, err := vr.ReadDateTime() - if err != nil { - return err - } - - if !val.CanSet() || val.Type() != tTime { - return ValueDecoderError{Name: "TimeDecodeValue", Types: []reflect.Type{tTime}, Received: val} - } - - val.Set(reflect.ValueOf(time.Unix(dt/1000, dt%1000*1000000).UTC())) - return nil -} - -// ByteSliceDecodeValue is the ValueDecoderFunc for []byte. -// -// Deprecated: ByteSliceDecodeValue is not registered by default. Use ByteSliceCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) ByteSliceDecodeValue(_ DecodeContext, vr ValueReader, val reflect.Value) error { - if vr.Type() != TypeBinary && vr.Type() != TypeNull { - return fmt.Errorf("cannot decode %v into a []byte", vr.Type()) - } - - if !val.CanSet() || val.Type() != tByteSlice { - return ValueDecoderError{Name: "ByteSliceDecodeValue", Types: []reflect.Type{tByteSlice}, Received: val} - } - - if vr.Type() == TypeNull { - val.Set(reflect.Zero(val.Type())) - return vr.ReadNull() - } - - data, subtype, err := vr.ReadBinary() - if err != nil { - return err - } - if subtype != 0x00 { - return fmt.Errorf("ByteSliceDecodeValue can only be used to decode subtype 0x00 for %s, got %v", TypeBinary, subtype) - } - - val.Set(reflect.ValueOf(data)) - return nil -} - -// MapDecodeValue is the ValueDecoderFunc for map[string]* types. -// -// Deprecated: MapDecodeValue is not registered by default. Use MapCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) MapDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Kind() != reflect.Map || val.Type().Key().Kind() != reflect.String { - return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val} - } - - switch vr.Type() { - case Type(0), TypeEmbeddedDocument: - case TypeNull: - val.Set(reflect.Zero(val.Type())) - return vr.ReadNull() - default: - return fmt.Errorf("cannot decode %v into a %s", vr.Type(), val.Type()) - } - - dr, err := vr.ReadDocument() - if err != nil { - return err - } - - if val.IsNil() { - val.Set(reflect.MakeMap(val.Type())) - } - - eType := val.Type().Elem() - decoder, err := dc.LookupDecoder(eType) - if err != nil { - return err - } - - if eType == tEmpty { - dc.Ancestor = val.Type() - } - - keyType := val.Type().Key() - for { - key, vr, err := dr.ReadElement() - if errors.Is(err, ErrEOD) { - break - } - if err != nil { - return err - } - - elem := reflect.New(eType).Elem() - - err = decoder.DecodeValue(dc, vr, elem) - if err != nil { - return err - } - - val.SetMapIndex(reflect.ValueOf(key).Convert(keyType), elem) - } - return nil -} - // ArrayDecodeValue is the ValueDecoderFunc for array types. // // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default @@ -1464,51 +1247,6 @@ func (dvd DefaultValueDecoders) ArrayDecodeValue(dc DecodeContext, vr ValueReade return nil } -// SliceDecodeValue is the ValueDecoderFunc for slice types. -// -// Deprecated: SliceDecodeValue is not registered by default. Use SliceCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) SliceDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Kind() != reflect.Slice { - return ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} - } - - switch vr.Type() { - case TypeArray: - case TypeNull: - val.Set(reflect.Zero(val.Type())) - return vr.ReadNull() - case Type(0), TypeEmbeddedDocument: - if val.Type().Elem() != tE { - return fmt.Errorf("cannot decode document into %s", val.Type()) - } - default: - return fmt.Errorf("cannot decode %v into a slice", vr.Type()) - } - - var elemsFunc func(DecodeContext, ValueReader, reflect.Value) ([]reflect.Value, error) - switch val.Type().Elem() { - case tE: - dc.Ancestor = val.Type() - elemsFunc = dvd.decodeD - default: - elemsFunc = dvd.decodeDefault - } - - elems, err := elemsFunc(dc, vr, val) - if err != nil { - return err - } - - if val.IsNil() { - val.Set(reflect.MakeSlice(val.Type(), 0, len(elems))) - } - - val.SetLen(0) - val.Set(reflect.Append(val, elems...)) - - return nil -} - // ValueUnmarshalerDecodeValue is the ValueDecoderFunc for ValueUnmarshaler implementations. // // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default @@ -1593,46 +1331,6 @@ func (dvd DefaultValueDecoders) UnmarshalerDecodeValue(_ DecodeContext, vr Value return m.UnmarshalBSON(src) } -// EmptyInterfaceDecodeValue is the ValueDecoderFunc for interface{}. -// -// Deprecated: EmptyInterfaceDecodeValue is not registered by default. Use EmptyInterfaceCodec.DecodeValue instead. -func (dvd DefaultValueDecoders) EmptyInterfaceDecodeValue(dc DecodeContext, vr ValueReader, val reflect.Value) error { - if !val.CanSet() || val.Type() != tEmpty { - return ValueDecoderError{Name: "EmptyInterfaceDecodeValue", Types: []reflect.Type{tEmpty}, Received: val} - } - - rtype, err := dc.LookupTypeMapEntry(vr.Type()) - if err != nil { - switch vr.Type() { - case TypeEmbeddedDocument: - if dc.Ancestor != nil { - rtype = dc.Ancestor - break - } - rtype = tD - case TypeNull: - val.Set(reflect.Zero(val.Type())) - return vr.ReadNull() - default: - return err - } - } - - decoder, err := dc.LookupDecoder(rtype) - if err != nil { - return err - } - - elem := reflect.New(rtype).Elem() - err = decoder.DecodeValue(dc, vr, elem) - if err != nil { - return err - } - - val.Set(elem) - return nil -} - // CoreDocumentDecodeValue is the ValueDecoderFunc for bsoncore.Document. // // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with all default @@ -1663,11 +1361,13 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr ValueReader, eType := val.Type().Elem() - decoder, err := dc.LookupDecoder(eType) - if err != nil { - return nil, err + var vDecoder ValueDecoder + if !(eType.Kind() == reflect.Interface && val.Len() > 0) { + vDecoder, err = dc.LookupDecoder(eType) + if err != nil { + return nil, err + } } - eTypeDecoder, _ := decoder.(typeDecoder) idx := 0 for { @@ -1679,10 +1379,41 @@ func (dvd DefaultValueDecoders) decodeDefault(dc DecodeContext, vr ValueReader, return nil, err } - elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true) - if err != nil { - return nil, newDecodeError(strconv.Itoa(idx), err) + var elem reflect.Value + if vDecoder == nil { + elem = val.Index(idx).Elem() + if elem.Kind() != reflect.Ptr || elem.IsNil() { + valueDecoder, err := dc.LookupDecoder(elem.Type()) + if err != nil { + return nil, err + } + err = valueDecoder.DecodeValue(dc, vr, elem) + if err != nil { + return nil, newDecodeError(strconv.Itoa(idx), err) + } + } else if vr.Type() == TypeNull { + if err = vr.ReadNull(); err != nil { + return nil, err + } + elem = reflect.Zero(val.Index(idx).Type()) + } else { + e := elem.Elem() + valueDecoder, err := dc.LookupDecoder(e.Type()) + if err != nil { + return nil, err + } + err = valueDecoder.DecodeValue(dc, vr, e) + if err != nil { + return nil, newDecodeError(strconv.Itoa(idx), err) + } + } + } else { + elem, err = decodeTypeOrValueWithInfo(vDecoder, dc, vr, eType) + if err != nil { + return nil, newDecodeError(strconv.Itoa(idx), err) + } } + elems = append(elems, elem) idx++ } diff --git a/bson/default_value_decoders_test.go b/bson/default_value_decoders_test.go index 699a958605..31148ab644 100644 --- a/bson/default_value_decoders_test.go +++ b/bson/default_value_decoders_test.go @@ -2472,14 +2472,6 @@ func TestDefaultValueDecoders(t *testing.T) { } }) - t.Run("SliceCodec/DecodeValue/can't set slice", func(t *testing.T) { - var val []string - want := ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: reflect.ValueOf(val)} - got := dvd.SliceDecodeValue(DecodeContext{}, nil, reflect.ValueOf(val)) - if !assert.CompareErrors(got, want) { - t.Errorf("Errors do not match. got %v; want %v", got, want) - } - }) t.Run("SliceCodec/DecodeValue/too many elements", func(t *testing.T) { idx, doc := bsoncore.AppendDocumentStart(nil) aidx, doc := bsoncore.AppendArrayElementStart(doc, "foo") diff --git a/bson/empty_interface_codec.go b/bson/empty_interface_codec.go index 56468e3068..e0af34c942 100644 --- a/bson/empty_interface_codec.go +++ b/bson/empty_interface_codec.go @@ -125,7 +125,7 @@ func (eic EmptyInterfaceCodec) decodeType(dc DecodeContext, vr ValueReader, t re return emptyValue, err } - elem, err := decodeTypeOrValue(decoder, dc, vr, rtype) + elem, err := decodeTypeOrValueWithInfo(decoder, dc, vr, rtype) if err != nil { return emptyValue, err } diff --git a/bson/map_codec.go b/bson/map_codec.go index 9592957db4..fddcc5c8b7 100644 --- a/bson/map_codec.go +++ b/bson/map_codec.go @@ -189,7 +189,6 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Va if err != nil { return err } - eTypeDecoder, _ := decoder.(typeDecoder) if eType == tEmpty { dc.Ancestor = val.Type() @@ -211,7 +210,7 @@ func (mc *MapCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect.Va return err } - elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true) + elem, err := decodeTypeOrValueWithInfo(decoder, dc, vr, eType) if err != nil { return newDecodeError(key, err) } diff --git a/bson/struct_codec.go b/bson/struct_codec.go index 917ac17bfd..14337c7a2e 100644 --- a/bson/struct_codec.go +++ b/bson/struct_codec.go @@ -349,6 +349,19 @@ func (sc *StructCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect } } + if field.Kind() == reflect.Interface && !field.IsNil() && field.Elem().Kind() == reflect.Ptr { + v := field.Elem().Elem() + decoder, err = dc.LookupDecoder(v.Type()) + if err != nil { + return err + } + err = decoder.DecodeValue(dc, vr, v) + if err != nil { + return newDecodeError(fd.name, err) + } + continue + } + if !field.CanSet() { // Being settable is a super set of being addressable. innerErr := fmt.Errorf("field %v is not settable", field) return newDecodeError(fd.name, innerErr) diff --git a/bson/unmarshal_test.go b/bson/unmarshal_test.go index 0871237386..d83edd3940 100644 --- a/bson/unmarshal_test.go +++ b/bson/unmarshal_test.go @@ -412,6 +412,157 @@ func TestUnmarshalExtJSONWithUndefinedField(t *testing.T) { } } +func TestUnmarshalInterface(t *testing.T) { + t.Parallel() + + type testCase struct { + name string + stub func() ([]byte, interface{}, func(*testing.T)) + } + testCases := []testCase{ + { + name: "struct with interface containing a concrete value", + stub: func() ([]byte, interface{}, func(*testing.T)) { + type testStruct struct { + Value interface{} + } + var value string + + data := docToBytes(struct { + Value string + }{ + Value: "foo", + }) + + receiver := testStruct{&value} + + check := func(t *testing.T) { + t.Helper() + assert.Equal(t, "foo", value) + } + + return data, &receiver, check + }, + }, + { + name: "struct with interface containing a struct", + stub: func() ([]byte, interface{}, func(*testing.T)) { + type demo struct { + Data string + } + + type testStruct struct { + Value interface{} + } + var value demo + + data := docToBytes(struct { + Value demo + }{ + Value: demo{"foo"}, + }) + + receiver := testStruct{&value} + + check := func(t *testing.T) { + t.Helper() + assert.Equal(t, "foo", value.Data) + } + + return data, &receiver, check + }, + }, + { + name: "struct with interface containing a slice", + stub: func() ([]byte, interface{}, func(*testing.T)) { + type testStruct struct { + Values interface{} + } + var values []string + + data := docToBytes(struct { + Values []string + }{ + Values: []string{"foo", "bar"}, + }) + + receiver := testStruct{&values} + + check := func(t *testing.T) { + t.Helper() + assert.Equal(t, []string{"foo", "bar"}, values) + } + + return data, &receiver, check + }, + }, + { + name: "struct with interface containing an array", + stub: func() ([]byte, interface{}, func(*testing.T)) { + type testStruct struct { + Values interface{} + } + var values [2]string + + data := docToBytes(struct { + Values []string + }{ + Values: []string{"foo", "bar"}, + }) + + receiver := testStruct{&values} + + check := func(t *testing.T) { + t.Helper() + assert.Equal(t, [2]string{"foo", "bar"}, values) + } + + return data, &receiver, check + }, + }, + { + name: "struct with interface array containing concrete values", + stub: func() ([]byte, interface{}, func(*testing.T)) { + type testStruct struct { + Values [3]interface{} + } + var str string + var i, j int + + data := docToBytes(struct { + Values []interface{} + }{ + Values: []interface{}{"foo", 42, nil}, + }) + + receiver := testStruct{[3]interface{}{&str, &i, &j}} + + check := func(t *testing.T) { + t.Helper() + assert.Equal(t, "foo", str) + assert.Equal(t, 42, i) + assert.Equal(t, 0, j) + assert.Equal(t, testStruct{[3]interface{}{&str, &i, nil}}, receiver) + } + + return data, &receiver, check + }, + }, + } + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data, receiver, check := tc.stub() + err := Unmarshal(data, receiver) + noerr(t, err) + check(t) + }) + } +} + func TestUnmarshalBSONWithUndefinedField(t *testing.T) { // When unmarshalling BSON, fields that are undefined in the destination struct are skipped. // This process must not skip other, defined fields and must not raise errors. From 6f19aef58bc7911d6fb7ee7c65112c37d1e3e29d Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 4 Jun 2024 12:59:24 -0500 Subject: [PATCH 3/5] GODRIVER-3125 [master] Allow to set search index type (#1652) --- .../integration/search_index_prose_test.go | 147 ++++++++++++++++++ .../unified/collection_operation_execution.go | 4 + mongo/options/searchindexoptions.go | 7 + mongo/search_index_view.go | 3 + .../index-management/createSearchIndex.json | 72 ++++++++- .../index-management/createSearchIndex.yml | 30 +++- .../index-management/createSearchIndexes.json | 74 ++++++++- .../index-management/createSearchIndexes.yml | 30 +++- 8 files changed, 351 insertions(+), 16 deletions(-) diff --git a/internal/integration/search_index_prose_test.go b/internal/integration/search_index_prose_test.go index d3f23e62d6..19b0f95997 100644 --- a/internal/integration/search_index_prose_test.go +++ b/internal/integration/search_index_prose_test.go @@ -311,4 +311,151 @@ func TestSearchIndexProse(t *testing.T) { actual := doc.Lookup("latestDefinition").Value assert.Equal(mt, expected, actual, "unmatched definition") }) + + case7CollName, err := uuid.New() + assert.NoError(mt, err, "failed to create random collection name for case #7") + + mt.RunOpts("case 7: Driver can successfully handle search index types when creating indexes", + mtest.NewOptions().CollectionName(case7CollName.String()), + func(mt *mtest.T) { + ctx := context.Background() + + _, err := mt.Coll.InsertOne(ctx, bson.D{}) + require.NoError(mt, err, "failed to insert") + + view := mt.Coll.SearchIndexes() + + definition := bson.D{{"mappings", bson.D{{"dynamic", false}}}} + indexName := "test-search-index-case7-implicit" + opts := options.SearchIndexes().SetName(indexName) + index, err := view.CreateOne(ctx, mongo.SearchIndexModel{ + Definition: definition, + Options: opts, + }) + require.NoError(mt, err, "failed to create index") + require.Equal(mt, indexName, index, "unmatched name") + var doc bson.Raw + for doc == nil { + cursor, err := view.List(ctx, opts) + require.NoError(mt, err, "failed to list") + + if !cursor.Next(ctx) { + break + } + name := cursor.Current.Lookup("name").StringValue() + queryable := cursor.Current.Lookup("queryable").Boolean() + indexType := cursor.Current.Lookup("type").StringValue() + if name == indexName && queryable { + doc = cursor.Current + assert.Equal(mt, indexType, "search") + } else { + t.Logf("cursor: %s, sleep 5 seconds...", cursor.Current.String()) + time.Sleep(5 * time.Second) + } + } + + indexName = "test-search-index-case7-explicit" + opts = options.SearchIndexes().SetName(indexName).SetType("search") + index, err = view.CreateOne(ctx, mongo.SearchIndexModel{ + Definition: definition, + Options: opts, + }) + require.NoError(mt, err, "failed to create index") + require.Equal(mt, indexName, index, "unmatched name") + doc = nil + for doc == nil { + cursor, err := view.List(ctx, opts) + require.NoError(mt, err, "failed to list") + + if !cursor.Next(ctx) { + break + } + name := cursor.Current.Lookup("name").StringValue() + queryable := cursor.Current.Lookup("queryable").Boolean() + indexType := cursor.Current.Lookup("type").StringValue() + if name == indexName && queryable { + doc = cursor.Current + assert.Equal(mt, indexType, "search") + } else { + t.Logf("cursor: %s, sleep 5 seconds...", cursor.Current.String()) + time.Sleep(5 * time.Second) + } + } + + indexName = "test-search-index-case7-vector" + type vectorDefinitionField struct { + Type string `bson:"type"` + Path string `bson:"path"` + NumDimensions int `bson:"numDimensions"` + Similarity string `bson:"similarity"` + } + + type vectorDefinition struct { + Fields []vectorDefinitionField `bson:"fields"` + } + + opts = options.SearchIndexes().SetName(indexName).SetType("vectorSearch") + index, err = view.CreateOne(ctx, mongo.SearchIndexModel{ + Definition: vectorDefinition{ + Fields: []vectorDefinitionField{{"vector", "path", 1536, "euclidean"}}, + }, + Options: opts, + }) + require.NoError(mt, err, "failed to create index") + require.Equal(mt, indexName, index, "unmatched name") + doc = nil + for doc == nil { + cursor, err := view.List(ctx, opts) + require.NoError(mt, err, "failed to list") + + if !cursor.Next(ctx) { + break + } + name := cursor.Current.Lookup("name").StringValue() + queryable := cursor.Current.Lookup("queryable").Boolean() + indexType := cursor.Current.Lookup("type").StringValue() + if name == indexName && queryable { + doc = cursor.Current + assert.Equal(mt, indexType, "vectorSearch") + } else { + t.Logf("cursor: %s, sleep 5 seconds...", cursor.Current.String()) + time.Sleep(5 * time.Second) + } + } + }) + + case8CollName, err := uuid.New() + assert.NoError(mt, err, "failed to create random collection name for case #8") + + mt.RunOpts("case 8: Driver requires explicit type to create a vector search index", + mtest.NewOptions().CollectionName(case8CollName.String()), + func(mt *mtest.T) { + ctx := context.Background() + + _, err := mt.Coll.InsertOne(ctx, bson.D{}) + require.NoError(mt, err, "failed to insert") + + view := mt.Coll.SearchIndexes() + + type vectorDefinitionField struct { + Type string `bson:"type"` + Path string `bson:"path"` + NumDimensions int `bson:"numDimensions"` + Similarity string `bson:"similarity"` + } + + type vectorDefinition struct { + Fields []vectorDefinitionField `bson:"fields"` + } + + const indexName = "test-search-index-case7-vector" + opts := options.SearchIndexes().SetName(indexName) + _, err = view.CreateOne(ctx, mongo.SearchIndexModel{ + Definition: vectorDefinition{ + Fields: []vectorDefinitionField{{"vector", "plot_embedding", 1536, "euclidean"}}, + }, + Options: opts, + }) + assert.ErrorContains(mt, err, "Attribute mappings missing") + }) } diff --git a/internal/integration/unified/collection_operation_execution.go b/internal/integration/unified/collection_operation_execution.go index 796ab01344..d211013920 100644 --- a/internal/integration/unified/collection_operation_execution.go +++ b/internal/integration/unified/collection_operation_execution.go @@ -311,6 +311,7 @@ func executeCreateSearchIndex(ctx context.Context, operation *operation) (*opera var m struct { Definition interface{} Name *string + Type *string } err = bson.Unmarshal(val.Document(), &m) if err != nil { @@ -319,6 +320,7 @@ func executeCreateSearchIndex(ctx context.Context, operation *operation) (*opera model.Definition = m.Definition model.Options = options.SearchIndexes() model.Options.Name = m.Name + model.Options.Type = m.Type default: return nil, fmt.Errorf("unrecognized createSearchIndex option %q", key) } @@ -354,6 +356,7 @@ func executeCreateSearchIndexes(ctx context.Context, operation *operation) (*ope var m struct { Definition interface{} Name *string + Type *string } err = bson.Unmarshal(val.Value, &m) if err != nil { @@ -364,6 +367,7 @@ func executeCreateSearchIndexes(ctx context.Context, operation *operation) (*ope Options: options.SearchIndexes(), } model.Options.Name = m.Name + model.Options.Type = m.Type models = append(models, model) } default: diff --git a/mongo/options/searchindexoptions.go b/mongo/options/searchindexoptions.go index 9774d615ba..8cb8a08b78 100644 --- a/mongo/options/searchindexoptions.go +++ b/mongo/options/searchindexoptions.go @@ -9,6 +9,7 @@ package options // SearchIndexesOptions represents options that can be used to configure a SearchIndexView. type SearchIndexesOptions struct { Name *string + Type *string } // SearchIndexes creates a new SearchIndexesOptions instance. @@ -22,6 +23,12 @@ func (sio *SearchIndexesOptions) SetName(name string) *SearchIndexesOptions { return sio } +// SetType sets the value for the Type field. +func (sio *SearchIndexesOptions) SetType(typ string) *SearchIndexesOptions { + sio.Type = &typ + return sio +} + // CreateSearchIndexesOptions represents options that can be used to configure a SearchIndexView.CreateOne or // SearchIndexView.CreateMany operation. type CreateSearchIndexesOptions struct { diff --git a/mongo/search_index_view.go b/mongo/search_index_view.go index 695a396425..73fe8534ed 100644 --- a/mongo/search_index_view.go +++ b/mongo/search_index_view.go @@ -108,6 +108,9 @@ func (siv SearchIndexView) CreateMany( if model.Options != nil && model.Options.Name != nil { indexes = bsoncore.AppendStringElement(indexes, "name", *model.Options.Name) } + if model.Options != nil && model.Options.Type != nil { + indexes = bsoncore.AppendStringElement(indexes, "type", *model.Options.Type) + } indexes = bsoncore.AppendDocumentElement(indexes, "definition", definition) indexes, err = bsoncore.AppendDocumentEnd(indexes, iidx) diff --git a/testdata/index-management/createSearchIndex.json b/testdata/index-management/createSearchIndex.json index f9c4e44d3e..327cb61259 100644 --- a/testdata/index-management/createSearchIndex.json +++ b/testdata/index-management/createSearchIndex.json @@ -50,7 +50,8 @@ "mappings": { "dynamic": true } - } + }, + "type": "search" } }, "expectError": { @@ -73,7 +74,8 @@ "mappings": { "dynamic": true } - } + }, + "type": "search" } ], "$db": "database0" @@ -97,7 +99,8 @@ "dynamic": true } }, - "name": "test index" + "name": "test index", + "type": "search" } }, "expectError": { @@ -121,7 +124,68 @@ "dynamic": true } }, - "name": "test index" + "name": "test index", + "type": "search" + } + ], + "$db": "database0" + } + } + } + ] + } + ] + }, + { + "description": "create a vector search index", + "operations": [ + { + "name": "createSearchIndex", + "object": "collection0", + "arguments": { + "model": { + "definition": { + "fields": [ + { + "type": "vector", + "path": "plot_embedding", + "numDimensions": 1536, + "similarity": "euclidean" + } + ] + }, + "name": "test index", + "type": "vectorSearch" + } + }, + "expectError": { + "isError": true, + "errorContains": "Atlas" + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "createSearchIndexes": "collection0", + "indexes": [ + { + "definition": { + "fields": [ + { + "type": "vector", + "path": "plot_embedding", + "numDimensions": 1536, + "similarity": "euclidean" + } + ] + }, + "name": "test index", + "type": "vectorSearch" } ], "$db": "database0" diff --git a/testdata/index-management/createSearchIndex.yml b/testdata/index-management/createSearchIndex.yml index 2e3cf50f8d..a32546cacf 100644 --- a/testdata/index-management/createSearchIndex.yml +++ b/testdata/index-management/createSearchIndex.yml @@ -26,7 +26,7 @@ tests: - name: createSearchIndex object: *collection0 arguments: - model: { definition: &definition { mappings: { dynamic: true } } } + model: { definition: &definition { mappings: { dynamic: true } } , type: 'search' } expectError: # This test always errors in a non-Atlas environment. The test functions as a unit test by asserting # that the driver constructs and sends the correct command. @@ -39,7 +39,7 @@ tests: - commandStartedEvent: command: createSearchIndexes: *collection0 - indexes: [ { definition: *definition } ] + indexes: [ { definition: *definition, type: 'search'} ] $db: *database0 - description: "name provided for an index definition" @@ -47,7 +47,7 @@ tests: - name: createSearchIndex object: *collection0 arguments: - model: { definition: &definition { mappings: { dynamic: true } } , name: 'test index' } + model: { definition: &definition { mappings: { dynamic: true } } , name: 'test index', type: 'search' } expectError: # This test always errors in a non-Atlas environment. The test functions as a unit test by asserting # that the driver constructs and sends the correct command. @@ -60,5 +60,27 @@ tests: - commandStartedEvent: command: createSearchIndexes: *collection0 - indexes: [ { definition: *definition, name: 'test index' } ] + indexes: [ { definition: *definition, name: 'test index', type: 'search' } ] + $db: *database0 + + - description: "create a vector search index" + operations: + - name: createSearchIndex + object: *collection0 + arguments: + model: { definition: &definition { fields: [ {"type": "vector", "path": "plot_embedding", "numDimensions": 1536, "similarity": "euclidean"} ] } + , name: 'test index', type: 'vectorSearch' } + expectError: + # This test always errors in a non-Atlas environment. The test functions as a unit test by asserting + # that the driver constructs and sends the correct command. + # The expected error message was changed in SERVER-83003. Check for the substring "Atlas" shared by both error messages. + isError: true + errorContains: Atlas + expectEvents: + - client: *client0 + events: + - commandStartedEvent: + command: + createSearchIndexes: *collection0 + indexes: [ { definition: *definition, name: 'test index', type: 'vectorSearch' } ] $db: *database0 diff --git a/testdata/index-management/createSearchIndexes.json b/testdata/index-management/createSearchIndexes.json index 3cf56ce12e..d91d7d9cf3 100644 --- a/testdata/index-management/createSearchIndexes.json +++ b/testdata/index-management/createSearchIndexes.json @@ -83,7 +83,8 @@ "mappings": { "dynamic": true } - } + }, + "type": "search" } ] }, @@ -107,7 +108,8 @@ "mappings": { "dynamic": true } - } + }, + "type": "search" } ], "$db": "database0" @@ -132,7 +134,8 @@ "dynamic": true } }, - "name": "test index" + "name": "test index", + "type": "search" } ] }, @@ -157,7 +160,70 @@ "dynamic": true } }, - "name": "test index" + "name": "test index", + "type": "search" + } + ], + "$db": "database0" + } + } + } + ] + } + ] + }, + { + "description": "create a vector search index", + "operations": [ + { + "name": "createSearchIndexes", + "object": "collection0", + "arguments": { + "models": [ + { + "definition": { + "fields": [ + { + "type": "vector", + "path": "plot_embedding", + "numDimensions": 1536, + "similarity": "euclidean" + } + ] + }, + "name": "test index", + "type": "vectorSearch" + } + ] + }, + "expectError": { + "isError": true, + "errorContains": "Atlas" + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "createSearchIndexes": "collection0", + "indexes": [ + { + "definition": { + "fields": [ + { + "type": "vector", + "path": "plot_embedding", + "numDimensions": 1536, + "similarity": "euclidean" + } + ] + }, + "name": "test index", + "type": "vectorSearch" } ], "$db": "database0" diff --git a/testdata/index-management/createSearchIndexes.yml b/testdata/index-management/createSearchIndexes.yml index db8f02e551..cac442cb87 100644 --- a/testdata/index-management/createSearchIndexes.yml +++ b/testdata/index-management/createSearchIndexes.yml @@ -48,7 +48,7 @@ tests: - name: createSearchIndexes object: *collection0 arguments: - models: [ { definition: &definition { mappings: { dynamic: true } } } ] + models: [ { definition: &definition { mappings: { dynamic: true } } , type: 'search' } ] expectError: # This test always errors in a non-Atlas environment. The test functions as a unit test by asserting # that the driver constructs and sends the correct command. @@ -61,7 +61,7 @@ tests: - commandStartedEvent: command: createSearchIndexes: *collection0 - indexes: [ { definition: *definition } ] + indexes: [ { definition: *definition, type: 'search'} ] $db: *database0 - description: "name provided for an index definition" @@ -69,7 +69,7 @@ tests: - name: createSearchIndexes object: *collection0 arguments: - models: [ { definition: &definition { mappings: { dynamic: true } } , name: 'test index' } ] + models: [ { definition: &definition { mappings: { dynamic: true } } , name: 'test index' , type: 'search' } ] expectError: # This test always errors in a non-Atlas environment. The test functions as a unit test by asserting # that the driver constructs and sends the correct command. @@ -82,5 +82,27 @@ tests: - commandStartedEvent: command: createSearchIndexes: *collection0 - indexes: [ { definition: *definition, name: 'test index' } ] + indexes: [ { definition: *definition, name: 'test index', type: 'search' } ] + $db: *database0 + + - description: "create a vector search index" + operations: + - name: createSearchIndexes + object: *collection0 + arguments: + models: [ { definition: &definition { fields: [ {"type": "vector", "path": "plot_embedding", "numDimensions": 1536, "similarity": "euclidean"} ] }, + name: 'test index' , type: 'vectorSearch' } ] + expectError: + # This test always errors in a non-Atlas environment. The test functions as a unit test by asserting + # that the driver constructs and sends the correct command. + # The expected error message was changed in SERVER-83003. Check for the substring "Atlas" shared by both error messages. + isError: true + errorContains: Atlas + expectEvents: + - client: *client0 + events: + - commandStartedEvent: + command: + createSearchIndexes: *collection0 + indexes: [ { definition: *definition, name: 'test index', type: 'vectorSearch' } ] $db: *database0 From 8f7b20fdd3939835289438840d06b0604d186f3c Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 5 Jun 2024 10:58:51 -0500 Subject: [PATCH 4/5] GODRIVER-3187 [master] Add SBOM lite file (#1656) --- sbom.json | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 sbom.json diff --git a/sbom.json b/sbom.json new file mode 100644 index 0000000000..433e3b2de3 --- /dev/null +++ b/sbom.json @@ -0,0 +1,10 @@ +{ + "metadata": { + "timestamp": "2024-05-02T17:36:29.429171+00:00" + }, + "serialNumber": "urn:uuid:06a59521-ad52-420b-aee6-7d9ed15e1fd9", + "version": 1, + "$schema": "http://cyclonedx.org/schema/bom-1.5.schema.json", + "bomFormat": "CycloneDX", + "specVersion": "1.5" + } \ No newline at end of file From bfe610f053e157628a93e19480a18c6c49fbedd8 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Wed, 5 Jun 2024 13:02:22 -0600 Subject: [PATCH 5/5] GODRIVER-2965 Internalize description package (#1621) Co-authored-by: Matt Dale <9760375+matthewdale@users.noreply.github.com> --- event/description.go | 58 + event/monitoring.go | 11 +- .../driverutil/description.go | 397 +++-- .../initial_dns_seedlist_discovery_test.go | 19 +- .../integration/mtest/opmsg_deployment.go | 8 +- internal/integration/mtest/setup.go | 10 +- internal/integration/sdam_prose_test.go | 2 +- .../server_selection_prose_test.go | 10 +- internal/integration/unified/event.go | 4 +- .../integration/unified/event_verification.go | 4 +- .../unified/testrunner_operation.go | 6 +- .../unified_runner_events_helper_test.go | 7 +- internal/integration/unified_spec_test.go | 8 +- internal/integtest/integtest.go | 7 +- internal/serverselector/server_selector.go | 359 +++++ .../serverselector/server_selector_test.go | 1278 +++++++++++++++++ mongo/bulk_write.go | 2 +- mongo/change_stream.go | 16 +- mongo/change_stream_deployment.go | 2 +- mongo/client.go | 20 +- mongo/collection.go | 74 +- mongo/database.go | 54 +- mongo/description/description.go | 11 - mongo/description/max_staleness_spec_test.go | 33 - mongo/description/selector_spec_test.go | 38 - mongo/description/selector_test.go | 874 ----------- mongo/description/server_kind.go | 46 - mongo/description/server_selector.go | 420 ------ mongo/description/server_test.go | 72 - mongo/description/shared_spec_test.go | 296 ---- mongo/description/topology.go | 142 -- mongo/description/topology_kind.go | 40 - mongo/description/topology_version.go | 65 - mongo/description/version_range.go | 42 - mongo/description/version_range_test.go | 34 - mongo/index_view.go | 15 +- mongo/session.go | 6 +- mongo/with_transactions_test.go | 6 +- x/mongo/driver/auth/auth.go | 4 +- x/mongo/driver/auth/gssapi_test.go | 2 +- x/mongo/driver/auth/mongodbcr_test.go | 2 +- x/mongo/driver/auth/plain_test.go | 2 +- x/mongo/driver/auth/scram_test.go | 2 +- x/mongo/driver/batch_cursor.go | 7 +- x/mongo/driver/description/server.go | 144 ++ x/mongo/driver/description/topology.go | 60 + x/mongo/driver/driver.go | 12 +- x/mongo/driver/drivertest/channel_conn.go | 2 +- x/mongo/driver/errors.go | 5 +- x/mongo/driver/integration/aggregate_test.go | 12 +- x/mongo/driver/integration/main_test.go | 6 +- x/mongo/driver/integration/scram_test.go | 9 +- x/mongo/driver/mnet/connection.go | 2 +- x/mongo/driver/operation.go | 51 +- x/mongo/driver/operation/abort_transaction.go | 2 +- x/mongo/driver/operation/aggregate.go | 5 +- x/mongo/driver/operation/command.go | 2 +- .../driver/operation/commit_transaction.go | 2 +- x/mongo/driver/operation/count.go | 2 +- x/mongo/driver/operation/create.go | 5 +- x/mongo/driver/operation/create_indexes.go | 4 +- .../driver/operation/create_search_indexes.go | 2 +- x/mongo/driver/operation/delete.go | 4 +- x/mongo/driver/operation/distinct.go | 4 +- x/mongo/driver/operation/drop_collection.go | 2 +- x/mongo/driver/operation/drop_database.go | 2 +- x/mongo/driver/operation/drop_indexes.go | 2 +- x/mongo/driver/operation/drop_search_index.go | 2 +- x/mongo/driver/operation/end_sessions.go | 2 +- x/mongo/driver/operation/find.go | 6 +- x/mongo/driver/operation/find_and_modify.go | 8 +- x/mongo/driver/operation/hello.go | 4 +- x/mongo/driver/operation/insert.go | 6 +- x/mongo/driver/operation/listDatabases.go | 2 +- x/mongo/driver/operation/list_collections.go | 2 +- x/mongo/driver/operation/list_indexes.go | 2 +- x/mongo/driver/operation/update.go | 8 +- .../driver/operation/update_search_index.go | 2 +- x/mongo/driver/operation_test.go | 44 +- x/mongo/driver/session/client_session.go | 11 +- x/mongo/driver/session/client_session_test.go | 4 +- x/mongo/driver/session/server_session.go | 4 +- x/mongo/driver/session/server_session_test.go | 4 +- x/mongo/driver/session/session_pool.go | 2 +- x/mongo/driver/session/session_pool_test.go | 2 +- x/mongo/driver/topology/connection.go | 5 +- x/mongo/driver/topology/connection_test.go | 2 +- x/mongo/driver/topology/diff.go | 2 +- x/mongo/driver/topology/diff_test.go | 2 +- x/mongo/driver/topology/errors.go | 2 +- x/mongo/driver/topology/fsm.go | 76 +- x/mongo/driver/topology/fsm_test.go | 16 +- .../topology/polling_srv_records_test.go | 4 +- x/mongo/driver/topology/sdam_spec_test.go | 19 +- x/mongo/driver/topology/server.go | 55 +- x/mongo/driver/topology/server_test.go | 50 +- x/mongo/driver/topology/topology.go | 120 +- .../driver/topology/topology_errors_test.go | 7 +- x/mongo/driver/topology/topology_test.go | 151 +- 99 files changed, 2735 insertions(+), 2747 deletions(-) create mode 100644 event/description.go rename mongo/description/server.go => internal/driverutil/description.go (69%) create mode 100644 internal/serverselector/server_selector.go create mode 100644 internal/serverselector/server_selector_test.go delete mode 100644 mongo/description/description.go delete mode 100644 mongo/description/max_staleness_spec_test.go delete mode 100644 mongo/description/selector_spec_test.go delete mode 100644 mongo/description/selector_test.go delete mode 100644 mongo/description/server_kind.go delete mode 100644 mongo/description/server_selector.go delete mode 100644 mongo/description/server_test.go delete mode 100644 mongo/description/shared_spec_test.go delete mode 100644 mongo/description/topology.go delete mode 100644 mongo/description/topology_kind.go delete mode 100644 mongo/description/topology_version.go delete mode 100644 mongo/description/version_range.go delete mode 100644 mongo/description/version_range_test.go create mode 100644 x/mongo/driver/description/server.go create mode 100644 x/mongo/driver/description/topology.go diff --git a/event/description.go b/event/description.go new file mode 100644 index 0000000000..682d61c1c9 --- /dev/null +++ b/event/description.go @@ -0,0 +1,58 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package event + +import ( + "time" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo/address" + "go.mongodb.org/mongo-driver/tag" +) + +// ServerDescription contains information about a node in a cluster. This is +// created from hello command responses. If the value of the Kind field is +// LoadBalancer, only the Addr and Kind fields will be set. All other fields +// will be set to the zero value of the field's type. +type ServerDescription struct { + Addr address.Address + Arbiters []string + Compression []string // compression methods returned by server + CanonicalAddr address.Address + ElectionID bson.ObjectID + IsCryptd bool + HelloOK bool + Hosts []string + Kind string + LastWriteTime time.Time + MaxBatchCount uint32 + MaxDocumentSize uint32 + MaxMessageSize uint32 + MaxWireVersion int32 + MinWireVersion int32 + Members []address.Address + Passives []string + Passive bool + Primary address.Address + ReadOnly bool + ServiceID *bson.ObjectID // Only set for servers that are deployed behind a load balancer. + SessionTimeoutMinutes *int64 + SetName string + SetVersion uint32 + Tags tag.Set + TopologyVersionProcessID bson.ObjectID + TopologyVersionCounter int64 +} + +// TopologyDescription contains information about a MongoDB cluster. +type TopologyDescription struct { + Servers []ServerDescription + SetName string + Kind string + SessionTimeoutMinutes *int64 + CompatibilityErr error +} diff --git a/event/monitoring.go b/event/monitoring.go index 610b737e92..a41da0172a 100644 --- a/event/monitoring.go +++ b/event/monitoring.go @@ -12,7 +12,6 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" ) // CommandStartedEvent represents an event generated when a command is sent to a server. @@ -120,8 +119,8 @@ type PoolMonitor struct { type ServerDescriptionChangedEvent struct { Address address.Address TopologyID bson.ObjectID // A unique identifier for the topology this server is a part of - PreviousDescription description.Server - NewDescription description.Server + PreviousDescription ServerDescription + NewDescription ServerDescription } // ServerOpeningEvent is an event generated when the server is initialized. @@ -139,8 +138,8 @@ type ServerClosedEvent struct { // TopologyDescriptionChangedEvent represents a topology description change. type TopologyDescriptionChangedEvent struct { TopologyID bson.ObjectID // A unique identifier for the topology this server is a part of - PreviousDescription description.Topology - NewDescription description.Topology + PreviousDescription TopologyDescription + NewDescription TopologyDescription } // TopologyOpeningEvent is an event generated when the topology is initialized. @@ -162,7 +161,7 @@ type ServerHeartbeatStartedEvent struct { // ServerHeartbeatSucceededEvent is an event generated when the heartbeat succeeds. type ServerHeartbeatSucceededEvent struct { Duration time.Duration - Reply description.Server + Reply ServerDescription ConnectionID string // The address this heartbeat was sent to with a unique identifier Awaited bool // If this heartbeat was awaitable } diff --git a/mongo/description/server.go b/internal/driverutil/description.go similarity index 69% rename from mongo/description/server.go rename to internal/driverutil/description.go index 3080a43045..6c04ec0b0a 100644 --- a/mongo/description/server.go +++ b/internal/driverutil/description.go @@ -1,10 +1,10 @@ -// Copyright (C) MongoDB, Inc. 2017-present. +// Copyright (C) MongoDB, Inc. 2024-present. // // Licensed under the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -package description +package driverutil import ( "errors" @@ -17,55 +17,192 @@ import ( "go.mongodb.org/mongo-driver/internal/ptrutil" "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/tag" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) -// SelectedServer augments the Server type by also including the TopologyKind of the topology that includes the server. -// This type should be used to track the state of a server that was selected to perform an operation. -type SelectedServer struct { - Server - Kind TopologyKind +func equalWireVersion(wv1, wv2 *description.VersionRange) bool { + if wv1 == nil && wv2 == nil { + return true + } + + if wv1 == nil || wv2 == nil { + return false + } + + return wv1.Min == wv2.Min && wv1.Max == wv2.Max +} + +// EqualServers compares two server descriptions and returns true if they are +// equal. +func EqualServers(srv1, srv2 description.Server) bool { + if srv1.CanonicalAddr.String() != srv2.CanonicalAddr.String() { + return false + } + + if !sliceStringEqual(srv1.Arbiters, srv2.Arbiters) { + return false + } + + if !sliceStringEqual(srv1.Hosts, srv2.Hosts) { + return false + } + + if !sliceStringEqual(srv1.Passives, srv2.Passives) { + return false + } + + if srv1.Primary != srv2.Primary { + return false + } + + if srv1.SetName != srv2.SetName { + return false + } + + if srv1.Kind != srv2.Kind { + return false + } + + if srv1.LastError != nil || srv2.LastError != nil { + if srv1.LastError == nil || srv2.LastError == nil { + return false + } + if srv1.LastError.Error() != srv2.LastError.Error() { + return false + } + } + + //if !s.WireVersion.Equals(other.WireVersion) { + if !equalWireVersion(srv1.WireVersion, srv2.WireVersion) { + return false + } + + if len(srv1.Tags) != len(srv2.Tags) || !srv1.Tags.ContainsAll(srv2.Tags) { + return false + } + + if srv1.SetVersion != srv2.SetVersion { + return false + } + + if srv1.ElectionID != srv2.ElectionID { + return false + } + + if ptrutil.CompareInt64(srv1.SessionTimeoutMinutes, srv2.SessionTimeoutMinutes) != 0 { + return false + } + + // If TopologyVersion is nil for both servers, CompareToIncoming will return -1 because it assumes that the + // incoming response is newer. We want the descriptions to be considered equal in this case, though, so an + // explicit check is required. + if srv1.TopologyVersion == nil && srv2.TopologyVersion == nil { + return true + } + + //return s.TopologyVersion.CompareToIncoming(other.TopologyVersion) == 0 + + return CompareTopologyVersions(srv1.TopologyVersion, srv2.TopologyVersion) == 0 +} + +// IsServerLoadBalanced checks if a description.Server describes a server that +// is load balanced. +func IsServerLoadBalanced(srv description.Server) bool { + return srv.Kind == description.ServerKindLoadBalancer || srv.ServiceID != nil +} + +// stringSliceFromRawElement decodes the provided BSON element into a []string. +// This internally calls StringSliceFromRawValue on the element's value. The +// error conditions outlined in that function's documentation apply for this +// function as well. +func stringSliceFromRawElement(element bson.RawElement) ([]string, error) { + return bsonutil.StringSliceFromRawValue(element.Key(), element.Value()) +} + +func decodeStringMap(element bson.RawElement, name string) (map[string]string, error) { + doc, ok := element.Value().DocumentOK() + if !ok { + return nil, fmt.Errorf("expected '%s' to be a document but it's a BSON %s", name, element.Value().Type) + } + elements, err := doc.Elements() + if err != nil { + return nil, err + } + m := make(map[string]string) + for _, element := range elements { + key := element.Key() + value, ok := element.Value().StringValueOK() + if !ok { + return nil, fmt.Errorf("expected '%s' to be a document of strings, but found a BSON %s", name, element.Value().Type) + } + m[key] = value + } + return m, nil +} + +// NewTopologyVersion creates a TopologyVersion based on doc +func NewTopologyVersion(doc bson.Raw) (*description.TopologyVersion, error) { + elements, err := doc.Elements() + if err != nil { + return nil, err + } + var tv description.TopologyVersion + var ok bool + for _, element := range elements { + switch element.Key() { + case "processId": + tv.ProcessID, ok = element.Value().ObjectIDOK() + if !ok { + return nil, fmt.Errorf("expected 'processId' to be a objectID but it's a BSON %s", element.Value().Type) + } + case "counter": + tv.Counter, ok = element.Value().Int64OK() + if !ok { + return nil, fmt.Errorf("expected 'counter' to be an int64 but it's a BSON %s", element.Value().Type) + } + } + } + return &tv, nil +} + +// NewVersionRange creates a new VersionRange given a min and a max. +func NewVersionRange(min, max int32) description.VersionRange { + return description.VersionRange{Min: min, Max: max} } -// Server contains information about a node in a cluster. This is created from hello command responses. If the value -// of the Kind field is LoadBalancer, only the Addr and Kind fields will be set. All other fields will be set to the -// zero value of the field's type. -type Server struct { - Addr address.Address - - Arbiters []string - AverageRTT time.Duration - AverageRTTSet bool - Compression []string // compression methods returned by server - CanonicalAddr address.Address - ElectionID bson.ObjectID - HeartbeatInterval time.Duration - HelloOK bool - Hosts []string - IsCryptd bool - LastError error - LastUpdateTime time.Time - LastWriteTime time.Time - MaxBatchCount uint32 - MaxDocumentSize uint32 - MaxMessageSize uint32 - Members []address.Address - Passives []string - Passive bool - Primary address.Address - ReadOnly bool - ServiceID *bson.ObjectID // Only set for servers that are deployed behind a load balancer. - SessionTimeoutMinutes *int64 - SetName string - SetVersion uint32 - Tags tag.Set - TopologyVersion *TopologyVersion - Kind ServerKind - WireVersion *VersionRange +// VersionRangeIncludes returns a bool indicating whether the supplied integer +// is included in the range. +func VersionRangeIncludes(versionRange description.VersionRange, v int32) bool { + return v >= versionRange.Min && v <= versionRange.Max } -// NewServer creates a new server description from the given hello command response. -func NewServer(addr address.Address, response bson.Raw) Server { - desc := Server{Addr: addr, CanonicalAddr: addr, LastUpdateTime: time.Now().UTC()} +// CompareTopologyVersions compares the receiver, which represents the currently +// known TopologyVersion for a server, to an incoming TopologyVersion extracted +// from a server command response. +// +// This returns -1 if the receiver version is less than the response, 0 if the +// versions are equal, and 1 if the receiver version is greater than the +// response. This comparison is not commutative. +func CompareTopologyVersions(receiver, response *description.TopologyVersion) int { + if receiver == nil || response == nil { + return -1 + } + if receiver.ProcessID != response.ProcessID { + return -1 + } + if receiver.Counter == response.Counter { + return 0 + } + if receiver.Counter < response.Counter { + return -1 + } + return 1 +} + +// NewServerDescription creates a new server description from the given hello +// command response. +func NewServerDescription(addr address.Address, response bson.Raw) description.Server { + desc := description.Server{Addr: addr, CanonicalAddr: addr, LastUpdateTime: time.Now().UTC()} elements, err := response.Elements() if err != nil { desc.LastError = err @@ -74,7 +211,7 @@ func NewServer(addr address.Address, response bson.Raw) Server { var ok bool var isReplicaSet, isWritablePrimary, hidden, secondary, arbiterOnly bool var msg string - var versionRange VersionRange + var versionRange description.VersionRange for _, element := range elements { switch element.Key() { case "arbiters": @@ -312,24 +449,24 @@ func NewServer(addr address.Address, response bson.Raw) Server { desc.Members = append(desc.Members, address.Address(arbiter).Canonicalize()) } - desc.Kind = Standalone + desc.Kind = description.ServerKindStandalone if isReplicaSet { - desc.Kind = RSGhost + desc.Kind = description.ServerKindRSGhost } else if desc.SetName != "" { if isWritablePrimary { - desc.Kind = RSPrimary + desc.Kind = description.ServerKindRSPrimary } else if hidden { - desc.Kind = RSMember + desc.Kind = description.ServerKindRSMember } else if secondary { - desc.Kind = RSSecondary + desc.Kind = description.ServerKindRSSecondary } else if arbiterOnly { - desc.Kind = RSArbiter + desc.Kind = description.ServerKindRSArbiter } else { - desc.Kind = RSMember + desc.Kind = description.ServerKindRSMember } } else if msg == "isdbgrid" { - desc.Kind = Mongos + desc.Kind = description.ServerKindMongos } desc.WireVersion = &versionRange @@ -337,164 +474,16 @@ func NewServer(addr address.Address, response bson.Raw) Server { return desc } -// NewDefaultServer creates a new unknown server description with the given address. -func NewDefaultServer(addr address.Address) Server { - return NewServerFromError(addr, nil, nil) -} - -// NewServerFromError creates a new unknown server description with the given parameters. -func NewServerFromError(addr address.Address, err error, tv *TopologyVersion) Server { - return Server{ - Addr: addr, - LastError: err, - Kind: Unknown, - TopologyVersion: tv, - } -} - -// SetAverageRTT sets the average round trip time for this server description. -func (s Server) SetAverageRTT(rtt time.Duration) Server { - s.AverageRTT = rtt - s.AverageRTTSet = true - return s -} - -// DataBearing returns true if the server is a data bearing server. -func (s Server) DataBearing() bool { - return s.Kind == RSPrimary || - s.Kind == RSSecondary || - s.Kind == Mongos || - s.Kind == Standalone -} - -// LoadBalanced returns true if the server is a load balancer or is behind a load balancer. -func (s Server) LoadBalanced() bool { - return s.Kind == LoadBalancer || s.ServiceID != nil -} - -// String implements the Stringer interface -func (s Server) String() string { - str := fmt.Sprintf("Addr: %s, Type: %s", - s.Addr, s.Kind) - if len(s.Tags) != 0 { - str += fmt.Sprintf(", Tag sets: %s", s.Tags) - } - - if s.AverageRTTSet { - str += fmt.Sprintf(", Average RTT: %d", s.AverageRTT) - } - - if s.LastError != nil { - str += fmt.Sprintf(", Last error: %s", s.LastError) - } - return str -} - -func decodeStringMap(element bson.RawElement, name string) (map[string]string, error) { - doc, ok := element.Value().DocumentOK() - if !ok { - return nil, fmt.Errorf("expected '%s' to be a document but it's a BSON %s", name, element.Value().Type) - } - elements, err := doc.Elements() - if err != nil { - return nil, err - } - m := make(map[string]string) - for _, element := range elements { - key := element.Key() - value, ok := element.Value().StringValueOK() - if !ok { - return nil, fmt.Errorf("expected '%s' to be a document of strings, but found a BSON %s", name, element.Value().Type) - } - m[key] = value - } - return m, nil -} - -// Equal compares two server descriptions and returns true if they are equal -func (s Server) Equal(other Server) bool { - if s.CanonicalAddr.String() != other.CanonicalAddr.String() { - return false - } - - if !sliceStringEqual(s.Arbiters, other.Arbiters) { - return false - } - - if !sliceStringEqual(s.Hosts, other.Hosts) { - return false - } - - if !sliceStringEqual(s.Passives, other.Passives) { - return false - } - - if s.Primary != other.Primary { - return false - } - - if s.SetName != other.SetName { - return false - } - - if s.Kind != other.Kind { - return false - } - - if s.LastError != nil || other.LastError != nil { - if s.LastError == nil || other.LastError == nil { - return false - } - if s.LastError.Error() != other.LastError.Error() { - return false - } - } - - if !s.WireVersion.Equals(other.WireVersion) { - return false - } - - if len(s.Tags) != len(other.Tags) || !s.Tags.ContainsAll(other.Tags) { - return false - } - - if s.SetVersion != other.SetVersion { - return false - } - - if s.ElectionID != other.ElectionID { - return false - } - - if ptrutil.CompareInt64(s.SessionTimeoutMinutes, other.SessionTimeoutMinutes) != 0 { - return false - } - - // If TopologyVersion is nil for both servers, CompareToIncoming will return -1 because it assumes that the - // incoming response is newer. We want the descriptions to be considered equal in this case, though, so an - // explicit check is required. - if s.TopologyVersion == nil && other.TopologyVersion == nil { - return true - } - return s.TopologyVersion.CompareToIncoming(other.TopologyVersion) == 0 -} - func sliceStringEqual(a []string, b []string) bool { if len(a) != len(b) { return false } + for i, v := range a { if v != b[i] { return false } } - return true -} -// stringSliceFromRawElement decodes the provided BSON element into a []string. -// This internally calls StringSliceFromRawValue on the element's value. The -// error conditions outlined in that function's documentation apply for this -// function as well. -func stringSliceFromRawElement(element bson.RawElement) ([]string, error) { - return bsonutil.StringSliceFromRawValue(element.Key(), element.Value()) + return true } diff --git a/internal/integration/initial_dns_seedlist_discovery_test.go b/internal/integration/initial_dns_seedlist_discovery_test.go index df0e7dbe04..d4d4e20799 100644 --- a/internal/integration/initial_dns_seedlist_discovery_test.go +++ b/internal/integration/initial_dns_seedlist_discovery_test.go @@ -20,11 +20,12 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/integration/mtest" + "go.mongodb.org/mongo-driver/internal/serverselector" "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/topology" ) @@ -248,14 +249,16 @@ func getSSLSettings(mt *mtest.T, test seedlistTest) *tls.Config { } func getServerByAddress(address string, topo *topology.Topology) (description.Server, error) { - selectByName := description.ServerSelectorFunc(func(_ description.Topology, servers []description.Server) ([]description.Server, error) { - for _, s := range servers { - if s.Addr.String() == address { - return []description.Server{s}, nil + selectByName := serverselector.Func( + func(_ description.Topology, servers []description.Server) ([]description.Server, error) { + for _, s := range servers { + if s.Addr.String() == address { + return []description.Server{s}, nil + } } - } - return []description.Server{}, nil - }) + + return []description.Server{}, nil + }) selectedServer, err := topo.SelectServer(context.Background(), selectByName) if err != nil { diff --git a/internal/integration/mtest/opmsg_deployment.go b/internal/integration/mtest/opmsg_deployment.go index dc15831fe5..6a0a1021c1 100644 --- a/internal/integration/mtest/opmsg_deployment.go +++ b/internal/integration/mtest/opmsg_deployment.go @@ -13,9 +13,9 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/csot" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/topology" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" @@ -39,7 +39,7 @@ var ( MaxMessageSize: maxMessageSize, MaxBatchCount: maxBatchCount, SessionTimeoutMinutes: &sessionTimeoutMinutes, - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, WireVersion: &description.VersionRange{ Max: topology.SupportedWireVersions.Max, }, @@ -133,9 +133,9 @@ func (md *mockDeployment) SelectServer(context.Context, description.ServerSelect return md, nil } -// Kind implements the Deployment interface. It always returns description.Single. +// Kind implements the Deployment interface. It always returns description.TopologyKindSingle. func (md *mockDeployment) Kind() description.TopologyKind { - return description.Single + return description.TopologyKindSingle } // Connection implements the driver.Server interface. diff --git a/internal/integration/mtest/setup.go b/internal/integration/mtest/setup.go index f7fa5c5243..48b0e9c9cb 100644 --- a/internal/integration/mtest/setup.go +++ b/internal/integration/mtest/setup.go @@ -19,12 +19,12 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/integtest" "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/topology" ) @@ -136,13 +136,13 @@ func Setup(setupOpts ...*SetupOptions) error { } switch testContext.topo.Kind() { - case description.Single: + case description.TopologyKindSingle: testContext.topoKind = Single - case description.ReplicaSet, description.ReplicaSetWithPrimary, description.ReplicaSetNoPrimary: + case description.TopologyKindReplicaSet, description.TopologyKindReplicaSetWithPrimary, description.TopologyKindReplicaSetNoPrimary: testContext.topoKind = ReplicaSet - case description.Sharded: + case description.TopologyKindSharded: testContext.topoKind = Sharded - case description.LoadBalanced: + case description.TopologyKindLoadBalanced: testContext.topoKind = LoadBalanced default: return fmt.Errorf("could not detect topology kind; current topology: %s", testContext.topo.String()) diff --git a/internal/integration/sdam_prose_test.go b/internal/integration/sdam_prose_test.go index 79c35abf95..5aa3358905 100644 --- a/internal/integration/sdam_prose_test.go +++ b/internal/integration/sdam_prose_test.go @@ -21,8 +21,8 @@ import ( "go.mongodb.org/mongo-driver/internal/integration/mtest" "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/topology" ) diff --git a/internal/integration/server_selection_prose_test.go b/internal/integration/server_selection_prose_test.go index 9e66eadc75..bc59c71197 100644 --- a/internal/integration/server_selection_prose_test.go +++ b/internal/integration/server_selection_prose_test.go @@ -18,8 +18,8 @@ import ( "go.mongodb.org/mongo-driver/internal/eventtest" "go.mongodb.org/mongo-driver/internal/integration/mtest" "go.mongodb.org/mongo-driver/internal/require" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) type saturatedConnections map[int64]bool @@ -159,7 +159,9 @@ func TestServerSelectionProse(t *testing.T) { })) for evt := range topologyEvents { servers := evt.NewDescription.Servers - if len(servers) == 2 && servers[0].Kind == description.Mongos && servers[1].Kind == description.Mongos { + if len(servers) == 2 && servers[0].Kind == description.ServerKindMongos.String() && + servers[1].Kind == description.ServerKindMongos.String() { + break } } @@ -209,7 +211,9 @@ func TestServerSelectionProse(t *testing.T) { })) for evt := range topologyEvents { servers := evt.NewDescription.Servers - if len(servers) == 2 && servers[0].Kind == description.Mongos && servers[1].Kind == description.Mongos { + if len(servers) == 2 && servers[0].Kind == description.ServerKindMongos.String() && + servers[1].Kind == description.ServerKindMongos.String() { + break } } diff --git a/internal/integration/unified/event.go b/internal/integration/unified/event.go index 206c78e33f..d62b384ba5 100644 --- a/internal/integration/unified/event.go +++ b/internal/integration/unified/event.go @@ -135,10 +135,10 @@ type serverDescriptionChangedEventInfo struct { func newServerDescriptionChangedEventInfo(evt *event.ServerDescriptionChangedEvent) *serverDescriptionChangedEventInfo { return &serverDescriptionChangedEventInfo{ NewDescription: &serverDescription{ - Type: evt.NewDescription.Kind.String(), + Type: evt.NewDescription.Kind, }, PreviousDescription: &serverDescription{ - Type: evt.PreviousDescription.Kind.String(), + Type: evt.PreviousDescription.Kind, }, } } diff --git a/internal/integration/unified/event_verification.go b/internal/integration/unified/event_verification.go index 784e8198cf..d6f64723e3 100644 --- a/internal/integration/unified/event_verification.go +++ b/internal/integration/unified/event_verification.go @@ -523,7 +523,7 @@ func verifySDAMEvents(client *clientEntity, expectedEvents *expectedEvents) erro wantPrevDesc = *prevDesc.Type } - gotPrevDesc := got.PreviousDescription.Kind.String() + gotPrevDesc := got.PreviousDescription.Kind if gotPrevDesc != wantPrevDesc { return newEventVerificationError(idx, client, "expected previous server description %q, got %q", wantPrevDesc, gotPrevDesc) @@ -536,7 +536,7 @@ func verifySDAMEvents(client *clientEntity, expectedEvents *expectedEvents) erro wantNewDesc = *newDesc.Type } - gotNewDesc := got.NewDescription.Kind.String() + gotNewDesc := got.NewDescription.Kind if gotNewDesc != wantNewDesc { return newEventVerificationError(idx, client, "expected new server description %q, got %q", wantNewDesc, gotNewDesc) diff --git a/internal/integration/unified/testrunner_operation.go b/internal/integration/unified/testrunner_operation.go index 1079f33840..a5dbc3e75a 100644 --- a/internal/integration/unified/testrunner_operation.go +++ b/internal/integration/unified/testrunner_operation.go @@ -113,11 +113,11 @@ func executeTestRunnerOperation(ctx context.Context, op *operation, loopDone <-c } clientSession := extractClientSession(sess) - if clientSession.PinnedServer == nil { + if clientSession.PinnedServerAddr == nil { return fmt.Errorf("session is not pinned to a server") } - targetHost := clientSession.PinnedServer.Addr.String() + targetHost := clientSession.PinnedServerAddr.String() fpDoc := args.Lookup("failPoint").Document() commandFn := func(ctx context.Context, client *mongo.Client) error { return mtest.SetRawFailPoint(fpDoc, client) @@ -453,7 +453,7 @@ func verifySessionPinnedState(ctx context.Context, sessionID string, expectedPin return err } - if isPinned := extractClientSession(sess).PinnedServer != nil; expectedPinned != isPinned { + if isPinned := extractClientSession(sess).PinnedServerAddr != nil; expectedPinned != isPinned { return fmt.Errorf("session pinned state mismatch; expected to be pinned: %v, is pinned: %v", expectedPinned, isPinned) } return nil diff --git a/internal/integration/unified_runner_events_helper_test.go b/internal/integration/unified_runner_events_helper_test.go index 832684f026..2937882dcd 100644 --- a/internal/integration/unified_runner_events_helper_test.go +++ b/internal/integration/unified_runner_events_helper_test.go @@ -14,9 +14,10 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/integration/mtest" + "go.mongodb.org/mongo-driver/internal/serverselector" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readpref" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/topology" ) @@ -50,7 +51,7 @@ func newUnifiedRunnerEventMonitor() *unifiedRunnerEventMonitor { // Spec tests only ever handle ServerMarkedUnknown ServerDescriptionChangedEvents // for the time being. - if e.NewDescription.Kind == description.Unknown { + if e.NewDescription.Kind == description.UnknownStr { urem.serverMarkedUnknownCount++ } }), @@ -164,7 +165,7 @@ func getPrimaryAddress(mt *mtest.T, topo *topology.Topology, failFast bool) addr cancel() } - primary, err := topo.SelectServer(ctx, description.ReadPrefSelector(readpref.Primary())) + primary, err := topo.SelectServer(ctx, &serverselector.ReadPref{ReadPref: readpref.Primary()}) assert.Nil(mt, err, "SelectServer error: %v", err) return primary.(*topology.SelectedServer).Description().Addr } diff --git a/internal/integration/unified_spec_test.go b/internal/integration/unified_spec_test.go index cba3244db3..bcfb5d72f8 100644 --- a/internal/integration/unified_spec_test.go +++ b/internal/integration/unified_spec_test.go @@ -457,7 +457,7 @@ func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation, if clientSession == nil { return errors.New("expected valid session, got nil") } - targetHost := clientSession.PinnedServer.Addr.String() + targetHost := clientSession.PinnedServerAddr.String() opts := options.Client().ApplyURI(mtest.ClusterURI()).SetHosts([]string{targetHost}) integtest.AddTestServerAPIVersion(opts) client, err := mongo.Connect(opts) @@ -502,7 +502,7 @@ func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation, if clientSession == nil { return errors.New("expected valid session, got nil") } - if clientSession.PinnedServer == nil { + if clientSession.PinnedServerAddr == nil { return errors.New("expected pinned server, got nil") } case "assertSessionUnpinned": @@ -511,8 +511,8 @@ func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation, } // We don't use a combined helper for assertSessionPinned and assertSessionUnpinned because the unpinned // case provides the pinned server address in the error msg for debugging. - if clientSession.PinnedServer != nil { - return fmt.Errorf("expected pinned server to be nil but got %q", clientSession.PinnedServer.Addr) + if clientSession.PinnedServerAddr != nil { + return fmt.Errorf("expected pinned server to be nil but got %q", clientSession.PinnedServerAddr) } case "assertSameLsidOnLastTwoCommands": first, second := lastTwoIDs(mt) diff --git a/internal/integtest/integtest.go b/internal/integtest/integtest.go index fb7fbf459f..23758e54de 100644 --- a/internal/integtest/integtest.go +++ b/internal/integtest/integtest.go @@ -20,7 +20,7 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/require" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/internal/serverselector" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" @@ -102,7 +102,7 @@ func MonitoredTopology(t *testing.T, dbName string, monitor *event.CommandMonito _ = monitoredTopology.Connect() err = operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dropDatabase", 1))). - Database(dbName).ServerSelector(description.WriteSelector()).Deployment(monitoredTopology).Execute(context.Background()) + Database(dbName).ServerSelector(&serverselector.Write{}).Deployment(monitoredTopology).Execute(context.Background()) require.NoError(t, err) } @@ -126,7 +126,8 @@ func Topology(t *testing.T) *topology.Topology { _ = liveTopology.Connect() err = operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dropDatabase", 1))). - Database(DBName(t)).ServerSelector(description.WriteSelector()).Deployment(liveTopology).Execute(context.Background()) + Database(DBName(t)).ServerSelector(&serverselector.Write{}). + Deployment(liveTopology).Execute(context.Background()) require.NoError(t, err) } }) diff --git a/internal/serverselector/server_selector.go b/internal/serverselector/server_selector.go new file mode 100644 index 0000000000..4599b0f9d3 --- /dev/null +++ b/internal/serverselector/server_selector.go @@ -0,0 +1,359 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package serverselector + +import ( + "fmt" + "math" + "time" + + "go.mongodb.org/mongo-driver/mongo/readpref" + "go.mongodb.org/mongo-driver/tag" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" +) + +// Composite combines multiple selectors into a single selector by applying them +// in order to the candidates list. +// +// For example, if the initial candidates list is [s0, s1, s2, s3] and two +// selectors are provided where the first matches s0 and s1 and the second +// matches s1 and s2, the following would occur during server selection: +// +// 1. firstSelector([s0, s1, s2, s3]) -> [s0, s1] +// 2. secondSelector([s0, s1]) -> [s1] +// +// The final list of candidates returned by the composite selector would be +// [s1]. +type Composite struct { + Selectors []description.ServerSelector +} + +var _ description.ServerSelector = &Composite{} + +// SelectServer combines multiple selectors into a single selector. +func (selector *Composite) SelectServer( + topo description.Topology, + candidates []description.Server, +) ([]description.Server, error) { + var err error + for _, sel := range selector.Selectors { + candidates, err = sel.SelectServer(topo, candidates) + if err != nil { + return nil, err + } + } + + return candidates, nil +} + +// Latency creates a ServerSelector which selects servers based on their average +// RTT values. +type Latency struct { + Latency time.Duration +} + +var _ description.ServerSelector = &Latency{} + +// SelectServer selects servers based on average RTT. +func (selector *Latency) SelectServer( + topo description.Topology, + candidates []description.Server, +) ([]description.Server, error) { + if selector.Latency < 0 { + return candidates, nil + } + if topo.Kind == description.TopologyKindLoadBalanced { + // In LoadBalanced mode, there should only be one server in the topology and + // it must be selected. + return candidates, nil + } + + switch len(candidates) { + case 0, 1: + return candidates, nil + default: + min := time.Duration(math.MaxInt64) + for _, candidate := range candidates { + if candidate.AverageRTTSet { + if candidate.AverageRTT < min { + min = candidate.AverageRTT + } + } + } + + if min == math.MaxInt64 { + return candidates, nil + } + + max := min + selector.Latency + + viableIndexes := make([]int, 0, len(candidates)) + for i, candidate := range candidates { + if candidate.AverageRTTSet { + if candidate.AverageRTT <= max { + viableIndexes = append(viableIndexes, i) + } + } + } + if len(viableIndexes) == len(candidates) { + return candidates, nil + } + result := make([]description.Server, len(viableIndexes)) + for i, idx := range viableIndexes { + result[i] = candidates[idx] + } + return result, nil + } +} + +// ReadPref selects servers based on the provided read preference. +type ReadPref struct { + ReadPref *readpref.ReadPref + IsOutputAggregate bool +} + +var _ description.ServerSelector = &ReadPref{} + +// SelectServer selects servers based on read preference. +func (selector *ReadPref) SelectServer( + topo description.Topology, + candidates []description.Server, +) ([]description.Server, error) { + if topo.Kind == description.TopologyKindLoadBalanced { + // In LoadBalanced mode, there should only be one server in the topology and + // it must be selected. We check this before checking MaxStaleness support + // because there's no monitoring in this mode, so the candidate server + // wouldn't have a wire version set, which would result in an error. + return candidates, nil + } + + switch topo.Kind { + case description.TopologyKindSingle: + return candidates, nil + case description.TopologyKindReplicaSetNoPrimary, description.TopologyKindReplicaSetWithPrimary: + return selectForReplicaSet(selector.ReadPref, selector.IsOutputAggregate, topo, candidates) + case description.TopologyKindSharded: + return selectByKind(candidates, description.ServerKindMongos), nil + } + + return nil, nil +} + +// Write selects all the writable servers. +type Write struct{} + +var _ description.ServerSelector = &Write{} + +// SelectServer selects all writable servers. +func (selector *Write) SelectServer( + topo description.Topology, + candidates []description.Server, +) ([]description.Server, error) { + switch topo.Kind { + case description.TopologyKindSingle, description.TopologyKindLoadBalanced: + return candidates, nil + default: + // Determine the capacity of the results slice. + selected := 0 + for _, candidate := range candidates { + switch candidate.Kind { + case description.ServerKindMongos, description.ServerKindRSPrimary, description.ServerKindStandalone: + selected++ + } + } + + // Append candidates to the results slice. + result := make([]description.Server, 0, selected) + for _, candidate := range candidates { + switch candidate.Kind { + case description.ServerKindMongos, description.ServerKindRSPrimary, description.ServerKindStandalone: + result = append(result, candidate) + } + } + return result, nil + } +} + +// Func is a function that can be used as a ServerSelector. +type Func func(description.Topology, []description.Server) ([]description.Server, error) + +// SelectServer implements the ServerSelector interface. +func (ssf Func) SelectServer( + t description.Topology, + s []description.Server, +) ([]description.Server, error) { + return ssf(t, s) +} + +func verifyMaxStaleness(rp *readpref.ReadPref, topo description.Topology) error { + maxStaleness, set := rp.MaxStaleness() + if !set { + return nil + } + + if maxStaleness < 90*time.Second { + return fmt.Errorf("max staleness (%s) must be greater than or equal to 90s", maxStaleness) + } + + if len(topo.Servers) < 1 { + // Maybe we should return an error here instead? + return nil + } + + // we'll assume all candidates have the same heartbeat interval. + s := topo.Servers[0] + idleWritePeriod := 10 * time.Second + + if maxStaleness < s.HeartbeatInterval+idleWritePeriod { + return fmt.Errorf( + "max staleness (%s) must be greater than or equal to the heartbeat interval (%s) plus idle write period (%s)", + maxStaleness, s.HeartbeatInterval, idleWritePeriod, + ) + } + + return nil +} + +func selectByKind(candidates []description.Server, kind description.ServerKind) []description.Server { + // Record the indices of viable candidates first and then append those to the returned slice + // to avoid appending costly Server structs directly as an optimization. + viableIndexes := make([]int, 0, len(candidates)) + for i, s := range candidates { + if s.Kind == kind { + viableIndexes = append(viableIndexes, i) + } + } + if len(viableIndexes) == len(candidates) { + return candidates + } + result := make([]description.Server, len(viableIndexes)) + for i, idx := range viableIndexes { + result[i] = candidates[idx] + } + return result +} + +func selectSecondaries(rp *readpref.ReadPref, candidates []description.Server) []description.Server { + secondaries := selectByKind(candidates, description.ServerKindRSSecondary) + if len(secondaries) == 0 { + return secondaries + } + if maxStaleness, set := rp.MaxStaleness(); set { + primaries := selectByKind(candidates, description.ServerKindRSPrimary) + if len(primaries) == 0 { + baseTime := secondaries[0].LastWriteTime + for i := 1; i < len(secondaries); i++ { + if secondaries[i].LastWriteTime.After(baseTime) { + baseTime = secondaries[i].LastWriteTime + } + } + + var selected []description.Server + for _, secondary := range secondaries { + estimatedStaleness := baseTime.Sub(secondary.LastWriteTime) + secondary.HeartbeatInterval + if estimatedStaleness <= maxStaleness { + selected = append(selected, secondary) + } + } + + return selected + } + + primary := primaries[0] + + var selected []description.Server + for _, secondary := range secondaries { + estimatedStaleness := secondary.LastUpdateTime.Sub(secondary.LastWriteTime) - + primary.LastUpdateTime.Sub(primary.LastWriteTime) + secondary.HeartbeatInterval + if estimatedStaleness <= maxStaleness { + selected = append(selected, secondary) + } + } + return selected + } + + return secondaries +} + +func selectByTagSet(candidates []description.Server, tagSets []tag.Set) []description.Server { + if len(tagSets) == 0 { + return candidates + } + + for _, ts := range tagSets { + // If this tag set is empty, we can take a fast path because the empty list + // is a subset of all tag sets, so all candidate servers will be selected. + if len(ts) == 0 { + return candidates + } + + var results []description.Server + for _, s := range candidates { + // ts is non-empty, so only servers with a non-empty set of tags need to be checked. + if len(s.Tags) > 0 && s.Tags.ContainsAll(ts) { + results = append(results, s) + } + } + + if len(results) > 0 { + return results + } + } + + return []description.Server{} +} + +func selectForReplicaSet( + rp *readpref.ReadPref, + isOutputAggregate bool, + topo description.Topology, + candidates []description.Server, +) ([]description.Server, error) { + if err := verifyMaxStaleness(rp, topo); err != nil { + return nil, err + } + + // If underlying operation is an aggregate with an output stage, only apply read preference + // if all candidates are 5.0+. Otherwise, operate under primary read preference. + if isOutputAggregate { + for _, s := range candidates { + if s.WireVersion.Max < 13 { + return selectByKind(candidates, description.ServerKindRSPrimary), nil + } + } + } + + switch rp.Mode() { + case readpref.PrimaryMode: + return selectByKind(candidates, description.ServerKindRSPrimary), nil + case readpref.PrimaryPreferredMode: + selected := selectByKind(candidates, description.ServerKindRSPrimary) + + if len(selected) == 0 { + selected = selectSecondaries(rp, candidates) + return selectByTagSet(selected, rp.TagSets()), nil + } + + return selected, nil + case readpref.SecondaryPreferredMode: + selected := selectSecondaries(rp, candidates) + selected = selectByTagSet(selected, rp.TagSets()) + if len(selected) > 0 { + return selected, nil + } + return selectByKind(candidates, description.ServerKindRSPrimary), nil + case readpref.SecondaryMode: + selected := selectSecondaries(rp, candidates) + return selectByTagSet(selected, rp.TagSets()), nil + case readpref.NearestMode: + selected := selectByKind(candidates, description.ServerKindRSPrimary) + selected = append(selected, selectSecondaries(rp, candidates)...) + return selectByTagSet(selected, rp.TagSets()), nil + } + + return nil, fmt.Errorf("unsupported mode: %d", rp.Mode()) +} diff --git a/internal/serverselector/server_selector_test.go b/internal/serverselector/server_selector_test.go new file mode 100644 index 0000000000..a8f212aeca --- /dev/null +++ b/internal/serverselector/server_selector_test.go @@ -0,0 +1,1278 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package serverselector + +import ( + "errors" + "io/ioutil" + "path" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/driverutil" + "go.mongodb.org/mongo-driver/internal/require" + "go.mongodb.org/mongo-driver/internal/spectest" + "go.mongodb.org/mongo-driver/mongo/address" + "go.mongodb.org/mongo-driver/mongo/readpref" + "go.mongodb.org/mongo-driver/tag" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" +) + +type lastWriteDate struct { + LastWriteDate int64 `bson:"lastWriteDate"` +} + +type serverDesc struct { + Address string `bson:"address"` + AverageRTTMS *int `bson:"avg_rtt_ms"` + MaxWireVersion *int32 `bson:"maxWireVersion"` + LastUpdateTime *int `bson:"lastUpdateTime"` + LastWrite *lastWriteDate `bson:"lastWrite"` + Type string `bson:"type"` + Tags map[string]string `bson:"tags"` +} + +type topDesc struct { + Type string `bson:"type"` + Servers []*serverDesc `bson:"servers"` +} + +type readPref struct { + MaxStaleness *int `bson:"maxStalenessSeconds"` + Mode string `bson:"mode"` + TagSets []map[string]string `bson:"tag_sets"` +} + +type testCase struct { + TopologyDescription topDesc `bson:"topology_description"` + Operation string `bson:"operation"` + ReadPreference readPref `bson:"read_preference"` + SuitableServers []*serverDesc `bson:"suitable_servers"` + InLatencyWindow []*serverDesc `bson:"in_latency_window"` + HeartbeatFrequencyMS *int `bson:"heartbeatFrequencyMS"` + Error *bool +} + +func serverKindFromString(t *testing.T, s string) description.ServerKind { + t.Helper() + + switch s { + case "Standalone": + return description.ServerKindStandalone + case "RSOther": + return description.ServerKindRSMember + case "RSPrimary": + return description.ServerKindRSPrimary + case "RSSecondary": + return description.ServerKindRSSecondary + case "RSArbiter": + return description.ServerKindRSArbiter + case "RSGhost": + return description.ServerKindRSGhost + case "Mongos": + return description.ServerKindMongos + case "LoadBalancer": + return description.ServerKindLoadBalancer + case "PossiblePrimary", "Unknown": + // Go does not have a PossiblePrimary server type and per the SDAM spec, this type is synonymous with Unknown. + return description.Unknown + default: + t.Fatalf("unrecognized server kind: %q", s) + } + + return description.Unknown +} + +func topologyKindFromString(t *testing.T, s string) description.TopologyKind { + t.Helper() + + switch s { + case "Single": + return description.TopologyKindSingle + case "ReplicaSet": + return description.TopologyKindReplicaSet + case "ReplicaSetNoPrimary": + return description.TopologyKindReplicaSetNoPrimary + case "ReplicaSetWithPrimary": + return description.TopologyKindReplicaSetWithPrimary + case "Sharded": + return description.TopologyKindSharded + case "LoadBalanced": + return description.TopologyKindLoadBalanced + case "Unknown": + return description.Unknown + default: + t.Fatalf("unrecognized topology kind: %q", s) + } + + return description.Unknown +} + +func anyTagsInSets(sets []tag.Set) bool { + for _, set := range sets { + if len(set) > 0 { + return true + } + } + + return false +} + +func findServerByAddress(servers []description.Server, address string) description.Server { + for _, server := range servers { + if server.Addr.String() == address { + return server + } + } + + return description.Server{} +} + +func compareServers(t *testing.T, expected []*serverDesc, actual []description.Server) { + require.Equal(t, len(expected), len(actual)) + + for _, expectedServer := range expected { + actualServer := findServerByAddress(actual, expectedServer.Address) + require.NotNil(t, actualServer) + + if expectedServer.AverageRTTMS != nil { + require.Equal(t, *expectedServer.AverageRTTMS, int(actualServer.AverageRTT/time.Millisecond)) + } + + require.Equal(t, expectedServer.Type, actualServer.Kind.String()) + + require.Equal(t, len(expectedServer.Tags), len(actualServer.Tags)) + for _, actualTag := range actualServer.Tags { + expectedTag, ok := expectedServer.Tags[actualTag.Name] + require.True(t, ok) + require.Equal(t, expectedTag, actualTag.Value) + } + } +} + +const maxStalenessTestsDir = "../../testdata/max-staleness" + +// Test case for all max staleness spec tests. +func TestMaxStalenessSpec(t *testing.T) { + for _, topology := range [...]string{ + "ReplicaSetNoPrimary", + "ReplicaSetWithPrimary", + "Sharded", + "Single", + "Unknown", + } { + for _, file := range spectest.FindJSONFilesInDir(t, + path.Join(maxStalenessTestsDir, topology)) { + + runTest(t, maxStalenessTestsDir, topology, file) + } + } +} + +const selectorTestsDir = "../../testdata/server-selection/server_selection" + +func selectServers(t *testing.T, test *testCase) error { + servers := make([]description.Server, 0, len(test.TopologyDescription.Servers)) + + // Times in the JSON files are given as offsets from an unspecified time, but the driver + // stores the lastWrite field as a timestamp, so we arbitrarily choose the current time + // as the base to offset from. + baseTime := time.Now() + + for _, serverDescription := range test.TopologyDescription.Servers { + server := description.Server{ + Addr: address.Address(serverDescription.Address), + Kind: serverKindFromString(t, serverDescription.Type), + } + + if serverDescription.AverageRTTMS != nil { + server.AverageRTT = time.Duration(*serverDescription.AverageRTTMS) * time.Millisecond + server.AverageRTTSet = true + } + + if test.HeartbeatFrequencyMS != nil { + server.HeartbeatInterval = time.Duration(*test.HeartbeatFrequencyMS) * time.Millisecond + } + + if serverDescription.LastUpdateTime != nil { + ms := int64(*serverDescription.LastUpdateTime) + server.LastUpdateTime = time.Unix(ms/1e3, ms%1e3/1e6) + } + + if serverDescription.LastWrite != nil { + i := serverDescription.LastWrite.LastWriteDate + + timeWithOffset := baseTime.Add(time.Duration(i) * time.Millisecond) + server.LastWriteTime = timeWithOffset + } + + if serverDescription.MaxWireVersion != nil { + versionRange := driverutil.NewVersionRange(0, *serverDescription.MaxWireVersion) + server.WireVersion = &versionRange + } + + if serverDescription.Tags != nil { + server.Tags = tag.NewTagSetFromMap(serverDescription.Tags) + } + + if test.ReadPreference.MaxStaleness != nil && server.WireVersion == nil { + server.WireVersion = &description.VersionRange{Max: 21} + } + + servers = append(servers, server) + } + + c := description.Topology{ + Kind: topologyKindFromString(t, test.TopologyDescription.Type), + Servers: servers, + } + + if len(test.ReadPreference.Mode) == 0 { + test.ReadPreference.Mode = "Primary" + } + + readprefMode, err := readpref.ModeFromString(test.ReadPreference.Mode) + if err != nil { + return err + } + + options := make([]readpref.Option, 0, 1) + + tagSets := tag.NewTagSetsFromMaps(test.ReadPreference.TagSets) + if anyTagsInSets(tagSets) { + options = append(options, readpref.WithTagSets(tagSets...)) + } + + if test.ReadPreference.MaxStaleness != nil { + s := time.Duration(*test.ReadPreference.MaxStaleness) * time.Second + options = append(options, readpref.WithMaxStaleness(s)) + } + + rp, err := readpref.New(readprefMode, options...) + if err != nil { + return err + } + + var selector description.ServerSelector + + selector = &ReadPref{ReadPref: rp} + if test.Operation == "write" { + selector = &Composite{ + Selectors: []description.ServerSelector{&Write{}, selector}, + } + } + + result, err := selector.SelectServer(c, c.Servers) + if err != nil { + return err + } + + compareServers(t, test.SuitableServers, result) + + latencySelector := &Latency{Latency: time.Duration(15) * time.Millisecond} + selector = &Composite{ + Selectors: []description.ServerSelector{selector, latencySelector}, + } + + result, err = selector.SelectServer(c, c.Servers) + if err != nil { + return err + } + + compareServers(t, test.InLatencyWindow, result) + + return nil +} + +func runTest(t *testing.T, testsDir string, directory string, filename string) { + filepath := path.Join(testsDir, directory, filename) + content, err := ioutil.ReadFile(filepath) + require.NoError(t, err) + + // Remove ".json" from filename. + filename = filename[:len(filename)-5] + testName := directory + "/" + filename + ":" + + t.Run(testName, func(t *testing.T) { + var test testCase + require.NoError(t, bson.UnmarshalExtJSON(content, true, &test)) + + err := selectServers(t, &test) + + if test.Error == nil || !*test.Error { + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) +} + +// Test case for all SDAM spec tests. +func TestServerSelectionSpec(t *testing.T) { + for _, topology := range [...]string{ + "ReplicaSetNoPrimary", + "ReplicaSetWithPrimary", + "Sharded", + "Single", + "Unknown", + "LoadBalanced", + } { + for _, subdir := range [...]string{"read", "write"} { + subdirPath := path.Join(topology, subdir) + + for _, file := range spectest.FindJSONFilesInDir(t, + path.Join(selectorTestsDir, subdirPath)) { + + runTest(t, selectorTestsDir, subdirPath, file) + } + } + } +} + +func TestServerSelection(t *testing.T) { + noerr := func(t *testing.T, err error) { + if err != nil { + t.Errorf("Unepexted error: %v", err) + t.FailNow() + } + } + + t.Run("WriteSelector", func(t *testing.T) { + testCases := []struct { + name string + desc description.Topology + start int + end int + }{ + { + name: "ReplicaSetWithPrimary", + desc: description.Topology{ + Kind: description.TopologyKindReplicaSetWithPrimary, + Servers: []description.Server{ + {Addr: address.Address("localhost:27017"), Kind: description.ServerKindRSPrimary}, + {Addr: address.Address("localhost:27018"), Kind: description.ServerKindRSSecondary}, + {Addr: address.Address("localhost:27019"), Kind: description.ServerKindRSSecondary}, + }, + }, + start: 0, + end: 1, + }, + { + name: "ReplicaSetNoPrimary", + desc: description.Topology{ + Kind: description.TopologyKindReplicaSetNoPrimary, + Servers: []description.Server{ + {Addr: address.Address("localhost:27018"), Kind: description.ServerKindRSSecondary}, + {Addr: address.Address("localhost:27019"), Kind: description.ServerKindRSSecondary}, + }, + }, + start: 0, + end: 0, + }, + { + name: "Sharded", + desc: description.Topology{ + Kind: description.TopologyKindSharded, + Servers: []description.Server{ + {Addr: address.Address("localhost:27018"), Kind: description.ServerKindMongos}, + {Addr: address.Address("localhost:27019"), Kind: description.ServerKindMongos}, + }, + }, + start: 0, + end: 2, + }, + { + name: "Single", + desc: description.Topology{ + Kind: description.TopologyKindSingle, + Servers: []description.Server{ + {Addr: address.Address("localhost:27018"), Kind: description.ServerKindStandalone}, + }, + }, + start: 0, + end: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result, err := (&Write{}).SelectServer(tc.desc, tc.desc.Servers) + noerr(t, err) + if len(result) != tc.end-tc.start { + t.Errorf("Incorrect number of servers selected. got %d; want %d", len(result), tc.end-tc.start) + } + if diff := cmp.Diff(result, tc.desc.Servers[tc.start:tc.end]); diff != "" { + t.Errorf("Incorrect servers selected (-got +want):\n%s", diff) + } + }) + } + }) + t.Run("LatencySelector", func(t *testing.T) { + testCases := []struct { + name string + desc description.Topology + start int + end int + }{ + { + name: "NoRTTSet", + desc: description.Topology{ + Servers: []description.Server{ + {Addr: address.Address("localhost:27017")}, + {Addr: address.Address("localhost:27018")}, + {Addr: address.Address("localhost:27019")}, + }, + }, + start: 0, + end: 3, + }, + { + name: "MultipleServers PartialNoRTTSet", + desc: description.Topology{ + Servers: []description.Server{ + {Addr: address.Address("localhost:27017"), AverageRTT: 5 * time.Second, AverageRTTSet: true}, + {Addr: address.Address("localhost:27018"), AverageRTT: 10 * time.Second, AverageRTTSet: true}, + {Addr: address.Address("localhost:27019")}, + }, + }, + start: 0, + end: 2, + }, + { + name: "MultipleServers", + desc: description.Topology{ + Servers: []description.Server{ + {Addr: address.Address("localhost:27017"), AverageRTT: 5 * time.Second, AverageRTTSet: true}, + {Addr: address.Address("localhost:27018"), AverageRTT: 10 * time.Second, AverageRTTSet: true}, + {Addr: address.Address("localhost:27019"), AverageRTT: 26 * time.Second, AverageRTTSet: true}, + }, + }, + start: 0, + end: 2, + }, + { + name: "No Servers", + desc: description.Topology{Servers: []description.Server{}}, + start: 0, + end: 0, + }, + { + name: "1 Server", + desc: description.Topology{ + Servers: []description.Server{ + {Addr: address.Address("localhost:27017"), AverageRTT: 26 * time.Second, AverageRTTSet: true}, + }, + }, + start: 0, + end: 1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result, err := (&Latency{Latency: 20 * time.Second}).SelectServer(tc.desc, tc.desc.Servers) + noerr(t, err) + if len(result) != tc.end-tc.start { + t.Errorf("Incorrect number of servers selected. got %d; want %d", len(result), tc.end-tc.start) + } + if diff := cmp.Diff(result, tc.desc.Servers[tc.start:tc.end]); diff != "" { + t.Errorf("Incorrect servers selected (-got +want):\n%s", diff) + } + }) + } + }) +} + +var readPrefTestPrimary = description.Server{ + Addr: address.Address("localhost:27017"), + HeartbeatInterval: time.Duration(10) * time.Second, + LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), + LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), + Kind: description.ServerKindRSPrimary, + Tags: tag.Set{tag.Tag{Name: "a", Value: "1"}}, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, +} +var readPrefTestSecondary1 = description.Server{ + Addr: address.Address("localhost:27018"), + HeartbeatInterval: time.Duration(10) * time.Second, + LastWriteTime: time.Date(2017, 2, 11, 13, 58, 0, 0, time.UTC), + LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), + Kind: description.ServerKindRSSecondary, + Tags: tag.Set{tag.Tag{Name: "a", Value: "1"}}, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, +} +var readPrefTestSecondary2 = description.Server{ + Addr: address.Address("localhost:27018"), + HeartbeatInterval: time.Duration(10) * time.Second, + LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), + LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), + Kind: description.ServerKindRSSecondary, + Tags: tag.Set{tag.Tag{Name: "a", Value: "2"}}, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, +} +var readPrefTestTopology = description.Topology{ + Kind: description.TopologyKindReplicaSetWithPrimary, + Servers: []description.Server{readPrefTestPrimary, readPrefTestSecondary1, readPrefTestSecondary2}, +} + +func TestSelector_Sharded(t *testing.T) { + t.Parallel() + + subject := readpref.Primary() + + s := description.Server{ + Addr: address.Address("localhost:27017"), + HeartbeatInterval: time.Duration(10) * time.Second, + LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), + LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), + Kind: description.ServerKindMongos, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, + } + c := description.Topology{ + Kind: description.TopologyKindSharded, + Servers: []description.Server{s}, + } + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(c, c.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{s}, result) +} + +func BenchmarkLatencySelector(b *testing.B) { + for _, bcase := range []struct { + name string + serversHook func(servers []description.Server) + }{ + { + name: "AllFit", + serversHook: func(servers []description.Server) {}, + }, + { + name: "AllButOneFit", + serversHook: func(servers []description.Server) { + servers[0].AverageRTT = 2 * time.Second + }, + }, + { + name: "HalfFit", + serversHook: func(servers []description.Server) { + for i := 0; i < len(servers); i += 2 { + servers[i].AverageRTT = 2 * time.Second + } + }, + }, + { + name: "OneFit", + serversHook: func(servers []description.Server) { + for i := 1; i < len(servers); i++ { + servers[i].AverageRTT = 2 * time.Second + } + }, + }, + } { + bcase := bcase + + b.Run(bcase.name, func(b *testing.B) { + s := description.Server{ + Addr: address.Address("localhost:27017"), + HeartbeatInterval: time.Duration(10) * time.Second, + LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), + LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), + Kind: description.ServerKindMongos, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, + AverageRTTSet: true, + AverageRTT: time.Second, + } + servers := make([]description.Server, 100) + for i := 0; i < len(servers); i++ { + servers[i] = s + } + bcase.serversHook(servers) + //this will make base 1 sec latency < min (0.5) + conf (1) + //and high latency 2 higher than the threshold + servers[99].AverageRTT = 500 * time.Millisecond + c := description.Topology{ + Kind: description.TopologyKindSharded, + Servers: servers, + } + + b.ResetTimer() + b.RunParallel(func(p *testing.PB) { + b.ReportAllocs() + for p.Next() { + _, _ = (&Latency{Latency: time.Second}).SelectServer(c, c.Servers) + } + }) + }) + } +} + +func BenchmarkSelector_Sharded(b *testing.B) { + for _, bcase := range []struct { + name string + serversHook func(servers []description.Server) + }{ + { + name: "AllFit", + serversHook: func(servers []description.Server) {}, + }, + { + name: "AllButOneFit", + serversHook: func(servers []description.Server) { + servers[0].Kind = description.ServerKindLoadBalancer + }, + }, + { + name: "HalfFit", + serversHook: func(servers []description.Server) { + for i := 0; i < len(servers); i += 2 { + servers[i].Kind = description.ServerKindLoadBalancer + } + }, + }, + { + name: "OneFit", + serversHook: func(servers []description.Server) { + for i := 1; i < len(servers); i++ { + servers[i].Kind = description.ServerKindLoadBalancer + } + }, + }, + } { + bcase := bcase + + b.Run(bcase.name, func(b *testing.B) { + subject := readpref.Primary() + + s := description.Server{ + Addr: address.Address("localhost:27017"), + HeartbeatInterval: time.Duration(10) * time.Second, + LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), + LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), + Kind: description.ServerKindMongos, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, + } + servers := make([]description.Server, 100) + for i := 0; i < len(servers); i++ { + servers[i] = s + } + bcase.serversHook(servers) + c := description.Topology{ + Kind: description.TopologyKindSharded, + Servers: servers, + } + + b.ResetTimer() + b.RunParallel(func(p *testing.PB) { + b.ReportAllocs() + for p.Next() { + _, _ = (&ReadPref{ReadPref: subject}).SelectServer(c, c.Servers) + } + }) + }) + } +} + +func Benchmark_SelectServer_SelectServer(b *testing.B) { + topology := description.Topology{Kind: description.TopologyKindReplicaSet} // You can change the topology as needed + candidates := []description.Server{ + {Kind: description.ServerKindMongos}, + {Kind: description.ServerKindRSPrimary}, + {Kind: description.ServerKindStandalone}, + } + + selector := &Write{} // Assuming this is the receiver type + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + _, err := selector.SelectServer(topology, candidates) + if err != nil { + b.Fatalf("Error selecting server: %v", err) + } + } +} + +func TestSelector_Single(t *testing.T) { + t.Parallel() + + subject := readpref.Primary() + + s := description.Server{ + Addr: address.Address("localhost:27017"), + HeartbeatInterval: time.Duration(10) * time.Second, + LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), + LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), + Kind: description.ServerKindMongos, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, + } + c := description.Topology{ + Kind: description.TopologyKindSingle, + Servers: []description.Server{s}, + } + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(c, c.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{s}, result) +} + +func TestSelector_Primary(t *testing.T) { + t.Parallel() + + subject := readpref.Primary() + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestPrimary}, result) +} + +func TestSelector_Primary_with_no_primary(t *testing.T) { + t.Parallel() + + subject := readpref.Primary() + + result, err := (&ReadPref{ReadPref: subject}). + SelectServer(readPrefTestTopology, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}) + + require.NoError(t, err) + require.Len(t, result, 0) +} + +func TestSelector_PrimaryPreferred(t *testing.T) { + t.Parallel() + + subject := readpref.PrimaryPreferred() + + result, err := (&ReadPref{ReadPref: subject}). + SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestPrimary}, result) +} + +func TestSelector_PrimaryPreferred_ignores_tags(t *testing.T) { + t.Parallel() + + subject := readpref.PrimaryPreferred( + readpref.WithTags("a", "2"), + ) + + result, err := (&ReadPref{ReadPref: subject}). + SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestPrimary}, result) +} + +func TestSelector_PrimaryPreferred_with_no_primary(t *testing.T) { + t.Parallel() + + subject := readpref.PrimaryPreferred() + + result, err := (&ReadPref{ReadPref: subject}). + SelectServer(readPrefTestTopology, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}) + + require.NoError(t, err) + require.Len(t, result, 2) + require.Equal(t, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}, result) +} + +func TestSelector_PrimaryPreferred_with_no_primary_and_tags(t *testing.T) { + t.Parallel() + + subject := readpref.PrimaryPreferred( + readpref.WithTags("a", "2"), + ) + + result, err := (&ReadPref{ReadPref: subject}). + SelectServer(readPrefTestTopology, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestSecondary2}, result) +} + +func TestSelector_PrimaryPreferred_with_maxStaleness(t *testing.T) { + t.Parallel() + + subject := readpref.PrimaryPreferred( + readpref.WithMaxStaleness(time.Duration(90) * time.Second), + ) + + result, err := (&ReadPref{ReadPref: subject}). + SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestPrimary}, result) +} + +func TestSelector_PrimaryPreferred_with_maxStaleness_and_no_primary(t *testing.T) { + t.Parallel() + + subject := readpref.PrimaryPreferred( + readpref.WithMaxStaleness(time.Duration(90) * time.Second), + ) + + result, err := (&ReadPref{ReadPref: subject}). + SelectServer(readPrefTestTopology, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestSecondary2}, result) +} + +func TestSelector_SecondaryPreferred(t *testing.T) { + t.Parallel() + + subject := readpref.SecondaryPreferred() + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 2) + require.Equal(t, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}, result) +} + +func TestSelector_SecondaryPreferred_with_tags(t *testing.T) { + t.Parallel() + + subject := readpref.SecondaryPreferred( + readpref.WithTags("a", "2"), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestSecondary2}, result) +} + +func TestSelector_SecondaryPreferred_with_tags_that_do_not_match(t *testing.T) { + t.Parallel() + + subject := readpref.SecondaryPreferred( + readpref.WithTags("a", "3"), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestPrimary}, result) +} + +func TestSelector_SecondaryPreferred_with_tags_that_do_not_match_and_no_primary(t *testing.T) { + t.Parallel() + + subject := readpref.SecondaryPreferred( + readpref.WithTags("a", "3"), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}) + + require.NoError(t, err) + require.Len(t, result, 0) +} + +func TestSelector_SecondaryPreferred_with_no_secondaries(t *testing.T) { + t.Parallel() + + subject := readpref.SecondaryPreferred() + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, []description.Server{readPrefTestPrimary}) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestPrimary}, result) +} + +func TestSelector_SecondaryPreferred_with_no_secondaries_or_primary(t *testing.T) { + t.Parallel() + + subject := readpref.SecondaryPreferred() + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, []description.Server{}) + + require.NoError(t, err) + require.Len(t, result, 0) +} + +func TestSelector_SecondaryPreferred_with_maxStaleness(t *testing.T) { + t.Parallel() + + subject := readpref.SecondaryPreferred( + readpref.WithMaxStaleness(time.Duration(90) * time.Second), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestSecondary2}, result) +} + +func TestSelector_SecondaryPreferred_with_maxStaleness_and_no_primary(t *testing.T) { + t.Parallel() + + subject := readpref.SecondaryPreferred( + readpref.WithMaxStaleness(time.Duration(90) * time.Second), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestSecondary2}, result) +} + +func TestSelector_Secondary(t *testing.T) { + t.Parallel() + + subject := readpref.Secondary() + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 2) + require.Equal(t, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}, result) +} + +func TestSelector_Secondary_with_tags(t *testing.T) { + t.Parallel() + + subject := readpref.Secondary( + readpref.WithTags("a", "2"), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestSecondary2}, result) +} + +func TestSelector_Secondary_with_empty_tag_set(t *testing.T) { + t.Parallel() + + primaryNoTags := description.Server{ + Addr: address.Address("localhost:27017"), + Kind: description.ServerKindRSPrimary, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, + } + firstSecondaryNoTags := description.Server{ + Addr: address.Address("localhost:27018"), + Kind: description.ServerKindRSSecondary, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, + } + secondSecondaryNoTags := description.Server{ + Addr: address.Address("localhost:27019"), + Kind: description.ServerKindRSSecondary, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, + } + topologyNoTags := description.Topology{ + Kind: description.TopologyKindReplicaSetWithPrimary, + Servers: []description.Server{primaryNoTags, firstSecondaryNoTags, secondSecondaryNoTags}, + } + + nonMatchingSet := tag.Set{ + {Name: "foo", Value: "bar"}, + } + emptyTagSet := tag.Set{} + rp := readpref.Secondary( + readpref.WithTagSets(nonMatchingSet, emptyTagSet), + ) + + result, err := (&ReadPref{ReadPref: rp}).SelectServer(topologyNoTags, topologyNoTags.Servers) + assert.Nil(t, err, "SelectServer error: %v", err) + expectedResult := []description.Server{firstSecondaryNoTags, secondSecondaryNoTags} + assert.Equal(t, expectedResult, result, "expected result %v, got %v", expectedResult, result) +} + +func TestSelector_Secondary_with_tags_that_do_not_match(t *testing.T) { + t.Parallel() + + subject := readpref.Secondary( + readpref.WithTags("a", "3"), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 0) +} + +func TestSelector_Secondary_with_no_secondaries(t *testing.T) { + t.Parallel() + + subject := readpref.Secondary() + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, []description.Server{readPrefTestPrimary}) + + require.NoError(t, err) + require.Len(t, result, 0) +} + +func TestSelector_Secondary_with_maxStaleness(t *testing.T) { + t.Parallel() + + subject := readpref.Secondary( + readpref.WithMaxStaleness(time.Duration(90) * time.Second), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestSecondary2}, result) +} + +func TestSelector_Secondary_with_maxStaleness_and_no_primary(t *testing.T) { + t.Parallel() + + subject := readpref.Secondary( + readpref.WithMaxStaleness(time.Duration(90) * time.Second), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestSecondary2}, result) +} + +func TestSelector_Nearest(t *testing.T) { + t.Parallel() + + subject := readpref.Nearest() + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 3) + require.Equal(t, []description.Server{readPrefTestPrimary, readPrefTestSecondary1, readPrefTestSecondary2}, result) +} + +func TestSelector_Nearest_with_tags(t *testing.T) { + t.Parallel() + + subject := readpref.Nearest( + readpref.WithTags("a", "1"), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 2) + require.Equal(t, []description.Server{readPrefTestPrimary, readPrefTestSecondary1}, result) +} + +func TestSelector_Nearest_with_tags_that_do_not_match(t *testing.T) { + t.Parallel() + + subject := readpref.Nearest( + readpref.WithTags("a", "3"), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 0) +} + +func TestSelector_Nearest_with_no_primary(t *testing.T) { + t.Parallel() + + subject := readpref.Nearest() + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}) + + require.NoError(t, err) + require.Len(t, result, 2) + require.Equal(t, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}, result) +} + +func TestSelector_Nearest_with_no_secondaries(t *testing.T) { + t.Parallel() + + subject := readpref.Nearest() + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, []description.Server{readPrefTestPrimary}) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestPrimary}, result) +} + +func TestSelector_Nearest_with_maxStaleness(t *testing.T) { + t.Parallel() + + subject := readpref.Nearest( + readpref.WithMaxStaleness(time.Duration(90) * time.Second), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) + + require.NoError(t, err) + require.Len(t, result, 2) + require.Equal(t, []description.Server{readPrefTestPrimary, readPrefTestSecondary2}, result) +} + +func TestSelector_Nearest_with_maxStaleness_and_no_primary(t *testing.T) { + t.Parallel() + + subject := readpref.Nearest( + readpref.WithMaxStaleness(time.Duration(90) * time.Second), + ) + + result, err := (&ReadPref{ReadPref: subject}).SelectServer(readPrefTestTopology, []description.Server{readPrefTestSecondary1, readPrefTestSecondary2}) + + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, []description.Server{readPrefTestSecondary2}, result) +} + +func TestSelector_Max_staleness_is_less_than_90_seconds(t *testing.T) { + t.Parallel() + + subject := readpref.Nearest( + readpref.WithMaxStaleness(time.Duration(50) * time.Second), + ) + + s := description.Server{ + Addr: address.Address("localhost:27017"), + HeartbeatInterval: time.Duration(10) * time.Second, + LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), + LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), + Kind: description.ServerKindRSPrimary, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, + } + c := description.Topology{ + Kind: description.TopologyKindReplicaSetWithPrimary, + Servers: []description.Server{s}, + } + + _, err := (&ReadPref{ReadPref: subject}).SelectServer(c, c.Servers) + + require.Error(t, err) +} + +func TestSelector_Max_staleness_is_too_low(t *testing.T) { + t.Parallel() + + subject := readpref.Nearest( + readpref.WithMaxStaleness(time.Duration(100) * time.Second), + ) + + s := description.Server{ + Addr: address.Address("localhost:27017"), + HeartbeatInterval: time.Duration(100) * time.Second, + LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), + LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), + Kind: description.ServerKindRSPrimary, + WireVersion: &description.VersionRange{Min: 6, Max: 21}, + } + c := description.Topology{ + Kind: description.TopologyKindReplicaSetWithPrimary, + Servers: []description.Server{s}, + } + + _, err := (&ReadPref{ReadPref: subject}).SelectServer(c, c.Servers) + + require.Error(t, err) +} + +func TestEqualServers(t *testing.T) { + int64ToPtr := func(i64 int64) *int64 { return &i64 } + + t.Run("equals", func(t *testing.T) { + defaultServer := description.Server{} + // Only some of the Server fields affect equality + testCases := []struct { + name string + server description.Server + equal bool + }{ + {"empty", description.Server{}, true}, + {"address", description.Server{Addr: address.Address("foo")}, true}, + {"arbiters", description.Server{Arbiters: []string{"foo"}}, false}, + {"rtt", description.Server{AverageRTT: time.Second}, true}, + {"compression", description.Server{Compression: []string{"foo"}}, true}, + {"canonicalAddr", description.Server{CanonicalAddr: address.Address("foo")}, false}, + {"electionID", description.Server{ElectionID: bson.NewObjectID()}, false}, + {"heartbeatInterval", description.Server{HeartbeatInterval: time.Second}, true}, + {"hosts", description.Server{Hosts: []string{"foo"}}, false}, + {"lastError", description.Server{LastError: errors.New("foo")}, false}, + {"lastUpdateTime", description.Server{LastUpdateTime: time.Now()}, true}, + {"lastWriteTime", description.Server{LastWriteTime: time.Now()}, true}, + {"maxBatchCount", description.Server{MaxBatchCount: 1}, true}, + {"maxDocumentSize", description.Server{MaxDocumentSize: 1}, true}, + {"maxMessageSize", description.Server{MaxMessageSize: 1}, true}, + {"members", description.Server{Members: []address.Address{address.Address("foo")}}, true}, + {"passives", description.Server{Passives: []string{"foo"}}, false}, + {"passive", description.Server{Passive: true}, true}, + {"primary", description.Server{Primary: address.Address("foo")}, false}, + {"readOnly", description.Server{ReadOnly: true}, true}, + { + "sessionTimeoutMinutes", + description.Server{ + SessionTimeoutMinutes: int64ToPtr(1), + }, + false, + }, + {"setName", description.Server{SetName: "foo"}, false}, + {"setVersion", description.Server{SetVersion: 1}, false}, + {"tags", description.Server{Tags: tag.Set{tag.Tag{"foo", "bar"}}}, false}, + {"topologyVersion", description.Server{TopologyVersion: &description.TopologyVersion{bson.NewObjectID(), 0}}, false}, + {"kind", description.Server{Kind: description.ServerKindStandalone}, false}, + {"wireVersion", description.Server{WireVersion: &description.VersionRange{1, 2}}, false}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actual := driverutil.EqualServers(defaultServer, tc.server) + assert.Equal(t, actual, tc.equal, "expected %v, got %v", tc.equal, actual) + }) + } + }) +} + +func TestVersionRangeIncludes(t *testing.T) { + t.Parallel() + + subject := driverutil.NewVersionRange(1, 3) + + tests := []struct { + n int32 + expected bool + }{ + {0, false}, + {1, true}, + {2, true}, + {3, true}, + {4, false}, + {10, false}, + } + + for _, test := range tests { + actual := driverutil.VersionRangeIncludes(subject, test.n) + if actual != test.expected { + t.Fatalf("expected %v to be %t", test.n, test.expected) + } + } +} diff --git a/mongo/bulk_write.go b/mongo/bulk_write.go index 0591dde6a3..96bea7b8e0 100644 --- a/mongo/bulk_write.go +++ b/mongo/bulk_write.go @@ -11,11 +11,11 @@ import ( "errors" "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/mongo/change_stream.go b/mongo/change_stream.go index 0df6ae03c7..cc051b5f08 100644 --- a/mongo/change_stream.go +++ b/mongo/change_stream.go @@ -16,12 +16,14 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/csot" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/internal/driverutil" + "go.mongodb.org/mongo-driver/internal/serverselector" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" @@ -165,10 +167,12 @@ func newChangeStream(ctx context.Context, config changeStreamConfig, pipeline in registry: config.registry, streamType: config.streamType, options: mergeChangeStreamOptions(opts...), - selector: description.CompositeSelector([]description.ServerSelector{ - description.ReadPrefSelector(config.readPreference), - description.LatencySelector(config.client.localThreshold), - }), + selector: &serverselector.Composite{ + Selectors: []description.ServerSelector{ + &serverselector.ReadPref{ReadPref: config.readPreference}, + &serverselector.Latency{Latency: config.client.localThreshold}, + }, + }, cursorOptions: cursorOpts, } @@ -751,7 +755,7 @@ func (cs *ChangeStream) isResumableError() bool { } // For wire versions 9 and above, a server error is resumable if it has the ResumableChangeStreamError label. - if cs.wireVersion != nil && cs.wireVersion.Includes(minResumableLabelWireVersion) { + if cs.wireVersion != nil && driverutil.VersionRangeIncludes(*cs.wireVersion, minResumableLabelWireVersion) { return commandErr.HasErrorLabel(resumableErrorLabel) } diff --git a/mongo/change_stream_deployment.go b/mongo/change_stream_deployment.go index a84b43f05c..64f30095c8 100644 --- a/mongo/change_stream_deployment.go +++ b/mongo/change_stream_deployment.go @@ -9,8 +9,8 @@ package mongo import ( "context" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" ) diff --git a/mongo/client.go b/mongo/client.go index 36f6fbc35f..e68714c33f 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -17,14 +17,15 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/httputil" "go.mongodb.org/mongo-driver/internal/logger" + "go.mongodb.org/mongo-driver/internal/serverselector" "go.mongodb.org/mongo-driver/internal/uuid" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt" mcopts "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt/options" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" @@ -452,8 +453,8 @@ func (c *Client) endSessions(ctx context.Context) { sessionIDs := c.sessionPool.IDSlice() op := operation.NewEndSessions(nil).ClusterClock(c.clock).Deployment(c.deployment). - ServerSelector(description.ReadPrefSelector(readpref.PrimaryPreferred())).CommandMonitor(c.monitor). - Database("admin").Crypt(c.cryptFLE).ServerAPI(c.serverAPI) + ServerSelector(&serverselector.ReadPref{ReadPref: readpref.PrimaryPreferred()}). + CommandMonitor(c.monitor).Database("admin").Crypt(c.cryptFLE).ServerAPI(c.serverAPI) totalNumIDs := len(sessionIDs) var currentBatch []bsoncore.Document @@ -704,10 +705,15 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ... return ListDatabasesResult{}, err } - selector := description.CompositeSelector([]description.ServerSelector{ - description.ReadPrefSelector(readpref.Primary()), - description.LatencySelector(c.localThreshold), - }) + var selector description.ServerSelector + + selector = &serverselector.Composite{ + Selectors: []description.ServerSelector{ + &serverselector.ReadPref{ReadPref: readpref.Primary()}, + &serverselector.Latency{Latency: c.localThreshold}, + }, + } + selector = makeReadPrefSelector(sess, selector, c.localThreshold) ldo := options.ListDatabases() diff --git a/mongo/collection.go b/mongo/collection.go index 1c9754f3c6..a7f7bc9cb9 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -16,13 +16,14 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/csfle" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/internal/serverselector" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -119,15 +120,19 @@ func newCollection(db *Database, name string, opts ...*options.CollectionOptions reg = collOpt.Registry } - readSelector := description.CompositeSelector([]description.ServerSelector{ - description.ReadPrefSelector(rp), - description.LatencySelector(db.client.localThreshold), - }) + readSelector := &serverselector.Composite{ + Selectors: []description.ServerSelector{ + &serverselector.ReadPref{ReadPref: rp}, + &serverselector.Latency{Latency: db.client.localThreshold}, + }, + } - writeSelector := description.CompositeSelector([]description.ServerSelector{ - description.WriteSelector(), - description.LatencySelector(db.client.localThreshold), - }) + writeSelector := &serverselector.Composite{ + Selectors: []description.ServerSelector{ + &serverselector.Write{}, + &serverselector.Latency{Latency: db.client.localThreshold}, + }, + } coll := &Collection{ client: db.client, @@ -182,10 +187,12 @@ func (coll *Collection) Clone(opts ...*options.CollectionOptions) *Collection { copyColl.registry = optsColl.Registry } - copyColl.readSelector = description.CompositeSelector([]description.ServerSelector{ - description.ReadPrefSelector(copyColl.readPreference), - description.LatencySelector(copyColl.client.localThreshold), - }) + copyColl.readSelector = &serverselector.Composite{ + Selectors: []description.ServerSelector{ + &serverselector.ReadPref{ReadPref: copyColl.readPreference}, + &serverselector.Latency{Latency: copyColl.client.localThreshold}, + }, + } return copyColl } @@ -2280,6 +2287,8 @@ type pinnedServerSelector struct { session *session.Client } +var _ description.ServerSelector = pinnedServerSelector{} + func (pss pinnedServerSelector) String() string { if pss.stringer == nil { return "" @@ -2292,10 +2301,10 @@ func (pss pinnedServerSelector) SelectServer( t description.Topology, svrs []description.Server, ) ([]description.Server, error) { - if pss.session != nil && pss.session.PinnedServer != nil { + if pss.session != nil && pss.session.PinnedServerAddr != nil { // If there is a pinned server, try to find it in the list of candidates. for _, candidate := range svrs { - if candidate.Addr == pss.session.PinnedServer.Addr { + if candidate.Addr == *pss.session.PinnedServerAddr { return []description.Server{candidate}, nil } } @@ -2306,7 +2315,7 @@ func (pss pinnedServerSelector) SelectServer( return pss.fallback.SelectServer(t, svrs) } -func makePinnedSelector(sess *session.Client, fallback description.ServerSelector) description.ServerSelector { +func makePinnedSelector(sess *session.Client, fallback description.ServerSelector) pinnedServerSelector { pss := pinnedServerSelector{ session: sess, fallback: fallback, @@ -2319,27 +2328,40 @@ func makePinnedSelector(sess *session.Client, fallback description.ServerSelecto return pss } -func makeReadPrefSelector(sess *session.Client, selector description.ServerSelector, localThreshold time.Duration) description.ServerSelector { +func makeReadPrefSelector( + sess *session.Client, + selector description.ServerSelector, + localThreshold time.Duration, +) pinnedServerSelector { if sess != nil && sess.TransactionRunning() { - selector = description.CompositeSelector([]description.ServerSelector{ - description.ReadPrefSelector(sess.CurrentRp), - description.LatencySelector(localThreshold), - }) + selector = &serverselector.Composite{ + Selectors: []description.ServerSelector{ + &serverselector.ReadPref{ReadPref: sess.CurrentRp}, + &serverselector.Latency{Latency: localThreshold}, + }, + } } return makePinnedSelector(sess, selector) } -func makeOutputAggregateSelector(sess *session.Client, rp *readpref.ReadPref, localThreshold time.Duration) description.ServerSelector { +func makeOutputAggregateSelector( + sess *session.Client, + rp *readpref.ReadPref, + localThreshold time.Duration, +) pinnedServerSelector { if sess != nil && sess.TransactionRunning() { // Use current transaction's read preference if available rp = sess.CurrentRp } - selector := description.CompositeSelector([]description.ServerSelector{ - description.OutputAggregateSelector(rp), - description.LatencySelector(localThreshold), - }) + selector := &serverselector.Composite{ + Selectors: []description.ServerSelector{ + &serverselector.ReadPref{ReadPref: rp, IsOutputAggregate: true}, + &serverselector.Latency{Latency: localThreshold}, + }, + } + return makePinnedSelector(sess, selector) } diff --git a/mongo/database.go b/mongo/database.go index dec6423c54..4748d3d2b0 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -14,13 +14,14 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/csfle" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/internal/serverselector" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -97,15 +98,19 @@ func newDatabase(client *Client, name string, opts ...*options.DatabaseOptions) registry: reg, } - db.readSelector = description.CompositeSelector([]description.ServerSelector{ - description.ReadPrefSelector(db.readPreference), - description.LatencySelector(db.client.localThreshold), - }) + db.readSelector = &serverselector.Composite{ + Selectors: []description.ServerSelector{ + &serverselector.ReadPref{ReadPref: db.readPreference}, + &serverselector.Latency{Latency: db.client.localThreshold}, + }, + } - db.writeSelector = description.CompositeSelector([]description.ServerSelector{ - description.WriteSelector(), - description.LatencySelector(db.client.localThreshold), - }) + db.writeSelector = &serverselector.Composite{ + Selectors: []description.ServerSelector{ + &serverselector.Write{}, + &serverselector.Latency{Latency: db.client.localThreshold}, + }, + } return db } @@ -189,11 +194,17 @@ func (db *Database) processRunCommand(ctx context.Context, cmd interface{}, if err != nil { return nil, sess, err } - readSelect := description.CompositeSelector([]description.ServerSelector{ - description.ReadPrefSelector(ro.ReadPreference), - description.LatencySelector(db.client.localThreshold), - }) - if sess != nil && sess.PinnedServer != nil { + + var readSelect description.ServerSelector + + readSelect = &serverselector.Composite{ + Selectors: []description.ServerSelector{ + &serverselector.ReadPref{ReadPref: ro.ReadPreference}, + &serverselector.Latency{Latency: db.client.localThreshold}, + }, + } + + if sess != nil && sess.PinnedServerAddr != nil { readSelect = makePinnedSelector(sess, readSelect) } @@ -436,10 +447,15 @@ func (db *Database) ListCollections(ctx context.Context, filter interface{}, opt return nil, err } - selector := description.CompositeSelector([]description.ServerSelector{ - description.ReadPrefSelector(readpref.Primary()), - description.LatencySelector(db.client.localThreshold), - }) + var selector description.ServerSelector + + selector = &serverselector.Composite{ + Selectors: []description.ServerSelector{ + &serverselector.ReadPref{ReadPref: readpref.Primary()}, + &serverselector.Latency{Latency: db.client.localThreshold}, + }, + } + selector = makeReadPrefSelector(sess, selector, db.client.localThreshold) lco := options.ListCollections() @@ -711,7 +727,7 @@ func (db *Database) createCollectionWithEncryptedFields(ctx context.Context, nam // That is OK. This wire version check is a best effort to inform users earlier if using a QEv2 driver with a QEv1 server. { const QEv2WireVersion = 21 - server, err := db.client.deployment.SelectServer(ctx, description.WriteSelector()) + server, err := db.client.deployment.SelectServer(ctx, &serverselector.Write{}) if err != nil { return fmt.Errorf("error selecting server to check maxWireVersion: %w", err) } diff --git a/mongo/description/description.go b/mongo/description/description.go deleted file mode 100644 index 68b61a6249..0000000000 --- a/mongo/description/description.go +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -// Package description contains types and functions for describing the state of MongoDB clusters. -package description - -// Unknown is an unknown server or topology kind. -const Unknown = 0 diff --git a/mongo/description/max_staleness_spec_test.go b/mongo/description/max_staleness_spec_test.go deleted file mode 100644 index 0bab617c6a..0000000000 --- a/mongo/description/max_staleness_spec_test.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package description - -import ( - "path" - "testing" - - "go.mongodb.org/mongo-driver/internal/spectest" -) - -const maxStalenessTestsDir = "../../testdata/max-staleness" - -// Test case for all max staleness spec tests. -func TestMaxStalenessSpec(t *testing.T) { - for _, topology := range [...]string{ - "ReplicaSetNoPrimary", - "ReplicaSetWithPrimary", - "Sharded", - "Single", - "Unknown", - } { - for _, file := range spectest.FindJSONFilesInDir(t, - path.Join(maxStalenessTestsDir, topology)) { - - runTest(t, maxStalenessTestsDir, topology, file) - } - } -} diff --git a/mongo/description/selector_spec_test.go b/mongo/description/selector_spec_test.go deleted file mode 100644 index f5d2eb3291..0000000000 --- a/mongo/description/selector_spec_test.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package description - -import ( - "path" - "testing" - - "go.mongodb.org/mongo-driver/internal/spectest" -) - -const selectorTestsDir = "../../testdata/server-selection/server_selection" - -// Test case for all SDAM spec tests. -func TestServerSelectionSpec(t *testing.T) { - for _, topology := range [...]string{ - "ReplicaSetNoPrimary", - "ReplicaSetWithPrimary", - "Sharded", - "Single", - "Unknown", - "LoadBalanced", - } { - for _, subdir := range [...]string{"read", "write"} { - subdirPath := path.Join(topology, subdir) - - for _, file := range spectest.FindJSONFilesInDir(t, - path.Join(selectorTestsDir, subdirPath)) { - - runTest(t, selectorTestsDir, subdirPath, file) - } - } - } -} diff --git a/mongo/description/selector_test.go b/mongo/description/selector_test.go deleted file mode 100644 index a3566783dd..0000000000 --- a/mongo/description/selector_test.go +++ /dev/null @@ -1,874 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package description - -import ( - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "go.mongodb.org/mongo-driver/internal/assert" - "go.mongodb.org/mongo-driver/internal/require" - "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/readpref" - "go.mongodb.org/mongo-driver/tag" -) - -func TestServerSelection(t *testing.T) { - noerr := func(t *testing.T, err error) { - if err != nil { - t.Errorf("Unepexted error: %v", err) - t.FailNow() - } - } - - t.Run("WriteSelector", func(t *testing.T) { - testCases := []struct { - name string - desc Topology - start int - end int - }{ - { - name: "ReplicaSetWithPrimary", - desc: Topology{ - Kind: ReplicaSetWithPrimary, - Servers: []Server{ - {Addr: address.Address("localhost:27017"), Kind: RSPrimary}, - {Addr: address.Address("localhost:27018"), Kind: RSSecondary}, - {Addr: address.Address("localhost:27019"), Kind: RSSecondary}, - }, - }, - start: 0, - end: 1, - }, - { - name: "ReplicaSetNoPrimary", - desc: Topology{ - Kind: ReplicaSetNoPrimary, - Servers: []Server{ - {Addr: address.Address("localhost:27018"), Kind: RSSecondary}, - {Addr: address.Address("localhost:27019"), Kind: RSSecondary}, - }, - }, - start: 0, - end: 0, - }, - { - name: "Sharded", - desc: Topology{ - Kind: Sharded, - Servers: []Server{ - {Addr: address.Address("localhost:27018"), Kind: Mongos}, - {Addr: address.Address("localhost:27019"), Kind: Mongos}, - }, - }, - start: 0, - end: 2, - }, - { - name: "Single", - desc: Topology{ - Kind: Single, - Servers: []Server{ - {Addr: address.Address("localhost:27018"), Kind: Standalone}, - }, - }, - start: 0, - end: 1, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - result, err := WriteSelector().SelectServer(tc.desc, tc.desc.Servers) - noerr(t, err) - if len(result) != tc.end-tc.start { - t.Errorf("Incorrect number of servers selected. got %d; want %d", len(result), tc.end-tc.start) - } - if diff := cmp.Diff(result, tc.desc.Servers[tc.start:tc.end]); diff != "" { - t.Errorf("Incorrect servers selected (-got +want):\n%s", diff) - } - }) - } - }) - t.Run("LatencySelector", func(t *testing.T) { - testCases := []struct { - name string - desc Topology - start int - end int - }{ - { - name: "NoRTTSet", - desc: Topology{ - Servers: []Server{ - {Addr: address.Address("localhost:27017")}, - {Addr: address.Address("localhost:27018")}, - {Addr: address.Address("localhost:27019")}, - }, - }, - start: 0, - end: 3, - }, - { - name: "MultipleServers PartialNoRTTSet", - desc: Topology{ - Servers: []Server{ - {Addr: address.Address("localhost:27017"), AverageRTT: 5 * time.Second, AverageRTTSet: true}, - {Addr: address.Address("localhost:27018"), AverageRTT: 10 * time.Second, AverageRTTSet: true}, - {Addr: address.Address("localhost:27019")}, - }, - }, - start: 0, - end: 2, - }, - { - name: "MultipleServers", - desc: Topology{ - Servers: []Server{ - {Addr: address.Address("localhost:27017"), AverageRTT: 5 * time.Second, AverageRTTSet: true}, - {Addr: address.Address("localhost:27018"), AverageRTT: 10 * time.Second, AverageRTTSet: true}, - {Addr: address.Address("localhost:27019"), AverageRTT: 26 * time.Second, AverageRTTSet: true}, - }, - }, - start: 0, - end: 2, - }, - { - name: "No Servers", - desc: Topology{Servers: []Server{}}, - start: 0, - end: 0, - }, - { - name: "1 Server", - desc: Topology{ - Servers: []Server{ - {Addr: address.Address("localhost:27017"), AverageRTT: 26 * time.Second, AverageRTTSet: true}, - }, - }, - start: 0, - end: 1, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - result, err := LatencySelector(20*time.Second).SelectServer(tc.desc, tc.desc.Servers) - noerr(t, err) - if len(result) != tc.end-tc.start { - t.Errorf("Incorrect number of servers selected. got %d; want %d", len(result), tc.end-tc.start) - } - if diff := cmp.Diff(result, tc.desc.Servers[tc.start:tc.end]); diff != "" { - t.Errorf("Incorrect servers selected (-got +want):\n%s", diff) - } - }) - } - }) -} - -var readPrefTestPrimary = Server{ - Addr: address.Address("localhost:27017"), - HeartbeatInterval: time.Duration(10) * time.Second, - LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), - LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), - Kind: RSPrimary, - Tags: tag.Set{tag.Tag{Name: "a", Value: "1"}}, - WireVersion: &VersionRange{Min: 6, Max: 21}, -} -var readPrefTestSecondary1 = Server{ - Addr: address.Address("localhost:27018"), - HeartbeatInterval: time.Duration(10) * time.Second, - LastWriteTime: time.Date(2017, 2, 11, 13, 58, 0, 0, time.UTC), - LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), - Kind: RSSecondary, - Tags: tag.Set{tag.Tag{Name: "a", Value: "1"}}, - WireVersion: &VersionRange{Min: 6, Max: 21}, -} -var readPrefTestSecondary2 = Server{ - Addr: address.Address("localhost:27018"), - HeartbeatInterval: time.Duration(10) * time.Second, - LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), - LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), - Kind: RSSecondary, - Tags: tag.Set{tag.Tag{Name: "a", Value: "2"}}, - WireVersion: &VersionRange{Min: 6, Max: 21}, -} -var readPrefTestTopology = Topology{ - Kind: ReplicaSetWithPrimary, - Servers: []Server{readPrefTestPrimary, readPrefTestSecondary1, readPrefTestSecondary2}, -} - -func TestSelector_Sharded(t *testing.T) { - t.Parallel() - - subject := readpref.Primary() - - s := Server{ - Addr: address.Address("localhost:27017"), - HeartbeatInterval: time.Duration(10) * time.Second, - LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), - LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), - Kind: Mongos, - WireVersion: &VersionRange{Min: 6, Max: 21}, - } - c := Topology{ - Kind: Sharded, - Servers: []Server{s}, - } - - result, err := ReadPrefSelector(subject).SelectServer(c, c.Servers) - - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, []Server{s}, result) -} - -func BenchmarkLatencySelector(b *testing.B) { - for _, bcase := range []struct { - name string - serversHook func(servers []Server) - }{ - { - name: "AllFit", - serversHook: func(servers []Server) {}, - }, - { - name: "AllButOneFit", - serversHook: func(servers []Server) { - servers[0].AverageRTT = 2 * time.Second - }, - }, - { - name: "HalfFit", - serversHook: func(servers []Server) { - for i := 0; i < len(servers); i += 2 { - servers[i].AverageRTT = 2 * time.Second - } - }, - }, - { - name: "OneFit", - serversHook: func(servers []Server) { - for i := 1; i < len(servers); i++ { - servers[i].AverageRTT = 2 * time.Second - } - }, - }, - } { - bcase := bcase - - b.Run(bcase.name, func(b *testing.B) { - s := Server{ - Addr: address.Address("localhost:27017"), - HeartbeatInterval: time.Duration(10) * time.Second, - LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), - LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), - Kind: Mongos, - WireVersion: &VersionRange{Min: 6, Max: 21}, - AverageRTTSet: true, - AverageRTT: time.Second, - } - servers := make([]Server, 100) - for i := 0; i < len(servers); i++ { - servers[i] = s - } - bcase.serversHook(servers) - //this will make base 1 sec latency < min (0.5) + conf (1) - //and high latency 2 higher than the threshold - servers[99].AverageRTT = 500 * time.Millisecond - c := Topology{ - Kind: Sharded, - Servers: servers, - } - - b.ResetTimer() - b.RunParallel(func(p *testing.PB) { - b.ReportAllocs() - for p.Next() { - _, _ = LatencySelector(time.Second).SelectServer(c, c.Servers) - } - }) - }) - } -} - -func BenchmarkSelector_Sharded(b *testing.B) { - for _, bcase := range []struct { - name string - serversHook func(servers []Server) - }{ - { - name: "AllFit", - serversHook: func(servers []Server) {}, - }, - { - name: "AllButOneFit", - serversHook: func(servers []Server) { - servers[0].Kind = LoadBalancer - }, - }, - { - name: "HalfFit", - serversHook: func(servers []Server) { - for i := 0; i < len(servers); i += 2 { - servers[i].Kind = LoadBalancer - } - }, - }, - { - name: "OneFit", - serversHook: func(servers []Server) { - for i := 1; i < len(servers); i++ { - servers[i].Kind = LoadBalancer - } - }, - }, - } { - bcase := bcase - - b.Run(bcase.name, func(b *testing.B) { - subject := readpref.Primary() - - s := Server{ - Addr: address.Address("localhost:27017"), - HeartbeatInterval: time.Duration(10) * time.Second, - LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), - LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), - Kind: Mongos, - WireVersion: &VersionRange{Min: 6, Max: 21}, - } - servers := make([]Server, 100) - for i := 0; i < len(servers); i++ { - servers[i] = s - } - bcase.serversHook(servers) - c := Topology{ - Kind: Sharded, - Servers: servers, - } - - b.ResetTimer() - b.RunParallel(func(p *testing.PB) { - b.ReportAllocs() - for p.Next() { - _, _ = ReadPrefSelector(subject).SelectServer(c, c.Servers) - } - }) - }) - } -} - -func Benchmark_SelectServer_SelectServer(b *testing.B) { - topology := Topology{Kind: ReplicaSet} // You can change the topology as needed - candidates := []Server{ - {Kind: Mongos}, - {Kind: RSPrimary}, - {Kind: Standalone}, - } - - selector := writeServerSelector{} // Assuming this is the receiver type - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - _, err := selector.SelectServer(topology, candidates) - if err != nil { - b.Fatalf("Error selecting server: %v", err) - } - } -} - -func TestSelector_Single(t *testing.T) { - t.Parallel() - - subject := readpref.Primary() - - s := Server{ - Addr: address.Address("localhost:27017"), - HeartbeatInterval: time.Duration(10) * time.Second, - LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), - LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), - Kind: Mongos, - WireVersion: &VersionRange{Min: 6, Max: 21}, - } - c := Topology{ - Kind: Single, - Servers: []Server{s}, - } - - result, err := ReadPrefSelector(subject).SelectServer(c, c.Servers) - - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, []Server{s}, result) -} - -func TestSelector_Primary(t *testing.T) { - t.Parallel() - - subject := readpref.Primary() - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) - - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, []Server{readPrefTestPrimary}, result) -} - -func TestSelector_Primary_with_no_primary(t *testing.T) { - t.Parallel() - - subject := readpref.Primary() - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, []Server{readPrefTestSecondary1, readPrefTestSecondary2}) - - require.NoError(t, err) - require.Len(t, result, 0) -} - -func TestSelector_PrimaryPreferred(t *testing.T) { - t.Parallel() - - subject := readpref.PrimaryPreferred() - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) - - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, []Server{readPrefTestPrimary}, result) -} - -func TestSelector_PrimaryPreferred_ignores_tags(t *testing.T) { - t.Parallel() - - subject := readpref.PrimaryPreferred( - readpref.WithTags("a", "2"), - ) - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) - - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, []Server{readPrefTestPrimary}, result) -} - -func TestSelector_PrimaryPreferred_with_no_primary(t *testing.T) { - t.Parallel() - - subject := readpref.PrimaryPreferred() - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, []Server{readPrefTestSecondary1, readPrefTestSecondary2}) - - require.NoError(t, err) - require.Len(t, result, 2) - require.Equal(t, []Server{readPrefTestSecondary1, readPrefTestSecondary2}, result) -} - -func TestSelector_PrimaryPreferred_with_no_primary_and_tags(t *testing.T) { - t.Parallel() - - subject := readpref.PrimaryPreferred( - readpref.WithTags("a", "2"), - ) - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, []Server{readPrefTestSecondary1, readPrefTestSecondary2}) - - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, []Server{readPrefTestSecondary2}, result) -} - -func TestSelector_PrimaryPreferred_with_maxStaleness(t *testing.T) { - t.Parallel() - - subject := readpref.PrimaryPreferred( - readpref.WithMaxStaleness(time.Duration(90) * time.Second), - ) - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) - - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, []Server{readPrefTestPrimary}, result) -} - -func TestSelector_PrimaryPreferred_with_maxStaleness_and_no_primary(t *testing.T) { - t.Parallel() - - subject := readpref.PrimaryPreferred( - readpref.WithMaxStaleness(time.Duration(90) * time.Second), - ) - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, []Server{readPrefTestSecondary1, readPrefTestSecondary2}) - - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, []Server{readPrefTestSecondary2}, result) -} - -func TestSelector_SecondaryPreferred(t *testing.T) { - t.Parallel() - - subject := readpref.SecondaryPreferred() - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) - - require.NoError(t, err) - require.Len(t, result, 2) - require.Equal(t, []Server{readPrefTestSecondary1, readPrefTestSecondary2}, result) -} - -func TestSelector_SecondaryPreferred_with_tags(t *testing.T) { - t.Parallel() - - subject := readpref.SecondaryPreferred( - readpref.WithTags("a", "2"), - ) - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) - - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, []Server{readPrefTestSecondary2}, result) -} - -func TestSelector_SecondaryPreferred_with_tags_that_do_not_match(t *testing.T) { - t.Parallel() - - subject := readpref.SecondaryPreferred( - readpref.WithTags("a", "3"), - ) - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) - - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, []Server{readPrefTestPrimary}, result) -} - -func TestSelector_SecondaryPreferred_with_tags_that_do_not_match_and_no_primary(t *testing.T) { - t.Parallel() - - subject := readpref.SecondaryPreferred( - readpref.WithTags("a", "3"), - ) - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, []Server{readPrefTestSecondary1, readPrefTestSecondary2}) - - require.NoError(t, err) - require.Len(t, result, 0) -} - -func TestSelector_SecondaryPreferred_with_no_secondaries(t *testing.T) { - t.Parallel() - - subject := readpref.SecondaryPreferred() - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, []Server{readPrefTestPrimary}) - - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, []Server{readPrefTestPrimary}, result) -} - -func TestSelector_SecondaryPreferred_with_no_secondaries_or_primary(t *testing.T) { - t.Parallel() - - subject := readpref.SecondaryPreferred() - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, []Server{}) - - require.NoError(t, err) - require.Len(t, result, 0) -} - -func TestSelector_SecondaryPreferred_with_maxStaleness(t *testing.T) { - t.Parallel() - - subject := readpref.SecondaryPreferred( - readpref.WithMaxStaleness(time.Duration(90) * time.Second), - ) - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) - - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, []Server{readPrefTestSecondary2}, result) -} - -func TestSelector_SecondaryPreferred_with_maxStaleness_and_no_primary(t *testing.T) { - t.Parallel() - - subject := readpref.SecondaryPreferred( - readpref.WithMaxStaleness(time.Duration(90) * time.Second), - ) - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, []Server{readPrefTestSecondary1, readPrefTestSecondary2}) - - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, []Server{readPrefTestSecondary2}, result) -} - -func TestSelector_Secondary(t *testing.T) { - t.Parallel() - - subject := readpref.Secondary() - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) - - require.NoError(t, err) - require.Len(t, result, 2) - require.Equal(t, []Server{readPrefTestSecondary1, readPrefTestSecondary2}, result) -} - -func TestSelector_Secondary_with_tags(t *testing.T) { - t.Parallel() - - subject := readpref.Secondary( - readpref.WithTags("a", "2"), - ) - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) - - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, []Server{readPrefTestSecondary2}, result) -} - -func TestSelector_Secondary_with_empty_tag_set(t *testing.T) { - t.Parallel() - - primaryNoTags := Server{ - Addr: address.Address("localhost:27017"), - Kind: RSPrimary, - WireVersion: &VersionRange{Min: 6, Max: 21}, - } - firstSecondaryNoTags := Server{ - Addr: address.Address("localhost:27018"), - Kind: RSSecondary, - WireVersion: &VersionRange{Min: 6, Max: 21}, - } - secondSecondaryNoTags := Server{ - Addr: address.Address("localhost:27019"), - Kind: RSSecondary, - WireVersion: &VersionRange{Min: 6, Max: 21}, - } - topologyNoTags := Topology{ - Kind: ReplicaSetWithPrimary, - Servers: []Server{primaryNoTags, firstSecondaryNoTags, secondSecondaryNoTags}, - } - - nonMatchingSet := tag.Set{ - {Name: "foo", Value: "bar"}, - } - emptyTagSet := tag.Set{} - rp := readpref.Secondary( - readpref.WithTagSets(nonMatchingSet, emptyTagSet), - ) - - result, err := ReadPrefSelector(rp).SelectServer(topologyNoTags, topologyNoTags.Servers) - assert.Nil(t, err, "SelectServer error: %v", err) - expectedResult := []Server{firstSecondaryNoTags, secondSecondaryNoTags} - assert.Equal(t, expectedResult, result, "expected result %v, got %v", expectedResult, result) -} - -func TestSelector_Secondary_with_tags_that_do_not_match(t *testing.T) { - t.Parallel() - - subject := readpref.Secondary( - readpref.WithTags("a", "3"), - ) - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) - - require.NoError(t, err) - require.Len(t, result, 0) -} - -func TestSelector_Secondary_with_no_secondaries(t *testing.T) { - t.Parallel() - - subject := readpref.Secondary() - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, []Server{readPrefTestPrimary}) - - require.NoError(t, err) - require.Len(t, result, 0) -} - -func TestSelector_Secondary_with_maxStaleness(t *testing.T) { - t.Parallel() - - subject := readpref.Secondary( - readpref.WithMaxStaleness(time.Duration(90) * time.Second), - ) - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) - - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, []Server{readPrefTestSecondary2}, result) -} - -func TestSelector_Secondary_with_maxStaleness_and_no_primary(t *testing.T) { - t.Parallel() - - subject := readpref.Secondary( - readpref.WithMaxStaleness(time.Duration(90) * time.Second), - ) - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, []Server{readPrefTestSecondary1, readPrefTestSecondary2}) - - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, []Server{readPrefTestSecondary2}, result) -} - -func TestSelector_Nearest(t *testing.T) { - t.Parallel() - - subject := readpref.Nearest() - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) - - require.NoError(t, err) - require.Len(t, result, 3) - require.Equal(t, []Server{readPrefTestPrimary, readPrefTestSecondary1, readPrefTestSecondary2}, result) -} - -func TestSelector_Nearest_with_tags(t *testing.T) { - t.Parallel() - - subject := readpref.Nearest( - readpref.WithTags("a", "1"), - ) - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) - - require.NoError(t, err) - require.Len(t, result, 2) - require.Equal(t, []Server{readPrefTestPrimary, readPrefTestSecondary1}, result) -} - -func TestSelector_Nearest_with_tags_that_do_not_match(t *testing.T) { - t.Parallel() - - subject := readpref.Nearest( - readpref.WithTags("a", "3"), - ) - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) - - require.NoError(t, err) - require.Len(t, result, 0) -} - -func TestSelector_Nearest_with_no_primary(t *testing.T) { - t.Parallel() - - subject := readpref.Nearest() - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, []Server{readPrefTestSecondary1, readPrefTestSecondary2}) - - require.NoError(t, err) - require.Len(t, result, 2) - require.Equal(t, []Server{readPrefTestSecondary1, readPrefTestSecondary2}, result) -} - -func TestSelector_Nearest_with_no_secondaries(t *testing.T) { - t.Parallel() - - subject := readpref.Nearest() - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, []Server{readPrefTestPrimary}) - - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, []Server{readPrefTestPrimary}, result) -} - -func TestSelector_Nearest_with_maxStaleness(t *testing.T) { - t.Parallel() - - subject := readpref.Nearest( - readpref.WithMaxStaleness(time.Duration(90) * time.Second), - ) - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, readPrefTestTopology.Servers) - - require.NoError(t, err) - require.Len(t, result, 2) - require.Equal(t, []Server{readPrefTestPrimary, readPrefTestSecondary2}, result) -} - -func TestSelector_Nearest_with_maxStaleness_and_no_primary(t *testing.T) { - t.Parallel() - - subject := readpref.Nearest( - readpref.WithMaxStaleness(time.Duration(90) * time.Second), - ) - - result, err := ReadPrefSelector(subject).SelectServer(readPrefTestTopology, []Server{readPrefTestSecondary1, readPrefTestSecondary2}) - - require.NoError(t, err) - require.Len(t, result, 1) - require.Equal(t, []Server{readPrefTestSecondary2}, result) -} - -func TestSelector_Max_staleness_is_less_than_90_seconds(t *testing.T) { - t.Parallel() - - subject := readpref.Nearest( - readpref.WithMaxStaleness(time.Duration(50) * time.Second), - ) - - s := Server{ - Addr: address.Address("localhost:27017"), - HeartbeatInterval: time.Duration(10) * time.Second, - LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), - LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), - Kind: RSPrimary, - WireVersion: &VersionRange{Min: 6, Max: 21}, - } - c := Topology{ - Kind: ReplicaSetWithPrimary, - Servers: []Server{s}, - } - - _, err := ReadPrefSelector(subject).SelectServer(c, c.Servers) - - require.Error(t, err) -} - -func TestSelector_Max_staleness_is_too_low(t *testing.T) { - t.Parallel() - - subject := readpref.Nearest( - readpref.WithMaxStaleness(time.Duration(100) * time.Second), - ) - - s := Server{ - Addr: address.Address("localhost:27017"), - HeartbeatInterval: time.Duration(100) * time.Second, - LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), - LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), - Kind: RSPrimary, - WireVersion: &VersionRange{Min: 6, Max: 21}, - } - c := Topology{ - Kind: ReplicaSetWithPrimary, - Servers: []Server{s}, - } - - _, err := ReadPrefSelector(subject).SelectServer(c, c.Servers) - - require.Error(t, err) -} diff --git a/mongo/description/server_kind.go b/mongo/description/server_kind.go deleted file mode 100644 index b71d29d8b5..0000000000 --- a/mongo/description/server_kind.go +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package description - -// ServerKind represents the type of a single server in a topology. -type ServerKind uint32 - -// These constants are the possible types of servers. -const ( - Standalone ServerKind = 1 - RSMember ServerKind = 2 - RSPrimary ServerKind = 4 + RSMember - RSSecondary ServerKind = 8 + RSMember - RSArbiter ServerKind = 16 + RSMember - RSGhost ServerKind = 32 + RSMember - Mongos ServerKind = 256 - LoadBalancer ServerKind = 512 -) - -// String returns a stringified version of the kind or "Unknown" if the kind is invalid. -func (kind ServerKind) String() string { - switch kind { - case Standalone: - return "Standalone" - case RSMember: - return "RSOther" - case RSPrimary: - return "RSPrimary" - case RSSecondary: - return "RSSecondary" - case RSArbiter: - return "RSArbiter" - case RSGhost: - return "RSGhost" - case Mongos: - return "Mongos" - case LoadBalancer: - return "LoadBalancer" - } - - return "Unknown" -} diff --git a/mongo/description/server_selector.go b/mongo/description/server_selector.go deleted file mode 100644 index 176f0fb53a..0000000000 --- a/mongo/description/server_selector.go +++ /dev/null @@ -1,420 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package description - -import ( - "encoding/json" - "fmt" - "math" - "time" - - "go.mongodb.org/mongo-driver/mongo/readpref" - "go.mongodb.org/mongo-driver/tag" -) - -// ServerSelector is an interface implemented by types that can perform server selection given a topology description -// and list of candidate servers. The selector should filter the provided candidates list and return a subset that -// matches some criteria. -type ServerSelector interface { - SelectServer(Topology, []Server) ([]Server, error) -} - -// ServerSelectorFunc is a function that can be used as a ServerSelector. -type ServerSelectorFunc func(Topology, []Server) ([]Server, error) - -// SelectServer implements the ServerSelector interface. -func (ssf ServerSelectorFunc) SelectServer(t Topology, s []Server) ([]Server, error) { - return ssf(t, s) -} - -// serverSelectorInfo contains metadata concerning the server selector for the -// purpose of publication. -type serverSelectorInfo struct { - Type string - Data string `json:",omitempty"` - Selectors []serverSelectorInfo `json:",omitempty"` -} - -// String returns the JSON string representation of the serverSelectorInfo. -func (sss serverSelectorInfo) String() string { - bytes, _ := json.Marshal(sss) - - return string(bytes) -} - -// serverSelectorInfoGetter is an interface that defines an info() method to -// get the serverSelectorInfo. -type serverSelectorInfoGetter interface { - info() serverSelectorInfo -} - -type compositeSelector struct { - selectors []ServerSelector -} - -func (cs *compositeSelector) info() serverSelectorInfo { - csInfo := serverSelectorInfo{Type: "compositeSelector"} - - for _, sel := range cs.selectors { - if getter, ok := sel.(serverSelectorInfoGetter); ok { - csInfo.Selectors = append(csInfo.Selectors, getter.info()) - } - } - - return csInfo -} - -// String returns the JSON string representation of the compositeSelector. -func (cs *compositeSelector) String() string { - return cs.info().String() -} - -// CompositeSelector combines multiple selectors into a single selector by applying them in order to the candidates -// list. -// -// For example, if the initial candidates list is [s0, s1, s2, s3] and two selectors are provided where the first -// matches s0 and s1 and the second matches s1 and s2, the following would occur during server selection: -// -// 1. firstSelector([s0, s1, s2, s3]) -> [s0, s1] -// 2. secondSelector([s0, s1]) -> [s1] -// -// The final list of candidates returned by the composite selector would be [s1]. -func CompositeSelector(selectors []ServerSelector) ServerSelector { - return &compositeSelector{selectors: selectors} -} - -func (cs *compositeSelector) SelectServer(t Topology, candidates []Server) ([]Server, error) { - var err error - for _, sel := range cs.selectors { - candidates, err = sel.SelectServer(t, candidates) - if err != nil { - return nil, err - } - } - return candidates, nil -} - -type latencySelector struct { - latency time.Duration -} - -// LatencySelector creates a ServerSelector which selects servers based on their average RTT values. -func LatencySelector(latency time.Duration) ServerSelector { - return &latencySelector{latency: latency} -} - -func (latencySelector) info() serverSelectorInfo { - return serverSelectorInfo{Type: "latencySelector"} -} - -func (selector latencySelector) String() string { - return selector.info().String() -} - -func (selector *latencySelector) SelectServer(t Topology, candidates []Server) ([]Server, error) { - if selector.latency < 0 { - return candidates, nil - } - if t.Kind == LoadBalanced { - // In LoadBalanced mode, there should only be one server in the topology and it must be selected. - return candidates, nil - } - - switch len(candidates) { - case 0, 1: - return candidates, nil - default: - min := time.Duration(math.MaxInt64) - for _, candidate := range candidates { - if candidate.AverageRTTSet { - if candidate.AverageRTT < min { - min = candidate.AverageRTT - } - } - } - - if min == math.MaxInt64 { - return candidates, nil - } - - max := min + selector.latency - - viableIndexes := make([]int, 0, len(candidates)) - for i, candidate := range candidates { - if candidate.AverageRTTSet { - if candidate.AverageRTT <= max { - viableIndexes = append(viableIndexes, i) - } - } - } - if len(viableIndexes) == len(candidates) { - return candidates, nil - } - result := make([]Server, len(viableIndexes)) - for i, idx := range viableIndexes { - result[i] = candidates[idx] - } - return result, nil - } -} - -type writeServerSelector struct{} - -// WriteSelector selects all the writable servers. -func WriteSelector() ServerSelector { - return writeServerSelector{} -} - -func (writeServerSelector) info() serverSelectorInfo { - return serverSelectorInfo{Type: "writeSelector"} -} - -func (selector writeServerSelector) String() string { - return selector.info().String() -} - -func (writeServerSelector) SelectServer(t Topology, candidates []Server) ([]Server, error) { - switch t.Kind { - case Single, LoadBalanced: - return candidates, nil - default: - // Determine the capacity of the results slice. - selected := 0 - for _, candidate := range candidates { - switch candidate.Kind { - case Mongos, RSPrimary, Standalone: - selected++ - } - } - - // Append candidates to the results slice. - result := make([]Server, 0, selected) - for _, candidate := range candidates { - switch candidate.Kind { - case Mongos, RSPrimary, Standalone: - result = append(result, candidate) - } - } - return result, nil - } -} - -type readPrefServerSelector struct { - rp *readpref.ReadPref - isOutputAggregate bool -} - -// ReadPrefSelector selects servers based on the provided read preference. -func ReadPrefSelector(rp *readpref.ReadPref) ServerSelector { - return readPrefServerSelector{ - rp: rp, - isOutputAggregate: false, - } -} - -func (selector readPrefServerSelector) info() serverSelectorInfo { - return serverSelectorInfo{ - Type: "readPrefSelector", - Data: selector.rp.String(), - } -} - -func (selector readPrefServerSelector) String() string { - return selector.info().String() -} - -func (selector readPrefServerSelector) SelectServer(t Topology, candidates []Server) ([]Server, error) { - if t.Kind == LoadBalanced { - // In LoadBalanced mode, there should only be one server in the topology and it must be selected. We check - // this before checking MaxStaleness support because there's no monitoring in this mode, so the candidate - // server wouldn't have a wire version set, which would result in an error. - return candidates, nil - } - - switch t.Kind { - case Single: - return candidates, nil - case ReplicaSetNoPrimary, ReplicaSetWithPrimary: - return selectForReplicaSet(selector.rp, selector.isOutputAggregate, t, candidates) - case Sharded: - return selectByKind(candidates, Mongos), nil - } - - return nil, nil -} - -// OutputAggregateSelector selects servers based on the provided read preference -// given that the underlying operation is aggregate with an output stage. -func OutputAggregateSelector(rp *readpref.ReadPref) ServerSelector { - return readPrefServerSelector{ - rp: rp, - isOutputAggregate: true, - } -} - -func selectForReplicaSet(rp *readpref.ReadPref, isOutputAggregate bool, t Topology, candidates []Server) ([]Server, error) { - if err := verifyMaxStaleness(rp, t); err != nil { - return nil, err - } - - // If underlying operation is an aggregate with an output stage, only apply read preference - // if all candidates are 5.0+. Otherwise, operate under primary read preference. - if isOutputAggregate { - for _, s := range candidates { - if s.WireVersion.Max < 13 { - return selectByKind(candidates, RSPrimary), nil - } - } - } - - switch rp.Mode() { - case readpref.PrimaryMode: - return selectByKind(candidates, RSPrimary), nil - case readpref.PrimaryPreferredMode: - selected := selectByKind(candidates, RSPrimary) - - if len(selected) == 0 { - selected = selectSecondaries(rp, candidates) - return selectByTagSet(selected, rp.TagSets()), nil - } - - return selected, nil - case readpref.SecondaryPreferredMode: - selected := selectSecondaries(rp, candidates) - selected = selectByTagSet(selected, rp.TagSets()) - if len(selected) > 0 { - return selected, nil - } - return selectByKind(candidates, RSPrimary), nil - case readpref.SecondaryMode: - selected := selectSecondaries(rp, candidates) - return selectByTagSet(selected, rp.TagSets()), nil - case readpref.NearestMode: - selected := selectByKind(candidates, RSPrimary) - selected = append(selected, selectSecondaries(rp, candidates)...) - return selectByTagSet(selected, rp.TagSets()), nil - } - - return nil, fmt.Errorf("unsupported mode: %d", rp.Mode()) -} - -func selectSecondaries(rp *readpref.ReadPref, candidates []Server) []Server { - secondaries := selectByKind(candidates, RSSecondary) - if len(secondaries) == 0 { - return secondaries - } - if maxStaleness, set := rp.MaxStaleness(); set { - primaries := selectByKind(candidates, RSPrimary) - if len(primaries) == 0 { - baseTime := secondaries[0].LastWriteTime - for i := 1; i < len(secondaries); i++ { - if secondaries[i].LastWriteTime.After(baseTime) { - baseTime = secondaries[i].LastWriteTime - } - } - - var selected []Server - for _, secondary := range secondaries { - estimatedStaleness := baseTime.Sub(secondary.LastWriteTime) + secondary.HeartbeatInterval - if estimatedStaleness <= maxStaleness { - selected = append(selected, secondary) - } - } - - return selected - } - - primary := primaries[0] - - var selected []Server - for _, secondary := range secondaries { - estimatedStaleness := secondary.LastUpdateTime.Sub(secondary.LastWriteTime) - primary.LastUpdateTime.Sub(primary.LastWriteTime) + secondary.HeartbeatInterval - if estimatedStaleness <= maxStaleness { - selected = append(selected, secondary) - } - } - return selected - } - - return secondaries -} - -func selectByTagSet(candidates []Server, tagSets []tag.Set) []Server { - if len(tagSets) == 0 { - return candidates - } - - for _, ts := range tagSets { - // If this tag set is empty, we can take a fast path because the empty list is a subset of all tag sets, so - // all candidate servers will be selected. - if len(ts) == 0 { - return candidates - } - - var results []Server - for _, s := range candidates { - // ts is non-empty, so only servers with a non-empty set of tags need to be checked. - if len(s.Tags) > 0 && s.Tags.ContainsAll(ts) { - results = append(results, s) - } - } - - if len(results) > 0 { - return results - } - } - - return []Server{} -} - -func selectByKind(candidates []Server, kind ServerKind) []Server { - // Record the indices of viable candidates first and then append those to the returned slice - // to avoid appending costly Server structs directly as an optimization. - viableIndexes := make([]int, 0, len(candidates)) - for i, s := range candidates { - if s.Kind == kind { - viableIndexes = append(viableIndexes, i) - } - } - if len(viableIndexes) == len(candidates) { - return candidates - } - result := make([]Server, len(viableIndexes)) - for i, idx := range viableIndexes { - result[i] = candidates[idx] - } - return result -} - -func verifyMaxStaleness(rp *readpref.ReadPref, t Topology) error { - maxStaleness, set := rp.MaxStaleness() - if !set { - return nil - } - - if maxStaleness < 90*time.Second { - return fmt.Errorf("max staleness (%s) must be greater than or equal to 90s", maxStaleness) - } - - if len(t.Servers) < 1 { - // Maybe we should return an error here instead? - return nil - } - - // we'll assume all candidates have the same heartbeat interval. - s := t.Servers[0] - idleWritePeriod := 10 * time.Second - - if maxStaleness < s.HeartbeatInterval+idleWritePeriod { - return fmt.Errorf( - "max staleness (%s) must be greater than or equal to the heartbeat interval (%s) plus idle write period (%s)", - maxStaleness, s.HeartbeatInterval, idleWritePeriod, - ) - } - - return nil -} diff --git a/mongo/description/server_test.go b/mongo/description/server_test.go deleted file mode 100644 index 5712086d8e..0000000000 --- a/mongo/description/server_test.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package description - -import ( - "errors" - "testing" - "time" - - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/internal/assert" - "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/tag" -) - -func TestServer(t *testing.T) { - int64ToPtr := func(i64 int64) *int64 { return &i64 } - - t.Run("equals", func(t *testing.T) { - defaultServer := Server{} - // Only some of the Server fields affect equality - testCases := []struct { - name string - server Server - equal bool - }{ - {"empty", Server{}, true}, - {"address", Server{Addr: address.Address("foo")}, true}, - {"arbiters", Server{Arbiters: []string{"foo"}}, false}, - {"rtt", Server{AverageRTT: time.Second}, true}, - {"compression", Server{Compression: []string{"foo"}}, true}, - {"canonicalAddr", Server{CanonicalAddr: address.Address("foo")}, false}, - {"electionID", Server{ElectionID: bson.NewObjectID()}, false}, - {"heartbeatInterval", Server{HeartbeatInterval: time.Second}, true}, - {"hosts", Server{Hosts: []string{"foo"}}, false}, - {"lastError", Server{LastError: errors.New("foo")}, false}, - {"lastUpdateTime", Server{LastUpdateTime: time.Now()}, true}, - {"lastWriteTime", Server{LastWriteTime: time.Now()}, true}, - {"maxBatchCount", Server{MaxBatchCount: 1}, true}, - {"maxDocumentSize", Server{MaxDocumentSize: 1}, true}, - {"maxMessageSize", Server{MaxMessageSize: 1}, true}, - {"members", Server{Members: []address.Address{address.Address("foo")}}, true}, - {"passives", Server{Passives: []string{"foo"}}, false}, - {"passive", Server{Passive: true}, true}, - {"primary", Server{Primary: address.Address("foo")}, false}, - {"readOnly", Server{ReadOnly: true}, true}, - { - "sessionTimeoutMinutes", - Server{ - SessionTimeoutMinutes: int64ToPtr(1), - }, - false, - }, - {"setName", Server{SetName: "foo"}, false}, - {"setVersion", Server{SetVersion: 1}, false}, - {"tags", Server{Tags: tag.Set{tag.Tag{"foo", "bar"}}}, false}, - {"topologyVersion", Server{TopologyVersion: &TopologyVersion{bson.NewObjectID(), 0}}, false}, - {"kind", Server{Kind: Standalone}, false}, - {"wireVersion", Server{WireVersion: &VersionRange{1, 2}}, false}, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - actual := defaultServer.Equal(tc.server) - assert.Equal(t, actual, tc.equal, "expected %v, got %v", tc.equal, actual) - }) - } - }) -} diff --git a/mongo/description/shared_spec_test.go b/mongo/description/shared_spec_test.go deleted file mode 100644 index f7fb250844..0000000000 --- a/mongo/description/shared_spec_test.go +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package description - -import ( - "encoding/json" - "io/ioutil" - "path" - "strconv" - "testing" - "time" - - "go.mongodb.org/mongo-driver/internal/require" - "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/readpref" - "go.mongodb.org/mongo-driver/tag" -) - -type testCase struct { - TopologyDescription topDesc `json:"topology_description"` - Operation string `json:"operation"` - ReadPreference readPref `json:"read_preference"` - SuitableServers []*serverDesc `json:"suitable_servers"` - InLatencyWindow []*serverDesc `json:"in_latency_window"` - HeartbeatFrequencyMS *int `json:"heartbeatFrequencyMS"` - Error *bool -} - -type topDesc struct { - Type string `json:"type"` - Servers []*serverDesc `json:"servers"` -} - -type serverDesc struct { - Address string `json:"address"` - AverageRTTMS *int `json:"avg_rtt_ms"` - MaxWireVersion *int32 `json:"maxWireVersion"` - LastUpdateTime *int `json:"lastUpdateTime"` - LastWrite *lastWriteDate `json:"lastWrite"` - Type string `json:"type"` - Tags map[string]string `json:"tags"` -} - -type lastWriteDate struct { - LastWriteDate lastWriteDateInner `json:"lastWriteDate"` -} - -// TODO(GODRIVER-33): Use proper extended JSON parsing to eliminate the need for this struct. -type lastWriteDateInner struct { - Value string `json:"$numberLong"` -} - -type readPref struct { - MaxStaleness *int `json:"maxStalenessSeconds"` - Mode string `json:"mode"` - TagSets []map[string]string `json:"tag_sets"` -} - -func topologyKindFromString(t *testing.T, s string) TopologyKind { - t.Helper() - - switch s { - case "Single": - return Single - case "ReplicaSet": - return ReplicaSet - case "ReplicaSetNoPrimary": - return ReplicaSetNoPrimary - case "ReplicaSetWithPrimary": - return ReplicaSetWithPrimary - case "Sharded": - return Sharded - case "LoadBalanced": - return LoadBalanced - case "Unknown": - return Unknown - default: - t.Fatalf("unrecognized topology kind: %q", s) - } - - return Unknown -} - -func serverKindFromString(t *testing.T, s string) ServerKind { - t.Helper() - - switch s { - case "Standalone": - return Standalone - case "RSOther": - return RSMember - case "RSPrimary": - return RSPrimary - case "RSSecondary": - return RSSecondary - case "RSArbiter": - return RSArbiter - case "RSGhost": - return RSGhost - case "Mongos": - return Mongos - case "LoadBalancer": - return LoadBalancer - case "PossiblePrimary", "Unknown": - // Go does not have a PossiblePrimary server type and per the SDAM spec, this type is synonymous with Unknown. - return Unknown - default: - t.Fatalf("unrecognized server kind: %q", s) - } - - return Unknown -} - -func findServerByAddress(servers []Server, address string) Server { - for _, server := range servers { - if server.Addr.String() == address { - return server - } - } - - return Server{} -} - -func anyTagsInSets(sets []tag.Set) bool { - for _, set := range sets { - if len(set) > 0 { - return true - } - } - - return false -} - -func compareServers(t *testing.T, expected []*serverDesc, actual []Server) { - require.Equal(t, len(expected), len(actual)) - - for _, expectedServer := range expected { - actualServer := findServerByAddress(actual, expectedServer.Address) - require.NotNil(t, actualServer) - - if expectedServer.AverageRTTMS != nil { - require.Equal(t, *expectedServer.AverageRTTMS, int(actualServer.AverageRTT/time.Millisecond)) - } - - require.Equal(t, expectedServer.Type, actualServer.Kind.String()) - - require.Equal(t, len(expectedServer.Tags), len(actualServer.Tags)) - for _, actualTag := range actualServer.Tags { - expectedTag, ok := expectedServer.Tags[actualTag.Name] - require.True(t, ok) - require.Equal(t, expectedTag, actualTag.Value) - } - } -} - -func selectServers(t *testing.T, test *testCase) error { - servers := make([]Server, 0, len(test.TopologyDescription.Servers)) - - // Times in the JSON files are given as offsets from an unspecified time, but the driver - // stores the lastWrite field as a timestamp, so we arbitrarily choose the current time - // as the base to offset from. - baseTime := time.Now() - - for _, serverDescription := range test.TopologyDescription.Servers { - server := Server{ - Addr: address.Address(serverDescription.Address), - Kind: serverKindFromString(t, serverDescription.Type), - } - - if serverDescription.AverageRTTMS != nil { - server.AverageRTT = time.Duration(*serverDescription.AverageRTTMS) * time.Millisecond - server.AverageRTTSet = true - } - - if test.HeartbeatFrequencyMS != nil { - server.HeartbeatInterval = time.Duration(*test.HeartbeatFrequencyMS) * time.Millisecond - } - - if serverDescription.LastUpdateTime != nil { - ms := int64(*serverDescription.LastUpdateTime) - server.LastUpdateTime = time.Unix(ms/1e3, ms%1e3/1e6) - } - - if serverDescription.LastWrite != nil { - i, err := strconv.ParseInt(serverDescription.LastWrite.LastWriteDate.Value, 10, 64) - - if err != nil { - return err - } - - timeWithOffset := baseTime.Add(time.Duration(i) * time.Millisecond) - server.LastWriteTime = timeWithOffset - } - - if serverDescription.MaxWireVersion != nil { - versionRange := NewVersionRange(0, *serverDescription.MaxWireVersion) - server.WireVersion = &versionRange - } - - if serverDescription.Tags != nil { - server.Tags = tag.NewTagSetFromMap(serverDescription.Tags) - } - - if test.ReadPreference.MaxStaleness != nil && server.WireVersion == nil { - server.WireVersion = &VersionRange{Max: 21} - } - - servers = append(servers, server) - } - - c := Topology{ - Kind: topologyKindFromString(t, test.TopologyDescription.Type), - Servers: servers, - } - - if len(test.ReadPreference.Mode) == 0 { - test.ReadPreference.Mode = "Primary" - } - - readprefMode, err := readpref.ModeFromString(test.ReadPreference.Mode) - if err != nil { - return err - } - - options := make([]readpref.Option, 0, 1) - - tagSets := tag.NewTagSetsFromMaps(test.ReadPreference.TagSets) - if anyTagsInSets(tagSets) { - options = append(options, readpref.WithTagSets(tagSets...)) - } - - if test.ReadPreference.MaxStaleness != nil { - s := time.Duration(*test.ReadPreference.MaxStaleness) * time.Second - options = append(options, readpref.WithMaxStaleness(s)) - } - - rp, err := readpref.New(readprefMode, options...) - if err != nil { - return err - } - - selector := ReadPrefSelector(rp) - if test.Operation == "write" { - selector = CompositeSelector( - []ServerSelector{WriteSelector(), selector}, - ) - } - - result, err := selector.SelectServer(c, c.Servers) - if err != nil { - return err - } - - compareServers(t, test.SuitableServers, result) - - latencySelector := LatencySelector(time.Duration(15) * time.Millisecond) - selector = CompositeSelector( - []ServerSelector{selector, latencySelector}, - ) - - result, err = selector.SelectServer(c, c.Servers) - if err != nil { - return err - } - - compareServers(t, test.InLatencyWindow, result) - - return nil -} - -func runTest(t *testing.T, testsDir string, directory string, filename string) { - filepath := path.Join(testsDir, directory, filename) - content, err := ioutil.ReadFile(filepath) - require.NoError(t, err) - - // Remove ".json" from filename. - filename = filename[:len(filename)-5] - testName := directory + "/" + filename + ":" - - t.Run(testName, func(t *testing.T) { - var test testCase - require.NoError(t, json.Unmarshal(content, &test)) - - err := selectServers(t, &test) - - if test.Error == nil || !*test.Error { - require.NoError(t, err) - } else { - require.Error(t, err) - } - }) -} diff --git a/mongo/description/topology.go b/mongo/description/topology.go deleted file mode 100644 index b0a52931fe..0000000000 --- a/mongo/description/topology.go +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package description - -import ( - "fmt" - - "go.mongodb.org/mongo-driver/mongo/readpref" -) - -// Topology contains information about a MongoDB cluster. -type Topology struct { - Servers []Server - SetName string - Kind TopologyKind - SessionTimeoutMinutes *int64 - CompatibilityErr error -} - -// String implements the Stringer interface. -func (t Topology) String() string { - var serversStr string - for _, s := range t.Servers { - serversStr += "{ " + s.String() + " }, " - } - return fmt.Sprintf("Type: %s, Servers: [%s]", t.Kind, serversStr) -} - -// Equal compares two topology descriptions and returns true if they are equal. -func (t Topology) Equal(other Topology) bool { - if t.Kind != other.Kind { - return false - } - - topoServers := make(map[string]Server) - for _, s := range t.Servers { - topoServers[s.Addr.String()] = s - } - - otherServers := make(map[string]Server) - for _, s := range other.Servers { - otherServers[s.Addr.String()] = s - } - - if len(topoServers) != len(otherServers) { - return false - } - - for _, server := range topoServers { - otherServer := otherServers[server.Addr.String()] - - if !server.Equal(otherServer) { - return false - } - } - - return true -} - -// HasReadableServer returns true if the topology contains a server suitable for reading. -// -// If the Topology's kind is Single or Sharded, the mode parameter is ignored and the function contains true if any of -// the servers in the Topology are of a known type. -// -// For replica sets, the function returns true if the cluster contains a server that matches the provided read -// preference mode. -func (t Topology) HasReadableServer(mode readpref.Mode) bool { - switch t.Kind { - case Single, Sharded: - return hasAvailableServer(t.Servers, 0) - case ReplicaSetWithPrimary: - return hasAvailableServer(t.Servers, mode) - case ReplicaSetNoPrimary, ReplicaSet: - if mode == readpref.PrimaryMode { - return false - } - // invalid read preference - if !mode.IsValid() { - return false - } - - return hasAvailableServer(t.Servers, mode) - } - return false -} - -// HasWritableServer returns true if a topology has a server available for writing. -// -// If the Topology's kind is Single or Sharded, this function returns true if any of the servers in the Topology are of -// a known type. -// -// For replica sets, the function returns true if the replica set contains a primary. -func (t Topology) HasWritableServer() bool { - return t.HasReadableServer(readpref.PrimaryMode) -} - -// hasAvailableServer returns true if any servers are available based on the read preference. -func hasAvailableServer(servers []Server, mode readpref.Mode) bool { - switch mode { - case readpref.PrimaryMode: - for _, s := range servers { - if s.Kind == RSPrimary { - return true - } - } - return false - case readpref.PrimaryPreferredMode, readpref.SecondaryPreferredMode, readpref.NearestMode: - for _, s := range servers { - if s.Kind == RSPrimary || s.Kind == RSSecondary { - return true - } - } - return false - case readpref.SecondaryMode: - for _, s := range servers { - if s.Kind == RSSecondary { - return true - } - } - return false - } - - // read preference is not specified - for _, s := range servers { - switch s.Kind { - case Standalone, - RSMember, - RSPrimary, - RSSecondary, - RSArbiter, - RSGhost, - Mongos: - return true - } - } - - return false -} diff --git a/mongo/description/topology_kind.go b/mongo/description/topology_kind.go deleted file mode 100644 index 6d60c4d874..0000000000 --- a/mongo/description/topology_kind.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package description - -// TopologyKind represents a specific topology configuration. -type TopologyKind uint32 - -// These constants are the available topology configurations. -const ( - Single TopologyKind = 1 - ReplicaSet TopologyKind = 2 - ReplicaSetNoPrimary TopologyKind = 4 + ReplicaSet - ReplicaSetWithPrimary TopologyKind = 8 + ReplicaSet - Sharded TopologyKind = 256 - LoadBalanced TopologyKind = 512 -) - -// String implements the fmt.Stringer interface. -func (kind TopologyKind) String() string { - switch kind { - case Single: - return "Single" - case ReplicaSet: - return "ReplicaSet" - case ReplicaSetNoPrimary: - return "ReplicaSetNoPrimary" - case ReplicaSetWithPrimary: - return "ReplicaSetWithPrimary" - case Sharded: - return "Sharded" - case LoadBalanced: - return "LoadBalanced" - } - - return "Unknown" -} diff --git a/mongo/description/topology_version.go b/mongo/description/topology_version.go deleted file mode 100644 index 2e1b28d588..0000000000 --- a/mongo/description/topology_version.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package description - -import ( - "fmt" - - "go.mongodb.org/mongo-driver/bson" -) - -// TopologyVersion represents a software version. -type TopologyVersion struct { - ProcessID bson.ObjectID - Counter int64 -} - -// NewTopologyVersion creates a TopologyVersion based on doc -func NewTopologyVersion(doc bson.Raw) (*TopologyVersion, error) { - elements, err := doc.Elements() - if err != nil { - return nil, err - } - var tv TopologyVersion - var ok bool - for _, element := range elements { - switch element.Key() { - case "processId": - tv.ProcessID, ok = element.Value().ObjectIDOK() - if !ok { - return nil, fmt.Errorf("expected 'processId' to be a objectID but it's a BSON %s", element.Value().Type) - } - case "counter": - tv.Counter, ok = element.Value().Int64OK() - if !ok { - return nil, fmt.Errorf("expected 'counter' to be an int64 but it's a BSON %s", element.Value().Type) - } - } - } - return &tv, nil -} - -// CompareToIncoming compares the receiver, which represents the currently known TopologyVersion for a server, to an -// incoming TopologyVersion extracted from a server command response. -// -// This returns -1 if the receiver version is less than the response, 0 if the versions are equal, and 1 if the -// receiver version is greater than the response. This comparison is not commutative. -func (tv *TopologyVersion) CompareToIncoming(responseTV *TopologyVersion) int { - if tv == nil || responseTV == nil { - return -1 - } - if tv.ProcessID != responseTV.ProcessID { - return -1 - } - if tv.Counter == responseTV.Counter { - return 0 - } - if tv.Counter < responseTV.Counter { - return -1 - } - return 1 -} diff --git a/mongo/description/version_range.go b/mongo/description/version_range.go deleted file mode 100644 index 5d6270c521..0000000000 --- a/mongo/description/version_range.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package description - -import "fmt" - -// VersionRange represents a range of versions. -type VersionRange struct { - Min int32 - Max int32 -} - -// NewVersionRange creates a new VersionRange given a min and a max. -func NewVersionRange(min, max int32) VersionRange { - return VersionRange{Min: min, Max: max} -} - -// Includes returns a bool indicating whether the supplied integer is included -// in the range. -func (vr VersionRange) Includes(v int32) bool { - return v >= vr.Min && v <= vr.Max -} - -// Equals returns a bool indicating whether the supplied VersionRange is equal. -func (vr *VersionRange) Equals(other *VersionRange) bool { - if vr == nil && other == nil { - return true - } - if vr == nil || other == nil { - return false - } - return vr.Min == other.Min && vr.Max == other.Max -} - -// String implements the fmt.Stringer interface. -func (vr VersionRange) String() string { - return fmt.Sprintf("[%d, %d]", vr.Min, vr.Max) -} diff --git a/mongo/description/version_range_test.go b/mongo/description/version_range_test.go deleted file mode 100644 index 8df28ea3ce..0000000000 --- a/mongo/description/version_range_test.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (C) MongoDB, Inc. 2017-present. -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. You may obtain -// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 - -package description - -import "testing" - -func TestRange_Includes(t *testing.T) { - t.Parallel() - - subject := NewVersionRange(1, 3) - - tests := []struct { - n int32 - expected bool - }{ - {0, false}, - {1, true}, - {2, true}, - {3, true}, - {4, false}, - {10, false}, - } - - for _, test := range tests { - actual := subject.Includes(test.n) - if actual != test.expected { - t.Fatalf("expected %v to be %t", test.n, test.expected) - } - } -} diff --git a/mongo/index_view.go b/mongo/index_view.go index b50092f8cd..84f4d71dc4 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -14,11 +14,12 @@ import ( "strconv" "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/internal/serverselector" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -78,11 +79,15 @@ func (iv IndexView) List(ctx context.Context, opts ...*options.ListIndexesOption closeImplicitSession(sess) return nil, err } + var selector description.ServerSelector + + selector = &serverselector.Composite{ + Selectors: []description.ServerSelector{ + &serverselector.ReadPref{ReadPref: readpref.Primary()}, + &serverselector.Latency{Latency: iv.coll.client.localThreshold}, + }, + } - selector := description.CompositeSelector([]description.ServerSelector{ - description.ReadPrefSelector(readpref.Primary()), - description.LatencySelector(iv.coll.client.localThreshold), - }) selector = makeReadPrefSelector(sess, selector, iv.coll.client.localThreshold) op := operation.NewListIndexes(). Session(sess).CommandMonitor(iv.coll.client.monitor). diff --git a/mongo/session.go b/mongo/session.go index bbcdf6a7f5..778abebc63 100644 --- a/mongo/session.go +++ b/mongo/session.go @@ -12,7 +12,7 @@ import ( "time" "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/internal/serverselector" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" @@ -242,7 +242,7 @@ func (s *Session) AbortTransaction(ctx context.Context) error { return s.clientSession.AbortTransaction() } - selector := makePinnedSelector(s.clientSession, description.WriteSelector()) + selector := makePinnedSelector(s.clientSession, &serverselector.Write{}) s.clientSession.Aborting = true _ = operation.NewAbortTransaction().Session(s.clientSession).ClusterClock(s.client.clock).Database("admin"). @@ -275,7 +275,7 @@ func (s *Session) CommitTransaction(ctx context.Context) error { s.clientSession.RetryingCommit = true } - selector := makePinnedSelector(s.clientSession, description.WriteSelector()) + selector := makePinnedSelector(s.clientSession, &serverselector.Write{}) s.clientSession.Committing = true op := operation.NewCommitTransaction(). diff --git a/mongo/with_transactions_test.go b/mongo/with_transactions_test.go index 30835fc5e9..d737ff9a07 100644 --- a/mongo/with_transactions_test.go +++ b/mongo/with_transactions_test.go @@ -21,11 +21,11 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/integtest" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/topology" ) @@ -586,11 +586,11 @@ func setupConvenientTransactions(t *testing.T, extraClientOpts ...*options.Clien version, err := getServerVersion(client.Database("admin")) assert.Nil(t, err, "getServerVersion error: %v", err) topoKind := client.deployment.(*topology.Topology).Kind() - if compareVersions(version, "4.1") < 0 || topoKind == description.Single { + if compareVersions(version, "4.1") < 0 || topoKind == description.TopologyKindSingle { t.Skip("skipping standalones and versions < 4.1") } - if topoKind != description.Sharded { + if topoKind != description.TopologyKindSharded { return client } diff --git a/x/mongo/driver/auth/auth.go b/x/mongo/driver/auth/auth.go index ac6540233b..e27465dac3 100644 --- a/x/mongo/driver/auth/auth.go +++ b/x/mongo/driver/auth/auth.go @@ -13,8 +13,8 @@ import ( "net/http" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/x/mongo/driver/session" @@ -125,7 +125,7 @@ func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn *mnet.Connec if performAuth == nil { performAuth = func(serv description.Server) bool { // Authentication is possible against all server types except arbiters - return serv.Kind != description.RSArbiter + return serv.Kind != description.ServerKindRSArbiter } } diff --git a/x/mongo/driver/auth/gssapi_test.go b/x/mongo/driver/auth/gssapi_test.go index 59f4e20c12..8df412c54f 100644 --- a/x/mongo/driver/auth/gssapi_test.go +++ b/x/mongo/driver/auth/gssapi_test.go @@ -14,7 +14,7 @@ import ( "testing" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" ) diff --git a/x/mongo/driver/auth/mongodbcr_test.go b/x/mongo/driver/auth/mongodbcr_test.go index 8fcc59820b..3d77cc44ff 100644 --- a/x/mongo/driver/auth/mongodbcr_test.go +++ b/x/mongo/driver/auth/mongodbcr_test.go @@ -12,9 +12,9 @@ import ( "strings" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" . "go.mongodb.org/mongo-driver/x/mongo/driver/auth" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" ) diff --git a/x/mongo/driver/auth/plain_test.go b/x/mongo/driver/auth/plain_test.go index c0dc8d760f..251c3475f9 100644 --- a/x/mongo/driver/auth/plain_test.go +++ b/x/mongo/driver/auth/plain_test.go @@ -14,9 +14,9 @@ import ( "encoding/base64" "go.mongodb.org/mongo-driver/internal/require" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" . "go.mongodb.org/mongo-driver/x/mongo/driver/auth" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" ) diff --git a/x/mongo/driver/auth/scram_test.go b/x/mongo/driver/auth/scram_test.go index 46e6ed9111..851bd6fb94 100644 --- a/x/mongo/driver/auth/scram_test.go +++ b/x/mongo/driver/auth/scram_test.go @@ -11,8 +11,8 @@ import ( "testing" "go.mongodb.org/mongo-driver/internal/assert" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" ) diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go index 773b0f2024..f78ef652fe 100644 --- a/x/mongo/driver/batch_cursor.go +++ b/x/mongo/driver/batch_cursor.go @@ -18,8 +18,9 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/codecutil" "go.mongodb.org/mongo-driver/internal/csot" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -130,7 +131,7 @@ func NewCursorResponse(info ResponseInfo) (CursorResponse, error) { // If the deployment is behind a load balancer and the cursor has a non-zero ID, pin the cursor to a connection and // use the same connection to execute getMore and killCursors commands. - if curresp.Desc.LoadBalanced() && curresp.ID != 0 { + if driverutil.IsServerLoadBalanced(curresp.Desc) && curresp.ID != 0 { // Cache the server as an ErrorProcessor to use when constructing deployments for cursor commands. ep, ok := curresp.Server.(ErrorProcessor) if !ok { @@ -513,7 +514,7 @@ func (lbcd *loadBalancedCursorDeployment) SelectServer(_ context.Context, _ desc } func (lbcd *loadBalancedCursorDeployment) Kind() description.TopologyKind { - return description.LoadBalanced + return description.TopologyKindLoadBalanced } func (lbcd *loadBalancedCursorDeployment) Connection(context.Context) (*mnet.Connection, error) { diff --git a/x/mongo/driver/description/server.go b/x/mongo/driver/description/server.go new file mode 100644 index 0000000000..a5d9943114 --- /dev/null +++ b/x/mongo/driver/description/server.go @@ -0,0 +1,144 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package description + +import ( + "fmt" + "time" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo/address" + "go.mongodb.org/mongo-driver/tag" +) + +// ServerKind represents the type of a single server in a topology. +type ServerKind uint32 + +// These constants are the possible types of servers. +const ( + ServerKindStandalone ServerKind = 1 + ServerKindRSMember ServerKind = 2 + ServerKindRSPrimary ServerKind = 4 + ServerKindRSMember + ServerKindRSSecondary ServerKind = 8 + ServerKindRSMember + ServerKindRSArbiter ServerKind = 16 + ServerKindRSMember + ServerKindRSGhost ServerKind = 32 + ServerKindRSMember + ServerKindMongos ServerKind = 256 + ServerKindLoadBalancer ServerKind = 512 +) + +// UnknownStr represents an unknown server kind. +const UnknownStr = "Unknown" + +// String returns a stringified version of the kind or "Unknown" if the kind is +// invalid. +func (kind ServerKind) String() string { + switch kind { + case ServerKindStandalone: + return "Standalone" + case ServerKindRSMember: + return "RSOther" + case ServerKindRSPrimary: + return "RSPrimary" + case ServerKindRSSecondary: + return "RSSecondary" + case ServerKindRSArbiter: + return "RSArbiter" + case ServerKindRSGhost: + return "RSGhost" + case ServerKindMongos: + return "Mongos" + case ServerKindLoadBalancer: + return "LoadBalancer" + } + + return UnknownStr +} + +// Unknown is an unknown server or topology kind. +const Unknown = 0 + +// TopologyVersion represents a software version. +type TopologyVersion struct { + ProcessID bson.ObjectID + Counter int64 +} + +// VersionRange represents a range of versions. +type VersionRange struct { + Min int32 + Max int32 +} + +// Server contains information about a node in a cluster. This is created from +// hello command responses. If the value of the Kind field is LoadBalancer, only +// the Addr and Kind fields will be set. All other fields will be set to the +// zero value of the field's type. +type Server struct { + Addr address.Address + + Arbiters []string + AverageRTT time.Duration + AverageRTTSet bool + Compression []string // compression methods returned by server + CanonicalAddr address.Address + ElectionID bson.ObjectID + HeartbeatInterval time.Duration + HelloOK bool + Hosts []string + IsCryptd bool + LastError error + LastUpdateTime time.Time + LastWriteTime time.Time + MaxBatchCount uint32 + MaxDocumentSize uint32 + MaxMessageSize uint32 + Members []address.Address + Passives []string + Passive bool + Primary address.Address + ReadOnly bool + ServiceID *bson.ObjectID // Only set for servers that are deployed behind a load balancer. + SessionTimeoutMinutes *int64 + SetName string + SetVersion uint32 + Tags tag.Set + TopologyVersion *TopologyVersion + Kind ServerKind + WireVersion *VersionRange +} + +func (s Server) String() string { + str := fmt.Sprintf("Addr: %s, Type: %s", s.Addr, s.Kind) + if len(s.Tags) != 0 { + str += fmt.Sprintf(", Tag sets: %s", s.Tags) + } + + if s.AverageRTTSet { + str += fmt.Sprintf(", Average RTT: %d", s.AverageRTT) + } + + if s.LastError != nil { + str += fmt.Sprintf(", Last error: %s", s.LastError) + } + return str +} + +// SelectedServer augments the Server type by also including the TopologyKind of +// the topology that includes the server. This type should be used to track the +// state of a server that was selected to perform an operation. +type SelectedServer struct { + Server + Kind TopologyKind +} + +// ServerSelector is an interface implemented by types that can perform server +// selection given a topology description and list of candidate servers. The +// selector should filter the provided candidates list and return a subset that +// matches some criteria. +type ServerSelector interface { + SelectServer(Topology, []Server) ([]Server, error) +} diff --git a/x/mongo/driver/description/topology.go b/x/mongo/driver/description/topology.go new file mode 100644 index 0000000000..e152a7ce90 --- /dev/null +++ b/x/mongo/driver/description/topology.go @@ -0,0 +1,60 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package description + +import "fmt" + +// TopologyKind represents a specific topology configuration. +type TopologyKind uint32 + +// These constants are the available topology configurations. +const ( + TopologyKindSingle TopologyKind = 1 + TopologyKindReplicaSet TopologyKind = 2 + TopologyKindReplicaSetNoPrimary TopologyKind = 4 + TopologyKindReplicaSet + TopologyKindReplicaSetWithPrimary TopologyKind = 8 + TopologyKindReplicaSet + TopologyKindSharded TopologyKind = 256 + TopologyKindLoadBalanced TopologyKind = 512 +) + +// Topology contains information about a MongoDB cluster. +type Topology struct { + Servers []Server + SetName string + Kind TopologyKind + SessionTimeoutMinutes *int64 + CompatibilityErr error +} + +// String implements the Stringer interface. +func (t Topology) String() string { + var serversStr string + for _, s := range t.Servers { + serversStr += "{ " + s.String() + " }, " + } + return fmt.Sprintf("Type: %s, Servers: [%s]", t.Kind, serversStr) +} + +// String implements the fmt.Stringer interface. +func (kind TopologyKind) String() string { + switch kind { + case TopologyKindSingle: + return "Single" + case TopologyKindReplicaSet: + return "ReplicaSet" + case TopologyKindReplicaSetNoPrimary: + return "ReplicaSetNoPrimary" + case TopologyKindReplicaSetWithPrimary: + return "ReplicaSetWithPrimary" + case TopologyKindSharded: + return "Sharded" + case TopologyKindLoadBalanced: + return "LoadBalanced" + } + + return "Unknown" +} diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 65ce6a8fa5..16992b4099 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -19,8 +19,8 @@ import ( "go.mongodb.org/mongo-driver/internal/csot" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" ) @@ -141,8 +141,8 @@ func (ssd SingleServerDeployment) SelectServer(context.Context, description.Serv return ssd.Server, nil } -// Kind implements the Deployment interface. It always returns description.Single. -func (SingleServerDeployment) Kind() description.TopologyKind { return description.Single } +// Kind implements the Deployment interface. It always returns description.TopologyKindSingle. +func (SingleServerDeployment) Kind() description.TopologyKind { return description.TopologyKindSingle } // SingleConnectionDeployment is an implementation of Deployment that always returns the same Connection. This // implementation should only be used for connection handshakes and server heartbeats as it does not implement @@ -159,8 +159,10 @@ func (scd SingleConnectionDeployment) SelectServer(context.Context, description. return scd, nil } -// Kind implements the Deployment interface. It always returns description.Single. -func (SingleConnectionDeployment) Kind() description.TopologyKind { return description.Single } +// Kind implements the Deployment interface. It always returns description.TopologyKindSingle. +func (SingleConnectionDeployment) Kind() description.TopologyKind { + return description.TopologyKindSingle +} // Connection implements the Server interface. It always returns the embedded connection. func (scd SingleConnectionDeployment) Connection(context.Context) (*mnet.Connection, error) { diff --git a/x/mongo/driver/drivertest/channel_conn.go b/x/mongo/driver/drivertest/channel_conn.go index c1eb6f19c6..e4952cce32 100644 --- a/x/mongo/driver/drivertest/channel_conn.go +++ b/x/mongo/driver/drivertest/channel_conn.go @@ -11,8 +11,8 @@ import ( "errors" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) diff --git a/x/mongo/driver/errors.go b/x/mongo/driver/errors.go index 411479130a..3a189318cb 100644 --- a/x/mongo/driver/errors.go +++ b/x/mongo/driver/errors.go @@ -14,8 +14,9 @@ import ( "strings" "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) // LegacyNotPrimaryErrMsg is the error message that older MongoDB servers (see @@ -498,7 +499,7 @@ func ExtractErrorFromServerResponse(doc bsoncore.Document) error { if !ok { break } - version, err := description.NewTopologyVersion(bson.Raw(doc)) + version, err := driverutil.NewTopologyVersion(bson.Raw(doc)) if err == nil { tv = version } diff --git a/x/mongo/driver/integration/aggregate_test.go b/x/mongo/driver/integration/aggregate_test.go index 1dae4fec3a..824c06f993 100644 --- a/x/mongo/driver/integration/aggregate_test.go +++ b/x/mongo/driver/integration/aggregate_test.go @@ -16,7 +16,7 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/integtest" "go.mongodb.org/mongo-driver/internal/require" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/internal/serverselector" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" @@ -43,7 +43,7 @@ func setUpMonitor() (*event.CommandMonitor, chan *event.CommandStartedEvent, cha } func skipIfBelow32(ctx context.Context, t *testing.T, topo *topology.Topology) { - server, err := topo.SelectServer(ctx, description.WriteSelector()) + server, err := topo.SelectServer(ctx, &serverselector.Write{}) noerr(t, err) versionCmd := bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "serverStatus", 1)) @@ -77,12 +77,12 @@ func TestAggregate(t *testing.T) { bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "x", 1)), bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "x", 1)), ).Collection(collName).Database(dbName). - Deployment(top).ServerSelector(description.WriteSelector()).Execute(context.Background()) + Deployment(top).ServerSelector(&serverselector.Write{}).Execute(context.Background()) noerr(t, err) clearChannels(started, succeeded, failed) op := operation.NewAggregate(bsoncore.BuildDocumentFromElements(nil)). - Collection(collName).Database(dbName).Deployment(top).ServerSelector(description.WriteSelector()). + Collection(collName).Database(dbName).Deployment(top).ServerSelector(&serverselector.Write{}). CommandMonitor(monitor).BatchSize(2) err = op.Execute(context.Background()) noerr(t, err) @@ -138,7 +138,7 @@ func TestAggregate(t *testing.T) { ), ), )).Collection(integtest.ColName(t)).Database(dbName).Deployment(integtest.Topology(t)). - ServerSelector(description.WriteSelector()).BatchSize(2) + ServerSelector(&serverselector.Write{}).BatchSize(2) err := op.Execute(context.Background()) noerr(t, err) cursor, err := op.Result(driver.CursorOptions{BatchSize: 2}) @@ -173,7 +173,7 @@ func TestAggregate(t *testing.T) { autoInsertDocs(t, wc, ds...) op := operation.NewAggregate(bsoncore.BuildArray(nil)).Collection(integtest.ColName(t)).Database(dbName). - Deployment(integtest.Topology(t)).ServerSelector(description.WriteSelector()).AllowDiskUse(true) + Deployment(integtest.Topology(t)).ServerSelector(&serverselector.Write{}).AllowDiskUse(true) err := op.Execute(context.Background()) if err != nil { t.Errorf("Expected no error from allowing disk use, but got %v", err) diff --git a/x/mongo/driver/integration/main_test.go b/x/mongo/driver/integration/main_test.go index a22505a09a..7c72d6d79c 100644 --- a/x/mongo/driver/integration/main_test.go +++ b/x/mongo/driver/integration/main_test.go @@ -16,7 +16,7 @@ import ( "go.mongodb.org/mongo-driver/internal/integtest" "go.mongodb.org/mongo-driver/internal/require" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/internal/serverselector" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" @@ -131,7 +131,7 @@ func runCommand(s driver.Server, db string, cmd bsoncore.Document) (bsoncore.Doc // dropCollection drops the collection in the test cluster. func dropCollection(t *testing.T, dbname, colname string) { err := operation.NewCommand(bsoncore.BuildDocument(nil, bsoncore.AppendStringElement(nil, "drop", colname))). - Database(dbname).ServerSelector(description.WriteSelector()).Deployment(integtest.Topology(t)). + Database(dbname).ServerSelector(&serverselector.Write{}).Deployment(integtest.Topology(t)). Execute(context.Background()) if de, ok := err.(driver.Error); err != nil && !(ok && de.NamespaceNotFound()) { require.NoError(t, err) @@ -149,7 +149,7 @@ func insertDocs(t *testing.T, dbname, colname string, writeConcern *writeconcern Collection(colname). Database(dbname). Deployment(integtest.Topology(t)). - ServerSelector(description.WriteSelector()). + ServerSelector(&serverselector.Write{}). WriteConcern(writeConcern). Execute(context.Background()) require.NoError(t, err) diff --git a/x/mongo/driver/integration/scram_test.go b/x/mongo/driver/integration/scram_test.go index fcf40e837b..2a5b77a96e 100644 --- a/x/mongo/driver/integration/scram_test.go +++ b/x/mongo/driver/integration/scram_test.go @@ -12,8 +12,9 @@ import ( "os" "testing" + "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/internal/integtest" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/internal/serverselector" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" @@ -32,13 +33,13 @@ func TestSCRAM(t *testing.T) { t.Skip("Skipping because authentication is required") } - server, err := integtest.Topology(t).SelectServer(context.Background(), description.WriteSelector()) + server, err := integtest.Topology(t).SelectServer(context.Background(), &serverselector.Write{}) noerr(t, err) serverConnection, err := server.Connection(context.Background()) noerr(t, err) defer serverConnection.Close() - if !serverConnection.Description().WireVersion.Includes(7) { + if !driverutil.VersionRangeIncludes(*serverConnection.Description().WireVersion, 7) { t.Skip("Skipping because MongoDB 4.0 is needed for SCRAM-SHA-256") } @@ -142,7 +143,7 @@ func testScramUserAuthWithMech(t *testing.T, c scramTestCase, mech string) error func runScramAuthTest(t *testing.T, credential options.Credential) error { t.Helper() topology := integtest.TopologyWithCredential(t, credential) - server, err := topology.SelectServer(context.Background(), description.WriteSelector()) + server, err := topology.SelectServer(context.Background(), &serverselector.Write{}) noerr(t, err) cmd := bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "dbstats", 1)) diff --git a/x/mongo/driver/mnet/connection.go b/x/mongo/driver/mnet/connection.go index f8542c4e0a..495c3fe58d 100644 --- a/x/mongo/driver/mnet/connection.go +++ b/x/mongo/driver/mnet/connection.go @@ -11,7 +11,7 @@ import ( "io" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) // ReadWriteCloser represents a Connection where server operations diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 432692295f..6a4a80da15 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -23,12 +23,13 @@ import ( "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/internal/handshake" "go.mongodb.org/mongo-driver/internal/logger" + "go.mongodb.org/mongo-driver/internal/serverselector" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/session" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" @@ -326,7 +327,7 @@ func filterDeprioritizedServers(candidates, deprioritized []description.Server) // Iterate over the candidates and append them to the allowdIndexes slice if // they are not in the deprioritizedServers list. for _, candidate := range candidates { - if srv, ok := dpaSet[candidate.Addr]; !ok || !srv.Equal(candidate) { + if srv, ok := dpaSet[candidate.Addr]; !ok || !driverutil.EqualServers(*srv, candidate) { allowed = append(allowed, candidate) } } @@ -381,10 +382,13 @@ func (op Operation) selectServer( if rp == nil { rp = readpref.Primary() } - selector = description.CompositeSelector([]description.ServerSelector{ - description.ReadPrefSelector(rp), - description.LatencySelector(defaultLocalThreshold), - }) + + selector = &serverselector.Composite{ + Selectors: []description.ServerSelector{ + &serverselector.ReadPref{ReadPref: rp}, + &serverselector.Latency{Latency: defaultLocalThreshold}, + }, + } } oss := &opServerSelector{ @@ -431,7 +435,7 @@ func (op Operation) getServerAndConnection( } // If we're in load balanced mode and this is the first operation in a transaction, pin the session to a connection. - if conn.Description().LoadBalanced() && op.Client != nil && op.Client.TransactionStarting() { + if driverutil.IsServerLoadBalanced(conn.Description()) && op.Client != nil && op.Client.TransactionStarting() { if conn.Pinner == nil { // Close the original connection to avoid a leak. _ = conn.Close() @@ -572,7 +576,7 @@ func (op Operation) Execute(ctx context.Context) error { if conn != nil { // If we are dealing with a sharded cluster, then mark the failed server // as "deprioritized". - if desc := conn.Description; desc != nil && op.Deployment.Kind() == description.Sharded { + if desc := conn.Description; desc != nil && op.Deployment.Kind() == description.TopologyKindSharded { deprioritizedServers = []description.Server{conn.Description()} } @@ -684,7 +688,10 @@ func (op Operation) Execute(ctx context.Context) error { maxTimeMS = 0 } - desc := description.SelectedServer{Server: conn.Description(), Kind: op.Deployment.Kind()} + desc := description.SelectedServer{ + Server: conn.Description(), + Kind: op.Deployment.Kind(), + } if batching { targetBatchSize := desc.MaxDocumentSize @@ -1415,7 +1422,9 @@ func (op Operation) addServerAPI(dst []byte) []byte { } func (op Operation) addReadConcern(dst []byte, desc description.SelectedServer) ([]byte, error) { - if op.MinimumReadConcernWireVersion > 0 && (desc.WireVersion == nil || !desc.WireVersion.Includes(op.MinimumReadConcernWireVersion)) { + if op.MinimumReadConcernWireVersion > 0 && (desc.WireVersion == nil || + !driverutil.VersionRangeIncludes(*desc.WireVersion, op.MinimumReadConcernWireVersion)) { + return dst, nil } rc := op.ReadConcern @@ -1466,7 +1475,9 @@ func (op Operation) addReadConcern(dst []byte, desc description.SelectedServer) } func (op Operation) addWriteConcern(dst []byte, desc description.SelectedServer) ([]byte, error) { - if op.MinimumWriteConcernWireVersion > 0 && (desc.WireVersion == nil || !desc.WireVersion.Includes(op.MinimumWriteConcernWireVersion)) { + if op.MinimumWriteConcernWireVersion > 0 && (desc.WireVersion == nil || + !driverutil.VersionRangeIncludes(*desc.WireVersion, op.MinimumWriteConcernWireVersion)) { + return dst, nil } wc := op.WriteConcern @@ -1645,7 +1656,9 @@ func (op Operation) createReadPref(desc description.SelectedServer, isOpQuery bo // TODO(GODRIVER-2231): Instead of checking if isOutputAggregate and desc.Server.WireVersion.Max < 13, somehow check // TODO if supplied readPreference was "overwritten" with primary in description.selectForReplicaSet. - if desc.Server.Kind == description.Standalone || (isOpQuery && desc.Server.Kind != description.Mongos) || + if desc.Server.Kind == description.ServerKindStandalone || (isOpQuery && + desc.Server.Kind != description.ServerKindMongos) || + op.Type == Write || (op.IsOutputAggregate && desc.Server.WireVersion.Max < 13) { // Don't send read preference for: // 1. all standalones @@ -1663,7 +1676,7 @@ func (op Operation) createReadPref(desc description.SelectedServer, isOpQuery bo } if rp == nil { - if desc.Kind == description.Single && desc.Server.Kind != description.Mongos { + if desc.Kind == description.TopologyKindSingle && desc.Server.Kind != description.ServerKindMongos { doc = bsoncore.AppendStringElement(doc, "mode", "primaryPreferred") doc, _ = bsoncore.AppendDocumentEnd(doc, idx) return doc, nil @@ -1673,10 +1686,10 @@ func (op Operation) createReadPref(desc description.SelectedServer, isOpQuery bo switch rp.Mode() { case readpref.PrimaryMode: - if desc.Server.Kind == description.Mongos { + if desc.Server.Kind == description.ServerKindMongos { return nil, nil } - if desc.Kind == description.Single { + if desc.Kind == description.TopologyKindSingle { doc = bsoncore.AppendStringElement(doc, "mode", "primaryPreferred") doc, _ = bsoncore.AppendDocumentEnd(doc, idx) return doc, nil @@ -1693,7 +1706,9 @@ func (op Operation) createReadPref(desc description.SelectedServer, isOpQuery bo doc = bsoncore.AppendStringElement(doc, "mode", "primaryPreferred") case readpref.SecondaryPreferredMode: _, ok := rp.MaxStaleness() - if desc.Server.Kind == description.Mongos && isOpQuery && !ok && len(rp.TagSets()) == 0 && rp.HedgeEnabled() == nil { + if desc.Server.Kind == description.ServerKindMongos && isOpQuery && !ok && len(rp.TagSets()) == 0 && + rp.HedgeEnabled() == nil { + return nil, nil } doc = bsoncore.AppendStringElement(doc, "mode", "secondaryPreferred") @@ -1740,7 +1755,7 @@ func (op Operation) createReadPref(desc description.SelectedServer, isOpQuery bo } func (op Operation) secondaryOK(desc description.SelectedServer) wiremessage.QueryFlag { - if desc.Kind == description.Single && desc.Server.Kind != description.Mongos { + if desc.Kind == description.TopologyKindSingle && desc.Server.Kind != description.ServerKindMongos { return wiremessage.SecondaryOK } @@ -2045,5 +2060,5 @@ func sessionsSupported(wireVersion *description.VersionRange) bool { // retryWritesSupported returns true if this description represents a server that supports retryable writes. func retryWritesSupported(s description.Server) bool { - return s.SessionTimeoutMinutes != nil && s.Kind != description.Standalone + return s.SessionTimeoutMinutes != nil && s.Kind != description.ServerKindStandalone } diff --git a/x/mongo/driver/operation/abort_transaction.go b/x/mongo/driver/operation/abort_transaction.go index 9413727130..9aa5bd4e32 100644 --- a/x/mongo/driver/operation/abort_transaction.go +++ b/x/mongo/driver/operation/abort_transaction.go @@ -12,10 +12,10 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/x/mongo/driver/operation/aggregate.go b/x/mongo/driver/operation/aggregate.go index 286bfe25a6..3fe4ca2fe3 100644 --- a/x/mongo/driver/operation/aggregate.go +++ b/x/mongo/driver/operation/aggregate.go @@ -13,12 +13,12 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -136,8 +136,7 @@ func (a *Aggregate) command(dst []byte, desc description.SelectedServer) ([]byte dst = bsoncore.AppendBooleanElement(dst, "bypassDocumentValidation", *a.bypassDocumentValidation) } if a.collation != nil { - - if desc.WireVersion == nil || !desc.WireVersion.Includes(5) { + if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 5) { return nil, errors.New("the 'collation' command parameter requires a minimum server wire version of 5") } dst = bsoncore.AppendDocumentElement(dst, "collation", a.collation) diff --git a/x/mongo/driver/operation/command.go b/x/mongo/driver/operation/command.go index 35283794a3..86f61ee98d 100644 --- a/x/mongo/driver/operation/command.go +++ b/x/mongo/driver/operation/command.go @@ -13,10 +13,10 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/logger" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/x/mongo/driver/operation/commit_transaction.go b/x/mongo/driver/operation/commit_transaction.go index 11c6f69ddf..42a79e2f56 100644 --- a/x/mongo/driver/operation/commit_transaction.go +++ b/x/mongo/driver/operation/commit_transaction.go @@ -13,10 +13,10 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/x/mongo/driver/operation/count.go b/x/mongo/driver/operation/count.go index bd1204cd5e..5625b79bd9 100644 --- a/x/mongo/driver/operation/count.go +++ b/x/mongo/driver/operation/count.go @@ -14,11 +14,11 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/x/mongo/driver/operation/create.go b/x/mongo/driver/operation/create.go index 45b26cb707..b1e40f977c 100644 --- a/x/mongo/driver/operation/create.go +++ b/x/mongo/driver/operation/create.go @@ -11,10 +11,11 @@ import ( "errors" "go.mongodb.org/mongo-driver/event" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -91,7 +92,7 @@ func (c *Create) command(dst []byte, desc description.SelectedServer) ([]byte, e dst = bsoncore.AppendDocumentElement(dst, "changeStreamPreAndPostImages", c.changeStreamPreAndPostImages) } if c.collation != nil { - if desc.WireVersion == nil || !desc.WireVersion.Includes(5) { + if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 5) { return nil, errors.New("the 'collation' command parameter requires a minimum server wire version of 5") } dst = bsoncore.AppendDocumentElement(dst, "collation", c.collation) diff --git a/x/mongo/driver/operation/create_indexes.go b/x/mongo/driver/operation/create_indexes.go index a8e55313ed..0192379e2b 100644 --- a/x/mongo/driver/operation/create_indexes.go +++ b/x/mongo/driver/operation/create_indexes.go @@ -14,10 +14,10 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -125,7 +125,7 @@ func (ci *CreateIndexes) Execute(ctx context.Context) error { func (ci *CreateIndexes) command(dst []byte, desc description.SelectedServer) ([]byte, error) { dst = bsoncore.AppendStringElement(dst, "createIndexes", ci.collection) if ci.commitQuorum.Type != bsoncore.Type(0) { - if desc.WireVersion == nil || !desc.WireVersion.Includes(9) { + if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 9) { return nil, errors.New("the 'commitQuorum' command parameter requires a minimum server wire version of 9") } dst = bsoncore.AppendValueElement(dst, "commitQuorum", ci.commitQuorum) diff --git a/x/mongo/driver/operation/create_search_indexes.go b/x/mongo/driver/operation/create_search_indexes.go index cb0d807952..8856651e6e 100644 --- a/x/mongo/driver/operation/create_search_indexes.go +++ b/x/mongo/driver/operation/create_search_indexes.go @@ -14,9 +14,9 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/event" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/x/mongo/driver/operation/delete.go b/x/mongo/driver/operation/delete.go index e3546c7ad5..4214393017 100644 --- a/x/mongo/driver/operation/delete.go +++ b/x/mongo/driver/operation/delete.go @@ -15,10 +15,10 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/internal/logger" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -128,7 +128,7 @@ func (d *Delete) command(dst []byte, desc description.SelectedServer) ([]byte, e dst = bsoncore.AppendBooleanElement(dst, "ordered", *d.ordered) } if d.hint != nil && *d.hint { - if desc.WireVersion == nil || !desc.WireVersion.Includes(5) { + if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 5) { return nil, errors.New("the 'hint' command parameter requires a minimum server wire version of 5") } if !d.writeConcern.Acknowledged() { diff --git a/x/mongo/driver/operation/distinct.go b/x/mongo/driver/operation/distinct.go index af65eb7d31..a13bd2b7b4 100644 --- a/x/mongo/driver/operation/distinct.go +++ b/x/mongo/driver/operation/distinct.go @@ -13,11 +13,11 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -113,7 +113,7 @@ func (d *Distinct) Execute(ctx context.Context) error { func (d *Distinct) command(dst []byte, desc description.SelectedServer) ([]byte, error) { dst = bsoncore.AppendStringElement(dst, "distinct", d.collection) if d.collation != nil { - if desc.WireVersion == nil || !desc.WireVersion.Includes(5) { + if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 5) { return nil, errors.New("the 'collation' command parameter requires a minimum server wire version of 5") } dst = bsoncore.AppendDocumentElement(dst, "collation", d.collation) diff --git a/x/mongo/driver/operation/drop_collection.go b/x/mongo/driver/operation/drop_collection.go index 8c65967564..5e98886024 100644 --- a/x/mongo/driver/operation/drop_collection.go +++ b/x/mongo/driver/operation/drop_collection.go @@ -14,10 +14,10 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/x/mongo/driver/operation/drop_database.go b/x/mongo/driver/operation/drop_database.go index a8f9b45ba4..a2c3daae58 100644 --- a/x/mongo/driver/operation/drop_database.go +++ b/x/mongo/driver/operation/drop_database.go @@ -12,10 +12,10 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/x/mongo/driver/operation/drop_indexes.go b/x/mongo/driver/operation/drop_indexes.go index 0c3d459707..597d04ac88 100644 --- a/x/mongo/driver/operation/drop_indexes.go +++ b/x/mongo/driver/operation/drop_indexes.go @@ -14,10 +14,10 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/x/mongo/driver/operation/drop_search_index.go b/x/mongo/driver/operation/drop_search_index.go index 3992c83165..d060234360 100644 --- a/x/mongo/driver/operation/drop_search_index.go +++ b/x/mongo/driver/operation/drop_search_index.go @@ -13,9 +13,9 @@ import ( "time" "go.mongodb.org/mongo-driver/event" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/x/mongo/driver/operation/end_sessions.go b/x/mongo/driver/operation/end_sessions.go index 52f300bb7f..a96cb2789b 100644 --- a/x/mongo/driver/operation/end_sessions.go +++ b/x/mongo/driver/operation/end_sessions.go @@ -12,9 +12,9 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/x/mongo/driver/operation/find.go b/x/mongo/driver/operation/find.go index d6cd90aaaa..1e34b8da8a 100644 --- a/x/mongo/driver/operation/find.go +++ b/x/mongo/driver/operation/find.go @@ -14,11 +14,11 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/internal/logger" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -115,7 +115,7 @@ func (f *Find) Execute(ctx context.Context) error { func (f *Find) command(dst []byte, desc description.SelectedServer) ([]byte, error) { dst = bsoncore.AppendStringElement(dst, "find", f.collection) if f.allowDiskUse != nil { - if desc.WireVersion == nil || !desc.WireVersion.Includes(4) { + if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 4) { return nil, errors.New("the 'allowDiskUse' command parameter requires a minimum server wire version of 4") } dst = bsoncore.AppendBooleanElement(dst, "allowDiskUse", *f.allowDiskUse) @@ -130,7 +130,7 @@ func (f *Find) command(dst []byte, desc description.SelectedServer) ([]byte, err dst = bsoncore.AppendInt32Element(dst, "batchSize", *f.batchSize) } if f.collation != nil { - if desc.WireVersion == nil || !desc.WireVersion.Includes(5) { + if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 5) { return nil, errors.New("the 'collation' command parameter requires a minimum server wire version of 5") } dst = bsoncore.AppendDocumentElement(dst, "collation", f.collation) diff --git a/x/mongo/driver/operation/find_and_modify.go b/x/mongo/driver/operation/find_and_modify.go index 5af8a0c105..12d241f710 100644 --- a/x/mongo/driver/operation/find_and_modify.go +++ b/x/mongo/driver/operation/find_and_modify.go @@ -15,10 +15,10 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -152,7 +152,7 @@ func (fam *FindAndModify) command(dst []byte, desc description.SelectedServer) ( dst = bsoncore.AppendStringElement(dst, "findAndModify", fam.collection) if fam.arrayFilters != nil { - if desc.WireVersion == nil || !desc.WireVersion.Includes(6) { + if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 6) { return nil, errors.New("the 'arrayFilters' command parameter requires a minimum server wire version of 6") } dst = bsoncore.AppendArrayElement(dst, "arrayFilters", fam.arrayFilters) @@ -163,7 +163,7 @@ func (fam *FindAndModify) command(dst []byte, desc description.SelectedServer) ( } if fam.collation != nil { - if desc.WireVersion == nil || !desc.WireVersion.Includes(5) { + if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 5) { return nil, errors.New("the 'collation' command parameter requires a minimum server wire version of 5") } dst = bsoncore.AppendDocumentElement(dst, "collation", fam.collation) @@ -200,7 +200,7 @@ func (fam *FindAndModify) command(dst []byte, desc description.SelectedServer) ( } if fam.hint.Type != bsoncore.Type(0) { - if desc.WireVersion == nil || !desc.WireVersion.Includes(8) { + if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 8) { return nil, errors.New("the 'hint' command parameter requires a minimum server wire version of 8") } if !fam.writeConcern.Acknowledged() { diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index 5a9d9bb36b..8e6c59de38 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -19,10 +19,10 @@ import ( "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/internal/handshake" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/version" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -123,7 +123,7 @@ func (h *Hello) LoadBalanced(lb bool) *Hello { // Result returns the result of executing this operation. func (h *Hello) Result(addr address.Address) description.Server { - return description.NewServer(addr, bson.Raw(h.res)) + return driverutil.NewServerDescription(addr, bson.Raw(h.res)) } const dockerEnvPath = "/.dockerenv" diff --git a/x/mongo/driver/operation/insert.go b/x/mongo/driver/operation/insert.go index 7da70653ee..0fe6ca8c82 100644 --- a/x/mongo/driver/operation/insert.go +++ b/x/mongo/driver/operation/insert.go @@ -15,10 +15,10 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/internal/logger" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -120,7 +120,9 @@ func (i *Insert) Execute(ctx context.Context) error { func (i *Insert) command(dst []byte, desc description.SelectedServer) ([]byte, error) { dst = bsoncore.AppendStringElement(dst, "insert", i.collection) - if i.bypassDocumentValidation != nil && (desc.WireVersion != nil && desc.WireVersion.Includes(4)) { + if i.bypassDocumentValidation != nil && (desc.WireVersion != nil && + driverutil.VersionRangeIncludes(*desc.WireVersion, 4)) { + dst = bsoncore.AppendBooleanElement(dst, "bypassDocumentValidation", *i.bypassDocumentValidation) } if i.comment.Type != bsoncore.Type(0) { diff --git a/x/mongo/driver/operation/listDatabases.go b/x/mongo/driver/operation/listDatabases.go index c70248e2a9..2dc4946f20 100644 --- a/x/mongo/driver/operation/listDatabases.go +++ b/x/mongo/driver/operation/listDatabases.go @@ -15,10 +15,10 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/x/mongo/driver/operation/list_collections.go b/x/mongo/driver/operation/list_collections.go index 6fe68fa033..f208915f87 100644 --- a/x/mongo/driver/operation/list_collections.go +++ b/x/mongo/driver/operation/list_collections.go @@ -13,10 +13,10 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/x/mongo/driver/operation/list_indexes.go b/x/mongo/driver/operation/list_indexes.go index 79d50eca95..d4cbe8a337 100644 --- a/x/mongo/driver/operation/list_indexes.go +++ b/x/mongo/driver/operation/list_indexes.go @@ -13,9 +13,9 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/x/mongo/driver/operation/update.go b/x/mongo/driver/operation/update.go index 65c33bc5f6..d470e82d21 100644 --- a/x/mongo/driver/operation/update.go +++ b/x/mongo/driver/operation/update.go @@ -16,10 +16,10 @@ import ( "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/internal/logger" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) @@ -173,7 +173,7 @@ func (u *Update) Execute(ctx context.Context) error { func (u *Update) command(dst []byte, desc description.SelectedServer) ([]byte, error) { dst = bsoncore.AppendStringElement(dst, "update", u.collection) if u.bypassDocumentValidation != nil && - (desc.WireVersion != nil && desc.WireVersion.Includes(4)) { + (desc.WireVersion != nil && driverutil.VersionRangeIncludes(*desc.WireVersion, 4)) { dst = bsoncore.AppendBooleanElement(dst, "bypassDocumentValidation", *u.bypassDocumentValidation) } @@ -186,7 +186,7 @@ func (u *Update) command(dst []byte, desc description.SelectedServer) ([]byte, e } if u.hint != nil && *u.hint { - if desc.WireVersion == nil || !desc.WireVersion.Includes(5) { + if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 5) { return nil, errors.New("the 'hint' command parameter requires a minimum server wire version of 5") } if !u.writeConcern.Acknowledged() { @@ -194,7 +194,7 @@ func (u *Update) command(dst []byte, desc description.SelectedServer) ([]byte, e } } if u.arrayFilters != nil && *u.arrayFilters { - if desc.WireVersion == nil || !desc.WireVersion.Includes(6) { + if desc.WireVersion == nil || !driverutil.VersionRangeIncludes(*desc.WireVersion, 6) { return nil, errors.New("the 'arrayFilters' command parameter requires a minimum server wire version of 6") } } diff --git a/x/mongo/driver/operation/update_search_index.go b/x/mongo/driver/operation/update_search_index.go index 64f2da7f6f..f9f238f409 100644 --- a/x/mongo/driver/operation/update_search_index.go +++ b/x/mongo/driver/operation/update_search_index.go @@ -13,9 +13,9 @@ import ( "time" "go.mongodb.org/mongo-driver/event" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/session" ) diff --git a/x/mongo/driver/operation_test.go b/x/mongo/driver/operation_test.go index 87c629c662..0e3da7007c 100644 --- a/x/mongo/driver/operation_test.go +++ b/x/mongo/driver/operation_test.go @@ -21,12 +21,12 @@ import ( "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/internal/uuid" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/tag" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/session" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" @@ -164,7 +164,7 @@ func TestOperation(t *testing.T) { descNotRetryableStandalone := description.Server{ WireVersion: &description.VersionRange{Min: 6, Max: 21}, SessionTimeoutMinutes: int64ToPtr(1), - Kind: description.Standalone, + Kind: description.ServerKindStandalone, } testCases := []struct { @@ -464,20 +464,20 @@ func TestOperation(t *testing.T) { opQuery bool want bsoncore.Document }{ - {"nil/single/mongos", nil, description.Mongos, description.Single, false, nil}, - {"nil/single/secondary", nil, description.RSSecondary, description.Single, false, rpPrimaryPreferred}, - {"primary/mongos", readpref.Primary(), description.Mongos, description.Sharded, false, nil}, - {"primary/single", readpref.Primary(), description.RSPrimary, description.Single, false, rpPrimaryPreferred}, - {"primary/primary", readpref.Primary(), description.RSPrimary, description.ReplicaSet, false, nil}, - {"primaryPreferred", readpref.PrimaryPreferred(), description.RSSecondary, description.ReplicaSet, false, rpPrimaryPreferred}, - {"secondaryPreferred/mongos/opquery", readpref.SecondaryPreferred(), description.Mongos, description.Sharded, true, nil}, - {"secondaryPreferred", readpref.SecondaryPreferred(), description.RSSecondary, description.ReplicaSet, false, rpSecondaryPreferred}, - {"secondary", readpref.Secondary(), description.RSSecondary, description.ReplicaSet, false, rpSecondary}, - {"nearest", readpref.Nearest(), description.RSSecondary, description.ReplicaSet, false, rpNearest}, + {"nil/single/mongos", nil, description.ServerKindMongos, description.TopologyKindSingle, false, nil}, + {"nil/single/secondary", nil, description.ServerKindRSSecondary, description.TopologyKindSingle, false, rpPrimaryPreferred}, + {"primary/mongos", readpref.Primary(), description.ServerKindMongos, description.TopologyKindSharded, false, nil}, + {"primary/single", readpref.Primary(), description.ServerKindRSPrimary, description.TopologyKindSingle, false, rpPrimaryPreferred}, + {"primary/primary", readpref.Primary(), description.ServerKindRSPrimary, description.TopologyKindReplicaSet, false, nil}, + {"primaryPreferred", readpref.PrimaryPreferred(), description.ServerKindRSSecondary, description.TopologyKindReplicaSet, false, rpPrimaryPreferred}, + {"secondaryPreferred/mongos/opquery", readpref.SecondaryPreferred(), description.ServerKindMongos, description.TopologyKindSharded, true, nil}, + {"secondaryPreferred", readpref.SecondaryPreferred(), description.ServerKindRSSecondary, description.TopologyKindReplicaSet, false, rpSecondaryPreferred}, + {"secondary", readpref.Secondary(), description.ServerKindRSSecondary, description.TopologyKindReplicaSet, false, rpSecondary}, + {"nearest", readpref.Nearest(), description.ServerKindRSSecondary, description.TopologyKindReplicaSet, false, rpNearest}, { "secondaryPreferred/withTags", readpref.SecondaryPreferred(readpref.WithTags("disk", "ssd", "use", "reporting")), - description.RSSecondary, description.ReplicaSet, false, rpWithTags, + description.ServerKindRSSecondary, description.TopologyKindReplicaSet, false, rpWithTags, }, // GODRIVER-2205: Ensure empty tag sets are written as an empty document in the read // preference document. Empty tag sets match any server and are used as a fallback when @@ -487,8 +487,8 @@ func TestOperation(t *testing.T) { readpref.SecondaryPreferred(readpref.WithTagSets( tag.Set{{Name: "disk", Value: "ssd"}}, tag.Set{})), - description.RSSecondary, - description.ReplicaSet, + description.ServerKindRSSecondary, + description.TopologyKindReplicaSet, false, bsoncore.NewDocumentBuilder(). AppendString("mode", "secondaryPreferred"). @@ -501,14 +501,14 @@ func TestOperation(t *testing.T) { { "secondaryPreferred/withMaxStaleness", readpref.SecondaryPreferred(readpref.WithMaxStaleness(25 * time.Second)), - description.RSSecondary, description.ReplicaSet, false, rpWithMaxStaleness, + description.ServerKindRSSecondary, description.TopologyKindReplicaSet, false, rpWithMaxStaleness, }, { // A read preference document is generated for SecondaryPreferred if the hedge document is non-nil. "secondaryPreferred with hedge to mongos using OP_QUERY", readpref.SecondaryPreferred(readpref.WithHedgeEnabled(true)), - description.Mongos, - description.Sharded, + description.ServerKindMongos, + description.TopologyKindSharded, true, rpWithHedge, }, @@ -519,8 +519,8 @@ func TestOperation(t *testing.T) { readpref.WithMaxStaleness(25*time.Second), readpref.WithHedgeEnabled(false), ), - description.RSSecondary, - description.ReplicaSet, + description.ServerKindRSSecondary, + description.TopologyKindReplicaSet, false, rpWithAllOptions, }, @@ -544,8 +544,8 @@ func TestOperation(t *testing.T) { t.Run("description.SelectedServer", func(t *testing.T) { want := wiremessage.SecondaryOK desc := description.SelectedServer{ - Kind: description.Single, - Server: description.Server{Kind: description.RSSecondary}, + Kind: description.TopologyKindSingle, + Server: description.Server{Kind: description.ServerKindRSSecondary}, } got := Operation{}.secondaryOK(desc) if got != want { diff --git a/x/mongo/driver/session/client_session.go b/x/mongo/driver/session/client_session.go index abc7a64efd..d535ec54c9 100644 --- a/x/mongo/driver/session/client_session.go +++ b/x/mongo/driver/session/client_session.go @@ -12,11 +12,12 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/uuid" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/readconcern" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" ) @@ -116,7 +117,7 @@ type Client struct { pool *Pool TransactionState TransactionState - PinnedServer *description.Server + PinnedServerAddr *address.Address RecoveryToken bson.Raw PinnedConnection LoadBalancedTransactionConnection SnapshotTime *bson.Timestamp @@ -305,7 +306,7 @@ func (c *Client) ClearPinnedResources() error { return nil } - c.PinnedServer = nil + c.PinnedServerAddr = nil if c.PinnedConnection != nil { if err := c.PinnedConnection.UnpinFromTransaction(); err != nil { return err @@ -514,8 +515,8 @@ func (c *Client) ApplyCommand(desc description.Server) error { if c.TransactionState == Starting { c.TransactionState = InProgress // If this is in a transaction and the server is a mongos, pin it - if desc.Kind == description.Mongos { - c.PinnedServer = &desc + if desc.Kind == description.ServerKindMongos { + c.PinnedServerAddr = &desc.Addr } } else if c.TransactionState == Committed || c.TransactionState == Aborted { c.TransactionState = None diff --git a/x/mongo/driver/session/client_session_test.go b/x/mongo/driver/session/client_session_test.go index e361d2f2fc..bdf8889732 100644 --- a/x/mongo/driver/session/client_session_test.go +++ b/x/mongo/driver/session/client_session_test.go @@ -15,8 +15,8 @@ import ( "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/internal/uuid" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) var consistent = true @@ -150,7 +150,7 @@ func TestClientSession(t *testing.T) { t.Errorf("expected error, got %v", err) } - err = sess.ApplyCommand(description.Server{Kind: description.Standalone}) + err = sess.ApplyCommand(description.Server{Kind: description.ServerKindStandalone}) assert.Nil(t, err, "ApplyCommand error: %v", err) if sess.TransactionState != InProgress { t.Errorf("incorrect session state, expected InProgress, received %v", sess.TransactionState) diff --git a/x/mongo/driver/session/server_session.go b/x/mongo/driver/session/server_session.go index b1e45552a7..73fce5f21d 100644 --- a/x/mongo/driver/session/server_session.go +++ b/x/mongo/driver/session/server_session.go @@ -10,8 +10,8 @@ import ( "time" "go.mongodb.org/mongo-driver/internal/uuid" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) // Server is an open session with the server. @@ -27,7 +27,7 @@ type Server struct { func (ss *Server) expired(topoDesc topologyDescription) bool { // There is no server monitoring in LB mode, so we do not track session timeout minutes from server hello responses // and never consider sessions to be expired. - if topoDesc.kind == description.LoadBalanced { + if topoDesc.kind == description.TopologyKindLoadBalanced { return false } diff --git a/x/mongo/driver/session/server_session_test.go b/x/mongo/driver/session/server_session_test.go index db4fe58ddf..b89d0963ff 100644 --- a/x/mongo/driver/session/server_session_test.go +++ b/x/mongo/driver/session/server_session_test.go @@ -11,7 +11,7 @@ import ( "time" "go.mongodb.org/mongo-driver/internal/assert" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) func TestServerSession(t *testing.T) { @@ -33,7 +33,7 @@ func TestServerSession(t *testing.T) { assert.Nil(t, err, "newServerSession error: %v", err) // The session should never be considered expired. - topoDesc := topologyDescription{kind: description.LoadBalanced} + topoDesc := topologyDescription{kind: description.TopologyKindLoadBalanced} assert.False(t, sess.expired(topoDesc), "session reported that it was expired in LB mode with timeoutMinutes=0") sess.LastUsed = time.Now().Add(-30 * time.Minute) diff --git a/x/mongo/driver/session/session_pool.go b/x/mongo/driver/session/session_pool.go index 8b3cb8bebd..2612540ba5 100644 --- a/x/mongo/driver/session/session_pool.go +++ b/x/mongo/driver/session/session_pool.go @@ -10,8 +10,8 @@ import ( "sync" "sync/atomic" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) // Node represents a server session in a linked list diff --git a/x/mongo/driver/session/session_pool_test.go b/x/mongo/driver/session/session_pool_test.go index 730bf83ead..e645acdc09 100644 --- a/x/mongo/driver/session/session_pool_test.go +++ b/x/mongo/driver/session/session_pool_test.go @@ -11,7 +11,7 @@ import ( "testing" "go.mongodb.org/mongo-driver/internal/assert" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) func TestSessionPool(t *testing.T) { diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 37e0027bff..43d45c1515 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -18,10 +18,11 @@ import ( "sync/atomic" "time" + "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/ocsp" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" @@ -129,7 +130,7 @@ func (c *connection) hasGenerationNumber() bool { // For LB clusters, we set the generation after the initial handshake, so we know it's set if the connection // description has been updated to reflect that it's behind an LB. - return c.desc.LoadBalanced() + return driverutil.IsServerLoadBalanced(c.desc) } // connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization diff --git a/x/mongo/driver/topology/connection_test.go b/x/mongo/driver/topology/connection_test.go index 0294a35be5..51bba47419 100644 --- a/x/mongo/driver/topology/connection_test.go +++ b/x/mongo/driver/topology/connection_test.go @@ -20,8 +20,8 @@ import ( "github.com/google/go-cmp/cmp" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) diff --git a/x/mongo/driver/topology/diff.go b/x/mongo/driver/topology/diff.go index b9bf2c14c7..3bb7e55ed6 100644 --- a/x/mongo/driver/topology/diff.go +++ b/x/mongo/driver/topology/diff.go @@ -6,7 +6,7 @@ package topology -import "go.mongodb.org/mongo-driver/mongo/description" +import "go.mongodb.org/mongo-driver/x/mongo/driver/description" // hostlistDiff is the difference between a topology and a host list. type hostlistDiff struct { diff --git a/x/mongo/driver/topology/diff_test.go b/x/mongo/driver/topology/diff_test.go index 93958d0cd0..ab0f85b075 100644 --- a/x/mongo/driver/topology/diff_test.go +++ b/x/mongo/driver/topology/diff_test.go @@ -10,7 +10,7 @@ import ( "testing" "go.mongodb.org/mongo-driver/internal/assert" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) func TestDiffHostList(t *testing.T) { diff --git a/x/mongo/driver/topology/errors.go b/x/mongo/driver/topology/errors.go index a6630aae76..5e7c4e0f53 100644 --- a/x/mongo/driver/topology/errors.go +++ b/x/mongo/driver/topology/errors.go @@ -12,7 +12,7 @@ import ( "fmt" "time" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) // ConnectionError represents a connection error. diff --git a/x/mongo/driver/topology/fsm.go b/x/mongo/driver/topology/fsm.go index 30ebfe1886..7d1d209e65 100644 --- a/x/mongo/driver/topology/fsm.go +++ b/x/mongo/driver/topology/fsm.go @@ -12,9 +12,10 @@ import ( "sync/atomic" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/internal/ptrutil" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) var ( @@ -22,7 +23,7 @@ var ( MinSupportedMongoDBVersion = "3.6" // SupportedWireVersions is the range of wire versions supported by the driver. - SupportedWireVersions = description.NewVersionRange(6, 21) + SupportedWireVersions = driverutil.NewVersionRange(6, 21) ) type fsm struct { @@ -39,6 +40,14 @@ func newFSM() *fsm { return &f } +// isServerDataBearing returns true if the server is a data bearing server. +func isServerDataBearing(srv description.Server) bool { + return srv.Kind == description.ServerKindRSPrimary || + srv.Kind == description.ServerKindRSSecondary || + srv.Kind == description.ServerKindMongos || + srv.Kind == description.ServerKindStandalone +} + // selectFSMSessionTimeout selects the timeout to return for the topology's // finite state machine. If the logicalSessionTimeoutMinutes on the FSM exists // and the server is data-bearing, then we determine this value by returning @@ -64,7 +73,7 @@ func selectFSMSessionTimeout(f *fsm, s description.Server) *int64 { // 2. non-nil while the server timeout is nil // // then return the server timeout. - if s.DataBearing() && (comp == 1 || comp == 2) { + if isServerDataBearing(s) && (comp == 1 || comp == 2) { return s.SessionTimeoutMinutes } @@ -79,7 +88,7 @@ func selectFSMSessionTimeout(f *fsm, s description.Server) *int64 { for _, server := range f.Servers { // If the server is not data-bearing, then we do not consider // it's timeout whether set or not. - if !server.DataBearing() { + if !isServerDataBearing(server) { continue } @@ -126,13 +135,13 @@ func (f *fsm) apply(s description.Server) (description.Topology, description.Ser switch f.Kind { case description.Unknown: updatedDesc = f.applyToUnknown(s) - case description.Sharded: + case description.TopologyKindSharded: updatedDesc = f.applyToSharded(s) - case description.ReplicaSetNoPrimary: + case description.TopologyKindReplicaSetNoPrimary: updatedDesc = f.applyToReplicaSetNoPrimary(s) - case description.ReplicaSetWithPrimary: + case description.TopologyKindReplicaSetWithPrimary: updatedDesc = f.applyToReplicaSetWithPrimary(s) - case description.Single: + case description.TopologyKindSingle: updatedDesc = f.applyToSingle(s) } @@ -174,13 +183,13 @@ func (f *fsm) apply(s description.Server) (description.Topology, description.Ser func (f *fsm) applyToReplicaSetNoPrimary(s description.Server) description.Server { switch s.Kind { - case description.Standalone, description.Mongos: + case description.ServerKindStandalone, description.ServerKindMongos: f.removeServerByAddr(s.Addr) - case description.RSPrimary: + case description.ServerKindRSPrimary: f.updateRSFromPrimary(s) - case description.RSSecondary, description.RSArbiter, description.RSMember: + case description.ServerKindRSSecondary, description.ServerKindRSArbiter, description.ServerKindRSMember: f.updateRSWithoutPrimary(s) - case description.Unknown, description.RSGhost: + case description.Unknown, description.ServerKindRSGhost: f.replaceServer(s) } @@ -189,14 +198,14 @@ func (f *fsm) applyToReplicaSetNoPrimary(s description.Server) description.Serve func (f *fsm) applyToReplicaSetWithPrimary(s description.Server) description.Server { switch s.Kind { - case description.Standalone, description.Mongos: + case description.ServerKindStandalone, description.ServerKindMongos: f.removeServerByAddr(s.Addr) f.checkIfHasPrimary() - case description.RSPrimary: + case description.ServerKindRSPrimary: f.updateRSFromPrimary(s) - case description.RSSecondary, description.RSArbiter, description.RSMember: + case description.ServerKindRSSecondary, description.ServerKindRSArbiter, description.ServerKindRSMember: f.updateRSWithPrimaryFromMember(s) - case description.Unknown, description.RSGhost: + case description.Unknown, description.ServerKindRSGhost: f.replaceServer(s) f.checkIfHasPrimary() } @@ -206,9 +215,11 @@ func (f *fsm) applyToReplicaSetWithPrimary(s description.Server) description.Ser func (f *fsm) applyToSharded(s description.Server) description.Server { switch s.Kind { - case description.Mongos, description.Unknown: + case description.ServerKindMongos, description.Unknown: f.replaceServer(s) - case description.Standalone, description.RSPrimary, description.RSSecondary, description.RSArbiter, description.RSMember, description.RSGhost: + case description.ServerKindStandalone, description.ServerKindRSPrimary, + description.ServerKindRSSecondary, description.ServerKindRSArbiter, description.ServerKindRSMember, + description.ServerKindRSGhost: f.removeServerByAddr(s.Addr) } @@ -219,14 +230,15 @@ func (f *fsm) applyToSingle(s description.Server) description.Server { switch s.Kind { case description.Unknown: f.replaceServer(s) - case description.Standalone, description.Mongos: + case description.ServerKindStandalone, description.ServerKindMongos: if f.SetName != "" { f.removeServerByAddr(s.Addr) return s } f.replaceServer(s) - case description.RSPrimary, description.RSSecondary, description.RSArbiter, description.RSMember, description.RSGhost: + case description.ServerKindRSPrimary, description.ServerKindRSSecondary, + description.ServerKindRSArbiter, description.ServerKindRSMember, description.ServerKindRSGhost: // A replica set name can be provided when creating a direct connection. In this case, if the set name returned // by the hello response doesn't match up with the one provided during configuration, the server description // is replaced with a default Unknown description. @@ -248,17 +260,17 @@ func (f *fsm) applyToSingle(s description.Server) description.Server { func (f *fsm) applyToUnknown(s description.Server) description.Server { switch s.Kind { - case description.Mongos: - f.setKind(description.Sharded) + case description.ServerKindMongos: + f.setKind(description.TopologyKindSharded) f.replaceServer(s) - case description.RSPrimary: + case description.ServerKindRSPrimary: f.updateRSFromPrimary(s) - case description.RSSecondary, description.RSArbiter, description.RSMember: - f.setKind(description.ReplicaSetNoPrimary) + case description.ServerKindRSSecondary, description.ServerKindRSArbiter, description.ServerKindRSMember: + f.setKind(description.TopologyKindReplicaSetNoPrimary) f.updateRSWithoutPrimary(s) - case description.Standalone: + case description.ServerKindStandalone: f.updateUnknownWithStandalone(s) - case description.Unknown, description.RSGhost: + case description.Unknown, description.ServerKindRSGhost: f.replaceServer(s) } @@ -267,9 +279,9 @@ func (f *fsm) applyToUnknown(s description.Server) description.Server { func (f *fsm) checkIfHasPrimary() { if _, ok := f.findPrimary(); ok { - f.setKind(description.ReplicaSetWithPrimary) + f.setKind(description.TopologyKindReplicaSetWithPrimary) } else { - f.setKind(description.ReplicaSetNoPrimary) + f.setKind(description.TopologyKindReplicaSetNoPrimary) } } @@ -395,7 +407,7 @@ func (f *fsm) updateRSWithPrimaryFromMember(s description.Server) { f.replaceServer(s) if _, ok := f.findPrimary(); !ok { - f.setKind(description.ReplicaSetNoPrimary) + f.setKind(description.TopologyKindReplicaSetNoPrimary) } } @@ -427,7 +439,7 @@ func (f *fsm) updateUnknownWithStandalone(s description.Server) { return } - f.setKind(description.Single) + f.setKind(description.TopologyKindSingle) f.replaceServer(s) } @@ -439,7 +451,7 @@ func (f *fsm) addServer(addr address.Address) { func (f *fsm) findPrimary() (int, bool) { for i, s := range f.Servers { - if s.Kind == description.RSPrimary { + if s.Kind == description.ServerKindRSPrimary { return i, true } } diff --git a/x/mongo/driver/topology/fsm_test.go b/x/mongo/driver/topology/fsm_test.go index 390ffebc32..b84be8fe30 100644 --- a/x/mongo/driver/topology/fsm_test.go +++ b/x/mongo/driver/topology/fsm_test.go @@ -11,7 +11,7 @@ import ( "testing" "go.mongodb.org/mongo-driver/internal/assert" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) func TestFSMSessionTimeout(t *testing.T) { @@ -39,7 +39,7 @@ func TestFSMSessionTimeout(t *testing.T) { }, }, s: description.Server{ - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, }, want: nil, }, @@ -51,7 +51,7 @@ func TestFSMSessionTimeout(t *testing.T) { }, }, s: description.Server{ - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, SessionTimeoutMinutes: int64ToPtr(1), }, want: int64ToPtr(1), @@ -60,7 +60,7 @@ func TestFSMSessionTimeout(t *testing.T) { name: "session support on data-bearing server with no session support on fsm with no servers", f: &fsm{Topology: description.Topology{}}, s: description.Server{ - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, SessionTimeoutMinutes: int64ToPtr(1), }, want: int64ToPtr(1), @@ -70,13 +70,13 @@ func TestFSMSessionTimeout(t *testing.T) { f: &fsm{Topology: description.Topology{ Servers: []description.Server{ { - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, SessionTimeoutMinutes: int64ToPtr(1), }, }, }}, s: description.Server{ - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, SessionTimeoutMinutes: int64ToPtr(2), }, want: int64ToPtr(1), @@ -86,13 +86,13 @@ func TestFSMSessionTimeout(t *testing.T) { f: &fsm{Topology: description.Topology{ Servers: []description.Server{ { - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, SessionTimeoutMinutes: int64ToPtr(3), }, }, }}, s: description.Server{ - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, SessionTimeoutMinutes: int64ToPtr(2), }, want: int64ToPtr(2), diff --git a/x/mongo/driver/topology/polling_srv_records_test.go b/x/mongo/driver/topology/polling_srv_records_test.go index df0704afc6..b824de88e7 100644 --- a/x/mongo/driver/topology/polling_srv_records_test.go +++ b/x/mongo/driver/topology/polling_srv_records_test.go @@ -18,8 +18,8 @@ import ( "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/dns" ) @@ -181,7 +181,7 @@ func TestPollSRVRecords(t *testing.T) { err = topo.Connect() require.NoError(t, err, "Could not Connect to the topology: %v", err) topo.serversLock.Lock() - topo.fsm.Kind = description.Single + topo.fsm.Kind = description.TopologyKindSingle topo.desc.Store(description.Topology{ Kind: topo.fsm.Kind, Servers: topo.fsm.Servers, diff --git a/x/mongo/driver/topology/sdam_spec_test.go b/x/mongo/driver/topology/sdam_spec_test.go index 2dac5e7b92..65e3213a3c 100644 --- a/x/mongo/driver/topology/sdam_spec_test.go +++ b/x/mongo/driver/topology/sdam_spec_test.go @@ -19,12 +19,13 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/assert" + "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/internal/spectest" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) type response struct { @@ -255,7 +256,7 @@ func applyResponses(t *testing.T, topo *Topology, responses []response, sub *dri assert.Nil(t, err, "Marshal error: %v", err) addr := address.Address(response.Host) - desc := description.NewServer(addr, doc) + desc := driverutil.NewServerDescription(addr, doc) server, ok := topo.servers[addr] if ok { server.updateDescription(desc) @@ -312,7 +313,7 @@ func applyErrors(t *testing.T, topo *Topology, errors []applicationError) { assert.True(t, ok, "server not found: %v", appErr.Address) desc := server.Description() - versionRange := description.NewVersionRange(0, *appErr.MaxWireVersion) + versionRange := driverutil.NewVersionRange(0, *appErr.MaxWireVersion) desc.WireVersion = &versionRange generation, _ := server.pool.generation.getGeneration(nil) @@ -339,7 +340,7 @@ func applyErrors(t *testing.T, topo *Topology, errors []applicationError) { } func compareServerDescriptions(t *testing.T, - expected serverDescription, actual description.Server, idx int) { + expected serverDescription, actual event.ServerDescription, idx int) { t.Helper() assert.Equal(t, expected.Address, actual.Addr.String(), @@ -368,16 +369,16 @@ func compareServerDescriptions(t *testing.T, if expected.Type == "PossiblePrimary" { expected.Type = "Unknown" } - assert.Equal(t, expected.Type, actual.Kind.String(), - "%v: expected server kind %s, got %s", idx, expected.Type, actual.Kind.String()) + assert.Equal(t, expected.Type, actual.Kind, + "%v: expected server kind %s, got %s", idx, expected.Type, actual.Kind) } func compareTopologyDescriptions(t *testing.T, - expected topologyDescription, actual description.Topology, idx int) { + expected topologyDescription, actual event.TopologyDescription, idx int) { t.Helper() - assert.Equal(t, expected.TopologyType, actual.Kind.String(), - "%v: expected topology kind %s, got %s", idx, expected.TopologyType, actual.Kind.String()) + assert.Equal(t, expected.TopologyType, actual.Kind, + "%v: expected topology kind %s, got %s", idx, expected.TopologyType, actual.Kind) assert.Equal(t, len(expected.Servers), len(actual.Servers), "%v: expected %d servers, got %d", idx, len(expected.Servers), len(actual.Servers)) diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index bfb0c3ca76..862f9c6d48 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -20,9 +20,9 @@ import ( "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/internal/logger" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/operation" ) @@ -50,6 +50,27 @@ func serverStateString(state int64) string { return "" } +// newServerDescriptionFromError creates a new unknown server description with +// the given parameters. +func newServerDescriptionFromError( + addr address.Address, + err error, + tv *description.TopologyVersion, +) description.Server { + return description.Server{ + Addr: addr, + LastError: err, + Kind: description.Unknown, + TopologyVersion: tv, + } +} + +// newDefaultServerDescription creates a new unknown server description with the +// given address. +func newDefaultServerDescription(addr address.Address) description.Server { + return newServerDescriptionFromError(addr, nil, nil) +} + var ( // ErrServerClosed occurs when an attempt to Get a connection is made after // the server has been closed. @@ -59,7 +80,7 @@ var ( ErrServerConnected = errors.New("server is connected") errCheckCancelled = errors.New("server check cancelled") - emptyDescription = description.NewDefaultServer("") + emptyDescription = newDefaultServerDescription("") ) // SelectedServer represents a specific server that was selected during server selection. @@ -170,7 +191,7 @@ func NewServer(addr address.Address, topologyID bson.ObjectID, opts ...ServerOpt globalCtx: globalCtx, globalCtxCancel: globalCtxCancel, } - s.desc.Store(description.NewDefaultServer(addr)) + s.desc.Store(newDefaultServerDescription(addr)) rttCfg := &rttConfig{ interval: cfg.heartbeatInterval, minRTTWindow: 5 * time.Minute, @@ -239,10 +260,10 @@ func (s *Server) Connect(updateCallback updateTopologyCallback) error { return ErrServerConnected } - desc := description.NewDefaultServer(s.address) + desc := newDefaultServerDescription(s.address) if s.cfg.loadBalanced { // LBs automatically start off with kind LoadBalancer because there is no monitoring routine for state changes. - desc.Kind = description.LoadBalancer + desc.Kind = description.ServerKindLoadBalancer } s.desc.Store(desc) s.updateTopologyCallback.Store(updateCallback) @@ -357,7 +378,7 @@ func (s *Server) ProcessHandshakeError(err error, startingGenerationNumber uint6 // Since the only kind of ConnectionError we receive from pool.Get will be an initialization error, we should set // the description.Server appropriately. The description should not have a TopologyVersion because the staleness // checking logic above has already determined that this description is not stale. - s.updateDescription(description.NewServerFromError(s.address, wrappedConnErr, nil)) + s.updateDescription(newServerDescriptionFromError(s.address, wrappedConnErr, nil)) s.pool.clear(err, serviceID) s.cancelCheck() } @@ -374,7 +395,7 @@ func (s *Server) SelectedDescription() description.SelectedServer { sdesc := s.Description() return description.SelectedServer{ Server: sdesc, - Kind: description.Single, + Kind: description.TopologyKindSingle, } } @@ -476,7 +497,7 @@ func (s *Server) ProcessError(err error, describer mnet.Describer) driver.Proces // TODO(GODRIVER-2841): Remove this logic once we set the Server description when we create // TODO application connections because then the Server's topology version will always be the // TODO latest known. - if tv := connDesc.TopologyVersion; tv != nil && topologyVersion.CompareToIncoming(tv) < 0 { + if tv := connDesc.TopologyVersion; tv != nil && driverutil.CompareTopologyVersions(topologyVersion, tv) < 0 { topologyVersion = tv } @@ -484,12 +505,12 @@ func (s *Server) ProcessError(err error, describer mnet.Describer) driver.Proces // These errors can be reported as a command error or a write concern error. if cerr, ok := err.(driver.Error); ok && (cerr.NodeIsRecovering() || cerr.NotPrimary()) { // Ignore errors that came from when the database was on a previous topology version. - if topologyVersion.CompareToIncoming(cerr.TopologyVersion) >= 0 { + if driverutil.CompareTopologyVersions(topologyVersion, cerr.TopologyVersion) >= 0 { return driver.NoChange } // updates description to unknown - s.updateDescription(description.NewServerFromError(s.address, err, cerr.TopologyVersion)) + s.updateDescription(newServerDescriptionFromError(s.address, err, cerr.TopologyVersion)) s.RequestImmediateCheck() res := driver.ServerMarkedUnknown @@ -503,12 +524,12 @@ func (s *Server) ProcessError(err error, describer mnet.Describer) driver.Proces } if wcerr, ok := getWriteConcernErrorForProcessing(err); ok { // Ignore errors that came from when the database was on a previous topology version. - if topologyVersion.CompareToIncoming(wcerr.TopologyVersion) >= 0 { + if driverutil.CompareTopologyVersions(topologyVersion, wcerr.TopologyVersion) >= 0 { return driver.NoChange } // updates description to unknown - s.updateDescription(description.NewServerFromError(s.address, err, wcerr.TopologyVersion)) + s.updateDescription(newServerDescriptionFromError(s.address, err, wcerr.TopologyVersion)) s.RequestImmediateCheck() res := driver.ServerMarkedUnknown @@ -536,7 +557,7 @@ func (s *Server) ProcessError(err error, describer mnet.Describer) driver.Proces // For a non-timeout network error, we clear the pool, set the description to Unknown, and cancel the in-progress // monitoring check. The check is cancelled last to avoid a post-cancellation reconnect racing with // updateDescription. - s.updateDescription(description.NewServerFromError(s.address, err, nil)) + s.updateDescription(newServerDescriptionFromError(s.address, err, nil)) s.pool.clear(err, serviceID) s.cancelCheck() return driver.ConnectionPoolCleared @@ -923,8 +944,10 @@ func (s *Server) check() (description.Server, error) { if descPtr != nil { // The check was successful. Set the average RTT and the 90th percentile RTT and return. desc := *descPtr - desc = desc.SetAverageRTT(s.rttMonitor.EWMA()) + desc.AverageRTT = s.rttMonitor.EWMA() + desc.AverageRTTSet = true desc.HeartbeatInterval = s.cfg.heartbeatInterval + return desc, nil } @@ -938,7 +961,7 @@ func (s *Server) check() (description.Server, error) { // be cleared, but only after the description has already been updated, so that is handled by the caller. topologyVersion := extractTopologyVersion(err) s.rttMonitor.reset() - return description.NewServerFromError(s.address, err, topologyVersion), nil + return newServerDescriptionFromError(s.address, err, topologyVersion), nil } func extractTopologyVersion(err error) *description.TopologyVersion { @@ -1060,7 +1083,7 @@ func (s *Server) publishServerHeartbeatSucceededEvent(connectionID string, ) { serverHeartbeatSucceeded := &event.ServerHeartbeatSucceededEvent{ Duration: duration, - Reply: desc, + Reply: newEventServerDescription(desc), ConnectionID: connectionID, Awaited: await, } diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index 366fc83917..1fd20fdadb 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -30,10 +30,10 @@ import ( "go.mongodb.org/mongo-driver/internal/eventtest" "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" "go.mongodb.org/mongo-driver/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" @@ -516,7 +516,7 @@ func TestServer(t *testing.T) { return driver.HandshakeInformation{}, tc.getInfoErr } - desc := description.NewDefaultServer(addr) + desc := newServerDescriptionFromError(addr, nil, nil) if tc.loadBalanced { desc.ServiceID = &serviceID } @@ -870,27 +870,27 @@ func TestServer_ProcessError(t *testing.T) { { name: "nil error", startDescription: description.Server{ - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, }, inputErr: nil, want: driver.NoChange, wantGeneration: 0, wantDescription: description.Server{ - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, }, }, // Test that errors that occur on stale connections are ignored. { name: "stale connection", startDescription: description.Server{ - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, }, inputErr: errors.New("foo"), inputConn: newProcessErrorTestConn(&description.VersionRange{Max: 17}, true), want: driver.NoChange, wantGeneration: 0, wantDescription: description.Server{ - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, }, }, // Test that errors that do not indicate a database state change or connection error are @@ -898,7 +898,7 @@ func TestServer_ProcessError(t *testing.T) { { name: "non state change error", startDescription: description.Server{ - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, }, inputErr: driver.Error{ Code: 1, @@ -907,13 +907,13 @@ func TestServer_ProcessError(t *testing.T) { want: driver.NoChange, wantGeneration: 0, wantDescription: description.Server{ - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, }, }, // Test that a "not writable primary" error with an old topology version is ignored. { name: "stale not writable primary error", - startDescription: newServerDescription(description.RSPrimary, processID, 1, nil), + startDescription: newServerDescription(description.ServerKindRSPrimary, processID, 1, nil), inputErr: driver.Error{ Code: 10107, // NotWritablePrimary TopologyVersion: &description.TopologyVersion{ @@ -924,13 +924,13 @@ func TestServer_ProcessError(t *testing.T) { inputConn: newProcessErrorTestConn(&description.VersionRange{Max: 17}, false), want: driver.NoChange, wantGeneration: 0, - wantDescription: newServerDescription(description.RSPrimary, processID, 1, nil), + wantDescription: newServerDescription(description.ServerKindRSPrimary, processID, 1, nil), }, // Test that a "not writable primary" error with an newer topology version marks the Server // as "unknown" and updates its topology version. { name: "new not writable primary error", - startDescription: newServerDescription(description.RSPrimary, processID, 0, nil), + startDescription: newServerDescription(description.ServerKindRSPrimary, processID, 0, nil), inputErr: driver.Error{ Code: 10107, // NotWritablePrimary TopologyVersion: &description.TopologyVersion{ @@ -953,7 +953,7 @@ func TestServer_ProcessError(t *testing.T) { // "unknown" and updates its topology version. { name: "new process ID not writable primary error", - startDescription: newServerDescription(description.RSPrimary, processID, 0, nil), + startDescription: newServerDescription(description.ServerKindRSPrimary, processID, 0, nil), inputErr: driver.Error{ Code: 10107, // NotWritablePrimary TopologyVersion: &description.TopologyVersion{ @@ -977,7 +977,7 @@ func TestServer_ProcessError(t *testing.T) { // TODO(GODRIVER-2841): Remove this test case. { name: "newer connection topology version", - startDescription: newServerDescription(description.RSPrimary, processID, 0, nil), + startDescription: newServerDescription(description.ServerKindRSPrimary, processID, 0, nil), inputErr: driver.Error{ Code: 10107, // NotWritablePrimary TopologyVersion: &description.TopologyVersion{ @@ -997,13 +997,13 @@ func TestServer_ProcessError(t *testing.T) { }), want: driver.NoChange, wantGeneration: 0, - wantDescription: newServerDescription(description.RSPrimary, processID, 0, nil), + wantDescription: newServerDescription(description.ServerKindRSPrimary, processID, 0, nil), }, // Test that a "node is shutting down" error with a newer topology version clears the // connection pool, marks the Server as "unknown", and updates its topology version. { name: "new shutdown error", - startDescription: newServerDescription(description.RSPrimary, processID, 0, nil), + startDescription: newServerDescription(description.ServerKindRSPrimary, processID, 0, nil), inputErr: driver.Error{ Code: 11600, // InterruptedAtShutdown TopologyVersion: &description.TopologyVersion{ @@ -1025,7 +1025,7 @@ func TestServer_ProcessError(t *testing.T) { // Test that a "not writable primary" error with a stale topology version is ignored. { name: "stale not writable primary write concern error", - startDescription: newServerDescription(description.RSPrimary, processID, 1, nil), + startDescription: newServerDescription(description.ServerKindRSPrimary, processID, 1, nil), inputErr: driver.WriteCommandError{ WriteConcernError: &driver.WriteConcernError{ Code: 10107, // NotWritablePrimary @@ -1038,13 +1038,13 @@ func TestServer_ProcessError(t *testing.T) { inputConn: newProcessErrorTestConn(&description.VersionRange{Max: 17}, false), want: driver.NoChange, wantGeneration: 0, - wantDescription: newServerDescription(description.RSPrimary, processID, 1, nil), + wantDescription: newServerDescription(description.ServerKindRSPrimary, processID, 1, nil), }, // Test that a "not writable primary" error with a newer topology version marks the Server // as "unknown" and updates its topology version. { name: "new not writable primary write concern error", - startDescription: newServerDescription(description.RSPrimary, processID, 0, nil), + startDescription: newServerDescription(description.ServerKindRSPrimary, processID, 0, nil), inputErr: driver.WriteCommandError{ WriteConcernError: &driver.WriteConcernError{ Code: 10107, // NotWritablePrimary @@ -1071,7 +1071,7 @@ func TestServer_ProcessError(t *testing.T) { // local Server topology version mark the Server as "unknown" and clear the connection pool. { name: "new shutdown write concern error", - startDescription: newServerDescription(description.RSPrimary, processID, 0, nil), + startDescription: newServerDescription(description.ServerKindRSPrimary, processID, 0, nil), inputErr: driver.WriteCommandError{ WriteConcernError: &driver.WriteConcernError{ Code: 11600, // InterruptedAtShutdown @@ -1099,7 +1099,7 @@ func TestServer_ProcessError(t *testing.T) { // servers before 4.2 mark the Server as "unknown" and clear the connection pool. { name: "older than 4.2 write concern error", - startDescription: newServerDescription(description.RSPrimary, processID, 0, nil), + startDescription: newServerDescription(description.ServerKindRSPrimary, processID, 0, nil), inputErr: driver.WriteCommandError{ WriteConcernError: &driver.WriteConcernError{ Code: 10107, // NotWritablePrimary @@ -1125,7 +1125,7 @@ func TestServer_ProcessError(t *testing.T) { // Test that a network timeout error, such as a DNS lookup timeout error, is ignored. { name: "network timeout error", - startDescription: newServerDescription(description.RSPrimary, processID, 0, nil), + startDescription: newServerDescription(description.ServerKindRSPrimary, processID, 0, nil), inputErr: driver.Error{ Labels: []string{driver.NetworkError}, Wrapped: ConnectionError{ @@ -1138,12 +1138,12 @@ func TestServer_ProcessError(t *testing.T) { inputConn: newProcessErrorTestConn(&description.VersionRange{Max: 17}, false), want: driver.NoChange, wantGeneration: 0, - wantDescription: newServerDescription(description.RSPrimary, processID, 0, nil), + wantDescription: newServerDescription(description.ServerKindRSPrimary, processID, 0, nil), }, // Test that a context canceled error is ignored. { name: "context canceled error", - startDescription: newServerDescription(description.RSPrimary, processID, 0, nil), + startDescription: newServerDescription(description.ServerKindRSPrimary, processID, 0, nil), inputErr: driver.Error{ Labels: []string{driver.NetworkError}, Wrapped: ConnectionError{ @@ -1153,13 +1153,13 @@ func TestServer_ProcessError(t *testing.T) { inputConn: newProcessErrorTestConn(&description.VersionRange{Max: 17}, false), want: driver.NoChange, wantGeneration: 0, - wantDescription: newServerDescription(description.RSPrimary, processID, 0, nil), + wantDescription: newServerDescription(description.ServerKindRSPrimary, processID, 0, nil), }, // Test that a non-timeout network error, such as an address lookup error, marks the server // as "unknown" and sets its topology version to nil. { name: "non-timeout network error", - startDescription: newServerDescription(description.RSPrimary, processID, 0, nil), + startDescription: newServerDescription(description.ServerKindRSPrimary, processID, 0, nil), inputErr: driver.Error{ Labels: []string{driver.NetworkError}, Wrapped: ConnectionError{ diff --git a/x/mongo/driver/topology/topology.go b/x/mongo/driver/topology/topology.go index 2a3521a893..d9e9de1f50 100644 --- a/x/mongo/driver/topology/topology.go +++ b/x/mongo/driver/topology/topology.go @@ -32,13 +32,14 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/internal/driverutil" "go.mongodb.org/mongo-driver/internal/logger" "go.mongodb.org/mongo-driver/internal/randutil" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" "go.mongodb.org/mongo-driver/x/mongo/driver/dns" ) @@ -300,17 +301,17 @@ func (t *Topology) Connect() error { // specified, in which case the initial type is Single. if t.cfg.ReplicaSetName != "" { t.fsm.SetName = t.cfg.ReplicaSetName - t.fsm.Kind = description.ReplicaSetNoPrimary + t.fsm.Kind = description.TopologyKindReplicaSetNoPrimary } // A direct connection unconditionally sets the topology type to Single. if t.cfg.Mode == SingleMode { - t.fsm.Kind = description.Single + t.fsm.Kind = description.TopologyKindSingle } for _, a := range t.cfg.SeedList { addr := address.Address(a).Canonicalize() - t.fsm.Servers = append(t.fsm.Servers, description.NewDefaultServer(addr)) + t.fsm.Servers = append(t.fsm.Servers, newServerDescriptionFromError(addr, nil, nil)) } switch { @@ -321,7 +322,7 @@ func (t *Topology) Connect() error { // monitoring routines in this mode, so we have to mock state changes. // Transition from Unknown with no servers to LoadBalanced with a single Unknown server. - t.fsm.Kind = description.LoadBalanced + t.fsm.Kind = description.TopologyKindLoadBalanced t.publishTopologyDescriptionChangedEvent(description.Topology{}, t.fsm.Topology) addr := address.Address(t.cfg.SeedList[0]).Canonicalize() @@ -742,7 +743,7 @@ func (t *Topology) selectServerFromDescription(desc description.Topology, // If the topology kind is LoadBalanced, the LB is the only server and it is always considered selectable. The // selectors exported by the driver should already return the LB as a candidate, so this but this check ensures that // the LB is always selectable even if a user of the low-level driver provides a custom selector. - if desc.Kind == description.LoadBalanced { + if desc.Kind == description.TopologyKindLoadBalanced { return desc.Servers, nil } @@ -790,7 +791,7 @@ func (t *Topology) pollSRVRecords(hosts string) { return } topoKind := t.Description().Kind - if !(topoKind == description.Unknown || topoKind == description.Sharded) { + if !(topoKind == description.Unknown || topoKind == description.TopologyKindSharded) { break } @@ -819,6 +820,38 @@ func (t *Topology) pollSRVRecords(hosts string) { doneOnce = true } +// equalTopologies compares two topology descriptions and returns true if they +// are equal. +func equalTopologies(topo1, topo2 description.Topology) bool { + if topo1.Kind != topo2.Kind { + return false + } + + topoServers := make(map[string]description.Server, len(topo1.Servers)) + for _, s := range topo1.Servers { + topoServers[s.Addr.String()] = s + } + + otherServers := make(map[string]description.Server, len(topo2.Servers)) + for _, s := range topo2.Servers { + otherServers[s.Addr.String()] = s + } + + if len(topoServers) != len(otherServers) { + return false + } + + for _, server := range topoServers { + otherServer := otherServers[server.Addr.String()] + + if !driverutil.EqualServers(server, otherServer) { + return false + } + } + + return true +} + func (t *Topology) processSRVResults(parsedHosts []string) bool { t.serversLock.Lock() defer t.serversLock.Unlock() @@ -875,7 +908,7 @@ func (t *Topology) processSRVResults(parsedHosts []string) bool { } t.desc.Store(newDesc) - if !prev.Equal(newDesc) { + if !equalTopologies(prev, newDesc) { t.publishTopologyDescriptionChangedEvent(prev, newDesc) } @@ -906,14 +939,14 @@ func (t *Topology) apply(ctx context.Context, desc description.Server) descripti prev := t.fsm.Topology oldDesc := t.fsm.Servers[ind] - if oldDesc.TopologyVersion.CompareToIncoming(desc.TopologyVersion) > 0 { + if driverutil.CompareTopologyVersions(oldDesc.TopologyVersion, desc.TopologyVersion) > 0 { return oldDesc } var current description.Topology current, desc = t.fsm.apply(desc) - if !oldDesc.Equal(desc) { + if !driverutil.EqualServers(oldDesc, desc) { t.publishServerDescriptionChangedEvent(oldDesc, desc) } @@ -936,7 +969,7 @@ func (t *Topology) apply(ctx context.Context, desc description.Server) descripti } t.desc.Store(current) - if !prev.Equal(current) { + if !equalTopologies(prev, current) { t.publishTopologyDescriptionChangedEvent(prev, current) } @@ -987,8 +1020,8 @@ func (t *Topology) publishServerDescriptionChangedEvent(prev description.Server, serverDescriptionChanged := &event.ServerDescriptionChangedEvent{ Address: current.Addr, TopologyID: t.id, - PreviousDescription: prev, - NewDescription: current, + PreviousDescription: newEventServerDescription(prev), + NewDescription: newEventServerDescription(current), } if t.cfg.ServerMonitor != nil && t.cfg.ServerMonitor.ServerDescriptionChanged != nil { @@ -1026,8 +1059,8 @@ func (t *Topology) publishServerClosedEvent(addr address.Address) { func (t *Topology) publishTopologyDescriptionChangedEvent(prev description.Topology, current description.Topology) { topologyDescriptionChanged := &event.TopologyDescriptionChangedEvent{ TopologyID: t.id, - PreviousDescription: prev, - NewDescription: current, + PreviousDescription: newEventServerTopology(prev), + NewDescription: newEventServerTopology(current), } if t.cfg.ServerMonitor != nil && t.cfg.ServerMonitor.TopologyDescriptionChanged != nil { @@ -1070,3 +1103,60 @@ func (t *Topology) publishTopologyClosedEvent() { logTopologyMessage(t, logger.LevelDebug, logger.TopologyClosed) } } + +func newEventServerDescription(srv description.Server) event.ServerDescription { + evtSrv := event.ServerDescription{ + Addr: srv.Addr, + Arbiters: srv.Arbiters, + Compression: srv.Compression, + CanonicalAddr: srv.CanonicalAddr, + ElectionID: srv.ElectionID, + IsCryptd: srv.IsCryptd, + HelloOK: srv.HelloOK, + Hosts: srv.Hosts, + Kind: srv.Kind.String(), + LastWriteTime: srv.LastWriteTime, + MaxBatchCount: srv.MaxBatchCount, + MaxDocumentSize: srv.MaxDocumentSize, + MaxMessageSize: srv.MaxMessageSize, + Members: srv.Members, + Passive: srv.Passive, + Passives: srv.Passives, + Primary: srv.Primary, + ReadOnly: srv.ReadOnly, + ServiceID: srv.ServiceID, + SessionTimeoutMinutes: srv.SessionTimeoutMinutes, + SetName: srv.SetName, + SetVersion: srv.SetVersion, + Tags: srv.Tags, + } + + if srv.WireVersion != nil { + evtSrv.MaxWireVersion = srv.WireVersion.Max + evtSrv.MinWireVersion = srv.WireVersion.Min + } + + if srv.TopologyVersion != nil { + evtSrv.TopologyVersionProcessID = srv.TopologyVersion.ProcessID + evtSrv.TopologyVersionCounter = srv.TopologyVersion.Counter + } + + return evtSrv +} + +func newEventServerTopology(topo description.Topology) event.TopologyDescription { + evtSrvs := make([]event.ServerDescription, len(topo.Servers)) + for idx, srv := range topo.Servers { + evtSrvs[idx] = newEventServerDescription(srv) + } + + evtTopo := event.TopologyDescription{ + Servers: evtSrvs, + SetName: topo.SetName, + Kind: topo.Kind.String(), + SessionTimeoutMinutes: topo.SessionTimeoutMinutes, + CompatibilityErr: topo.CompatibilityErr, + } + + return evtTopo +} diff --git a/x/mongo/driver/topology/topology_errors_test.go b/x/mongo/driver/topology/topology_errors_test.go index c09ef9731c..c959fe5cf9 100644 --- a/x/mongo/driver/topology/topology_errors_test.go +++ b/x/mongo/driver/topology/topology_errors_test.go @@ -17,10 +17,11 @@ import ( "time" "go.mongodb.org/mongo-driver/internal/assert" - "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/internal/serverselector" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) -var selectNone description.ServerSelectorFunc = func(description.Topology, []description.Server) ([]description.Server, error) { +var selectNone serverselector.Func = func(description.Topology, []description.Server) ([]description.Server, error) { return []description.Server{}, nil } @@ -38,7 +39,7 @@ func TestTopologyErrors(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err = topo.SelectServer(ctx, description.WriteSelector()) + _, err = topo.SelectServer(ctx, &serverselector.Write{}) assert.True(t, errors.Is(err, context.Canceled), "expected error %v, got %v", context.Canceled, err) }) t.Run("context deadline error", func(t *testing.T) { diff --git a/x/mongo/driver/topology/topology_test.go b/x/mongo/driver/topology/topology_test.go index aae357bc3b..937824d4dd 100644 --- a/x/mongo/driver/topology/topology_test.go +++ b/x/mongo/driver/topology/topology_test.go @@ -20,12 +20,13 @@ import ( "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/internal/assert" "go.mongodb.org/mongo-driver/internal/require" + "go.mongodb.org/mongo-driver/internal/serverselector" "go.mongodb.org/mongo-driver/internal/spectest" "go.mongodb.org/mongo-driver/mongo/address" - "go.mongodb.org/mongo-driver/mongo/description" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/x/mongo/driver" + "go.mongodb.org/mongo-driver/x/mongo/driver/description" ) const testTimeout = 2 * time.Second @@ -55,17 +56,17 @@ func compareErrors(err1, err2 error) bool { } func TestServerSelection(t *testing.T) { - var selectFirst description.ServerSelectorFunc = func(_ description.Topology, candidates []description.Server) ([]description.Server, error) { + var selectFirst serverselector.Func = func(_ description.Topology, candidates []description.Server) ([]description.Server, error) { if len(candidates) == 0 { return []description.Server{}, nil } return candidates[0:1], nil } - var selectNone description.ServerSelectorFunc = func(description.Topology, []description.Server) ([]description.Server, error) { + var selectNone serverselector.Func = func(description.Topology, []description.Server) ([]description.Server, error) { return []description.Server{}, nil } var errSelectionError = errors.New("encountered an error in the selector") - var selectError description.ServerSelectorFunc = func(description.Topology, []description.Server) ([]description.Server, error) { + var selectError serverselector.Func = func(description.Topology, []description.Server) ([]description.Server, error) { return nil, errSelectionError } @@ -74,9 +75,9 @@ func TestServerSelection(t *testing.T) { noerr(t, err) desc := description.Topology{ Servers: []description.Server{ - {Addr: address.Address("one"), Kind: description.Standalone}, - {Addr: address.Address("two"), Kind: description.Standalone}, - {Addr: address.Address("three"), Kind: description.Standalone}, + {Addr: address.Address("one"), Kind: description.ServerKindStandalone}, + {Addr: address.Address("two"), Kind: description.ServerKindStandalone}, + {Addr: address.Address("three"), Kind: description.ServerKindStandalone}, }, } subCh := make(chan description.Topology, 1) @@ -96,11 +97,11 @@ func TestServerSelection(t *testing.T) { topo, err := New(nil) noerr(t, err) desc := description.Topology{ - Kind: description.Single, + Kind: description.TopologyKindSingle, Servers: []description.Server{ - {Addr: address.Address("one:27017"), Kind: description.Standalone, WireVersion: &description.VersionRange{Max: 11, Min: 11}}, - {Addr: address.Address("two:27017"), Kind: description.Standalone, WireVersion: &description.VersionRange{Max: 9, Min: 6}}, - {Addr: address.Address("three:27017"), Kind: description.Standalone, WireVersion: &description.VersionRange{Max: 9, Min: 6}}, + {Addr: address.Address("one:27017"), Kind: description.ServerKindStandalone, WireVersion: &description.VersionRange{Max: 11, Min: 11}}, + {Addr: address.Address("two:27017"), Kind: description.ServerKindStandalone, WireVersion: &description.VersionRange{Max: 9, Min: 6}}, + {Addr: address.Address("three:27017"), Kind: description.ServerKindStandalone, WireVersion: &description.VersionRange{Max: 9, Min: 6}}, }, } want := fmt.Errorf( @@ -119,11 +120,11 @@ func TestServerSelection(t *testing.T) { topo, err := New(nil) noerr(t, err) desc := description.Topology{ - Kind: description.Single, + Kind: description.TopologyKindSingle, Servers: []description.Server{ - {Addr: address.Address("one:27017"), Kind: description.Standalone, WireVersion: &description.VersionRange{Max: 21, Min: 6}}, - {Addr: address.Address("two:27017"), Kind: description.Standalone, WireVersion: &description.VersionRange{Max: 9, Min: 2}}, - {Addr: address.Address("three:27017"), Kind: description.Standalone, WireVersion: &description.VersionRange{Max: 9, Min: 2}}, + {Addr: address.Address("one:27017"), Kind: description.ServerKindStandalone, WireVersion: &description.VersionRange{Max: 21, Min: 6}}, + {Addr: address.Address("two:27017"), Kind: description.ServerKindStandalone, WireVersion: &description.VersionRange{Max: 9, Min: 2}}, + {Addr: address.Address("three:27017"), Kind: description.ServerKindStandalone, WireVersion: &description.VersionRange{Max: 9, Min: 2}}, }, } want := fmt.Errorf( @@ -155,9 +156,9 @@ func TestServerSelection(t *testing.T) { desc = description.Topology{ Servers: []description.Server{ - {Addr: address.Address("one"), Kind: description.Standalone}, - {Addr: address.Address("two"), Kind: description.Standalone}, - {Addr: address.Address("three"), Kind: description.Standalone}, + {Addr: address.Address("one"), Kind: description.ServerKindStandalone}, + {Addr: address.Address("two"), Kind: description.ServerKindStandalone}, + {Addr: address.Address("three"), Kind: description.ServerKindStandalone}, }, } select { @@ -183,9 +184,9 @@ func TestServerSelection(t *testing.T) { t.Run("Cancel", func(t *testing.T) { desc := description.Topology{ Servers: []description.Server{ - {Addr: address.Address("one"), Kind: description.Standalone}, - {Addr: address.Address("two"), Kind: description.Standalone}, - {Addr: address.Address("three"), Kind: description.Standalone}, + {Addr: address.Address("one"), Kind: description.ServerKindStandalone}, + {Addr: address.Address("two"), Kind: description.ServerKindStandalone}, + {Addr: address.Address("three"), Kind: description.ServerKindStandalone}, }, } topo, err := New(nil) @@ -220,9 +221,9 @@ func TestServerSelection(t *testing.T) { t.Run("Timeout", func(t *testing.T) { desc := description.Topology{ Servers: []description.Server{ - {Addr: address.Address("one"), Kind: description.Standalone}, - {Addr: address.Address("two"), Kind: description.Standalone}, - {Addr: address.Address("three"), Kind: description.Standalone}, + {Addr: address.Address("one"), Kind: description.ServerKindStandalone}, + {Addr: address.Address("two"), Kind: description.ServerKindStandalone}, + {Addr: address.Address("three"), Kind: description.ServerKindStandalone}, }, } topo, err := New(nil) @@ -256,9 +257,9 @@ func TestServerSelection(t *testing.T) { t.Run("Error", func(t *testing.T) { desc := description.Topology{ Servers: []description.Server{ - {Addr: address.Address("one"), Kind: description.Standalone}, - {Addr: address.Address("two"), Kind: description.Standalone}, - {Addr: address.Address("three"), Kind: description.Standalone}, + {Addr: address.Address("one"), Kind: description.ServerKindStandalone}, + {Addr: address.Address("two"), Kind: description.ServerKindStandalone}, + {Addr: address.Address("three"), Kind: description.ServerKindStandalone}, }, } topo, err := New(nil) @@ -291,15 +292,15 @@ func TestServerSelection(t *testing.T) { noerr(t, err) topo.servers[address.Address("one")] = srvr desc := topo.desc.Load().(description.Topology) - desc.Kind = description.Single + desc.Kind = description.TopologyKindSingle topo.desc.Store(desc) selected := description.Server{Addr: address.Address("one")} ss, err := topo.FindServer(selected) noerr(t, err) - if ss.Kind != description.Single { - t.Errorf("findServer does not properly set the topology description kind. got %v; want %v", ss.Kind, description.Single) + if ss.Kind != description.TopologyKindSingle { + t.Errorf("findServer does not properly set the topology description kind. got %v; want %v", ss.Kind, description.TopologyKindSingle) } }) t.Run("Update on not primary error", func(t *testing.T) { @@ -312,9 +313,9 @@ func TestServerSelection(t *testing.T) { addr3 := address.Address("three") desc := description.Topology{ Servers: []description.Server{ - {Addr: addr1, Kind: description.RSPrimary}, - {Addr: addr2, Kind: description.RSSecondary}, - {Addr: addr3, Kind: description.RSSecondary}, + {Addr: addr1, Kind: description.ServerKindRSPrimary}, + {Addr: addr2, Kind: description.ServerKindRSSecondary}, + {Addr: addr3, Kind: description.ServerKindRSSecondary}, }, } @@ -328,9 +329,9 @@ func TestServerSelection(t *testing.T) { // Send updated description desc = description.Topology{ Servers: []description.Server{ - {Addr: addr1, Kind: description.RSSecondary}, - {Addr: addr2, Kind: description.RSPrimary}, - {Addr: addr3, Kind: description.RSSecondary}, + {Addr: addr1, Kind: description.ServerKindRSSecondary}, + {Addr: addr2, Kind: description.ServerKindRSPrimary}, + {Addr: addr3, Kind: description.ServerKindRSSecondary}, }, } @@ -347,7 +348,7 @@ func TestServerSelection(t *testing.T) { go func() { // server selection should discover the new topology - state := newServerSelectionState(description.WriteSelector(), nil) + state := newServerSelectionState(&serverselector.Write{}, nil) srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state) noerr(t, err) resp <- srvs @@ -376,7 +377,7 @@ func TestServerSelection(t *testing.T) { primaryAddr := address.Address("one") desc := description.Topology{ Servers: []description.Server{ - {Addr: primaryAddr, Kind: description.RSPrimary}, + {Addr: primaryAddr, Kind: description.ServerKindRSPrimary}, }, } topo.desc.Store(desc) @@ -391,7 +392,7 @@ func TestServerSelection(t *testing.T) { topo.subscriptionsClosed = true ctx, cancel := context.WithCancel(context.Background()) cancel() - selectedServer, err := topo.SelectServer(ctx, description.WriteSelector()) + selectedServer, err := topo.SelectServer(ctx, &serverselector.Write{}) noerr(t, err) selectedAddr := selectedServer.(*SelectedServer).address assert.Equal(t, primaryAddr, selectedAddr, "expected address %v, got %v", primaryAddr, selectedAddr) @@ -407,7 +408,7 @@ func TestServerSelection(t *testing.T) { topo.desc.Store(desc) topo.subscriptionsClosed = true - _, err = topo.SelectServer(context.Background(), description.WriteSelector()) + _, err = topo.SelectServer(context.Background(), &serverselector.Write{}) assert.Equal(t, ErrSubscribeAfterClosed, err, "expected error %v, got %v", ErrSubscribeAfterClosed, err) }) } @@ -422,7 +423,7 @@ func TestSessionTimeout(t *testing.T) { topo.fsm.Servers = []description.Server{ { Addr: address.Address("foo").Canonicalize(), - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, SessionTimeoutMinutes: int64ToPtr(60), }, } @@ -432,7 +433,7 @@ func TestSessionTimeout(t *testing.T) { desc := description.Server{ Addr: "foo", - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, SessionTimeoutMinutes: int64ToPtr(30), } topo.apply(ctx, desc) @@ -445,18 +446,18 @@ func TestSessionTimeout(t *testing.T) { t.Run("MultipleUpdates", func(t *testing.T) { topo, err := New(nil) noerr(t, err) - topo.fsm.Kind = description.ReplicaSetWithPrimary + topo.fsm.Kind = description.TopologyKindReplicaSetWithPrimary topo.servers["foo"] = nil topo.servers["bar"] = nil topo.fsm.Servers = []description.Server{ { Addr: address.Address("foo").Canonicalize(), - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, SessionTimeoutMinutes: int64ToPtr(60), }, { Addr: address.Address("bar").Canonicalize(), - Kind: description.RSSecondary, + Kind: description.ServerKindRSSecondary, SessionTimeoutMinutes: int64ToPtr(60), }, } @@ -466,14 +467,14 @@ func TestSessionTimeout(t *testing.T) { desc1 := description.Server{ Addr: "foo", - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, SessionTimeoutMinutes: int64ToPtr(30), Members: []address.Address{address.Address("foo").Canonicalize(), address.Address("bar").Canonicalize()}, } // should update because new timeout is lower desc2 := description.Server{ Addr: "bar", - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, SessionTimeoutMinutes: int64ToPtr(20), Members: []address.Address{address.Address("foo").Canonicalize(), address.Address("bar").Canonicalize()}, } @@ -493,12 +494,12 @@ func TestSessionTimeout(t *testing.T) { topo.fsm.Servers = []description.Server{ { Addr: address.Address("foo").Canonicalize(), - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, SessionTimeoutMinutes: int64ToPtr(60), }, { Addr: address.Address("bar").Canonicalize(), - Kind: description.RSSecondary, + Kind: description.ServerKindRSSecondary, SessionTimeoutMinutes: int64ToPtr(60), }, } @@ -508,14 +509,14 @@ func TestSessionTimeout(t *testing.T) { desc1 := description.Server{ Addr: "foo", - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, SessionTimeoutMinutes: int64ToPtr(20), Members: []address.Address{address.Address("foo").Canonicalize(), address.Address("bar").Canonicalize()}, } // should not update because new timeout is higher desc2 := description.Server{ Addr: "bar", - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, SessionTimeoutMinutes: int64ToPtr(30), Members: []address.Address{address.Address("foo").Canonicalize(), address.Address("bar").Canonicalize()}, } @@ -535,12 +536,12 @@ func TestSessionTimeout(t *testing.T) { topo.fsm.Servers = []description.Server{ { Addr: address.Address("foo").Canonicalize(), - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, SessionTimeoutMinutes: int64ToPtr(60), }, { Addr: address.Address("bar").Canonicalize(), - Kind: description.RSSecondary, + Kind: description.ServerKindRSSecondary, SessionTimeoutMinutes: int64ToPtr(60), }, } @@ -550,7 +551,7 @@ func TestSessionTimeout(t *testing.T) { desc1 := description.Server{ Addr: "foo", - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, SessionTimeoutMinutes: int64ToPtr(20), Members: []address.Address{address.Address("foo").Canonicalize(), address.Address("bar").Canonicalize()}, } @@ -572,24 +573,24 @@ func TestSessionTimeout(t *testing.T) { t.Run("MixedSessionSupport", func(t *testing.T) { topo, err := New(nil) noerr(t, err) - topo.fsm.Kind = description.ReplicaSetWithPrimary + topo.fsm.Kind = description.TopologyKindReplicaSetWithPrimary topo.servers["one"] = nil topo.servers["two"] = nil topo.servers["three"] = nil topo.fsm.Servers = []description.Server{ { Addr: address.Address("one").Canonicalize(), - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, SessionTimeoutMinutes: int64ToPtr(20), }, { // does not support sessions Addr: address.Address("two").Canonicalize(), - Kind: description.RSSecondary, + Kind: description.ServerKindRSSecondary, }, { Addr: address.Address("three").Canonicalize(), - Kind: description.RSPrimary, + Kind: description.ServerKindRSPrimary, SessionTimeoutMinutes: int64ToPtr(60), }, } @@ -599,7 +600,7 @@ func TestSessionTimeout(t *testing.T) { desc := description.Server{ Addr: address.Address("three"), - Kind: description.RSSecondary, + Kind: description.ServerKindRSSecondary, SessionTimeoutMinutes: int64ToPtr(30), } @@ -1026,7 +1027,7 @@ func runInWindowTest(t *testing.T, directory string, filename string) { for i := 0; i < test.Iterations; i++ { selected, err := topology.SelectServer( context.Background(), - description.ReadPrefSelector(readpref.Nearest())) + &serverselector.ReadPref{ReadPref: readpref.Nearest()}) require.NoError(t, err, "error selecting server") counts[string(selected.(*SelectedServer).address)]++ } @@ -1073,17 +1074,17 @@ func topologyKindFromString(t *testing.T, s string) description.TopologyKind { switch s { case "Single": - return description.Single + return description.TopologyKindSingle case "ReplicaSet": - return description.ReplicaSet + return description.TopologyKindReplicaSet case "ReplicaSetNoPrimary": - return description.ReplicaSetNoPrimary + return description.TopologyKindReplicaSetNoPrimary case "ReplicaSetWithPrimary": - return description.ReplicaSetWithPrimary + return description.TopologyKindReplicaSetWithPrimary case "Sharded": - return description.Sharded + return description.TopologyKindSharded case "LoadBalanced": - return description.LoadBalanced + return description.TopologyKindLoadBalanced case "Unknown": return description.Unknown default: @@ -1098,21 +1099,21 @@ func serverKindFromString(t *testing.T, s string) description.ServerKind { switch s { case "Standalone": - return description.Standalone + return description.ServerKindStandalone case "RSOther": - return description.RSMember + return description.ServerKindRSMember case "RSPrimary": - return description.RSPrimary + return description.ServerKindRSPrimary case "RSSecondary": - return description.RSSecondary + return description.ServerKindRSSecondary case "RSArbiter": - return description.RSArbiter + return description.ServerKindRSArbiter case "RSGhost": - return description.RSGhost + return description.ServerKindRSGhost case "Mongos": - return description.Mongos + return description.ServerKindMongos case "LoadBalancer": - return description.LoadBalancer + return description.ServerKindLoadBalancer case "PossiblePrimary", "Unknown": // Go does not have a PossiblePrimary server type and per the SDAM spec, this type is synonymous with Unknown. return description.Unknown @@ -1163,7 +1164,7 @@ func BenchmarkSelectServerFromDescription(b *testing.B) { HeartbeatInterval: time.Duration(10) * time.Second, LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC), LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC), - Kind: description.Mongos, + Kind: description.ServerKindMongos, WireVersion: &description.VersionRange{Min: 6, Max: 21}, } servers := make([]description.Server, 100)