Skip to content

Commit

Permalink
Vector SIMD refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
sys27 committed Jul 12, 2023
1 parent ee7cd1f commit 4ec3016
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 30 deletions.
5 changes: 5 additions & 0 deletions xFunc.Benchmark/Benchmarks/VectorBenchmark.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using System.Collections.Immutable;
using xFunc.Maths.Expressions.Statistical;

namespace xFunc.Benchmark.Benchmarks;

Expand Down Expand Up @@ -48,4 +49,8 @@ public object MulVectors()
[Benchmark]
public object MulVectorByNumber()
=> new Mul(vector1, Number.Two).Execute();

[Benchmark]
public object SumVector()
=> new Sum(new[] { vector1 }).Execute();
}
46 changes: 16 additions & 30 deletions xFunc.Maths/Expressions/Matrices/VectorValue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -196,35 +196,28 @@ public int Size
public static object Abs(VectorValue vector)
{
var size = Simd.Vector<double>.Count;
var i = 0;
var sum = 0.0;

if (Simd.Vector.IsHardwareAccelerated && vector.Size >= size)
{
var span = Unsafe.As<double[]>(vector.array).AsSpan();
var span = Unsafe.As<double[]>(vector.array);

var v = Simd.Vector<double>.Zero;
var i = 0;

for (; i <= span.Length - size; i += size)
{
var chunkVector = new Simd.Vector<double>(span[i..]);
var chunkVector = new Simd.Vector<double>(span, i);
v += chunkVector * chunkVector;
}

var sum = Simd.Vector.Sum(v);

for (; i < span.Length; i++)
sum += span[i] * span[i];

return NumberValue.Sqrt(new NumberValue(sum));
sum = Simd.Vector.Sum(v);
}
else
{
var sum = NumberValue.Zero;
for (var i = 0; i < vector.Size; i++)
sum += vector[i] * vector[i];

return NumberValue.Sqrt(sum);
}
for (; i < vector.Size; i++)
sum += vector[i].Number * vector[i].Number;

return NumberValue.Sqrt(new NumberValue(sum));
}

/// <summary>
Expand Down Expand Up @@ -293,14 +286,15 @@ public static NumberValue Mul(VectorValue left, VectorValue right)
throw new ArgumentException(Resource.MatrixArgException);

var size = Simd.Vector<double>.Count;
var i = 0;
var product = 0.0;

if (Simd.Vector.IsHardwareAccelerated && left.Size >= size)
{
var leftSpan = Unsafe.As<double[]>(left.array).AsSpan();
var rightSpan = Unsafe.As<double[]>(right.array).AsSpan();

var v = Simd.Vector<double>.Zero;
var i = 0;

for (; i <= leftSpan.Length - size; i += size)
{
Expand All @@ -310,21 +304,13 @@ public static NumberValue Mul(VectorValue left, VectorValue right)
v += leftV * rightV;
}

var product = Simd.Vector.Sum(v);

for (; i < leftSpan.Length; i++)
product += leftSpan[i] * rightSpan[i];

return new NumberValue(product);
product = Simd.Vector.Sum(v);
}
else
{
var product = new NumberValue(0.0);
for (var i = 0; i < left.Size; i++)
product += left[i] * right[i];

return product;
}
for (; i < left.Size; i++)
product += left[i].Number * right[i].Number;

return new NumberValue(product);
}

/// <summary>
Expand Down
43 changes: 43 additions & 0 deletions xFunc.Tests/Expressions/Matrices/VectorValueTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -408,4 +408,47 @@ public static IEnumerable<object[]> MulTestData()
[MemberData(nameof(MulTestData))]
public void MulTest(VectorValue v1, VectorValue v2, NumberValue expected)
=> Assert.Equal(expected, v1 * v2);

public static IEnumerable<object[]> SumTestData()
{
yield return new object[]
{
VectorValue.Create(new NumberValue(1), new NumberValue(2), new NumberValue(3)),
new NumberValue(6),
};

yield return new object[]
{
VectorValue.Create(
new NumberValue(1),
new NumberValue(2),
new NumberValue(3),
new NumberValue(4),
new NumberValue(5),
new NumberValue(6),
new NumberValue(7),
new NumberValue(8)),
new NumberValue(36),
};

yield return new object[]
{
VectorValue.Create(
new NumberValue(10),
new NumberValue(20),
new NumberValue(30),
new NumberValue(40),
new NumberValue(50),
new NumberValue(60),
new NumberValue(70),
new NumberValue(80),
new NumberValue(90)),
new NumberValue(450),
};
}

[Theory]
[MemberData(nameof(SumTestData))]
public void SumTest(VectorValue v1, NumberValue expected)
=> Assert.Equal(expected, v1.Sum());
}

0 comments on commit 4ec3016

Please sign in to comment.