Skip to content

Commit

Permalink
Merge pull request #48241 from Shopify/populate-autoincremented-colum…
Browse files Browse the repository at this point in the history
…n-for-a-model-with-cpk

Assign auto populated columns on Active Record object creation
  • Loading branch information
eileencodes committed Jun 1, 2023
2 parents 64ab34c + c929332 commit 3421e89
Show file tree
Hide file tree
Showing 21 changed files with 173 additions and 26 deletions.
9 changes: 9 additions & 0 deletions activerecord/CHANGELOG.md
@@ -1,3 +1,12 @@
* Assign auto populated columns on Active Record record creation

Changes record creation logic to allow for the `auto_increment` column to be assigned
right after creation regardless of it's relation to model's primary key.
PostgreSQL adapter benefits the most from the change allowing for any number of auto-populated
columns to be assigned on the object immediately after row insertion utilizing the `RETURNING` statement.

*Nikita Vasilevsky*

* Use the first key in the `shards` hash from `connected_to` for the `default_shard`.

Some applications may not want to use `:default` as a shard name in their connection model. Unfortunately Active Record expects there to be a `:default` shard because it must assume a shard to get the right connection from the pool manager. Rather than force applications to manually set this, `connects_to` can infer the default shard name from the hash of shards and will now assume that the first shard is your default.
Expand Down
Expand Up @@ -145,8 +145,11 @@ def exec_query(sql, name = "SQL", binds = [], prepare: false)
# Executes insert +sql+ statement in the context of this connection using
# +binds+ as the bind substitutes. +name+ is logged along with
# the executed +sql+ statement.
def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil)
sql, binds = sql_for_insert(sql, pk, binds)
# Some adapters support the `returning` keyword argument which allows to control the result of the query:
# `nil` is the default value and maintains default behavior. If an array of column names is passed -
# the result will contain values of the specified columns from the inserted row.
def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil, returning: nil)
sql, binds = sql_for_insert(sql, pk, binds, returning)
internal_exec_query(sql, name, binds)
end

Expand Down Expand Up @@ -180,10 +183,14 @@ def explain(arel, binds = [], options = []) # :nodoc:
#
# If the next id was calculated in advance (as in Oracle), it should be
# passed in as +id_value+.
def insert(arel, name = nil, pk = nil, id_value = nil, sequence_name = nil, binds = [])
# Some adapters support the `returning` keyword argument which allows defining the return value of the method:
# `nil` is the default value and maintains default behavior. If an array of column names is passed -
# an array of is returned from the method representing values of the specified columns from the inserted row.
def insert(arel, name = nil, pk = nil, id_value = nil, sequence_name = nil, binds = [], returning: nil)
sql, binds = to_sql_and_binds(arel, binds)
value = exec_insert(sql, name, binds, pk, sequence_name)
id_value || last_inserted_id(value)
value = exec_insert(sql, name, binds, pk, sequence_name, returning: returning)
return id_value if id_value
returning.nil? ? last_inserted_id(value) : returning_column_values(value)
end
alias create insert

Expand Down Expand Up @@ -626,14 +633,18 @@ def select(sql, name = nil, binds = [], prepare: false, async: false)
end
end

def sql_for_insert(sql, pk, binds)
def sql_for_insert(sql, _pk, binds, _returning)
[sql, binds]
end

def last_inserted_id(result)
single_value_from_rows(result.rows)
end

def returning_column_values(result)
[last_inserted_id(result)]
end

def single_value_from_rows(rows)
row = rows.first
row && row.first
Expand Down
Expand Up @@ -106,8 +106,9 @@ def index_exists?(table_name, column_name, **options)
# Returns an array of +Column+ objects for the table specified by +table_name+.
def columns(table_name)
table_name = table_name.to_s
column_definitions(table_name).map do |field|
new_column_from_field(table_name, field)
definitions = column_definitions(table_name)
definitions.map do |field|
new_column_from_field(table_name, field, definitions)
end
end

Expand Down
Expand Up @@ -580,6 +580,10 @@ def supports_concurrent_connections?
true
end

def return_value_after_insert?(column) # :nodoc:
column.auto_incremented_by_db?
end

def async_enabled? # :nodoc:
supports_concurrent_connections? &&
!ActiveRecord.async_query_executor.nil? && !pool.async_executor.nil?
Expand Down
9 changes: 9 additions & 0 deletions activerecord/lib/active_record/connection_adapters/column.rb
Expand Up @@ -63,6 +63,15 @@ def encode_with(coder)
coder["comment"] = @comment
end

# whether the column is auto-populated by the database using a sequence
def auto_incremented_by_db?
false
end

def auto_populated?
auto_incremented_by_db? || default_function
end

def ==(other)
other.is_a?(Column) &&
name == other.name &&
Expand Down
Expand Up @@ -17,6 +17,7 @@ def case_sensitive?
def auto_increment?
extra == "auto_increment"
end
alias_method :auto_incremented_by_db?, :auto_increment?

