Skip to content

Commit e949bea

Browse files
committed
Fixing change_column logic that was causing two unit tests to fail
1 parent 15b753a commit e949bea

File tree

1 file changed

+42
-33
lines changed

1 file changed

+42
-33
lines changed

lib/active_record/connection_adapters/sqlserver/schema_statements.rb

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module ActiveRecord
22
module ConnectionAdapters
33
module Sqlserver
44
module SchemaStatements
5-
5+
66
def native_database_types
77
@native_database_types ||= initialize_native_database_types.freeze
88
end
@@ -43,11 +43,11 @@ def columns(table_name, name = nil)
4343
SQLServerColumn.new ci[:name], ci[:default_value], ci[:type], ci[:null], sqlserver_options
4444
end
4545
end
46-
46+
4747
def rename_table(table_name, new_name)
4848
do_execute "EXEC sp_rename '#{table_name}', '#{new_name}'"
4949
end
50-
50+
5151
def remove_column(table_name, *column_names)
5252
raise ArgumentError.new("You must specify at least one column name. Example: remove_column(:people, :first_name)") if column_names.empty?
5353
ActiveSupport::Deprecation.warn 'Passing array to remove_columns is deprecated, please use multiple arguments, like: `remove_columns(:posts, :foo, :bar)`', caller if column_names.flatten!
@@ -61,16 +61,25 @@ def remove_column(table_name, *column_names)
6161

6262
def change_column(table_name, column_name, type, options = {})
6363
sql_commands = []
64+
indexes = []
6465
column_object = schema_cache.columns[table_name].detect { |c| c.name.to_s == column_name.to_s }
65-
change_column_sql = "ALTER TABLE #{quote_table_name(table_name)} ALTER COLUMN #{quote_column_name(column_name)} #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
66-
change_column_sql << " NOT NULL" if options[:null] == false
67-
sql_commands << change_column_sql
66+
6867
if options_include_default?(options) || (column_object && column_object.type != type.to_sym)
69-
remove_default_constraint(table_name,column_name)
68+
remove_default_constraint(table_name,column_name)
69+
indexes = indexes(table_name).select{ |index| index.columns.include?(column_name.to_s) }
70+
remove_indexes(table_name, column_name)
7071
end
72+
sql_commands << "UPDATE #{quote_table_name(table_name)} SET #{quote_column_name(column_name)}=#{quote(options[:default])} WHERE #{quote_column_name(column_name)} IS NULL" if !options[:null].nil? && options[:null] == false && !options[:default].nil?
73+
sql_commands << "ALTER TABLE #{quote_table_name(table_name)} ALTER COLUMN #{quote_column_name(column_name)} #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
74+
sql_commands[-1] << " NOT NULL" if !options[:null].nil? && options[:null] == false
7175
if options_include_default?(options)
7276
sql_commands << "ALTER TABLE #{quote_table_name(table_name)} ADD CONSTRAINT #{default_constraint_name(table_name,column_name)} DEFAULT #{quote(options[:default])} FOR #{quote_column_name(column_name)}"
7377
end
78+
79+
#Add any removed indexes back
80+
indexes.each do |index|
81+
sql_commands << "CREATE INDEX #{quote_table_name(index.name)} ON #{quote_table_name(table_name)} (#{index.columns.collect {|c|quote_column_name(c)}.join(', ')})"
82+
end
7483
sql_commands.each { |c| do_execute(c) }
7584
end
7685

@@ -83,7 +92,7 @@ def rename_column(table_name, column_name, new_column_name)
8392
detect_column_for! table_name, column_name
8493
do_execute "EXEC sp_rename '#{table_name}.#{column_name}', '#{new_column_name}', 'COLUMN'"
8594
end
86-
95+
8796
def remove_index!(table_name, index_name)
8897
do_execute "DROP INDEX #{quote_column_name(index_name)} ON #{quote_table_name(table_name)}"
8998
end
@@ -104,27 +113,27 @@ def type_to_sql(type, limit = nil, precision = nil, scale = nil)
104113
end
105114
end
106115

