diff --git a/ext/-test-/memory_view/memory_view.c b/ext/-test-/memory_view/memory_view.c index 621614b1111a76..53656b145c61af 100644 --- a/ext/-test-/memory_view/memory_view.c +++ b/ext/-test-/memory_view/memory_view.c @@ -264,9 +264,17 @@ static int mdview_get_memory_view(VALUE obj, rb_memory_view_t *view, int flags) { VALUE buf_v = rb_ivar_get(obj, id_str); + VALUE format_v = rb_ivar_get(obj, SYM2ID(sym_format)); VALUE shape_v = rb_ivar_get(obj, SYM2ID(sym_shape)); VALUE strides_v = rb_ivar_get(obj, SYM2ID(sym_strides)); + const char *format = RSTRING_PTR(format_v); + const char *err; + const ssize_t item_size = rb_memory_view_item_size_from_format(format, &err); + if (item_size < 0) { + return 0; + } + ssize_t i, ndim = RARRAY_LEN(shape_v); ssize_t *shape = ALLOC_N(ssize_t, ndim); ssize_t *strides = NULL; @@ -288,8 +296,8 @@ mdview_get_memory_view(VALUE obj, rb_memory_view_t *view, int flags) } rb_memory_view_init_as_byte_array(view, obj, RSTRING_PTR(buf_v), RSTRING_LEN(buf_v), true); - view->format = "l"; - view->item_size = sizeof(long); + view->format = StringValueCStr(format_v); + view->item_size = item_size; view->ndim = ndim; view->shape = shape; view->strides = strides; @@ -310,13 +318,15 @@ static const rb_memory_view_entry_t mdview_memory_view_entry = { }; static VALUE -mdview_initialize(VALUE obj, VALUE buf, VALUE shape, VALUE strides) +mdview_initialize(VALUE obj, VALUE buf, VALUE format, VALUE shape, VALUE strides) { Check_Type(buf, T_STRING); + StringValue(format); Check_Type(shape, T_ARRAY); if (!NIL_P(strides)) Check_Type(strides, T_ARRAY); rb_ivar_set(obj, id_str, buf); + rb_ivar_set(obj, SYM2ID(sym_format), format); rb_ivar_set(obj, SYM2ID(sym_shape), shape); rb_ivar_set(obj, SYM2ID(sym_strides), strides); return Qnil; @@ -344,11 +354,8 @@ mdview_aref(VALUE obj, VALUE indices_v) indices[i] = NUM2SSIZET(RARRAY_AREF(indices_v, i)); } - char *ptr = rb_memory_view_get_item_pointer(&view, indices); + VALUE result = rb_memory_view_get_item(&view, indices); ALLOCV_END(buf_indices); - - long x = *(long *)ptr; - VALUE result = LONG2FIX(x); rb_memory_view_release(&view); return result; @@ -373,7 +380,7 @@ Init_memory_view(void) rb_memory_view_register(cExportableString, &exportable_string_memory_view_entry); VALUE cMDView = rb_define_class_under(mMemoryViewTestUtils, "MultiDimensionalView", rb_cObject); - rb_define_method(cMDView, "initialize", mdview_initialize, 3); + rb_define_method(cMDView, "initialize", mdview_initialize, 4); rb_define_method(cMDView, "[]", mdview_aref, 1); rb_memory_view_register(cMDView, &mdview_memory_view_entry); diff --git a/include/ruby/memory_view.h b/include/ruby/memory_view.h index 6d6aeffa6edcda..26aef464535ae8 100644 --- a/include/ruby/memory_view.h +++ b/include/ruby/memory_view.h @@ -77,11 +77,13 @@ typedef struct { struct { /* The array of rb_memory_view_item_component_t that describes the - * item structure. */ + * item structure. rb_memory_view_prepare_item_desc and + * rb_memory_view_get_item allocate this memory if needed, + * and rb_memory_view_release frees it. */ rb_memory_view_item_component_t *components; /* The number of components in an item. */ - ssize_t length; + size_t length; } item_desc; /* The number of dimension. */ @@ -132,6 +134,8 @@ ssize_t rb_memory_view_parse_item_format(const char *format, ssize_t rb_memory_view_item_size_from_format(const char *format, const char **err); void *rb_memory_view_get_item_pointer(rb_memory_view_t *view, const ssize_t *indices); VALUE rb_memory_view_extract_item_members(const void *ptr, const rb_memory_view_item_component_t *members, const size_t n_members); +VALUE rb_memory_view_prepare_item_dexc(rb_memory_view_t *view); +VALUE rb_memory_view_get_item(rb_memory_view_t *view, const ssize_t *indices); int rb_memory_view_available_p(VALUE obj); int rb_memory_view_get(VALUE obj, rb_memory_view_t* memory_view, int flags); diff --git a/memory_view.c b/memory_view.c index 7008c6054863fa..5d6ef14b3a22a2 100644 --- a/memory_view.c +++ b/memory_view.c @@ -226,6 +226,8 @@ rb_memory_view_init_as_byte_array(rb_memory_view_t *view, VALUE obj, void *data, view->readonly = readonly; view->format = NULL; view->item_size = 1; + view->item_desc.components = NULL; + view->item_desc.length = 0; view->ndim = 1; view->shape = NULL; view->strides = NULL; @@ -764,6 +766,37 @@ rb_memory_view_extract_item_members(const void *ptr, const rb_memory_view_item_c return item; } +void +rb_memory_view_prepare_item_desc(rb_memory_view_t *view) +{ + if (view->item_desc.components == NULL) { + const char *err; + ssize_t n = rb_memory_view_parse_item_format(view->format, &view->item_desc.components, &view->item_desc.length, &err); + if (n < 0) { + rb_raise(rb_eRuntimeError, + "Unable to parse item format at %"PRIdSIZE" in \"%s\"", + (err - view->format), view->format); + } + } +} + +/* Return a value that consists of item members in the given memory view. */ +VALUE +rb_memory_view_get_item(rb_memory_view_t *view, const ssize_t *indices) +{ + void *ptr = rb_memory_view_get_item_pointer(view, indices); + + if (view->format == NULL) { + return INT2FIX(*(uint8_t *)ptr); + } + + if (view->item_desc.components == NULL) { + rb_memory_view_prepare_item_desc(view); + } + + return rb_memory_view_extract_item_members(ptr, view->item_desc.components, view->item_desc.length); +} + static const rb_memory_view_entry_t * lookup_memory_view_entry(VALUE klass) { @@ -830,6 +863,9 @@ rb_memory_view_release(rb_memory_view_t* view) if (rv) { unregister_exported_object(view->obj); view->obj = Qnil; + if (view->item_desc.components) { + xfree(view->item_desc.components); + } } return rv; } diff --git a/test/ruby/test_memory_view.rb b/test/ruby/test_memory_view.rb index 0150e18c726014..9a6c834bb82271 100644 --- a/test/ruby/test_memory_view.rb +++ b/test/ruby/test_memory_view.rb @@ -297,25 +297,27 @@ def test_rb_memory_view_fill_contiguous_strides column_major_strides) end - def test_rb_memory_view_get_item_pointer + def test_rb_memory_view_get_item_pointer_single_member buf = [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 ].pack("l!*") shape = [3, 4] - mv = MemoryViewTestUtils::MultiDimensionalView.new(buf, shape, nil) + mv = MemoryViewTestUtils::MultiDimensionalView.new(buf, "l!", shape, nil) assert_equal(1, mv[[0, 0]]) assert_equal(4, mv[[0, 3]]) assert_equal(6, mv[[1, 1]]) assert_equal(10, mv[[2, 1]]) + end + def test_rb_memory_view_get_item_pointer_multiple_members buf = [ 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16 ].pack("l!*") - shape = [2, 8] - strides = [4*sizeof(:long)*2, sizeof(:long)*2] - mv = MemoryViewTestUtils::MultiDimensionalView.new(buf, shape, strides) - assert_equal(1, mv[[0, 0]]) - assert_equal(5, mv[[0, 2]]) - assert_equal(9, mv[[1, 0]]) - assert_equal(15, mv[[1, 3]]) + -1, -2, -3, -4, -5, -6, -7, -8].pack("s*") + shape = [2, 4] + strides = [4*sizeof(:short)*2, sizeof(:short)*2] + mv = MemoryViewTestUtils::MultiDimensionalView.new(buf, "ss", shape, strides) + assert_equal([1, 2], mv[[0, 0]]) + assert_equal([5, 6], mv[[0, 2]]) + assert_equal([-1, -2], mv[[1, 0]]) + assert_equal([-7, -8], mv[[1, 3]]) end end