Skip to content

Commit

Permalink
Merge pull request #48241 from Shopify/populate-autoincremented-colum…
Browse files Browse the repository at this point in the history
…n-for-a-model-with-cpk

Assign auto populated columns on Active Record object creation
  • Loading branch information
eileencodes committed Jun 1, 2023
2 parents 64ab34c + c929332 commit 3421e89
Show file tree
Hide file tree
Showing 21 changed files with 173 additions and 26 deletions.
9 changes: 9 additions & 0 deletions activerecord/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
* Assign auto populated columns on Active Record record creation

Changes record creation logic to allow for the `auto_increment` column to be assigned
right after creation regardless of it's relation to model's primary key.
PostgreSQL adapter benefits the most from the change allowing for any number of auto-populated
columns to be assigned on the object immediately after row insertion utilizing the `RETURNING` statement.

*Nikita Vasilevsky*

* Use the first key in the `shards` hash from `connected_to` for the `default_shard`.

Some applications may not want to use `:default` as a shard name in their connection model. Unfortunately Active Record expects there to be a `:default` shard because it must assume a shard to get the right connection from the pool manager. Rather than force applications to manually set this, `connects_to` can infer the default shard name from the hash of shards and will now assume that the first shard is your default.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,11 @@ def exec_query(sql, name = "SQL", binds = [], prepare: false)
# Executes insert +sql+ statement in the context of this connection using
# +binds+ as the bind substitutes. +name+ is logged along with
# the executed +sql+ statement.
def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil)
sql, binds = sql_for_insert(sql, pk, binds)
# Some adapters support the `returning` keyword argument which allows to control the result of the query:
# `nil` is the default value and maintains default behavior. If an array of column names is passed -
# the result will contain values of the specified columns from the inserted row.
def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil, returning: nil)
sql, binds = sql_for_insert(sql, pk, binds, returning)
internal_exec_query(sql, name, binds)
end

Expand Down Expand Up @@ -180,10 +183,14 @@ def explain(arel, binds = [], options = []) # :nodoc:
#
# If the next id was calculated in advance (as in Oracle), it should be
# passed in as +id_value+.
def insert(arel, name = nil, pk = nil, id_value = nil, sequence_name = nil, binds = [])
# Some adapters support the `returning` keyword argument which allows defining the return value of the method:
# `nil` is the default value and maintains default behavior. If an array of column names is passed -
# an array of is returned from the method representing values of the specified columns from the inserted row.
def insert(arel, name = nil, pk = nil, id_value = nil, sequence_name = nil, binds = [], returning: nil)
sql, binds = to_sql_and_binds(arel, binds)
value = exec_insert(sql, name, binds, pk, sequence_name)
id_value || last_inserted_id(value)
value = exec_insert(sql, name, binds, pk, sequence_name, returning: returning)
return id_value if id_value
returning.nil? ? last_inserted_id(value) : returning_column_values(value)
end
alias create insert

Expand Down Expand Up @@ -626,14 +633,18 @@ def select(sql, name = nil, binds = [], prepare: false, async: false)
end
end

def sql_for_insert(sql, pk, binds)
def sql_for_insert(sql, _pk, binds, _returning)
[sql, binds]
end

def last_inserted_id(result)
single_value_from_rows(result.rows)
end

def returning_column_values(result)
[last_inserted_id(result)]
end

def single_value_from_rows(rows)
row = rows.first
row && row.first
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ def index_exists?(table_name, column_name, **options)
# Returns an array of +Column+ objects for the table specified by +table_name+.
def columns(table_name)
table_name = table_name.to_s
column_definitions(table_name).map do |field|
new_column_from_field(table_name, field)
definitions = column_definitions(table_name)
definitions.map do |field|
new_column_from_field(table_name, field, definitions)
end
end

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,10 @@ def supports_concurrent_connections?
true
end

def return_value_after_insert?(column) # :nodoc:
column.auto_incremented_by_db?
end

def async_enabled? # :nodoc:
supports_concurrent_connections? &&
!ActiveRecord.async_query_executor.nil? && !pool.async_executor.nil?
Expand Down
9 changes: 9 additions & 0 deletions activerecord/lib/active_record/connection_adapters/column.rb
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ def encode_with(coder)
coder["comment"] = @comment
end

# whether the column is auto-populated by the database using a sequence
def auto_incremented_by_db?
false
end

def auto_populated?
auto_incremented_by_db? || default_function
end

def ==(other)
other.is_a?(Column) &&
name == other.name &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def case_sensitive?
def auto_increment?
extra == "auto_increment"
end
alias_method :auto_incremented_by_db?, :auto_increment?

