From 0a423d4c4e84bc024d5dabe86227c1f0b509a898 Mon Sep 17 00:00:00 2001 From: Samuel Giddins Date: Wed, 20 Sep 2023 21:32:09 -0700 Subject: [PATCH] [rubygems/rubygems] Aggressively optimize allocations in SafeMarshal Reduces allocations in a bundle install --full-index by an order of magnitude Main wins are (a) getting rid of exessive string allocations for exception message stack (b) Avoiding hash allocations caused by kwargs for #initialize (c) avoid using unpack to do bit math, its easy enough to do by hand (d) special case the most common elements so they can be read without an allocation (e) avoid string allocations every time a symbol->string lookup is done by using symbol#name https://github.com/rubygems/rubygems/commit/7d2ee51402 --- lib/rubygems/safe_marshal/elements.rb | 30 ++-- lib/rubygems/safe_marshal/reader.rb | 158 +++++++++++++----- lib/rubygems/safe_marshal/visitors/to_ruby.rb | 124 ++++++++++---- test/rubygems/test_gem_safe_marshal.rb | 30 +++- 4 files changed, 253 insertions(+), 89 deletions(-) diff --git a/lib/rubygems/safe_marshal/elements.rb b/lib/rubygems/safe_marshal/elements.rb index 70961c40ff2458..067ab59d19d6e4 100644 --- a/lib/rubygems/safe_marshal/elements.rb +++ b/lib/rubygems/safe_marshal/elements.rb @@ -7,14 +7,14 @@ class Element end class Symbol < Element - def initialize(name:) + def initialize(name) @name = name end attr_reader :name end class UserDefined < Element - def initialize(name:, binary_string:) + def initialize(name, binary_string) @name = name @binary_string = binary_string end @@ -23,7 +23,7 @@ def initialize(name:, binary_string:) end class UserMarshal < Element - def initialize(name:, data:) + def initialize(name, data) @name = name @data = data end @@ -32,7 +32,7 @@ def initialize(name:, data:) end class String < Element - def initialize(str:) + def initialize(str) @str = str end @@ -40,7 +40,7 @@ def initialize(str:) end class Hash < Element - def initialize(pairs:) + def initialize(pairs) @pairs = pairs end @@ -48,8 +48,8 @@ def initialize(pairs:) end class HashWithDefaultValue < Hash - def initialize(default:, **kwargs) - super(**kwargs) + def initialize(pairs, default) + super(pairs) @default = default end @@ -57,7 +57,7 @@ def initialize(default:, **kwargs) end class Array < Element - def initialize(elements:) + def initialize(elements) @elements = elements end @@ -65,7 +65,7 @@ def initialize(elements:) end class Integer < Element - def initialize(int:) + def initialize(int) @int = int end @@ -86,7 +86,7 @@ def initialize end class WithIvars < Element - def initialize(object:,ivars:) + def initialize(object, ivars) @object = object @ivars = ivars end @@ -95,7 +95,7 @@ def initialize(object:,ivars:) end class Object < Element - def initialize(name:) + def initialize(name) @name = name end attr_reader :name @@ -106,28 +106,28 @@ class Nil < Element end class ObjectLink < Element - def initialize(offset:) + def initialize(offset) @offset = offset end attr_reader :offset end class SymbolLink < Element - def initialize(offset:) + def initialize(offset) @offset = offset end attr_reader :offset end class Float < Element - def initialize(string:) + def initialize(string) @string = string end attr_reader :string end class Bignum < Element # rubocop:disable Lint/UnifiedInteger - def initialize(sign:, data:) + def initialize(sign, data) @sign = sign @data = data end diff --git a/lib/rubygems/safe_marshal/reader.rb b/lib/rubygems/safe_marshal/reader.rb index bc0bb6298617e4..c2c2295086124a 100644 --- a/lib/rubygems/safe_marshal/reader.rb +++ b/lib/rubygems/safe_marshal/reader.rb @@ -49,19 +49,19 @@ def read_integer when 0x00 0 when 0x01 - @io.read(1).unpack1("C") + read_byte when 0x02 - @io.read(2).unpack1("S<") + read_byte | (read_byte << 8) when 0x03 - (@io.read(3) + "\0").unpack1("L<") + read_byte | (read_byte << 8) | (read_byte << 16) when 0x04 - @io.read(4).unpack1("L<") + read_byte | (read_byte << 8) | (read_byte << 16) | (read_byte << 24) when 0xFC - @io.read(4).unpack1("L<") | -0x100000000 + read_byte | (read_byte << 8) | (read_byte << 16) | (read_byte << 24) | -0x100000000 when 0xFD - (@io.read(3) + "\0").unpack1("L<") | -0x1000000 + read_byte | (read_byte << 8) | (read_byte << 16) | -0x1000000 when 0xFE - @io.read(2).unpack1("s<") | -0x10000 + read_byte | (read_byte << 8) | -0x10000 when 0xFF read_byte | -0x100 else @@ -88,31 +88,51 @@ def read_element when 85 then read_user_marshal # ?U when 91 then read_array # ?[ when 102 then read_float # ?f - when 105 then Elements::Integer.new int: read_integer # ?i + when 105 then Elements::Integer.new(read_integer) # ?i when 108 then read_bignum # ?l when 111 then read_object # ?o when 117 then read_user_defined # ?u when 123 then read_hash # ?{ when 125 then read_hash_with_default_value # ?} - when "e".ord then read_extended_object - when "c".ord then read_class - when "m".ord then read_module - when "M".ord then read_class_or_module - when "d".ord then read_data - when "/".ord then read_regexp - when "S".ord then read_struct - when "C".ord then read_user_class + when 101 then read_extended_object # ?e + when 99 then read_class # ?c + when 109 then read_module # ?m + when 77 then read_class_or_module # ?M + when 100 then read_data # ?d + when 47 then read_regexp # ?/ + when 83 then read_struct # ?S + when 67 then read_user_class # ?C else raise Error, "Unknown marshal type discriminator #{type.chr.inspect} (#{type})" end end + STRING_E_SYMBOL = Elements::Symbol.new("E").freeze + private_constant :STRING_E_SYMBOL + def read_symbol - Elements::Symbol.new name: @io.read(read_integer) + len = read_integer + if len == 1 + byte = read_byte + if byte == 69 # ?E + STRING_E_SYMBOL + else + Elements::Symbol.new(byte.chr) + end + else + name = -@io.read(len) + Elements::Symbol.new(name) + end end + EMPTY_STRING = Elements::String.new("".b.freeze).freeze + private_constant :EMPTY_STRING + def read_string - Elements::String.new(str: @io.read(read_integer)) + length = read_integer + return EMPTY_STRING if length == 0 + str = @io.read(length) + Elements::String.new(str) end def read_true @@ -124,55 +144,108 @@ def read_false end def read_user_defined - Elements::UserDefined.new(name: read_element, binary_string: @io.read(read_integer)) + name = read_element + binary_string = @io.read(read_integer) + Elements::UserDefined.new(name, binary_string) end + EMPTY_ARRAY = Elements::Array.new([].freeze).freeze + private_constant :EMPTY_ARRAY + def read_array - Elements::Array.new(elements: Array.new(read_integer) do |_i| - read_element - end) + length = read_integer + return EMPTY_ARRAY if length == 0 + elements = Array.new(length) do + read_element + end + Elements::Array.new(elements) end def read_object_with_ivars - Elements::WithIvars.new(object: read_element, ivars: - Array.new(read_integer) do - [read_element, read_element] - end) + object = read_element + ivars = Array.new(read_integer) do + [read_element, read_element] + end + Elements::WithIvars.new(object, ivars) end def read_symbol_link - Elements::SymbolLink.new offset: read_integer + offset = read_integer + Elements::SymbolLink.new(offset) end def read_user_marshal - Elements::UserMarshal.new(name: read_element, data: read_element) - end + name = read_element + data = read_element + Elements::UserMarshal.new(name, data) + end + + # profiling bundle install --full-index shows that + # offset 6 is by far the most common object link, + # so we special case it to avoid allocating a new + # object a third of the time. + # the following are all the object links that + # appear more than 10000 times in my profiling + + OBJECT_LINKS = { + 6 => Elements::ObjectLink.new(6).freeze, + 30 => Elements::ObjectLink.new(30).freeze, + 81 => Elements::ObjectLink.new(81).freeze, + 34 => Elements::ObjectLink.new(34).freeze, + 38 => Elements::ObjectLink.new(38).freeze, + 50 => Elements::ObjectLink.new(50).freeze, + 91 => Elements::ObjectLink.new(91).freeze, + 42 => Elements::ObjectLink.new(42).freeze, + 46 => Elements::ObjectLink.new(46).freeze, + 150 => Elements::ObjectLink.new(150).freeze, + 100 => Elements::ObjectLink.new(100).freeze, + 104 => Elements::ObjectLink.new(104).freeze, + 108 => Elements::ObjectLink.new(108).freeze, + 242 => Elements::ObjectLink.new(242).freeze, + 246 => Elements::ObjectLink.new(246).freeze, + 139 => Elements::ObjectLink.new(139).freeze, + 143 => Elements::ObjectLink.new(143).freeze, + 114 => Elements::ObjectLink.new(114).freeze, + 308 => Elements::ObjectLink.new(308).freeze, + 200 => Elements::ObjectLink.new(200).freeze, + 54 => Elements::ObjectLink.new(54).freeze, + 62 => Elements::ObjectLink.new(62).freeze, + 1_286_245 => Elements::ObjectLink.new(1_286_245).freeze, + }.freeze + private_constant :OBJECT_LINKS def read_object_link - Elements::ObjectLink.new(offset: read_integer) + offset = read_integer + OBJECT_LINKS[offset] || Elements::ObjectLink.new(offset) end + EMPTY_HASH = Elements::Hash.new([].freeze).freeze + private_constant :EMPTY_HASH + def read_hash - pairs = Array.new(read_integer) do + length = read_integer + return EMPTY_HASH if length == 0 + pairs = Array.new(length) do [read_element, read_element] end - Elements::Hash.new(pairs: pairs) + Elements::Hash.new(pairs) end def read_hash_with_default_value pairs = Array.new(read_integer) do [read_element, read_element] end - Elements::HashWithDefaultValue.new(pairs: pairs, default: read_element) + default = read_element + Elements::HashWithDefaultValue.new(pairs, default) end def read_object - Elements::WithIvars.new( - object: Elements::Object.new(name: read_element), - ivars: Array.new(read_integer) do - [read_element, read_element] - end - ) + name = read_element + object = Elements::Object.new(name) + ivars = Array.new(read_integer) do + [read_element, read_element] + end + Elements::WithIvars.new(object, ivars) end def read_nil @@ -180,11 +253,14 @@ def read_nil end def read_float - Elements::Float.new string: @io.read(read_integer) + string = @io.read(read_integer) + Elements::Float.new(string) end def read_bignum - Elements::Bignum.new(sign: read_byte, data: @io.read(read_integer * 2)) + sign = read_byte + data = @io.read(read_integer * 2) + Elements::Bignum.new(sign, data) end def read_extended_object diff --git a/lib/rubygems/safe_marshal/visitors/to_ruby.rb b/lib/rubygems/safe_marshal/visitors/to_ruby.rb index 147141c2c54862..58c44fa8bf2789 100644 --- a/lib/rubygems/safe_marshal/visitors/to_ruby.rb +++ b/lib/rubygems/safe_marshal/visitors/to_ruby.rb @@ -7,7 +7,7 @@ module Visitors class ToRuby < Visitor def initialize(permitted_classes:, permitted_symbols:, permitted_ivars:) @permitted_classes = permitted_classes - @permitted_symbols = permitted_symbols | permitted_classes | ["E"] + @permitted_symbols = ["E"].concat(permitted_symbols).concat(permitted_classes) @permitted_ivars = permitted_ivars @objects = [] @@ -15,6 +15,7 @@ def initialize(permitted_classes:, permitted_symbols:, permitted_ivars:) @class_cache = {} @stack = ["root"] + @stack_idx = 1 end def inspect # :nodoc: @@ -23,39 +24,61 @@ def inspect # :nodoc: end def visit(target) - depth = @stack.size + stack_idx = @stack_idx super ensure - @stack.slice!(depth.pred..) + @stack_idx = stack_idx - 1 end private + def push_stack(element) + @stack[@stack_idx] = element + @stack_idx += 1 + end + def visit_Gem_SafeMarshal_Elements_Array(a) - register_object([]).replace(a.elements.each_with_index.map do |e, i| - @stack << "[#{i}]" - visit(e) - end) + array = register_object([]) + + elements = a.elements + size = elements.size + idx = 0 + # not idiomatic, but there's a huge number of IMEMOs allocated here, so we avoid the block + # because this is such a hot path when doing a bundle install with the full index + until idx == size + push_stack idx + array << visit(elements[idx]) + idx += 1 + end + + array end def visit_Gem_SafeMarshal_Elements_Symbol(s) name = s.name - raise UnpermittedSymbolError.new(symbol: name, stack: @stack.dup) unless @permitted_symbols.include?(name) + raise UnpermittedSymbolError.new(symbol: name, stack: formatted_stack) unless @permitted_symbols.include?(name) visit_symbol_type(s) end def map_ivars(klass, ivars) + stack_idx = @stack_idx ivars.map.with_index do |(k, v), i| - @stack << "ivar_#{i}" + @stack_idx = stack_idx + + push_stack "ivar_" + push_stack i k = resolve_ivar(klass, k) - @stack[-1] = k + + @stack_idx = stack_idx + push_stack k + next k, visit(v) end end def visit_Gem_SafeMarshal_Elements_WithIvars(e) object_offset = @objects.size - @stack << "object" + push_stack "object" object = visit(e.object) ivars = map_ivars(object.class, e.ivars) @@ -76,12 +99,18 @@ def visit_Gem_SafeMarshal_Elements_WithIvars(e) s = e.object.binary_string - marshal_string = "\x04\bIu:\tTime#{(s.size + 5).chr}#{s.b}".b - - marshal_string << (internal.size + 5).chr + marshal_string = "\x04\bIu:\tTime".b + marshal_string.concat(s.size + 5) + marshal_string << s + marshal_string.concat(internal.size + 5) internal.each do |k, v| - marshal_string << ":#{(k.size + 5).chr}#{k}#{Marshal.dump(v)[2..-1]}" + marshal_string.concat(":") + marshal_string.concat(k.size + 5) + marshal_string.concat(k.to_s) + dumped = Marshal.dump(v) + dumped[0, 2] = "" + marshal_string.concat(dumped) end object = @objects[object_offset] = Marshal.load(marshal_string) @@ -108,7 +137,7 @@ def visit_Gem_SafeMarshal_Elements_WithIvars(e) true end - object.replace ::String.new(object, encoding: enc) + object.force_encoding(enc) if enc end ivars.each do |k, v| @@ -121,9 +150,9 @@ def visit_Gem_SafeMarshal_Elements_Hash(o) hash = register_object({}) o.pairs.each_with_index do |(k, v), i| - @stack << i + push_stack i k = visit(k) - @stack << k + push_stack k hash[k] = visit(v) end @@ -132,7 +161,7 @@ def visit_Gem_SafeMarshal_Elements_Hash(o) def visit_Gem_SafeMarshal_Elements_HashWithDefaultValue(o) hash = visit_Gem_SafeMarshal_Elements_Hash(o) - @stack << :default + push_stack :default hash.default = visit(o.default) hash end @@ -159,7 +188,7 @@ def visit_Gem_SafeMarshal_Elements_UserMarshal(o) idx = @objects.size object = register_object(call_method(compat || klass, :allocate)) - @stack << :data + push_stack :data ret = call_method(object, :marshal_load, visit(o.data)) if compat @@ -186,7 +215,7 @@ def visit_Gem_SafeMarshal_Elements_False(_) end def visit_Gem_SafeMarshal_Elements_String(s) - register_object(s.str) + register_object(+s.str) end def visit_Gem_SafeMarshal_Elements_Float(f) @@ -221,7 +250,7 @@ def visit_Gem_SafeMarshal_Elements_Bignum(b) def resolve_class(n) @class_cache[n] ||= begin to_s = resolve_symbol_name(n) - raise UnpermittedClassError.new(name: to_s, stack: @stack.dup) unless @permitted_classes.include?(to_s) + raise UnpermittedClassError.new(name: to_s, stack: formatted_stack) unless @permitted_classes.include?(to_s) visit_symbol_type(n) begin ::Object.const_get(to_s) @@ -238,16 +267,17 @@ def marshal_load(s) Rational(num, den) end end + private_constant :RationalCompat COMPAT_CLASSES = {}.tap do |h| h[Rational] = RationalCompat - end.freeze + end.compare_by_identity.freeze private_constant :COMPAT_CLASSES def resolve_ivar(klass, name) to_s = resolve_symbol_name(name) - raise UnpermittedIvarError.new(symbol: to_s, klass: klass, stack: @stack.dup) unless @permitted_ivars.fetch(klass.name, [].freeze).include?(to_s) + raise UnpermittedIvarError.new(symbol: to_s, klass: klass, stack: formatted_stack) unless @permitted_ivars.fetch(klass.name, [].freeze).include?(to_s) visit_symbol_type(name) end @@ -263,14 +293,28 @@ def visit_symbol_type(element) end end - def resolve_symbol_name(element) - case element - when Elements::Symbol - element.name - when Elements::SymbolLink - visit_Gem_SafeMarshal_Elements_SymbolLink(element).to_s - else - raise FormatError, "Expected symbol or symbol link, got #{element.inspect} @ #{@stack.join(".")}" + # This is a hot method, so avoid respond_to? checks on every invocation + if :read.respond_to?(:name) + def resolve_symbol_name(element) + case element + when Elements::Symbol + element.name + when Elements::SymbolLink + visit_Gem_SafeMarshal_Elements_SymbolLink(element).name + else + raise FormatError, "Expected symbol or symbol link, got #{element.inspect} @ #{formatted_stack.join(".")}" + end + end + else + def resolve_symbol_name(element) + case element + when Elements::Symbol + element.name + when Elements::SymbolLink + visit_Gem_SafeMarshal_Elements_SymbolLink(element).to_s + else + raise FormatError, "Expected symbol or symbol link, got #{element.inspect} @ #{formatted_stack.join(".")}" + end end end @@ -287,6 +331,22 @@ def call_method(receiver, method, *args) raise MethodCallError, "Unable to call #{method.inspect} on #{receiver.inspect}, perhaps it is a class using marshal compat, which is not visible in ruby? #{e}" end + def formatted_stack + formatted = [] + @stack[0, @stack_idx].each do |e| + if e.is_a?(Integer) + if formatted.last == "ivar_" + formatted[-1] = "ivar_#{e}" + else + formatted << "[#{e}]" + end + else + formatted << e + end + end + formatted + end + class Error < StandardError end diff --git a/test/rubygems/test_gem_safe_marshal.rb b/test/rubygems/test_gem_safe_marshal.rb index d123ad5dc4f0e5..5c73170192334a 100644 --- a/test/rubygems/test_gem_safe_marshal.rb +++ b/test/rubygems/test_gem_safe_marshal.rb @@ -17,7 +17,9 @@ class TestGemSafeMarshal < Gem::TestCase define_method("test_safe_load_marshal Float 30000000.0") { assert_safe_load_marshal "\x04\bf\b3e7" } define_method("test_safe_load_marshal Float -30000000.0") { assert_safe_load_marshal "\x04\bf\t-3e7" } define_method("test_safe_load_marshal Gem::Version #") { assert_safe_load_marshal "\x04\bU:\x11Gem::Version[\x06I\"\n1.abc\x06:\x06ET" } - define_method("test_safe_load_marshal Hash {}") { assert_safe_load_marshal "\x04\b}\x00[\x00" } + define_method("test_safe_load_marshal Hash {} default value") { assert_safe_load_marshal "\x04\b}\x00[\x00", additional_methods: [:default] } + define_method("test_safe_load_marshal Hash {}") { assert_safe_load_marshal "\x04\b{\x00" } + define_method("test_safe_load_marshal Array {}") { assert_safe_load_marshal "\x04\b[\x00" } define_method("test_safe_load_marshal Hash {:runtime=>:development}") { assert_safe_load_marshal "\x04\bI{\x06:\fruntime:\x10development\x06:\n@type[\x00", permitted_ivars: { "Hash" => %w[@type] } } define_method("test_safe_load_marshal Integer -1") { assert_safe_load_marshal "\x04\bi\xFA" } define_method("test_safe_load_marshal Integer -1048575") { assert_safe_load_marshal "\x04\bi\xFD\x01\x00\xF0" } @@ -124,6 +126,12 @@ def test_repeated_symbol assert_safe_load_as [:development, :development] end + def test_length_one_symbols + with_const(Gem::SafeMarshal, :PERMITTED_SYMBOLS, %w[E A b 0] << "") do + assert_safe_load_as [:A, :E, :E, :A, "".to_sym, "".to_sym], additional_methods: [:instance_variables] + end + end + def test_repeated_string s = "hello" a = [s] @@ -156,6 +164,12 @@ def test_string_with_encoding String.new("abc", encoding: "Windows-1256"), String.new("abc", encoding: Encoding::BINARY), String.new("abc", encoding: "UTF-32"), + + String.new("", encoding: "US-ASCII"), + String.new("", encoding: "UTF-8"), + String.new("", encoding: "Windows-1256"), + String.new("", encoding: Encoding::BINARY), + String.new("", encoding: "UTF-32"), ].each do |s| assert_safe_load_as s, additional_methods: [:encoding] assert_safe_load_as [s, s], additional_methods: [->(a) { a.map(&:encoding) }] @@ -282,6 +296,20 @@ def test_gem_spec_disallowed_symbol assert_equal e.message, "Attempting to load unpermitted symbol \"rspec\" @ root.[9].[0].@name" end + def test_gem_spec_disallowed_ivar + e = assert_raise(Gem::SafeMarshal::Visitors::ToRuby::UnpermittedIvarError) do + spec = Gem::Specification.new do |s| + s.name = "hi" + s.version = "1.2.3" + + s.dependencies << Gem::Dependency.new("rspec", Gem::Requirement.new([">= 1.2.3"]), :runtime).tap {|d| d.instance_variable_set(:@foobar, "rspec") } + end + Gem::SafeMarshal.safe_load(Marshal.dump(spec)) + end + + assert_equal e.message, "Attempting to set unpermitted ivar \"@foobar\" on object of class Gem::Dependency @ root.[9].[0].ivar_5" + end + def assert_safe_load_marshal(dumped, additional_methods: [], permitted_ivars: nil, equality: true, marshal_dump_equality: true) loaded = Marshal.load(dumped) safe_loaded =