Skip to content

Commit

Permalink
[ruby/prism] Change numbered parameters
Browse files Browse the repository at this point in the history
Previously numbered parameters were a field on blocks and lambdas
that indicated the maximum number of numbered parameters in either
the block or lambda, respectively. However they also had a
parameters field that would always be nil in these cases.

This changes it so that we introduce a NumberedParametersNode that
goes in place of parameters, which has a single uint8_t maximum
field on it. That field contains the maximum numbered parameter in
either the block or lambda.

As a part of the PR, I'm introducing a new UInt8Field type that
can be used on nodes, which is just to make it a little more
explicit what the maximum values can be (the maximum is actually 9,
since it only goes up to _9). Plus we can do a couple of nice
things in serialization like just read a single byte.

ruby/prism@2d87303903
  • Loading branch information
kddnewton committed Dec 1, 2023
1 parent 90d9c20 commit cdb74d7
Show file tree
Hide file tree
Showing 205 changed files with 983 additions and 1,361 deletions.
11 changes: 8 additions & 3 deletions lib/prism/debug.rb
Expand Up @@ -103,9 +103,14 @@ def self.prism_locals(source)
case node
when BlockNode, DefNode, LambdaNode
names = node.locals

params = node.parameters
params = params&.parameters unless node.is_a?(DefNode)
params =
if node.is_a?(DefNode)
node.parameters
elsif node.parameters.is_a?(NumberedParametersNode)
nil
else
node.parameters&.parameters
end

# prism places parameters in the same order that they appear in the
# source. CRuby places them in the order that they need to appear
Expand Down
16 changes: 10 additions & 6 deletions prism/config.yml
Expand Up @@ -589,15 +589,12 @@ nodes:
type: constant[]
- name: parameters
type: node?
kind: BlockParametersNode
- name: body
type: node?
- name: opening_loc
type: location
- name: closing_loc
type: location
- name: numbered_parameters
type: uint32
comment: |
Represents a block of ruby code.
Expand Down Expand Up @@ -1772,11 +1769,8 @@ nodes:
type: location
- name: parameters
type: node?
kind: BlockParametersNode
- name: body
type: node?
- name: numbered_parameters
type: uint32
comment: |
Represents using a lambda literal (not the lambda method call).
Expand Down Expand Up @@ -2026,6 +2020,16 @@ nodes:
def a(**nil)
^^^^^
end
- name: NumberedParametersNode
fields:
- name: maximum
type: uint8
comment: |
Represents an implicit set of parameters through the use of numbered
parameters within a block or lambda.
-> { _1 + _2 }
^^^^^^^^^^^^^^
- name: NumberedReferenceReadNode
fields:
- name: number
Expand Down
2 changes: 1 addition & 1 deletion prism/parser.h
Expand Up @@ -474,7 +474,7 @@ typedef struct pm_scope {
* numbered parameters, and to pass information to consumers of the AST
* about how many numbered parameters exist.
*/
uint32_t numbered_parameters;
uint8_t numbered_parameters;
} pm_scope_t;

