Skip to content

Commit

Permalink
Feature #16812: Allow slicing arrays with ArithmeticSequence (#3241)
Browse files Browse the repository at this point in the history
* Support ArithmeticSequence in Array#slice

* Extract rb_range_component_beg_len

* Use rb_range_values to check Range object

* Fix ary_make_partial_step

* Fix for negative step cases

* range.c: Describe the role of err argument in rb_range_component_beg_len

* Raise a RangeError when an arithmetic sequence refers the outside of an array

[Feature #16812]
  • Loading branch information
mrkn committed Oct 20, 2020
1 parent 081cc4e commit a6a8576
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 31 deletions.
91 changes: 83 additions & 8 deletions array.c
Expand Up @@ -1140,6 +1140,52 @@ ary_make_partial(VALUE ary, VALUE klass, long offset, long len)
}
}

static VALUE
ary_make_partial_step(VALUE ary, VALUE klass, long offset, long len, long step)
{
assert(offset >= 0);
assert(len >= 0);
assert(offset+len <= RARRAY_LEN(ary));
assert(step != 0);

const VALUE *values = RARRAY_CONST_PTR_TRANSIENT(ary);
const long orig_len = len;

if ((step > 0 && step >= len) || (step < 0 && (step < -len))) {
VALUE result = ary_new(klass, 1);
VALUE *ptr = (VALUE *)ARY_EMBED_PTR(result);
RB_OBJ_WRITE(result, ptr, values[offset]);
ARY_SET_EMBED_LEN(result, 1);
return result;
}

long ustep = (step < 0) ? -step : step;
len = (len + ustep - 1) / ustep;

long i;
long j = offset + ((step > 0) ? 0 : (orig_len - 1));
VALUE result = ary_new(klass, len);
if (len <= RARRAY_EMBED_LEN_MAX) {
VALUE *ptr = (VALUE *)ARY_EMBED_PTR(result);
for (i = 0; i < len; ++i) {
RB_OBJ_WRITE(result, ptr+i, values[j]);
j += step;
}
ARY_SET_EMBED_LEN(result, len);
}
else {
RARRAY_PTR_USE_TRANSIENT(result, ptr, {
for (i = 0; i < len; ++i) {
RB_OBJ_WRITE(result, ptr+i, values[j]);
j += step;
}
});
ARY_SET_LEN(result, len);
}

return result;
}

static VALUE
ary_make_shared_copy(VALUE ary)
{
Expand Down Expand Up @@ -1571,7 +1617,7 @@ rb_ary_entry(VALUE ary, long offset)
}

VALUE
rb_ary_subseq(VALUE ary, long beg, long len)
rb_ary_subseq_step(VALUE ary, long beg, long len, long step)
{
VALUE klass;
long alen = RARRAY_LEN(ary);
Expand All @@ -1584,8 +1630,18 @@ rb_ary_subseq(VALUE ary, long beg, long len)
}
klass = rb_obj_class(ary);
if (len == 0) return ary_new(klass, 0);
if (step == 0)
rb_raise(rb_eArgError, "slice step cannot be zero");
if (step == 1)
return ary_make_partial(ary, klass, beg, len);
else
return ary_make_partial_step(ary, klass, beg, len, step);
}

return ary_make_partial(ary, klass, beg, len);
VALUE
rb_ary_subseq(VALUE ary, long beg, long len)
{
return rb_ary_subseq_step(ary, beg, len, 1);
}

