Skip to content

Commit

Permalink
Performance improvements in CollectionTrackerExtensions.AsTracker
Browse files Browse the repository at this point in the history
  • Loading branch information
bradwilson committed Oct 28, 2023
1 parent 6a94a86 commit b3e3c93
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 55 deletions.
4 changes: 3 additions & 1 deletion CollectionAsserts.cs
Original file line number Diff line number Diff line change
Expand Up @@ -430,10 +430,12 @@ public static void Empty(IEnumerable collection)
GuardArgumentNotNull(nameof(collection), collection);

using (var tracker = collection.AsTracker())
using (var enumerator = tracker.GetEnumerator())
{
var enumerator = tracker.GetEnumerator();
if (enumerator.MoveNext())
throw EmptyException.ForNonEmptyCollection(tracker.FormatStart());
}
}

/// <summary>
/// Verifies that two sequences are equivalent, using a default comparer.
Expand Down
113 changes: 70 additions & 43 deletions Sdk/CollectionTracker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -304,19 +304,78 @@ protected CollectionTracker(IEnumerable innerEnumerable)
/// <inheritdoc/>
public abstract void Dispose();

/// <summary>
/// Formats the collection when you have a mismatched index. The formatted result will be the section of the
/// collection surrounded by the mismatched item.
/// </summary>
/// <param name="mismatchedIndex">The index of the mismatched item</param>
/// <param name="pointerIndent">How many spaces into the output value the pointed-to item begins at</param>
/// <param name="depth">The optional printing depth (1 indicates a top-level value)</param>
/// <returns>The formatted collection</returns>
public abstract string FormatIndexedMismatch(
int? mismatchedIndex,
out int? pointerIndent,
int depth = 1);

/// <summary>
/// Formats the collection when you have a mismatched index. The formatted result will be the section of the
/// collection from <paramref name="startIndex"/> to <paramref name="endIndex"/>. These indices are usually
/// obtained by calling <see cref="GetMismatchExtents"/>.
/// </summary>
/// <param name="startIndex">The start index of the collection to print</param>
/// <param name="endIndex">The end index of the collection to print</param>
/// <param name="mismatchedIndex">The mismatched item index</param>
/// <param name="pointerIndent">How many spaces into the output value the pointed-to item begins at</param>
/// <param name="depth">The optional printing depth (1 indicates a top-level value)</param>
/// <returns>The formatted collection</returns>
public abstract string FormatIndexedMismatch(
int startIndex,
int endIndex,
int? mismatchedIndex,
out int? pointerIndent,
int depth = 1);

/// <summary>
/// Formats the beginning part of the collection.
/// </summary>
/// <param name="depth">The optional printing depth (1 indicates a top-level value)</param>
/// <returns>The formatted collection</returns>
public abstract string FormatStart(int depth = 1);

/// <summary>
/// Gets the extents to print when you find a mismatched index, in the form of
/// a <paramref name="startIndex"/> and <paramref name="endIndex"/>. If the mismatched
/// index is <c>null</c>, the extents will start at index 0.
/// </summary>
/// <param name="mismatchedIndex">The mismatched item index</param>
/// <param name="startIndex">The start index that should be used for printing</param>
/// <param name="endIndex">The end index that should be used for printing</param>
public abstract void GetMismatchExtents(
int? mismatchedIndex,
out int startIndex,
out int endIndex);

/// <summary>
/// Gets a safe version of <see cref="IEnumerator"/> that prevents double enumeration and does all
/// the necessary tracking required for collection formatting. Should should be the same value
/// returned by <see cref="CollectionTracker{T}.GetEnumerator"/>, except non-generic.
/// </summary>
protected abstract IEnumerator GetSafeEnumerator();
protected internal abstract IEnumerator GetSafeEnumerator();

/// <summary>
/// Gets the full name of the type of the element at the given index, if known.
/// Since this uses the item cache produced by enumeration, it may return <c>null</c>
/// when we haven't enumerated enough to see the given element, or if we enumerated
/// so much that the item has left the cache, or if the item at the given index
/// is <c>null</c>. It will also return <c>null</c> when the <paramref name="index"/>
/// is <c>null</c>.
/// </summary>
/// <param name="index">The item index</param>
#if XUNIT_NULLABLE
public abstract string? TypeAt(int? index);
#else
public abstract string TypeAt(int? index);
#endif

