diff --git a/src/Weaviate.Client.Tests/Integration/TestBatch.cs b/src/Weaviate.Client.Tests/Integration/TestBatch.cs index 2307fc87..189abde9 100644 --- a/src/Weaviate.Client.Tests/Integration/TestBatch.cs +++ b/src/Weaviate.Client.Tests/Integration/TestBatch.cs @@ -1,4 +1,5 @@ using Weaviate.Client.Models; +using Weaviate.Client.Models.Vectorizers; namespace Weaviate.Client.Tests.Integration; @@ -44,4 +45,84 @@ public async Task InsertMany(string key) data.Select(d => d.References.Sum(r => r.Value.Count)).Sum() ); } + + [Fact] + public async Task Test_Batch_ReferenceAddMany() + { + // Setup referenced collection ("To") + var refCollection = await CollectionFactory( + name: "To", + vectorConfig: Vector.Name("default"), + properties: [Property.Int("number")] + ); + int numObjects = 10; + + // Insert objects into the referenced collection and get their UUIDs + var refInsertResult = await refCollection.Data.InsertMany(add => + Enumerable.Range(0, numObjects).ToList().ForEach(i => add(new { Number = i })) + ); + + Guid[] uuidsTo = [.. refInsertResult.Select(r => r.ID!.Value)]; + + // Setup main collection ("From") with a reference property + var collection = await CollectionFactory( + name: "From", + properties: [Property.Int("num")], + references: [Property.Reference("ref", refCollection.Name)], + vectorConfig: Vector.Name("default") + ); + + // Insert objects into the main collection and get their UUIDs + var fromInsertResult = await collection.Data.InsertMany(add => + Enumerable.Range(0, numObjects).ToList().ForEach(i => add(new { Num = i })) + ); + + Guid[] uuidsFrom = [.. fromInsertResult.Select(r => r.ID!.Value)]; + + // First batch: each "From" object references the "To" object with the same index + var batchReturn1 = await collection.Data.ReferenceAddMany( + Enumerable + .Range(0, numObjects) + .Select(i => new DataReference(uuidsFrom[i], "ref", uuidsTo[i])) + .ToArray() + ); + Assert.False(batchReturn1.HasErrors); + + // Second batch: each "From" object references the first 3 "To" objects + var batchReturn2 = await collection.Data.ReferenceAddMany( + Enumerable + .Range(0, numObjects) + .Select(i => new DataReference(uuidsFrom[i], "ref", uuidsTo.Take(3).ToArray())) + .ToArray() + ); + Assert.False(batchReturn2.HasErrors); + + // Fetch objects with references + var objects = await collection.Query.List( + properties: ["num"], + references: [new QueryReference(linkOn: "ref")] + ); + + foreach (var obj in objects) + { + var num = (long)obj.Properties["num"]!; + var refObjects = obj.References["ref"]; + + // The first reference should match the corresponding "To" object's "number" + Assert.Equal(num, (long)refObjects[0].Properties["number"]!); + Assert.Contains(refObjects[0].ID!.Value, uuidsTo); + + // There should be 4 references: 1 from the first batch, 3 from the second + Assert.Equal(4, refObjects.Count); + + // The next 3 references should have "number" properties 0, 1, 2 (order sorted) + var refs = refObjects + .Skip(1) + .Take(3) + .Select(r => (long)r.Properties["number"]!) + .OrderBy(x => x) + .ToList(); + Assert.Equal(new List { 0, 1, 2 }, refs); + } + } } diff --git a/src/Weaviate.Client/DataClient.cs b/src/Weaviate.Client/DataClient.cs index 736026b6..b2c495e9 100644 --- a/src/Weaviate.Client/DataClient.cs +++ b/src/Weaviate.Client/DataClient.cs @@ -1,543 +1,11 @@ using System.Collections.Frozen; -using System.ComponentModel; using System.Diagnostics; -using System.Dynamic; -using System.Reflection; using System.Text.Json; -using Google.Protobuf.WellKnownTypes; using Weaviate.Client.Models; using Weaviate.Client.Rest.Dto; namespace Weaviate.Client; -internal class ObjectHelper -{ - internal static T? UnmarshallProperties(IDictionary dict) - where T : new() - { - ArgumentNullException.ThrowIfNull(dict); - - // Create an instance of T using the default constructor - var instance = new T(); - - if (instance is IDictionary target) - { - foreach (var kvp in dict) - { - if (kvp.Value is IDictionary subDict) - { - object? nestedValue = UnmarshallProperties(subDict); - - target[kvp.Key.Capitalize()] = nestedValue ?? subDict; - } - else - { - if (kvp.Value?.GetType() == typeof(Rest.Dto.GeoCoordinates)) - { - var value = (Rest.Dto.GeoCoordinates)kvp.Value; - target[kvp.Key.Capitalize()] = new GeoCoordinate( - value.Latitude ?? 0f, - value.Longitude ?? 0f - ); - } - else - { - target[kvp.Key.Capitalize()] = kvp.Value; - } - } - } - return instance; - } - - var type = typeof(T); - var properties = type.GetProperties(BindingFlags.Public | BindingFlags.Instance) - .Where(p => p.CanWrite) - .ToArray(); - - foreach (var property in properties) - { - var matchingKey = dict.Keys.FirstOrDefault(k => - string.Equals(k, property.Name, StringComparison.OrdinalIgnoreCase) - ); - - if (matchingKey is null) - { - continue; - } - - var value = dict[matchingKey]; - - try - { - var convertedValue = ConvertValue(value, property.PropertyType); - property.SetValue(instance, convertedValue); - } - catch (Exception ex) - { - // Skip if conversion fails - Debug.WriteLine($"Failed to convert property {property.Name}: {ex.Message}"); - continue; - } - } - - return instance; - } - - private static object? ConvertValue(object? value, System.Type targetType) - { - // Handle null values - if (value == null) - { - if (IsNullableType(targetType) || !targetType.IsValueType) - { - return null; - } - // For non-nullable value types, return default value - return Activator.CreateInstance(targetType); - } - - // If types already match, return as-is - if (targetType.IsAssignableFrom(value.GetType())) - { - return value; - } - - // Handle nullable types - if (IsNullableType(targetType)) - { - var underlyingType = Nullable.GetUnderlyingType(targetType)!; - return ConvertValue(value, underlyingType); - } - - // Handle nested objects (dictionaries -> custom types) - if ( - value is IDictionary nestedDict - && !typeof(IDictionary).IsAssignableFrom(targetType) - ) - { - var method = typeof(ObjectHelper) - .GetMethod("UnmarshallProperties", BindingFlags.Static | BindingFlags.NonPublic)! - .MakeGenericMethod(targetType); - return method.Invoke(null, new object[] { nestedDict }); - } - - // Handle collections - if ( - IsCollectionType(targetType) - && value is System.Collections.IEnumerable enumerable - && !(value is string) - ) - { - return ConvertCollection(enumerable, targetType); - } - - // Handle enums - if (targetType.IsEnum) - { - if (value is string stringValue) - { - return System.Enum.Parse(targetType, stringValue, true); - } - return System.Enum.ToObject(targetType, value); - } - - // Try TypeConverter first (handles more cases than Convert.ChangeType) - var converter = TypeDescriptor.GetConverter(targetType); - if (converter.CanConvertFrom(value.GetType())) - { - return converter.ConvertFrom(value); - } - - // Fallback to Convert.ChangeType for basic types - return Convert.ChangeType(value, targetType); - } - - private static bool IsNullableType(System.Type type) - { - return type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>); - } - - private static bool IsCollectionType(System.Type type) - { - return type.IsArray - || ( - type.IsGenericType - && ( - type.GetGenericTypeDefinition() == typeof(List<>) - || type.GetGenericTypeDefinition() == typeof(IList<>) - || type.GetGenericTypeDefinition() == typeof(ICollection<>) - || type.GetGenericTypeDefinition() == typeof(IEnumerable<>) - ) - ); - } - - private static object? ConvertCollection( - System.Collections.IEnumerable source, - System.Type targetType - ) - { - if (targetType.IsArray) - { - var elementType = targetType.GetElementType()!; - var items = new List(); - - foreach (var item in source) - { - items.Add(ConvertValue(item, elementType)); - } - - var array = Array.CreateInstance(elementType, items.Count); - for (int i = 0; i < items.Count; i++) - { - array.SetValue(items[i], i); - } - return array; - } - - if (targetType.IsGenericType) - { - var elementType = targetType.GetGenericArguments()[0]; - var listType = typeof(List<>).MakeGenericType(elementType); - var list = (System.Collections.IList)Activator.CreateInstance(listType)!; - - foreach (var item in source) - { - list.Add(ConvertValue(item, elementType)); - } - - return list; - } - - // Fallback - convert to object array - var fallbackItems = new List(); - foreach (var item in source) - { - fallbackItems.Add(item); - } - return fallbackItems.ToArray(); - } - - internal static IDictionary BuildDataTransferObject(object? data) - { - var obj = new ExpandoObject(); - var propDict = obj as IDictionary; - - if (data is null) - { - return propDict; - } - - foreach (var propertyInfo in data.GetType().GetProperties()) - { - if (!propertyInfo.CanRead) - continue; // skip non-readable properties - - var value = propertyInfo.GetValue(data); - - if (value is null) - { - continue; - } - else if (propertyInfo.PropertyType.IsNativeType()) - { - propDict[propertyInfo.Name] = value; - } - else if (propertyInfo.PropertyType == typeof(GeoCoordinate)) - { - var newValue = (GeoCoordinate)value; - propDict[propertyInfo.Name] = new GeoCoordinates - { - Latitude = newValue.Latitude, - Longitude = newValue.Longitude, - }; - } - else - { - propDict[propertyInfo.Name] = BuildDataTransferObject(value); // recursive call - } - } - - return obj; - } - - internal static V1.BatchObject.Types.Properties BuildBatchProperties(TProps data) - { - var props = new V1.BatchObject.Types.Properties(); - - if (data is null) - { - return props; - } - - Google.Protobuf.WellKnownTypes.Struct? nonRefProps = null; - - foreach (var propertyInfo in data.GetType().GetProperties()) - { - if (propertyInfo is null) - { - continue; - } - - if (!propertyInfo.CanRead) - continue; // skip non-readable properties - - var value = propertyInfo.GetValue(data); - - if (value is null) - { - continue; - } - - if (propertyInfo.PropertyType.IsArray) - { - switch (value) - { - case bool[] v: - props.BooleanArrayProperties.Add( - new V1.BooleanArrayProperties() - { - PropName = propertyInfo.Name, - Values = { v }, - } - ); - break; - case int[] v: - props.IntArrayProperties.Add( - new V1.IntArrayProperties() - { - PropName = propertyInfo.Name, - Values = { v.Select(Convert.ToInt64) }, - } - ); - break; - case long[] v: - props.IntArrayProperties.Add( - new V1.IntArrayProperties() - { - PropName = propertyInfo.Name, - Values = { v }, - } - ); - break; - case double[] v: - props.NumberArrayProperties.Add( - new V1.NumberArrayProperties() - { - PropName = propertyInfo.Name, - ValuesBytes = v.ToByteString(), - } - ); - break; - case float[] v: - props.NumberArrayProperties.Add( - new V1.NumberArrayProperties() - { - PropName = propertyInfo.Name, - ValuesBytes = v.Select(Convert.ToDouble).ToByteString(), - } - ); - break; - case string[] v: - props.TextArrayProperties.Add( - new V1.TextArrayProperties() - { - PropName = propertyInfo.Name, - Values = { v }, - } - ); - break; - case Guid[] v: - props.TextArrayProperties.Add( - new V1.TextArrayProperties() - { - PropName = propertyInfo.Name, - Values = { v.Select(v => v.ToString()) }, - } - ); - break; - case DateTime[] v: - props.TextArrayProperties.Add( - new V1.TextArrayProperties() - { - PropName = propertyInfo.Name, - Values = { v.Select(v => v.ToUniversalTime().ToString("o")) }, - } - ); - break; - case DateTimeOffset[] v: - props.TextArrayProperties.Add( - new V1.TextArrayProperties() - { - PropName = propertyInfo.Name, - Values = { v.Select(dto => dto.ToUniversalTime().ToString("o")) }, - } - ); - break; - - // Handle general IEnumerable (e.g., List, HashSet) - case System.Collections.IEnumerable enumerable - when enumerable is IEnumerable bools: - props.BooleanArrayProperties.Add( - new V1.BooleanArrayProperties() - { - PropName = propertyInfo.Name, - Values = { bools }, - } - ); - continue; - case System.Collections.IEnumerable enumerable - when enumerable is IEnumerable ints: - props.IntArrayProperties.Add( - new V1.IntArrayProperties() - { - PropName = propertyInfo.Name, - Values = { ints.Select(Convert.ToInt64) }, - } - ); - continue; - case System.Collections.IEnumerable enumerable - when enumerable is IEnumerable longs: - props.IntArrayProperties.Add( - new V1.IntArrayProperties() - { - PropName = propertyInfo.Name, - Values = { longs }, - } - ); - continue; - case System.Collections.IEnumerable enumerable - when enumerable is IEnumerable doubles: - props.NumberArrayProperties.Add( - new V1.NumberArrayProperties() - { - PropName = propertyInfo.Name, - ValuesBytes = doubles.ToByteString(), - } - ); - continue; - case System.Collections.IEnumerable enumerable - when enumerable is IEnumerable floats: - props.NumberArrayProperties.Add( - new V1.NumberArrayProperties() - { - PropName = propertyInfo.Name, - ValuesBytes = floats.Select(f => (double)f).ToByteString(), - } - ); - continue; - case System.Collections.IEnumerable enumerable - when enumerable is IEnumerable strings: - props.TextArrayProperties.Add( - new V1.TextArrayProperties() - { - PropName = propertyInfo.Name, - Values = { strings }, - } - ); - continue; - case System.Collections.IEnumerable enumerable - when enumerable is IEnumerable guids: - props.TextArrayProperties.Add( - new V1.TextArrayProperties() - { - PropName = propertyInfo.Name, - Values = { guids.Select(g => g.ToString()) }, - } - ); - continue; - case System.Collections.IEnumerable enumerable - when enumerable is IEnumerable dateTimes: - props.TextArrayProperties.Add( - new V1.TextArrayProperties() - { - PropName = propertyInfo.Name, - Values = - { - dateTimes.Select(dt => dt.ToUniversalTime().ToString("o")), - }, - } - ); - continue; - case System.Collections.IEnumerable enumerable - when enumerable is IEnumerable dateTimeOffsets: - props.TextArrayProperties.Add( - new V1.TextArrayProperties() - { - PropName = propertyInfo.Name, - Values = - { - dateTimeOffsets.Select(dto => - dto.ToUniversalTime().ToString("o") - ), - }, - } - ); - continue; - default: - throw new WeaviateException( - $"Unsupported array type '{value.GetType().GetElementType()?.Name ?? value.GetType().Name}' for property '{propertyInfo.Name}'. Check the documentation for supported array value types." - ); - } - continue; // Move to the next property after handling array - } - - if (propertyInfo.PropertyType.IsNativeType()) - { - nonRefProps ??= new(); - - nonRefProps.Fields.Add(propertyInfo.Name, ConvertToProtoValue(value)); - } - } - - props.NonRefProperties = nonRefProps; - - return props; - } - - // Helper method to convert C# objects to protobuf Values - internal static Value ConvertToProtoValue(object obj) - { - return obj switch - { - null => Value.ForNull(), - bool b => Value.ForBool(b), - int i => Value.ForNumber(i), - long l => Value.ForNumber(l), - float f => Value.ForNumber(f), - double d => Value.ForNumber(d), - decimal dec => Value.ForNumber((double)dec), - string s => Value.ForString(s), - DateTime dt => Value.ForString(dt.ToUniversalTime().ToString("o")), - Guid uuid => Value.ForString(uuid.ToString()), - GeoCoordinate v => Value.ForStruct( - new Struct - { - Fields = - { - ["latitude"] = Value.ForNumber(v.Latitude), - ["longitude"] = Value.ForNumber(v.Longitude), - }, - } - ), - // Dictionary dict => Value.ForStruct(CreateStructFromDictionary(dict)), - // IEnumerable enumerable => CreateListValue(enumerable), - _ => throw new ArgumentException($"Unsupported type: {obj.GetType()}"), - }; - } - - internal static Guid GuidFromByteString(Google.Protobuf.ByteString x) - { - byte[] bytes = x.ToByteArray(); - if (BitConverter.IsLittleEndian) - { - Array.Reverse(bytes, 0, 4); // Reverse first 4 bytes - Array.Reverse(bytes, 4, 2); // Reverse next 2 bytes - Array.Reverse(bytes, 6, 2); // Reverse next 2 bytes - } - return new Guid(bytes); - } -} - public class DataClient { private readonly CollectionClient _collectionClient; @@ -549,17 +17,6 @@ internal DataClient(CollectionClient collectionClient) _collectionClient = collectionClient; } - public static IDictionary[] MakeBeacons(params Guid[] guids) - { - return - [ - .. guids.Select(uuid => new Dictionary - { - { "beacon", $"weaviate://localhost/{uuid}" }, - }), - ]; - } - public async Task Insert( TData data, Guid? id = null, @@ -572,7 +29,7 @@ public async Task Insert( foreach (var kvp in references ?? []) { - propDict[kvp.Name] = MakeBeacons(kvp.TargetID); + propDict[kvp.Name] = ObjectHelper.MakeBeacons(kvp.TargetID); } var dtoVectors = @@ -724,11 +181,46 @@ public async Task Delete(Guid id) await _client.RestClient.DeleteObject(_collectionName, id); } + public async Task ReferenceAdd(DataReference reference) + { + await _client.RestClient.ReferenceAdd( + _collectionName, + reference.From, + reference.FromProperty, + reference.To.Single() + ); + } + public async Task ReferenceAdd(Guid from, string fromProperty, Guid to) { await _client.RestClient.ReferenceAdd(_collectionName, from, fromProperty, to); } + public async Task ReferenceAddMany(params DataReference[] references) + { + var stopwatch = Stopwatch.StartNew(); + + var result = await _client.RestClient.ReferenceAddMany(_collectionName, references); + + stopwatch.Stop(); + var elapsedSeconds = (float)stopwatch.Elapsed.TotalSeconds; + + var errorsByIndex = result + .Select((r, idx) => new { r.Result, Index = idx }) + .Where(r => (r.Result?.Status ?? ResultStatus.SUCCESS) == ResultStatus.FAILED) + .ToDictionary( + entry => entry.Index, + entry => + { + var errors = entry.Result?.Errors?.Error ?? Enumerable.Empty(); + + return errors.Select(e => new WeaviateException(e.Message)).ToArray(); + } + ); + + return new BatchReferenceReturn(elapsedSeconds, errorsByIndex); + } + public async Task ReferenceReplace(Guid from, string fromProperty, Guid[] to) { await _client.RestClient.ReferenceReplace(_collectionName, from, fromProperty, to); diff --git a/src/Weaviate.Client/Models/QueryReference.cs b/src/Weaviate.Client/Models/QueryReference.cs index 73662bd9..8a899bf0 100644 --- a/src/Weaviate.Client/Models/QueryReference.cs +++ b/src/Weaviate.Client/Models/QueryReference.cs @@ -3,7 +3,7 @@ namespace Weaviate.Client.Models; public record QueryReference { public string LinkOn { get; init; } - public string[] Fields { get; init; } + public string[]? Fields { get; init; } public MetadataQuery? Metadata { get; init; } public IList? References { get; init; } @@ -15,7 +15,7 @@ params QueryReference[]? references ) { LinkOn = linkOn; - Fields = fields ?? []; + Fields = fields; Metadata = metadata; References = references; } diff --git a/src/Weaviate.Client/Models/Reference.cs b/src/Weaviate.Client/Models/Reference.cs new file mode 100644 index 00000000..a346925e --- /dev/null +++ b/src/Weaviate.Client/Models/Reference.cs @@ -0,0 +1,48 @@ +namespace Weaviate.Client.Models; + +public record DataReference(Guid From, string FromProperty, params Guid[] To); + +public record BatchReferenceReturn +{ + public float ElapsedSeconds { get; init; } = 0.0f; + public Dictionary Errors { get; init; } = new(); + public bool HasErrors { get; init; } = false; + + public BatchReferenceReturn(float elapsedSeconds, Dictionary errors) + { + ElapsedSeconds = elapsedSeconds; + Errors = errors; + HasErrors = errors.Count > 0; + } + + public static BatchReferenceReturn operator +( + BatchReferenceReturn left, + BatchReferenceReturn right + ) + { + var result = new BatchReferenceReturn( + left.ElapsedSeconds + right.ElapsedSeconds, + new Dictionary(left.Errors) + ); + foreach (var kvp in right.Errors) + { + if (result.Errors.ContainsKey(kvp.Key)) + result.Errors[kvp.Key] = kvp.Value; + else + result.Errors.Add(kvp.Key, kvp.Value); + } + + return result with + { + HasErrors = left.HasErrors || right.HasErrors, + }; + } + + public override string ToString() + { + var errorsStr = string.Join(", ", Errors.Select(x => $"({x.Key}, {x.Value})")); + if (errorsStr.Length > 0) + errorsStr = $"{{...}}"; + return $"ElapsedSeconds: {ElapsedSeconds}, Errors: [{errorsStr}], HasErrors: {HasErrors}"; + } +} diff --git a/src/Weaviate.Client/ObjectHelper.cs b/src/Weaviate.Client/ObjectHelper.cs new file mode 100644 index 00000000..91444d4b --- /dev/null +++ b/src/Weaviate.Client/ObjectHelper.cs @@ -0,0 +1,551 @@ +using System.ComponentModel; +using System.Diagnostics; +using System.Dynamic; +using System.Reflection; +using Google.Protobuf.WellKnownTypes; +using Weaviate.Client.Models; +using Weaviate.Client.Rest.Dto; + +namespace Weaviate.Client; + +internal class ObjectHelper +{ + public static string MakeBeaconSource(string collection, Guid fromUuid, string fromProperty) => + $"weaviate://localhost/{collection}/{fromUuid}/{fromProperty}"; + + public static IDictionary[] MakeBeacons(params Guid[] guids) + { + return + [ + .. guids.Select(uuid => new Dictionary + { + { "beacon", $"weaviate://localhost/{uuid}" }, + }), + ]; + } + + internal static T? UnmarshallProperties(IDictionary dict) + where T : new() + { + ArgumentNullException.ThrowIfNull(dict); + + // Create an instance of T using the default constructor + var instance = new T(); + + if (instance is IDictionary target) + { + foreach (var kvp in dict) + { + if (kvp.Value is IDictionary subDict) + { + object? nestedValue = UnmarshallProperties(subDict); + + target[kvp.Key.Capitalize()] = nestedValue ?? subDict; + } + else + { + if (kvp.Value?.GetType() == typeof(Rest.Dto.GeoCoordinates)) + { + var value = (Rest.Dto.GeoCoordinates)kvp.Value; + target[kvp.Key.Capitalize()] = new GeoCoordinate( + value.Latitude ?? 0f, + value.Longitude ?? 0f + ); + } + else + { + target[kvp.Key.Capitalize()] = kvp.Value; + } + } + } + return instance; + } + + var type = typeof(T); + var properties = type.GetProperties(BindingFlags.Public | BindingFlags.Instance) + .Where(p => p.CanWrite) + .ToArray(); + + foreach (var property in properties) + { + var matchingKey = dict.Keys.FirstOrDefault(k => + string.Equals(k, property.Name, StringComparison.OrdinalIgnoreCase) + ); + + if (matchingKey is null) + { + continue; + } + + var value = dict[matchingKey]; + + try + { + var convertedValue = ConvertValue(value, property.PropertyType); + property.SetValue(instance, convertedValue); + } + catch (Exception ex) + { + // Skip if conversion fails + Debug.WriteLine($"Failed to convert property {property.Name}: {ex.Message}"); + continue; + } + } + + return instance; + } + + private static object? ConvertValue(object? value, System.Type targetType) + { + // Handle null values + if (value == null) + { + if (IsNullableType(targetType) || !targetType.IsValueType) + { + return null; + } + // For non-nullable value types, return default value + return Activator.CreateInstance(targetType); + } + + // If types already match, return as-is + if (targetType.IsAssignableFrom(value.GetType())) + { + return value; + } + + // Handle nullable types + if (IsNullableType(targetType)) + { + var underlyingType = Nullable.GetUnderlyingType(targetType)!; + return ConvertValue(value, underlyingType); + } + + // Handle nested objects (dictionaries -> custom types) + if ( + value is IDictionary nestedDict + && !typeof(IDictionary).IsAssignableFrom(targetType) + ) + { + var method = typeof(ObjectHelper) + .GetMethod("UnmarshallProperties", BindingFlags.Static | BindingFlags.NonPublic)! + .MakeGenericMethod(targetType); + return method.Invoke(null, new object[] { nestedDict }); + } + + // Handle collections + if ( + IsCollectionType(targetType) + && value is System.Collections.IEnumerable enumerable + && !(value is string) + ) + { + return ConvertCollection(enumerable, targetType); + } + + // Handle enums + if (targetType.IsEnum) + { + if (value is string stringValue) + { + return System.Enum.Parse(targetType, stringValue, true); + } + return System.Enum.ToObject(targetType, value); + } + + // Try TypeConverter first (handles more cases than Convert.ChangeType) + var converter = TypeDescriptor.GetConverter(targetType); + if (converter.CanConvertFrom(value.GetType())) + { + return converter.ConvertFrom(value); + } + + // Fallback to Convert.ChangeType for basic types + return Convert.ChangeType(value, targetType); + } + + private static bool IsNullableType(System.Type type) + { + return type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Nullable<>); + } + + private static bool IsCollectionType(System.Type type) + { + return type.IsArray + || ( + type.IsGenericType + && ( + type.GetGenericTypeDefinition() == typeof(List<>) + || type.GetGenericTypeDefinition() == typeof(IList<>) + || type.GetGenericTypeDefinition() == typeof(ICollection<>) + || type.GetGenericTypeDefinition() == typeof(IEnumerable<>) + ) + ); + } + + private static object? ConvertCollection( + System.Collections.IEnumerable source, + System.Type targetType + ) + { + if (targetType.IsArray) + { + var elementType = targetType.GetElementType()!; + var items = new List(); + + foreach (var item in source) + { + items.Add(ConvertValue(item, elementType)); + } + + var array = Array.CreateInstance(elementType, items.Count); + for (int i = 0; i < items.Count; i++) + { + array.SetValue(items[i], i); + } + return array; + } + + if (targetType.IsGenericType) + { + var elementType = targetType.GetGenericArguments()[0]; + var listType = typeof(List<>).MakeGenericType(elementType); + var list = (System.Collections.IList)Activator.CreateInstance(listType)!; + + foreach (var item in source) + { + list.Add(ConvertValue(item, elementType)); + } + + return list; + } + + // Fallback - convert to object array + var fallbackItems = new List(); + foreach (var item in source) + { + fallbackItems.Add(item); + } + return fallbackItems.ToArray(); + } + + internal static IDictionary BuildDataTransferObject(object? data) + { + var obj = new ExpandoObject(); + var propDict = obj as IDictionary; + + if (data is null) + { + return propDict; + } + + foreach (var propertyInfo in data.GetType().GetProperties()) + { + if (!propertyInfo.CanRead) + continue; // skip non-readable properties + + var value = propertyInfo.GetValue(data); + + if (value is null) + { + continue; + } + else if (propertyInfo.PropertyType.IsNativeType()) + { + propDict[propertyInfo.Name] = value; + } + else if (propertyInfo.PropertyType == typeof(GeoCoordinate)) + { + var newValue = (GeoCoordinate)value; + propDict[propertyInfo.Name] = new GeoCoordinates + { + Latitude = newValue.Latitude, + Longitude = newValue.Longitude, + }; + } + else + { + propDict[propertyInfo.Name] = BuildDataTransferObject(value); // recursive call + } + } + + return obj; + } + + internal static V1.BatchObject.Types.Properties BuildBatchProperties(TProps data) + { + var props = new V1.BatchObject.Types.Properties(); + + if (data is null) + { + return props; + } + + Google.Protobuf.WellKnownTypes.Struct? nonRefProps = null; + + foreach (var propertyInfo in data.GetType().GetProperties()) + { + if (propertyInfo is null) + { + continue; + } + + if (!propertyInfo.CanRead) + continue; // skip non-readable properties + + var value = propertyInfo.GetValue(data); + + if (value is null) + { + continue; + } + + if (propertyInfo.PropertyType.IsArray) + { + switch (value) + { + case bool[] v: + props.BooleanArrayProperties.Add( + new V1.BooleanArrayProperties() + { + PropName = propertyInfo.Name, + Values = { v }, + } + ); + break; + case int[] v: + props.IntArrayProperties.Add( + new V1.IntArrayProperties() + { + PropName = propertyInfo.Name, + Values = { v.Select(Convert.ToInt64) }, + } + ); + break; + case long[] v: + props.IntArrayProperties.Add( + new V1.IntArrayProperties() + { + PropName = propertyInfo.Name, + Values = { v }, + } + ); + break; + case double[] v: + props.NumberArrayProperties.Add( + new V1.NumberArrayProperties() + { + PropName = propertyInfo.Name, + ValuesBytes = v.ToByteString(), + } + ); + break; + case float[] v: + props.NumberArrayProperties.Add( + new V1.NumberArrayProperties() + { + PropName = propertyInfo.Name, + ValuesBytes = v.Select(Convert.ToDouble).ToByteString(), + } + ); + break; + case string[] v: + props.TextArrayProperties.Add( + new V1.TextArrayProperties() + { + PropName = propertyInfo.Name, + Values = { v }, + } + ); + break; + case Guid[] v: + props.TextArrayProperties.Add( + new V1.TextArrayProperties() + { + PropName = propertyInfo.Name, + Values = { v.Select(v => v.ToString()) }, + } + ); + break; + case DateTime[] v: + props.TextArrayProperties.Add( + new V1.TextArrayProperties() + { + PropName = propertyInfo.Name, + Values = { v.Select(v => v.ToUniversalTime().ToString("o")) }, + } + ); + break; + case DateTimeOffset[] v: + props.TextArrayProperties.Add( + new V1.TextArrayProperties() + { + PropName = propertyInfo.Name, + Values = { v.Select(dto => dto.ToUniversalTime().ToString("o")) }, + } + ); + break; + + // Handle general IEnumerable (e.g., List, HashSet) + case System.Collections.IEnumerable enumerable + when enumerable is IEnumerable bools: + props.BooleanArrayProperties.Add( + new V1.BooleanArrayProperties() + { + PropName = propertyInfo.Name, + Values = { bools }, + } + ); + continue; + case System.Collections.IEnumerable enumerable + when enumerable is IEnumerable ints: + props.IntArrayProperties.Add( + new V1.IntArrayProperties() + { + PropName = propertyInfo.Name, + Values = { ints.Select(Convert.ToInt64) }, + } + ); + continue; + case System.Collections.IEnumerable enumerable + when enumerable is IEnumerable longs: + props.IntArrayProperties.Add( + new V1.IntArrayProperties() + { + PropName = propertyInfo.Name, + Values = { longs }, + } + ); + continue; + case System.Collections.IEnumerable enumerable + when enumerable is IEnumerable doubles: + props.NumberArrayProperties.Add( + new V1.NumberArrayProperties() + { + PropName = propertyInfo.Name, + ValuesBytes = doubles.ToByteString(), + } + ); + continue; + case System.Collections.IEnumerable enumerable + when enumerable is IEnumerable floats: + props.NumberArrayProperties.Add( + new V1.NumberArrayProperties() + { + PropName = propertyInfo.Name, + ValuesBytes = floats.Select(f => (double)f).ToByteString(), + } + ); + continue; + case System.Collections.IEnumerable enumerable + when enumerable is IEnumerable strings: + props.TextArrayProperties.Add( + new V1.TextArrayProperties() + { + PropName = propertyInfo.Name, + Values = { strings }, + } + ); + continue; + case System.Collections.IEnumerable enumerable + when enumerable is IEnumerable guids: + props.TextArrayProperties.Add( + new V1.TextArrayProperties() + { + PropName = propertyInfo.Name, + Values = { guids.Select(g => g.ToString()) }, + } + ); + continue; + case System.Collections.IEnumerable enumerable + when enumerable is IEnumerable dateTimes: + props.TextArrayProperties.Add( + new V1.TextArrayProperties() + { + PropName = propertyInfo.Name, + Values = + { + dateTimes.Select(dt => dt.ToUniversalTime().ToString("o")), + }, + } + ); + continue; + case System.Collections.IEnumerable enumerable + when enumerable is IEnumerable dateTimeOffsets: + props.TextArrayProperties.Add( + new V1.TextArrayProperties() + { + PropName = propertyInfo.Name, + Values = + { + dateTimeOffsets.Select(dto => + dto.ToUniversalTime().ToString("o") + ), + }, + } + ); + continue; + default: + throw new WeaviateException( + $"Unsupported array type '{value.GetType().GetElementType()?.Name ?? value.GetType().Name}' for property '{propertyInfo.Name}'. Check the documentation for supported array value types." + ); + } + continue; // Move to the next property after handling array + } + + if (propertyInfo.PropertyType.IsNativeType()) + { + nonRefProps ??= new(); + + nonRefProps.Fields.Add(propertyInfo.Name, ConvertToProtoValue(value)); + } + } + + props.NonRefProperties = nonRefProps; + + return props; + } + + // Helper method to convert C# objects to protobuf Values + internal static Value ConvertToProtoValue(object obj) + { + return obj switch + { + null => Value.ForNull(), + bool b => Value.ForBool(b), + int i => Value.ForNumber(i), + long l => Value.ForNumber(l), + float f => Value.ForNumber(f), + double d => Value.ForNumber(d), + decimal dec => Value.ForNumber((double)dec), + string s => Value.ForString(s), + DateTime dt => Value.ForString(dt.ToUniversalTime().ToString("o")), + Guid uuid => Value.ForString(uuid.ToString()), + GeoCoordinate v => Value.ForStruct( + new Struct + { + Fields = + { + ["latitude"] = Value.ForNumber(v.Latitude), + ["longitude"] = Value.ForNumber(v.Longitude), + }, + } + ), + // Dictionary dict => Value.ForStruct(CreateStructFromDictionary(dict)), + // IEnumerable enumerable => CreateListValue(enumerable), + _ => throw new ArgumentException($"Unsupported type: {obj.GetType()}"), + }; + } + + internal static Guid GuidFromByteString(Google.Protobuf.ByteString x) + { + byte[] bytes = x.ToByteArray(); + if (BitConverter.IsLittleEndian) + { + Array.Reverse(bytes, 0, 4); // Reverse first 4 bytes + Array.Reverse(bytes, 4, 2); // Reverse next 2 bytes + Array.Reverse(bytes, 6, 2); // Reverse next 2 bytes + } + return new Guid(bytes); + } +} diff --git a/src/Weaviate.Client/QueryClient.cs b/src/Weaviate.Client/QueryClient.cs index 212f46de..934e357e 100644 --- a/src/Weaviate.Client/QueryClient.cs +++ b/src/Weaviate.Client/QueryClient.cs @@ -16,6 +16,7 @@ public QueryClient(CollectionClient collectionClient) #region Objects public async Task List( + string[]? properties = null, uint? limit = null, Filter? filter = null, IEnumerable? sort = null, @@ -25,6 +26,7 @@ public async Task List( { return await _client.GrpcClient.FetchObjects( _collectionName, + fields: properties, limit: limit, sort: sort, filter: filter, @@ -35,6 +37,7 @@ public async Task List( public async Task FetchObjectByID( Guid id, + string[]? properties = null, IList? references = null, MetadataQuery? metadata = null ) @@ -42,6 +45,7 @@ public async Task List( return ( await _client.GrpcClient.FetchObjects( _collectionName, + fields: properties, filter: Filter.WithID(id), reference: references, metadata: metadata @@ -51,6 +55,7 @@ await _client.GrpcClient.FetchObjects( public async Task FetchObjectsByIDs( ISet ids, + string[]? properties = null, uint? limit = null, IList? references = null, MetadataQuery? metadata = null @@ -58,6 +63,7 @@ public async Task FetchObjectsByIDs( { return await _client.GrpcClient.FetchObjects( _collectionName, + fields: properties, limit: limit, filter: Filter.WithIDs(ids), reference: references, diff --git a/src/Weaviate.Client/Rest/Client.cs b/src/Weaviate.Client/Rest/Client.cs index 7605a276..3d12fb9a 100644 --- a/src/Weaviate.Client/Rest/Client.cs +++ b/src/Weaviate.Client/Rest/Client.cs @@ -190,7 +190,7 @@ internal async Task ReferenceAdd(string collectionName, Guid from, string fromPr { var path = WeaviateEndpoints.Reference(collectionName, from, fromProperty); - var beacons = DataClient.MakeBeacons(to); + var beacons = ObjectHelper.MakeBeacons(to); var reference = beacons.First(); var response = await _httpClient.PostAsJsonAsync(path, reference); @@ -207,7 +207,7 @@ Guid[] to { var path = WeaviateEndpoints.Reference(collectionName, from, fromProperty); - var beacons = DataClient.MakeBeacons(to); + var beacons = ObjectHelper.MakeBeacons(to); var reference = beacons; var response = await _httpClient.PutAsJsonAsync(path, reference); @@ -224,7 +224,7 @@ Guid to { var path = WeaviateEndpoints.Reference(collectionName, from, fromProperty); - var beacons = DataClient.MakeBeacons(to); + var beacons = ObjectHelper.MakeBeacons(to); var reference = beacons.First(); var request = new HttpRequestMessage(HttpMethod.Delete, path); @@ -243,4 +243,32 @@ internal async Task CollectionAddProperty(string collectionName, Property proper await response.EnsureExpectedStatusCodeAsync([200], "collection property add"); } + + internal async Task ReferenceAddMany( + string collectionName, + Models.DataReference[] references + ) + { + var batchRefs = references.SelectMany(r => + ObjectHelper + .MakeBeacons(r.To) + .SelectMany(b => b.Values) + .Select(beacon => new BatchReference + { + From = new Uri( + ObjectHelper.MakeBeaconSource(collectionName, r.From, r.FromProperty) + ), + To = new Uri(beacon), + }) + ); + + var path = WeaviateEndpoints.ReferencesAdd(); + + var response = await _httpClient.PostAsJsonAsync(path, batchRefs); + + await response.EnsureExpectedStatusCodeAsync([200], "reference add many"); + + return await response.Content.ReadFromJsonAsync() + ?? throw new WeaviateRestException(); + } } diff --git a/src/Weaviate.Client/Rest/Endpoints.cs b/src/Weaviate.Client/Rest/Endpoints.cs index b24b5b36..ea4ae9ed 100644 --- a/src/Weaviate.Client/Rest/Endpoints.cs +++ b/src/Weaviate.Client/Rest/Endpoints.cs @@ -25,4 +25,6 @@ public static string CollectionTenant(string className, string tenantName) => internal static string? Reference(string collectionName, Guid from, string fromProperty) => $"objects/{collectionName}/{from}/references/{fromProperty}"; + + internal static string? ReferencesAdd() => "batch/references"; }