diff --git a/src/Weaviate.Client.Tests/Integration/TestIterator.cs b/src/Weaviate.Client.Tests/Integration/TestIterator.cs new file mode 100644 index 00000000..7f46a62b --- /dev/null +++ b/src/Weaviate.Client.Tests/Integration/TestIterator.cs @@ -0,0 +1,291 @@ +using Weaviate.Client.Models; +using Weaviate.Client.Models.Vectorizers; + +namespace Weaviate.Client.Tests.Integration; + +public partial class BasicTests +{ + [Fact] + public async Task Test_Iterator() + { + var collection = await CollectionFactory( + properties: [Property.Text("name")], + vectorConfig: Vector.Name("default").With(new VectorizerConfig.None()) + ); + + await collection.Data.InsertMany(new { Name = "Name 1" }, new { Name = "Name 2" }); + + var names = new List(); + await foreach ( + var obj in collection.Iterator(cancellationToken: TestContext.Current.CancellationToken) + ) + { + obj.Do(o => names.Add(o.Name)); + } + + Assert.Contains("Name 1", names); + Assert.Contains("Name 2", names); + } + + [Theory] + [InlineData(false, null, null, null)] + [InlineData(true, null, null, null)] + [InlineData(false, true, null, null)] + [InlineData(true, true, null, null)] + [InlineData(false, null, true, null)] + [InlineData(true, null, true, null)] + [InlineData(false, true, true, null)] + [InlineData(true, true, true, null)] + [InlineData(false, null, null, 100u)] + [InlineData(true, null, null, 100u)] + [InlineData(false, true, null, 100u)] + [InlineData(true, true, null, 100u)] + [InlineData(false, null, true, 100u)] + [InlineData(true, null, true, 100u)] + [InlineData(false, true, true, 100u)] + [InlineData(true, true, true, 100u)] + [InlineData(false, null, null, 10000u)] + [InlineData(true, null, null, 10000u)] + [InlineData(false, true, null, 10000u)] + [InlineData(true, true, null, 10000u)] + [InlineData(false, null, true, 10000u)] + [InlineData(true, null, true, 10000u)] + [InlineData(false, true, true, 10000u)] + [InlineData(true, true, true, 10000u)] + public async Task Test_Iterator_Arguments( + bool includeVector, + bool? returnFullMetadata, + bool? returnSpecificProperties, + uint? cacheSize + ) + { + var collection = await CollectionFactory( + properties: [Property.Int("data"), Property.Text("text")], + vectorConfig: Vector.Name("default").With(new VectorizerConfig.Text2VecContextionary()) + ); + + // Insert test data + var insertData = Enumerable + .Range(0, 10) + .Select(i => new { data = i, text = "hi" }) + .ToArray(); + await collection.Data.InsertMany(insertData); + + // Build metadata query + MetadataQuery? metadata = null; + if (includeVector && returnFullMetadata == true) + { + metadata = new MetadataQuery(MetadataOptions.Full | MetadataOptions.Vector); + } + else if (includeVector) + { + metadata = new MetadataQuery(MetadataOptions.Vector); + } + else if (returnFullMetadata == true) + { + metadata = new MetadataQuery(MetadataOptions.Full); + } + + // Build fields array + string[]? fields = null; + if (returnSpecificProperties == true) + { + fields = ["data"]; + } + + var iter = collection.Iterator( + metadata: metadata, + fields: fields, + cacheSize: cacheSize ?? CollectionClient.ITERATOR_CACHE_SIZE, + cancellationToken: TestContext.Current.CancellationToken + ); + + var objects = new List(); + await foreach (var obj in iter) + { + objects.Add(obj); + } + + // Verify we got all 10 objects + Assert.Equal(10, objects.Count); + + // Sort by data property for consistent comparison + var allData = objects.Select(obj => (long)obj.Properties["data"]!).OrderBy(x => x).ToList(); + Assert.Equal(Enumerable.Range(0, 10).Select(Convert.ToInt64).ToList(), allData); + + // Test expectations based on parameters + if (includeVector && returnSpecificProperties != true && returnFullMetadata == true) + { + // Expect everything back + Assert.All(objects, obj => Assert.True(obj.Properties.ContainsKey("text"))); + Assert.All(objects, obj => Assert.True(obj.Vectors.ContainsKey("default"))); + Assert.All(objects, obj => Assert.NotNull(obj.Metadata.CreationTime)); + Assert.All(objects, obj => Assert.NotNull(obj.Metadata.Score)); + } + else if (!includeVector && returnSpecificProperties != true && returnFullMetadata == true) + { + // Expect everything back except vector + Assert.All(objects, obj => Assert.True(obj.Properties.ContainsKey("text"))); + Assert.All(objects, obj => Assert.False(obj.Vectors.ContainsKey("default"))); + Assert.All(objects, obj => Assert.NotNull(obj.Metadata.CreationTime)); + Assert.All(objects, obj => Assert.NotNull(obj.Metadata.Score)); + } + else if (includeVector && returnSpecificProperties == true) + { + // Expect specified properties and vector + Assert.All(objects, obj => Assert.False(obj.Properties.ContainsKey("text"))); + Assert.All(objects, obj => Assert.True(obj.Vectors.ContainsKey("default"))); + if (returnFullMetadata == true) + { + Assert.All(objects, obj => Assert.NotNull(obj.Metadata.CreationTime)); + Assert.All(objects, obj => Assert.NotNull(obj.Metadata.Score)); + } + else + { + Assert.All(objects, obj => Assert.Null(obj.Metadata.CreationTime)); + } + } + else if (!includeVector && returnSpecificProperties == true) + { + // Expect specified properties and no vector + Assert.All(objects, obj => Assert.False(obj.Properties.ContainsKey("text"))); + Assert.All(objects, obj => Assert.False(obj.Vectors.ContainsKey("default"))); + if (returnFullMetadata == true) + { + Assert.All(objects, obj => Assert.NotNull(obj.Metadata.CreationTime)); + Assert.All(objects, obj => Assert.NotNull(obj.Metadata.Score)); + } + else + { + Assert.All(objects, obj => Assert.Null(obj.Metadata.CreationTime)); + } + } + } + + [Fact] + public async Task Test_Iterator_With_Default_Generic() + { + var collection = await CollectionFactory( + properties: [Property.Text("this"), Property.Text("that")], + vectorConfig: Vector.Name("default").With(new VectorizerConfig.None()) + ); + + var insertData = Enumerable + .Range(0, 10) + .Select(_ => new { @this = "this", that = "that" }) + .ToArray(); + await collection.Data.InsertMany(insertData); + + // Test with all properties + var allPropsIter = collection.Iterator( + cancellationToken: TestContext.Current.CancellationToken + ); + + await foreach (var obj in allPropsIter) + { + Assert.Equal("this", obj.Properties["this"]); + Assert.Equal("that", obj.Properties["that"]); + } + + // Test with specific properties + var specificPropsIter = collection.Iterator( + fields: ["this"], + cancellationToken: TestContext.Current.CancellationToken + ); + + await foreach (var obj in specificPropsIter) + { + Assert.Equal("this", obj.Properties["this"]); + Assert.False(obj.Properties.ContainsKey("that")); + } + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(2)] + [InlineData(CollectionClient.ITERATOR_CACHE_SIZE - 1)] + [InlineData(CollectionClient.ITERATOR_CACHE_SIZE)] + [InlineData(CollectionClient.ITERATOR_CACHE_SIZE + 1)] + [InlineData(2 * CollectionClient.ITERATOR_CACHE_SIZE - 1)] + [InlineData(2 * CollectionClient.ITERATOR_CACHE_SIZE)] + [InlineData(2 * CollectionClient.ITERATOR_CACHE_SIZE + 1)] + [InlineData(20 * CollectionClient.ITERATOR_CACHE_SIZE)] + public async Task Test_Iterator_Basic(uint count) + { + var collection = await CollectionFactory( + properties: [Property.Int("data")], + vectorConfig: Vector.Name("default").With(new VectorizerConfig.None()) + ); + + if (count > 0) + { + var insertData = Enumerable + .Range(0, (int)count) + .Select(i => new { data = i }) + .ToArray(); + await collection.Data.InsertMany(insertData); + } + + var expected = Enumerable.Range(0, (int)count).Select(x => Convert.ToInt64(x)).ToList(); + List? firstOrder = null; + + // Make sure a new iterator resets the internal state and that the return order is the same for every run + for (int run = 0; run < 3; run++) + { + var iter = collection.Iterator( + cancellationToken: TestContext.Current.CancellationToken + ); + + var ret = new List(); + await foreach (var obj in iter) + { + ret.Add((long)obj.Properties["data"]!); + } + + if (firstOrder == null) + { + firstOrder = ret; + } + else + { + Assert.Equal(firstOrder, ret); + } + + Assert.Equal(expected, ret.OrderBy(x => x).ToList()); + } + } + + [Fact] + public async Task Test_Iterator_With_After() + { + var collection = await CollectionFactory( + properties: [Property.Int("data")], + vectorConfig: Vector.Name("default").With(new VectorizerConfig.None()) + ); + + var insertData = Enumerable.Range(0, 10).Select(i => new { data = i }).ToArray(); + await collection.Data.InsertMany(insertData); + + // Get all UUIDs first + var allUuids = new List(); + var initialIter = collection.Iterator( + cancellationToken: TestContext.Current.CancellationToken + ); + + await foreach (var obj in initialIter) + { + allUuids.Add(obj.ID!.Value); + } + + // Test pagination with after parameter + var firstAfterObject = await collection + .Iterator(after: allUuids[5], cancellationToken: TestContext.Current.CancellationToken) + .FirstAsync(TestContext.Current.CancellationToken); + + // Fetch the object at index 6 to compare + var expectedObject = await collection.Query.FetchObjectByID(allUuids[6]); + + Assert.Equal(expectedObject!.Properties["data"]!, firstAfterObject.Properties["data"]!); + } +} diff --git a/src/Weaviate.Client/CollectionClient.cs b/src/Weaviate.Client/CollectionClient.cs index 266ac875..e97adec6 100644 --- a/src/Weaviate.Client/CollectionClient.cs +++ b/src/Weaviate.Client/CollectionClient.cs @@ -1,9 +1,12 @@ +using System.Runtime.CompilerServices; using Weaviate.Client.Models; namespace Weaviate.Client; public class CollectionClient { + public const uint ITERATOR_CACHE_SIZE = 100; + private readonly WeaviateClient _client; private DataClient _dataClient; private QueryClient _queryClient; @@ -66,4 +69,41 @@ internal async Task AddReference(ReferenceProperty referenceProperty) await _client.RestClient.CollectionAddProperty(_collectionName, dto); } + + public async IAsyncEnumerable Iterator( + Guid? after = null, + uint cacheSize = ITERATOR_CACHE_SIZE, + MetadataQuery? metadata = null, + string[]? fields = null, + IList? references = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default + ) + { + Guid? cursor = after; + + while (true) + { + cancellationToken.ThrowIfCancellationRequested(); + + var page = await _client.GrpcClient.FetchObjects( + _collectionName, + limit: cacheSize, + metadata: metadata, + fields: fields, + reference: references, + after: cursor + ); + + if (!page.Objects.Any()) + { + yield break; + } + + foreach (var c in page.Objects) + { + cursor = c.ID; + yield return c; + } + } + } } diff --git a/src/Weaviate.Client/DataClient.cs b/src/Weaviate.Client/DataClient.cs index f6a5028f..7807be87 100644 --- a/src/Weaviate.Client/DataClient.cs +++ b/src/Weaviate.Client/DataClient.cs @@ -588,6 +588,17 @@ public delegate void InsertDelegate( string? tenant = null ); + public async Task> InsertMany(params TData[] data) + { + return await InsertMany(add => + { + foreach (var o in data) + { + add(o); + } + }); + } + public async Task> InsertMany( params BatchInsertRequest[] requests ) diff --git a/src/Weaviate.Client/Models/MetadataQuery.cs b/src/Weaviate.Client/Models/MetadataQuery.cs index 391109b8..90193230 100644 --- a/src/Weaviate.Client/Models/MetadataQuery.cs +++ b/src/Weaviate.Client/Models/MetadataQuery.cs @@ -12,6 +12,8 @@ public enum MetadataOptions Score = 1 << 5, // 2^5 ExplainScore = 1 << 6, // 2^6 IsConsistent = 1 << 7, // 2^7 + Full = + CreationTime | LastUpdateTime | Distance | Certainty | Score | ExplainScore | IsConsistent, } public record MetadataQuery diff --git a/src/Weaviate.Client/Models/WeaviateObject.cs b/src/Weaviate.Client/Models/WeaviateObject.cs index d2331983..33bd362a 100644 --- a/src/Weaviate.Client/Models/WeaviateObject.cs +++ b/src/Weaviate.Client/Models/WeaviateObject.cs @@ -49,32 +49,38 @@ public partial record WeaviateObject new Dictionary>(); public NamedVectors Vectors { get; set; } = new NamedVectors(); +} - public T? As() +public static class WeaviateObjectExtensions +{ + public static T? As(this WeaviateObject obj) where T : new() { - return ObjectHelper.UnmarshallProperties(Properties); + return ObjectHelper.UnmarshallProperties(obj.Properties); } - public void Do(Action action) + public static void Do(this WeaviateObject obj, Action action) where T : new() { - var data = As(); + var data = obj.As(); if (data is not null) { action(data); } } - public void Do(Action action) + public static void Do(this WeaviateObject obj, Action action) { - Do(action); + obj.Do(action); } - public TResult? Get(Func func) + public static TResult? Get( + this WeaviateObject obj, + Func func + ) where TSource : new() { - var data = ObjectHelper.UnmarshallProperties(Properties); + var data = obj.As(); if (data is not null) { return func(data); @@ -82,8 +88,8 @@ public void Do(Action action) return default; } - public TResult? Get(Func func) + public static TResult? Get(this WeaviateObject obj, Func func) { - return Get(func); + return obj.Get(func); } } diff --git a/src/Weaviate.Client/gRPC/Search.cs b/src/Weaviate.Client/gRPC/Search.cs index 945afbd1..cf0bfd0e 100644 --- a/src/Weaviate.Client/gRPC/Search.cs +++ b/src/Weaviate.Client/gRPC/Search.cs @@ -14,7 +14,8 @@ internal SearchRequest BaseSearchRequest( GroupByRequest? groupBy = null, MetadataQuery? metadata = null, IList? reference = null, - string[]? fields = null + string[]? fields = null, + Guid? after = null ) { var metadataRequest = new MetadataRequest() @@ -54,6 +55,11 @@ internal SearchRequest BaseSearchRequest( Properties = MakePropsRequest(fields, reference), }; + if (after is not null) + { + request.After = after.ToString(); + } + if (sort is not null) { request.SortBy.AddRange(sort); @@ -333,7 +339,8 @@ internal async Task FetchObjects( uint? limit = null, string[]? fields = null, IList? reference = null, - MetadataQuery? metadata = null + MetadataQuery? metadata = null, + Guid? after = null ) { var req = BaseSearchRequest( @@ -343,7 +350,8 @@ internal async Task FetchObjects( limit, fields: fields, metadata: metadata, - reference: reference + reference: reference, + after: after ); SearchReply? reply = await _grpcClient.SearchAsync(req, headers: _defaultHeaders);