/// <summary>
/// Wraps an untyped enumerable in an object-based <see cref="CollectionTracker{T}"/>.
Expand Down Expand Up @@ -379,15 +438,8 @@ sealed class CollectionTracker<T> : CollectionTracker, IEnumerable<T>
public override void Dispose() =>
enumerator?.DisposeInternal();

/// <summary>
/// Formats the collection when you have a mismatched index. The formatted result will be the section of the
/// collection surrounded by the mismatched item.
/// </summary>
/// <param name="mismatchedIndex">The index of the mismatched item</param>
/// <param name="pointerIndent">How many spaces into the output value the pointed-to item begins at</param>
/// <param name="depth">The optional printing depth (1 indicates a top-level value)</param>
/// <returns>The formatted collection</returns>
public string FormatIndexedMismatch(
/// <inheritdoc/>
public override string FormatIndexedMismatch(
int? mismatchedIndex,
out int? pointerIndent,
int depth = 1)
Expand Down Expand Up @@ -418,18 +470,8 @@ sealed class CollectionTracker<T> : CollectionTracker, IEnumerable<T>
);
}

/// <summary>
/// Formats the collection when you have a mismatched index. The formatted result will be the section of the
/// collection from <paramref name="startIndex"/> to <paramref name="endIndex"/>. These indices are usually
/// obtained by calling <see cref="GetMismatchExtents"/>.
/// </summary>
/// <param name="startIndex">The start index of the collection to print</param>
/// <param name="endIndex">The end index of the collection to print</param>
/// <param name="mismatchedIndex">The mismatched item index</param>
/// <param name="pointerIndent">How many spaces into the output value the pointed-to item begins at</param>
/// <param name="depth">The optional printing depth (1 indicates a top-level value)</param>
/// <returns>The formatted collection</returns>
public string FormatIndexedMismatch(
/// <inheritdoc/>
public override string FormatIndexedMismatch(
int startIndex,
int endIndex,
int? mismatchedIndex,
Expand Down Expand Up @@ -644,18 +686,11 @@ public IEnumerator<T> GetEnumerator()
GetEnumerator();

/// <inheritdoc/>
protected override IEnumerator GetSafeEnumerator() =>
protected internal override IEnumerator GetSafeEnumerator() =>
GetEnumerator();

/// <summary>
/// Gets the extents to print when you find a mismatched index, in the form of
/// a <paramref name="startIndex"/> and <paramref name="endIndex"/>. If the mismatched
/// index is <c>null</c>, the extents will start at index 0.
/// </summary>
/// <param name="mismatchedIndex">The mismatched item index</param>
/// <param name="startIndex">The start index that should be used for printing</param>
/// <param name="endIndex">The end index that should be used for printing</param>
public void GetMismatchExtents(
/// <inheritdoc/>
public override void GetMismatchExtents(
int? mismatchedIndex,
out int startIndex,
out int endIndex)
Expand All @@ -675,19 +710,11 @@ public IEnumerator<T> GetEnumerator()
startIndex = Math.Max(0, endIndex - ArgumentFormatter.MAX_ENUMERABLE_LENGTH + 1);
}

/// <summary>
/// Gets the full name of the type of the element at the given index, if known.
/// Since this uses the item cache produced by enumeration, it may return <c>null</c>
/// when we haven't enumerated enough to see the given element, or if we enumerated
/// so much that the item has left the cache, or if the item at the given index
/// is <c>null</c>. It will also return <c>null</c> when the <paramref name="index"/>
/// is <c>null</c>.
/// </summary>
/// <param name="index">The item index</param>
/// <inheritdoc/>
#if XUNIT_NULLABLE
public string? TypeAt(int? index)
public override string? TypeAt(int? index)
#else
public string TypeAt(int? index)
public override string TypeAt(int? index)
#endif
{
if (enumerator == null || !index.HasValue)
Expand Down
65 changes: 54 additions & 11 deletions Sdk/CollectionTrackerExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
#nullable enable
#else
// In case this is source-imported with global nullable enabled but no XUNIT_NULLABLE
#pragma warning disable CS8601
#pragma warning disable CS8603
#pragma warning disable CS8604
#endif

using System;
using System.Collections;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;

#if XUNIT_NULLABLE
using System.Diagnostics.CodeAnalysis;
Expand All @@ -25,32 +31,59 @@ namespace Xunit.Sdk
static class CollectionTrackerExtensions
{
#if XUNIT_NULLABLE
internal static IEnumerable? AsNonStringEnumerable(this object? value) =>
static readonly MethodInfo? asTrackerOpenGeneric = typeof(CollectionTrackerExtensions).GetRuntimeMethods().FirstOrDefault(m => m.Name == nameof(AsTracker) && m.IsGenericMethod);
#else
internal static IEnumerable AsNonStringEnumerable(this object value) =>
static readonly MethodInfo asTrackerOpenGeneric = typeof(CollectionTrackerExtensions).GetRuntimeMethods().FirstOrDefault(m => m.Name == nameof(AsTracker) && m.IsGenericMethod);
#endif
value == null || value is string ? null : value as IEnumerable;
static readonly ConcurrentDictionary<Type, MethodInfo> cacheOfAsTrackerByType = new ConcurrentDictionary<Type, MethodInfo>();

#if XUNIT_NULLABLE
internal static CollectionTracker<object>? AsNonStringTracker(this object? value) =>
internal static CollectionTracker? AsNonStringTracker(this object? value)
#else
internal static CollectionTracker<object> AsNonStringTracker(this object value) =>
internal static CollectionTracker AsNonStringTracker(this object value)
#endif
AsTracker(AsNonStringEnumerable(value));
{
if (value == null || value is string)
return null;

return AsTracker(value as IEnumerable);
}

/// <summary>
/// Wraps the given enumerable in an instance of <see cref="CollectionTracker{T}"/>.
/// </summary>
/// <param name="enumerable">The enumerable to be wrapped</param>
#if XUNIT_NULLABLE
[return: NotNullIfNotNull("enumerable")]
public static CollectionTracker<object>? AsTracker(this IEnumerable? enumerable) =>
public static CollectionTracker? AsTracker(this IEnumerable? enumerable)
#else
public static CollectionTracker<object> AsTracker(this IEnumerable enumerable) =>
public static CollectionTracker AsTracker(this IEnumerable enumerable)
#endif
{
if (enumerable == null)
return null;

var result = enumerable as CollectionTracker;
if (result != null)
return result;

// CollectionTracker.Wrap for the non-T enumerable uses the CastIterator, which has terrible
// performance during iteration. We do our best to try to get a T and dynamically invoke the
// generic version of AsTracker as we can.
var iEnumerableOfT = enumerable.GetType().GetTypeInfo().ImplementedInterfaces.FirstOrDefault(i => i.IsConstructedGenericType && i.GetGenericTypeDefinition() == typeof(IEnumerable<>));
if (iEnumerableOfT == null)
return CollectionTracker.Wrap(enumerable);

var enumerableType = iEnumerableOfT.GenericTypeArguments[0];
#if XUNIT_NULLABLE
var method = cacheOfAsTrackerByType.GetOrAdd(enumerableType, t => asTrackerOpenGeneric!.MakeGenericMethod(enumerableType));
#else
var method = cacheOfAsTrackerByType.GetOrAdd(enumerableType, t => asTrackerOpenGeneric.MakeGenericMethod(enumerableType));
#endif
enumerable == null
? null
: enumerable as CollectionTracker<object> ?? CollectionTracker.Wrap(enumerable);

result = method.Invoke(null, new object[] { enumerable }) as CollectionTracker;
return result ?? CollectionTracker.Wrap(enumerable);
}

/// <summary>
/// Wraps the given enumerable in an instance of <see cref="CollectionTracker{T}"/>.
Expand All @@ -66,5 +99,15 @@ static class CollectionTrackerExtensions
enumerable == null
? null
: enumerable as CollectionTracker<T> ?? CollectionTracker<T>.Wrap(enumerable);

/// <summary>
/// Enumerates the elements inside the collection tracker.
/// </summary>
public static IEnumerator GetEnumerator(this CollectionTracker tracker)
{
Assert.GuardArgumentNotNull(nameof(tracker), tracker);

return tracker.GetSafeEnumerator();
}
}
}

0 comments on commit b3e3c93

Please sign in to comment.