def virtual?
/\b(?:VIRTUAL|STORED|PERSISTENT)\b/.match?(extra)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def default_type(table_name, field_name)
end
end

def new_column_from_field(table_name, field)
def new_column_from_field(table_name, field, _definitions)
field_name = field.fetch(:Field)
type_metadata = fetch_type_metadata(field[:Type], field[:Extra])
default, default_function = field[:Default], nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def initialize(*, serial: nil, generated: nil, **)
def serial?
@serial
end
alias_method :auto_incremented_by_db?, :serial?

def virtual?
# We assume every generated column is virtual, no matter the concrete type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,23 @@ def exec_delete(sql, name = nil, binds = []) # :nodoc:
end
alias :exec_update :exec_delete

def sql_for_insert(sql, pk, binds) # :nodoc:
def sql_for_insert(sql, pk, binds, returning) # :nodoc:
if pk.nil?
# Extract the table from the insert sql. Yuck.
table_ref = extract_table_ref_from_insert_sql(sql)
pk = primary_key(table_ref) if table_ref
end

if pk = suppress_composite_primary_key(pk)
sql = "#{sql} RETURNING #{quote_column_name(pk)}"
end
returning_columns = returning || Array(pk)

returning_columns_statement = returning_columns.map { |c| quote_column_name(c) }.join(", ")
sql = "#{sql} RETURNING #{returning_columns_statement}" if returning_columns.any?

super
end
private :sql_for_insert

def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil) # :nodoc:
def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil, returning: nil) # :nodoc:
if use_insert_returning? || pk == false
super
else
Expand Down Expand Up @@ -172,6 +173,10 @@ def last_insert_id_result(sequence_name)
internal_exec_query("SELECT currval(#{quote(sequence_name)})", "SQL")
end

def returning_column_values(result)
result.rows.first
end

def suppress_composite_primary_key(pk)
pk unless pk.is_a?(Array)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ def create_alter_table(name)
PostgreSQL::AlterTable.new create_table_definition(name)
end

def new_column_from_field(table_name, field)
def new_column_from_field(table_name, field, _definitions)
column_name, type, default, notnull, oid, fmod, collation, comment, attgenerated = field
type_metadata = fetch_type_metadata(column_name, type, oid.to_i, fmod.to_i)
default_value = extract_value_from_default(default)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ def index_algorithms
{ concurrently: "CONCURRENTLY" }
end

def return_value_after_insert?(column) # :nodoc:
column.auto_populated?
end

class StatementPool < ConnectionAdapters::StatementPool # :nodoc:
def initialize(connection, max)
super(max)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,22 @@ module ActiveRecord
module ConnectionAdapters
module SQLite3
class Column < ConnectionAdapters::Column # :nodoc:
def initialize(*, auto_increment: nil, **)
attr_reader :rowid

def initialize(*, auto_increment: nil, rowid: false, **)
super
@auto_increment = auto_increment
@rowid = rowid
end

def auto_increment?
@auto_increment
end

def auto_incremented_by_db?
auto_increment? || rowid
end

def init_with(coder)
@auto_increment = coder["auto_increment"]
super
Expand All @@ -33,7 +40,8 @@ def ==(other)
def hash
Column.hash ^
super.hash ^
auto_increment?.hash
auto_increment?.hash ^
rowid.hash
end
end
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,13 @@ def validate_index_length!(table_name, new_name, internal = false)
super unless internal
end

def new_column_from_field(table_name, field)
def new_column_from_field(table_name, field, definitions)
default = field["dflt_value"]

type_metadata = fetch_type_metadata(field["type"])
default_value = extract_value_from_default(default)
default_function = extract_default_function(default_value, default)
rowid = is_column_the_rowid?(field, definitions)

Column.new(
field["name"],
Expand All @@ -147,9 +148,20 @@ def new_column_from_field(table_name, field)
default_function,
collation: field["collation"],
auto_increment: field["auto_increment"],
rowid: rowid
)
end

INTEGER_REGEX = /integer/i
# if a rowid table has a primary key that consists of a single column
# and the declared type of that column is "INTEGER" in any mixture of upper and lower case,
# then the column becomes an alias for the rowid.
def is_column_the_rowid?(field, column_definitions)
return false unless INTEGER_REGEX.match?(field["type"]) && field["pk"] == 1
# is the primary key a single column?
column_definitions.one? { |c| c["pk"] > 0 }
end

