-
-
Notifications
You must be signed in to change notification settings - Fork 9
/
AssertCount.cs
131 lines (113 loc) · 3.57 KB
/
AssertCount.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
namespace SuperLinq;
public static partial class SuperEnumerable
{
/// <summary>
/// Asserts that a source sequence contains a given count of elements.
/// </summary>
/// <typeparam name="TSource">
/// Type of elements in <paramref name="source"/> sequence.
/// </typeparam>
/// <param name="source">
/// Source sequence.
/// </param>
/// <param name="count">
/// Count to assert.
/// </param>
/// <returns>
/// Returns the original sequence as long it is contains the number of elements specified by <paramref
/// name="count"/>. Otherwise it throws <see cref="ArgumentException" />.
/// </returns>
/// <exception cref="ArgumentNullException">
/// <paramref name="source"/> is <see langword="null" />.
/// </exception>
/// <exception cref="ArgumentOutOfRangeException">
/// <paramref name="count"/> is less than <c>0</c>.
/// </exception>
/// <exception cref="ArgumentException">
/// Thrown lazily <paramref name="source"/> has a length different than <paramref name="count"/>.
/// </exception>
/// <remarks>
/// The sequence length is evaluated lazily during the enumeration of the sequence.
/// </remarks>
public static IEnumerable<TSource> AssertCount<TSource>(this IEnumerable<TSource> source, int count)
{
ArgumentNullException.ThrowIfNull(source);
ArgumentOutOfRangeException.ThrowIfNegative(count);
if (source is IList<TSource> list)
return new AssertCountListIterator<TSource>(list, count);
if (source.TryGetCollectionCount() is int)
return new AssertCountCollectionIterator<TSource>(source, count);
return Core(source, count);
static IEnumerable<TSource> Core(IEnumerable<TSource> source, int count)
{
var c = 0;
foreach (var item in source)
{
if (++c > count)
break;
yield return item;
}
ArgumentOutOfRangeException.ThrowIfNotEqual(c, count, $"{nameof(source)}.Count()");
}
}
private sealed class AssertCountCollectionIterator<T>(
IEnumerable<T> source,
int count
) : CollectionIterator<T>
{
public override int Count
{
get
{
ArgumentOutOfRangeException.ThrowIfNotEqual(source.GetCollectionCount(), count, "source.Count()");
return count;
}
}
protected override IEnumerable<T> GetEnumerable()
{
ArgumentOutOfRangeException.ThrowIfNotEqual(source.GetCollectionCount(), count, "source.Count()");
foreach (var item in source)
yield return item;
}
public override void CopyTo(T[] array, int arrayIndex)
{
ArgumentNullException.ThrowIfNull(array);
ArgumentOutOfRangeException.ThrowIfNegative(arrayIndex);
ArgumentOutOfRangeException.ThrowIfGreaterThan(arrayIndex, array.Length - Count);
_ = source.CopyTo(array, arrayIndex);
}
}
private sealed class AssertCountListIterator<T>(
IList<T> source,
int count
) : ListIterator<T>
{
public override int Count
{
get
{
ArgumentOutOfRangeException.ThrowIfNotEqual(source.Count, count, "source.Count()");
return count;
}
}
protected override IEnumerable<T> GetEnumerable()
{
var cnt = (uint)Count;
for (var i = 0; i < cnt; i++)
yield return source[i];
}
public override void CopyTo(T[] array, int arrayIndex)
{
ArgumentNullException.ThrowIfNull(array);
ArgumentOutOfRangeException.ThrowIfNegative(arrayIndex);
ArgumentOutOfRangeException.ThrowIfGreaterThan(arrayIndex, array.Length - Count);
source.CopyTo(array, arrayIndex);
}
protected override T ElementAt(int index)
{
ArgumentOutOfRangeException.ThrowIfNegative(index);
ArgumentOutOfRangeException.ThrowIfGreaterThanOrEqual(index, Count);
return source[index];
}
}
}