From 2dc155d82ef311263c2bf6445bdf114937db95dc Mon Sep 17 00:00:00 2001 From: Alexey Kopytko Date: Sat, 15 Apr 2023 08:10:58 +0900 Subject: [PATCH] Introduce Welford's online algorithm for variance --- .gitignore | 3 +- .phan/config.php | 2 +- README.md | 2 +- composer.json | 2 +- src/Helper/RunningVariance.php | 138 +++++++++++++++ src/Standard.php | 80 +++++++-- tests/Helper/RunningVarianceTest.php | 241 +++++++++++++++++++++++++++ tests/VarianceTest.php | 158 ++++++++++++++++++ 8 files changed, 611 insertions(+), 15 deletions(-) create mode 100644 src/Helper/RunningVariance.php create mode 100644 tests/Helper/RunningVarianceTest.php create mode 100644 tests/VarianceTest.php diff --git a/.gitignore b/.gitignore index 483ab79..f55f9fc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /composer.lock /vendor /build -/.phpunit.result.cache +*.cache +.?* diff --git a/.phan/config.php b/.phan/config.php index 53c8d56..c7b633e 100644 --- a/.phan/config.php +++ b/.phan/config.php @@ -3,7 +3,7 @@ use Phan\Issue; return [ - 'target_php_version' => '7.1', + 'target_php_version' => '7.4', 'backward_compatibility_checks' => false, 'exclude_analysis_directory_list' => [ 'vendor/', diff --git a/README.md b/README.md index bf83192..3b30942 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ This rigorously tested library just works. Pipeline neither defines nor throws a composer require sanmai/pipeline -The latest version requires PHP 7.1 or above, including PHP 8.0 and later. +The latest version requires PHP 7.4 or above, including PHP 8.0 and later. There are earlier versions that work under PHP 5.6 and above, but they are not as feature complete. diff --git a/composer.json b/composer.json index 4229b13..a3ce3ed 100644 --- a/composer.json +++ b/composer.json @@ -10,7 +10,7 @@ } ], "require": { - "php": "^7.1 || ^8.0" + "php": "^7.4 || ^8.0" }, "require-dev": { "ergebnis/composer-normalize": "^2.8", diff --git a/src/Helper/RunningVariance.php b/src/Helper/RunningVariance.php new file mode 100644 index 0000000..bffccad --- /dev/null +++ b/src/Helper/RunningVariance.php @@ -0,0 +1,138 @@ + + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +declare(strict_types=1); + +namespace Pipeline\Helper; + +use function sqrt; + +/** + * Computes statistics (such as standard deviation) in real time. + * + * @see https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + * + * @final + */ +class RunningVariance +{ + private const ZERO = 0; + + /** The number of observed values. */ + private int $count = 0; + + /** The mean value. */ + private float $mean = 0.0; + + /** The aggregated squared distance from the mean. */ + private float $m2 = 0.0; + + public function __construct(self ...$spiesToMerge) + { + foreach ($spiesToMerge as $spy) { + $this->merge($spy); + } + } + + public function observe(float $value): float + { + ++$this->count; + + $delta = $value - $this->mean; + + $this->mean += $delta / $this->count; + $this->m2 += $delta * ($value - $this->mean); + + return $value; + } + + /** + * The number of observed values. + */ + public function getCount(): int + { + return $this->count; + } + + /** + * Get the mean value. + */ + public function getMean(): float + { + if (self::ZERO === $this->count) { + // For no values the variance is undefined. + return NAN; + } + + return $this->mean; + } + + /** + * Get the variance. + */ + public function getVariance(): float + { + if (self::ZERO === $this->count) { + // For no values the variance is undefined. + return NAN; + } + + if (1 === $this->count) { + // Avoiding division by zero: variance for one value is zero. + return 0.0; + } + + return $this->m2 / ($this->count - 1); + } + + /** + * Compute the standard deviation. + */ + public function getStandardDeviation(): float + { + return sqrt($this->getVariance()); + } + + /** + * Merge another instance into this instance. + * + * @see https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + */ + private function merge(self $other): void + { + // Shortcut a no-op + if (self::ZERO === $other->count) { + return; + } + + // Avoid division by zero by copying values + if (self::ZERO === $this->count) { + $this->count = $other->count; + $this->mean = $other->mean; + $this->m2 = $other->m2; + + return; + } + + $count = $this->count + $other->count; + $delta = $other->mean - $this->mean; + + $this->mean = ($this->count * $this->mean) / $count + ($other->count * $other->mean) / $count; + $this->m2 = $this->m2 + $other->m2 + ($delta * $delta * $this->count * $other->count / $count); + $this->count = $count; + } +} diff --git a/src/Standard.php b/src/Standard.php index f3302a6..e0f0f5a 100644 --- a/src/Standard.php +++ b/src/Standard.php @@ -26,6 +26,7 @@ use Generator; use Iterator; use IteratorAggregate; +use Pipeline\Helper\RunningVariance; use Traversable; use function array_chunk; use function array_filter; @@ -63,10 +64,8 @@ class Standard implements IteratorAggregate, Countable /** * Contructor with an optional source of data. - * - * @param ?iterable $input */ - public function __construct(iterable $input = null) + public function __construct(?iterable $input = null) { // IteratorAggregate is a nuance best we avoid dealing with. // For example, CallbackFilterIterator needs a plain Iterator. @@ -79,10 +78,8 @@ public function __construct(iterable $input = null) /** * Appends the contents of an interable to the end of the pipeline. - * - * @param ?iterable $values */ - public function append(iterable $values = null): self + public function append(?iterable $values = null): self { // Do we need to do anything here? if ($this->willReplace($values)) { @@ -108,10 +105,8 @@ public function push(...$vector): self /** * Prepends the pipeline with the contents of an iterable. - * - * @param ?iterable $values */ - public function prepend(iterable $values = null): self + public function prepend(?iterable $values = null): self { // Do we need to do anything here? if ($this->willReplace($values)) { @@ -140,7 +135,7 @@ public function unshift(...$vector): self * * Utility method for appending/prepending methods. */ - private function willReplace(iterable $values = null): bool + private function willReplace(?iterable $values = null): bool { // Nothing needs to be done here. /** @phan-suppress-next-line PhanTypeComparisonFromArray */ @@ -263,7 +258,7 @@ public function chunk(int $length, bool $preserve_keys = false): self } /** - * @psalm-param positive-int $length + * @psalm-param positive-int $length */ private static function toChunks(Generator $input, int $length, bool $preserve_keys): iterable { @@ -992,4 +987,67 @@ private static function flipKeysAndValues(iterable $previous): iterable yield $value => $key; } } + + /** + * Feeds in an instance of RunningVariance. + * + * @param RunningVariance $variance the instance of RunningVariance + * @param ?callable $castFunc the cast callback, returning ?float; null values are not counted + * + * @return $this + */ + public function feedRunningVariance(RunningVariance $variance, ?callable $castFunc = null) + { + if (null === $castFunc) { + $castFunc = 'floatval'; + } + + return $this->cast(static function ($value) use ($variance, $castFunc) { + $float = $castFunc($value); + + if (null !== $float) { + $variance->observe($float); + } + + // Returning the original value here + return $value; + }); + } + + public function onlineVariance(?callable $castFunc = null): RunningVariance + { + $variance = new RunningVariance(); + + $this->feedRunningVariance($variance, $castFunc); + + return $variance; + } + + public function variance(?callable $castFunc = null): RunningVariance + { + $variance = new RunningVariance(); + + if (null === $this->pipeline) { + // No-op: null. + return $variance; + } + + if ([] === $this->pipeline) { + // No-op: an empty array. + return $variance; + } + + $this->feedRunningVariance($variance, $castFunc); + + if (is_array($this->pipeline)) { + // We are done! + return $variance; + } + + foreach ($this->pipeline as $_) { + // Discard + } + + return $variance; + } } diff --git a/tests/Helper/RunningVarianceTest.php b/tests/Helper/RunningVarianceTest.php new file mode 100644 index 0000000..7c885ba --- /dev/null +++ b/tests/Helper/RunningVarianceTest.php @@ -0,0 +1,241 @@ + + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +declare(strict_types=1); + +namespace Tests\Pipeline\Helper; + +use PHPUnit\Framework\TestCase; +use Pipeline\Helper\RunningVariance; +use function array_sum; +use function cos; +use function count; +use function log; +use function mt_getrandmax; +use function mt_rand; +use function Pipeline\take; +use function sin; +use function sqrt; + +/** + * @internal + * + * @covers \Pipeline\Helper\RunningVariance + */ +final class RunningVarianceTest extends TestCase +{ + public function testEmpty(): void + { + $variance = new RunningVariance(); + + $this->assertSame(0, $variance->getCount()); + $this->assertNan($variance->getMean()); + $this->assertNan($variance->getVariance()); + $this->assertNan($variance->getStandardDeviation()); + } + + public function testEmptyPlusEmpty(): void + { + $variance = new RunningVariance(new RunningVariance()); + + $this->assertSame(0, $variance->getCount()); + $this->assertNan($variance->getMean()); + $this->assertNan($variance->getVariance()); + $this->assertNan($variance->getStandardDeviation()); + } + + public function testOne(): void + { + $variance = new RunningVariance(); + $variance->observe(M_PI); + + $this->assertSame(1, $variance->getCount()); + $this->assertSame(M_PI, $variance->getMean()); + $this->assertSame(0.0, $variance->getVariance()); + $this->assertSame(0.0, $variance->getStandardDeviation()); + } + + public function testTwo(): void + { + $variance = new RunningVariance(); + $variance->observe(M_PI); + $variance->observe(M_PI); + + $this->assertSame(2, $variance->getCount()); + $this->assertSame(M_PI, $variance->getMean()); + $this->assertSame(0.0, $variance->getVariance()); + $this->assertSame(0.0, $variance->getStandardDeviation()); + } + + public function testNAN(): void + { + $variance = new RunningVariance(); + $variance->observe(M_PI); + $variance->observe(NAN); + + $this->assertSame(2, $variance->getCount()); + $this->assertNan($variance->getMean()); + $this->assertNan($variance->getVariance()); + $this->assertNan($variance->getStandardDeviation()); + } + + public function testFive(): void + { + $variance = new RunningVariance(); + $variance->observe(4.0); + $variance->observe(2.0); + $variance->observe(5.0); + $variance->observe(8.0); + $variance->observe(6.0); + + $this->assertSame(5, $variance->getCount()); + $this->assertSame(5.0, $variance->getMean()); + $this->assertEqualsWithDelta(5.0, $variance->getVariance(), 0.0001); + $this->assertEqualsWithDelta(sqrt(5.0), $variance->getStandardDeviation(), 0.0001); + } + + public function testCopy(): void + { + $variance = new RunningVariance(); + $variance->observe(4.0); + $variance->observe(2.0); + $variance->observe(5.0); + $variance->observe(8.0); + $variance->observe(6.0); + + $variance = new RunningVariance($variance); + + $this->assertSame(5, $variance->getCount()); + $this->assertSame(5.0, $variance->getMean()); + $this->assertEqualsWithDelta(5.0, $variance->getVariance(), 0.0001); + $this->assertEqualsWithDelta(sqrt(5.0), $variance->getStandardDeviation(), 0.0001); + } + + public function testFiveMerged(): void + { + $variance = new RunningVariance(); + $variance->observe(4.0); + $variance->observe(2.0); + + $variance = new RunningVariance($variance); + $variance->observe(5.0); + $variance->observe(8.0); + $variance->observe(6.0); + + $this->assertSame(5, $variance->getCount()); + $this->assertSame(5.0, $variance->getMean()); + $this->assertEqualsWithDelta(5.0, $variance->getVariance(), 0.0001); + $this->assertEqualsWithDelta(sqrt(5.0), $variance->getStandardDeviation(), 0.0001); + } + + public function testFiveMergedTwice(): void + { + $varianceA = new RunningVariance(); + $varianceA->observe(5.0); + $varianceA->observe(8.0); + $varianceA->observe(6.0); + + $varianceB = new RunningVariance(); + $varianceB->observe(4.0); + $varianceB->observe(2.0); + + $variance = new RunningVariance($varianceA, $varianceB); + + $this->assertSame(5, $variance->getCount()); + $this->assertSame(5.0, $variance->getMean()); + $this->assertEqualsWithDelta(5.0, $variance->getVariance(), 0.0001); + $this->assertEqualsWithDelta(sqrt(5.0), $variance->getStandardDeviation(), 0.0001); + } + + public function provideRandomNumberCounts(): iterable + { + yield ['count' => 900, 'mean' => 8.1, 'sigma' => 1.9]; + + yield ['count' => 1190, 'mean' => 729.4, 'sigma' => 4.2]; + + yield ['count' => 1500, 'mean' => 3698.41, 'sigma' => 12.9]; + } + + /** + * @coversNothing + * + * @dataProvider provideRandomNumberCounts + */ + public function testNumericStability(int $count, float $mean, float $sigma): void + { + $numbers = take(self::getRandomNumbers($mean, $sigma)) + ->slice(0, $count)->toArray(); + + $benchmark = self::standard_deviation($numbers); + + $variance = take($numbers)->variance(); + + $this->assertEqualsWithDelta($benchmark, $variance->getStandardDeviation(), $sigma / 100); // 1% + } + + /** + * @coversNothing + * + * @dataProvider provideRandomNumberCounts + */ + public function testMullerTransform(int $count, float $mean, float $sigma): void + { + $numbers = take(self::getRandomNumbers($mean, $sigma)) + ->slice(0, $count) + ->toArray(); + + $this->assertEqualsWithDelta($sigma, self::standard_deviation( + $numbers + ), $sigma / 10); + } + + /** + * @see https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform + * + * @param float $mean The target average/mean + * @param float $sigma The target standard deviation + * + * @return iterable + */ + private static function getRandomNumbers(float $mean, float $sigma): iterable + { + $two_pi = 2 * M_PI; + $epsilon = 1E-6; // Arbitrary number + + while (true) { + do { + $u1 = mt_rand() / mt_getrandmax(); + } while ($u1 <= $epsilon); + + $mag = $sigma * sqrt(-2.0 * log($u1)); + + $u2 = mt_rand() / mt_getrandmax(); + + yield $mag * cos($two_pi * $u2) + $mean; + yield $mag * sin($two_pi * $u2) + $mean; + } + } + + private static function standard_deviation(array $input) + { + $mean = array_sum($input) / count($input); + + $carry = take($input)->cast(fn (float $val) => ($val - $mean) ** 2)->reduce(); + + return sqrt($carry / count($input)); + } +} diff --git a/tests/VarianceTest.php b/tests/VarianceTest.php new file mode 100644 index 0000000..05140d3 --- /dev/null +++ b/tests/VarianceTest.php @@ -0,0 +1,158 @@ + + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +declare(strict_types=1); + +namespace Tests\Pipeline; + +use PHPUnit\Framework\TestCase; +use Pipeline\Helper\RunningVariance; +use function Pipeline\fromArray; +use function Pipeline\map; + +/** + * @covers \Pipeline\Standard::feedRunningVariance() + * @covers \Pipeline\Standard::onlineVariance() + * @covers \Pipeline\Standard::variance() + * + * @internal + */ +final class VarianceTest extends TestCase +{ + public function testVarianceUnitinialized(): void + { + $pipeline = new \Pipeline\Standard(); + + $this->assertSame(0, $pipeline->variance()->getCount()); + } + + public function testVarianceEmptyArray(): void + { + $this->assertSame(0, fromArray([])->variance()->getCount()); + } + + public function testVarianceNANPassThrough(): void + { + $this->assertNan(fromArray([1.0, 2.0, 3.0, NAN])->variance()->getStandardDeviation()); + } + + public function testVarianceArray(): void + { + $this->assertEqualsWithDelta( + 2.2913, + fromArray([5, 5, 9, 9, 9, 10, 5, 10, 10])->variance()->getStandardDeviation(), + 0.0001 + ); + } + + public function testVarianceIterable(): void + { + $pipeline = map(fn () => yield from [5, 5, 9, 9, 9, 10, 5, 10, 10]); + + $this->assertEqualsWithDelta( + 2.2913, + map(fn () => yield from [5, 5, 9, 9, 9, 10, 5, 10, 10])->variance()->getStandardDeviation(), + 0.0001 + ); + } + + public function testVarianceCast(): void + { + $pipeline = map(fn () => yield from [-10, -20, 5, 5, 9, 9, 9, 10, 5, 10, 10, 100, 200]); + + $variance = $pipeline->variance(static function (int $number): ?float { + if ($number < 0 || $number > 10) { + return null; + } + + return (float) $number; + }); + + $this->assertEqualsWithDelta( + 2.2913, + $variance->getStandardDeviation(), + 0.0001 + ); + } + + public function testOnlineVariance(): void + { + $pipeline = map(fn () => yield from [-10, -20, 5, 5, 9, 9, 9, 10, 5, 10, 10, 100, 200]); + + $variance = $pipeline->onlineVariance(static function (int $number): ?float { + if ($number < 0 || $number > 10) { + return null; + } + + return (float) $number; + }); + + $this->assertSame(0, $variance->getCount()); + + // Now, count all values + $this->assertSame(13, $pipeline->count()); + + // Only valid values are accounted for variance + $this->assertSame(9, $variance->getCount()); + + $this->assertEqualsWithDelta( + 2.2913, + $variance->getStandardDeviation(), + 0.0001 + ); + } + + public function testFeedVariance(): void + { + $pipeline = map(fn () => yield from [5, 5, 9, 9, 9, 10, 5, 10, 10]); + + $variance = new RunningVariance(); + + $pipeline->feedRunningVariance($variance, 'floatval'); + + $this->assertSame(0, $variance->getCount()); + + $this->assertSame(9, $pipeline->count()); + $this->assertSame(9, $variance->getCount()); + + $this->assertEqualsWithDelta( + 2.2913, + $variance->getStandardDeviation(), + 0.0001 + ); + } + + public function testFeedVarianceArray(): void + { + $pipeline = fromArray([5, 5, 9, 9, 9, 10, 5, 10, 10]); + + $variance = new RunningVariance(); + + $pipeline->feedRunningVariance($variance, 'floatval'); + + // Arrays are eagerly processed + $this->assertSame(9, $variance->getCount()); + $this->assertSame(9, $pipeline->count()); + $this->assertSame(9, $variance->getCount()); + + $this->assertEqualsWithDelta( + 2.2913, + $variance->getStandardDeviation(), + 0.0001 + ); + } +}