Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug #17736] Ensure the receiver modifiable before updating #4299

Merged
merged 2 commits into from
Mar 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 42 additions & 40 deletions hash.c
Original file line number Diff line number Diff line change
Expand Up @@ -1665,15 +1665,10 @@ func##_insert(st_data_t *key, st_data_t *val, st_data_t arg, int existing) \

struct update_arg {
st_data_t arg;
st_update_callback_func *func;
VALUE hash;
VALUE new_key;
VALUE old_key;
VALUE new_value;
VALUE old_value;
};

static int hash_update_replace(st_data_t *key, st_data_t *value, struct update_arg *arg, int existing, st_data_t newvalue);

typedef int (*tbl_update_func)(st_data_t *, st_data_t *, st_data_t, int);

int
Expand All @@ -1693,25 +1688,42 @@ rb_hash_stlike_update(VALUE hash, st_data_t key, st_update_callback_func *func,
}

static int
tbl_update(VALUE hash, VALUE key, tbl_update_func func, st_data_t optional_arg)
{
struct update_arg arg;
int result;

arg.arg = optional_arg;
arg.hash = hash;
arg.new_key = 0;
arg.old_key = Qundef;
arg.new_value = 0;
arg.old_value = Qundef;
tbl_update_modify(st_data_t *key, st_data_t *val, st_data_t arg, int existing)
{
struct update_arg *p = (struct update_arg *)arg;
st_data_t old_key = *key;
st_data_t old_value = *val;
VALUE hash = p->hash;
int ret = (p->func)(key, val, arg, existing);
switch (ret) {
default:
break;
case ST_CONTINUE:
if (!existing || *key != old_key || *val != old_value)
rb_hash_modify(hash);
/* write barrier */
RB_OBJ_WRITTEN(hash, Qundef, *key);
RB_OBJ_WRITTEN(hash, Qundef, *val);
break;
case ST_DELETE:
if (existing)
rb_hash_modify(hash);
break;
}

result = rb_hash_stlike_update(hash, key, func, (st_data_t)&arg);
return ret;
}

/* write barrier */
if (arg.new_key) RB_OBJ_WRITTEN(hash, arg.old_key, arg.new_key);
if (arg.new_value) RB_OBJ_WRITTEN(hash, arg.old_value, arg.new_value);
static int
tbl_update(VALUE hash, VALUE key, tbl_update_func func, st_data_t optional_arg)
{
struct update_arg arg = {
.arg = optional_arg,
.func = func,
.hash = hash,
};

return result;
return rb_hash_stlike_update(hash, key, tbl_update_modify, (st_data_t)&arg);
}

#define UPDATE_CALLBACK(iter_lev, func) ((iter_lev) > 0 ? func##_noinsert : func##_insert)
Expand Down Expand Up @@ -2839,7 +2851,8 @@ rb_hash_clear(VALUE hash)
static int
hash_aset(st_data_t *key, st_data_t *val, struct update_arg *arg, int existing)
{
return hash_update_replace(key, val, arg, existing, arg->arg);
*val = arg->arg;
return ST_CONTINUE;
}

VALUE
Expand Down Expand Up @@ -3862,24 +3875,11 @@ rb_hash_invert(VALUE hash)
return h;
}

static int
hash_update_replace(st_data_t *key, st_data_t *value, struct update_arg *arg, int existing, st_data_t newvalue)
{
if (existing) {
arg->old_value = *value;
}
else {
arg->new_key = *key;
}
arg->new_value = newvalue;
*value = newvalue;
return ST_CONTINUE;
}

static int
rb_hash_update_callback(st_data_t *key, st_data_t *value, struct update_arg *arg, int existing)
{
return hash_update_replace(key, value, arg, existing, arg->arg);
*value = arg->arg;
return ST_CONTINUE;
}

NOINSERT_UPDATE_CALLBACK(rb_hash_update_callback)
Expand All @@ -3899,7 +3899,8 @@ rb_hash_update_block_callback(st_data_t *key, st_data_t *value, struct update_ar
if (existing) {
newvalue = (st_data_t)rb_yield_values(3, (VALUE)*key, (VALUE)*value, (VALUE)newvalue);
}
return hash_update_replace(key, value, arg, existing, newvalue);
*value = newvalue;
return ST_CONTINUE;
}

NOINSERT_UPDATE_CALLBACK(rb_hash_update_block_callback)
Expand Down Expand Up @@ -3995,7 +3996,8 @@ rb_hash_update_func_callback(st_data_t *key, st_data_t *value, struct update_arg
if (existing) {
newvalue = (*uf_arg->func)((VALUE)*key, (VALUE)*value, newvalue);
}
return hash_update_replace(key, value, arg, existing, (st_data_t)newvalue);
*value = newvalue;
return ST_CONTINUE;
}

NOINSERT_UPDATE_CALLBACK(rb_hash_update_func_callback)
Expand Down
28 changes: 28 additions & 0 deletions test/ruby/test_hash.rb
Original file line number Diff line number Diff line change
Expand Up @@ -1232,6 +1232,20 @@ def test_update4
assert_equal({1=>8, 2=>4, 3=>4, 5=>7}, h1)
end

def test_update5
h = @cls[a: 1, b: 2, c: 3]
assert_raise(FrozenError) do
h.update({a: 10, b: 20}){ |key, v1, v2| key == :b && h.freeze; v2 }
end
assert_equal(@cls[a: 10, b: 2, c: 3], h)

h = @cls[a: 1, b: 2, c: 3, d: 4, e: 5, f: 6, g: 7, h: 8, i: 9, j: 10]
assert_raise(FrozenError) do
h.update({a: 10, b: 20}){ |key, v1, v2| key == :b && h.freeze; v2 }
end
assert_equal(@cls[a: 10, b: 2, c: 3, d: 4, e: 5, f: 6, g: 7, h: 8, i: 9, j: 10], h)
end

def test_merge
h1 = @cls[1=>2, 3=>4]
h2 = {1=>3, 5=>7}
Expand All @@ -1243,6 +1257,20 @@ def test_merge
assert_equal({1=>8, 2=>4, 3=>4, 5=>7}, h1.merge(h2, h3) {|k, v1, v2| k + v1 + v2 })
end

def test_merge!
h = @cls[a: 1, b: 2, c: 3]
assert_raise(FrozenError) do
h.merge!({a: 10, b: 20}){ |key, v1, v2| key == :b && h.freeze; v2 }
end
assert_equal(@cls[a: 10, b: 2, c: 3], h)

h = @cls[a: 1, b: 2, c: 3, d: 4, e: 5, f: 6, g: 7, h: 8, i: 9, j: 10]
assert_raise(FrozenError) do
h.merge!({a: 10, b: 20}){ |key, v1, v2| key == :b && h.freeze; v2 }
end
assert_equal(@cls[a: 10, b: 2, c: 3, d: 4, e: 5, f: 6, g: 7, h: 8, i: 9, j: 10], h)
end

def test_assoc
assert_equal([3,4], @cls[1=>2, 3=>4, 5=>6].assoc(3))
assert_nil(@cls[1=>2, 3=>4, 5=>6].assoc(4))
Expand Down