Skip to content

Commit

Permalink
Merge cd16678 into 86c1df9
Browse files Browse the repository at this point in the history
  • Loading branch information
sanmai committed Nov 27, 2021
2 parents 86c1df9 + cd16678 commit 5bd3c43
Show file tree
Hide file tree
Showing 5 changed files with 350 additions and 15 deletions.
1 change: 1 addition & 0 deletions .php-cs-fixer.dist.php
Expand Up @@ -57,6 +57,7 @@
],
'php_unit_test_case_static_method_calls' => false,
'yoda_style' => true,
'random_api_migration' => false,
])
->setFinder(
PhpCsFixer\Finder::create()
Expand Down
135 changes: 127 additions & 8 deletions src/Standard.php
Expand Up @@ -26,7 +26,6 @@
use function array_slice;
use function array_values;
use ArrayIterator;
use function assert;
use CallbackFilterIterator;
use function count;
use Countable;
Expand All @@ -38,6 +37,8 @@
use Iterator;
use function iterator_to_array;
use IteratorAggregate;
use function mt_getrandmax;
use function mt_rand;
use Traversable;

/**
Expand Down Expand Up @@ -378,6 +379,17 @@ public function count(): int
return count($this->pipeline);
}

private static function makeNonRewindable(iterable $input): Generator
{
if ($input instanceof Generator) {
return $input;
}

return (static function (iterable $input) {
yield from $input;
})($input);
}

/**
* Extracts a slice from the inputs. Keys are not discarded intentionally.
*
Expand Down Expand Up @@ -409,15 +421,14 @@ public function slice(int $offset, ?int $length = null)
return $this;
}

$this->pipeline = self::makeNonRewindable($this->pipeline);

if ($offset < 0) {
// If offset is negative, the sequence will start that far from the end of the array.
$this->pipeline = self::tail($this->pipeline, -$offset);
}

if ($offset > 0) {
// @infection-ignore-all
assert($this->pipeline instanceof Iterator);

// If offset is non-negative, the sequence will start at that offset in the array.
$this->pipeline = self::skip($this->pipeline, $offset);
}
Expand All @@ -438,7 +449,7 @@ public function slice(int $offset, ?int $length = null)
/**
* @psalm-param positive-int $skip
*/
private static function skip(Iterator $input, int $skip): iterable
private static function skip(Iterator $input, int $skip): Generator
{
// Consume until seen enough.
foreach ($input as $_) {
Expand All @@ -459,7 +470,7 @@ private static function skip(Iterator $input, int $skip): iterable
/**
* @psalm-param positive-int $take
*/
private static function take(iterable $input, int $take): iterable
private static function take(Generator $input, int $take): Generator
{
foreach ($input as $key => $value) {
yield $key => $value;
Expand All @@ -469,9 +480,12 @@ private static function take(iterable $input, int $take): iterable
break;
}
}

// the break above will leave the generator in an inconsistent state
$input->next();
}

private static function tail(iterable $input, int $length): iterable
private static function tail(iterable $input, int $length): Generator
{
$buffer = [];

Expand All @@ -496,7 +510,7 @@ private static function tail(iterable $input, int $length): iterable
/**
* Allocates a buffer of $length, and reads records into it, proceeding with FIFO when buffer is full.
*/
private static function head(iterable $input, int $length): iterable
private static function head(iterable $input, int $length): Generator
{
$buffer = [];

Expand Down Expand Up @@ -569,4 +583,109 @@ private static function toIterators(iterable ...$inputs): array
return new ArrayIterator($input);
}, $inputs);
}

/**
* Reservoir sampling method with an optional weighting function. Uses the most optimal algorithm.
*
* @see https://en.wikipedia.org/wiki/Reservoir_sampling
*
* @param int $size The desired sample size
* @param ?callable $weightFunc The optional weighting function
*/
public function reservoir(int $size, ?callable $weightFunc = null): array
{
if (null === $this->pipeline) {
return [];
}

if ($size <= 0) {
return [];
}

// Algorithms below assume inputs are non-rewindable
$this->pipeline = self::makeNonRewindable($this->pipeline);

$result = null === $weightFunc ?
self::reservoirRandom($this->pipeline, $size) :
self::reservoirWeighted($this->pipeline, $size, $weightFunc);

return iterator_to_array($result, true);
}

/**
* Simple and slow algorithm, commonly known as Algorithm R.
*
* @see https://en.wikipedia.org/wiki/Reservoir_sampling#Simple_algorithm
* @psalm-param positive-int $size
*/
private static function reservoirRandom(Generator $input, int $size): Generator
{
// Take an initial sample (AKA fill the reservoir array)
foreach (self::take($input, $size) as $output) {
yield $output;
}

// Return if there's nothing more to fetch
if (!$input->valid()) {
return;
}

$counter = $size;

// Produce replacement elements with gradually decreasing probability
foreach ($input as $value) {
$key = mt_rand(0, $counter);

if ($key < $size) {
yield $key => $value;
}

++$counter;
}
}

/**
* Weighted random sampling.
*
* @see https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Chao
* @psalm-param positive-int $size
*/
private static function reservoirWeighted(Generator $input, int $size, callable $weightFunc): Generator
{
$sum = 0.0;

// Take an initial sample (AKA fill the reservoir array)
foreach (self::take($input, $size) as $output) {
yield $output;
$sum += $weightFunc($output);
}

// Return if there's nothing more to fetch
if (!$input->valid()) {
return;
}

foreach ($input as $value) {
$weight = $weightFunc($value);
$sum += $weight;

// probability for this item
$probability = $weight / $sum;

// @infection-ignore-all
if (self::random() <= $probability) {
yield mt_rand(0, $size - 1) => $value;
}
}
}

/**
* Returns a pseudorandom value between zero (inclusive) and one (exclusive).
*
* @infection-ignore-all
*/
private static function random(): float
{
return mt_rand(0, mt_getrandmax() - 1) / mt_getrandmax();
}
}
1 change: 0 additions & 1 deletion tests/Benchmarks/BenchTest.php
Expand Up @@ -26,7 +26,6 @@
use function random_int;

/**
* @covers \Pipeline\Principal
* @covers \Pipeline\Standard
*
* @internal
Expand Down

0 comments on commit 5bd3c43

Please sign in to comment.