Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions lib/prism/parse_result.rb
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def character_column(byte_offset)
#--
#: (Integer byte_offset, Encoding encoding) -> Integer
def code_units_offset(byte_offset, encoding)
return byte_offset if encoding == Encoding::UTF_8

byteslice = (source.byteslice(0, byte_offset) or raise).encode(encoding, invalid: :replace, undef: :replace)

if encoding == Encoding::UTF_16LE || encoding == Encoding::UTF_16BE
Expand Down Expand Up @@ -250,6 +252,14 @@ def find_line(byte_offset) # :nodoc:
# has not yet been implemented.
#
class CodeUnitsCache
# Counter used for UTF-8, where one code unit equals one byte.
class UTF8Counter # :nodoc:
#: (Integer byte_offset, Integer byte_length) -> Integer
def count(byte_offset, byte_length)
byte_length
end
end

class UTF16Counter # :nodoc:
# @rbs @source: String
# @rbs @encoding: Encoding
Expand All @@ -266,7 +276,10 @@ def count(byte_offset, byte_length)
end
end

class LengthCounter # :nodoc:
# Counter used for UTF-32, where one code unit equals one code point and
# matches String#length. Also used as a best-effort fallback for any other
# encoding that does not have a dedicated counter.
class UTF32Counter # :nodoc:
# @rbs @source: String
# @rbs @encoding: Encoding

Expand All @@ -282,10 +295,10 @@ def count(byte_offset, byte_length)
end
end

private_constant :UTF16Counter, :LengthCounter
private_constant :UTF8Counter, :UTF16Counter, :UTF32Counter

# @rbs @source: String
# @rbs @counter: UTF16Counter | LengthCounter
# @rbs @counter: UTF8Counter | UTF16Counter | UTF32Counter
# @rbs @cache: Hash[Integer, Integer]
# @rbs @offsets: Array[Integer]

Expand All @@ -295,10 +308,13 @@ def count(byte_offset, byte_length)
def initialize(source, encoding)
@source = source
@counter =
if encoding == Encoding::UTF_16LE || encoding == Encoding::UTF_16BE
case encoding
when Encoding::UTF_8
UTF8Counter.new
when Encoding::UTF_16LE, Encoding::UTF_16BE
UTF16Counter.new(source, encoding)
else
LengthCounter.new(source, encoding)
UTF32Counter.new(source, encoding)
end

@cache = {} #: Hash[Integer, Integer]
Expand Down
11 changes: 10 additions & 1 deletion rbi/generated/prism/parse_result.rbi

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 11 additions & 2 deletions sig/generated/prism/parse_result.rbs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 20 additions & 20 deletions test/prism/ruby/location_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -73,72 +73,72 @@ def test_code_units
assert_equal 0, location.start_code_units_offset(Encoding::UTF_16LE)
assert_equal 0, location.start_code_units_offset(Encoding::UTF_32LE)

assert_equal 1, location.end_code_units_offset(Encoding::UTF_8)
assert_equal 4, location.end_code_units_offset(Encoding::UTF_8)
assert_equal 2, location.end_code_units_offset(Encoding::UTF_16LE)
assert_equal 1, location.end_code_units_offset(Encoding::UTF_32LE)

assert_equal 0, location.start_code_units_column(Encoding::UTF_8)
assert_equal 0, location.start_code_units_column(Encoding::UTF_16LE)
assert_equal 0, location.start_code_units_column(Encoding::UTF_32LE)

assert_equal 1, location.end_code_units_column(Encoding::UTF_8)
assert_equal 4, location.end_code_units_column(Encoding::UTF_8)
assert_equal 2, location.end_code_units_column(Encoding::UTF_16LE)
assert_equal 1, location.end_code_units_column(Encoding::UTF_32LE)

# second 😀
location = program.statements.body.first.arguments.arguments.first.location

assert_equal 4, location.start_code_units_offset(Encoding::UTF_8)
assert_equal 7, location.start_code_units_offset(Encoding::UTF_8)
assert_equal 5, location.start_code_units_offset(Encoding::UTF_16LE)
assert_equal 4, location.start_code_units_offset(Encoding::UTF_32LE)

assert_equal 5, location.end_code_units_offset(Encoding::UTF_8)
assert_equal 11, location.end_code_units_offset(Encoding::UTF_8)
assert_equal 7, location.end_code_units_offset(Encoding::UTF_16LE)
assert_equal 5, location.end_code_units_offset(Encoding::UTF_32LE)

assert_equal 4, location.start_code_units_column(Encoding::UTF_8)
assert_equal 7, location.start_code_units_column(Encoding::UTF_8)
assert_equal 5, location.start_code_units_column(Encoding::UTF_16LE)
assert_equal 4, location.start_code_units_column(Encoding::UTF_32LE)

assert_equal 5, location.end_code_units_column(Encoding::UTF_8)
assert_equal 11, location.end_code_units_column(Encoding::UTF_8)
assert_equal 7, location.end_code_units_column(Encoding::UTF_16LE)
assert_equal 5, location.end_code_units_column(Encoding::UTF_32LE)

# first 😍
location = program.statements.body.last.name_loc