def virtual?
/\b(?:VIRTUAL|STORED|PERSISTENT)\b/.match?(extra)
Expand Down
Expand Up @@ -175,7 +175,7 @@ def default_type(table_name, field_name)
end
end

def new_column_from_field(table_name, field)
def new_column_from_field(table_name, field, _definitions)
field_name = field.fetch(:Field)
type_metadata = fetch_type_metadata(field[:Type], field[:Extra])
default, default_function = field[:Default], nil
Expand Down
Expand Up @@ -15,6 +15,7 @@ def initialize(*, serial: nil, generated: nil, **)
def serial?
@serial
end
alias_method :auto_incremented_by_db?, :serial?

def virtual?
# We assume every generated column is virtual, no matter the concrete type
Expand Down
Expand Up @@ -75,22 +75,23 @@ def exec_delete(sql, name = nil, binds = []) # :nodoc:
end
alias :exec_update :exec_delete

def sql_for_insert(sql, pk, binds) # :nodoc:
def sql_for_insert(sql, pk, binds, returning) # :nodoc:
if pk.nil?
# Extract the table from the insert sql. Yuck.
table_ref = extract_table_ref_from_insert_sql(sql)
pk = primary_key(table_ref) if table_ref
end

if pk = suppress_composite_primary_key(pk)
sql = "#{sql} RETURNING #{quote_column_name(pk)}"
end
returning_columns = returning || Array(pk)

returning_columns_statement = returning_columns.map { |c| quote_column_name(c) }.join(", ")
sql = "#{sql} RETURNING #{returning_columns_statement}" if returning_columns.any?

super
end
private :sql_for_insert

def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil) # :nodoc:
def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil, returning: nil) # :nodoc:
if use_insert_returning? || pk == false
super
else
Expand Down Expand Up @@ -172,6 +173,10 @@ def last_insert_id_result(sequence_name)
internal_exec_query("SELECT currval(#{quote(sequence_name)})", "SQL")
end

def returning_column_values(result)
result.rows.first
end

def suppress_composite_primary_key(pk)
pk unless pk.is_a?(Array)
end
Expand Down
Expand Up @@ -897,7 +897,7 @@ def create_alter_table(name)
PostgreSQL::AlterTable.new create_table_definition(name)
end

def new_column_from_field(table_name, field)
def new_column_from_field(table_name, field, _definitions)
column_name, type, default, notnull, oid, fmod, collation, comment, attgenerated = field
type_metadata = fetch_type_metadata(column_name, type, oid.to_i, fmod.to_i)
default_value = extract_value_from_default(default)
Expand Down
Expand Up @@ -281,6 +281,10 @@ def index_algorithms
{ concurrently: "CONCURRENTLY" }
end

def return_value_after_insert?(column) # :nodoc:
column.auto_populated?
end

class StatementPool < ConnectionAdapters::StatementPool # :nodoc:
def initialize(connection, max)
super(max)
Expand Down
Expand Up @@ -4,15 +4,22 @@ module ActiveRecord
module ConnectionAdapters
module SQLite3
class Column < ConnectionAdapters::Column # :nodoc:
def initialize(*, auto_increment: nil, **)
attr_reader :rowid

def initialize(*, auto_increment: nil, rowid: false, **)
super
@auto_increment = auto_increment
@rowid = rowid
end

def auto_increment?
@auto_increment
end

def auto_incremented_by_db?
auto_increment? || rowid
end

def init_with(coder)
@auto_increment = coder["auto_increment"]
super
Expand All @@ -33,7 +40,8 @@ def ==(other)
def hash
Column.hash ^
super.hash ^
auto_increment?.hash
auto_increment?.hash ^
rowid.hash
end
end
end
Expand Down
Expand Up @@ -132,12 +132,13 @@ def validate_index_length!(table_name, new_name, internal = false)
super unless internal
end

def new_column_from_field(table_name, field)
def new_column_from_field(table_name, field, definitions)
default = field["dflt_value"]

type_metadata = fetch_type_metadata(field["type"])
default_value = extract_value_from_default(default)
default_function = extract_default_function(default_value, default)
rowid = is_column_the_rowid?(field, definitions)

Column.new(
field["name"],
Expand All @@ -147,9 +148,20 @@ def new_column_from_field(table_name, field)
default_function,
collation: field["collation"],
auto_increment: field["auto_increment"],
rowid: rowid
)
end

INTEGER_REGEX = /integer/i
# if a rowid table has a primary key that consists of a single column
# and the declared type of that column is "INTEGER" in any mixture of upper and lower case,
# then the column becomes an alias for the rowid.
def is_column_the_rowid?(field, column_definitions)
return false unless INTEGER_REGEX.match?(field["type"]) && field["pk"] == 1
# is the primary key a single column?
column_definitions.one? { |c| c["pk"] > 0 }
end