def data_source_sql(name = nil, type: nil)
scope = quoted_scope(name, type: type)
scope[:type] ||= "'table','view'"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: fa
ActiveRecord::Result.new(result.fields, result.to_a)
end

def exec_insert(sql, name, binds, pk = nil, sequence_name = nil) # :nodoc:
def exec_insert(sql, name, binds, pk = nil, sequence_name = nil, returning: nil) # :nodoc:
sql = transform_query(sql)
check_if_write_query(sql)
mark_transaction_written_if_write(sql)
Expand Down
7 changes: 7 additions & 0 deletions activerecord/lib/active_record/model_schema.rb
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,12 @@ def columns
@columns ||= columns_hash.values.freeze
end

def _returning_columns_for_insert # :nodoc:
@_returning_columns_for_insert ||= columns.filter_map do |c|
c.name if connection.return_value_after_insert?(c)
end
end

def attribute_types # :nodoc:
load_schema
@attribute_types ||= Hash.new(Type.default_value)
Expand Down Expand Up @@ -546,6 +552,7 @@ def initialize_load_schema_monitor
end

def reload_schema_from_cache(recursive = true)
@_returning_columns_for_insert = nil
@arel_table = nil
@column_names = nil
@symbol_column_to_string_name_hash = nil
Expand Down
18 changes: 13 additions & 5 deletions activerecord/lib/active_record/persistence.rb
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def delete(id_or_array)
delete_by(primary_key => id_or_array)
end

def _insert_record(values) # :nodoc:
def _insert_record(values, returning) # :nodoc:
primary_key = self.primary_key
primary_key_value = nil

Expand All @@ -580,7 +580,10 @@ def _insert_record(values) # :nodoc:
im.insert(values.transform_keys { |name| arel_table[name] })
end

connection.insert(im, "#{self} Create", primary_key || false, primary_key_value)
connection.insert(
im, "#{self} Create", primary_key || false, primary_key_value,
returning: returning
)
end

def _update_record(values, constraints) # :nodoc:
Expand Down Expand Up @@ -1235,11 +1238,16 @@ def _update_record(attribute_names = self.attribute_names)
def _create_record(attribute_names = self.attribute_names)
attribute_names = attributes_for_create(attribute_names)

new_id = self.class._insert_record(
attributes_with_values(attribute_names)
returning_columns = self.class._returning_columns_for_insert

returning_values = self.class._insert_record(
attributes_with_values(attribute_names),
returning_columns
)

self.id ||= new_id if @primary_key
returning_columns.zip(returning_values).each do |column, value|
_write_attribute(column, value) if !_read_attribute(column)
end if returning_values

@new_record = false
@previously_new_record = true
Expand Down
1 change: 1 addition & 0 deletions activerecord/test/cases/adapters/postgresql/uuid_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class UUIDType < ActiveRecord::Base
end

teardown do
UUIDType.reset_column_information
drop_table "uuid_data_type"
end

Expand Down
36 changes: 36 additions & 0 deletions activerecord/test/cases/adapters/sqlite3/sqlite3_adapter_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,42 @@ def test_strict_strings_by_default_and_false_in_database_yml
end
end

def test_rowid_column
with_example_table "id_uppercase INTEGER PRIMARY KEY" do
assert @conn.columns("ex").index_by(&:name)["id_uppercase"].rowid
end
end

def test_lowercase_rowid_column
with_example_table "id_lowercase integer PRIMARY KEY" do
assert @conn.columns("ex").index_by(&:name)["id_lowercase"].rowid
end
end

def test_non_integer_column_returns_false_for_rowid
with_example_table "id_int_short int PRIMARY KEY" do
assert_not @conn.columns("ex").index_by(&:name)["id_int_short"].rowid
end
end

def test_mixed_case_integer_colum_returns_true_for_rowid
with_example_table "id_mixed_case InTeGeR PRIMARY KEY" do
assert @conn.columns("ex").index_by(&:name)["id_mixed_case"].rowid
end
end

def test_rowid_column_with_autoincrement_returns_true_for_rowid
with_example_table "id_autoincrement integer PRIMARY KEY AUTOINCREMENT" do
assert @conn.columns("ex").index_by(&:name)["id_autoincrement"].rowid
end
end

def test_integer_cpk_column_returns_false_for_rowid
with_example_table("id integer, shop_id integer, PRIMARY KEY (shop_id, id)", "cpk_table") do
assert_not @conn.columns("cpk_table").any?(&:rowid)
end
end

private
def assert_logged(logs)
subscriber = SQLSubscriber.new
Expand Down
Loading

0 comments on commit 3421e89

Please sign in to comment.