Skip to content

Commit

Permalink
Merge pull request #45625 from adrianna-chang-shopify/ac-extract-buil…
Browse files Browse the repository at this point in the history
…d-create-table-definition

Extract `#build_create_table_definition` method
  • Loading branch information
eileencodes committed Jul 20, 2022
2 parents b5a758d + d747b0f commit 69078b0
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 21 deletions.
Expand Up @@ -66,7 +66,7 @@ def visit_TableDefinition(o)
create_sql << "(#{statements.join(', ')})" if statements.present?
add_table_options!(create_sql, o)
create_sql << " AS #{to_sql(o.as)}" if o.as
create_sql
o.ddl = create_sql
end

def visit_PrimaryKeyDefinition(o)
Expand Down
Expand Up @@ -333,6 +333,7 @@ class TableDefinition
include ColumnMethods

attr_reader :name, :temporary, :if_not_exists, :options, :as, :comment, :indexes, :foreign_keys, :check_constraints
attr_accessor :ddl

def initialize(
conn,
Expand All @@ -358,6 +359,23 @@ def initialize(
@comment = comment
end

def set_primary_key(table_name, id, primary_key, **options)
if id && !as
pk = primary_key || Base.get_primary_key(table_name.to_s.singularize)

if id.is_a?(Hash)
options.merge!(id.except(:type))
id = id.fetch(:type, :primary_key)
end

if pk.is_a?(Array)
primary_keys(pk)
else
primary_key(pk, id, **options)
end
end
end

def primary_keys(name = nil) # :nodoc:
@primary_keys = PrimaryKeyDefinition.new(name) if name
@primary_keys
Expand Down
Expand Up @@ -289,34 +289,17 @@ def primary_key(table_name)
# SELECT * FROM orders INNER JOIN line_items ON order_id=orders.id
#
# See also TableDefinition#column for details on how to create columns.
def create_table(table_name, id: :primary_key, primary_key: nil, force: nil, **options)
def create_table(table_name, id: :primary_key, primary_key: nil, force: nil, **options, &block)
validate_table_length!(table_name) unless options[:_uses_legacy_table_name]
td = create_table_definition(table_name, **extract_table_options!(options))

if id && !td.as
pk = primary_key || Base.get_primary_key(table_name.to_s.singularize)

if id.is_a?(Hash)
options.merge!(id.except(:type))
id = id.fetch(:type, :primary_key)
end

if pk.is_a?(Array)
td.primary_keys pk
else
td.primary_key pk, id, **options
end
end

yield td if block_given?
td = build_create_table_definition(table_name, id: id, primary_key: primary_key, force: force, **options, &block)

if force
drop_table(table_name, force: force, if_exists: true)
else
schema_cache.clear_data_source_cache!(table_name.to_s)
end

result = execute schema_creation.accept td
result = execute(td.ddl)

unless supports_indexes_in_create?
td.indexes.each do |column_name, index_options|
Expand All @@ -337,6 +320,19 @@ def create_table(table_name, id: :primary_key, primary_key: nil, force: nil, **o
result
end

# Returns a TableDefinition object containing information about the table that would be created
# if the same arguments were passed to #create_table. See #create_table for information about
# passing a +table_name+, and other additional options that can be passed.
def build_create_table_definition(table_name, id: :primary_key, primary_key: nil, force: nil, **options)
table_definition = create_table_definition(table_name, **extract_table_options!(options))
table_definition.set_primary_key(table_name, id, primary_key, **options)

yield table_definition if block_given?

schema_creation.accept(table_definition)
table_definition
end

# Creates a new join table with the name created using the lexical order of the first two
# arguments. These arguments can be a String or a Symbol.
#
Expand Down
40 changes: 40 additions & 0 deletions activerecord/test/cases/migration/schema_definitions_test.rb
@@ -0,0 +1,40 @@
# frozen_string_literal: true

require "cases/helper"

module ActiveRecord
class Migration
class SchemaDefinitionsTest < ActiveRecord::TestCase
attr_reader :connection

def setup
@connection = ActiveRecord::Base.connection
end

def test_build_create_table_definition_with_block
td = connection.build_create_table_definition :test do |t|
t.column :foo, :string
end

id_column = td.columns.find { |col| col.name == "id" }
assert_predicate id_column, :present?
assert id_column.type
assert id_column.sql_type

foo_column = td.columns.find { |col| col.name == "foo" }
assert_predicate foo_column, :present?
assert foo_column.type
assert foo_column.sql_type
end

def test_build_create_table_definition_without_block
td = connection.build_create_table_definition(:test)

id_column = td.columns.find { |col| col.name == "id" }
assert_predicate id_column, :present?
assert id_column.type
assert id_column.sql_type
end
end
end
end

0 comments on commit 69078b0

Please sign in to comment.