diff --git a/core/jvm/src/main/scala/zio/sql/expr.scala b/core/jvm/src/main/scala/zio/sql/expr.scala index 5b4c736c5..e4f3c7db0 100644 --- a/core/jvm/src/main/scala/zio/sql/expr.scala +++ b/core/jvm/src/main/scala/zio/sql/expr.scala @@ -412,7 +412,6 @@ trait ExprModule extends NewtypesModule with FeaturesModule with OpsModule { val Ascii = FunctionDef[String, Int](FunctionName("ascii")) val CharLength = FunctionDef[String, Int](FunctionName("character_length")) val Concat = FunctionDef[(String, String), String](FunctionName("concat")) // todo varargs - val ConcatWs2 = FunctionDef[(String, String), String](FunctionName("concat_ws")) val ConcatWs3 = FunctionDef[(String, String, String), String](FunctionName("concat_ws")) val ConcatWs4 = FunctionDef[(String, String, String, String), String](FunctionName("concat_ws")) val Lower = FunctionDef[String, String](FunctionName("lower")) diff --git a/mysql/src/test/scala/zio/sql/mysql/CommonFunctionDefSpec.scala b/mysql/src/test/scala/zio/sql/mysql/CommonFunctionDefSpec.scala new file mode 100644 index 000000000..3297805ad --- /dev/null +++ b/mysql/src/test/scala/zio/sql/mysql/CommonFunctionDefSpec.scala @@ -0,0 +1,313 @@ +package zio.sql.mysql + +import zio.Cause +import zio.stream.ZStream +import zio.test.Assertion._ +import zio.test._ + +object CommonFunctionDefSpec extends MysqlRunnableSpec with ShopSchema { + import FunctionDef.{ CharLength => _, _ } + import Customers._ + + private def collectAndCompare[R, E]( + expected: Seq[String], + testResult: ZStream[R, E, String] + ) = + assertZIO(testResult.runCollect)(hasSameElementsDistinct(expected)) + + override def specLayered = suite("MySQL Common FunctionDef")( + suite("Schema dependent tests")( + test("concat_ws #2 - combine columns") { + + // note: you can't use customerId here as it is a UUID, hence not a string in our book + val query = select(ConcatWs3(Customers.fName, Customers.fName, Customers.lName)) from customers + + val expected = Seq( + "RonaldRonaldRussell", + "TerrenceTerrenceNoel", + "MilaMilaPaterso", + "AlanaAlanaMurray", + "JoseJoseWiggins" + ) + + val testResult = execute(query) + collectAndCompare(expected, testResult) + }, + test("concat_ws #3 - combine columns and flat values") { + + val query = select(ConcatWs4(" ", "Person:", Customers.fName, Customers.lName)) from customers + + val expected = Seq( + "Person: Ronald Russell", + "Person: Terrence Noel", + "Person: Mila Paterso", + "Person: Alana Murray", + "Person: Jose Wiggins" + ) + + val testResult = execute(query) + collectAndCompare(expected, testResult) + }, + test("concat_ws #3 - combine function calls together") { + + val query = select( + ConcatWs3(" and ", Concat("Name: ", Customers.fName), Concat("Surname: ", Customers.lName)) + ) from customers + + val expected = Seq( + "Name: Ronald and Surname: Russell", + "Name: Terrence and Surname: Noel", + "Name: Mila and Surname: Paterso", + "Name: Alana and Surname: Murray", + "Name: Jose and Surname: Wiggins" + ) + + val testResult = execute(query) + collectAndCompare(expected, testResult) + }, + test("lower") { + val query = select(Lower(fName)) from customers limit (1) + + val expected = "ronald" + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("Can concat strings with concat function") { + + val query = select(Concat(fName, lName) as "fullname") from customers + + val expected = Seq("RonaldRussell", "TerrenceNoel", "MilaPaterso", "AlanaMurray", "JoseWiggins") + + val result = execute(query) + + val assertion = for { + r <- result.runCollect + } yield assert(r)(hasSameElementsDistinct(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("replace") { + val lastNameReplaced = Replace(lName, "ll", "_") as "lastNameReplaced" + val computedReplace = Replace("special ::ąę::", "ąę", "__") as "computedReplace" + + val query = select(lastNameReplaced, computedReplace) from customers + + val expected = ("Russe_", "special ::__::") + + val testResult = + execute(query).map { case row => + (row._1, row._2) + } + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + } + ), + suite("Schema independent tests")( + test("concat_ws #1 - combine flat values") { + + // note: a plain number (3) would and should not compile + val query = select(ConcatWs4("+", "1", "2", "3")) + + val expected = Seq("1+2+3") + + val testResult = execute(query) + collectAndCompare(expected, testResult) + }, + test("ltrim") { + assertZIO(execute(select(Ltrim(" hello "))).runHead.some)(equalTo("hello ")) + }, + test("rtrim") { + assertZIO(execute(select(Rtrim(" hello "))).runHead.some)(equalTo(" hello")) + }, + test("abs") { + assertZIO(execute(select(Abs(-3.14159))).runHead.some)(equalTo(3.14159)) + }, + test("log") { + assertZIO(execute(select(Log(2.0, 32.0))).runHead.some)(equalTo(5.0)) + }, + test("acos") { + assertZIO(execute(select(Acos(-1.0))).runHead.some)(equalTo(3.141592653589793)) + }, + test("asin") { + assertZIO(execute(select(Asin(0.5))).runHead.some)(equalTo(0.5235987755982989)) + }, + test("ln") { + assertZIO(execute(select(Ln(3.0))).runHead.some)(equalTo(1.0986122886681097)) + }, + test("atan") { + assertZIO(execute(select(Atan(10.0))).runHead.some)(equalTo(1.4711276743037347)) + }, + test("cos") { + assertZIO(execute(select(Cos(3.141592653589793))).runHead.some)(equalTo(-1.0)) + }, + test("exp") { + assertZIO(execute(select(Exp(1.0))).runHead.some)(equalTo(2.718281828459045)) + }, + test("floor") { + assertZIO(execute(select(Floor(-3.14159))).runHead.some)(equalTo(-4.0)) + }, + test("ceil") { + assertZIO(execute(select(Ceil(53.7), Ceil(-53.7))).runHead.some)(equalTo((54.0, -53.0))) + }, + test("sin") { + assertZIO(execute(select(Sin(1.0))).runHead.some)(equalTo(0.8414709848078965)) + }, + test("sqrt") { + val query = select(Sqrt(121.0)) + + val expected = 11.0 + + val testResult = execute(query) + + assertZIO(testResult.runHead.some)(equalTo(expected)) + }, + test("round") { + val query = select(Round(10.8124, 2)) + + val expected = 10.81 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("sign positive") { + val query = select(Sign(3.0)) + + val expected = 1 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("sign negative") { + val query = select(Sign(-3.0)) + + val expected = -1 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("sign zero") { + val query = select(Sign(0.0)) + + val expected = 0 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("power") { + val query = select(Power(7.0, 3.0)) + + val expected = 343.000000000000000 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("mod") { + val query = select(Mod(-15.0, -4.0)) + + val expected = -3.0 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("octet_length") { + val query = select(OctetLength("josé")) + + val expected = 5 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("ascii") { + val query = select(Ascii("""x""")) + + val expected = 120 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("upper") { + val query = (select(Upper("ronald"))).limit(1) + + val expected = "RONALD" + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("tan") { + val query = select(Tan(0.7853981634)) + + val expected = 1.0000000000051035 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("trim") { + assertZIO(execute(select(Trim(" 1234 "))).runHead.some)(equalTo("1234")) + }, + test("lower") { + assertZIO(execute(select(Lower("YES"))).runHead.some)(equalTo("yes")) + } + ) + ) + +} diff --git a/mysql/src/test/scala/zio/sql/mysql/FunctionDefSpec.scala b/mysql/src/test/scala/zio/sql/mysql/CustomFunctionDefSpec.scala similarity index 78% rename from mysql/src/test/scala/zio/sql/mysql/FunctionDefSpec.scala rename to mysql/src/test/scala/zio/sql/mysql/CustomFunctionDefSpec.scala index 3c5953820..b5d98c4f5 100644 --- a/mysql/src/test/scala/zio/sql/mysql/FunctionDefSpec.scala +++ b/mysql/src/test/scala/zio/sql/mysql/CustomFunctionDefSpec.scala @@ -6,56 +6,12 @@ import zio.test.Assertion._ import java.time.{ LocalDate, LocalTime, ZoneId } import java.time.format.DateTimeFormatter -object FunctionDefSpec extends MysqlRunnableSpec with ShopSchema { +object CustomFunctionDefSpec extends MysqlRunnableSpec with ShopSchema { import Customers._ - import FunctionDef._ import MysqlFunctionDef._ override def specLayered = suite("MySQL FunctionDef")( - test("lower") { - val query = select(Lower(fName)) from customers limit (1) - - val expected = "ronald" - - val testResult = execute(query) - - assertZIO(testResult.runHead.some)(equalTo(expected)) - }, - // FIXME: lower with string literal should not refer to a column name - // See: https://www.w3schools.com/sql/trymysql.asp?filename=trysql_func_mysql_lower - // Uncomment the following test when fixed - // test("lower with string literal") { - // val query = select(Lower("LOWER")) from customers limit(1) - // - // val expected = "lower" - // - // val testResult = execute(query.to[String, String](identity)) - // - // val assertion = for { - // r <- testResult.runCollect - // } yield assert(r.head)(equalTo(expected)) - // - // assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) - // }, - test("sin") { - val query = select(Sin(1.0)) - - val expected = 0.8414709848078965 - - val testResult = execute(query) - - assertZIO(testResult.runHead.some)(equalTo(expected)) - }, - test("abs") { - val query = select(Abs(-32.0)) - - val expected = 32.0 - - val testResult = execute(query) - - assertZIO(testResult.runHead.some)(equalTo(expected)) - }, test("crc32") { val query = select(Crc32("MySQL")) from customers diff --git a/oracle/src/main/scala/zio/sql/oracle/OracleRenderModule.scala b/oracle/src/main/scala/zio/sql/oracle/OracleRenderModule.scala index c1e22b34d..c874eba74 100644 --- a/oracle/src/main/scala/zio/sql/oracle/OracleRenderModule.scala +++ b/oracle/src/main/scala/zio/sql/oracle/OracleRenderModule.scala @@ -3,6 +3,7 @@ package zio.sql.oracle import zio.schema.Schema import zio.schema.DynamicValue import zio.schema.StandardType + import java.time.Instant import java.time.LocalDate import java.time.LocalDateTime @@ -10,12 +11,14 @@ import java.time.LocalTime import java.time.OffsetTime import java.time.ZonedDateTime import zio.sql.driver.Renderer -import zio.sql.driver.Renderer.Extensions import zio.Chunk + import scala.collection.mutable import java.time.OffsetDateTime import java.time.YearMonth import java.time.Duration +import java.time.format.{ DateTimeFormatter, DateTimeFormatterBuilder } +import java.time.temporal.ChronoField._ trait OracleRenderModule extends OracleSqlModule { self => @@ -43,6 +46,105 @@ trait OracleRenderModule extends OracleSqlModule { self => render.toString } + private object DateTimeFormats { + val fmtTime = new DateTimeFormatterBuilder() + .appendValue(HOUR_OF_DAY, 2) + .appendLiteral(':') + .appendValue(MINUTE_OF_HOUR, 2) + .appendLiteral(':') + .appendValue(SECOND_OF_MINUTE, 2) + .appendFraction(NANO_OF_SECOND, 9, 9, true) + .appendOffset("+HH:MM", "Z") + .toFormatter() + + val fmtTimeOffset = new DateTimeFormatterBuilder() + .append(fmtTime) + .appendFraction(NANO_OF_SECOND, 9, 9, true) + .toFormatter() + + val fmtDateTime = new DateTimeFormatterBuilder().parseCaseInsensitive + .append(DateTimeFormatter.ISO_LOCAL_DATE) + .appendLiteral('T') + .append(fmtTime) + .toFormatter() + + val fmtDateTimeOffset = new DateTimeFormatterBuilder().parseCaseInsensitive + .append(fmtDateTime) + .appendOffset("+HH:MM", "Z") + .toFormatter() + } + + private def buildLit(lit: self.Expr.Literal[_])(builder: StringBuilder): Unit = { + import TypeTag._ + val value = lit.value + lit.typeTag match { + case TInstant => + val _ = builder.append(s"""TO_TIMESTAMP_TZ('${DateTimeFormats.fmtDateTimeOffset.format( + value.asInstanceOf[Instant] + )}', 'SYYYY-MM-DD"T"HH24:MI:SS.FF9TZH:TZM')""") + case TLocalTime => + val localTime = value.asInstanceOf[LocalTime] + val _ = builder.append( + s"INTERVAL '${localTime.getHour}:${localTime.getMinute}:${localTime.getSecond}.${localTime.getNano}' HOUR TO SECOND(9)" + ) + case TLocalDate => + val _ = builder.append( + s"TO_DATE('${DateTimeFormatter.ISO_LOCAL_DATE.format(value.asInstanceOf[LocalDate])}', 'SYYYY-MM-DD')" + ) + case TLocalDateTime => + val _ = builder.append(s"""TO_TIMESTAMP('${DateTimeFormats.fmtDateTime.format( + value.asInstanceOf[LocalDateTime] + )}', 'SYYYY-MM-DD"T"HH24:MI:SS.FF9')""") + case TZonedDateTime => + val _ = builder.append(s"""TO_TIMESTAMP_TZ('${DateTimeFormats.fmtDateTimeOffset.format( + value.asInstanceOf[ZonedDateTime] + )}', 'SYYYY-MM-DD"T"HH24:MI:SS.FF9TZH:TZM')""") + case TOffsetTime => + val _ = builder.append( + s"TO_TIMESTAMP_TZ('${DateTimeFormats.fmtTimeOffset.format(value.asInstanceOf[OffsetTime])}', 'HH24:MI:SS.FF9TZH:TZM')" + ) + case TOffsetDateTime => + val _ = builder.append( + s"""TO_TIMESTAMP_TZ('${DateTimeFormats.fmtDateTimeOffset.format( + value.asInstanceOf[OffsetDateTime] + )}', 'SYYYY-MM-DD"T"HH24:MI:SS.FF9TZH:TZM')""" + ) + + case TBoolean => + val b = value.asInstanceOf[Boolean] + if (b) { + val _ = builder.append('1') + } else { + val _ = builder.append('0') + } + case TUUID => + val _ = builder.append(s"'$value'") + + case TBigDecimal => + val _ = builder.append(value) + case TByte => + val _ = builder.append(value) + case TDouble => + val _ = builder.append(value) + case TFloat => + val _ = builder.append(value) + case TInt => + val _ = builder.append(value) + case TLong => + val _ = builder.append(value) + case TShort => + val _ = builder.append(value) + + case TChar => + val _ = builder.append(s"N'$value'") + case TString => + val _ = builder.append(s"N'$value'") + + case _ => + val _ = builder.append(s"'$value'") + } + } + // TODO: to consider the refactoring and using the implicit `Renderer`, see `renderExpr` in `PostgresRenderModule` private def buildExpr[A, B](expr: self.Expr[_, A, B], builder: StringBuilder): Unit = expr match { case Expr.Subselect(subselect) => @@ -85,8 +187,8 @@ trait OracleRenderModule extends OracleSqlModule { self => val _ = builder.append("1 = 1") case Expr.Literal(false) => val _ = builder.append("0 = 1") - case Expr.Literal(value) => - val _ = builder.append(value.toString.singleQuoted) + case literal: Expr.Literal[_] => + val _ = buildLit(literal)(builder) case Expr.AggregationCall(param, aggregation) => builder.append(aggregation.name.name) builder.append("(") @@ -178,7 +280,7 @@ trait OracleRenderModule extends OracleSqlModule { self => } /** - * Drops the initial Litaral(true) present at the start of every WHERE expressions by default + * Drops the initial Litaral(true) present at the start of every WHERE expressions by default * and proceeds to the rest of Expr's. */ private def buildWhereExpr[A, B](expr: self.Expr[_, A, B], builder: mutable.StringBuilder): Unit = expr match { diff --git a/oracle/src/main/scala/zio/sql/oracle/OracleSqlModule.scala b/oracle/src/main/scala/zio/sql/oracle/OracleSqlModule.scala index e6b17ff45..2a4f55e5c 100644 --- a/oracle/src/main/scala/zio/sql/oracle/OracleSqlModule.scala +++ b/oracle/src/main/scala/zio/sql/oracle/OracleSqlModule.scala @@ -45,6 +45,11 @@ trait OracleSqlModule extends Sql { self => val Sind = FunctionDef[Double, Double](FunctionName("sind")) } + object Dual { + val dual = (string("dummy")).table("dual") + val (dummy) = dual.columns + } + implicit val instantSchema = Schema.primitive[Instant](zio.schema.StandardType.InstantType(DateTimeFormatter.ISO_OFFSET_DATE_TIME)) diff --git a/oracle/src/test/scala/zio/sql/oracle/CommonFunctionDefSpec.scala b/oracle/src/test/scala/zio/sql/oracle/CommonFunctionDefSpec.scala new file mode 100644 index 000000000..1ffa82096 --- /dev/null +++ b/oracle/src/test/scala/zio/sql/oracle/CommonFunctionDefSpec.scala @@ -0,0 +1,285 @@ +package zio.sql.oracle + +import zio.Cause +import zio.stream.ZStream +import zio.test.Assertion._ +import zio.test._ + +object CommonFunctionDefSpec extends OracleRunnableSpec with ShopSchema { + import FunctionDef.{ CharLength => _, _ } + import Customers._ + import Dual._ + + private def collectAndCompare[R, E]( + expected: Seq[String], + testResult: ZStream[R, E, String] + ) = + assertZIO(testResult.runCollect)(hasSameElementsDistinct(expected)) + + override def specLayered = suite("Oracle Common FunctionDef")( + suite("Schema dependent tests")( + test("concat - combine function calls together") { + + val query = select( + Concat(Concat("Name: ", Customers.fName), Concat(" and Surname: ", Customers.lName)) + ) from customers + + val expected = Seq( + "Name: Ronald and Surname: Russell", + "Name: Terrence and Surname: Noel", + "Name: Mila and Surname: Paterso", + "Name: Alana and Surname: Murray", + "Name: Jose and Surname: Wiggins" + ) + + val testResult = execute(query) + collectAndCompare(expected, testResult) + }, + test("lower") { + val query = select(Lower(fName)) from customers limit (1) + + val expected = "ronald" + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("Can concat strings with concat function") { + + val query = select(Concat(fName, lName) as "fullname") from customers + + val expected = Seq("RonaldRussell", "TerrenceNoel", "MilaPaterso", "AlanaMurray", "JoseWiggins") + + val result = execute(query) + + val assertion = for { + r <- result.runCollect + } yield assert(r)(hasSameElementsDistinct(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("replace") { + val lastNameReplaced = Replace(lName, "ll", "_") as "lastNameReplaced" + val computedReplace = Replace("special ::ąę::", "ąę", "__") as "computedReplace" + + val query = select(lastNameReplaced, computedReplace) from customers + + val expected = ("Russe_", "special ::__::") + + val testResult = + execute(query).map { case row => + (row._1, row._2) + } + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + } + ), + suite("Schema independent tests")( + test("ltrim") { + assertZIO(execute(select(Ltrim(" hello ")).from(dual)).runHead.some)(equalTo("hello ")) + }, + test("rtrim") { + assertZIO(execute(select(Rtrim(" hello ")).from(dual)).runHead.some)(equalTo(" hello")) + }, + test("abs") { + assertZIO(execute(select(Abs(-3.14159)).from(dual)).runHead.some)(equalTo(3.14159)) + }, + test("log") { + assertZIO(execute(select(Log(2.0, 32.0)).from(dual)).runHead.some)(equalTo(5.0)) + }, + test("acos") { + assertZIO(execute(select(Acos(-1.0)).from(dual)).runHead.some)(equalTo(3.141592653589793)) + }, + test("asin") { + assertZIO(execute(select(Asin(0.5)).from(dual)).runHead.some)(equalTo(0.5235987755982989)) + }, + test("ln") { + assertZIO(execute(select(Ln(3.0)).from(dual)).runHead.some)(equalTo(1.0986122886681097)) + }, + test("atan") { + assertZIO(execute(select(Atan(10.0)).from(dual)).runHead.some)(equalTo(1.4711276743037347)) + }, + test("cos") { + assertZIO(execute(select(Cos(3.141592653589793)).from(dual)).runHead.some)(equalTo(-1.0)) + }, + test("exp") { + assertZIO(execute(select(Exp(1.0)).from(dual)).runHead.some)(equalTo(2.718281828459045)) + }, + test("floor") { + assertZIO(execute(select(Floor(-3.14159)).from(dual)).runHead.some)(equalTo(-4.0)) + }, + test("ceil") { + assertZIO(execute(select(Ceil(53.7), Ceil(-53.7)).from(dual)).runHead.some)(equalTo((54.0, -53.0))) + }, + test("sin") { + assertZIO(execute(select(Sin(1.0)).from(dual)).runHead.some)(equalTo(0.8414709848078965)) + }, + test("sqrt") { + val query = select(Sqrt(121.0)).from(dual) + val expected = 11.0 + + val testResult = execute(query) + + assertZIO(testResult.runHead.some)(equalTo(expected)) + }, + test("round") { + val query = select(Round(10.8124, 2)).from(dual) + + val expected = 10.81 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("sign positive") { + val query = select(Sign(3.0)).from(dual) + + val expected = 1 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("sign negative") { + val query = select(Sign(-3.0)).from(dual) + + val expected = -1 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("sign zero") { + val query = select(Sign(0.0)).from(dual) + + val expected = 0 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("power") { + val query = select(Power(7.0, 3.0)).from(dual) + + val expected = 343.000000000000000 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("mod") { + val query = select(Mod(-15.0, -4.0)).from(dual) + + val expected = -3.0 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("octet_length") { + val query = select(OctetLength("josé")).from(dual) + + val expected = 5 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + } @@ TestAspect.ignore @@ TestAspect.tag("lengthb"), + test("ascii") { + val query = select(Ascii("""x""")).from(dual) + + val expected = 120 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("upper") { + val query = (select(Upper("ronald")).from(dual)).limit(1) + + val expected = "RONALD" + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("width_bucket") { + val query = select(WidthBucket(5.35, 0.024, 10.06, 5)).from(dual) + + val expected = 3 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("tan") { + val query = select(Tan(0.7853981634)).from(dual) + + val expected = 1.0000000000051035 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("trim") { + assertZIO(execute(select(Trim(" 1234 ")).from(dual)).runHead.some)(equalTo("1234")) + }, + test("lower") { + assertZIO(execute(select(Lower("YES")).from(dual)).runHead.some)(equalTo("yes")) + } + ) + ) + +} diff --git a/postgres/src/test/scala/zio/sql/postgresql/CommonFunctionDefSpec.scala b/postgres/src/test/scala/zio/sql/postgresql/CommonFunctionDefSpec.scala new file mode 100644 index 000000000..2ef317344 --- /dev/null +++ b/postgres/src/test/scala/zio/sql/postgresql/CommonFunctionDefSpec.scala @@ -0,0 +1,329 @@ +package zio.sql.postgresql + +import zio.stream.ZStream +import zio.test.Assertion._ +import zio.test._ +import zio.Cause + +object CommonFunctionDefSpec extends PostgresRunnableSpec with DbSchema { + import FunctionDef.{ CharLength => _, _ } + import Customers._ + + private def collectAndCompare[R, E]( + expected: Seq[String], + testResult: ZStream[R, E, String] + ) = + assertZIO(testResult.runCollect)(hasSameElementsDistinct(expected)) + + override def specLayered = suite("Postgres Common FunctionDef")( + suite("Schema dependent tests")( + test("concat_ws #2 - combine columns") { + + // note: you can't use customerId here as it is a UUID, hence not a string in our book + val query = select(ConcatWs3(Customers.fName, Customers.fName, Customers.lName)) from customers + + val expected = Seq( + "RonaldRonaldRussell", + "TerrenceTerrenceNoel", + "MilaMilaPaterso", + "AlanaAlanaMurray", + "JoseJoseWiggins" + ) + + val testResult = execute(query) + collectAndCompare(expected, testResult) + }, + test("concat_ws #3 - combine columns and flat values") { + import Expr._ + + val query = select(ConcatWs4(" ", "Person:", Customers.fName, Customers.lName)) from customers + + val expected = Seq( + "Person: Ronald Russell", + "Person: Terrence Noel", + "Person: Mila Paterso", + "Person: Alana Murray", + "Person: Jose Wiggins" + ) + + val testResult = execute(query) + collectAndCompare(expected, testResult) + }, + test("concat_ws #3 - combine function calls together") { + import Expr._ + + val query = select( + ConcatWs3(" and ", Concat("Name: ", Customers.fName), Concat("Surname: ", Customers.lName)) + ) from customers + + val expected = Seq( + "Name: Ronald and Surname: Russell", + "Name: Terrence and Surname: Noel", + "Name: Mila and Surname: Paterso", + "Name: Alana and Surname: Murray", + "Name: Jose and Surname: Wiggins" + ) + + val testResult = execute(query) + collectAndCompare(expected, testResult) + }, + test("lower") { + val query = select(Lower(fName)) from customers limit (1) + + val expected = "ronald" + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("Can concat strings with concat function") { + + val query = select(Concat(fName, lName) as "fullname") from customers + + val expected = Seq("RonaldRussell", "TerrenceNoel", "MilaPaterso", "AlanaMurray", "JoseWiggins") + + val result = execute(query) + + val assertion = for { + r <- result.runCollect + } yield assert(r)(hasSameElementsDistinct(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("replace") { + val lastNameReplaced = Replace(lName, "ll", "_") as "lastNameReplaced" + val computedReplace = Replace("special ::ąę::", "ąę", "__") as "computedReplace" + + val query = select(lastNameReplaced, computedReplace) from customers + + val expected = ("Russe_", "special ::__::") + + val testResult = + execute(query).map { case row => + (row._1, row._2) + } + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + } + ), + suite("Schema independent tests")( + test("concat_ws #1 - combine flat values") { + import Expr._ + + // note: a plain number (3) would and should not compile + val query = select(ConcatWs4("+", "1", "2", "3")) + + val expected = Seq("1+2+3") + + val testResult = execute(query) + collectAndCompare(expected, testResult) + }, + test("ltrim") { + assertZIO(execute(select(Ltrim(" hello "))).runHead.some)(equalTo("hello ")) + }, + test("rtrim") { + assertZIO(execute(select(Rtrim(" hello "))).runHead.some)(equalTo(" hello")) + }, + test("abs") { + assertZIO(execute(select(Abs(-3.14159))).runHead.some)(equalTo(3.14159)) + }, + test("log") { + assertZIO(execute(select(Log(2.0, 32.0))).runHead.some)(equalTo(5.0)) + }, + test("acos") { + assertZIO(execute(select(Acos(-1.0))).runHead.some)(equalTo(3.141592653589793)) + }, + test("asin") { + assertZIO(execute(select(Asin(0.5))).runHead.some)(equalTo(0.5235987755982989)) + }, + test("ln") { + assertZIO(execute(select(Ln(3.0))).runHead.some)(equalTo(1.0986122886681097)) + }, + test("atan") { + assertZIO(execute(select(Atan(10.0))).runHead.some)(equalTo(1.4711276743037347)) + }, + test("cos") { + assertZIO(execute(select(Cos(3.141592653589793))).runHead.some)(equalTo(-1.0)) + }, + test("exp") { + assertZIO(execute(select(Exp(1.0))).runHead.some)(equalTo(2.718281828459045)) + }, + test("floor") { + assertZIO(execute(select(Floor(-3.14159))).runHead.some)(equalTo(-4.0)) + }, + test("ceil") { + assertZIO(execute(select(Ceil(53.7), Ceil(-53.7))).runHead.some)(equalTo((54.0, -53.0))) + }, + test("sin") { + assertZIO(execute(select(Sin(1.0))).runHead.some)(equalTo(0.8414709848078965)) + }, + test("sqrt") { + val query = select(Sqrt(121.0)) + + val expected = 11.0 + + val testResult = execute(query) + + assertZIO(testResult.runHead.some)(equalTo(expected)) + }, + test("round") { + val query = select(Round(10.8124, 2)) + + val expected = 10.81 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("sign positive") { + val query = select(Sign(3.0)) + + val expected = 1 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("sign negative") { + val query = select(Sign(-3.0)) + + val expected = -1 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("sign zero") { + val query = select(Sign(0.0)) + + val expected = 0 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("power") { + val query = select(Power(7.0, 3.0)) + + val expected = 343.000000000000000 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("mod") { + val query = select(Mod(-15.0, -4.0)) + + val expected = -3.0 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("octet_length") { + val query = select(OctetLength("josé")) + + val expected = 5 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("ascii") { + val query = select(Ascii("""x""")) + + val expected = 120 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("upper") { + val query = (select(Upper("ronald"))).limit(1) + + val expected = "RONALD" + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("width_bucket") { + val query = select(WidthBucket(5.35, 0.024, 10.06, 5)) + + val expected = 3 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("tan") { + val query = select(Tan(0.7853981634)) + + val expected = 1.0000000000051035 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("trim") { + assertZIO(execute(select(Trim(" 1234 "))).runHead.some)(equalTo("1234")) + }, + test("lower") { + assertZIO(execute(select(Lower("YES"))).runHead.some)(equalTo("yes")) + } + ) + ) + +} diff --git a/postgres/src/test/scala/zio/sql/postgresql/FunctionDefSpec.scala b/postgres/src/test/scala/zio/sql/postgresql/CustomFunctionDefSpec.scala similarity index 73% rename from postgres/src/test/scala/zio/sql/postgresql/FunctionDefSpec.scala rename to postgres/src/test/scala/zio/sql/postgresql/CustomFunctionDefSpec.scala index 2d0936f95..eee241858 100644 --- a/postgres/src/test/scala/zio/sql/postgresql/FunctionDefSpec.scala +++ b/postgres/src/test/scala/zio/sql/postgresql/CustomFunctionDefSpec.scala @@ -10,10 +10,9 @@ import java.time._ import java.time.format.DateTimeFormatter import java.util.UUID -object FunctionDefSpec extends PostgresRunnableSpec with DbSchema { +object CustomFunctionDefSpec extends PostgresRunnableSpec with DbSchema { import Customers._ - import FunctionDef.{ CharLength => _, _ } import PostgresFunctionDef._ import PostgresSpecific._ @@ -26,67 +25,6 @@ object FunctionDefSpec extends PostgresRunnableSpec with DbSchema { private val timestampFormatter = DateTimeFormatter.ofPattern("uuuu-MM-dd HH:mm:ss.SSSS").withZone(ZoneId.of("UTC")) override def specLayered = suite("Postgres FunctionDef")( - test("concat_ws #1 - combine flat values") { - import Expr._ - - // note: a plain number (3) would and should not compile - val query = select(ConcatWs4("+", "1", "2", "3")) - - val expected = Seq("1+2+3") - - val testResult = execute(query) - collectAndCompare(expected, testResult) - }, - test("concat_ws #2 - combine columns") { - - // note: you can't use customerId here as it is a UUID, hence not a string in our book - val query = select(ConcatWs3(Customers.fName, Customers.fName, Customers.lName)) from customers - - val expected = Seq( - "RonaldRonaldRussell", - "TerrenceTerrenceNoel", - "MilaMilaPaterso", - "AlanaAlanaMurray", - "JoseJoseWiggins" - ) - - val testResult = execute(query) - collectAndCompare(expected, testResult) - }, - test("concat_ws #3 - combine columns and flat values") { - import Expr._ - - val query = select(ConcatWs4(" ", "Person:", Customers.fName, Customers.lName)) from customers - - val expected = Seq( - "Person: Ronald Russell", - "Person: Terrence Noel", - "Person: Mila Paterso", - "Person: Alana Murray", - "Person: Jose Wiggins" - ) - - val testResult = execute(query) - collectAndCompare(expected, testResult) - }, - test("concat_ws #3 - combine function calls together") { - import Expr._ - - val query = select( - ConcatWs3(" and ", Concat("Name: ", Customers.fName), Concat("Surname: ", Customers.lName)) - ) from customers - - val expected = Seq( - "Name: Ronald and Surname: Russell", - "Name: Terrence and Surname: Noel", - "Name: Mila and Surname: Paterso", - "Name: Alana and Surname: Murray", - "Name: Jose and Surname: Wiggins" - ) - - val testResult = execute(query) - collectAndCompare(expected, testResult) - }, test("isfinite") { val query = select(IsFinite(Instant.now)) val expected: Boolean = true @@ -99,12 +37,6 @@ object FunctionDefSpec extends PostgresRunnableSpec with DbSchema { test("CharLength") { assertZIO(execute(select(Length("hello"))).runHead.some)(equalTo(5)) }, - test("ltrim") { - assertZIO(execute(select(Ltrim(" hello "))).runHead.some)(equalTo("hello ")) - }, - test("rtrim") { - assertZIO(execute(select(Rtrim(" hello "))).runHead.some)(equalTo(" hello")) - }, test("bit_length") { assertZIO(execute(select(BitLength("hello"))).runHead.some)(equalTo(40)) }, @@ -226,45 +158,12 @@ object FunctionDefSpec extends PostgresRunnableSpec with DbSchema { } ) ), - test("abs") { - assertZIO(execute(select(Abs(-3.14159))).runHead.some)(equalTo(3.14159)) - }, - test("log") { - assertZIO(execute(select(Log(2.0, 32.0))).runHead.some)(equalTo(5.0)) - }, - test("acos") { - assertZIO(execute(select(Acos(-1.0))).runHead.some)(equalTo(3.141592653589793)) - }, test("repeat") { assertZIO(execute(select(Repeat("Zio", 3))).runHead.some)(equalTo("ZioZioZio")) }, test("reverse") { assertZIO(execute(select(Reverse("abcd"))).runHead.some)(equalTo("dcba")) }, - test("asin") { - assertZIO(execute(select(Asin(0.5))).runHead.some)(equalTo(0.5235987755982989)) - }, - test("ln") { - assertZIO(execute(select(Ln(3.0))).runHead.some)(equalTo(1.0986122886681097)) - }, - test("atan") { - assertZIO(execute(select(Atan(10.0))).runHead.some)(equalTo(1.4711276743037347)) - }, - test("cos") { - assertZIO(execute(select(Cos(3.141592653589793))).runHead.some)(equalTo(-1.0)) - }, - test("exp") { - assertZIO(execute(select(Exp(1.0))).runHead.some)(equalTo(2.718281828459045)) - }, - test("floor") { - assertZIO(execute(select(Floor(-3.14159))).runHead.some)(equalTo(-4.0)) - }, - test("ceil") { - assertZIO(execute(select(Ceil(53.7), Ceil(-53.7))).runHead.some)(equalTo((54.0, -53.0))) - }, - test("sin") { - assertZIO(execute(select(Sin(1.0))).runHead.some)(equalTo(0.8414709848078965)) - }, test("sind") { assertZIO(execute(select(Sind(30.0))).runHead.some)(equalTo(0.5)) }, @@ -371,15 +270,6 @@ object FunctionDefSpec extends PostgresRunnableSpec with DbSchema { assertZIO(testResult.runCollect.exit)(fails(anything)) } ) @@ ignore, - test("sqrt") { - val query = select(Sqrt(121.0)) - - val expected = 11.0 - - val testResult = execute(query) - - assertZIO(testResult.runHead.some)(equalTo(expected)) - }, test("chr") { val query = select(Chr(65)) @@ -484,71 +374,6 @@ object FunctionDefSpec extends PostgresRunnableSpec with DbSchema { assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) }, - test("round") { - val query = select(Round(10.8124, 2)) - - val expected = 10.81 - - val testResult = execute(query) - - val assertion = for { - r <- testResult.runCollect - } yield assert(r.head)(equalTo(expected)) - - assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) - }, - test("sign positive") { - val query = select(Sign(3.0)) - - val expected = 1 - - val testResult = execute(query) - - val assertion = for { - r <- testResult.runCollect - } yield assert(r.head)(equalTo(expected)) - - assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) - }, - test("sign negative") { - val query = select(Sign(-3.0)) - - val expected = -1 - - val testResult = execute(query) - - val assertion = for { - r <- testResult.runCollect - } yield assert(r.head)(equalTo(expected)) - - assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) - }, - test("sign zero") { - val query = select(Sign(0.0)) - - val expected = 0 - - val testResult = execute(query) - - val assertion = for { - r <- testResult.runCollect - } yield assert(r.head)(equalTo(expected)) - - assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) - }, - test("power") { - val query = select(Power(7.0, 3.0)) - - val expected = 343.000000000000000 - - val testResult = execute(query) - - val assertion = for { - r <- testResult.runCollect - } yield assert(r.head)(equalTo(expected)) - - assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) - }, test("length") { val query = select(Length("hello")) @@ -562,19 +387,6 @@ object FunctionDefSpec extends PostgresRunnableSpec with DbSchema { assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) }, - test("mod") { - val query = select(Mod(-15.0, -4.0)) - - val expected = -3.0 - - val testResult = execute(query) - - val assertion = for { - r <- testResult.runCollect - } yield assert(r.head)(equalTo(expected)) - - assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) - }, test("translate") { val query = select(Translate("12345", "143", "ax")) @@ -667,97 +479,6 @@ object FunctionDefSpec extends PostgresRunnableSpec with DbSchema { assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) }, - test("lower") { - val query = select(Lower(fName)) from customers limit (1) - - val expected = "ronald" - - val testResult = execute(query) - - val assertion = for { - r <- testResult.runCollect - } yield assert(r.head)(equalTo(expected)) - - assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) - }, - test("lower with string literal") { - val query = select(Lower("LOWER")) from customers limit (1) - - val expected = "lower" - - val testResult = execute(query) - - val assertion = for { - r <- testResult.runCollect - } yield assert(r.head)(equalTo(expected)) - - assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) - }, - test("octet_length") { - val query = select(OctetLength("josé")) - - val expected = 5 - - val testResult = execute(query) - - val assertion = for { - r <- testResult.runCollect - } yield assert(r.head)(equalTo(expected)) - - assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) - }, - test("ascii") { - val query = select(Ascii("""x""")) - - val expected = 120 - - val testResult = execute(query) - - val assertion = for { - r <- testResult.runCollect - } yield assert(r.head)(equalTo(expected)) - - assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) - }, - test("upper") { - val query = (select(Upper("ronald"))).limit(1) - - val expected = "RONALD" - - val testResult = execute(query) - - val assertion = for { - r <- testResult.runCollect - } yield assert(r.head)(equalTo(expected)) - - assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) - }, - test("width_bucket") { - val query = select(WidthBucket(5.35, 0.024, 10.06, 5)) - - val expected = 3 - - val testResult = execute(query) - - val assertion = for { - r <- testResult.runCollect - } yield assert(r.head)(equalTo(expected)) - - assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) - }, - test("tan") { - val query = select(Tan(0.7853981634)) - - val expected = 1.0000000000051035 - - val testResult = execute(query) - - val assertion = for { - r <- testResult.runCollect - } yield assert(r.head)(equalTo(expected)) - - assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) - }, test("gcd") { val query = select(GCD(1071d, 462d)) @@ -855,20 +576,6 @@ object FunctionDefSpec extends PostgresRunnableSpec with DbSchema { assertZIO(testResult.runHead.some)(equalTo(randomTupleForSeed)) }, - test("Can concat strings with concat function") { - - val query = select(Concat(fName, lName) as "fullname") from customers - - val expected = Seq("RonaldRussell", "TerrenceNoel", "MilaPaterso", "AlanaMurray", "JoseWiggins") - - val result = execute(query) - - val assertion = for { - r <- result.runCollect - } yield assert(r)(hasSameElementsDistinct(expected)) - - assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) - }, test("Can calculate character length of a string") { val query = select(CharLength(fName)) from customers @@ -909,25 +616,6 @@ object FunctionDefSpec extends PostgresRunnableSpec with DbSchema { assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) }, - test("replace") { - val lastNameReplaced = Replace(lName, "ll", "_") as "lastNameReplaced" - val computedReplace = Replace("special ::ąę::", "ąę", "__") as "computedReplace" - - val query = select(lastNameReplaced, computedReplace) from customers - - val expected = ("Russe_", "special ::__::") - - val testResult = - execute(query).map { case row => - (row._1, row._2) - } - - val assertion = for { - r <- testResult.runCollect - } yield assert(r.head)(equalTo(expected)) - - assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) - }, test("lpad") { def runTest(s: String, pad: String) = { val query = select(LPad(s, 5, pad)) diff --git a/sqlserver/src/main/scala/zio/sql/sqlserver/SqlServerRenderModule.scala b/sqlserver/src/main/scala/zio/sql/sqlserver/SqlServerRenderModule.scala index c42723271..4e5b7cb9d 100644 --- a/sqlserver/src/main/scala/zio/sql/sqlserver/SqlServerRenderModule.scala +++ b/sqlserver/src/main/scala/zio/sql/sqlserver/SqlServerRenderModule.scala @@ -4,7 +4,6 @@ import zio.Chunk import zio.schema.StandardType._ import zio.schema._ import zio.sql.driver.Renderer -import zio.sql.driver.Renderer.Extensions import java.time.format.{ DateTimeFormatter, DateTimeFormatterBuilder } import java.time._ @@ -107,6 +106,49 @@ trait SqlServerRenderModule extends SqlServerSqlModule { self => render(" (", values.mkString(","), ") ") // todo fix needs escaping } + private def renderLit(lit: self.Expr.Literal[_])(implicit render: Renderer): Unit = { + import TypeTag._ + val value = lit.value + lit.typeTag match { + case TInstant => + render(s"'${DateTimeFormatter.ISO_INSTANT.format(value.asInstanceOf[Instant])}'") + case TLocalTime => + render(s"'${DateTimeFormatter.ISO_LOCAL_TIME.format(value.asInstanceOf[LocalTime])}'") + case TLocalDate => + render(s"'${DateTimeFormatter.ISO_LOCAL_DATE.format(value.asInstanceOf[LocalDate])}'") + case TLocalDateTime => + render(s"'${DateTimeFormatter.ISO_LOCAL_DATE_TIME.format(value.asInstanceOf[LocalDateTime])}'") + case TZonedDateTime => + render(s"'${fmtDateTimeOffset.format(value.asInstanceOf[ZonedDateTime])}'") + case TOffsetTime => + render(s"'${fmtTimeOffset.format(value.asInstanceOf[OffsetTime])}'") + case TOffsetDateTime => + render(s"'${fmtDateTimeOffset.format(value.asInstanceOf[OffsetDateTime])}'") + + case TBoolean => + val b = value.asInstanceOf[Boolean] + if (b) { + render('1') + } else { + render('0') + } + case TUUID => render(s"'$value'") + + case TBigDecimal => render(value) + case TByte => render(value) + case TDouble => render(value) + case TFloat => render(value) + case TInt => render(value) + case TLong => render(value) + case TShort => render(value) + + case TChar => render(s"N'$value'") + case TString => render(s"N'$value'") + + case _ => render(s"'$value'") + } + } + private def buildExpr[A, B](expr: self.Expr[_, A, B])(implicit render: Renderer): Unit = expr match { case Expr.Subselect(subselect) => render(" (") @@ -145,33 +187,7 @@ trait SqlServerRenderModule extends SqlServerSqlModule { self => case Expr.In(value, set) => buildExpr(value) renderReadImpl(set) - case literal @ Expr.Literal(value) => - val lit = literal.typeTag match { - case TypeTag.TBoolean => - // MSSQL server variant of true/false - if (value.asInstanceOf[Boolean]) { - "0 = 0" - } else { - "0 = 1" - } - case TypeTag.TLocalDateTime => - val x = value - .asInstanceOf[java.time.LocalDateTime] - .format(java.time.format.DateTimeFormatter.ofPattern("YYYY-MM-dd HH:mm:ss")) - s"'$x'" - case TypeTag.TZonedDateTime => - val x = value - .asInstanceOf[java.time.ZonedDateTime] - .format(java.time.format.DateTimeFormatter.ofPattern("YYYY-MM-dd HH:mm:ss")) - s"'$x'" - case TypeTag.TOffsetDateTime => - val x = value - .asInstanceOf[java.time.OffsetDateTime] - .format(java.time.format.DateTimeFormatter.ofPattern("YYYY-MM-dd HH:mm:ss")) - s"'$x'" - case _ => value.toString.singleQuoted - } - render(lit) + case literal: Expr.Literal[_] => renderLit(literal) case Expr.AggregationCall(param, aggregation) => render(aggregation.name.name) render("(") diff --git a/sqlserver/src/test/scala/zio/sql/sqlserver/CommonFunctionDefSpec.scala b/sqlserver/src/test/scala/zio/sql/sqlserver/CommonFunctionDefSpec.scala new file mode 100644 index 000000000..907d8e33c --- /dev/null +++ b/sqlserver/src/test/scala/zio/sql/sqlserver/CommonFunctionDefSpec.scala @@ -0,0 +1,328 @@ +package zio.sql.sqlserver + +import zio.Cause +import zio.stream.ZStream +import zio.test.Assertion._ +import zio.test._ + +object CommonFunctionDefSpec extends SqlServerRunnableSpec with DbSchema { + import FunctionDef.{ CharLength => _, _ } + import DbSchema._ + + private def collectAndCompare[R, E]( + expected: Seq[String], + testResult: ZStream[R, E, String] + ) = + assertZIO(testResult.runCollect)(hasSameElementsDistinct(expected)) + + override def specLayered = suite("SqlServer Common FunctionDef")( + suite("Schema dependent tests")( + test("concat_ws #2 - combine columns") { + + // note: you can't use customerId here as it is a UUID, hence not a string in our book + val query = select(ConcatWs3(fName, fName, lName)) from customers + + val expected = Seq( + "RonaldRonaldRussell", + "TerrenceTerrenceNoel", + "MilaMilaPaterso", + "AlanaAlanaMurray", + "JoseJoseWiggins" + ) + + val testResult = execute(query) + collectAndCompare(expected, testResult) + }, + test("concat_ws #3 - combine columns and flat values") { + import Expr._ + + val query = select(ConcatWs4(" ", "Person:", fName, lName)) from customers + + val expected = Seq( + "Person: Ronald Russell", + "Person: Terrence Noel", + "Person: Mila Paterso", + "Person: Alana Murray", + "Person: Jose Wiggins" + ) + + val testResult = execute(query) + collectAndCompare(expected, testResult) + }, + test("concat_ws #3 - combine function calls together") { + import Expr._ + + val query = select( + ConcatWs3(" and ", Concat("Name: ", fName), Concat("Surname: ", lName)) + ) from customers + + val expected = Seq( + "Name: Ronald and Surname: Russell", + "Name: Terrence and Surname: Noel", + "Name: Mila and Surname: Paterso", + "Name: Alana and Surname: Murray", + "Name: Jose and Surname: Wiggins" + ) + + val testResult = execute(query) + collectAndCompare(expected, testResult) + }, + test("lower") { + val query = select(Lower(fName)) from customers limit (1) + + val expected = "ronald" + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("Can concat strings with concat function") { + + val query = select(Concat(fName, lName) as "fullname") from customers + + val expected = Seq("RonaldRussell", "TerrenceNoel", "MilaPaterso", "AlanaMurray", "JoseWiggins") + + val result = execute(query) + + val assertion = for { + r <- result.runCollect + } yield assert(r)(hasSameElementsDistinct(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("replace") { + val lastNameReplaced = Replace(lName, "ll", "_") as "lastNameReplaced" + val computedReplace = Replace("special ::ąę::", "ąę", "__") as "computedReplace" + + val query = select(lastNameReplaced, computedReplace) from customers + + val expected = ("Russe_", "special ::__::") + + val testResult = + execute(query).map { case row => + (row._1, row._2) + } + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + } + ), + suite("Schema independent tests")( + test("concat_ws #1 - combine flat values") { + import Expr._ + + // note: a plain number (3) would and should not compile + val query = select(ConcatWs4("+", "1", "2", "3")) + + val expected = Seq("1+2+3") + + val testResult = execute(query) + collectAndCompare(expected, testResult) + }, + test("ltrim") { + assertZIO(execute(select(Ltrim(" hello "))).runHead.some)(equalTo("hello ")) + }, + test("rtrim") { + assertZIO(execute(select(Rtrim(" hello "))).runHead.some)(equalTo(" hello")) + }, + test("abs") { + assertZIO(execute(select(Abs(-3.14159))).runHead.some)(equalTo(3.14159)) + }, + test("log") { + assertZIO(execute(select(Log(32.0, 2.0))).runHead.some)(equalTo(5.0)) + } @@ TestAspect.tag("different order of params"), + test("acos") { + assertZIO(execute(select(Acos(-1.0))).runHead.some)(equalTo(3.141592653589793)) + }, + test("asin") { + assertZIO(execute(select(Asin(0.5))).runHead.some)(equalTo(0.5235987755982989)) + }, + test("ln") { + assertZIO(execute(select(Ln(3.0))).runHead.some)(equalTo(1.0986122886681097)) + } @@ TestAspect.ignore @@ TestAspect.tag("log with one param"), + test("atan") { + assertZIO(execute(select(Atan(10.0))).runHead.some)(equalTo(1.4711276743037347)) + }, + test("cos") { + assertZIO(execute(select(Cos(3.141592653589793))).runHead.some)(equalTo(-1.0)) + }, + test("exp") { + assertZIO(execute(select(Exp(1.0))).runHead.some)(equalTo(2.718281828459045)) + }, + test("floor") { + assertZIO(execute(select(Floor(-3.14159))).runHead.some)(equalTo(-4.0)) + }, + test("ceil") { + assertZIO(execute(select(Ceil(53.7), Ceil(-53.7))).runHead.some)(equalTo((54.0, -53.0))) + } @@ TestAspect.ignore @@ TestAspect.tag("cailing"), + test("sin") { + assertZIO(execute(select(Sin(1.0))).runHead.some)(equalTo(0.8414709848078965)) + }, + test("sqrt") { + val query = select(Sqrt(121.0)) + + val expected = 11.0 + + val testResult = execute(query) + + assertZIO(testResult.runHead.some)(equalTo(expected)) + }, + test("round") { + val query = select(Round(10.8124, 2)) + + val expected = 10.81 + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("sign positive") { + val query = select(Sign(3.0)) + + val expected = 1 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("sign negative") { + val query = select(Sign(-3.0)) + + val expected = -1 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("sign zero") { + val query = select(Sign(0.0)) + + val expected = 0 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("power") { + val query = select(Power(7.0, 3.0)) + + val expected = 343.000000000000000 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("mod") { + val query = select(Mod(-15.0, -4.0)) + + val expected = -3.0 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + } @@ TestAspect.ignore @@ TestAspect.tag("to use % instead"), + test("octet_length") { + val query = select(OctetLength("josé")) + + val expected = 5 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + } @@ TestAspect.ignore @@ TestAspect.tag("datalength"), + test("ascii") { + val query = select(Ascii("""x""")) + + val expected = 120 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("upper") { + val query = (select(Upper("ronald"))).limit(1) + + val expected = "RONALD" + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("width_bucket") { + val query = select(WidthBucket(5.35, 0.024, 10.06, 5)) + + val expected = 3 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + } @@ TestAspect.ignore, + test("tan") { + val query = select(Tan(0.7853981634)) + + val expected = 1.0000000000051035 + + val testResult = execute(query) + + val assertion = for { + r <- testResult.runCollect + } yield assert(r.head)(equalTo(expected)) + + assertion.mapErrorCause(cause => Cause.stackless(cause.untraced)) + }, + test("trim") { + assertZIO(execute(select(Trim(" 1234 "))).runHead.some)(equalTo("1234")) + }, + test("lower") { + assertZIO(execute(select(Lower("YES"))).runHead.some)(equalTo("yes")) + } + ) + ) + +}