diff --git a/impl/ocaml/mariadb/sqlgg_mariadb.ml b/impl/ocaml/mariadb/sqlgg_mariadb.ml index 80889caa..06ac5042 100644 --- a/impl/ocaml/mariadb/sqlgg_mariadb.ml +++ b/impl/ocaml/mariadb/sqlgg_mariadb.ml @@ -24,6 +24,14 @@ module type Value = sig val to_literal : t -> string end +module type Enum = sig + type t + + val inj: string -> t + + val proj: t -> string +end + module type Types = sig type field type value @@ -36,6 +44,7 @@ module type Types = sig module Datetime : Value module Decimal : Value module Any : Value + module Make_enum : functor (E : Enum) -> Value with type t = E.t end module Default_types(M : Mariadb.Nonblocking.S) : Types with @@ -183,6 +192,20 @@ struct let to_value x = x let to_literal _ = failwith "to_literal Any" end) + + module Make_enum (E: Enum) = Make(struct + + include E + + let of_field field = + match M.Field.value field with + | `String x -> inj x + | value -> convfail "enum" field value + + let to_value v = `String (proj v) + + let to_literal = proj + end) end module Make @@ -258,6 +281,19 @@ let set_param_Float = set_param_ty Float.to_value let set_param_Decimal = set_param_ty Decimal.to_value let set_param_Datetime = set_param_ty Datetime.to_value +module Make_enum (E: Enum) = struct + + module E = Make_enum(E) + + type t = E.t + + let get_column, get_column_nullable = get_column_ty "Enum" E.of_field + + let set_param = set_param_ty E.to_value + + let to_literal = E.to_literal +end + let no_params stmt = let open IO in M.Stmt.execute stmt [||] >>= diff --git a/impl/ocaml/mysql/sqlgg_mysql.ml b/impl/ocaml/mysql/sqlgg_mysql.ml index 081885ad..f64e09ae 100644 --- a/impl/ocaml/mysql/sqlgg_mysql.ml +++ b/impl/ocaml/mysql/sqlgg_mysql.ml @@ -126,6 +126,15 @@ type result = P.stmt_result type execute_response = { affected_rows: int64; insert_id: int64 } module Types = T + +module type Enum = sig + type t + + val inj: string -> t + + val proj: t -> string +end + open Types (* compatibility *) @@ -178,6 +187,17 @@ let set_param_Float = set_param_ty Float.to_string let set_param_Decimal = set_param_ty Decimal.to_string let set_param_Datetime = set_param_ty Datetime.to_string +module Make_enum (E: Enum) = struct + + include E + + let get_column, get_column_nullable = get_column_ty "Enum" E.inj + + let set_param = set_param_ty E.proj + + let to_literal = E.proj +end + let no_params stmt = P.execute stmt [||] let try_finally final f x = diff --git a/impl/ocaml/sqlgg_traits.ml b/impl/ocaml/sqlgg_traits.ml index fb32ea7f..2e980277 100644 --- a/impl/ocaml/sqlgg_traits.ml +++ b/impl/ocaml/sqlgg_traits.ml @@ -20,6 +20,14 @@ module type Value = sig val to_literal : t -> string end +module type Enum = sig + type t + + val inj: string -> t + + val proj: t -> string +end + module type M = sig type statement @@ -80,6 +88,14 @@ module type M = sig val set_param_Decimal : params -> Decimal.t -> unit val set_param_Datetime : params -> Datetime.t -> unit + module Make_enum: functor (E : Enum) -> sig + (* The type itself is not exposed to provide a user a polymorphic type without aliases. *) + val get_column : row -> int -> E.t + val get_column_nullable : row -> int -> E.t option + val set_param : params -> E.t -> unit + val to_literal : E.t -> string + end + val no_params : statement -> result (** diff --git a/impl/ocaml/sqlite3/sqlgg_sqlite3.ml b/impl/ocaml/sqlite3/sqlgg_sqlite3.ml index 0b7351a5..44fc8132 100644 --- a/impl/ocaml/sqlite3/sqlgg_sqlite3.ml +++ b/impl/ocaml/sqlite3/sqlgg_sqlite3.ml @@ -67,6 +67,14 @@ module Types = struct module Any = Text end +module type Enum = sig + type t + + val inj: string -> t + + val proj: t -> string +end + type statement = S.stmt * string type 'a connection = S.db type params = statement * int * int ref @@ -110,6 +118,17 @@ let get_column_Float, get_column_Float_nullable = get_column_ty Conv.float let get_column_Decimal, get_column_Decimal_nullable = get_column_ty Conv.decimal let get_column_Datetime, get_column_Datetime_nullable = get_column_ty Conv.float +module Make_enum (E: Enum) = struct + + include E + + let get_column, get_column_nullable = failwith "sqlite does not support enums" + + let set_param = failwith "sqlite does not support enums" + + let to_literal = failwith "sqlite does not support enums" +end + let test_ok sql rc = if rc <> S.Rc.OK then raise (Oops (sprintf "test_ok %s for %s" (S.Rc.to_string rc) sql)) diff --git a/lib/sql.ml b/lib/sql.ml index 141d75dc..b2a4cb8b 100644 --- a/lib/sql.ml +++ b/lib/sql.ml @@ -6,6 +6,22 @@ open Prelude module Type = struct + + module Enum_kind = struct + + module Ctors = struct + include Set.Make(String) + + let pp fmt s = + Format.fprintf fmt "{%s}" + (String.concat "; " (elements s)) + end + + type t = Ctors.t [@@deriving eq, show{with_path=false}] + + let make ctors = Ctors.of_list ctors + end + type kind = | Unit of [`Interval] | Int @@ -15,10 +31,17 @@ struct | Bool | Datetime | Decimal + | Union of { ctors: Enum_kind.t; is_closed: bool } + | StringLiteral of string | Any (* FIXME - Top and Bottom ? *) [@@deriving eq, show{with_path=false}] (* TODO NULL is currently typed as Any? which actually is a misnormer *) + let show_kind = function + | Union { ctors; _ } -> sprintf "Union (%s)" (String.concat "| " (Enum_kind.Ctors.elements ctors)) + | StringLiteral l -> sprintf "StringLiteral (%s)" l + | k -> show_kind k + type nullability = | Nullable (** can be NULL *) | Strict (** cannot be NULL *) @@ -34,6 +57,8 @@ struct let make_nullable { t; nullability=_ } = nullable t let make_strict { t; nullability=_ } = strict t + + let make_enum_kind ctors = Union { ctors = (Enum_kind.make ctors); is_closed = true } let is_strict { nullability; _ } = nullability = Strict @@ -50,19 +75,34 @@ struct let is_unit = function { t = Unit _; _ } -> true | _ -> false (** @return (subtype, supertype) *) - let order_kind x y = - if equal_kind x y then - `Equal - else - match x,y with - | Any, t | t, Any -> `Order (t,t) - | Int, Float | Float, Int -> `Order (Int,Float) - (* arbitrary decision : allow int<->decimal but require explicit cast for floats *) - | Decimal, Int | Int, Decimal -> `Order (Int,Decimal) - | Text, Blob | Blob, Text -> `Order (Text,Blob) - | Int, Datetime | Datetime, Int -> `Order (Int,Datetime) - | Text, Datetime | Datetime, Text -> `Order (Datetime,Text) - | _ -> `No + let order_kind x y = + match x, y with + | x, y when equal_kind x y -> `Equal + | StringLiteral a, StringLiteral b -> + `StringLiteralUnion (Union { ctors = (Enum_kind.make [a; b]); is_closed = false }) + + | StringLiteral a, Union { ctors = b; is_closed } | Union { ctors = b; is_closed }, StringLiteral a when Enum_kind.Ctors.mem a b + -> `Order (StringLiteral a, Union { ctors = (Enum_kind.Ctors.add a b); is_closed }) + + | StringLiteral a, Union { ctors = b; is_closed = false } | Union { ctors = b; is_closed = false }, StringLiteral a -> + `StringLiteralUnion (Union { ctors = (Enum_kind.Ctors.add a b); is_closed = false; }) + + | StringLiteral _ as x , Text -> `Order (x, Text) + | Text, (StringLiteral _ as x) -> `Order (x, Text) + + | Text, (Union _ as x) -> `Order (x, Text) + | Union { ctors = a; _ } as x1, (Union { ctors = b ;_ } as x2) when Enum_kind.Ctors.subset b a -> `Order (x2, x1) + + | StringLiteral x, Datetime | Datetime, StringLiteral x -> `Order (Datetime, StringLiteral x) + | StringLiteral x, Blob | Blob, StringLiteral x -> `Order (Blob, StringLiteral x) + | Any, t | t, Any -> `Order (t, t) + | Int, Float | Float, Int -> `Order (Int, Float) + | Decimal, Int | Int, Decimal -> `Order (Int, Decimal) + | Text, Blob | Blob, Text -> `Order (Text, Blob) + | Int, Datetime | Datetime, Int -> `Order (Int, Datetime) + | Text, Datetime | Datetime, Text -> `Order (Datetime, Text) + | _ -> `No + let order_nullability x y = match x,y with @@ -89,19 +129,30 @@ struct let common_type_ order x y = match order_nullability x.nullability y.nullability, order_kind x.t y.t with | _, `No -> None - | `Equal nullability, `Order pair -> Some {t = order pair; nullability} + | `Equal nullability, `Order pair -> `CommonType pair |> order |> Option.map (fun t -> { t = t; nullability }) | `Equal nullability, `Equal -> Some { x with nullability } | (`Nullable_Strict|`Strict_Nullable), `Equal -> Some (nullable x.t) (* FIXME need nullability order? *) - | (`Nullable_Strict|`Strict_Nullable), `Order pair -> Some (nullable @@ order pair) + | (`Nullable_Strict|`Strict_Nullable), `Order pair -> `CommonType pair |> order |> Option.map nullable + | `Equal nullability, `StringLiteralUnion t -> `StringLiteralUnion t |> order |> Option.map (fun t -> { t = t; nullability }) + | (`Nullable_Strict | `Strict_Nullable), `StringLiteralUnion t -> `StringLiteralUnion t |> order |> Option.map nullable let common_type_l_ order = function | [] -> None | t::ts -> List.fold_left (fun acc t -> match acc with None -> None | Some prev -> common_type_ order prev t) (Some t) ts - let subtype = common_type_ fst - let supertype = common_type_ snd - let common_subtype = common_type_l_ fst - let common_supertype = common_type_l_ snd + let get_subtype = function + | `CommonType t -> Some (fst t) + | `StringLiteralUnion t -> Some t + + let get_supertype = function + | `CommonType t -> Some (snd t) + | `StringLiteralUnion t -> Some t + + let subtype = common_type_ get_subtype + let supertype = common_type_ get_supertype + let common_subtype = common_type_l_ get_subtype + + let common_supertype = common_type_l_ get_supertype let common_type = subtype @@ -435,7 +486,7 @@ and expr = (* pos - full syntax pos from {, to }?, pos is only sql, that inside {}? to use it during the substitution and to not depend on the magic numbers there. *) - | OptionBoolChoices of { choice: expr; pos: (pos * pos) } + | OptionBoolChoices of { choice: expr; pos: (pos * pos) } and column = | All | AllOf of table_name diff --git a/lib/sql_parser.mly b/lib/sql_parser.mly index 1fb0e1b8..bd6a718d 100644 --- a/lib/sql_parser.mly +++ b/lib/sql_parser.mly @@ -440,8 +440,7 @@ expr: InTupleList(names, p) } | LPAREN select=select_stmt RPAREN { SelectExpr (select, `AsValue) } - | p=param { Param (new_param p (depends Any)) } - | p=param DOUBLECOLON t=manual_type { Param (new_param { p with pos=($startofs, $endofs) } t) } + | p=param t=preceded(DOUBLECOLON, manual_type)? { Param (new_param { p with pos=($startofs, $endofs) } (Option.default (depends Any) t)) } | LCURLY e=expr RCURLY QSTN { OptionBoolChoices ({ choice=e; pos=(($startofs, $endofs), ($startofs + 1, $endofs - 2))}) } | p=param parser_state_ident LCURLY l=choices c2=RCURLY { let { label; pos=(p1,_p2) } = p in Choices ({ label; pos = (p1,c2+1)},l) } | SUBSTRING LPAREN s=expr FROM p=expr FOR n=expr RPAREN @@ -481,6 +480,7 @@ values_stmt1: values_stmt: | kind=values_stmt1 row_order=loption(order) row_limit=limit_t? {{ row_constructor_list = kind; row_order; row_limit;}} + (* https://dev.mysql.com/doc/refman/8.0/en/window-functions-usage.html *) window_function: @@ -513,7 +513,7 @@ choices: separated_nonempty_list(pair(parser_state_ident,NUM_BIT_OR),choice) { $ datetime_value: | DATETIME_FUNC | DATETIME_FUNC LPAREN INTEGER? RPAREN { Value (strict Datetime) } strict_value: - | TEXT collate? { Text } + | TEXT { StringLiteral $1 } | BLOB collate? { Blob } | INTEGER { Int } | FLOAT { Float } @@ -555,7 +555,7 @@ sql_type_flavor: T_INTEGER UNSIGNED? ZEROFILL? { Int } | T_DECIMAL { Decimal } | binary { Blob } | NATIONAL? text VARYING? charset? collate? { Text } - | ENUM sequence(TEXT) charset? collate? { Text } + | ENUM ctors=sequence(TEXT) charset? collate? { make_enum_kind ctors } | T_FLOAT PRECISION? { Float } | T_BOOLEAN { Bool } | T_DATETIME | YEAR | DATE | TIME | TIMESTAMP { Datetime } diff --git a/lib/syntax.ml b/lib/syntax.ml index 85208baa..d88a7a75 100644 --- a/lib/syntax.ml +++ b/lib/syntax.ml @@ -31,6 +31,8 @@ module Tables_with_derived = struct let get_from ~env name = get_from (env.ctes @ env.tables) name end +type enum_ctor_value_data = { ctor_name: string; pos: pos; } [@@deriving show] + (* expr with all name references resolved to values or "functions" *) type res_expr = | ResValue of Type.t (** literal value *) @@ -221,7 +223,7 @@ let rec bool_choice_id = function | SelectExpr _ | OptionBoolChoices _ | Choices _ - | Value _ -> None + | Value _ -> None | Inparam p | Param p -> Some p.id | Fun (_, exprs) -> List.find_map bool_choice_id exprs diff --git a/src/cli.exe b/src/cli.exe deleted file mode 100755 index bda6d8c1..00000000 Binary files a/src/cli.exe and /dev/null differ diff --git a/src/cli.ml b/src/cli.ml index 5efeeee2..eec11c4f 100644 --- a/src/cli.ml +++ b/src/cli.ml @@ -103,6 +103,7 @@ let main () = "-static-header", Arg.Unit (fun () -> Sqlgg_config.gen_header := Some `Static), "only output short static header without version/timestamp"; "-show-tables", Arg.Unit Tables.print_all, " Show all current tables"; "-show-table", Arg.String Tables.print1, " Show specified table"; + "-enum-poly-variant", Arg.Unit (fun () -> Sqlgg_config.enum_as_poly_variant := true), " Represent enums as variants in generated code"; "-", Arg.Unit (fun () -> work "-"), " Read sql from stdin"; "-test", Arg.Unit Test.run, " Run unit tests"; ] diff --git a/src/gen.ml b/src/gen.ml index eeef70d8..b161d5f1 100644 --- a/src/gen.ml +++ b/src/gen.ml @@ -181,6 +181,7 @@ let substitute_vars s vars subst_param = in squash [] acc + let subst_named index p = "@" ^ (show_param_name p index) let subst_oracle index p = ":" ^ (show_param_name p index) let subst_postgresql index _ = "$" ^ string_of_int (index + 1) diff --git a/src/gen_caml.ml b/src/gen_caml.ml index c6a02b12..23e0ee97 100644 --- a/src/gen_caml.ml +++ b/src/gen_caml.ml @@ -119,21 +119,41 @@ let comment () fmt = Printf.kprintf (indent_endline $ make_comment) fmt let empty_line () = print_newline () +let enums_hash_tbl = Hashtbl.create 100 + +let enum_get_hash ctors = Type.Enum_kind.Ctors.elements ctors |> String.concat "_" + +let enum_name = Printf.sprintf "Enum_%d" + +let get_enum_name ctors = ctors |> enum_get_hash |> Hashtbl.find enums_hash_tbl |> fst |> enum_name + module L = struct open Type let as_lang_type = function | { t = Blob; nullability } -> type_name { t = Text; nullability } - | t -> type_name t + | { t = StringLiteral _; nullability } -> type_name { t = Text; nullability } + | { t = Unit _; _ } + | { t = Int; _ } + | { t = Text; _ } + | { t = Float; _ } + | { t = Bool; _ } + | { t = Datetime; _ } + | { t = Decimal; _ } + | { t = Union _; _ } + | { t = Any; _ } as t -> type_name t let as_api_type = as_lang_type end let get_column index attr = - sprintf "(T.get_column_%s%s stmt %u)" - (L.as_lang_type attr.domain) - (if is_attr_nullable attr then "_nullable" else "") - index + let rec print_column attr = match attr with + | { domain={ t = Union {ctors; _}; _ }; _ } when !Sqlgg_config.enum_as_poly_variant -> + sprintf "(%s.get_column%s" (get_enum_name ctors) + | { domain={ t = Union _; _ }; _ } as c -> print_column { c with domain = { c.domain with t = Text } } + | _ -> sprintf "(T.get_column_%s%s" (L.as_lang_type attr.domain) in + let column = print_column attr (if is_attr_nullable attr then "_nullable" else "") in + sprintf "%s stmt %u)" column index module T = Translate(L) @@ -188,9 +208,6 @@ let make_variant_name i name ~is_poly = let vname n = make_variant_name 0 (Some n) -let match_variant_wildcard i name args ~is_poly = - sprintf "%s%s" (make_variant_name i name ~is_poly) (match args with Some [] | None -> "" | Some _ -> " _") - let match_arg_pattern = function | Sql.Single _ | SingleIn _ | Choice _ | OptionBoolChoice _ @@ -208,20 +225,22 @@ let match_variant_pattern i name args ~is_poly = | l when List.for_all ((=) "_") l -> " _" | l -> sprintf " (%s)" (String.concat ", " l)) -let set_param index param = +let rec set_param index param = let nullable = is_param_nullable param in let pname = show_param_name param index in let ptype = show_param_type param in - if nullable then - output "begin match %s with None -> T.set_param_null p | Some v -> T.set_param_%s p v end;" pname ptype - else - output "T.set_param_%s p %s;" ptype pname - + let set_param_nullable = output "begin match %s with None -> T.set_param_null p | Some v -> %s p v end;" pname in + match param with + | { typ = { t=Union _; _}; _ } as c when not !Sqlgg_config.enum_as_poly_variant -> set_param index { c with typ = { c.typ with t = Text } } + | { typ = { t=Union {ctors; _}; _}; _ } when nullable -> set_param_nullable @@ (get_enum_name ctors) ^ ".set_param" + | { typ = { t=Union {ctors; _}; _ }; _ } -> output "%s.set_param p %s;" (get_enum_name ctors) pname + | param' when nullable -> set_param_nullable @@ sprintf "T.set_param_%s" (show_param_type param') + | _ -> output "T.set_param_%s p %s;" ptype pname + let rec set_var index var = match var with | Single p -> set_param index p - | SingleIn _ -> () - | TupleList _ -> () + | SingleIn _ | TupleList _ -> () | ChoiceIn { param = name; vars; _ } -> output "begin match %s with" (make_param_name index name); output "| [] -> ()"; @@ -276,6 +295,7 @@ let rec eval_count_params vars = | `BoolChoice v -> group_vars (static, choices, v::bool_choices, choices_in) xs | `ChoiceIn v -> group_vars (static, choices, bool_choices, v::choices_in) xs | `Choice v -> group_vars (static, v::choices, bool_choices, choices_in) xs + | `No -> group_vars (static, choices, bool_choices, choices_in) xs in group_vars ([], [], [], []) vars in @@ -350,12 +370,17 @@ let output_params_binder index vars = | [] -> "T.no_params" | vars -> output_params_binder index vars -let in_var_module _label typ = Sql.Type.type_name typ + +let make_to_literal = + let rec go domain = match domain with + | { Type.t = Union _; _ } when not !Sqlgg_config.enum_as_poly_variant -> go { domain with Type.t = Text } + | { Type.t = Union { ctors; _ }; _ } -> sprintf "%s.to_literal" (get_enum_name ctors) + | t -> sprintf "T.Types.%s.to_literal" (Sql.Type.type_name t) in go let gen_in_substitution var = if Option.is_none var.id.label then failwith "empty label in IN param"; - sprintf {code| "(" ^ String.concat ", " (List.map T.Types.%s.to_literal %s) ^ ")"|code} - (in_var_module (Option.get var.id.label) var.typ) + sprintf {code| "(" ^ String.concat ", " (List.map %s %s) ^ ")"|code} + (make_to_literal var.typ) (Option.get var.id.label) let gen_tuple_printer ~is_row _label schema = @@ -371,7 +396,7 @@ let gen_tuple_printer ~is_row _label schema = let { name; domain; _ } = attr in (if idx = 0 then "" else {|Buffer.add_string _sqlgg_b ", "; |}) ^ sprintf {|Buffer.add_string _sqlgg_b (%s);|} - (let to_literal = sprintf "T.Types.%s.to_literal %s" (in_var_module name domain) in + (let to_literal = sprintf "%s %s" (make_to_literal domain) in if is_attr_nullable attr then (sprintf {|match %s with None -> "NULL" | Some v -> %s|} name (to_literal "v") ) else to_literal name)) @@ -510,6 +535,80 @@ let generate_stmt style index stmt = dec_indent (); empty_line () +let sanitize_to_variant_name s = + let normalized = + let open String in + s + |> lowercase_ascii + |> map (function 'a'..'z' | 'A'..'Z' | '0'..'9' | '_' as c -> c | _ -> '_') + |> capitalize_ascii + in match normalized.[0] with + | '0'..'9' -> "Num_" ^ normalized + | _ -> normalized + +let generate_enum_modules stmts = + let open Sql.Type.Enum_kind in + + let schemas = List.concat_map (fun stmt -> stmt.Gen.schema) stmts in + let vars = List.concat_map (fun stmt -> stmt.Gen.vars) stmts in + + let get_enum typ = match typ.Sql.Type.t with + | Union { ctors; _ } -> Some ctors + | Unit _ | Int | Text | Blob | Float | Bool | Datetime | Decimal | Any | StringLiteral _ -> None + in + + let schemas_to_enums schemas = schemas |> List.filter_map (fun { domain; _ } -> get_enum domain) in + + let rec vars_to_enums vars = List.concat_map (function + | Single { typ; _ } + | SingleIn { typ; _ } -> typ |> get_enum |> option_list + | OptionBoolChoice (_, vars, _) + | ChoiceIn { vars; _ } -> vars_to_enums vars + | Choice (_, ctor_list) -> + List.concat_map ( function + | Simple (_, vars) -> Option.map vars_to_enums vars |> option_list |> List.concat + | Verbatim _ -> [] + ) ctor_list + | TupleList (_, ( Where_in types | ValueRows { types; _ } )) -> + List.concat_map (fun typ -> typ |> get_enum |> option_list) types + | TupleList (_, Insertion schema) -> schemas_to_enums schema + ) vars in + + Hashtbl.reset enums_hash_tbl; + + let generate_enum_module enum_count enum = + let get_ctor_name x = x |> sanitize_to_variant_name |> vname ~is_poly:true in + let ctor_list = Ctors.elements enum in + output {| + module %s = T.Make_enum(struct + type t = [%s] + let inj = function %s | s -> failwith (Printf.sprintf "Invalid enum value: %%s" s) + let proj = function %s + end) + |} + (enum_name enum_count) + (ctor_list |> List.map get_ctor_name |> String.concat " | ") + (String.concat " " + (List.map (fun ctor -> Printf.sprintf "| \"%s\" -> %s" (String.escaped ctor) (get_ctor_name ctor)) ctor_list)) + (String.concat "" + (List.map (fun ctor -> Printf.sprintf "| %s -> \"%s\"" (get_ctor_name ctor) (String.escaped ctor)) ctor_list)) + in + + indented (fun () -> + let result = schemas_to_enums schemas @ vars_to_enums vars in + let (_: int * unit list) = List.fold_left_map begin fun acc enum -> + let hash = enum_get_hash enum in + if Hashtbl.mem enums_hash_tbl hash then acc, () + else begin + Hashtbl.add enums_hash_tbl hash (acc, enum); + acc + 1, begin empty_line (); generate_enum_module acc enum end + end + end 0 result in + () + ) + +let generate_enum_modules stmts = if !Sqlgg_config.enum_as_poly_variant then generate_enum_modules stmts + let generate ~gen_io name stmts = (* let types = @@ -525,6 +624,7 @@ let generate ~gen_io name stmts = empty_line (); inc_indent (); output "module IO = %s" io; + generate_enum_modules stmts; empty_line (); List.iteri (generate_stmt `Direct) stmts; output "module Fold = struct"; diff --git a/src/gen_csharp.ml b/src/gen_csharp.ml index 043d3128..78cedd47 100644 --- a/src/gen_csharp.ml +++ b/src/gen_csharp.ml @@ -42,6 +42,8 @@ let as_api_type t = | Decimal -> "Decimal" | Datetime -> "Datetime" | Any -> "String" + | Union _ + | StringLiteral _ -> "String" | Unit _ -> assert false let as_lang_type = as_api_type diff --git a/src/gen_java.ml b/src/gen_java.ml index 06019200..9ec532d1 100644 --- a/src/gen_java.ml +++ b/src/gen_java.ml @@ -42,6 +42,8 @@ let as_lang_type t = | Bool -> "boolean" | Decimal -> "float" (* BigDecimal? *) | Datetime -> "Timestamp" + | StringLiteral _ -> "String" + | Union _ | Unit _ -> assert false let as_api_type = String.capitalize_ascii $ as_lang_type diff --git a/src/sqlgg_config.ml b/src/sqlgg_config.ml index c9c710f5..a9492826 100644 --- a/src/sqlgg_config.ml +++ b/src/sqlgg_config.ml @@ -16,3 +16,5 @@ let debug1 () = !debug_level > 0 let gen_header : [ `Full | `Without_timestamp | `Static ] option ref = ref (Some `Full) let include_category : [ `All | `None | `Only of Stmt.category list | `Except of Stmt.category list ] ref = ref `All + +let enum_as_poly_variant = ref false diff --git a/src/test.ml b/src/test.ml index 74e94a0c..fc9532c3 100644 --- a/src/test.ml +++ b/src/test.ml @@ -74,8 +74,8 @@ let test = Type.[ tt "insert or replace into test values (2,?,?)" [] [param_nullable Text; param_nullable Text;]; tt "replace into test values (2,?,?)" [] [param_nullable Text; param_nullable Text;]; tt "select str, case when id > @id then name when id < @id then 'qqq' else @def end as q from test" - [attr' ~nullability:(Nullable) "str" Text; attr' ~nullability:(Nullable) "q" Text] - [named_nullable "id" Int; named_nullable "id" Int; named_nullable "def" Text]; + [attr' ~nullability:(Nullable) "str" Text; attr' ~nullability:(Nullable) "q" (StringLiteral "qqq")] + [named_nullable "id" Int; named_nullable "id" Int; named_nullable "def" (StringLiteral "qqq")]; wrong "insert into test values (1,2)"; wrong "insert into test (str,name) values (1,'str','name')"; (* check precedence of boolean and arithmetic operators *) @@ -169,8 +169,8 @@ let test_join_result_cols () = let test_enum = [ tt "CREATE TABLE test6 (x enum('true','false') COLLATE utf8_bin NOT NULL, y INT DEFAULT 0) ENGINE=MyISAM DEFAULT CHARSET=utf8" [] []; - tt "SELECT * FROM test6" [attr "x" Text ~extra:[NotNull;]; attr ~extra:[WithDefault;] "y" Int] []; - tt "SELECT x, y+10 FROM test6" [attr "x" Text ~extra:[NotNull;]; attr "" Int] []; + tt "SELECT * FROM test6" [attr "x" (Type.(Union { ctors = (Enum_kind.Ctors.of_list ["true"; "false"]); is_closed = true })) ~extra:[NotNull;]; attr ~extra:[WithDefault;] "y" Int] []; + tt "SELECT x, y+10 FROM test6" [attr "x" (Type.(Union { ctors = (Enum_kind.Ctors.of_list ["true"; "false"]); is_closed = true })) ~extra:[NotNull;]; attr "" Int] []; ] let test_manual_param = [ @@ -810,7 +810,7 @@ let test_select_exposed_alias = [ true as flag ) as inner_x (a, b, c, d) ) as outer_x (str, num, price, flag, bonus) |} [ - attr' "str" Text; + attr' "str" (StringLiteral "abc"); attr' "num" Int; attr' "price" Float; attr' "flag" Bool; @@ -818,6 +818,69 @@ let test_select_exposed_alias = [ ] []; ] +let test_enum_as_variant = [ + "test_enum_as_variant" >:: (fun _ -> + + do_test "CREATE TABLE test35 (status enum('active','pending','deleted') NOT NULL DEFAULT 'pending')" [] []; + + do_test "SELECT status FROM test35" [ + attr' ~extra:[NotNull; WithDefault] "status" + (Type.(Union { ctors = (Enum_kind.Ctors.of_list ["active"; "pending"; "deleted"]); is_closed = true })) + ] []; + + do_test "INSERT INTO test35 (status) VALUES (@status)" [] [ + named "status" (Type.(Union { ctors = (Enum_kind.Ctors.of_list ["active"; "pending"; "deleted"]); is_closed = true })) + ]; + ) +] + +let test_enum_literal () = + + do_test "CREATE TABLE test36 (status enum('active','pending','deleted') NOT NULL DEFAULT 'pending')" [] []; + + let stmt = parse {|INSERT INTO test36 VALUES('pending')|} in + assert_equal ~msg:"schema" ~printer:Sql.Schema.to_string [] stmt.schema; + + let stmt2 = parse {|INSERT INTO test36 VALUES('active')|} in + assert_equal ~msg:"schema" ~printer:Sql.Schema.to_string [] stmt2.schema; + + let stmt3 = parse {|INSERT INTO test36 VALUES('deleted')|} in + assert_equal ~msg:"schema" ~printer:Sql.Schema.to_string [] stmt3.schema; + + let stmt4 = parse {|SELECT * FROM test36 WHERE status = 'active'|} in + assert_equal ~msg:"schema" ~printer:Sql.Schema.to_string + [attr' ~extra:[NotNull; WithDefault] "status" + (Type.(Union { ctors = (Enum_kind.Ctors.of_list ["active"; "pending"; "deleted"] ); is_closed = true }))] + stmt4.schema; + + let stmt5 = parse {|UPDATE test36 SET status = 'deleted' WHERE status = 'pending'|} in + assert_equal ~msg:"schema" ~printer:Sql.Schema.to_string [] stmt5.schema; + + let stmt6 = parse {| + SELECT * FROM test36 + WHERE status IN ('active', 'pending') + AND status != 'deleted' + |} in + assert_equal ~msg:"schema" ~printer:Sql.Schema.to_string + [attr' ~extra:[NotNull; WithDefault] "status" + (Type.(Union { ctors = (Enum_kind.Ctors.of_list ["active"; "pending"; "deleted"]); is_closed = true }))] + stmt6.schema; + + ignore @@ wrong {|INSERT INTO test36 VALUES('deleteddd')|} ; + ignore @@ wrong {|INSERT INTO test36 VALUES((IF(TRUE, 'a', 'b')))|} ; + + let stmt7 = parse {|INSERT INTO test36 VALUES((IF(TRUE, 'pending', 'active')))|} in + assert_equal ~msg:"schema" ~printer:Sql.Schema.to_string [] stmt7.schema; + + ignore @@ wrong {|INSERT INTO test36 VALUES((IF(TRUE, 'pending', 'b')))|}; + ignore @@ wrong {|INSERT INTO test36 VALUES(CONCAT(''))|}; + + ignore @@ wrong {|SELECT * FROM test36 WHERE status = 'activee'|}; + + let stmt8 = parse {|SELECT CONCAT(status, 'test') AS named FROM test36 WHERE status = 'active'|} in + assert_equal ~msg:"schema" ~printer:Sql.Schema.to_string + [attr' ~extra:[] "named" Text] + stmt8.schema let run () = Gen.params_mode := Some Named; @@ -844,6 +907,8 @@ let run () = "test_subquery_nullability" >::: test_subquery_nullability; "test_values_row" >::: test_values_row; "test_select_exposed_alias" >::: test_select_exposed_alias; + "test_enum_as_variant" >::: test_enum_as_variant; + "test_enum_literal" >:: test_enum_literal; ] in let test_suite = "main" >::: tests in