/**
Expand Down
81 changes: 54 additions & 27 deletions prism/prism.c
Expand Up @@ -1467,7 +1467,7 @@ pm_block_argument_node_create(pm_parser_t *parser, const pm_token_t *operator, p
* Allocate and initialize a new BlockNode node.
*/
static pm_block_node_t *
pm_block_node_create(pm_parser_t *parser, pm_constant_id_list_t *locals, const pm_token_t *opening, pm_block_parameters_node_t *parameters, pm_node_t *body, const pm_token_t *closing, uint32_t numbered_parameters) {
pm_block_node_create(pm_parser_t *parser, pm_constant_id_list_t *locals, const pm_token_t *opening, pm_node_t *parameters, pm_node_t *body, const pm_token_t *closing) {
pm_block_node_t *node = PM_ALLOC_NODE(parser, pm_block_node_t);

*node = (pm_block_node_t) {
Expand All @@ -1478,7 +1478,6 @@ pm_block_node_create(pm_parser_t *parser, pm_constant_id_list_t *locals, const p
.locals = *locals,
.parameters = parameters,
.body = body,
.numbered_parameters = numbered_parameters,
.opening_loc = PM_LOCATION_TOKEN_VALUE(opening),
.closing_loc = PM_LOCATION_TOKEN_VALUE(closing)
};
Expand Down Expand Up @@ -3958,9 +3957,8 @@ pm_lambda_node_create(
const pm_token_t *operator,
const pm_token_t *opening,
const pm_token_t *closing,
pm_block_parameters_node_t *parameters,
pm_node_t *body,
uint32_t numbered_parameters
pm_node_t *parameters,
pm_node_t *body
) {
pm_lambda_node_t *node = PM_ALLOC_NODE(parser, pm_lambda_node_t);

Expand All @@ -3977,8 +3975,7 @@ pm_lambda_node_create(
.opening_loc = PM_LOCATION_TOKEN_VALUE(opening),
.closing_loc = PM_LOCATION_TOKEN_VALUE(closing),
.parameters = parameters,
.body = body,
.numbered_parameters = numbered_parameters
.body = body
};

return node;
Expand Down Expand Up @@ -4442,7 +4439,25 @@ pm_no_keywords_parameter_node_create(pm_parser_t *parser, const pm_token_t *oper
}

/**
* Allocate a new NthReferenceReadNode node.
* Allocate and initialize a new NumberedParametersNode node.
*/
static pm_numbered_parameters_node_t *
pm_numbered_parameters_node_create(pm_parser_t *parser, const pm_location_t *location, uint8_t maximum) {
pm_numbered_parameters_node_t *node = PM_ALLOC_NODE(parser, pm_numbered_parameters_node_t);

*node = (pm_numbered_parameters_node_t) {
{
.type = PM_NUMBERED_PARAMETERS_NODE,
.location = *location
},
.maximum = maximum
};

return node;
}

/**
* Allocate and initialize a new NthReferenceReadNode node.
*/
static pm_numbered_reference_read_node_t *
pm_numbered_reference_read_node_create(pm_parser_t *parser, const pm_token_t *name) {
Expand Down Expand Up @@ -5822,7 +5837,7 @@ pm_parser_local_add(pm_parser_t *parser, pm_constant_id_t constant_id) {
* Set the numbered_parameters value of the current scope.
*/
static inline void
pm_parser_numbered_parameters_set(pm_parser_t *parser, uint32_t numbered_parameters) {
pm_parser_numbered_parameters_set(pm_parser_t *parser, uint8_t numbered_parameters) {
parser->current_scope->numbered_parameters = numbered_parameters;
}

Expand Down Expand Up @@ -11845,24 +11860,24 @@ parse_block(pm_parser_t *parser) {

pm_accepts_block_stack_push(parser, true);
pm_parser_scope_push(parser, false);
pm_block_parameters_node_t *parameters = NULL;
pm_block_parameters_node_t *block_parameters = NULL;

if (accept1(parser, PM_TOKEN_PIPE)) {
parser->current_scope->explicit_params = true;
pm_token_t block_parameters_opening = parser->previous;

if (match1(parser, PM_TOKEN_PIPE)) {
parameters = pm_block_parameters_node_create(parser, NULL, &block_parameters_opening);
block_parameters = pm_block_parameters_node_create(parser, NULL, &block_parameters_opening);
parser->command_start = true;
parser_lex(parser);
} else {
parameters = parse_block_parameters(parser, true, &block_parameters_opening, false);
block_parameters = parse_block_parameters(parser, true, &block_parameters_opening, false);
accept1(parser, PM_TOKEN_NEWLINE);
parser->command_start = true;
expect1(parser, PM_TOKEN_PIPE, PM_ERR_BLOCK_PARAM_PIPE_TERM);
}

pm_block_parameters_node_closing_set(parameters, &parser->previous);
pm_block_parameters_node_closing_set(block_parameters, &parser->previous);
}

accept1(parser, PM_TOKEN_NEWLINE);
Expand Down Expand Up @@ -11891,11 +11906,17 @@ parse_block(pm_parser_t *parser) {
expect1(parser, PM_TOKEN_KEYWORD_END, PM_ERR_BLOCK_TERM_END);
}

pm_node_t *parameters = (pm_node_t *) block_parameters;
uint8_t maximum = parser->current_scope->numbered_parameters;

if (parameters == NULL && (maximum > 0)) {
parameters = (pm_node_t *) pm_numbered_parameters_node_create(parser, &(pm_location_t) { .start = opening.start, .end = parser->previous.end }, maximum);
}

pm_constant_id_list_t locals = parser->current_scope->locals;
uint32_t numbered_parameters = parser->current_scope->numbered_parameters;
pm_parser_scope_pop(parser);
pm_accepts_block_stack_pop(parser);
return pm_block_node_create(parser, &locals, &opening, parameters, statements, &parser->previous, numbered_parameters);
return pm_block_node_create(parser, &locals, &opening, parameters, statements, &parser->previous);
}

/**
Expand Down Expand Up @@ -12511,10 +12532,10 @@ parse_variable_call(pm_parser_t *parser) {

// We subtract the value for the character '0' to get the actual
// integer value of the number (only _1 through _9 are valid)
uint32_t number_as_int = (uint32_t) (number - '0');
if (number_as_int > parser->current_scope->numbered_parameters) {
parser->current_scope->numbered_parameters = number_as_int;
pm_parser_numbered_parameters_set(parser, number_as_int);
uint8_t numbered_parameters = (uint8_t) (number - '0');
if (numbered_parameters > parser->current_scope->numbered_parameters) {
parser->current_scope->numbered_parameters = numbered_parameters;
pm_parser_numbered_parameters_set(parser, numbered_parameters);
}

// When you use a numbered parameter, it implies the existence
Expand Down Expand Up @@ -15707,7 +15728,7 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power) {

pm_token_t operator = parser->previous;
pm_parser_scope_push(parser, false);
pm_block_parameters_node_t *params;
pm_block_parameters_node_t *block_parameters;

switch (parser->current.type) {
case PM_TOKEN_PARENTHESIS_LEFT: {
Expand All @@ -15716,27 +15737,27 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power) {
parser_lex(parser);

if (match1(parser, PM_TOKEN_PARENTHESIS_RIGHT)) {
params = pm_block_parameters_node_create(parser, NULL, &opening);
block_parameters = pm_block_parameters_node_create(parser, NULL, &opening);
} else {
params = parse_block_parameters(parser, false, &opening, true);
block_parameters = parse_block_parameters(parser, false, &opening, true);
}

accept1(parser, PM_TOKEN_NEWLINE);
expect1(parser, PM_TOKEN_PARENTHESIS_RIGHT, PM_ERR_EXPECT_RPAREN);

pm_block_parameters_node_closing_set(params, &parser->previous);
pm_block_parameters_node_closing_set(block_parameters, &parser->previous);
break;
}
case PM_CASE_PARAMETER: {
parser->current_scope->explicit_params = true;
pm_accepts_block_stack_push(parser, false);
pm_token_t opening = not_provided(parser);
params = parse_block_parameters(parser, false, &opening, true);
block_parameters = parse_block_parameters(parser, false, &opening, true);
pm_accepts_block_stack_pop(parser);
break;
}
default: {
params = NULL;
block_parameters = NULL;
break;
}
}
Expand Down Expand Up @@ -15770,11 +15791,17 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power) {
expect1(parser, PM_TOKEN_KEYWORD_END, PM_ERR_LAMBDA_TERM_END);
}

pm_node_t *parameters = (pm_node_t *) block_parameters;
uint8_t maximum = parser->current_scope->numbered_parameters;

if (parameters == NULL && (maximum > 0)) {
parameters = (pm_node_t *) pm_numbered_parameters_node_create(parser, &(pm_location_t) { .start = operator.start, .end = parser->previous.end }, maximum);
}

pm_constant_id_list_t locals = parser->current_scope->locals;
uint32_t numbered_parameters = parser->current_scope->numbered_parameters;
pm_parser_scope_pop(parser);
pm_accepts_block_stack_pop(parser);
return (pm_node_t *) pm_lambda_node_create(parser, &locals, &operator, &opening, &parser->previous, params, body, numbered_parameters);
return (pm_node_t *) pm_lambda_node_create(parser, &locals, &operator, &opening, &parser->previous, parameters, body);
}
case PM_TOKEN_UPLUS: {
parser_lex(parser);
Expand Down
3 changes: 3 additions & 0 deletions prism/templates/ext/prism/api_node.c.erb
Expand Up @@ -181,6 +181,9 @@ pm_ast_new(pm_parser_t *parser, pm_node_t *node, rb_encoding *encoding) {
<%- when Prism::OptionalLocationField -%>
#line <%= __LINE__ + 1 %> "<%= File.basename(__FILE__) %>"
argv[<%= index %>] = cast-><%= field.name %>.start == NULL ? Qnil : pm_location_new(parser, cast-><%= field.name %>.start, cast-><%= field.name %>.end, source);
<%- when Prism::UInt8Field -%>
#line <%= __LINE__ + 1 %> "<%= File.basename(__FILE__) %>"
argv[<%= index %>] = UINT2NUM(cast-><%= field.name %>);
<%- when Prism::UInt32Field -%>
#line <%= __LINE__ + 1 %> "<%= File.basename(__FILE__) %>"
argv[<%= index %>] = ULONG2NUM(cast-><%= field.name %>);
Expand Down
1 change: 1 addition & 0 deletions prism/templates/include/prism/ast.h.erb
Expand Up @@ -168,6 +168,7 @@ typedef struct pm_<%= node.human %> {
when Prism::ConstantListField then "pm_constant_id_list_t #{field.name}"
when Prism::StringField then "pm_string_t #{field.name}"
when Prism::LocationField, Prism::OptionalLocationField then "pm_location_t #{field.name}"
when Prism::UInt8Field then "uint8_t #{field.name}"
when Prism::UInt32Field then "uint32_t #{field.name}"
else raise field.class.name
end
Expand Down
2 changes: 1 addition & 1 deletion prism/templates/lib/prism/dot_visitor.rb.erb
Expand Up @@ -129,7 +129,7 @@ module Prism

digraph.edge("#{id}:<%= field.name %> -> #{waypoint};")
node.<%= field.name %>.each { |child| digraph.edge("#{waypoint} -> #{node_id(child)};") }
<%- when Prism::StringField, Prism::ConstantField, Prism::OptionalConstantField, Prism::UInt32Field, Prism::ConstantListField -%>
<%- when Prism::StringField, Prism::ConstantField, Prism::OptionalConstantField, Prism::UInt8Field, Prism::UInt32Field, Prism::ConstantListField -%>
table.field("<%= field.name %>", node.<%= field.name %>.inspect)
<%- when Prism::LocationField -%>
table.field("<%= field.name %>", location_inspect(node.<%= field.name %>))
Expand Down
2 changes: 1 addition & 1 deletion prism/templates/lib/prism/node.rb.erb
Expand Up @@ -190,7 +190,7 @@ module Prism
inspector << "<%= pointer %><%= field.name %>:\n"
inspector << <%= field.name %>.inspect(inspector.child_inspector("<%= preadd %>")).delete_prefix(inspector.prefix)
end
<%- when Prism::ConstantField, Prism::StringField, Prism::UInt32Field -%>
<%- when Prism::ConstantField, Prism::StringField, Prism::UInt8Field, Prism::UInt32Field -%>
inspector << "<%= pointer %><%= field.name %>: #{<%= field.name %>.inspect}\n"
<%- when Prism::OptionalConstantField -%>
if (<%= field.name %> = self.<%= field.name %>).nil?
Expand Down
2 changes: 2 additions & 0 deletions prism/templates/lib/prism/serialize.rb.erb
Expand Up @@ -253,6 +253,7 @@ module Prism
when Prism::ConstantListField then "Array.new(load_varuint) { load_required_constant }"
when Prism::LocationField then "load_location"
when Prism::OptionalLocationField then "load_optional_location"
when Prism::UInt8Field then "io.getbyte"
when Prism::UInt32Field, Prism::FlagsField then "load_varuint"
else raise
end
Expand Down Expand Up @@ -286,6 +287,7 @@ module Prism
when Prism::ConstantListField then "Array.new(load_varuint) { load_required_constant }"
when Prism::LocationField then "load_location"
when Prism::OptionalLocationField then "load_optional_location"
when Prism::UInt8Field then "io.getbyte"
when Prism::UInt32Field, Prism::FlagsField then "load_varuint"
else raise
end
Expand Down
6 changes: 3 additions & 3 deletions prism/templates/src/node.c.erb
Expand Up @@ -56,12 +56,12 @@ pm_node_destroy(pm_parser_t *parser, pm_node_t *node) {
<%- nodes.each do |node| -%>
#line <%= __LINE__ + 1 %> "<%= File.basename(__FILE__) %>"
case <%= node.type %>: {
<%- if node.fields.any? { |field| ![Prism::LocationField, Prism::OptionalLocationField, Prism::UInt32Field, Prism::FlagsField, Prism::ConstantField, Prism::OptionalConstantField].include?(field.class) } -%>
<%- if node.fields.any? { |field| ![Prism::LocationField, Prism::OptionalLocationField, Prism::UInt8Field, Prism::UInt32Field, Prism::FlagsField, Prism::ConstantField, Prism::OptionalConstantField].include?(field.class) } -%>
pm_<%= node.human %>_t *cast = (pm_<%= node.human %>_t *) node;
<%- end -%>
<%- node.fields.each do |field| -%>
<%- case field -%>
<%- when Prism::LocationField, Prism::OptionalLocationField, Prism::UInt32Field, Prism::FlagsField, Prism::ConstantField, Prism::OptionalConstantField -%>
<%- when Prism::LocationField, Prism::OptionalLocationField, Prism::UInt8Field, Prism::UInt32Field, Prism::FlagsField, Prism::ConstantField, Prism::OptionalConstantField -%>
<%- when Prism::NodeField -%>
pm_node_destroy(parser, (pm_node_t *)cast-><%= field.name %>);
<%- when Prism::OptionalNodeField -%>
Expand Down Expand Up @@ -113,7 +113,7 @@ pm_node_memsize_node(pm_node_t *node, pm_memsize_t *memsize) {
<%- end -%>
<%- node.fields.each do |field| -%>
<%- case field -%>
<%- when Prism::ConstantField, Prism::OptionalConstantField, Prism::UInt32Field, Prism::FlagsField, Prism::LocationField, Prism::OptionalLocationField -%>
<%- when Prism::ConstantField, Prism::OptionalConstantField, Prism::UInt8Field, Prism::UInt32Field, Prism::FlagsField, Prism::LocationField, Prism::OptionalLocationField -%>
<%- when Prism::NodeField -%>
pm_node_memsize_node((pm_node_t *)cast-><%= field.name %>, memsize);
<%- when Prism::OptionalNodeField -%>
Expand Down
2 changes: 1 addition & 1 deletion prism/templates/src/prettyprint.c.erb
Expand Up @@ -152,7 +152,7 @@ prettyprint_node(pm_buffer_t *output_buffer, const pm_parser_t *parser, const pm
prettyprint_source(output_buffer, location->start, (size_t) (location->end - location->start));
pm_buffer_append_string(output_buffer, "\"\n", 2);
}
<%- when Prism::UInt32Field -%>
<%- when Prism::UInt8Field, Prism::UInt32Field -%>
pm_buffer_append_format(output_buffer, " %d\n", cast-><%= field.name %>);
<%- when Prism::FlagsField -%>
bool found = false;
Expand Down
2 changes: 2 additions & 0 deletions prism/templates/src/serialize.c.erb
Expand Up @@ -107,6 +107,8 @@ pm_serialize_node(pm_parser_t *parser, pm_node_t *node, pm_buffer_t *buffer) {
pm_serialize_location(parser, &((pm_<%= node.human %>_t *)node)-><%= field.name %>, buffer);
}
<%- end -%>
<%- when Prism::UInt8Field -%>
pm_buffer_append_byte(buffer, ((pm_<%= node.human %>_t *)node)-><%= field.name %>);
<%- when Prism::UInt32Field -%>
pm_buffer_append_varuint(buffer, ((pm_<%= node.human %>_t *)node)-><%= field.name %>);
<%- when Prism::FlagsField -%>
Expand Down
16 changes: 16 additions & 0 deletions prism/templates/template.rb
Expand Up @@ -197,6 +197,21 @@ def java_type
end
end

# This represents an integer field.
class UInt8Field < Field
def rbs_class
"Integer"
end

def rbi_class
"Integer"
end

def java_type
"int"
end
end

# This represents an integer field.
class UInt32Field < Field
def rbs_class
Expand Down Expand Up @@ -282,6 +297,7 @@ def field_type_for(name)
when "constant[]" then ConstantListField
when "location" then LocationField
when "location?" then OptionalLocationField
when "uint8" then UInt8Field
when "uint32" then UInt32Field
when "flags" then FlagsField
else raise("Unknown field type: #{name.inspect}")
Expand Down

0 comments on commit cdb74d7

Please sign in to comment.