assert_equal 6, location.start_code_units_offset(Encoding::UTF_8)
assert_equal 12, location.start_code_units_offset(Encoding::UTF_8)
assert_equal 8, location.start_code_units_offset(Encoding::UTF_16LE)
assert_equal 6, location.start_code_units_offset(Encoding::UTF_32LE)

assert_equal 7, location.end_code_units_offset(Encoding::UTF_8)
assert_equal 16, location.end_code_units_offset(Encoding::UTF_8)
assert_equal 10, location.end_code_units_offset(Encoding::UTF_16LE)
assert_equal 7, location.end_code_units_offset(Encoding::UTF_32LE)

assert_equal 0, location.start_code_units_column(Encoding::UTF_8)
assert_equal 0, location.start_code_units_column(Encoding::UTF_16LE)
assert_equal 0, location.start_code_units_column(Encoding::UTF_32LE)

assert_equal 1, location.end_code_units_column(Encoding::UTF_8)
assert_equal 4, location.end_code_units_column(Encoding::UTF_8)
assert_equal 2, location.end_code_units_column(Encoding::UTF_16LE)
assert_equal 1, location.end_code_units_column(Encoding::UTF_32LE)

# second 😍
location = program.statements.body.last.value.location

assert_equal 12, location.start_code_units_offset(Encoding::UTF_8)
assert_equal 21, location.start_code_units_offset(Encoding::UTF_8)
assert_equal 15, location.start_code_units_offset(Encoding::UTF_16LE)
assert_equal 12, location.start_code_units_offset(Encoding::UTF_32LE)

assert_equal 13, location.end_code_units_offset(Encoding::UTF_8)
assert_equal 25, location.end_code_units_offset(Encoding::UTF_8)
assert_equal 17, location.end_code_units_offset(Encoding::UTF_16LE)
assert_equal 13, location.end_code_units_offset(Encoding::UTF_32LE)

assert_equal 6, location.start_code_units_column(Encoding::UTF_8)
assert_equal 9, location.start_code_units_column(Encoding::UTF_8)
assert_equal 7, location.start_code_units_column(Encoding::UTF_16LE)
assert_equal 6, location.start_code_units_column(Encoding::UTF_32LE)

assert_equal 7, location.end_code_units_column(Encoding::UTF_8)
assert_equal 13, location.end_code_units_column(Encoding::UTF_8)
assert_equal 9, location.end_code_units_column(Encoding::UTF_16LE)
assert_equal 7, location.end_code_units_column(Encoding::UTF_32LE)
end
Expand All @@ -157,34 +157,34 @@ def test_cached_code_units
assert_equal 0, location.cached_start_code_units_offset(utf16_cache)
assert_equal 0, location.cached_start_code_units_offset(utf32_cache)

assert_equal 1, location.cached_end_code_units_offset(utf8_cache)
assert_equal 4, location.cached_end_code_units_offset(utf8_cache)
assert_equal 2, location.cached_end_code_units_offset(utf16_cache)
assert_equal 1, location.cached_end_code_units_offset(utf32_cache)

assert_equal 0, location.cached_start_code_units_column(utf8_cache)
assert_equal 0, location.cached_start_code_units_column(utf16_cache)
assert_equal 0, location.cached_start_code_units_column(utf32_cache)

assert_equal 1, location.cached_end_code_units_column(utf8_cache)
assert_equal 4, location.cached_end_code_units_column(utf8_cache)
assert_equal 2, location.cached_end_code_units_column(utf16_cache)
assert_equal 1, location.cached_end_code_units_column(utf32_cache)

# second 😀
location = result.value.statements.body.first.arguments.arguments.first.location

assert_equal 4, location.cached_start_code_units_offset(utf8_cache)
assert_equal 7, location.cached_start_code_units_offset(utf8_cache)
assert_equal 5, location.cached_start_code_units_offset(utf16_cache)
assert_equal 4, location.cached_start_code_units_offset(utf32_cache)

assert_equal 5, location.cached_end_code_units_offset(utf8_cache)
assert_equal 11, location.cached_end_code_units_offset(utf8_cache)
assert_equal 7, location.cached_end_code_units_offset(utf16_cache)
assert_equal 5, location.cached_end_code_units_offset(utf32_cache)

assert_equal 4, location.cached_start_code_units_column(utf8_cache)
assert_equal 7, location.cached_start_code_units_column(utf8_cache)
assert_equal 5, location.cached_start_code_units_column(utf16_cache)
assert_equal 4, location.cached_start_code_units_column(utf32_cache)

assert_equal 5, location.cached_end_code_units_column(utf8_cache)
assert_equal 11, location.cached_end_code_units_column(utf8_cache)
assert_equal 7, location.cached_end_code_units_column(utf16_cache)
assert_equal 5, location.cached_end_code_units_column(utf32_cache)
end
Expand All @@ -200,7 +200,7 @@ def test_code_units_binary_valid_utf8
assert_equal "😀".b.to_sym, receiver.name

location = receiver.location
assert_equal 1, location.end_code_units_column(Encoding::UTF_8)
assert_equal 4, location.end_code_units_column(Encoding::UTF_8)
assert_equal 2, location.end_code_units_column(Encoding::UTF_16LE)
assert_equal 1, location.end_code_units_column(Encoding::UTF_32LE)
end
Expand Down