Skip to content

Commit

Permalink
Merge 7ae509b into 86c1df9
Browse files Browse the repository at this point in the history
  • Loading branch information
sanmai committed Nov 27, 2021
2 parents 86c1df9 + 7ae509b commit 13ce2e6
Show file tree
Hide file tree
Showing 4 changed files with 350 additions and 14 deletions.
1 change: 1 addition & 0 deletions .php-cs-fixer.dist.php
Expand Up @@ -57,6 +57,7 @@
],
'php_unit_test_case_static_method_calls' => false,
'yoda_style' => true,
'random_api_migration' => false,
])
->setFinder(
PhpCsFixer\Finder::create()
Expand Down
135 changes: 127 additions & 8 deletions src/Standard.php
Expand Up @@ -26,7 +26,6 @@
use function array_slice;
use function array_values;
use ArrayIterator;
use function assert;
use CallbackFilterIterator;
use function count;
use Countable;
Expand All @@ -38,6 +37,8 @@
use Iterator;
use function iterator_to_array;
use IteratorAggregate;
use function mt_getrandmax;
use function mt_rand;
use Traversable;

/**
Expand Down Expand Up @@ -378,6 +379,17 @@ public function count(): int
return count($this->pipeline);
}

private static function makeNonRewindable(iterable $input): Generator
{
if ($input instanceof Generator) {
return $input;
}

return (static function (iterable $input) {
yield from $input;
})($input);
}

/**
* Extracts a slice from the inputs. Keys are not discarded intentionally.
*
Expand Down Expand Up @@ -409,15 +421,14 @@ public function slice(int $offset, ?int $length = null)
return $this;
}

$this->pipeline = self::makeNonRewindable($this->pipeline);

if ($offset < 0) {
// If offset is negative, the sequence will start that far from the end of the array.
$this->pipeline = self::tail($this->pipeline, -$offset);
}

if ($offset > 0) {
// @infection-ignore-all
assert($this->pipeline instanceof Iterator);

// If offset is non-negative, the sequence will start at that offset in the array.
$this->pipeline = self::skip($this->pipeline, $offset);
}
Expand All @@ -438,7 +449,7 @@ public function slice(int $offset, ?int $length = null)
/**
* @psalm-param positive-int $skip
*/
private static function skip(Iterator $input, int $skip): iterable
private static function skip(Iterator $input, int $skip): Generator
{
// Consume until seen enough.
foreach ($input as $_) {
Expand All @@ -459,7 +470,7 @@ private static function skip(Iterator $input, int $skip): iterable
/**
* @psalm-param positive-int $take
*/
private static function take(iterable $input, int $take): iterable
private static function take(Generator $input, int $take): Generator
{
foreach ($input as $key => $value) {
yield $key => $value;
Expand All @@ -469,9 +480,12 @@ private static function take(iterable $input, int $take): iterable
break;
}
}

// the break above will leave the generator in an inconsistent state
$input->next();
}

private static function tail(iterable $input, int $length): iterable
private static function tail(iterable $input, int $length): Generator
{
$buffer = [];

Expand All @@ -496,7 +510,7 @@ private static function tail(iterable $input, int $length): iterable
/**
* Allocates a buffer of $length, and reads records into it, proceeding with FIFO when buffer is full.
*/
private static function head(iterable $input, int $length): iterable
private static function head(iterable $input, int $length): Generator
{
$buffer = [];

Expand Down Expand Up @@ -569,4 +583,109 @@ private static function toIterators(iterable ...$inputs): array
return new ArrayIterator($input);
}, $inputs);
}

/**
* Reservoir sampling method with an optional weighting function. Uses the most optimal algorithm.
*
* @see https://en.wikipedia.org/wiki/Reservoir_sampling
*
* @param int $size The desired sample size
* @param ?callable $weightFunc The optional weighting function
*/
public function reservoir(int $size, ?callable $weightFunc = null): array
{
if (null === $this->pipeline) {
return [];
}

if ($size <= 0) {
return [];
}

// Algorithms below assume inputs are non-rewindable
$this->pipeline = self::makeNonRewindable($this->pipeline);

$result = null === $weightFunc ?
self::reservoirRandom($this->pipeline, $size) :
self::reservoirWeighted($this->pipeline, $size, $weightFunc);

return iterator_to_array($result, true);
}

/**
* Simple and slow algorithm, commonly known as Algorithm R.
*
* @see https://en.wikipedia.org/wiki/Reservoir_sampling#Simple_algorithm
* @psalm-param positive-int $size
*/
private static function reservoirRandom(Generator $input, int $size): Generator
{
// Take an initial sample (AKA fill the reservoir array)
foreach (self::take($input, $size) as $output) {
yield $output;
}

// Return if there's nothing more to fetch
if (!$input->valid()) {
return;
}

$counter = $size;

// Produce replacement elements with gradually decreasing probability
foreach ($input as $value) {
$key = mt_rand(0, $counter);

if ($key < $size) {
yield $key => $value;
}

++$counter;
}
}

/**
* Weighted random sampling.
*
* @see https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Chao
* @psalm-param positive-int $size
*/
private static function reservoirWeighted(Generator $input, int $size, callable $weightFunc): Generator
{
$sum = 0.0;

// Take an initial sample (AKA fill the reservoir array)
foreach (self::take($input, $size) as $output) {
yield $output;
$sum += $weightFunc($output);
}

// Return if there's nothing more to fetch
if (!$input->valid()) {
return;
}

foreach ($input as $value) {
$weight = $weightFunc($value);
$sum += $weight;

// probability for this item
$probability = $weight / $sum;

// @infection-ignore-all
if (self::random() <= $probability) {
yield mt_rand(0, $size - 1) => $value;
}
}
}

/**
* Returns a pseudorandom value between zero (inclusive) and one (exclusive).
*
* @infection-ignore-all
*/
private static function random(): float
{
return mt_rand(0, mt_getrandmax() - 1) / mt_getrandmax();
}
}
183 changes: 183 additions & 0 deletions tests/ReservoirTest.php
@@ -0,0 +1,183 @@
<?php
/**
* Copyright 2017, 2018 Alexey Kopytko <alexey@kopytko.com>
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

declare(strict_types=1);

namespace Tests\Pipeline;

use function abs;
use ArrayIterator;
use IteratorIterator;
use function mt_rand;
use function mt_srand;
use function ord;
use PHPUnit\Framework\TestCase;
use function Pipeline\map;
use Pipeline\Standard;
use function Pipeline\take;
use function range;
use function sin;

/**
* @covers \Pipeline\Standard
*
* @internal
*/
final class ReservoirTest extends TestCase
{
protected function setUp(): void
{
mt_srand(0);
}

public function testRandomSeed(): void
{
$this->assertSame(
[20, 17, 11, 13, 18, 13],
take(range(0, 5))->map(function () {
return mt_rand(10, 20);
})->toArray()
);
}

public function testNoop(): void
{
$pipeline = new Standard();
$this->assertSame([], $pipeline->reservoir(1000));
}

public function provideInputs(): iterable
{
yield 'no change expected' => [['a', 'b', 'c'], 3, ['a', 'b', 'c']];

yield [['a', 'b', 'c'], -1, []];

yield [['a', 'b', 'c'], 0, []];

yield [['a', 'b', 'c'], 1, ['c']];

yield [['a', 'b', 'c'], 2, ['a', 'b']];

yield [['a', 'b', 'c'], 4, ['a', 'b', 'c']];

yield [['a', 'b', 'c', 'd', 'e', 'f'], 2, ['f', 'b']];

yield [['a', 'b', 'c', 'd', 'e', 'f'], 3, ['d', 'b', 'c']];

yield [['a', 'b', 'c', 'd', 'e', 'f'], 4, ['a', 'b', 'c', 'f']];

yield [range(0, 1000), 10, [
838,
96,
381,
971,
87,
715,
589,
168,
693,
366,
]];
}

/**
* @dataProvider provideInputs
*/
public function testSampleFromGenerator(array $input, int $size, array $expected): void
{
$this->assertSame($expected, map(static function () use ($input) {
yield from $input;
})->reservoir($size));
}

/**
* @dataProvider provideInputs
*/
public function testSampleFromArray(array $input, int $size, array $expected): void
{
$this->assertSame($expected, take($input)->reservoir($size));
}

/**
* @dataProvider provideInputs
*/
public function testSampleFromIterator(array $input, int $size, array $expected): void
{
$input = new IteratorIterator(new ArrayIterator($input));

$this->assertSame($expected, take($input)->reservoir($size));
}

public function provideWeightedInputs(): iterable
{
$weightFn = static function (string $input): float {
return abs(sin(ord($input[0])));
};

yield 'no change expected' => [['a', 'b', 'c'], 3, $weightFn, ['a', 'b', 'c']];

yield [['a', 'b', 'c'], -1, $weightFn, []];

yield [['a', 'b', 'c'], 0, $weightFn, []];

yield [['a', 'b', 'c'], 1, $weightFn, ['c']];

yield [['a', 'b', 'c'], 2, $weightFn, ['a', 'c']];

yield [['a', 'b', 'c'], 4, $weightFn, ['a', 'b', 'c']];

$weightFnInt = static function (int $input): float {
return abs(sin($input / 1000));
};

yield [range(0, 1000), 5, $weightFnInt, [
437,
1,
358,
240,
197,
]];
}

/**
* @dataProvider provideWeightedInputs
*/
public function testWeightedSampleFromGenerator(array $input, int $size, callable $weightFn, array $expected): void
{
$this->assertSame($expected, map(static function () use ($input) {
yield from $input;
})->reservoir($size, $weightFn));
}

/**
* @dataProvider provideWeightedInputs
*/
public function testWeightedSampleFromArray(array $input, int $size, callable $weightFn, array $expected): void
{
$this->assertSame($expected, take($input)->reservoir($size, $weightFn));
}

/**
* @dataProvider provideWeightedInputs
*/
public function testWeightedSampleFromIterator(array $input, int $size, callable $weightFn, array $expected): void
{
$input = new IteratorIterator(new ArrayIterator($input));

$this->assertSame($expected, take($input)->reservoir($size, $weightFn));
}
}

0 comments on commit 13ce2e6

Please sign in to comment.