diff --git a/lib/irb.rb b/lib/irb.rb index d0688e6f9fe1d1..655abaf0690e2d 100644 --- a/lib/irb.rb +++ b/lib/irb.rb @@ -140,6 +140,10 @@ # # IRB.conf[:USE_AUTOCOMPLETE] = false # +# To enable enhanced completion using type information, add the following to your +.irbrc+: +# +# IRB.conf[:COMPLETOR] = :type +# # === History # # By default, irb will store the last 1000 commands you used in diff --git a/lib/irb/cmd/irb_info.rb b/lib/irb/cmd/irb_info.rb index 75fdc386760955..5b905a09bdb1bc 100644 --- a/lib/irb/cmd/irb_info.rb +++ b/lib/irb/cmd/irb_info.rb @@ -14,6 +14,7 @@ def execute str = "Ruby version: #{RUBY_VERSION}\n" str += "IRB version: #{IRB.version}\n" str += "InputMethod: #{IRB.CurrentContext.io.inspect}\n" + str += "Completion: #{IRB.CurrentContext.io.respond_to?(:completion_info) ? IRB.CurrentContext.io.completion_info : 'off'}\n" str += ".irbrc path: #{IRB.rc_file}\n" if File.exist?(IRB.rc_file) str += "RUBY_PLATFORM: #{RUBY_PLATFORM}\n" str += "LANG env: #{ENV["LANG"]}\n" if ENV["LANG"] && !ENV["LANG"].empty? diff --git a/lib/irb/completion.rb b/lib/irb/completion.rb index 61bdc335878280..e3ebe4abfffda9 100644 --- a/lib/irb/completion.rb +++ b/lib/irb/completion.rb @@ -9,6 +9,30 @@ module IRB class BaseCompletor # :nodoc: + + # Set of reserved words used by Ruby, you should not use these for + # constants or variables + ReservedWords = %w[ + __ENCODING__ __LINE__ __FILE__ + BEGIN END + alias and + begin break + case class + def defined? do + else elsif end ensure + false for + if in + module + next nil not + or + redo rescue retry return + self super + then true + undef unless until + when while + yield + ] + def completion_candidates(preposing, target, postposing, bind:) raise NotImplementedError end @@ -94,28 +118,9 @@ def eval_class_constants end } - # Set of reserved words used by Ruby, you should not use these for - # constants or variables - ReservedWords = %w[ - __ENCODING__ __LINE__ __FILE__ - BEGIN END - alias and - begin break - case class - def defined? do - else elsif end ensure - false for - if in - module - next nil not - or - redo rescue retry return - self super - then true - undef unless until - when while - yield - ] + def inspect + 'RegexpCompletor' + end def complete_require_path(target, preposing, postposing) if target =~ /\A(['"])([^'"]+)\Z/ diff --git a/lib/irb/context.rb b/lib/irb/context.rb index a20510d73c3ee8..5dfe9d0d71d41a 100644 --- a/lib/irb/context.rb +++ b/lib/irb/context.rb @@ -86,14 +86,14 @@ def initialize(irb, workspace = nil, input_method = nil) when nil if STDIN.tty? && IRB.conf[:PROMPT_MODE] != :INF_RUBY && !use_singleline? # Both of multiline mode and singleline mode aren't specified. - @io = RelineInputMethod.new + @io = RelineInputMethod.new(build_completor) else @io = nil end when false @io = nil when true - @io = RelineInputMethod.new + @io = RelineInputMethod.new(build_completor) end unless @io case use_singleline? @@ -149,6 +149,43 @@ def initialize(irb, workspace = nil, input_method = nil) @command_aliases = IRB.conf[:COMMAND_ALIASES] end + private def build_completor + completor_type = IRB.conf[:COMPLETOR] + case completor_type + when :regexp + return RegexpCompletor.new + when :type + completor = build_type_completor + return completor if completor + else + warn "Invalid value for IRB.conf[:COMPLETOR]: #{completor_type}" + end + # Fallback to RegexpCompletor + RegexpCompletor.new + end + + TYPE_COMPLETION_REQUIRED_PRISM_VERSION = '0.17.1' + + private def build_type_completor + unless Gem::Version.new(RUBY_VERSION) >= Gem::Version.new('3.0.0') && RUBY_ENGINE != 'truffleruby' + warn 'TypeCompletion requires RUBY_VERSION >= 3.0.0' + return + end + begin + require 'prism' + rescue LoadError => e + warn "TypeCompletion requires Prism: #{e.message}" + return + end + unless Gem::Version.new(Prism::VERSION) >= Gem::Version.new(TYPE_COMPLETION_REQUIRED_PRISM_VERSION) + warn "TypeCompletion requires Prism::VERSION >= #{TYPE_COMPLETION_REQUIRED_PRISM_VERSION}" + return + end + require 'irb/type_completion/completor' + TypeCompletion::Types.preload_in_thread + TypeCompletion::Completor.new + end + def save_history=(val) IRB.conf[:SAVE_HISTORY] = val end diff --git a/lib/irb/init.rb b/lib/irb/init.rb index d9549420b4aa9d..e9111974f058c7 100644 --- a/lib/irb/init.rb +++ b/lib/irb/init.rb @@ -76,6 +76,7 @@ def IRB.init_config(ap_path) @CONF[:USE_SINGLELINE] = false unless defined?(ReadlineInputMethod) @CONF[:USE_COLORIZE] = (nc = ENV['NO_COLOR']).nil? || nc.empty? @CONF[:USE_AUTOCOMPLETE] = ENV.fetch("IRB_USE_AUTOCOMPLETE", "true") != "false" + @CONF[:COMPLETOR] = :regexp @CONF[:INSPECT_MODE] = true @CONF[:USE_TRACER] = false @CONF[:USE_LOADER] = false diff --git a/lib/irb/input-method.rb b/lib/irb/input-method.rb index cef65b71624002..94ad28cd634566 100644 --- a/lib/irb/input-method.rb +++ b/lib/irb/input-method.rb @@ -193,6 +193,10 @@ def initialize } end + def completion_info + 'RegexpCompletor' + end + # Reads the next line from this input method. # # See IO#gets for more information. @@ -230,13 +234,13 @@ class RelineInputMethod < StdioInputMethod HISTORY = Reline::HISTORY include HistorySavingAbility # Creates a new input method object using Reline - def initialize + def initialize(completor) IRB.__send__(:set_encoding, Reline.encoding_system_needs.name, override: false) - super + super() @eof = false - @completor = RegexpCompletor.new + @completor = completor Reline.basic_word_break_characters = BASIC_WORD_BREAK_CHARACTERS Reline.completion_append_character = nil @@ -270,6 +274,11 @@ def initialize end end + def completion_info + autocomplete_message = Reline.autocompletion ? 'Autocomplete' : 'Tab Complete' + "#{autocomplete_message}, #{@completor.inspect}" + end + def check_termination(&block) @check_termination_proc = block end diff --git a/lib/irb/type_completion/completor.rb b/lib/irb/type_completion/completor.rb new file mode 100644 index 00000000000000..e893fd8adcd52a --- /dev/null +++ b/lib/irb/type_completion/completor.rb @@ -0,0 +1,235 @@ +# frozen_string_literal: true + +require 'prism' +require 'irb/completion' +require_relative 'type_analyzer' + +module IRB + module TypeCompletion + class Completor < BaseCompletor # :nodoc: + HIDDEN_METHODS = %w[Namespace TypeName] # defined by rbs, should be hidden + + class << self + attr_accessor :last_completion_error + end + + def inspect + name = 'TypeCompletion::Completor' + prism_info = "Prism: #{Prism::VERSION}" + if Types.rbs_builder + "#{name}(#{prism_info}, RBS: #{RBS::VERSION})" + elsif Types.rbs_load_error + "#{name}(#{prism_info}, RBS: #{Types.rbs_load_error.inspect})" + else + "#{name}(#{prism_info}, RBS: loading)" + end + end + + def completion_candidates(preposing, target, _postposing, bind:) + @preposing = preposing + verbose, $VERBOSE = $VERBOSE, nil + code = "#{preposing}#{target}" + @result = analyze code, bind + name, candidates = candidates_from_result(@result) + + all_symbols_pattern = /\A[ -\/:-@\[-`\{-~]*\z/ + candidates.map(&:to_s).select { !_1.match?(all_symbols_pattern) && _1.start_with?(name) }.uniq.sort.map do + target + _1[name.size..] + end + rescue SyntaxError, StandardError => e + Completor.last_completion_error = e + handle_error(e) + [] + ensure + $VERBOSE = verbose + end + + def doc_namespace(preposing, matched, postposing, bind:) + name = matched[/[a-zA-Z_0-9]*[!?=]?\z/] + method_doc = -> type do + type = type.types.find { _1.all_methods.include? name.to_sym } + case type + when Types::SingletonType + "#{Types.class_name_of(type.module_or_class)}.#{name}" + when Types::InstanceType + "#{Types.class_name_of(type.klass)}##{name}" + end + end + call_or_const_doc = -> type do + if name =~ /\A[A-Z]/ + type = type.types.grep(Types::SingletonType).find { _1.module_or_class.const_defined?(name) } + type.module_or_class == Object ? name : "#{Types.class_name_of(type.module_or_class)}::#{name}" if type + else + method_doc.call(type) + end + end + + value_doc = -> type do + return unless type + type.types.each do |t| + case t + when Types::SingletonType + return Types.class_name_of(t.module_or_class) + when Types::InstanceType + return Types.class_name_of(t.klass) + end + end + nil + end + + case @result + in [:call_or_const, type, _name, _self_call] + call_or_const_doc.call type + in [:const, type, _name, scope] + if type + call_or_const_doc.call type + else + value_doc.call scope[name] + end + in [:gvar, _name, scope] + value_doc.call scope["$#{name}"] + in [:ivar, _name, scope] + value_doc.call scope["@#{name}"] + in [:cvar, _name, scope] + value_doc.call scope["@@#{name}"] + in [:call, type, _name, _self_call] + method_doc.call type + in [:lvar_or_method, _name, scope] + if scope.local_variables.include?(name) + value_doc.call scope[name] + else + method_doc.call scope.self_type + end + else + end + end + + def candidates_from_result(result) + candidates = case result + in [:require, name] + retrieve_files_to_require_from_load_path + in [:require_relative, name] + retrieve_files_to_require_relative_from_current_dir + in [:call_or_const, type, name, self_call] + ((self_call ? type.all_methods : type.methods).map(&:to_s) - HIDDEN_METHODS) | type.constants + in [:const, type, name, scope] + if type + scope_constants = type.types.flat_map do |t| + scope.table_module_constants(t.module_or_class) if t.is_a?(Types::SingletonType) + end + (scope_constants.compact | type.constants.map(&:to_s)).sort + else + scope.constants.sort | ReservedWords + end + in [:ivar, name, scope] + ivars = scope.instance_variables.sort + name == '@' ? ivars + scope.class_variables.sort : ivars + in [:cvar, name, scope] + scope.class_variables + in [:gvar, name, scope] + scope.global_variables + in [:symbol, name] + Symbol.all_symbols.map { _1.inspect[1..] } + in [:call, type, name, self_call] + (self_call ? type.all_methods : type.methods).map(&:to_s) - HIDDEN_METHODS + in [:lvar_or_method, name, scope] + scope.self_type.all_methods.map(&:to_s) | scope.local_variables | ReservedWords + else + [] + end + [name || '', candidates] + end + + def analyze(code, binding = Object::TOPLEVEL_BINDING) + # Workaround for https://github.com/ruby/prism/issues/1592 + return if code.match?(/%[qQ]\z/) + + ast = Prism.parse(code, scopes: [binding.local_variables]).value + name = code[/(@@|@|\$)?\w*[!?=]?\z/] + *parents, target_node = find_target ast, code.bytesize - name.bytesize + return unless target_node + + calculate_scope = -> { TypeAnalyzer.calculate_target_type_scope(binding, parents, target_node).last } + calculate_type_scope = ->(node) { TypeAnalyzer.calculate_target_type_scope binding, [*parents, target_node], node } + + case target_node + when Prism::StringNode, Prism::InterpolatedStringNode + call_node, args_node = parents.last(2) + return unless call_node.is_a?(Prism::CallNode) && call_node.receiver.nil? + return unless args_node.is_a?(Prism::ArgumentsNode) && args_node.arguments.size == 1 + + case call_node.name + when :require + [:require, name.rstrip] + when :require_relative + [:require_relative, name.rstrip] + end + when Prism::SymbolNode + if parents.last.is_a? Prism::BlockArgumentNode # method(&:target) + receiver_type, _scope = calculate_type_scope.call target_node + [:call, receiver_type, name, false] + else + [:symbol, name] unless name.empty? + end + when Prism::CallNode + return [:lvar_or_method, name, calculate_scope.call] if target_node.receiver.nil? + + self_call = target_node.receiver.is_a? Prism::SelfNode + op = target_node.call_operator + receiver_type, _scope = calculate_type_scope.call target_node.receiver + receiver_type = receiver_type.nonnillable if op == '&.' + [op == '::' ? :call_or_const : :call, receiver_type, name, self_call] + when Prism::LocalVariableReadNode, Prism::LocalVariableTargetNode + [:lvar_or_method, name, calculate_scope.call] + when Prism::ConstantReadNode, Prism::ConstantTargetNode + if parents.last.is_a? Prism::ConstantPathNode + path_node = parents.last + if path_node.parent # A::B + receiver, scope = calculate_type_scope.call(path_node.parent) + [:const, receiver, name, scope] + else # ::A + scope = calculate_scope.call + [:const, Types::SingletonType.new(Object), name, scope] + end + else + [:const, nil, name, calculate_scope.call] + end + when Prism::GlobalVariableReadNode, Prism::GlobalVariableTargetNode + [:gvar, name, calculate_scope.call] + when Prism::InstanceVariableReadNode, Prism::InstanceVariableTargetNode + [:ivar, name, calculate_scope.call] + when Prism::ClassVariableReadNode, Prism::ClassVariableTargetNode + [:cvar, name, calculate_scope.call] + end + end + + def find_target(node, position) + location = ( + case node + when Prism::CallNode + node.message_loc + when Prism::SymbolNode + node.value_loc + when Prism::StringNode + node.content_loc + when Prism::InterpolatedStringNode + node.closing_loc if node.parts.empty? + end + ) + return [node] if location&.start_offset == position + + node.compact_child_nodes.each do |n| + match = find_target(n, position) + next unless match + match.unshift node + return match + end + + [node] if node.location.start_offset == position + end + + def handle_error(e) + end + end + end +end diff --git a/lib/irb/type_completion/methods.rb b/lib/irb/type_completion/methods.rb new file mode 100644 index 00000000000000..8a88b6d0f96fbc --- /dev/null +++ b/lib/irb/type_completion/methods.rb @@ -0,0 +1,13 @@ +# frozen_string_literal: true + +module IRB + module TypeCompletion + module Methods + OBJECT_SINGLETON_CLASS_METHOD = Object.instance_method(:singleton_class) + OBJECT_INSTANCE_VARIABLES_METHOD = Object.instance_method(:instance_variables) + OBJECT_INSTANCE_VARIABLE_GET_METHOD = Object.instance_method(:instance_variable_get) + OBJECT_CLASS_METHOD = Object.instance_method(:class) + MODULE_NAME_METHOD = Module.instance_method(:name) + end + end +end diff --git a/lib/irb/type_completion/scope.rb b/lib/irb/type_completion/scope.rb new file mode 100644 index 00000000000000..5a58a0ed65c938 --- /dev/null +++ b/lib/irb/type_completion/scope.rb @@ -0,0 +1,412 @@ +# frozen_string_literal: true + +require 'set' +require_relative 'types' + +module IRB + module TypeCompletion + + class RootScope + attr_reader :module_nesting, :self_object + + def initialize(binding, self_object, local_variables) + @binding = binding + @self_object = self_object + @cache = {} + modules = [*binding.eval('::Module.nesting'), Object] + @module_nesting = modules.map { [_1, []] } + binding_local_variables = binding.local_variables + uninitialized_locals = local_variables - binding_local_variables + uninitialized_locals.each { @cache[_1] = Types::NIL } + @local_variables = (local_variables | binding_local_variables).map(&:to_s).to_set + @global_variables = Kernel.global_variables.map(&:to_s).to_set + @owned_constants_cache = {} + end + + def level() = 0 + + def level_of(_name, _var_type) = 0 + + def mutable?() = false + + def module_own_constant?(mod, name) + set = (@owned_constants_cache[mod] ||= Set.new(mod.constants.map(&:to_s))) + set.include? name + end + + def get_const(nesting, path, _key = nil) + return unless nesting + + result = path.reduce nesting do |mod, name| + return nil unless mod.is_a?(Module) && module_own_constant?(mod, name) + mod.const_get name + end + Types.type_from_object result + end + + def get_cvar(nesting, path, name, _key = nil) + return Types::NIL unless nesting + + result = path.reduce nesting do |mod, n| + return Types::NIL unless mod.is_a?(Module) && module_own_constant?(mod, n) + mod.const_get n + end + value = result.class_variable_get name if result.is_a?(Module) && name.size >= 3 && result.class_variable_defined?(name) + Types.type_from_object value + end + + def [](name) + @cache[name] ||= ( + value = case RootScope.type_by_name name + when :ivar + begin + Methods::OBJECT_INSTANCE_VARIABLE_GET_METHOD.bind_call(@self_object, name) + rescue NameError + end + when :lvar + begin + @binding.local_variable_get(name) + rescue NameError + end + when :gvar + @binding.eval name if @global_variables.include? name + end + Types.type_from_object(value) + ) + end + + def self_type + Types.type_from_object @self_object + end + + def local_variables() = @local_variables.to_a + + def global_variables() = @global_variables.to_a + + def self.type_by_name(name) + if name.start_with? '@@' + # "@@cvar" or "@@cvar::[module_id]::[module_path]" + :cvar + elsif name.start_with? '@' + :ivar + elsif name.start_with? '$' + :gvar + elsif name.start_with? '%' + :internal + elsif name[0].downcase != name[0] || name[0].match?(/\d/) + # "ConstName" or "[module_id]::[const_path]" + :const + else + :lvar + end + end + end + + class Scope + BREAK_RESULT = '%break' + NEXT_RESULT = '%next' + RETURN_RESULT = '%return' + PATTERNMATCH_BREAK = '%match' + + attr_reader :parent, :mergeable_changes, :level, :module_nesting + + def self.from_binding(binding, locals) = new(RootScope.new(binding, binding.receiver, locals)) + + def initialize(parent, table = {}, trace_ivar: true, trace_lvar: true, self_type: nil, nesting: nil) + @parent = parent + @level = parent.level + 1 + @trace_ivar = trace_ivar + @trace_lvar = trace_lvar + @module_nesting = nesting ? [nesting, *parent.module_nesting] : parent.module_nesting + @self_type = self_type + @terminated = false + @jump_branches = [] + @mergeable_changes = @table = table.transform_values { [level, _1] } + end + + def mutable? = true + + def terminated? + @terminated + end + + def terminate_with(type, value) + return if terminated? + store_jump type, value, @mergeable_changes + terminate + end + + def store_jump(type, value, changes) + return if terminated? + if has_own?(type) + changes[type] = [level, value] + @jump_branches << changes + elsif @parent.mutable? + @parent.store_jump(type, value, changes) + end + end + + def terminate + return if terminated? + @terminated = true + @table = @mergeable_changes.dup + end + + def trace?(name) + return false unless @parent + type = RootScope.type_by_name(name) + type == :ivar ? @trace_ivar : type == :lvar ? @trace_lvar : true + end + + def level_of(name, var_type) + case var_type + when :ivar + return level unless @trace_ivar + when :gvar + return 0 + end + variable_level, = @table[name] + variable_level || parent.level_of(name, var_type) + end + + def get_const(nesting, path, key = nil) + key ||= [nesting.__id__, path].join('::') + _l, value = @table[key] + value || @parent.get_const(nesting, path, key) + end + + def get_cvar(nesting, path, name, key = nil) + key ||= [name, nesting.__id__, path].join('::') + _l, value = @table[key] + value || @parent.get_cvar(nesting, path, name, key) + end + + def [](name) + type = RootScope.type_by_name(name) + if type == :const + return get_const(nil, nil, name) || Types::NIL if name.include?('::') + + module_nesting.each do |(nesting, path)| + value = get_const nesting, [*path, name] + return value if value + end + return Types::NIL + elsif type == :cvar + return get_cvar(nil, nil, nil, name) if name.include?('::') + + nesting, path = module_nesting.first + return get_cvar(nesting, path, name) + end + level, value = @table[name] + if level + value + elsif trace? name + @parent[name] + elsif type == :ivar + self_instance_variable_get name + end + end + + def set_const(nesting, path, value) + key = [nesting.__id__, path].join('::') + @table[key] = [0, value] + end + + def set_cvar(nesting, path, name, value) + key = [name, nesting.__id__, path].join('::') + @table[key] = [0, value] + end + + def []=(name, value) + type = RootScope.type_by_name(name) + if type == :const + if name.include?('::') + @table[name] = [0, value] + else + parent_module, parent_path = module_nesting.first + set_const parent_module, [*parent_path, name], value + end + return + elsif type == :cvar + if name.include?('::') + @table[name] = [0, value] + else + parent_module, parent_path = module_nesting.first + set_cvar parent_module, parent_path, name, value + end + return + end + variable_level = level_of name, type + @table[name] = [variable_level, value] if variable_level + end + + def self_type + @self_type || @parent.self_type + end + + def global_variables + gvar_keys = @table.keys.select do |name| + RootScope.type_by_name(name) == :gvar + end + gvar_keys | @parent.global_variables + end + + def local_variables + lvar_keys = @table.keys.select do |name| + RootScope.type_by_name(name) == :lvar + end + lvar_keys |= @parent.local_variables if @trace_lvar + lvar_keys + end + + def table_constants + constants = module_nesting.flat_map do |mod, path| + prefix = [mod.__id__, *path].join('::') + '::' + @table.keys.select { _1.start_with? prefix }.map { _1.delete_prefix(prefix).split('::').first } + end.uniq + constants |= @parent.table_constants if @parent.mutable? + constants + end + + def table_module_constants(mod) + prefix = "#{mod.__id__}::" + constants = @table.keys.select { _1.start_with? prefix }.map { _1.delete_prefix(prefix).split('::').first } + constants |= @parent.table_constants if @parent.mutable? + constants + end + + def base_scope + @parent.mutable? ? @parent.base_scope : @parent + end + + def table_instance_variables + ivars = @table.keys.select { RootScope.type_by_name(_1) == :ivar } + ivars |= @parent.table_instance_variables if @parent.mutable? && @trace_ivar + ivars + end + + def instance_variables + self_singleton_types = self_type.types.grep(Types::SingletonType) + singleton_classes = self_type.types.grep(Types::InstanceType).map(&:klass).select(&:singleton_class?) + base_self = base_scope.self_object + self_instance_variables = singleton_classes.flat_map do |singleton_class| + if singleton_class.respond_to? :attached_object + Methods::OBJECT_INSTANCE_VARIABLES_METHOD.bind_call(singleton_class.attached_object).map(&:to_s) + elsif singleton_class == Methods::OBJECT_SINGLETON_CLASS_METHOD.bind_call(base_self) + Methods::OBJECT_INSTANCE_VARIABLES_METHOD.bind_call(base_self).map(&:to_s) + else + [] + end + end + [ + self_singleton_types.flat_map { _1.module_or_class.instance_variables.map(&:to_s) }, + self_instance_variables || [], + table_instance_variables + ].inject(:|) + end + + def self_instance_variable_get(name) + self_objects = self_type.types.grep(Types::SingletonType).map(&:module_or_class) + singleton_classes = self_type.types.grep(Types::InstanceType).map(&:klass).select(&:singleton_class?) + base_self = base_scope.self_object + singleton_classes.each do |singleton_class| + if singleton_class.respond_to? :attached_object + self_objects << singleton_class.attached_object + elsif singleton_class == base_self.singleton_class + self_objects << base_self + end + end + types = self_objects.map do |object| + value = begin + Methods::OBJECT_INSTANCE_VARIABLE_GET_METHOD.bind_call(object, name) + rescue NameError + end + Types.type_from_object value + end + Types::UnionType[*types] + end + + def table_class_variables + cvars = @table.keys.filter_map { _1.split('::', 2).first if RootScope.type_by_name(_1) == :cvar } + cvars |= @parent.table_class_variables if @parent.mutable? + cvars + end + + def class_variables + cvars = table_class_variables + m, = module_nesting.first + cvars |= m.class_variables.map(&:to_s) if m.is_a? Module + cvars + end + + def constants + module_nesting.flat_map do |nest,| + nest.constants + end.map(&:to_s) | table_constants + end + + def merge_jumps + if terminated? + @terminated = false + @table = @mergeable_changes + merge @jump_branches + @terminated = true + else + merge [*@jump_branches, {}] + end + end + + def conditional(&block) + run_branches(block, ->(_s) {}).first || Types::NIL + end + + def never(&block) + block.call Scope.new(self, { BREAK_RESULT => nil, NEXT_RESULT => nil, PATTERNMATCH_BREAK => nil, RETURN_RESULT => nil }) + end + + def run_branches(*blocks) + results = [] + branches = [] + blocks.each do |block| + scope = Scope.new self + result = block.call scope + next if scope.terminated? + results << result + branches << scope.mergeable_changes + end + terminate if branches.empty? + merge branches + results + end + + def has_own?(name) + @table.key? name + end + + def update(child_scope) + current_level = level + child_scope.mergeable_changes.each do |name, (level, value)| + self[name] = value if level <= current_level + end + end + + protected + + def merge(branches) + current_level = level + merge = {} + branches.each do |changes| + changes.each do |name, (level, value)| + next if current_level < level + (merge[name] ||= []) << value + end + end + merge.each do |name, values| + values << self[name] unless values.size == branches.size + values.compact! + self[name] = Types::UnionType[*values.compact] unless values.empty? + end + end + end + end +end diff --git a/lib/irb/type_completion/type_analyzer.rb b/lib/irb/type_completion/type_analyzer.rb new file mode 100644 index 00000000000000..c4a41e49993b8b --- /dev/null +++ b/lib/irb/type_completion/type_analyzer.rb @@ -0,0 +1,1169 @@ +# frozen_string_literal: true + +require 'set' +require_relative 'types' +require_relative 'scope' +require 'prism' + +module IRB + module TypeCompletion + class TypeAnalyzer + class DigTarget + def initialize(parents, receiver, &block) + @dig_ids = parents.to_h { [_1.__id__, true] } + @target_id = receiver.__id__ + @block = block + end + + def dig?(node) = @dig_ids[node.__id__] + def target?(node) = @target_id == node.__id__ + def resolve(type, scope) + @block.call type, scope + end + end + + OBJECT_METHODS = { + to_s: Types::STRING, + to_str: Types::STRING, + to_a: Types::ARRAY, + to_ary: Types::ARRAY, + to_h: Types::HASH, + to_hash: Types::HASH, + to_i: Types::INTEGER, + to_int: Types::INTEGER, + to_f: Types::FLOAT, + to_c: Types::COMPLEX, + to_r: Types::RATIONAL + } + + def initialize(dig_targets) + @dig_targets = dig_targets + end + + def evaluate(node, scope) + method = "evaluate_#{node.type}" + if respond_to? method + result = send method, node, scope + else + result = Types::NIL + end + @dig_targets.resolve result, scope if @dig_targets.target? node + result + end + + def evaluate_program_node(node, scope) + evaluate node.statements, scope + end + + def evaluate_statements_node(node, scope) + if node.body.empty? + Types::NIL + else + node.body.map { evaluate _1, scope }.last + end + end + + def evaluate_def_node(node, scope) + if node.receiver + self_type = evaluate node.receiver, scope + else + current_self_types = scope.self_type.types + self_types = current_self_types.map do |type| + if type.is_a?(Types::SingletonType) && type.module_or_class.is_a?(Class) + Types::InstanceType.new type.module_or_class + else + type + end + end + self_type = Types::UnionType[*self_types] + end + if @dig_targets.dig?(node.body) || @dig_targets.dig?(node.parameters) + params_table = node.locals.to_h { [_1.to_s, Types::NIL] } + method_scope = Scope.new( + scope, + { **params_table, Scope::BREAK_RESULT => nil, Scope::NEXT_RESULT => nil, Scope::RETURN_RESULT => nil }, + self_type: self_type, + trace_lvar: false, + trace_ivar: false + ) + if node.parameters + # node.parameters is Prism::ParametersNode + assign_parameters node.parameters, method_scope, [], {} + end + + if @dig_targets.dig?(node.body) + method_scope.conditional do |s| + evaluate node.body, s + end + end + method_scope.merge_jumps + scope.update method_scope + end + Types::SYMBOL + end + + def evaluate_integer_node(_node, _scope) = Types::INTEGER + + def evaluate_float_node(_node, _scope) = Types::FLOAT + + def evaluate_rational_node(_node, _scope) = Types::RATIONAL + + def evaluate_imaginary_node(_node, _scope) = Types::COMPLEX + + def evaluate_string_node(_node, _scope) = Types::STRING + + def evaluate_x_string_node(_node, _scope) + Types::UnionType[Types::STRING, Types::NIL] + end + + def evaluate_symbol_node(_node, _scope) = Types::SYMBOL + + def evaluate_regular_expression_node(_node, _scope) = Types::REGEXP + + def evaluate_string_concat_node(node, scope) + evaluate node.left, scope + evaluate node.right, scope + Types::STRING + end + + def evaluate_interpolated_string_node(node, scope) + node.parts.each { evaluate _1, scope } + Types::STRING + end + + def evaluate_interpolated_x_string_node(node, scope) + node.parts.each { evaluate _1, scope } + Types::STRING + end + + def evaluate_interpolated_symbol_node(node, scope) + node.parts.each { evaluate _1, scope } + Types::SYMBOL + end + + def evaluate_interpolated_regular_expression_node(node, scope) + node.parts.each { evaluate _1, scope } + Types::REGEXP + end + + def evaluate_embedded_statements_node(node, scope) + node.statements ? evaluate(node.statements, scope) : Types::NIL + Types::STRING + end + + def evaluate_embedded_variable_node(node, scope) + evaluate node.variable, scope + Types::STRING + end + + def evaluate_array_node(node, scope) + Types.array_of evaluate_list_splat_items(node.elements, scope) + end + + def evaluate_hash_node(node, scope) = evaluate_hash(node, scope) + def evaluate_keyword_hash_node(node, scope) = evaluate_hash(node, scope) + def evaluate_hash(node, scope) + keys = [] + values = [] + node.elements.each do |assoc| + case assoc + when Prism::AssocNode + keys << evaluate(assoc.key, scope) + values << evaluate(assoc.value, scope) + when Prism::AssocSplatNode + next unless assoc.value # def f(**); {**} + + hash = evaluate assoc.value, scope + unless hash.is_a?(Types::InstanceType) && hash.klass == Hash + hash = method_call hash, :to_hash, [], nil, nil, scope + end + if hash.is_a?(Types::InstanceType) && hash.klass == Hash + keys << hash.params[:K] if hash.params[:K] + values << hash.params[:V] if hash.params[:V] + end + end + end + if keys.empty? && values.empty? + Types::InstanceType.new Hash + else + Types::InstanceType.new Hash, K: Types::UnionType[*keys], V: Types::UnionType[*values] + end + end + + def evaluate_parentheses_node(node, scope) + node.body ? evaluate(node.body, scope) : Types::NIL + end + + def evaluate_constant_path_node(node, scope) + type, = evaluate_constant_node_info node, scope + type + end + + def evaluate_self_node(_node, scope) = scope.self_type + + def evaluate_true_node(_node, _scope) = Types::TRUE + + def evaluate_false_node(_node, _scope) = Types::FALSE + + def evaluate_nil_node(_node, _scope) = Types::NIL + + def evaluate_source_file_node(_node, _scope) = Types::STRING + + def evaluate_source_line_node(_node, _scope) = Types::INTEGER + + def evaluate_source_encoding_node(_node, _scope) = Types::InstanceType.new(Encoding) + + def evaluate_numbered_reference_read_node(_node, _scope) + Types::UnionType[Types::STRING, Types::NIL] + end + + def evaluate_back_reference_read_node(_node, _scope) + Types::UnionType[Types::STRING, Types::NIL] + end + + def evaluate_reference_read(node, scope) + scope[node.name.to_s] || Types::NIL + end + alias evaluate_constant_read_node evaluate_reference_read + alias evaluate_global_variable_read_node evaluate_reference_read + alias evaluate_local_variable_read_node evaluate_reference_read + alias evaluate_class_variable_read_node evaluate_reference_read + alias evaluate_instance_variable_read_node evaluate_reference_read + + + def evaluate_call_node(node, scope) + is_field_assign = node.name.match?(/[^<>=!\]]=\z/) || (node.name == :[]= && !node.call_operator) + receiver_type = node.receiver ? evaluate(node.receiver, scope) : scope.self_type + evaluate_method = lambda do |scope| + args_types, kwargs_types, block_sym_node, has_block = evaluate_call_node_arguments node, scope + + if block_sym_node + block_sym = block_sym_node.value + if @dig_targets.target? block_sym_node + # method(args, &:completion_target) + call_block_proc = ->(block_args, _self_type) do + block_receiver = block_args.first || Types::OBJECT + @dig_targets.resolve block_receiver, scope + Types::OBJECT + end + else + call_block_proc = ->(block_args, _self_type) do + block_receiver, *rest = block_args + block_receiver ? method_call(block_receiver || Types::OBJECT, block_sym, rest, nil, nil, scope) : Types::OBJECT + end + end + elsif node.block.is_a? Prism::BlockNode + call_block_proc = ->(block_args, block_self_type) do + scope.conditional do |s| + numbered_parameters = node.block.locals.grep(/\A_[1-9]/).map(&:to_s) + params_table = node.block.locals.to_h { [_1.to_s, Types::NIL] } + table = { **params_table, Scope::BREAK_RESULT => nil, Scope::NEXT_RESULT => nil } + block_scope = Scope.new s, table, self_type: block_self_type, trace_ivar: !block_self_type + # TODO kwargs + if node.block.parameters&.parameters + # node.block.parameters is Prism::BlockParametersNode + assign_parameters node.block.parameters.parameters, block_scope, block_args, {} + elsif !numbered_parameters.empty? + assign_numbered_parameters numbered_parameters, block_scope, block_args, {} + end + result = node.block.body ? evaluate(node.block.body, block_scope) : Types::NIL + block_scope.merge_jumps + s.update block_scope + nexts = block_scope[Scope::NEXT_RESULT] + breaks = block_scope[Scope::BREAK_RESULT] + if block_scope.terminated? + [Types::UnionType[*nexts], breaks] + else + [Types::UnionType[result, *nexts], breaks] + end + end + end + elsif has_block + call_block_proc = ->(_block_args, _self_type) { Types::OBJECT } + end + result = method_call receiver_type, node.name, args_types, kwargs_types, call_block_proc, scope + if is_field_assign + args_types.last || Types::NIL + else + result + end + end + if node.call_operator == '&.' + result = scope.conditional { evaluate_method.call _1 } + if receiver_type.nillable? + Types::UnionType[result, Types::NIL] + else + result + end + else + evaluate_method.call scope + end + end + + def evaluate_and_node(node, scope) = evaluate_and_or(node, scope, and_op: true) + def evaluate_or_node(node, scope) = evaluate_and_or(node, scope, and_op: false) + def evaluate_and_or(node, scope, and_op:) + left = evaluate node.left, scope + right = scope.conditional { evaluate node.right, _1 } + if and_op + Types::UnionType[right, Types::NIL, Types::FALSE] + else + Types::UnionType[left, right] + end + end + + def evaluate_call_operator_write_node(node, scope) = evaluate_call_write(node, scope, :operator, node.write_name) + def evaluate_call_and_write_node(node, scope) = evaluate_call_write(node, scope, :and, node.write_name) + def evaluate_call_or_write_node(node, scope) = evaluate_call_write(node, scope, :or, node.write_name) + def evaluate_index_operator_write_node(node, scope) = evaluate_call_write(node, scope, :operator, :[]=) + def evaluate_index_and_write_node(node, scope) = evaluate_call_write(node, scope, :and, :[]=) + def evaluate_index_or_write_node(node, scope) = evaluate_call_write(node, scope, :or, :[]=) + def evaluate_call_write(node, scope, operator, write_name) + receiver_type = evaluate node.receiver, scope + if write_name == :[]= + args_types, kwargs_types, block_sym_node, has_block = evaluate_call_node_arguments node, scope + else + args_types = [] + end + if block_sym_node + block_sym = block_sym_node.value + call_block_proc = ->(block_args, _self_type) do + block_receiver, *rest = block_args + block_receiver ? method_call(block_receiver || Types::OBJECT, block_sym, rest, nil, nil, scope) : Types::OBJECT + end + elsif has_block + call_block_proc = ->(_block_args, _self_type) { Types::OBJECT } + end + method = write_name.to_s.delete_suffix('=') + left = method_call receiver_type, method, args_types, kwargs_types, call_block_proc, scope + case operator + when :and + right = scope.conditional { evaluate node.value, _1 } + Types::UnionType[right, Types::NIL, Types::FALSE] + when :or + right = scope.conditional { evaluate node.value, _1 } + Types::UnionType[left, right] + else + right = evaluate node.value, scope + method_call left, node.operator, [right], nil, nil, scope, name_match: false + end + end + + def evaluate_variable_operator_write(node, scope) + left = scope[node.name.to_s] || Types::OBJECT + right = evaluate node.value, scope + scope[node.name.to_s] = method_call left, node.operator, [right], nil, nil, scope, name_match: false + end + alias evaluate_global_variable_operator_write_node evaluate_variable_operator_write + alias evaluate_local_variable_operator_write_node evaluate_variable_operator_write + alias evaluate_class_variable_operator_write_node evaluate_variable_operator_write + alias evaluate_instance_variable_operator_write_node evaluate_variable_operator_write + + def evaluate_variable_and_write(node, scope) + right = scope.conditional { evaluate node.value, scope } + scope[node.name.to_s] = Types::UnionType[right, Types::NIL, Types::FALSE] + end + alias evaluate_global_variable_and_write_node evaluate_variable_and_write + alias evaluate_local_variable_and_write_node evaluate_variable_and_write + alias evaluate_class_variable_and_write_node evaluate_variable_and_write + alias evaluate_instance_variable_and_write_node evaluate_variable_and_write + + def evaluate_variable_or_write(node, scope) + left = scope[node.name.to_s] || Types::OBJECT + right = scope.conditional { evaluate node.value, scope } + scope[node.name.to_s] = Types::UnionType[left, right] + end + alias evaluate_global_variable_or_write_node evaluate_variable_or_write + alias evaluate_local_variable_or_write_node evaluate_variable_or_write + alias evaluate_class_variable_or_write_node evaluate_variable_or_write + alias evaluate_instance_variable_or_write_node evaluate_variable_or_write + + def evaluate_constant_operator_write_node(node, scope) + left = scope[node.name.to_s] || Types::OBJECT + right = evaluate node.value, scope + scope[node.name.to_s] = method_call left, node.operator, [right], nil, nil, scope, name_match: false + end + + def evaluate_constant_and_write_node(node, scope) + right = scope.conditional { evaluate node.value, scope } + scope[node.name.to_s] = Types::UnionType[right, Types::NIL, Types::FALSE] + end + + def evaluate_constant_or_write_node(node, scope) + left = scope[node.name.to_s] || Types::OBJECT + right = scope.conditional { evaluate node.value, scope } + scope[node.name.to_s] = Types::UnionType[left, right] + end + + def evaluate_constant_path_operator_write_node(node, scope) + left, receiver, _parent_module, name = evaluate_constant_node_info node.target, scope + right = evaluate node.value, scope + value = method_call left, node.operator, [right], nil, nil, scope, name_match: false + const_path_write receiver, name, value, scope + value + end + + def evaluate_constant_path_and_write_node(node, scope) + _left, receiver, _parent_module, name = evaluate_constant_node_info node.target, scope + right = scope.conditional { evaluate node.value, scope } + value = Types::UnionType[right, Types::NIL, Types::FALSE] + const_path_write receiver, name, value, scope + value + end + + def evaluate_constant_path_or_write_node(node, scope) + left, receiver, _parent_module, name = evaluate_constant_node_info node.target, scope + right = scope.conditional { evaluate node.value, scope } + value = Types::UnionType[left, right] + const_path_write receiver, name, value, scope + value + end + + def evaluate_constant_path_write_node(node, scope) + receiver = evaluate node.target.parent, scope if node.target.parent + value = evaluate node.value, scope + const_path_write receiver, node.target.child.name.to_s, value, scope + value + end + + def evaluate_lambda_node(node, scope) + local_table = node.locals.to_h { [_1.to_s, Types::OBJECT] } + block_scope = Scope.new scope, { **local_table, Scope::BREAK_RESULT => nil, Scope::NEXT_RESULT => nil, Scope::RETURN_RESULT => nil } + block_scope.conditional do |s| + assign_parameters node.parameters.parameters, s, [], {} if node.parameters&.parameters + evaluate node.body, s if node.body + end + block_scope.merge_jumps + scope.update block_scope + Types::PROC + end + + def evaluate_reference_write(node, scope) + scope[node.name.to_s] = evaluate node.value, scope + end + alias evaluate_constant_write_node evaluate_reference_write + alias evaluate_global_variable_write_node evaluate_reference_write + alias evaluate_local_variable_write_node evaluate_reference_write + alias evaluate_class_variable_write_node evaluate_reference_write + alias evaluate_instance_variable_write_node evaluate_reference_write + + def evaluate_multi_write_node(node, scope) + evaluated_receivers = {} + evaluate_multi_write_receiver node, scope, evaluated_receivers + value = ( + if node.value.is_a? Prism::ArrayNode + if node.value.elements.any?(Prism::SplatNode) + evaluate node.value, scope + else + node.value.elements.map do |n| + evaluate n, scope + end + end + elsif node.value + evaluate node.value, scope + else + Types::NIL + end + ) + evaluate_multi_write node, value, scope, evaluated_receivers + value.is_a?(Array) ? Types.array_of(*value) : value + end + + def evaluate_if_node(node, scope) = evaluate_if_unless(node, scope) + def evaluate_unless_node(node, scope) = evaluate_if_unless(node, scope) + def evaluate_if_unless(node, scope) + evaluate node.predicate, scope + Types::UnionType[*scope.run_branches( + -> { node.statements ? evaluate(node.statements, _1) : Types::NIL }, + -> { node.consequent ? evaluate(node.consequent, _1) : Types::NIL } + )] + end + + def evaluate_else_node(node, scope) + node.statements ? evaluate(node.statements, scope) : Types::NIL + end + + def evaluate_while_until(node, scope) + inner_scope = Scope.new scope, { Scope::BREAK_RESULT => nil } + evaluate node.predicate, inner_scope + if node.statements + inner_scope.conditional do |s| + evaluate node.statements, s + end + end + inner_scope.merge_jumps + scope.update inner_scope + breaks = inner_scope[Scope::BREAK_RESULT] + breaks ? Types::UnionType[breaks, Types::NIL] : Types::NIL + end + alias evaluate_while_node evaluate_while_until + alias evaluate_until_node evaluate_while_until + + def evaluate_break_node(node, scope) = evaluate_jump(node, scope, :break) + def evaluate_next_node(node, scope) = evaluate_jump(node, scope, :next) + def evaluate_return_node(node, scope) = evaluate_jump(node, scope, :return) + def evaluate_jump(node, scope, mode) + internal_key = ( + case mode + when :break + Scope::BREAK_RESULT + when :next + Scope::NEXT_RESULT + when :return + Scope::RETURN_RESULT + end + ) + jump_value = ( + arguments = node.arguments&.arguments + if arguments.nil? || arguments.empty? + Types::NIL + elsif arguments.size == 1 && !arguments.first.is_a?(Prism::SplatNode) + evaluate arguments.first, scope + else + Types.array_of evaluate_list_splat_items(arguments, scope) + end + ) + scope.terminate_with internal_key, jump_value + Types::NIL + end + + def evaluate_yield_node(node, scope) + evaluate_list_splat_items node.arguments.arguments, scope if node.arguments + Types::OBJECT + end + + def evaluate_redo_node(_node, scope) + scope.terminate + Types::NIL + end + + def evaluate_retry_node(_node, scope) + scope.terminate + Types::NIL + end + + def evaluate_forwarding_super_node(_node, _scope) = Types::OBJECT + + def evaluate_super_node(node, scope) + evaluate_list_splat_items node.arguments.arguments, scope if node.arguments + Types::OBJECT + end + + def evaluate_begin_node(node, scope) + return_type = node.statements ? evaluate(node.statements, scope) : Types::NIL + if node.rescue_clause + if node.else_clause + return_types = scope.run_branches( + ->{ evaluate node.rescue_clause, _1 }, + ->{ evaluate node.else_clause, _1 } + ) + else + return_types = [ + return_type, + scope.conditional { evaluate node.rescue_clause, _1 } + ] + end + return_type = Types::UnionType[*return_types] + end + if node.ensure_clause&.statements + # ensure_clause is Prism::EnsureNode + evaluate node.ensure_clause.statements, scope + end + return_type + end + + def evaluate_rescue_node(node, scope) + run_rescue = lambda do |s| + if node.reference + error_classes_type = evaluate_list_splat_items node.exceptions, s + error_types = error_classes_type.types.filter_map do + Types::InstanceType.new _1.module_or_class if _1.is_a?(Types::SingletonType) + end + error_types << Types::InstanceType.new(StandardError) if error_types.empty? + error_type = Types::UnionType[*error_types] + case node.reference + when Prism::LocalVariableTargetNode, Prism::InstanceVariableTargetNode, Prism::ClassVariableTargetNode, Prism::GlobalVariableTargetNode, Prism::ConstantTargetNode + s[node.reference.name.to_s] = error_type + when Prism::CallNode + evaluate node.reference, s + end + end + node.statements ? evaluate(node.statements, s) : Types::NIL + end + if node.consequent # begin; rescue A; rescue B; end + types = scope.run_branches( + run_rescue, + -> { evaluate node.consequent, _1 } + ) + Types::UnionType[*types] + else + run_rescue.call scope + end + end + + def evaluate_rescue_modifier_node(node, scope) + a = evaluate node.expression, scope + b = scope.conditional { evaluate node.rescue_expression, _1 } + Types::UnionType[a, b] + end + + def evaluate_singleton_class_node(node, scope) + klass_types = evaluate(node.expression, scope).types.filter_map do |type| + Types::SingletonType.new type.klass if type.is_a? Types::InstanceType + end + klass_types = [Types::CLASS] if klass_types.empty? + table = node.locals.to_h { [_1.to_s, Types::NIL] } + sclass_scope = Scope.new( + scope, + { **table, Scope::BREAK_RESULT => nil, Scope::NEXT_RESULT => nil, Scope::RETURN_RESULT => nil }, + trace_ivar: false, + trace_lvar: false, + self_type: Types::UnionType[*klass_types] + ) + result = node.body ? evaluate(node.body, sclass_scope) : Types::NIL + scope.update sclass_scope + result + end + + def evaluate_class_node(node, scope) = evaluate_class_module(node, scope, true) + def evaluate_module_node(node, scope) = evaluate_class_module(node, scope, false) + def evaluate_class_module(node, scope, is_class) + unless node.constant_path.is_a?(Prism::ConstantReadNode) || node.constant_path.is_a?(Prism::ConstantPathNode) + # Incomplete class/module `class (statement[cursor_here])::Name; end` + evaluate node.constant_path, scope + return Types::NIL + end + const_type, _receiver, parent_module, name = evaluate_constant_node_info node.constant_path, scope + if is_class + select_class_type = -> { _1.is_a?(Types::SingletonType) && _1.module_or_class.is_a?(Class) } + module_types = const_type.types.select(&select_class_type) + module_types += evaluate(node.superclass, scope).types.select(&select_class_type) if node.superclass + module_types << Types::CLASS if module_types.empty? + else + module_types = const_type.types.select { _1.is_a?(Types::SingletonType) && !_1.module_or_class.is_a?(Class) } + module_types << Types::MODULE if module_types.empty? + end + return Types::NIL unless node.body + + table = node.locals.to_h { [_1.to_s, Types::NIL] } + if !name.empty? && (parent_module.is_a?(Module) || parent_module.nil?) + value = parent_module.const_get name if parent_module&.const_defined? name + unless value + value_type = scope[name] + value = value_type.module_or_class if value_type.is_a? Types::SingletonType + end + + if value.is_a? Module + nesting = [value, []] + else + if parent_module + nesting = [parent_module, [name]] + else + parent_nesting, parent_path = scope.module_nesting.first + nesting = [parent_nesting, parent_path + [name]] + end + nesting_key = [nesting[0].__id__, nesting[1]].join('::') + nesting_value = is_class ? Types::CLASS : Types::MODULE + end + else + # parent_module == :unknown + # TODO: dummy module + end + module_scope = Scope.new( + scope, + { **table, Scope::BREAK_RESULT => nil, Scope::NEXT_RESULT => nil, Scope::RETURN_RESULT => nil }, + trace_ivar: false, + trace_lvar: false, + self_type: Types::UnionType[*module_types], + nesting: nesting + ) + module_scope[nesting_key] = nesting_value if nesting_value + result = evaluate(node.body, module_scope) + scope.update module_scope + result + end + + def evaluate_for_node(node, scope) + node.statements + collection = evaluate node.collection, scope + inner_scope = Scope.new scope, { Scope::BREAK_RESULT => nil } + ary_type = method_call collection, :to_ary, [], nil, nil, nil, name_match: false + element_types = ary_type.types.filter_map do |ary| + ary.params[:Elem] if ary.is_a?(Types::InstanceType) && ary.klass == Array + end + element_type = Types::UnionType[*element_types] + inner_scope.conditional do |s| + evaluate_write node.index, element_type, s, nil + evaluate node.statements, s if node.statements + end + inner_scope.merge_jumps + scope.update inner_scope + breaks = inner_scope[Scope::BREAK_RESULT] + breaks ? Types::UnionType[breaks, collection] : collection + end + + def evaluate_case_node(node, scope) + target = evaluate(node.predicate, scope) if node.predicate + # TODO + branches = node.conditions.map do |condition| + ->(s) { evaluate_case_match target, condition, s } + end + if node.consequent + branches << ->(s) { evaluate node.consequent, s } + elsif node.conditions.any? { _1.is_a? Prism::WhenNode } + branches << ->(_s) { Types::NIL } + end + Types::UnionType[*scope.run_branches(*branches)] + end + + def evaluate_match_required_node(node, scope) + value_type = evaluate node.value, scope + evaluate_match_pattern value_type, node.pattern, scope + Types::NIL # void value + end + + def evaluate_match_predicate_node(node, scope) + value_type = evaluate node.value, scope + scope.conditional { evaluate_match_pattern value_type, node.pattern, _1 } + Types::BOOLEAN + end + + def evaluate_range_node(node, scope) + beg_type = evaluate node.left, scope if node.left + end_type = evaluate node.right, scope if node.right + elem = (Types::UnionType[*[beg_type, end_type].compact]).nonnillable + Types::InstanceType.new Range, Elem: elem + end + + def evaluate_defined_node(node, scope) + scope.conditional { evaluate node.value, _1 } + Types::UnionType[Types::STRING, Types::NIL] + end + + def evaluate_flip_flop_node(node, scope) + scope.conditional { evaluate node.left, _1 } if node.left + scope.conditional { evaluate node.right, _1 } if node.right + Types::BOOLEAN + end + + def evaluate_multi_target_node(node, scope) + # Raw MultiTargetNode, incomplete code like `a,b`, `*a`. + evaluate_multi_write_receiver node, scope, nil + Types::NIL + end + + def evaluate_splat_node(node, scope) + # Raw SplatNode, incomplete code like `*a.` + evaluate_multi_write_receiver node.expression, scope, nil if node.expression + Types::NIL + end + + def evaluate_implicit_node(node, scope) + evaluate node.value, scope + end + + def evaluate_match_write_node(node, scope) + # /(?)(?)/ =~ string + evaluate node.call, scope + node.locals.each { scope[_1.to_s] = Types::UnionType[Types::STRING, Types::NIL] } + Types::BOOLEAN + end + + def evaluate_match_last_line_node(_node, _scope) + Types::BOOLEAN + end + + def evaluate_interpolated_match_last_line_node(node, scope) + node.parts.each { evaluate _1, scope } + Types::BOOLEAN + end + + def evaluate_pre_execution_node(node, scope) + node.statements ? evaluate(node.statements, scope) : Types::NIL + end + + def evaluate_post_execution_node(node, scope) + node.statements && @dig_targets.dig?(node.statements) ? evaluate(node.statements, scope) : Types::NIL + end + + def evaluate_alias_method_node(_node, _scope) = Types::NIL + def evaluate_alias_global_variable_node(_node, _scope) = Types::NIL + def evaluate_undef_node(_node, _scope) = Types::NIL + def evaluate_missing_node(_node, _scope) = Types::NIL + + def evaluate_call_node_arguments(call_node, scope) + # call_node.arguments is Prism::ArgumentsNode + arguments = call_node.arguments&.arguments&.dup || [] + block_arg = call_node.block.expression if call_node.block.is_a?(Prism::BlockArgumentNode) + kwargs = arguments.pop.elements if arguments.last.is_a?(Prism::KeywordHashNode) + args_types = arguments.map do |arg| + case arg + when Prism::ForwardingArgumentsNode + # `f(a, ...)` treat like splat + nil + when Prism::SplatNode + evaluate arg.expression, scope if arg.expression + nil # TODO: splat + else + evaluate arg, scope + end + end + if kwargs + kwargs_types = kwargs.map do |arg| + case arg + when Prism::AssocNode + if arg.key.is_a?(Prism::SymbolNode) + [arg.key.value, evaluate(arg.value, scope)] + else + evaluate arg.key, scope + evaluate arg.value, scope + nil + end + when Prism::AssocSplatNode + evaluate arg.value, scope if arg.value + nil + end + end.compact.to_h + end + if block_arg.is_a? Prism::SymbolNode + block_sym_node = block_arg + elsif block_arg + evaluate block_arg, scope + end + [args_types, kwargs_types, block_sym_node, !!block_arg] + end + + def const_path_write(receiver, name, value, scope) + if receiver # receiver::A = value + singleton_type = receiver.types.find { _1.is_a? Types::SingletonType } + scope.set_const singleton_type.module_or_class, name, value if singleton_type + else # ::A = value + scope.set_const Object, name, value + end + end + + def assign_required_parameter(node, value, scope) + case node + when Prism::RequiredParameterNode + scope[node.name.to_s] = value || Types::OBJECT + when Prism::MultiTargetNode + parameters = [*node.lefts, *node.rest, *node.rights] + values = value ? sized_splat(value, :to_ary, parameters.size) : [] + parameters.zip values do |n, v| + assign_required_parameter n, v, scope + end + when Prism::SplatNode + splat_value = value ? Types.array_of(value) : Types::ARRAY + assign_required_parameter node.expression, splat_value, scope if node.expression + end + end + + def evaluate_constant_node_info(node, scope) + case node + when Prism::ConstantPathNode + name = node.child.name.to_s + if node.parent + receiver = evaluate node.parent, scope + if receiver.is_a? Types::SingletonType + parent_module = receiver.module_or_class + end + else + parent_module = Object + end + if parent_module + type = scope.get_const(parent_module, [name]) || Types::NIL + else + parent_module = :unknown + type = Types::NIL + end + when Prism::ConstantReadNode + name = node.name.to_s + type = scope[name] + end + @dig_targets.resolve type, scope if @dig_targets.target? node + [type, receiver, parent_module, name] + end + + + def assign_parameters(node, scope, args, kwargs) + args = args.dup + kwargs = kwargs.dup + size = node.requireds.size + node.optionals.size + (node.rest ? 1 : 0) + node.posts.size + args = sized_splat(args.first, :to_ary, size) if size >= 2 && args.size == 1 + reqs = args.shift node.requireds.size + if node.rest + # node.rest is Prism::RestParameterNode + posts = [] + opts = args.shift node.optionals.size + rest = args + else + posts = args.pop node.posts.size + opts = args + rest = [] + end + node.requireds.zip reqs do |n, v| + assign_required_parameter n, v, scope + end + node.optionals.zip opts do |n, v| + # n is Prism::OptionalParameterNode + values = [v] + values << evaluate(n.value, scope) if n.value + scope[n.name.to_s] = Types::UnionType[*values.compact] + end + node.posts.zip posts do |n, v| + assign_required_parameter n, v, scope + end + if node.rest&.name + # node.rest is Prism::RestParameterNode + scope[node.rest.name.to_s] = Types.array_of(*rest) + end + node.keywords.each do |n| + name = n.name.to_s.delete(':') + values = [kwargs.delete(name)] + # n is Prism::OptionalKeywordParameterNode (has n.value) or Prism::RequiredKeywordParameterNode (does not have n.value) + values << evaluate(n.value, scope) if n.respond_to?(:value) + scope[name] = Types::UnionType[*values.compact] + end + # node.keyword_rest is Prism::KeywordRestParameterNode or Prism::ForwardingParameterNode or Prism::NoKeywordsParameterNode + if node.keyword_rest.is_a?(Prism::KeywordRestParameterNode) && node.keyword_rest.name + scope[node.keyword_rest.name.to_s] = Types::InstanceType.new(Hash, K: Types::SYMBOL, V: Types::UnionType[*kwargs.values]) + end + if node.block&.name + # node.block is Prism::BlockParameterNode + scope[node.block.name.to_s] = Types::PROC + end + end + + def assign_numbered_parameters(numbered_parameters, scope, args, _kwargs) + return if numbered_parameters.empty? + max_num = numbered_parameters.map { _1[1].to_i }.max + if max_num == 1 + scope['_1'] = args.first || Types::NIL + else + args = sized_splat(args.first, :to_ary, max_num) if args.size == 1 + numbered_parameters.each do |name| + index = name[1].to_i - 1 + scope[name] = args[index] || Types::NIL + end + end + end + + def evaluate_case_match(target, node, scope) + case node + when Prism::WhenNode + node.conditions.each { evaluate _1, scope } + node.statements ? evaluate(node.statements, scope) : Types::NIL + when Prism::InNode + pattern = node.pattern + if pattern.is_a?(Prism::IfNode) || pattern.is_a?(Prism::UnlessNode) + cond_node = pattern.predicate + pattern = pattern.statements.body.first + end + evaluate_match_pattern(target, pattern, scope) + evaluate cond_node, scope if cond_node # TODO: conditional branch + node.statements ? evaluate(node.statements, scope) : Types::NIL + end + end + + def evaluate_match_pattern(value, pattern, scope) + # TODO: scope.terminate_with Scope::PATTERNMATCH_BREAK, Types::NIL + case pattern + when Prism::FindPatternNode + # TODO + evaluate_match_pattern Types::OBJECT, pattern.left, scope + pattern.requireds.each { evaluate_match_pattern Types::OBJECT, _1, scope } + evaluate_match_pattern Types::OBJECT, pattern.right, scope + when Prism::ArrayPatternNode + # TODO + pattern.requireds.each { evaluate_match_pattern Types::OBJECT, _1, scope } + evaluate_match_pattern Types::OBJECT, pattern.rest, scope if pattern.rest + pattern.posts.each { evaluate_match_pattern Types::OBJECT, _1, scope } + Types::ARRAY + when Prism::HashPatternNode + # TODO + pattern.elements.each { evaluate_match_pattern Types::OBJECT, _1, scope } + if pattern.respond_to?(:rest) && pattern.rest + evaluate_match_pattern Types::OBJECT, pattern.rest, scope + end + Types::HASH + when Prism::AssocNode + evaluate_match_pattern value, pattern.value, scope if pattern.value + Types::OBJECT + when Prism::AssocSplatNode + # TODO + evaluate_match_pattern Types::HASH, pattern.value, scope + Types::OBJECT + when Prism::PinnedVariableNode + evaluate pattern.variable, scope + when Prism::PinnedExpressionNode + evaluate pattern.expression, scope + when Prism::LocalVariableTargetNode + scope[pattern.name.to_s] = value + when Prism::AlternationPatternNode + Types::UnionType[evaluate_match_pattern(value, pattern.left, scope), evaluate_match_pattern(value, pattern.right, scope)] + when Prism::CapturePatternNode + capture_type = class_or_value_to_instance evaluate_match_pattern(value, pattern.value, scope) + value = capture_type unless capture_type.types.empty? || capture_type.types == [Types::OBJECT] + evaluate_match_pattern value, pattern.target, scope + when Prism::SplatNode + value = Types.array_of value + evaluate_match_pattern value, pattern.expression, scope if pattern.expression + value + else + # literal node + type = evaluate(pattern, scope) + class_or_value_to_instance(type) + end + end + + def class_or_value_to_instance(type) + instance_types = type.types.map do |t| + t.is_a?(Types::SingletonType) ? Types::InstanceType.new(t.module_or_class) : t + end + Types::UnionType[*instance_types] + end + + def evaluate_write(node, value, scope, evaluated_receivers) + case node + when Prism::MultiTargetNode + evaluate_multi_write node, value, scope, evaluated_receivers + when Prism::CallNode + evaluated_receivers&.[](node.receiver) || evaluate(node.receiver, scope) if node.receiver + when Prism::SplatNode + evaluate_write node.expression, Types.array_of(value), scope, evaluated_receivers if node.expression + when Prism::LocalVariableTargetNode, Prism::GlobalVariableTargetNode, Prism::InstanceVariableTargetNode, Prism::ClassVariableTargetNode, Prism::ConstantTargetNode + scope[node.name.to_s] = value + when Prism::ConstantPathTargetNode + receiver = evaluated_receivers&.[](node.parent) || evaluate(node.parent, scope) if node.parent + const_path_write receiver, node.child.name.to_s, value, scope + value + end + end + + def evaluate_multi_write(node, values, scope, evaluated_receivers) + pre_targets = node.lefts + splat_target = node.rest + post_targets = node.rights + size = pre_targets.size + (splat_target ? 1 : 0) + post_targets.size + values = values.is_a?(Array) ? values.dup : sized_splat(values, :to_ary, size) + pre_pairs = pre_targets.zip(values.shift(pre_targets.size)) + post_pairs = post_targets.zip(values.pop(post_targets.size)) + splat_pairs = splat_target ? [[splat_target, Types::UnionType[*values]]] : [] + (pre_pairs + splat_pairs + post_pairs).each do |target, value| + evaluate_write target, value || Types::NIL, scope, evaluated_receivers + end + end + + def evaluate_multi_write_receiver(node, scope, evaluated_receivers) + case node + when Prism::MultiWriteNode, Prism::MultiTargetNode + targets = [*node.lefts, *node.rest, *node.rights] + targets.each { evaluate_multi_write_receiver _1, scope, evaluated_receivers } + when Prism::CallNode + if node.receiver + receiver = evaluate(node.receiver, scope) + evaluated_receivers[node.receiver] = receiver if evaluated_receivers + end + if node.arguments + node.arguments.arguments&.each do |arg| + if arg.is_a? Prism::SplatNode + evaluate arg.expression, scope + else + evaluate arg, scope + end + end + end + when Prism::SplatNode + evaluate_multi_write_receiver node.expression, scope, evaluated_receivers if node.expression + end + end + + def evaluate_list_splat_items(list, scope) + items = list.flat_map do |node| + if node.is_a? Prism::SplatNode + next unless node.expression # def f(*); [*] + + splat = evaluate node.expression, scope + array_elem, non_array = partition_to_array splat.nonnillable, :to_a + [*array_elem, *non_array] + else + evaluate node, scope + end + end.compact.uniq + Types::UnionType[*items] + end + + def sized_splat(value, method, size) + array_elem, non_array = partition_to_array value, method + values = [Types::UnionType[*array_elem, *non_array]] + values += [array_elem] * (size - 1) if array_elem && size >= 1 + values + end + + def partition_to_array(value, method) + arrays, non_arrays = value.types.partition { _1.is_a?(Types::InstanceType) && _1.klass == Array } + non_arrays.select! do |type| + to_array_result = method_call type, method, [], nil, nil, nil, name_match: false + if to_array_result.is_a?(Types::InstanceType) && to_array_result.klass == Array + arrays << to_array_result + false + else + true + end + end + array_elem = arrays.empty? ? nil : Types::UnionType[*arrays.map { _1.params[:Elem] || Types::OBJECT }] + non_array = non_arrays.empty? ? nil : Types::UnionType[*non_arrays] + [array_elem, non_array] + end + + def method_call(receiver, method_name, args, kwargs, block, scope, name_match: true) + methods = Types.rbs_methods receiver, method_name.to_sym, args, kwargs, !!block + block_called = false + type_breaks = methods.map do |method, given_params, method_params| + receiver_vars = receiver.is_a?(Types::InstanceType) ? receiver.params : {} + free_vars = method.type.free_variables - receiver_vars.keys.to_set + vars = receiver_vars.merge Types.match_free_variables(free_vars, method_params, given_params) + if block && method.block + params_type = method.block.type.required_positionals.map do |func_param| + Types.from_rbs_type func_param.type, receiver, vars + end + self_type = Types.from_rbs_type method.block.self_type, receiver, vars if method.block.self_type + block_response, breaks = block.call params_type, self_type + block_called = true + vars.merge! Types.match_free_variables(free_vars - vars.keys.to_set, [method.block.type.return_type], [block_response]) + end + if Types.method_return_bottom?(method) + [nil, breaks] + else + [Types.from_rbs_type(method.type.return_type, receiver, vars || {}), breaks] + end + end + block&.call [], nil unless block_called + terminates = !type_breaks.empty? && type_breaks.map(&:first).all?(&:nil?) + types = type_breaks.map(&:first).compact + breaks = type_breaks.map(&:last).compact + types << OBJECT_METHODS[method_name.to_sym] if name_match && OBJECT_METHODS.has_key?(method_name.to_sym) + + if method_name.to_sym == :new + receiver.types.each do |type| + if type.is_a?(Types::SingletonType) && type.module_or_class.is_a?(Class) + types << Types::InstanceType.new(type.module_or_class) + end + end + end + scope&.terminate if terminates && breaks.empty? + Types::UnionType[*types, *breaks] + end + + def self.calculate_target_type_scope(binding, parents, target) + dig_targets = DigTarget.new(parents, target) do |type, scope| + return type, scope + end + program = parents.first + scope = Scope.from_binding(binding, program.locals) + new(dig_targets).evaluate program, scope + [Types::NIL, scope] + end + end + end +end diff --git a/lib/irb/type_completion/types.rb b/lib/irb/type_completion/types.rb new file mode 100644 index 00000000000000..f0f2342ffeee3e --- /dev/null +++ b/lib/irb/type_completion/types.rb @@ -0,0 +1,426 @@ +# frozen_string_literal: true + +require_relative 'methods' + +module IRB + module TypeCompletion + module Types + OBJECT_TO_TYPE_SAMPLE_SIZE = 50 + + singleton_class.attr_reader :rbs_builder, :rbs_load_error + + def self.preload_in_thread + return if @preload_started + + @preload_started = true + Thread.new do + load_rbs_builder + end + end + + def self.load_rbs_builder + require 'rbs' + require 'rbs/cli' + loader = RBS::CLI::LibraryOptions.new.loader + loader.add path: Pathname('sig') + @rbs_builder = RBS::DefinitionBuilder.new env: RBS::Environment.from_loader(loader).resolve_type_names + rescue LoadError, StandardError => e + @rbs_load_error = e + nil + end + + def self.class_name_of(klass) + klass = klass.superclass if klass.singleton_class? + Methods::MODULE_NAME_METHOD.bind_call klass + end + + def self.rbs_search_method(klass, method_name, singleton) + klass.ancestors.each do |ancestor| + name = class_name_of ancestor + next unless name && rbs_builder + type_name = RBS::TypeName(name).absolute! + definition = (singleton ? rbs_builder.build_singleton(type_name) : rbs_builder.build_instance(type_name)) rescue nil + method = definition.methods[method_name] if definition + return method if method + end + nil + end + + def self.method_return_type(type, method_name) + receivers = type.types.map do |t| + case t + in SingletonType + [t, t.module_or_class, true] + in InstanceType + [t, t.klass, false] + end + end + types = receivers.flat_map do |receiver_type, klass, singleton| + method = rbs_search_method klass, method_name, singleton + next [] unless method + method.method_types.map do |method| + from_rbs_type(method.type.return_type, receiver_type, {}) + end + end + UnionType[*types] + end + + def self.rbs_methods(type, method_name, args_types, kwargs_type, has_block) + return [] unless rbs_builder + + receivers = type.types.map do |t| + case t + in SingletonType + [t, t.module_or_class, true] + in InstanceType + [t, t.klass, false] + end + end + has_splat = args_types.include?(nil) + methods_with_score = receivers.flat_map do |receiver_type, klass, singleton| + method = rbs_search_method klass, method_name, singleton + next [] unless method + method.method_types.map do |method_type| + score = 0 + score += 2 if !!method_type.block == has_block + reqs = method_type.type.required_positionals + opts = method_type.type.optional_positionals + rest = method_type.type.rest_positionals + trailings = method_type.type.trailing_positionals + keyreqs = method_type.type.required_keywords + keyopts = method_type.type.optional_keywords + keyrest = method_type.type.rest_keywords + args = args_types + if kwargs_type&.any? && keyreqs.empty? && keyopts.empty? && keyrest.nil? + kw_value_type = UnionType[*kwargs_type.values] + args += [InstanceType.new(Hash, K: SYMBOL, V: kw_value_type)] + end + if has_splat + score += 1 if args.count(&:itself) <= reqs.size + opts.size + trailings.size + elsif reqs.size + trailings.size <= args.size && (rest || args.size <= reqs.size + opts.size + trailings.size) + score += 2 + centers = args[reqs.size...-trailings.size] + given = args.first(reqs.size) + centers.take(opts.size) + args.last(trailings.size) + expected = (reqs + opts.take(centers.size) + trailings).map(&:type) + if rest + given << UnionType[*centers.drop(opts.size)] + expected << rest.type + end + if given.any? + score += given.zip(expected).count do |t, e| + e = from_rbs_type e, receiver_type + intersect?(t, e) || (intersect?(STRING, e) && t.methods.include?(:to_str)) || (intersect?(INTEGER, e) && t.methods.include?(:to_int)) || (intersect?(ARRAY, e) && t.methods.include?(:to_ary)) + end.fdiv(given.size) + end + end + [[method_type, given || [], expected || []], score] + end + end + max_score = methods_with_score.map(&:last).max + methods_with_score.select { _2 == max_score }.map(&:first) + end + + def self.intersect?(a, b) + atypes = a.types.group_by(&:class) + btypes = b.types.group_by(&:class) + if atypes[SingletonType] && btypes[SingletonType] + aa, bb = [atypes, btypes].map {|types| types[SingletonType].map(&:module_or_class) } + return true if (aa & bb).any? + end + + aa, bb = [atypes, btypes].map {|types| (types[InstanceType] || []).map(&:klass) } + (aa.flat_map(&:ancestors) & bb).any? + end + + def self.type_from_object(object) + case object + when Array + InstanceType.new Array, { Elem: union_type_from_objects(object) } + when Hash + InstanceType.new Hash, { K: union_type_from_objects(object.keys), V: union_type_from_objects(object.values) } + when Module + SingletonType.new object + else + klass = Methods::OBJECT_SINGLETON_CLASS_METHOD.bind_call(object) rescue Methods::OBJECT_CLASS_METHOD.bind_call(object) + InstanceType.new klass + end + end + + def self.union_type_from_objects(objects) + values = objects.size <= OBJECT_TO_TYPE_SAMPLE_SIZE ? objects : objects.sample(OBJECT_TO_TYPE_SAMPLE_SIZE) + klasses = values.map { Methods::OBJECT_CLASS_METHOD.bind_call(_1) } + UnionType[*klasses.uniq.map { InstanceType.new _1 }] + end + + class SingletonType + attr_reader :module_or_class + def initialize(module_or_class) + @module_or_class = module_or_class + end + def transform() = yield(self) + def methods() = @module_or_class.methods + def all_methods() = methods | Kernel.methods + def constants() = @module_or_class.constants + def types() = [self] + def nillable?() = false + def nonnillable() = self + def inspect + "#{module_or_class}.itself" + end + end + + class InstanceType + attr_reader :klass, :params + def initialize(klass, params = {}) + @klass = klass + @params = params + end + def transform() = yield(self) + def methods() = rbs_methods.select { _2.public? }.keys | @klass.instance_methods + def all_methods() = rbs_methods.keys | @klass.instance_methods | @klass.private_instance_methods + def constants() = [] + def types() = [self] + def nillable?() = (@klass == NilClass) + def nonnillable() = self + def rbs_methods + name = Types.class_name_of(@klass) + return {} unless name && Types.rbs_builder + + type_name = RBS::TypeName(name).absolute! + Types.rbs_builder.build_instance(type_name).methods rescue {} + end + def inspect + if params.empty? + inspect_without_params + else + params_string = "[#{params.map { "#{_1}: #{_2.inspect}" }.join(', ')}]" + "#{inspect_without_params}#{params_string}" + end + end + def inspect_without_params + if klass == NilClass + 'nil' + elsif klass == TrueClass + 'true' + elsif klass == FalseClass + 'false' + else + klass.singleton_class? ? klass.superclass.to_s : klass.to_s + end + end + end + + NIL = InstanceType.new NilClass + OBJECT = InstanceType.new Object + TRUE = InstanceType.new TrueClass + FALSE = InstanceType.new FalseClass + SYMBOL = InstanceType.new Symbol + STRING = InstanceType.new String + INTEGER = InstanceType.new Integer + RANGE = InstanceType.new Range + REGEXP = InstanceType.new Regexp + FLOAT = InstanceType.new Float + RATIONAL = InstanceType.new Rational + COMPLEX = InstanceType.new Complex + ARRAY = InstanceType.new Array + HASH = InstanceType.new Hash + CLASS = InstanceType.new Class + MODULE = InstanceType.new Module + PROC = InstanceType.new Proc + + class UnionType + attr_reader :types + + def initialize(*types) + @types = [] + singletons = [] + instances = {} + collect = -> type do + case type + in UnionType + type.types.each(&collect) + in InstanceType + params = (instances[type.klass] ||= {}) + type.params.each do |k, v| + (params[k] ||= []) << v + end + in SingletonType + singletons << type + end + end + types.each(&collect) + @types = singletons.uniq + instances.map do |klass, params| + InstanceType.new(klass, params.transform_values { |v| UnionType[*v] }) + end + end + + def transform(&block) + UnionType[*types.map(&block)] + end + + def nillable? + types.any?(&:nillable?) + end + + def nonnillable + UnionType[*types.reject { _1.is_a?(InstanceType) && _1.klass == NilClass }] + end + + def self.[](*types) + type = new(*types) + if type.types.empty? + OBJECT + elsif type.types.size == 1 + type.types.first + else + type + end + end + + def methods() = @types.flat_map(&:methods).uniq + def all_methods() = @types.flat_map(&:all_methods).uniq + def constants() = @types.flat_map(&:constants).uniq + def inspect() = @types.map(&:inspect).join(' | ') + end + + BOOLEAN = UnionType[TRUE, FALSE] + + def self.array_of(*types) + type = types.size >= 2 ? UnionType[*types] : types.first || OBJECT + InstanceType.new Array, Elem: type + end + + def self.from_rbs_type(return_type, self_type, extra_vars = {}) + case return_type + when RBS::Types::Bases::Self + self_type + when RBS::Types::Bases::Bottom, RBS::Types::Bases::Nil + NIL + when RBS::Types::Bases::Any, RBS::Types::Bases::Void + OBJECT + when RBS::Types::Bases::Class + self_type.transform do |type| + case type + in SingletonType + InstanceType.new(self_type.module_or_class.is_a?(Class) ? Class : Module) + in InstanceType + SingletonType.new type.klass + end + end + UnionType[*types] + when RBS::Types::Bases::Bool + BOOLEAN + when RBS::Types::Bases::Instance + self_type.transform do |type| + if type.is_a?(SingletonType) && type.module_or_class.is_a?(Class) + InstanceType.new type.module_or_class + else + OBJECT + end + end + when RBS::Types::Union + UnionType[*return_type.types.map { from_rbs_type _1, self_type, extra_vars }] + when RBS::Types::Proc + PROC + when RBS::Types::Tuple + elem = UnionType[*return_type.types.map { from_rbs_type _1, self_type, extra_vars }] + InstanceType.new Array, Elem: elem + when RBS::Types::Record + InstanceType.new Hash, K: SYMBOL, V: OBJECT + when RBS::Types::Literal + InstanceType.new return_type.literal.class + when RBS::Types::Variable + if extra_vars.key? return_type.name + extra_vars[return_type.name] + elsif self_type.is_a? InstanceType + self_type.params[return_type.name] || OBJECT + elsif self_type.is_a? UnionType + types = self_type.types.filter_map do |t| + t.params[return_type.name] if t.is_a? InstanceType + end + UnionType[*types] + else + OBJECT + end + when RBS::Types::Optional + UnionType[from_rbs_type(return_type.type, self_type, extra_vars), NIL] + when RBS::Types::Alias + case return_type.name.name + when :int + INTEGER + when :boolish + BOOLEAN + when :string + STRING + else + # TODO: ??? + OBJECT + end + when RBS::Types::Interface + # unimplemented + OBJECT + when RBS::Types::ClassInstance + klass = return_type.name.to_namespace.path.reduce(Object) { _1.const_get _2 } + if return_type.args + args = return_type.args.map { from_rbs_type _1, self_type, extra_vars } + names = rbs_builder.build_singleton(return_type.name).type_params + params = names.map.with_index { [_1, args[_2] || OBJECT] }.to_h + end + InstanceType.new klass, params || {} + end + end + + def self.method_return_bottom?(method) + method.type.return_type.is_a? RBS::Types::Bases::Bottom + end + + def self.match_free_variables(vars, types, values) + accumulator = {} + types.zip values do |t, v| + _match_free_variable(vars, t, v, accumulator) if v + end + accumulator.transform_values { UnionType[*_1] } + end + + def self._match_free_variable(vars, rbs_type, value, accumulator) + case [rbs_type, value] + in [RBS::Types::Variable,] + (accumulator[rbs_type.name] ||= []) << value if vars.include? rbs_type.name + in [RBS::Types::ClassInstance, InstanceType] + names = rbs_builder.build_singleton(rbs_type.name).type_params + names.zip(rbs_type.args).each do |name, arg| + v = value.params[name] + _match_free_variable vars, arg, v, accumulator if v + end + in [RBS::Types::Tuple, InstanceType] if value.klass == Array + v = value.params[:Elem] + rbs_type.types.each do |t| + _match_free_variable vars, t, v, accumulator + end + in [RBS::Types::Record, InstanceType] if value.klass == Hash + # TODO + in [RBS::Types::Interface,] + definition = rbs_builder.build_interface rbs_type.name + convert = {} + definition.type_params.zip(rbs_type.args).each do |from, arg| + convert[from] = arg.name if arg.is_a? RBS::Types::Variable + end + return if convert.empty? + ac = {} + definition.methods.each do |method_name, method| + return_type = method_return_type value, method_name + method.defs.each do |method_def| + interface_return_type = method_def.type.type.return_type + _match_free_variable convert, interface_return_type, return_type, ac + end + end + convert.each do |from, to| + values = ac[from] + (accumulator[to] ||= []).concat values if values + end + else + end + end + end + end +end diff --git a/test/irb/test_cmd.rb b/test/irb/test_cmd.rb index 67dcfd0a635ed1..219710c9210f34 100644 --- a/test/irb/test_cmd.rb +++ b/test/irb/test_cmd.rb @@ -90,6 +90,7 @@ def test_irb_info_multiline Ruby\sversion:\s.+\n IRB\sversion:\sirb\s.+\n InputMethod:\sAbstract\sInputMethod\n + Completion: .+\n \.irbrc\spath:\s.+\n RUBY_PLATFORM:\s.+\n East\sAsian\sAmbiguous\sWidth:\s\d\n @@ -113,6 +114,7 @@ def test_irb_info_singleline Ruby\sversion:\s.+\n IRB\sversion:\sirb\s.+\n InputMethod:\sAbstract\sInputMethod\n + Completion: .+\n \.irbrc\spath:\s.+\n RUBY_PLATFORM:\s.+\n East\sAsian\sAmbiguous\sWidth:\s\d\n @@ -139,6 +141,7 @@ def test_irb_info_multiline_without_rc_files Ruby\sversion:\s.+\n IRB\sversion:\sirb\s.+\n InputMethod:\sAbstract\sInputMethod\n + Completion: .+\n RUBY_PLATFORM:\s.+\n East\sAsian\sAmbiguous\sWidth:\s\d\n #{@is_win ? 'Code\spage:\s\d+\n' : ''} @@ -168,6 +171,7 @@ def test_irb_info_singleline_without_rc_files Ruby\sversion:\s.+\n IRB\sversion:\sirb\s.+\n InputMethod:\sAbstract\sInputMethod\n + Completion: .+\n RUBY_PLATFORM:\s.+\n East\sAsian\sAmbiguous\sWidth:\s\d\n #{@is_win ? 'Code\spage:\s\d+\n' : ''} @@ -196,6 +200,7 @@ def test_irb_info_lang Ruby\sversion: .+\n IRB\sversion:\sirb .+\n InputMethod:\sAbstract\sInputMethod\n + Completion: .+\n \.irbrc\spath: .+\n RUBY_PLATFORM: .+\n LANG\senv:\sja_JP\.UTF-8\n diff --git a/test/irb/test_context.rb b/test/irb/test_context.rb index af47bec9deae19..ce57df6cdb19fe 100644 --- a/test/irb/test_context.rb +++ b/test/irb/test_context.rb @@ -652,6 +652,24 @@ def test_lineno ], out) end + def test_build_completor + verbose, $VERBOSE = $VERBOSE, nil + original_completor = IRB.conf[:COMPLETOR] + IRB.conf[:COMPLETOR] = :regexp + assert_equal 'IRB::RegexpCompletor', @context.send(:build_completor).class.name + IRB.conf[:COMPLETOR] = :type + if RUBY_VERSION >= '3.0.0' && RUBY_ENGINE != 'truffleruby' + assert_equal 'IRB::TypeCompletion::Completor', @context.send(:build_completor).class.name + else + assert_equal 'IRB::RegexpCompletor', @context.send(:build_completor).class.name + end + IRB.conf[:COMPLETOR] = :unknown + assert_equal 'IRB::RegexpCompletor', @context.send(:build_completor).class.name + ensure + $VERBOSE = verbose + IRB.conf[:COMPLETOR] = original_completor + end + private def without_colorize diff --git a/test/irb/test_input_method.rb b/test/irb/test_input_method.rb index 2d8cfadcf57c54..e6a1b06e82e0f9 100644 --- a/test/irb/test_input_method.rb +++ b/test/irb/test_input_method.rb @@ -24,7 +24,7 @@ class RelineInputMethodTest < InputMethodTest def test_initialization Reline.completion_proc = nil Reline.dig_perfect_match_proc = nil - IRB::RelineInputMethod.new + IRB::RelineInputMethod.new(IRB::RegexpCompletor.new) assert_nil Reline.completion_append_character assert_equal '', Reline.completer_quote_characters @@ -40,7 +40,7 @@ def test_initialization_without_use_autocomplete IRB.conf[:USE_AUTOCOMPLETE] = false - IRB::RelineInputMethod.new + IRB::RelineInputMethod.new(IRB::RegexpCompletor.new) refute Reline.autocompletion assert_equal empty_proc, Reline.dialog_proc(:show_doc).dialog_proc @@ -55,7 +55,7 @@ def test_initialization_with_use_autocomplete IRB.conf[:USE_AUTOCOMPLETE] = true - IRB::RelineInputMethod.new + IRB::RelineInputMethod.new(IRB::RegexpCompletor.new) assert Reline.autocompletion assert_not_equal empty_proc, Reline.dialog_proc(:show_doc).dialog_proc @@ -71,7 +71,7 @@ def test_initialization_with_use_autocomplete_but_without_rdoc IRB.conf[:USE_AUTOCOMPLETE] = true without_rdoc do - IRB::RelineInputMethod.new + IRB::RelineInputMethod.new(IRB::RegexpCompletor.new) end assert Reline.autocompletion @@ -89,7 +89,7 @@ def setup end def display_document(target, bind) - input_method = IRB::RelineInputMethod.new + input_method = IRB::RelineInputMethod.new(IRB::RegexpCompletor.new) input_method.instance_variable_set(:@completion_params, [target, '', '', bind]) input_method.display_document(target, driver: @driver) end diff --git a/test/irb/type_completion/test_scope.rb b/test/irb/type_completion/test_scope.rb new file mode 100644 index 00000000000000..d7f9540b067b0c --- /dev/null +++ b/test/irb/type_completion/test_scope.rb @@ -0,0 +1,112 @@ +# frozen_string_literal: true + +return unless RUBY_VERSION >= '3.0.0' +return if RUBY_ENGINE == 'truffleruby' # needs endless method definition + +require 'irb/type_completion/scope' +require_relative '../helper' + +module TestIRB + class TypeCompletionScopeTest < TestCase + A, B, C, D, E, F, G, H, I, J, K = ('A'..'K').map do |name| + klass = Class.new + klass.define_singleton_method(:inspect) { name } + IRB::TypeCompletion::Types::InstanceType.new(klass) + end + + def assert_type(expected_types, type) + assert_equal [*expected_types].map(&:klass).to_set, type.types.map(&:klass).to_set + end + + def table(*local_variable_names) + local_variable_names.to_h { [_1, IRB::TypeCompletion::Types::NIL] } + end + + def base_scope + IRB::TypeCompletion::RootScope.new(binding, Object.new, []) + end + + def test_lvar + scope = IRB::TypeCompletion::Scope.new base_scope, table('a') + scope['a'] = A + assert_equal A, scope['a'] + end + + def test_conditional + scope = IRB::TypeCompletion::Scope.new base_scope, table('a') + scope.conditional do |sub_scope| + sub_scope['a'] = A + end + assert_type [A, IRB::TypeCompletion::Types::NIL], scope['a'] + end + + def test_branch + scope = IRB::TypeCompletion::Scope.new base_scope, table('a', 'b', 'c', 'd') + scope['c'] = A + scope['d'] = B + scope.run_branches( + -> { _1['a'] = _1['c'] = _1['d'] = C }, + -> { _1['a'] = _1['b'] = _1['d'] = D }, + -> { _1['a'] = _1['b'] = _1['d'] = E }, + -> { _1['a'] = _1['b'] = _1['c'] = F; _1.terminate } + ) + assert_type [C, D, E], scope['a'] + assert_type [IRB::TypeCompletion::Types::NIL, D, E], scope['b'] + assert_type [A, C], scope['c'] + assert_type [C, D, E], scope['d'] + end + + def test_scope_local_variables + scope1 = IRB::TypeCompletion::Scope.new base_scope, table('a', 'b') + scope2 = IRB::TypeCompletion::Scope.new scope1, table('b', 'c'), trace_lvar: false + scope3 = IRB::TypeCompletion::Scope.new scope2, table('c', 'd') + scope4 = IRB::TypeCompletion::Scope.new scope2, table('d', 'e') + assert_empty base_scope.local_variables + assert_equal %w[a b], scope1.local_variables.sort + assert_equal %w[b c], scope2.local_variables.sort + assert_equal %w[b c d], scope3.local_variables.sort + assert_equal %w[b c d e], scope4.local_variables.sort + end + + def test_nested_scope + scope = IRB::TypeCompletion::Scope.new base_scope, table('a', 'b', 'c') + scope['a'] = A + scope['b'] = A + scope['c'] = A + sub_scope = IRB::TypeCompletion::Scope.new scope, { 'c' => B } + assert_type A, sub_scope['a'] + + assert_type A, sub_scope['b'] + assert_type B, sub_scope['c'] + sub_scope['a'] = C + sub_scope.conditional { _1['b'] = C } + sub_scope['c'] = C + assert_type C, sub_scope['a'] + assert_type [A, C], sub_scope['b'] + assert_type C, sub_scope['c'] + scope.update sub_scope + assert_type C, scope['a'] + assert_type [A, C], scope['b'] + assert_type A, scope['c'] + end + + def test_break + scope = IRB::TypeCompletion::Scope.new base_scope, table('a') + scope['a'] = A + breakable_scope = IRB::TypeCompletion::Scope.new scope, { IRB::TypeCompletion::Scope::BREAK_RESULT => nil } + breakable_scope.conditional do |sub| + sub['a'] = B + assert_type [B], sub['a'] + sub.terminate_with IRB::TypeCompletion::Scope::BREAK_RESULT, C + sub['a'] = C + assert_type [C], sub['a'] + end + assert_type [A], breakable_scope['a'] + breakable_scope[IRB::TypeCompletion::Scope::BREAK_RESULT] = D + breakable_scope.merge_jumps + assert_type [C, D], breakable_scope[IRB::TypeCompletion::Scope::BREAK_RESULT] + scope.update breakable_scope + assert_type [A, B], scope['a'] + end + end +end diff --git a/test/irb/type_completion/test_type_analyze.rb b/test/irb/type_completion/test_type_analyze.rb new file mode 100644 index 00000000000000..c417a8ad120719 --- /dev/null +++ b/test/irb/type_completion/test_type_analyze.rb @@ -0,0 +1,697 @@ +# frozen_string_literal: true + +# Run test only when Ruby >= 3.0 and %w[prism rbs] are available +return unless RUBY_VERSION >= '3.0.0' +return if RUBY_ENGINE == 'truffleruby' # needs endless method definition +begin + require 'prism' + require 'rbs' +rescue LoadError + return +end + + +require 'irb/completion' +require 'irb/type_completion/completor' +require_relative '../helper' + +module TestIRB + class TypeCompletionAnalyzeTest < TestCase + def setup + IRB::TypeCompletion::Types.load_rbs_builder unless IRB::TypeCompletion::Types.rbs_builder + end + + def empty_binding + binding + end + + def analyze(code, binding: nil) + completor = IRB::TypeCompletion::Completor.new + def completor.handle_error(e) + raise e + end + completor.analyze(code, binding || empty_binding) + end + + def assert_analyze_type(code, type, token = nil, binding: empty_binding) + result_type, result_token = analyze(code, binding: binding) + assert_equal type, result_type + assert_equal token, result_token if token + end + + def assert_call(code, include: nil, exclude: nil, binding: nil) + raise ArgumentError if include.nil? && exclude.nil? + + result = analyze(code.strip, binding: binding) + type = result[1] if result[0] == :call + klasses = type.types.flat_map do + _1.klass.singleton_class? ? [_1.klass.superclass, _1.klass] : _1.klass + end + assert ([*include] - klasses).empty?, "Expected #{klasses} to include #{include}" if include + assert (klasses & [*exclude]).empty?, "Expected #{klasses} not to include #{exclude}" if exclude + end + + def test_lvar_ivar_gvar_cvar + assert_analyze_type('puts(x', :lvar_or_method, 'x') + assert_analyze_type('puts($', :gvar, '$') + assert_analyze_type('puts($x', :gvar, '$x') + assert_analyze_type('puts(@', :ivar, '@') + assert_analyze_type('puts(@x', :ivar, '@x') + assert_analyze_type('puts(@@', :cvar, '@@') + assert_analyze_type('puts(@@x', :cvar, '@@x') + end + + def test_rescue + assert_call '(1 rescue 1.0).', include: [Integer, Float] + assert_call 'a=""; (a=1) rescue (a=1.0); a.', include: [Integer, Float], exclude: String + assert_call 'begin; 1; rescue; 1.0; end.', include: [Integer, Float] + assert_call 'begin; 1; rescue A; 1.0; rescue B; 1i; end.', include: [Integer, Float, Complex] + assert_call 'begin; 1i; rescue; 1.0; else; 1; end.', include: [Integer, Float], exclude: Complex + assert_call 'begin; 1; rescue; 1.0; ensure; 1i; end.', include: [Integer, Float], exclude: Complex + assert_call 'begin; 1i; rescue; 1.0; else; 1; ensure; 1i; end.', include: [Integer, Float], exclude: Complex + assert_call 'a=""; begin; a=1; rescue; a=1.0; end; a.', include: [Integer, Float], exclude: [String] + assert_call 'a=""; begin; a=1; rescue; a=1.0; else; a=1r; end; a.', include: [Float, Rational], exclude: [String, Integer] + assert_call 'a=""; begin; a=1; rescue; a=1.0; else; a=1r; ensure; a = 1i; end; a.', include: Complex, exclude: [Float, Rational, String, Integer] + end + + def test_rescue_assign + assert_equal [:lvar_or_method, 'a'], analyze('begin; rescue => a')[0, 2] + assert_equal [:gvar, '$a'], analyze('begin; rescue => $a')[0, 2] + assert_equal [:ivar, '@a'], analyze('begin; rescue => @a')[0, 2] + assert_equal [:cvar, '@@a'], analyze('begin; rescue => @@a')[0, 2] + assert_equal [:const, 'A'], analyze('begin; rescue => A').values_at(0, 2) + assert_equal [:call, 'b'], analyze('begin; rescue => a.b').values_at(0, 2) + end + + def test_ref + bind = eval <<~RUBY + class (Module.new)::A + @ivar = :a + @@cvar = 'a' + binding + end + RUBY + assert_call('STDIN.', include: STDIN.singleton_class) + assert_call('$stdin.', include: $stdin.singleton_class) + assert_call('@ivar.', include: Symbol, binding: bind) + assert_call('@@cvar.', include: String, binding: bind) + lbind = eval('lvar = 1; binding') + assert_call('lvar.', include: Integer, binding: lbind) + end + + def test_self_ivar_ref + obj = Object.new + obj.instance_variable_set(:@hoge, 1) + assert_call('obj.instance_eval { @hoge.', include: Integer, binding: obj.instance_eval { binding }) + if Class.method_defined? :attached_object + bind = binding + assert_call('obj.instance_eval { @hoge.', include: Integer, binding: bind) + assert_call('@hoge = 1.0; obj.instance_eval { @hoge.', include: Integer, exclude: Float, binding: bind) + assert_call('@hoge = 1.0; obj.instance_eval { @hoge = "" }; @hoge.', include: Float, exclude: [Integer, String], binding: bind) + assert_call('@fuga = 1.0; obj.instance_eval { @fuga.', exclude: Float, binding: bind) + assert_call('@fuga = 1.0; obj.instance_eval { @fuga = "" }; @fuga.', include: Float, exclude: [Integer, String], binding: bind) + end + end + + class CVarModule + @@test_cvar = 1 + end + def test_module_cvar_ref + bind = binding + assert_call('@@foo=1; class A; @@foo.', exclude: Integer, binding: bind) + assert_call('@@foo=1; class A; @@foo=1.0; @@foo.', include: Float, exclude: Integer, binding: bind) + assert_call('@@foo=1; class A; @@foo=1.0; end; @@foo.', include: Integer, exclude: Float, binding: bind) + assert_call('module CVarModule; @@test_cvar.', include: Integer, binding: bind) + assert_call('class Array; @@foo = 1; end; class Array; @@foo.', include: Integer, binding: bind) + assert_call('class Array; class B; @@foo = 1; end; class B; @@foo.', include: Integer, binding: bind) + assert_call('class Array; class B; @@foo = 1; end; @@foo.', exclude: Integer, binding: bind) + end + + def test_lvar_singleton_method + a = 1 + b = +'' + c = Object.new + d = [a, b, c] + binding = Kernel.binding + assert_call('a.', include: Integer, exclude: String, binding: binding) + assert_call('b.', include: b.singleton_class, exclude: [Integer, Object], binding: binding) + assert_call('c.', include: c.singleton_class, exclude: [Integer, String], binding: binding) + assert_call('d.', include: d.class, exclude: [Integer, String, Object], binding: binding) + assert_call('d.sample.', include: [Integer, String, Object], exclude: [b.singleton_class, c.singleton_class], binding: binding) + end + + def test_local_variable_assign + assert_call('(a = 1).', include: Integer) + assert_call('a = 1; a = ""; a.', include: String, exclude: Integer) + assert_call('1 => a; a.', include: Integer) + end + + def test_block_symbol + assert_call('[1].map(&:', include: Integer) + assert_call('1.to_s.tap(&:', include: String) + end + + def test_union_splat + assert_call('a, = [[:a], 1, nil].sample; a.', include: [Symbol, Integer, NilClass], exclude: Object) + assert_call('[[:a], 1, nil].each do _2; _1.', include: [Symbol, Integer, NilClass], exclude: Object) + assert_call('a = [[:a], 1, nil, ("a".."b")].sample; [*a].sample.', include: [Symbol, Integer, NilClass, String], exclude: Object) + end + + def test_range + assert_call('(1..2).first.', include: Integer) + assert_call('("a".."b").first.', include: String) + assert_call('(..1.to_f).first.', include: Float) + assert_call('(1.to_s..).first.', include: String) + assert_call('(1..2.0).first.', include: [Float, Integer]) + end + + def test_conditional_assign + assert_call('a = 1; a = "" if cond; a.', include: [String, Integer], exclude: NilClass) + assert_call('a = 1 if cond; a.', include: [Integer, NilClass]) + assert_call(<<~RUBY, include: [String, Symbol], exclude: [Integer, NilClass]) + a = 1 + cond ? a = '' : a = :a + a. + RUBY + end + + def test_block + assert_call('nil.then{1}.', include: Integer, exclude: NilClass) + assert_call('nil.then(&:to_s).', include: String, exclude: NilClass) + end + + def test_block_break + assert_call('1.tap{}.', include: [Integer], exclude: NilClass) + assert_call('1.tap{break :a}.', include: [Symbol, Integer], exclude: NilClass) + assert_call('1.tap{break :a, :b}[0].', include: Symbol) + assert_call('1.tap{break :a; break "a"}.', include: [Symbol, Integer], exclude: [NilClass, String]) + assert_call('1.tap{break :a if b}.', include: [Symbol, Integer], exclude: NilClass) + assert_call('1.tap{break :a; break "a" if b}.', include: [Symbol, Integer], exclude: [NilClass, String]) + assert_call('1.tap{if cond; break :a; else; break "a"; end}.', include: [Symbol, Integer, String], exclude: NilClass) + end + + def test_instance_eval + assert_call('1.instance_eval{:a.then{self.', include: Integer, exclude: Symbol) + assert_call('1.then{:a.instance_eval{self.', include: Symbol, exclude: Integer) + end + + def test_block_next + assert_call('nil.then{1}.', include: Integer, exclude: [NilClass, Object]) + assert_call('nil.then{next 1}.', include: Integer, exclude: [NilClass, Object]) + assert_call('nil.then{next :a, :b}[0].', include: Symbol) + assert_call('nil.then{next 1; 1.0}.', include: Integer, exclude: [Float, NilClass, Object]) + assert_call('nil.then{next 1; next 1.0}.', include: Integer, exclude: [Float, NilClass, Object]) + assert_call('nil.then{1 if cond}.', include: [Integer, NilClass], exclude: Object) + assert_call('nil.then{if cond; 1; else; 1.0; end}.', include: [Integer, Float], exclude: [NilClass, Object]) + assert_call('nil.then{next 1 if cond; 1.0}.', include: [Integer, Float], exclude: [NilClass, Object]) + assert_call('nil.then{if cond; next 1; else; next 1.0; end; "a"}.', include: [Integer, Float], exclude: [String, NilClass, Object]) + assert_call('nil.then{if cond; next 1; else; next 1.0; end; next "a"}.', include: [Integer, Float], exclude: [String, NilClass, Object]) + end + + def test_vars_with_branch_termination + assert_call('a=1; tap{break; a=//}; a.', include: Integer, exclude: Regexp) + assert_call('a=1; tap{a=1.0; break; a=//}; a.', include: [Integer, Float], exclude: Regexp) + assert_call('a=1; tap{next; a=//}; a.', include: Integer, exclude: Regexp) + assert_call('a=1; tap{a=1.0; next; a=//}; a.', include: [Integer, Float], exclude: Regexp) + assert_call('a=1; while cond; break; a=//; end; a.', include: Integer, exclude: Regexp) + assert_call('a=1; while cond; a=1.0; break; a=//; end; a.', include: [Integer, Float], exclude: Regexp) + assert_call('a=1; ->{ break; a=// }; a.', include: Integer, exclude: Regexp) + assert_call('a=1; ->{ a=1.0; break; a=// }; a.', include: [Integer, Float], exclude: Regexp) + + assert_call('a=1; tap{ break; a=// if cond }; a.', include: Integer, exclude: Regexp) + assert_call('a=1; tap{ next; a=// if cond }; a.', include: Integer, exclude: Regexp) + assert_call('a=1; while cond; break; a=// if cond; end; a.', include: Integer, exclude: Regexp) + assert_call('a=1; ->{ break; a=// if cond }; a.', include: Integer, exclude: Regexp) + + assert_call('a=1; tap{if cond; a=:a; break; a=""; end; a.', include: Integer, exclude: [Symbol, String]) + assert_call('a=1; tap{if cond; a=:a; break; a=""; end; a=//}; a.', include: [Integer, Symbol, Regexp], exclude: String) + assert_call('a=1; tap{if cond; a=:a; break; a=""; else; break; end; a=//}; a.', include: [Integer, Symbol], exclude: [String, Regexp]) + assert_call('a=1; tap{if cond; a=:a; next; a=""; end; a.', include: Integer, exclude: [Symbol, String]) + assert_call('a=1; tap{if cond; a=:a; next; a=""; end; a=//}; a.', include: [Integer, Symbol, Regexp], exclude: String) + assert_call('a=1; tap{if cond; a=:a; next; a=""; else; next; end; a=//}; a.', include: [Integer, Symbol], exclude: [String, Regexp]) + assert_call('def f(a=1); if cond; a=:a; return; a=""; end; a.', include: Integer, exclude: [Symbol, String]) + assert_call('a=1; while cond; if cond; a=:a; break; a=""; end; a.', include: Integer, exclude: [Symbol, String]) + assert_call('a=1; while cond; if cond; a=:a; break; a=""; end; a=//; end; a.', include: [Integer, Symbol, Regexp], exclude: String) + assert_call('a=1; while cond; if cond; a=:a; break; a=""; else; break; end; a=//; end; a.', include: [Integer, Symbol], exclude: [String, Regexp]) + assert_call('a=1; ->{ if cond; a=:a; break; a=""; end; a.', include: Integer, exclude: [Symbol, String]) + assert_call('a=1; ->{ if cond; a=:a; break; a=""; end; a=// }; a.', include: [Integer, Symbol, Regexp], exclude: String) + assert_call('a=1; ->{ if cond; a=:a; break; a=""; else; break; end; a=// }; a.', include: [Integer, Symbol], exclude: [String, Regexp]) + + # continue evaluation on terminated branch + assert_call('a=1; tap{ a=1.0; break; a=// if cond; a.', include: [Regexp, Float], exclude: Integer) + assert_call('a=1; tap{ a=1.0; next; a=// if cond; a.', include: [Regexp, Float], exclude: Integer) + assert_call('a=1; ->{ a=1.0; break; a=// if cond; a.', include: [Regexp, Float], exclude: Integer) + assert_call('a=1; while cond; a=1.0; break; a=// if cond; a.', include: [Regexp, Float], exclude: Integer) + end + + def test_to_str_to_int + sobj = Struct.new(:to_str).new('a') + iobj = Struct.new(:to_int).new(1) + binding = Kernel.binding + assert_equal String, ([] * sobj).class + assert_equal Array, ([] * iobj).class + assert_call('([]*sobj).', include: String, exclude: Array, binding: binding) + assert_call('([]*iobj).', include: Array, exclude: String, binding: binding) + end + + def test_method_select + assert_call('([]*4).', include: Array, exclude: String) + assert_call('([]*"").', include: String, exclude: Array) + assert_call('([]*unknown).', include: [String, Array]) + assert_call('p(1).', include: Integer) + assert_call('p(1, 2).', include: Array, exclude: Integer) + assert_call('2.times.', include: Enumerator, exclude: Integer) + assert_call('2.times{}.', include: Integer, exclude: Enumerator) + end + + def test_interface_match_var + assert_call('([1]+[:a]+["a"]).sample.', include: [Integer, String, Symbol]) + end + + def test_lvar_scope + code = <<~RUBY + tap { a = :never } + a = 1 if x? + tap {|a| a = :never } + tap { a = 'maybe' } + a = {} if x? + a. + RUBY + assert_call(code, include: [Hash, Integer, String], exclude: [Symbol]) + end + + def test_lvar_scope_complex + assert_call('if cond; a = 1; else; tap { a = :a }; end; a.', include: [NilClass, Integer, Symbol], exclude: [Object]) + assert_call('def f; if cond; a = 1; return; end; tap { a = :a }; a.', include: [NilClass, Symbol], exclude: [Integer, Object]) + assert_call('def f; if cond; return; a = 1; end; tap { a = :a }; a.', include: [NilClass, Symbol], exclude: [Integer, Object]) + assert_call('def f; if cond; return; if cond; return; a = 1; end; end; tap { a = :a }; a.', include: [NilClass, Symbol], exclude: [Integer, Object]) + assert_call('def f; if cond; return; if cond; return; a = 1; end; end; tap { a = :a }; a.', include: [NilClass, Symbol], exclude: [Integer, Object]) + end + + def test_gvar_no_scope + code = <<~RUBY + tap { $a = :maybe } + $a = 'maybe' if x? + $a. + RUBY + assert_call(code, include: [Symbol, String]) + end + + def test_ivar_no_scope + code = <<~RUBY + tap { @a = :maybe } + @a = 'maybe' if x? + @a. + RUBY + assert_call(code, include: [Symbol, String]) + end + + def test_massign + assert_call('(a,=1).', include: Integer) + assert_call('(a,=[*1])[0].', include: Integer) + assert_call('(a,=[1,2])[0].', include: Integer) + assert_call('a,=[1,2]; a.', include: Integer, exclude: Array) + assert_call('a,b=[1,2]; a.', include: Integer, exclude: Array) + assert_call('a,b=[1,2]; b.', include: Integer, exclude: Array) + assert_call('a,*,b=[1,2]; a.', include: Integer, exclude: Array) + assert_call('a,*,b=[1,2]; b.', include: Integer, exclude: Array) + assert_call('a,*b=[1,2]; a.', include: Integer, exclude: Array) + assert_call('a,*b=[1,2]; b.', include: Array, exclude: Integer) + assert_call('a,*b=[1,2]; b.sample.', include: Integer) + assert_call('a,*,(*)=[1,2]; a.', include: Integer) + assert_call('*a=[1,2]; a.', include: Array, exclude: Integer) + assert_call('*a=[1,2]; a.sample.', include: Integer) + assert_call('a,*b,c=[1,2,3]; b.', include: Array, exclude: Integer) + assert_call('a,*b,c=[1,2,3]; b.sample.', include: Integer) + assert_call('a,b=(cond)?[1,2]:[:a,:b]; a.', include: [Integer, Symbol]) + assert_call('a,b=(cond)?[1,2]:[:a,:b]; b.', include: [Integer, Symbol]) + assert_call('a,b=(cond)?[1,2]:"s"; a.', include: [Integer, String]) + assert_call('a,b=(cond)?[1,2]:"s"; b.', include: Integer, exclude: String) + assert_call('a,*b=(cond)?[1,2]:"s"; a.', include: [Integer, String]) + assert_call('a,*b=(cond)?[1,2]:"s"; b.', include: Array, exclude: [Integer, String]) + assert_call('a,*b=(cond)?[1,2]:"s"; b.sample.', include: Integer, exclude: String) + assert_call('*a=(cond)?[1,2]:"s"; a.', include: Array, exclude: [Integer, String]) + assert_call('*a=(cond)?[1,2]:"s"; a.sample.', include: [Integer, String]) + assert_call('a,(b,),c=[1,[:a],4]; b.', include: Symbol) + assert_call('a,(b,(c,))=1; a.', include: Integer) + assert_call('a,(b,(*c))=1; c.', include: Array) + assert_call('(a=1).b, c = 1; a.', include: Integer) + assert_call('a, ((b=1).c, d) = 1; b.', include: Integer) + assert_call('a, b[c=1] = 1; c.', include: Integer) + assert_call('a, b[*(c=1)] = 1; c.', include: Integer) + # incomplete massign + assert_analyze_type('a,b', :lvar_or_method, 'b') + assert_call('(a=1).b, a.', include: Integer) + assert_call('a=1; *a.', include: Integer) + end + + def test_field_assign + assert_call('(a.!=1).', exclude: Integer) + assert_call('(a.b=1).', include: Integer, exclude: NilClass) + assert_call('(a&.b=1).', include: Integer) + assert_call('(nil&.b=1).', include: NilClass) + assert_call('(a[]=1).', include: Integer) + assert_call('(a[b]=1).', include: Integer) + assert_call('(a.[]=1).', exclude: Integer) + end + + def test_def + assert_call('def f; end.', include: Symbol) + assert_call('s=""; def s.f; self.', include: String) + assert_call('def (a="").f; end; a.', include: String) + assert_call('def f(a=1); a.', include: Integer) + assert_call('def f(**nil); 1.', include: Integer) + assert_call('def f((*),*); 1.', include: Integer) + assert_call('def f(a,*b); b.', include: Array) + assert_call('def f(a,x:1); x.', include: Integer) + assert_call('def f(a,x:,**); 1.', include: Integer) + assert_call('def f(a,x:,**y); y.', include: Hash) + assert_call('def f((*a)); a.', include: Array) + assert_call('def f(a,b=1,*c,d,x:0,y:,**z,&e); e.arity.', include: Integer) + assert_call('def f(...); 1.', include: Integer) + assert_call('def f(a,...); 1.', include: Integer) + assert_call('def f(...); g(...); 1.', include: Integer) + assert_call('def f(*,**,&); g(*,**,&); 1.', include: Integer) + assert_call('def f(*,**,&); {**}.', include: Hash) + assert_call('def f(*,**,&); [*,**].', include: Array) + assert_call('class Array; def f; self.', include: Array) + end + + def test_defined + assert_call('defined?(a.b+c).', include: [String, NilClass]) + assert_call('defined?(a = 1); tap { a = 1.0 }; a.', include: [Integer, Float, NilClass]) + end + + def test_ternary_operator + assert_call('condition ? 1.chr.', include: [String]) + assert_call('condition ? value : 1.chr.', include: [String]) + assert_call('condition ? cond ? cond ? value : cond ? value : 1.chr.', include: [String]) + end + + def test_block_parameter + assert_call('method { |arg = 1.chr.', include: [String]) + assert_call('method do |arg = 1.chr.', include: [String]) + assert_call('method { |arg1 = 1.|(2|3), arg2 = 1.chr.', include: [String]) + assert_call('method do |arg1 = 1.|(2|3), arg2 = 1.chr.', include: [String]) + end + + def test_self + integer_binding = 1.instance_eval { Kernel.binding } + assert_call('self.', include: [Integer], binding: integer_binding) + string = +'' + string_binding = string.instance_eval { Kernel.binding } + assert_call('self.', include: [string.singleton_class], binding: string_binding) + object = Object.new + object.instance_eval { @int = 1; @string = string } + object_binding = object.instance_eval { Kernel.binding } + assert_call('self.', include: [object.singleton_class], binding: object_binding) + assert_call('@int.', include: [Integer], binding: object_binding) + assert_call('@string.', include: [String], binding: object_binding) + end + + def test_optional_chain + assert_call('[1,nil].sample.', include: [Integer, NilClass]) + assert_call('[1,nil].sample&.', include: [Integer], exclude: [NilClass]) + assert_call('[1,nil].sample.chr.', include: [String], exclude: [NilClass]) + assert_call('[1,nil].sample&.chr.', include: [String, NilClass]) + assert_call('[1,nil].sample.chr&.ord.', include: [Integer], exclude: [NilClass]) + assert_call('a = 1; b.c(a = :a); a.', include: [Symbol], exclude: [Integer]) + assert_call('a = 1; b&.c(a = :a); a.', include: [Integer, Symbol]) + end + + def test_class_module + assert_call('class (1.', include: Integer) + assert_call('class (a=1)::B; end; a.', include: Integer) + assert_call('class Array; 1; end.', include: Integer) + assert_call('class ::Array; 1; end.', include: Integer) + assert_call('class Array::A; 1; end.', include: Integer) + assert_call('class Array; self.new.', include: Array) + assert_call('class ::Array; self.new.', include: Array) + assert_call('class Array::A; self.', include: Class) + assert_call('class (a=1)::A; end; a.', include: Integer) + assert_call('module M; 1; end.', include: Integer) + assert_call('module ::M; 1; end.', include: Integer) + assert_call('module Array::M; 1; end.', include: Integer) + assert_call('module M; self.', include: Module) + assert_call('module Array::M; self.', include: Module) + assert_call('module ::M; self.', include: Module) + assert_call('module (a=1)::M; end; a.', include: Integer) + assert_call('class << Array; 1; end.', include: Integer) + assert_call('class << a; 1; end.', include: Integer) + assert_call('a = ""; class << a; self.superclass.', include: Class) + end + + def test_constant_path + assert_call('class A; X=1; class B; X=""; X.', include: String, exclude: Integer) + assert_call('class A; X=1; class B; X=""; end; X.', include: Integer, exclude: String) + assert_call('class A; class B; X=1; end; end; class A; class B; X.', include: Integer) + assert_call('module IRB; VERSION.', include: String) + assert_call('module IRB; IRB::VERSION.', include: String) + assert_call('module IRB; VERSION=1; VERSION.', include: Integer) + assert_call('module IRB; VERSION=1; IRB::VERSION.', include: Integer) + assert_call('module IRB; module A; VERSION.', include: String) + assert_call('module IRB; module A; VERSION=1; VERSION.', include: Integer) + assert_call('module IRB; module A; VERSION=1; IRB::VERSION.', include: String) + assert_call('module IRB; module A; VERSION=1; end; VERSION.', include: String) + assert_call('module IRB; IRB=1; IRB.', include: Integer) + assert_call('module IRB; IRB=1; ::IRB::VERSION.', include: String) + module_binding = eval 'module ::IRB; binding; end' + assert_call('VERSION.', include: NilClass) + assert_call('VERSION.', include: String, binding: module_binding) + assert_call('IRB::VERSION.', include: String, binding: module_binding) + assert_call('A = 1; module M; A += 0.5; A.', include: Float) + assert_call('::A = 1; module M; A += 0.5; A.', include: Float) + assert_call('::A = 1; module M; A += 0.5; ::A.', include: Integer) + assert_call('IRB::A = 1; IRB::A += 0.5; IRB::A.', include: Float) + end + + def test_literal + assert_call('1.', include: Integer) + assert_call('1.0.', include: Float) + assert_call('1r.', include: Rational) + assert_call('1i.', include: Complex) + assert_call('true.', include: TrueClass) + assert_call('false.', include: FalseClass) + assert_call('nil.', include: NilClass) + assert_call('().', include: NilClass) + assert_call('//.', include: Regexp) + assert_call('/#{a=1}/.', include: Regexp) + assert_call('/#{a=1}/; a.', include: Integer) + assert_call(':a.', include: Symbol) + assert_call(':"#{a=1}".', include: Symbol) + assert_call(':"#{a=1}"; a.', include: Integer) + assert_call('"".', include: String) + assert_call('"#$a".', include: String) + assert_call('("a" "b").', include: String) + assert_call('"#{a=1}".', include: String) + assert_call('"#{a=1}"; a.', include: Integer) + assert_call('``.', include: String) + assert_call('`#{a=1}`.', include: String) + assert_call('`#{a=1}`; a.', include: Integer) + end + + def test_redo_retry_yield_super + assert_call('a=nil; tap do a=1; redo; a=1i; end; a.', include: Integer, exclude: Complex) + assert_call('a=nil; tap do a=1; retry; a=1i; end; a.', include: Integer, exclude: Complex) + assert_call('a = 0; a = yield; a.', include: Object, exclude: Integer) + assert_call('yield 1,(a=1); a.', include: Integer) + assert_call('a = 0; a = super; a.', include: Object, exclude: Integer) + assert_call('a = 0; a = super(1); a.', include: Object, exclude: Integer) + assert_call('super 1,(a=1); a.', include: Integer) + end + + def test_rarely_used_syntax + # FlipFlop + assert_call('if (a=1).even?..(a=1.0).even; a.', include: [Integer, Float]) + # MatchLastLine + assert_call('if /regexp/; 1.', include: Integer) + assert_call('if /reg#{a=1}exp/; a.', include: Integer) + # BlockLocalVariable + assert_call('tap do |i;a| a=1; a.', include: Integer) + # BEGIN{} END{} + assert_call('BEGIN{1.', include: Integer) + assert_call('END{1.', include: Integer) + # MatchWrite + assert_call('a=1; /(?)/=~b; a.', include: [String, NilClass], exclude: Integer) + # OperatorWrite with block `a[&b]+=c` + assert_call('a=[1]; (a[0,&:to_a]+=1.0).', include: Float) + assert_call('a=[1]; (a[0,&b]+=1.0).', include: Float) + end + + def test_hash + assert_call('{}.', include: Hash) + assert_call('{**a}.', include: Hash) + assert_call('{ rand: }.values.sample.', include: Float) + assert_call('rand=""; { rand: }.values.sample.', include: String, exclude: Float) + assert_call('{ 1 => 1.0 }.keys.sample.', include: Integer, exclude: Float) + assert_call('{ 1 => 1.0 }.values.sample.', include: Float, exclude: Integer) + assert_call('a={1=>1.0}; {"a"=>1i,**a}.keys.sample.', include: [Integer, String]) + assert_call('a={1=>1.0}; {"a"=>1i,**a}.values.sample.', include: [Float, Complex]) + end + + def test_array + assert_call('[1,2,3].sample.', include: Integer) + assert_call('a = 1.0; [1,2,a].sample.', include: [Integer, Float]) + assert_call('a = [1.0]; [1,2,*a].sample.', include: [Integer, Float]) + end + + def test_numbered_parameter + assert_call('loop{_1.', include: NilClass) + assert_call('1.tap{_1.', include: Integer) + assert_call('1.tap{_3.', include: NilClass, exclude: Integer) + assert_call('[:a,1].tap{_1.', include: Array, exclude: [Integer, Symbol]) + assert_call('[:a,1].tap{_2.', include: [Symbol, Integer], exclude: Array) + assert_call('[:a,1].tap{_2; _1.', include: [Symbol, Integer], exclude: Array) + assert_call('[:a].each_with_index{_1.', include: Symbol, exclude: [Integer, Array]) + assert_call('[:a].each_with_index{_2; _1.', include: Symbol, exclude: [Integer, Array]) + assert_call('[:a].each_with_index{_2.', include: Integer, exclude: Symbol) + end + + def test_if_unless + assert_call('if cond; 1; end.', include: Integer) + assert_call('unless true; 1; end.', include: Integer) + assert_call('a=1; (a=1.0) if cond; a.', include: [Integer, Float]) + assert_call('a=1; (a=1.0) unless cond; a.', include: [Integer, Float]) + assert_call('a=1; 123 if (a=1.0).foo; a.', include: Float, exclude: Integer) + assert_call('if cond; a=1; end; a.', include: [Integer, NilClass]) + assert_call('a=1; if cond; a=1.0; elsif cond; a=1r; else; a=1i; end; a.', include: [Float, Rational, Complex], exclude: Integer) + assert_call('a=1; if cond; a=1.0; else; a.', include: Integer, exclude: Float) + assert_call('a=1; if (a=1.0).foo; a.', include: Float, exclude: Integer) + assert_call('a=1; if (a=1.0).foo; end; a.', include: Float, exclude: Integer) + assert_call('a=1; if (a=1.0).foo; else; a.', include: Float, exclude: Integer) + assert_call('a=1; if (a=1.0).foo; elsif a.', include: Float, exclude: Integer) + assert_call('a=1; if (a=1.0).foo; elsif (a=1i); else; a.', include: Complex, exclude: [Integer, Float]) + end + + def test_while_until + assert_call('while cond; 123; end.', include: NilClass) + assert_call('until cond; 123; end.', include: NilClass) + assert_call('a=1; a=1.0 while cond; a.', include: [Integer, Float]) + assert_call('a=1; a=1.0 until cond; a.', include: [Integer, Float]) + assert_call('a=1; 1 while (a=1.0).foo; a.', include: Float, exclude: Integer) + assert_call('while cond; break 1; end.', include: Integer) + assert_call('while cond; a=1; end; a.', include: Integer) + assert_call('a=1; while cond; a=1.0; end; a.', include: [Integer, Float]) + assert_call('a=1; while (a=1.0).foo; end; a.', include: Float, exclude: Integer) + end + + def test_for + assert_call('for i in [1,2,3]; i.', include: Integer) + assert_call('for i,j in [1,2,3]; i.', include: Integer) + assert_call('for *,(*) in [1,2,3]; 1.', include: Integer) + assert_call('for *i in [1,2,3]; i.sample.', include: Integer) + assert_call('for (a=1).b in [1,2,3]; a.', include: Integer) + assert_call('for Array::B in [1,2,3]; Array::B.', include: Integer) + assert_call('for A in [1,2,3]; A.', include: Integer) + assert_call('for $a in [1,2,3]; $a.', include: Integer) + assert_call('for @a in [1,2,3]; @a.', include: Integer) + assert_call('for i in [1,2,3]; end.', include: Array) + assert_call('for i in [1,2,3]; break 1.0; end.', include: [Array, Float]) + assert_call('i = 1.0; for i in [1,2,3]; end; i.', include: [Integer, Float]) + assert_call('a = 1.0; for i in [1,2,3]; a = 1i; end; a.', include: [Float, Complex]) + end + + def test_special_var + assert_call('__FILE__.', include: String) + assert_call('__LINE__.', include: Integer) + assert_call('__ENCODING__.', include: Encoding) + assert_call('$1.', include: String) + assert_call('$&.', include: String) + end + + def test_and_or + assert_call('(1&&1.0).', include: Float, exclude: Integer) + assert_call('(nil&&1.0).', include: NilClass) + assert_call('(nil||1).', include: Integer) + assert_call('(1||1.0).', include: Float) + end + + def test_opwrite + assert_call('a=[]; a*=1; a.', include: Array) + assert_call('a=[]; a*=""; a.', include: String) + assert_call('a=[1,false].sample; a||=1.0; a.', include: [Integer, Float]) + assert_call('a=1; a&&=1.0; a.', include: Float, exclude: Integer) + assert_call('(a=1).b*=1; a.', include: Integer) + assert_call('(a=1).b||=1; a.', include: Integer) + assert_call('(a=1).b&&=1; a.', include: Integer) + assert_call('[][a=1]&&=1; a.', include: Integer) + assert_call('[][a=1]||=1; a.', include: Integer) + assert_call('[][a=1]+=1; a.', include: Integer) + assert_call('([1][0]+=1.0).', include: Float) + assert_call('([1.0][0]+=1).', include: Float) + assert_call('A=nil; A||=1; A.', include: Integer) + assert_call('A=1; A&&=1.0; A.', include: Float) + assert_call('A=1; A+=1.0; A.', include: Float) + assert_call('Array::A||=1; Array::A.', include: Integer) + assert_call('Array::A=1; Array::A&&=1.0; Array::A.', include: Float) + end + + def test_case_when + assert_call('case x; when A; 1; when B; 1.0; end.', include: [Integer, Float, NilClass]) + assert_call('case x; when A; 1; when B; 1.0; else; 1r; end.', include: [Integer, Float, Rational], exclude: NilClass) + assert_call('case; when (a=1); a.', include: Integer) + assert_call('case x; when (a=1); a.', include: Integer) + assert_call('a=1; case (a=1.0); when A; a.', include: Float, exclude: Integer) + assert_call('a=1; case (a=1.0); when A; end; a.', include: Float, exclude: Integer) + assert_call('a=1; case x; when A; a=1.0; else; a=1r; end; a.', include: [Float, Rational], exclude: Integer) + assert_call('a=1; case x; when A; a=1.0; when B; a=1r; end; a.', include: [Float, Rational, Integer]) + end + + def test_case_in + assert_call('case x; in A; 1; in B; 1.0; end.', include: [Integer, Float], exclude: NilClass) + assert_call('case x; in A; 1; in B; 1.0; else; 1r; end.', include: [Integer, Float, Rational], exclude: NilClass) + assert_call('a=""; case 1; in A; a=1; in B; a=1.0; end; a.', include: [Integer, Float], exclude: String) + assert_call('a=""; case 1; in A; a=1; in B; a=1.0; else; a=1r; end; a.', include: [Integer, Float, Rational], exclude: String) + assert_call('case 1; in x; x.', include: Integer) + assert_call('case x; in A if (a=1); a.', include: Integer) + assert_call('case x; in ^(a=1); a.', include: Integer) + assert_call('case x; in [1, String => a, 2]; a.', include: String) + assert_call('case x; in [*a, 1]; a.', include: Array) + assert_call('case x; in [1, *a]; a.', include: Array) + assert_call('case x; in [*a, 1, *b]; a.', include: Array) + assert_call('case x; in [*a, 1, *b]; b.', include: Array) + assert_call('case x; in {a: {b: **c}}; c.', include: Hash) + assert_call('case x; in (String | { x: Integer, y: ^$a }) => a; a.', include: [String, Hash]) + end + + def test_pattern_match + assert_call('1 in a; a.', include: Integer) + assert_call('a=1; x in String=>a; a.', include: [Integer, String]) + assert_call('a=1; x=>String=>a; a.', include: String, exclude: Integer) + end + + def test_bottom_type_termination + assert_call('a=1; tap { raise; a=1.0; a.', include: Float) + assert_call('a=1; tap { loop{}; a=1.0; a.', include: Float) + assert_call('a=1; tap { raise; a=1.0 } a.', include: Integer, exclude: Float) + assert_call('a=1; tap { loop{}; a=1.0 } a.', include: Integer, exclude: Float) + end + + def test_call_parameter + assert_call('f((x=1),*b,c:1,**d,&e); x.', include: Integer) + assert_call('f(a,*(x=1),c:1,**d,&e); x.', include: Integer) + assert_call('f(a,*b,(x=1):1,**d,&e); x.', include: Integer) + assert_call('f(a,*b,c:(x=1),**d,&e); x.', include: Integer) + assert_call('f(a,*b,c:1,**(x=1),&e); x.', include: Integer) + assert_call('f(a,*b,c:1,**d,&(x=1)); x.', include: Integer) + assert_call('f((x=1)=>1); x.', include: Integer) + end + + def test_block_args + assert_call('[1,2,3].tap{|a| a.', include: Array) + assert_call('[1,2,3].tap{|a,b| a.', include: Integer) + assert_call('[1,2,3].tap{|(a,b)| a.', include: Integer) + assert_call('[1,2,3].tap{|a,*b| b.', include: Array) + assert_call('[1,2,3].tap{|a=1.0| a.', include: [Array, Float]) + assert_call('[1,2,3].tap{|a,**b| b.', include: Hash) + assert_call('1.tap{|(*),*,**| 1.', include: Integer) + end + + def test_array_aref + assert_call('[1][0..].', include: [Array, NilClass], exclude: Integer) + assert_call('[1][0].', include: Integer, exclude: [Array, NilClass]) + assert_call('[1].[](0).', include: Integer, exclude: [Array, NilClass]) + assert_call('[1].[](0){}.', include: Integer, exclude: [Array, NilClass]) + end + end +end diff --git a/test/irb/type_completion/test_type_completor.rb b/test/irb/type_completion/test_type_completor.rb new file mode 100644 index 00000000000000..eed400b3e2d6f6 --- /dev/null +++ b/test/irb/type_completion/test_type_completor.rb @@ -0,0 +1,181 @@ +# frozen_string_literal: true + +# Run test only when Ruby >= 3.0 and %w[prism rbs] are available +return unless RUBY_VERSION >= '3.0.0' +return if RUBY_ENGINE == 'truffleruby' # needs endless method definition +begin + require 'prism' + require 'rbs' +rescue LoadError + return +end + +require 'irb/type_completion/completor' +require_relative '../helper' + +module TestIRB + class TypeCompletorTest < TestCase + def setup + IRB::TypeCompletion::Types.load_rbs_builder unless IRB::TypeCompletion::Types.rbs_builder + @completor = IRB::TypeCompletion::Completor.new + end + + def empty_binding + binding + end + + TARGET_REGEXP = /(@@|@|\$)?[a-zA-Z_]*[!?=]?$/ + + def assert_completion(code, binding: empty_binding, include: nil, exclude: nil) + raise ArgumentError if include.nil? && exclude.nil? + target = code[TARGET_REGEXP] + candidates = @completor.completion_candidates(code.delete_suffix(target), target, '', bind: binding) + assert ([*include] - candidates).empty?, "Expected #{candidates} to include #{include}" if include + assert (candidates & [*exclude]).empty?, "Expected #{candidates} not to include #{exclude}" if exclude + end + + def assert_doc_namespace(code, namespace, binding: empty_binding) + target = code[TARGET_REGEXP] + preposing = code.delete_suffix(target) + @completor.completion_candidates(preposing, target, '', bind: binding) + assert_equal namespace, @completor.doc_namespace(preposing, target, '', bind: binding) + end + + def test_require + assert_completion("require '", include: 'set') + assert_completion("require 's", include: 'set') + Dir.chdir(__dir__ + "/../../..") do + assert_completion("require_relative 'l", include: 'lib/irb') + end + # Incomplete double quote string is InterpolatedStringNode + assert_completion('require "', include: 'set') + assert_completion('require "s', include: 'set') + end + + def test_method_block_sym + assert_completion('[1].map(&:', include: 'abs') + assert_completion('[:a].map(&:', exclude: 'abs') + assert_completion('[1].map(&:a', include: 'abs') + assert_doc_namespace('[1].map(&:abs', 'Integer#abs') + end + + def test_symbol + sym = :test_completion_symbol + assert_completion(":test_com", include: sym.to_s) + end + + def test_call + assert_completion('1.', include: 'abs') + assert_completion('1.a', include: 'abs') + assert_completion('ran', include: 'rand') + assert_doc_namespace('1.abs', 'Integer#abs') + assert_doc_namespace('Integer.sqrt', 'Integer.sqrt') + assert_doc_namespace('rand', 'TestIRB::TypeCompletorTest#rand') + assert_doc_namespace('Object::rand', 'Object.rand') + end + + def test_lvar + bind = eval('lvar = 1; binding') + assert_completion('lva', binding: bind, include: 'lvar') + assert_completion('lvar.', binding: bind, include: 'abs') + assert_completion('lvar.a', binding: bind, include: 'abs') + assert_completion('lvar = ""; lvar.', binding: bind, include: 'ascii_only?') + assert_completion('lvar = ""; lvar.', include: 'ascii_only?') + assert_doc_namespace('lvar', 'Integer', binding: bind) + assert_doc_namespace('lvar.abs', 'Integer#abs', binding: bind) + assert_doc_namespace('lvar = ""; lvar.ascii_only?', 'String#ascii_only?', binding: bind) + end + + def test_const + assert_completion('Ar', include: 'Array') + assert_completion('::Ar', include: 'Array') + assert_completion('IRB::V', include: 'VERSION') + assert_completion('FooBar=1; F', include: 'FooBar') + assert_completion('::FooBar=1; ::F', include: 'FooBar') + assert_doc_namespace('Array', 'Array') + assert_doc_namespace('Array = 1; Array', 'Integer') + assert_doc_namespace('Object::Array', 'Array') + assert_completion('::', include: 'Array') + assert_completion('class ::', include: 'Array') + assert_completion('module IRB; class T', include: ['TypeCompletion', 'TracePoint']) + end + + def test_gvar + assert_completion('$', include: '$stdout') + assert_completion('$s', include: '$stdout') + assert_completion('$', exclude: '$foobar') + assert_completion('$foobar=1; $', include: '$foobar') + assert_doc_namespace('$foobar=1; $foobar', 'Integer') + assert_doc_namespace('$stdout', 'IO') + assert_doc_namespace('$stdout=1; $stdout', 'Integer') + end + + def test_ivar + bind = Object.new.instance_eval { @foo = 1; binding } + assert_completion('@', binding: bind, include: '@foo') + assert_completion('@f', binding: bind, include: '@foo') + assert_completion('@bar = 1; @', include: '@bar') + assert_completion('@bar = 1; @b', include: '@bar') + assert_doc_namespace('@bar = 1; @bar', 'Integer') + assert_doc_namespace('@foo', 'Integer', binding: bind) + assert_doc_namespace('@foo = 1.0; @foo', 'Float', binding: bind) + end + + def test_cvar + bind = eval('m=Module.new; module m::M; @@foo = 1; binding; end') + assert_equal(1, bind.eval('@@foo')) + assert_completion('@', binding: bind, include: '@@foo') + assert_completion('@@', binding: bind, include: '@@foo') + assert_completion('@@f', binding: bind, include: '@@foo') + assert_doc_namespace('@@foo', 'Integer', binding: bind) + assert_doc_namespace('@@foo = 1.0; @@foo', 'Float', binding: bind) + assert_completion('@@bar = 1; @', include: '@@bar') + assert_completion('@@bar = 1; @@', include: '@@bar') + assert_completion('@@bar = 1; @@b', include: '@@bar') + assert_doc_namespace('@@bar = 1; @@bar', 'Integer') + end + + def test_basic_object + bo = BasicObject.new + def bo.foo; end + bo.instance_eval { @bar = 1 } + bind = binding + bo_self_bind = bo.instance_eval { Kernel.binding } + assert_completion('bo.', binding: bind, include: 'foo') + assert_completion('def bo.baz; self.', binding: bind, include: 'foo') + assert_completion('[bo].first.', binding: bind, include: 'foo') + assert_doc_namespace('bo', 'BasicObject', binding: bind) + assert_doc_namespace('bo.__id__', 'BasicObject#__id__', binding: bind) + assert_doc_namespace('v = [bo]; v', 'Array', binding: bind) + assert_doc_namespace('v = [bo].first; v', 'BasicObject', binding: bind) + bo_self_bind = bo.instance_eval { Kernel.binding } + assert_completion('self.', binding: bo_self_bind, include: 'foo') + assert_completion('@', binding: bo_self_bind, include: '@bar') + assert_completion('@bar.', binding: bo_self_bind, include: 'abs') + assert_doc_namespace('self.__id__', 'BasicObject#__id__', binding: bo_self_bind) + assert_doc_namespace('@bar', 'Integer', binding: bo_self_bind) + if RUBY_VERSION >= '3.2.0' # Needs Class#attached_object to get instance variables from singleton class + assert_completion('def bo.baz; @bar.', binding: bind, include: 'abs') + assert_completion('def bo.baz; @', binding: bind, include: '@bar') + end + end + + def test_inspect + rbs_builder = IRB::TypeCompletion::Types.rbs_builder + assert_match(/TypeCompletion::Completor\(Prism: \d.+, RBS: \d.+\)/, @completor.inspect) + IRB::TypeCompletion::Types.instance_variable_set(:@rbs_builder, nil) + assert_match(/TypeCompletion::Completor\(Prism: \d.+, RBS: loading\)/, @completor.inspect) + IRB::TypeCompletion::Types.instance_variable_set(:@rbs_load_error, StandardError.new('[err]')) + assert_match(/TypeCompletion::Completor\(Prism: \d.+, RBS: .+\[err\].+\)/, @completor.inspect) + ensure + IRB::TypeCompletion::Types.instance_variable_set(:@rbs_builder, rbs_builder) + IRB::TypeCompletion::Types.instance_variable_set(:@rbs_load_error, nil) + end + + def test_none + candidates = @completor.completion_candidates('(', ')', '', bind: binding) + assert_equal [], candidates + assert_doc_namespace('()', nil) + end + end +end diff --git a/test/irb/type_completion/test_types.rb b/test/irb/type_completion/test_types.rb new file mode 100644 index 00000000000000..7698bd2fc07492 --- /dev/null +++ b/test/irb/type_completion/test_types.rb @@ -0,0 +1,89 @@ +# frozen_string_literal: true + +return unless RUBY_VERSION >= '3.0.0' +return if RUBY_ENGINE == 'truffleruby' # needs endless method definition + +require 'irb/type_completion/types' +require_relative '../helper' + +module TestIRB + class TypeCompletionTypesTest < TestCase + def test_type_inspect + true_type = IRB::TypeCompletion::Types::TRUE + false_type = IRB::TypeCompletion::Types::FALSE + nil_type = IRB::TypeCompletion::Types::NIL + string_type = IRB::TypeCompletion::Types::STRING + true_or_false = IRB::TypeCompletion::Types::UnionType[true_type, false_type] + array_type = IRB::TypeCompletion::Types::InstanceType.new Array, { Elem: true_or_false } + assert_equal 'nil', nil_type.inspect + assert_equal 'true', true_type.inspect + assert_equal 'false', false_type.inspect + assert_equal 'String', string_type.inspect + assert_equal 'Array', IRB::TypeCompletion::Types::InstanceType.new(Array).inspect + assert_equal 'true | false', true_or_false.inspect + assert_equal 'Array[Elem: true | false]', array_type.inspect + assert_equal 'Array', array_type.inspect_without_params + assert_equal 'Proc', IRB::TypeCompletion::Types::PROC.inspect + assert_equal 'Array.itself', IRB::TypeCompletion::Types::SingletonType.new(Array).inspect + end + + def test_type_from_object + obj = Object.new + bo = BasicObject.new + def bo.hash; 42; end # Needed to use this object as a hash key + arr = [1, 'a'] + hash = { 'key' => :value } + int_type = IRB::TypeCompletion::Types.type_from_object 1 + obj_type = IRB::TypeCompletion::Types.type_from_object obj + arr_type = IRB::TypeCompletion::Types.type_from_object arr + hash_type = IRB::TypeCompletion::Types.type_from_object hash + bo_type = IRB::TypeCompletion::Types.type_from_object bo + bo_arr_type = IRB::TypeCompletion::Types.type_from_object [bo] + bo_key_hash_type = IRB::TypeCompletion::Types.type_from_object({ bo => 1 }) + bo_value_hash_type = IRB::TypeCompletion::Types.type_from_object({ x: bo }) + + assert_equal Integer, int_type.klass + # Use singleton_class to autocomplete singleton methods + assert_equal obj.singleton_class, obj_type.klass + assert_equal Object.instance_method(:singleton_class).bind_call(bo), bo_type.klass + # Array and Hash are special + assert_equal Array, arr_type.klass + assert_equal Array, bo_arr_type.klass + assert_equal Hash, hash_type.klass + assert_equal Hash, bo_key_hash_type.klass + assert_equal Hash, bo_value_hash_type.klass + assert_equal BasicObject, bo_arr_type.params[:Elem].klass + assert_equal BasicObject, bo_key_hash_type.params[:K].klass + assert_equal BasicObject, bo_value_hash_type.params[:V].klass + assert_equal 'Object', obj_type.inspect + assert_equal 'Array[Elem: Integer | String]', arr_type.inspect + assert_equal 'Hash[K: String, V: Symbol]', hash_type.inspect + assert_equal 'Array.itself', IRB::TypeCompletion::Types.type_from_object(Array).inspect + assert_equal 'IRB::TypeCompletion.itself', IRB::TypeCompletion::Types.type_from_object(IRB::TypeCompletion).inspect + end + + def test_type_methods + s = +'' + class << s + def foobar; end + private def foobaz; end + end + String.define_method(:foobarbaz) {} + targets = [:foobar, :foobaz, :foobarbaz] + type = IRB::TypeCompletion::Types.type_from_object s + assert_equal [:foobar, :foobarbaz], targets & type.methods + assert_equal [:foobar, :foobaz, :foobarbaz], targets & type.all_methods + assert_equal [:foobarbaz], targets & IRB::TypeCompletion::Types::STRING.methods + assert_equal [:foobarbaz], targets & IRB::TypeCompletion::Types::STRING.all_methods + ensure + String.remove_method :foobarbaz + end + + def test_basic_object_methods + bo = BasicObject.new + def bo.foobar; end + type = IRB::TypeCompletion::Types.type_from_object bo + assert type.all_methods.include?(:foobar) + end + end +end