static VALUE rb_ary_aref2(VALUE ary, VALUE b, VALUE e);
Expand All @@ -1595,6 +1651,11 @@ static VALUE rb_ary_aref2(VALUE ary, VALUE b, VALUE e);
* array[index] -> object or nil
* array[start, length] -> object or nil
* array[range] -> object or nil
* array[aseq] -> object or nil
* array.slice(index) -> object or nil
* array.slice(start, length) -> object or nil
* array.slice(range) -> object or nil
* array.slice(aseq) -> object or nil
*
* Returns elements from +self+; does not modify +self+.
*
Expand Down Expand Up @@ -1651,6 +1712,19 @@ static VALUE rb_ary_aref2(VALUE ary, VALUE b, VALUE e);
* a[-3..2] # => [:foo, "bar", 2]
*
* If <tt>range.start</tt> is larger than the array size, returns +nil+.
* a = [:foo, 'bar', 2]
* a[4..1] # => nil
* a[4..0] # => nil
* a[4..-1] # => nil
*
* When a single argument +aseq+ is given,
* ...(to be described)
*
* Raises an exception if given a single argument
* that is not an \Integer-convertible object or a \Range object:
* a = [:foo, 'bar', 2]
* # Raises TypeError (no implicit conversion of Symbol into Integer):
* a[:foo]
*
* Array#slice is an alias for Array#[].
*/
Expand Down Expand Up @@ -1679,21 +1753,22 @@ rb_ary_aref2(VALUE ary, VALUE b, VALUE e)
MJIT_FUNC_EXPORTED VALUE
rb_ary_aref1(VALUE ary, VALUE arg)
{
long beg, len;
long beg, len, step;

/* special case - speeding up */
if (FIXNUM_P(arg)) {
return rb_ary_entry(ary, FIX2LONG(arg));
}
/* check if idx is Range */
switch (rb_range_beg_len(arg, &beg, &len, RARRAY_LEN(ary), 0)) {
/* check if idx is Range or ArithmeticSequence */
switch (rb_arithmetic_sequence_beg_len_step(arg, &beg, &len, &step, RARRAY_LEN(ary), 0)) {
case Qfalse:
break;
break;
case Qnil:
return Qnil;
return Qnil;
default:
return rb_ary_subseq(ary, beg, len);
return rb_ary_subseq_step(ary, beg, len, step);
}

return rb_ary_entry(ary, NUM2LONG(arg));
}

Expand Down
44 changes: 40 additions & 4 deletions enumerator.c
Expand Up @@ -3410,17 +3410,53 @@ rb_arithmetic_sequence_extract(VALUE obj, rb_arithmetic_sequence_components_t *c
component->exclude_end = arith_seq_exclude_end_p(obj);
return 1;
}
else if (rb_obj_is_kind_of(obj, rb_cRange)) {
component->begin = RANGE_BEG(obj);
component->end = RANGE_END(obj);
else if (rb_range_values(obj, &component->begin, &component->end, &component->exclude_end)) {
component->step = INT2FIX(1);
component->exclude_end = RTEST(RANGE_EXCL(obj));
return 1;
}

return 0;
}

VALUE
rb_arithmetic_sequence_beg_len_step(VALUE obj, long *begp, long *lenp, long *stepp, long len, int err)
{
RUBY_ASSERT(begp != NULL);
RUBY_ASSERT(lenp != NULL);
RUBY_ASSERT(stepp != NULL);

rb_arithmetic_sequence_components_t aseq;
if (!rb_arithmetic_sequence_extract(obj, &aseq)) {
return Qfalse;
}

long step = NIL_P(aseq.step) ? 1 : NUM2LONG(aseq.step);
*stepp = step;

if (step < 0) {
VALUE tmp = aseq.begin;
aseq.begin = aseq.end;
aseq.end = tmp;
}

if (err == 0 && (step < -1 || step > 1)) {
if (rb_range_component_beg_len(aseq.begin, aseq.end, aseq.exclude_end, begp, lenp, len, 1) == Qtrue) {
if (*begp > len)
goto out_of_range;
if (*lenp > len)
goto out_of_range;
return Qtrue;
}
}
else {
return rb_range_component_beg_len(aseq.begin, aseq.end, aseq.exclude_end, begp, lenp, len, err);
}

out_of_range:
rb_raise(rb_eRangeError, "%+"PRIsVALUE" out of range", obj);
return Qnil;
}

/*
* call-seq:
* aseq.first -> num or nil
Expand Down
1 change: 1 addition & 0 deletions include/ruby/internal/intern/enumerator.h
Expand Up @@ -42,6 +42,7 @@ VALUE rb_enumeratorize(VALUE, VALUE, int, const VALUE *);
VALUE rb_enumeratorize_with_size(VALUE, VALUE, int, const VALUE *, rb_enumerator_size_func *);
VALUE rb_enumeratorize_with_size_kw(VALUE, VALUE, int, const VALUE *, rb_enumerator_size_func *, int);
int rb_arithmetic_sequence_extract(VALUE, rb_arithmetic_sequence_components_t *);
VALUE rb_arithmetic_sequence_beg_len_step(VALUE, long *begp, long *lenp, long *stepp, long len, int err);

RBIMPL_SYMBOL_EXPORT_END()

Expand Down
4 changes: 4 additions & 0 deletions internal/range.h
Expand Up @@ -34,4 +34,8 @@ RANGE_EXCL(VALUE r)
return RSTRUCT(r)->as.ary[2];
}

VALUE
rb_range_component_beg_len(VALUE b, VALUE e, int excl,
long *begp, long *lenp, long len, int err);

#endif /* INTERNAL_RANGE_H */
69 changes: 51 additions & 18 deletions range.c
Expand Up @@ -1329,48 +1329,81 @@ rb_range_values(VALUE range, VALUE *begp, VALUE *endp, int *exclp)
return (int)Qtrue;
}

