Skip to content

Commit

Permalink
[ruby/prism] Implement case equality on nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
exterm authored and matzbot committed Apr 23, 2024
1 parent 87b829a commit f7d1699
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
23 changes: 19 additions & 4 deletions prism/templates/lib/prism/node.rb.erb
Expand Up @@ -219,10 +219,10 @@ module Prism
def deconstruct_keys(keys)
{ <%= (node.fields.map { |field| "#{field.name}: #{field.name}" } + ["location: location"]).join(", ") %> }
end

<%- node.fields.each do |field| -%>
<%- if field.comment.nil? -%>
# <%= "private " if field.is_a?(Prism::Template::FlagsField) %>attr_reader <%= field.name %>: <%= field.rbs_class %>
# <%= "protected " if field.is_a?(Prism::Template::FlagsField) %>attr_reader <%= field.name %>: <%= field.rbs_class %>
<%- else -%>
<%- field.each_comment_line do |line| -%>
#<%= line %>
Expand All @@ -248,9 +248,8 @@ module Prism
end
end
<%- else -%>
attr_reader :<%= field.name -%><%= "\n private :#{field.name}" if field.is_a?(Prism::Template::FlagsField) %>
attr_reader :<%= field.name -%><%= "\n protected :#{field.name}" if field.is_a?(Prism::Template::FlagsField) %>
<%- end -%>
<%- end -%>
<%- node.fields.each do |field| -%>
<%- case field -%>
Expand Down Expand Up @@ -349,6 +348,22 @@ module Prism
def self.type
:<%= node.human %>
end

# Implements case-equality for the node. This is effectively == but without
# comparing the value of locations. Locations are checked only for presence.
def ===(other)
other.is_a?(<%= node.name %>)<%= " &&" if node.fields.any? %>
<%- node.fields.each_with_index do |field, index| -%>
<%- if field.is_a?(Prism::Template::LocationField) || field.is_a?(Prism::Template::OptionalLocationField) -%>
(<%= field.name %>.nil? == other.<%= field.name %>.nil?)<%= " &&" if index != node.fields.length - 1 %>
<%- elsif field.is_a?(Prism::Template::NodeListField) || field.is_a?(Prism::Template::ConstantListField) -%>
(<%= field.name %>.length == other.<%= field.name %>.length) &&
<%= field.name %>.zip(other.<%= field.name %>).all? { |left, right| left === right }<%= " &&" if index != node.fields.length - 1 %>
<%- else -%>
(<%= field.name %> === other.<%= field.name %>)<%= " &&" if index != node.fields.length - 1 %>
<%- end -%>
<%- end -%>
end
end
<%- end -%>
<%- flags.each_with_index do |flag, flag_index| -%>
Expand Down
15 changes: 15 additions & 0 deletions test/prism/ruby_api_test.rb
Expand Up @@ -244,6 +244,21 @@ def test_integer_base_flags
assert_equal 16, base[parse_expression("0x1")]
end

def test_node_equality
assert_operator parse_expression("1"), :===, parse_expression("1")
assert_operator Prism.parse("1").value, :===, Prism.parse("1").value

complex_source = "class Something; @var = something.else { _1 }; end"
assert_operator parse_expression(complex_source), :===, parse_expression(complex_source)

refute_operator parse_expression("1"), :===, parse_expression("2")
refute_operator parse_expression("1"), :===, parse_expression("0x1")

complex_source_1 = "class Something; @var = something.else { _1 }; end"
complex_source_2 = "class Something; @var = something.else { _2 }; end"
refute_operator parse_expression(complex_source_1), :===, parse_expression(complex_source_2)
end

private

def parse_expression(source)
Expand Down

0 comments on commit f7d1699

Please sign in to comment.