From 846dc8145fddfd0c60406f64fa7e0b85d958708d Mon Sep 17 00:00:00 2001 From: Vinicius Stock Date: Wed, 29 Apr 2026 17:40:05 -0400 Subject: [PATCH] Fix UTF-8 code units to match the number of bytes --- lib/prism/parse_result.rb | 26 ++++++++++++++---- rbi/generated/prism/parse_result.rbi | 11 +++++++- sig/generated/prism/parse_result.rbs | 13 +++++++-- test/prism/ruby/location_test.rb | 40 ++++++++++++++-------------- 4 files changed, 62 insertions(+), 28 deletions(-) diff --git a/lib/prism/parse_result.rb b/lib/prism/parse_result.rb index 4f7bcf07d6..7cf6630f44 100644 --- a/lib/prism/parse_result.rb +++ b/lib/prism/parse_result.rb @@ -189,6 +189,8 @@ def character_column(byte_offset) #-- #: (Integer byte_offset, Encoding encoding) -> Integer def code_units_offset(byte_offset, encoding) + return byte_offset if encoding == Encoding::UTF_8 + byteslice = (source.byteslice(0, byte_offset) or raise).encode(encoding, invalid: :replace, undef: :replace) if encoding == Encoding::UTF_16LE || encoding == Encoding::UTF_16BE @@ -250,6 +252,14 @@ def find_line(byte_offset) # :nodoc: # has not yet been implemented. # class CodeUnitsCache + # Counter used for UTF-8, where one code unit equals one byte. + class UTF8Counter # :nodoc: + #: (Integer byte_offset, Integer byte_length) -> Integer + def count(byte_offset, byte_length) + byte_length + end + end + class UTF16Counter # :nodoc: # @rbs @source: String # @rbs @encoding: Encoding @@ -266,7 +276,10 @@ def count(byte_offset, byte_length) end end - class LengthCounter # :nodoc: + # Counter used for UTF-32, where one code unit equals one code point and + # matches String#length. Also used as a best-effort fallback for any other + # encoding that does not have a dedicated counter. + class UTF32Counter # :nodoc: # @rbs @source: String # @rbs @encoding: Encoding @@ -282,10 +295,10 @@ def count(byte_offset, byte_length) end end - private_constant :UTF16Counter, :LengthCounter + private_constant :UTF8Counter, :UTF16Counter, :UTF32Counter # @rbs @source: String - # @rbs @counter: UTF16Counter | LengthCounter + # @rbs @counter: UTF8Counter | UTF16Counter | UTF32Counter # @rbs @cache: Hash[Integer, Integer] # @rbs @offsets: Array[Integer] @@ -295,10 +308,13 @@ def count(byte_offset, byte_length) def initialize(source, encoding) @source = source @counter = - if encoding == Encoding::UTF_16LE || encoding == Encoding::UTF_16BE + case encoding + when Encoding::UTF_8 + UTF8Counter.new + when Encoding::UTF_16LE, Encoding::UTF_16BE UTF16Counter.new(source, encoding) else - LengthCounter.new(source, encoding) + UTF32Counter.new(source, encoding) end @cache = {} #: Hash[Integer, Integer] diff --git a/rbi/generated/prism/parse_result.rbi b/rbi/generated/prism/parse_result.rbi index 4d065b5be1..94a8522c18 100644 --- a/rbi/generated/prism/parse_result.rbi +++ b/rbi/generated/prism/parse_result.rbi @@ -143,6 +143,12 @@ module Prism # introduce some kind of LRU cache to limit the number of entries, but this # has not yet been implemented. class CodeUnitsCache + # Counter used for UTF-8, where one code unit equals one byte. + class UTF8Counter + sig { params(byte_offset: Integer, byte_length: Integer).returns(Integer) } + def count(byte_offset, byte_length); end + end + class UTF16Counter sig { params(source: String, encoding: Encoding).void } def initialize(source, encoding); end @@ -151,7 +157,10 @@ module Prism def count(byte_offset, byte_length); end end - class LengthCounter + # Counter used for UTF-32, where one code unit equals one code point and + # matches String#length. Also used as a best-effort fallback for any other + # encoding that does not have a dedicated counter. + class UTF32Counter sig { params(source: String, encoding: Encoding).void } def initialize(source, encoding); end diff --git a/sig/generated/prism/parse_result.rbs b/sig/generated/prism/parse_result.rbs index f005f17375..7df37cef69 100644 --- a/sig/generated/prism/parse_result.rbs +++ b/sig/generated/prism/parse_result.rbs @@ -169,6 +169,12 @@ module Prism # introduce some kind of LRU cache to limit the number of entries, but this # has not yet been implemented. class CodeUnitsCache + # Counter used for UTF-8, where one code unit equals one byte. + class UTF8Counter + # : (Integer byte_offset, Integer byte_length) -> Integer + def count: (Integer byte_offset, Integer byte_length) -> Integer + end + class UTF16Counter @source: String @@ -181,7 +187,10 @@ module Prism def count: (Integer byte_offset, Integer byte_length) -> Integer end - class LengthCounter + # Counter used for UTF-32, where one code unit equals one code point and + # matches String#length. Also used as a best-effort fallback for any other + # encoding that does not have a dedicated counter. + class UTF32Counter @source: String @encoding: Encoding @@ -195,7 +204,7 @@ module Prism @source: String - @counter: UTF16Counter | LengthCounter + @counter: UTF8Counter | UTF16Counter | UTF32Counter @cache: Hash[Integer, Integer] diff --git a/test/prism/ruby/location_test.rb b/test/prism/ruby/location_test.rb index 5e2ab63802..12c4258cde 100644 --- a/test/prism/ruby/location_test.rb +++ b/test/prism/ruby/location_test.rb @@ -73,7 +73,7 @@ def test_code_units assert_equal 0, location.start_code_units_offset(Encoding::UTF_16LE) assert_equal 0, location.start_code_units_offset(Encoding::UTF_32LE) - assert_equal 1, location.end_code_units_offset(Encoding::UTF_8) + assert_equal 4, location.end_code_units_offset(Encoding::UTF_8) assert_equal 2, location.end_code_units_offset(Encoding::UTF_16LE) assert_equal 1, location.end_code_units_offset(Encoding::UTF_32LE) @@ -81,37 +81,37 @@ def test_code_units assert_equal 0, location.start_code_units_column(Encoding::UTF_16LE) assert_equal 0, location.start_code_units_column(Encoding::UTF_32LE) - assert_equal 1, location.end_code_units_column(Encoding::UTF_8) + assert_equal 4, location.end_code_units_column(Encoding::UTF_8) assert_equal 2, location.end_code_units_column(Encoding::UTF_16LE) assert_equal 1, location.end_code_units_column(Encoding::UTF_32LE) # second 😀 location = program.statements.body.first.arguments.arguments.first.location - assert_equal 4, location.start_code_units_offset(Encoding::UTF_8) + assert_equal 7, location.start_code_units_offset(Encoding::UTF_8) assert_equal 5, location.start_code_units_offset(Encoding::UTF_16LE) assert_equal 4, location.start_code_units_offset(Encoding::UTF_32LE) - assert_equal 5, location.end_code_units_offset(Encoding::UTF_8) + assert_equal 11, location.end_code_units_offset(Encoding::UTF_8) assert_equal 7, location.end_code_units_offset(Encoding::UTF_16LE) assert_equal 5, location.end_code_units_offset(Encoding::UTF_32LE) - assert_equal 4, location.start_code_units_column(Encoding::UTF_8) + assert_equal 7, location.start_code_units_column(Encoding::UTF_8) assert_equal 5, location.start_code_units_column(Encoding::UTF_16LE) assert_equal 4, location.start_code_units_column(Encoding::UTF_32LE) - assert_equal 5, location.end_code_units_column(Encoding::UTF_8) + assert_equal 11, location.end_code_units_column(Encoding::UTF_8) assert_equal 7, location.end_code_units_column(Encoding::UTF_16LE) assert_equal 5, location.end_code_units_column(Encoding::UTF_32LE) # first 😍 location = program.statements.body.last.name_loc - assert_equal 6, location.start_code_units_offset(Encoding::UTF_8) + assert_equal 12, location.start_code_units_offset(Encoding::UTF_8) assert_equal 8, location.start_code_units_offset(Encoding::UTF_16LE) assert_equal 6, location.start_code_units_offset(Encoding::UTF_32LE) - assert_equal 7, location.end_code_units_offset(Encoding::UTF_8) + assert_equal 16, location.end_code_units_offset(Encoding::UTF_8) assert_equal 10, location.end_code_units_offset(Encoding::UTF_16LE) assert_equal 7, location.end_code_units_offset(Encoding::UTF_32LE) @@ -119,26 +119,26 @@ def test_code_units assert_equal 0, location.start_code_units_column(Encoding::UTF_16LE) assert_equal 0, location.start_code_units_column(Encoding::UTF_32LE) - assert_equal 1, location.end_code_units_column(Encoding::UTF_8) + assert_equal 4, location.end_code_units_column(Encoding::UTF_8) assert_equal 2, location.end_code_units_column(Encoding::UTF_16LE) assert_equal 1, location.end_code_units_column(Encoding::UTF_32LE) # second 😍 location = program.statements.body.last.value.location - assert_equal 12, location.start_code_units_offset(Encoding::UTF_8) + assert_equal 21, location.start_code_units_offset(Encoding::UTF_8) assert_equal 15, location.start_code_units_offset(Encoding::UTF_16LE) assert_equal 12, location.start_code_units_offset(Encoding::UTF_32LE) - assert_equal 13, location.end_code_units_offset(Encoding::UTF_8) + assert_equal 25, location.end_code_units_offset(Encoding::UTF_8) assert_equal 17, location.end_code_units_offset(Encoding::UTF_16LE) assert_equal 13, location.end_code_units_offset(Encoding::UTF_32LE) - assert_equal 6, location.start_code_units_column(Encoding::UTF_8) + assert_equal 9, location.start_code_units_column(Encoding::UTF_8) assert_equal 7, location.start_code_units_column(Encoding::UTF_16LE) assert_equal 6, location.start_code_units_column(Encoding::UTF_32LE) - assert_equal 7, location.end_code_units_column(Encoding::UTF_8) + assert_equal 13, location.end_code_units_column(Encoding::UTF_8) assert_equal 9, location.end_code_units_column(Encoding::UTF_16LE) assert_equal 7, location.end_code_units_column(Encoding::UTF_32LE) end @@ -157,7 +157,7 @@ def test_cached_code_units assert_equal 0, location.cached_start_code_units_offset(utf16_cache) assert_equal 0, location.cached_start_code_units_offset(utf32_cache) - assert_equal 1, location.cached_end_code_units_offset(utf8_cache) + assert_equal 4, location.cached_end_code_units_offset(utf8_cache) assert_equal 2, location.cached_end_code_units_offset(utf16_cache) assert_equal 1, location.cached_end_code_units_offset(utf32_cache) @@ -165,26 +165,26 @@ def test_cached_code_units assert_equal 0, location.cached_start_code_units_column(utf16_cache) assert_equal 0, location.cached_start_code_units_column(utf32_cache) - assert_equal 1, location.cached_end_code_units_column(utf8_cache) + assert_equal 4, location.cached_end_code_units_column(utf8_cache) assert_equal 2, location.cached_end_code_units_column(utf16_cache) assert_equal 1, location.cached_end_code_units_column(utf32_cache) # second 😀 location = result.value.statements.body.first.arguments.arguments.first.location - assert_equal 4, location.cached_start_code_units_offset(utf8_cache) + assert_equal 7, location.cached_start_code_units_offset(utf8_cache) assert_equal 5, location.cached_start_code_units_offset(utf16_cache) assert_equal 4, location.cached_start_code_units_offset(utf32_cache) - assert_equal 5, location.cached_end_code_units_offset(utf8_cache) + assert_equal 11, location.cached_end_code_units_offset(utf8_cache) assert_equal 7, location.cached_end_code_units_offset(utf16_cache) assert_equal 5, location.cached_end_code_units_offset(utf32_cache) - assert_equal 4, location.cached_start_code_units_column(utf8_cache) + assert_equal 7, location.cached_start_code_units_column(utf8_cache) assert_equal 5, location.cached_start_code_units_column(utf16_cache) assert_equal 4, location.cached_start_code_units_column(utf32_cache) - assert_equal 5, location.cached_end_code_units_column(utf8_cache) + assert_equal 11, location.cached_end_code_units_column(utf8_cache) assert_equal 7, location.cached_end_code_units_column(utf16_cache) assert_equal 5, location.cached_end_code_units_column(utf32_cache) end @@ -200,7 +200,7 @@ def test_code_units_binary_valid_utf8 assert_equal "😀".b.to_sym, receiver.name location = receiver.location - assert_equal 1, location.end_code_units_column(Encoding::UTF_8) + assert_equal 4, location.end_code_units_column(Encoding::UTF_8) assert_equal 2, location.end_code_units_column(Encoding::UTF_16LE) assert_equal 1, location.end_code_units_column(Encoding::UTF_32LE) end