/* Extract the components of a Range.
*
* You can use +err+ to control the behavior of out-of-range and exception.
*
* When +err+ is 0 or 2, if the begin offset is greater than +len+,
* it is out-of-range. The +RangeError+ is raised only if +err+ is 2,
* in this case. If +err+ is 0, +Qnil+ will be returned.
*
* When +err+ is 1, the begin and end offsets won't be adjusted even if they
* are greater than +len+. It allows +rb_ary_aset+ extends arrays.
*
* If the begin component of the given range is negative and is too-large
* abstract value, the +RangeError+ is raised only +err+ is 1 or 2.
*
* The case of <code>err = 0</code> is used in item accessing methods such as
* +rb_ary_aref+, +rb_ary_slice_bang+, and +rb_str_aref+.
*
* The case of <code>err = 1</code> is used in Array's methods such as
* +rb_ary_aset+ and +rb_ary_fill+.
*
* The case of <code>err = 2</code> is used in +rb_str_aset+.
*/
VALUE
rb_range_beg_len(VALUE range, long *begp, long *lenp, long len, int err)
rb_range_component_beg_len(VALUE b, VALUE e, int excl,
long *begp, long *lenp, long len, int err)
{
long beg, end;
VALUE b, e;
int excl;

if (!rb_range_values(range, &b, &e, &excl))
return Qfalse;
beg = NIL_P(b) ? 0 : NUM2LONG(b);
end = NIL_P(e) ? -1 : NUM2LONG(e);
if (NIL_P(e)) excl = 0;
if (beg < 0) {
beg += len;
if (beg < 0)
goto out_of_range;
beg += len;
if (beg < 0)
goto out_of_range;
}
if (end < 0)
end += len;
end += len;
if (!excl)
end++; /* include end point */
end++; /* include end point */
if (err == 0 || err == 2) {
if (beg > len)
goto out_of_range;
if (end > len)
end = len;
if (beg > len)
goto out_of_range;
if (end > len)
end = len;
}
len = end - beg;
if (len < 0)
len = 0;
len = 0;

*begp = beg;
*lenp = len;
return Qtrue;

out_of_range:
if (err) {
rb_raise(rb_eRangeError, "%+"PRIsVALUE" out of range", range);
}
return Qnil;
}

VALUE
rb_range_beg_len(VALUE range, long *begp, long *lenp, long len, int err)
{
VALUE b, e;
int excl;

if (!rb_range_values(range, &b, &e, &excl))
return Qfalse;

VALUE res = rb_range_component_beg_len(b, e, excl, begp, lenp, len, err);
if (NIL_P(res) && err) {
rb_raise(rb_eRangeError, "%+"PRIsVALUE" out of range", range);
}

return res;
}

/*
* call-seq:
* rng.to_s -> string
Expand Down
38 changes: 37 additions & 1 deletion test/ruby/test_array.rb
Expand Up @@ -1496,9 +1496,46 @@ def test_slice
assert_equal(@cls[10, 11, 12], a.slice(-91..-89))
assert_equal(@cls[10, 11, 12], a.slice(-91..-89))

assert_equal(@cls[5, 8, 11], a.slice((4..12)%3))
assert_equal(@cls[95, 97, 99], a.slice((94..)%2))

# [0] [1] [2] [3] [4] [5] [6] [7]
# ary = [ 1 2 3 4 5 6 7 8 ... ]
# (0) (1) (2) <- (..7) % 3
# (2) (1) (0) <- (7..) % -3
assert_equal(@cls[1, 4, 7], a.slice((..7)%3))
assert_equal(@cls[8, 5, 2], a.slice((7..)% -3))

# [-98] [-97] [-96] [-95] [-94] [-93] [-92] [-91] [-90]
# ary = [ ... 3 4 5 6 7 8 9 10 11 ... ]
# (0) (1) (2) <- (-98..-90) % 3
# (2) (1) (0) <- (-90..-98) % -3
assert_equal(@cls[3, 6, 9], a.slice((-98..-90)%3))
assert_equal(@cls[11, 8, 5], a.slice((-90..-98)% -3))

# [ 48] [ 49] [ 50] [ 51] [ 52] [ 53]
# [-52] [-51] [-50] [-49] [-48] [-47]
# ary = [ ... 49 50 51 52 53 54 ... ]
# (0) (1) (2) <- (48..-47) % 2
# (2) (1) (0) <- (-47..48) % -2
assert_equal(@cls[49, 51, 53], a.slice((48..-47)%2))
assert_equal(@cls[54, 52, 50], a.slice((-47..48)% -2))

idx = ((3..90) % 2).to_a
assert_equal(@cls[*a.values_at(*idx)], a.slice((3..90)%2))
idx = 90.step(3, -2).to_a
assert_equal(@cls[*a.values_at(*idx)], a.slice((90 .. 3)% -2))
end

def test_slice_out_of_range
a = @cls[*(1..100).to_a]

assert_nil(a.slice(-101..-1))
assert_nil(a.slice(-101..))

assert_raise_with_message(RangeError, "((-101..-1).%(2)) out of range") { a.slice((-101..-1)%2) }
assert_raise_with_message(RangeError, "((-101..).%(2)) out of range") { a.slice((-101..)%2) }

assert_nil(a.slice(10, -3))
assert_equal @cls[], a.slice(10..7)
end
Expand Down Expand Up @@ -2414,7 +2451,6 @@ def test_unshift_error

def test_aref
assert_raise(ArgumentError) { [][0, 0, 0] }
assert_raise(TypeError) { [][(1..10).step(2)] }
end

def test_fetch
Expand Down

0 comments on commit a6a8576

Please sign in to comment.