diff --git a/QueryBuilder.Tests/GeneralTests.cs b/QueryBuilder.Tests/GeneralTests.cs index 428b0567..2f5f7523 100644 --- a/QueryBuilder.Tests/GeneralTests.cs +++ b/QueryBuilder.Tests/GeneralTests.cs @@ -401,5 +401,39 @@ public void Where_Nested() Assert.Equal("SELECT * FROM [table] WHERE ([a] = 1 OR [a] = 2)", c[EngineCodes.SqlServer].ToString()); } + + [Fact] + public void UnsafeLiteral_Insert() + { + var query = new Query("Table").AsInsert(new + { + Count = new UnsafeLiteral("Count + 1") + }); + + var engines = new[] { + EngineCodes.SqlServer, + }; + + var c = Compilers.Compile(engines, query); + + Assert.Equal("INSERT INTO [Table] ([Count]) VALUES (Count + 1)", c[EngineCodes.SqlServer].ToString()); + } + + [Fact] + public void UnsafeLiteral_Update() + { + var query = new Query("Table").AsUpdate(new + { + Count = new UnsafeLiteral("Count + 1") + }); + + var engines = new[] { + EngineCodes.SqlServer, + }; + + var c = Compilers.Compile(engines, query); + + Assert.Equal("UPDATE [Table] SET [Count] = Count + 1", c[EngineCodes.SqlServer].ToString()); + } } } diff --git a/QueryBuilder.Tests/UpdateTests.cs b/QueryBuilder.Tests/UpdateTests.cs index 06e44064..994feddc 100644 --- a/QueryBuilder.Tests/UpdateTests.cs +++ b/QueryBuilder.Tests/UpdateTests.cs @@ -283,5 +283,37 @@ public void UpdateUsingExpandoObject() "UPDATE [Table] SET [Name] = 'The User', [Age] = '2018-01-01'", c[EngineCodes.SqlServer]); } + + [Fact] + public void IncrementUpdate() + { + var query = new Query("Table").AsIncrement("Total"); + var c = Compile(query); + Assert.Equal("UPDATE [Table] SET [Total] = [Total] + 1", c[EngineCodes.SqlServer]); + } + + [Fact] + public void IncrementUpdateWithValue() + { + var query = new Query("Table").AsIncrement("Total", 2); + var c = Compile(query); + Assert.Equal("UPDATE [Table] SET [Total] = [Total] + 2", c[EngineCodes.SqlServer]); + } + + [Fact] + public void IncrementUpdateWithWheres() + { + var query = new Query("Table").Where("Name", "A").AsIncrement("Total", 2); + var c = Compile(query); + Assert.Equal("UPDATE [Table] SET [Total] = [Total] + 2 WHERE [Name] = 'A'", c[EngineCodes.SqlServer]); + } + + [Fact] + public void DecrementUpdate() + { + var query = new Query("Table").Where("Name", "A").AsDecrement("Total", 2); + var c = Compile(query); + Assert.Equal("UPDATE [Table] SET [Total] = [Total] - 2 WHERE [Name] = 'A'", c[EngineCodes.SqlServer]); + } } } \ No newline at end of file diff --git a/QueryBuilder/Clauses/IncrementClause.cs b/QueryBuilder/Clauses/IncrementClause.cs new file mode 100644 index 00000000..4ee5a194 --- /dev/null +++ b/QueryBuilder/Clauses/IncrementClause.cs @@ -0,0 +1,19 @@ +namespace SqlKata +{ + public class IncrementClause : InsertClause + { + public string Column { get; set; } + public int Value { get; set; } = 1; + + public override AbstractClause Clone() + { + return new IncrementClause + { + Engine = Engine, + Component = Component, + Column = Column, + Value = Value + }; + } + } +} \ No newline at end of file diff --git a/QueryBuilder/Compilers/Compiler.cs b/QueryBuilder/Compilers/Compiler.cs index f6abb967..28abe5a6 100644 --- a/QueryBuilder/Compilers/Compiler.cs +++ b/QueryBuilder/Compilers/Compiler.cs @@ -285,8 +285,31 @@ protected virtual SqlResult CompileUpdateQuery(Query query) throw new InvalidOperationException("Invalid table expression"); } - var toUpdate = ctx.Query.GetOneComponent("update", EngineCode); + // check for increment statements + var clause = ctx.Query.GetOneComponent("update", EngineCode); + + string wheres; + + if (clause != null && clause is IncrementClause increment) + { + var column = Wrap(increment.Column); + var value = Parameter(ctx, Math.Abs(increment.Value)); + var sign = increment.Value >= 0 ? "+" : "-"; + + wheres = CompileWheres(ctx); + if (!string.IsNullOrEmpty(wheres)) + { + wheres = " " + wheres; + } + + ctx.RawSql = $"UPDATE {table} SET {column} = {column} {sign} {value}{wheres}"; + + return ctx; + } + + + var toUpdate = ctx.Query.GetOneComponent("update", EngineCode); var parts = new List(); for (var i = 0; i < toUpdate.Columns.Count; i++) @@ -294,16 +317,16 @@ protected virtual SqlResult CompileUpdateQuery(Query query) parts.Add(Wrap(toUpdate.Columns[i]) + " = " + Parameter(ctx, toUpdate.Values[i])); } - var where = CompileWheres(ctx); + var sets = string.Join(", ", parts); - if (!string.IsNullOrEmpty(where)) + wheres = CompileWheres(ctx); + + if (!string.IsNullOrEmpty(wheres)) { - where = " " + where; + wheres = " " + wheres; } - var sets = string.Join(", ", parts); - - ctx.RawSql = $"UPDATE {table} SET {sets}{where}"; + ctx.RawSql = $"UPDATE {table} SET {sets}{wheres}"; return ctx; } diff --git a/QueryBuilder/Query.Update.cs b/QueryBuilder/Query.Update.cs index ced8bab8..d88aeb00 100644 --- a/QueryBuilder/Query.Update.cs +++ b/QueryBuilder/Query.Update.cs @@ -54,5 +54,22 @@ public Query AsUpdate(IEnumerable> values) return this; } + + public Query AsIncrement(string column, int value = 1) + { + Method = "update"; + AddOrReplaceComponent("update", new IncrementClause + { + Column = column, + Value = value + }); + + return this; + } + + public Query AsDecrement(string column, int value = 1) + { + return AsIncrement(column, -value); + } } } diff --git a/SqlKata.Execution/Query.Extensions.cs b/SqlKata.Execution/Query.Extensions.cs index 19a90d4d..5f3eeefd 100644 --- a/SqlKata.Execution/Query.Extensions.cs +++ b/SqlKata.Execution/Query.Extensions.cs @@ -169,9 +169,9 @@ public static int Insert(this Query query, IEnumerable columns, IEnumera return CreateQueryFactory(query).Execute(query.AsInsert(columns, valuesCollection), transaction, timeout); } - public static async Task InsertAsync(this Query query, IEnumerable columns, IEnumerable> valuesCollection, IDbTransaction transaction = null, int? timeout = null) + public static async Task InsertAsync(this Query query, IEnumerable columns, IEnumerable> valuesCollection, IDbTransaction transaction = null, int? timeout = null, CancellationToken cancellationToken = default) { - return await CreateQueryFactory(query).ExecuteAsync(query.AsInsert(columns, valuesCollection), transaction, timeout); + return await CreateQueryFactory(query).ExecuteAsync(query.AsInsert(columns, valuesCollection), transaction, timeout, cancellationToken); } public static int Insert(this Query query, IEnumerable columns, Query fromQuery, IDbTransaction transaction = null, int? timeout = null)