def data_source_sql(name = nil, type: nil)
scope = quoted_scope(name, type: type)
scope[:type] ||= "'table','view'"
Expand Down
Expand Up @@ -21,7 +21,7 @@ def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: fa
ActiveRecord::Result.new(result.fields, result.to_a)
end

def exec_insert(sql, name, binds, pk = nil, sequence_name = nil) # :nodoc:
def exec_insert(sql, name, binds, pk = nil, sequence_name = nil, returning: nil) # :nodoc:
sql = transform_query(sql)
check_if_write_query(sql)
mark_transaction_written_if_write(sql)
Expand Down
7 changes: 7 additions & 0 deletions activerecord/lib/active_record/model_schema.rb
Expand Up @@ -422,6 +422,12 @@ def columns
@columns ||= columns_hash.values.freeze
end

def _returning_columns_for_insert # :nodoc:
@_returning_columns_for_insert ||= columns.filter_map do |c|
c.name if connection.return_value_after_insert?(c)
end
end

def attribute_types # :nodoc:
load_schema
@attribute_types ||= Hash.new(Type.default_value)
Expand Down Expand Up @@ -546,6 +552,7 @@ def initialize_load_schema_monitor
end

def reload_schema_from_cache(recursive = true)
@_returning_columns_for_insert = nil
@arel_table = nil
@column_names = nil
@symbol_column_to_string_name_hash = nil
Expand Down
18 changes: 13 additions & 5 deletions activerecord/lib/active_record/persistence.rb
Expand Up @@ -561,7 +561,7 @@ def delete(id_or_array)
delete_by(primary_key => id_or_array)
end

def _insert_record(values) # :nodoc:
def _insert_record(values, returning) # :nodoc:
primary_key = self.primary_key
primary_key_value = nil

Expand All @@ -580,7 +580,10 @@ def _insert_record(values) # :nodoc:
im.insert(values.transform_keys { |name| arel_table[name] })
end

connection.insert(im, "#{self} Create", primary_key || false, primary_key_value)
connection.insert(
im, "#{self} Create", primary_key || false, primary_key_value,
returning: returning
)
end

def _update_record(values, constraints) # :nodoc:
Expand Down Expand Up @@ -1235,11 +1238,16 @@ def _update_record(attribute_names = self.attribute_names)
def _create_record(attribute_names = self.attribute_names)
attribute_names = attributes_for_create(attribute_names)

new_id = self.class._insert_record(
attributes_with_values(attribute_names)
returning_columns = self.class._returning_columns_for_insert

returning_values = self.class._insert_record(
attributes_with_values(attribute_names),
returning_columns
)

self.id ||= new_id if @primary_key
returning_columns.zip(returning_values).each do |column, value|
_write_attribute(column, value) if !_read_attribute(column)
end if returning_values

@new_record = false
@previously_new_record = true
Expand Down
1 change: 1 addition & 0 deletions activerecord/test/cases/adapters/postgresql/uuid_test.rb
Expand Up @@ -39,6 +39,7 @@ class UUIDType < ActiveRecord::Base
end

teardown do
UUIDType.reset_column_information
drop_table "uuid_data_type"
end

Expand Down
36 changes: 36 additions & 0 deletions activerecord/test/cases/adapters/sqlite3/sqlite3_adapter_test.rb
Expand Up @@ -710,6 +710,42 @@ def test_strict_strings_by_default_and_false_in_database_yml
end
end

def test_rowid_column
with_example_table "id_uppercase INTEGER PRIMARY KEY" do
assert @conn.columns("ex").index_by(&:name)["id_uppercase"].rowid
end
end

def test_lowercase_rowid_column
with_example_table "id_lowercase integer PRIMARY KEY" do
assert @conn.columns("ex").index_by(&:name)["id_lowercase"].rowid
end
end

def test_non_integer_column_returns_false_for_rowid
with_example_table "id_int_short int PRIMARY KEY" do
assert_not @conn.columns("ex").index_by(&:name)["id_int_short"].rowid
end
end

def test_mixed_case_integer_colum_returns_true_for_rowid
with_example_table "id_mixed_case InTeGeR PRIMARY KEY" do
assert @conn.columns("ex").index_by(&:name)["id_mixed_case"].rowid
end
end

def test_rowid_column_with_autoincrement_returns_true_for_rowid
with_example_table "id_autoincrement integer PRIMARY KEY AUTOINCREMENT" do
assert @conn.columns("ex").index_by(&:name)["id_autoincrement"].rowid
end
end

def test_integer_cpk_column_returns_false_for_rowid
with_example_table("id integer, shop_id integer, PRIMARY KEY (shop_id, id)", "cpk_table") do
assert_not @conn.columns("cpk_table").any?(&:rowid)
end
end

private
def assert_logged(logs)
subscriber = SQLSubscriber.new
Expand Down

0 comments on commit 3421e89

Please sign in to comment.