107-
def change_column_null(table_name, column_name, null, default = nil)
116+
def change_column_null(table_name, column_name, allow_null, default = nil)
108117
column = detect_column_for! table_name, column_name
109-
unless null || default.nil?
118+
if !allow_null.nil? && allow_null == false && !default.nil?
110119
do_execute("UPDATE #{quote_table_name(table_name)} SET #{quote_column_name(column_name)}=#{quote(default)} WHERE #{quote_column_name(column_name)} IS NULL")
111120
end
112121
sql = "ALTER TABLE #{table_name} ALTER COLUMN #{quote_column_name(column_name)} #{type_to_sql column.type, column.limit, column.precision, column.scale}"
113-
sql << " NOT NULL" unless null
122+
sql << " NOT NULL" if !allow_null.nil? && allow_null == false
114123
do_execute sql
115124
end
116-
125+
117126
# === SQLServer Specific ======================================== #
118-
127+
119128
def views
120129
tables('VIEW')
121130
end
122-
123-
131+
132+
124133
protected
125-
134+
126135
# === SQLServer Specific ======================================== #
127-
136+
128137
def initialize_native_database_types
129138
{
130139
:primary_key => "int NOT NULL IDENTITY(1,1) PRIMARY KEY",
@@ -156,7 +165,7 @@ def column_definitions(table_name)
156165
table_schema = Utils.unqualify_table_schema(table_name)
157166
table_name = Utils.unqualify_table_name(table_name)
158167
sql = %{
159-
SELECT DISTINCT
168+
SELECT DISTINCT
160169
#{lowercase_schema_reflection_sql('columns.TABLE_NAME')} AS table_name,
161170
#{lowercase_schema_reflection_sql('columns.COLUMN_NAME')} AS name,
162171
columns.DATA_TYPE AS type,
@@ -172,7 +181,7 @@ def column_definitions(table_name)
172181
WHEN columns.IS_NULLABLE = 'YES' THEN 1
173182
ELSE NULL
174183
END AS [is_nullable],
175-
CASE
184+
CASE
176185
WHEN KCU.COLUMN_NAME IS NOT NULL AND TC.CONSTRAINT_TYPE = N'PRIMARY KEY' THEN 1
177186
ELSE NULL
178187
END AS [is_primary],
@@ -240,7 +249,7 @@ def column_definitions(table_name)
240249
ci
241250
end
242251
end
243-
252+
244253
def remove_check_constraints(table_name, column_name)
245254
constraints = select_values "SELECT CONSTRAINT_NAME FROM INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE where TABLE_NAME = '#{quote_string(table_name)}' and COLUMN_NAME = '#{quote_string(column_name)}'", 'SCHEMA'
246255
constraints.each do |constraint|
@@ -262,9 +271,9 @@ def remove_indexes(table_name, column_name)
262271
remove_index(table_name, {:name => index.name})
263272
end
264273
end
265-
274+
266275
# === SQLServer Specific (Misc Helpers) ========================= #
267-
276+
268277
def get_table_name(sql)
269278
if sql =~ /^\s*(INSERT|EXEC sp_executesql N'INSERT)\s+INTO\s+([^\(\s]+)\s*|^\s*update\s+([^\(\s]+)\s*/i
270279
$2 || $3
@@ -274,29 +283,29 @@ def get_table_name(sql)
274283
nil
275284
end
276285
end
277-
286+
278287
def default_constraint_name(table_name, column_name)
279288
"DF_#{table_name}_#{column_name}"
280289
end
281-
290+
282291
def detect_column_for!(table_name, column_name)
283292
unless column = schema_cache.columns[table_name].detect { |c| c.name == column_name.to_s }
284293
raise ActiveRecordError, "No such column: #{table_name}.#{column_name}"
285294
end
286295
column
287296
end
288-
297+
289298
def lowercase_schema_reflection_sql(node)
290299
lowercase_schema_reflection ? "LOWER(#{node})" : node
291300
end
292-
301+
293302
# === SQLServer Specific (View Reflection) ====================== #
294-
303+
295304
def view_table_name(table_name)
296305
view_info = schema_cache.view_information(table_name)
297306
view_info ? get_table_name(view_info['VIEW_DEFINITION']) : table_name
298307
end
299-
308+
300309
def view_information(table_name)
301310
table_name = Utils.unqualify_table_name(table_name)
302311
view_info = select_one "SELECT * FROM INFORMATION_SCHEMA.VIEWS WHERE TABLE_NAME = '#{table_name}'", 'SCHEMA'
@@ -313,18 +322,18 @@ def view_information(table_name)
313322
end
314323
view_info
315324
end
316-
325+
317326
def table_name_or_views_table_name(table_name)
318327
unquoted_table_name = Utils.unqualify_table_name(table_name)
319328
schema_cache.view_names.include?(unquoted_table_name) ? view_table_name(unquoted_table_name) : unquoted_table_name
320329
end
321-
330+
322331
def views_real_column_name(table_name,column_name)
323332
view_definition = schema_cache.view_information(table_name)[:VIEW_DEFINITION]
324333
match_data = view_definition.match(/([\w-]*)\s+as\s+#{column_name}/im)
325334
match_data ? match_data[1] : column_name
326335
end
327-
336+
328337
# === SQLServer Specific (Identity Inserts) ===================== #
329338

330339
def query_requires_identity_insert?(sql)
@@ -336,11 +345,11 @@ def query_requires_identity_insert?(sql)
336345
false
337346
end
338347
end
339-
348+
340349
def insert_sql?(sql)
341350
!(sql =~ /^\s*(INSERT|EXEC sp_executesql N'INSERT)/i).nil?
342351
end
343-
352+
344353
def with_identity_insert_enabled(table_name)
345354
table_name = quote_table_name(table_name_or_views_table_name(table_name))
346355
set_identity_insert(table_name, true)

0 commit comments

Comments
 (0)