Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Coalesce Function for Proper Type Inference #128

Open
wants to merge 2 commits into
base: null
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 6 additions & 3 deletions lib/sql.ml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ struct
| Group of t (* _ -> t *)
| Agg (* 'a -> 'a *)
| Multi of tyvar * tyvar (* 'a -> ... -> 'a -> 'b *)
| Coalesce of tyvar * tyvar
| Ret of kind (* _ -> t *) (* TODO eliminate *)
| F of tyvar * tyvar list

Expand All @@ -125,13 +126,13 @@ struct
| Group ret -> fprintf pp "|_| -> %s" (show ret)
| Ret ret -> fprintf pp "_ -> %s" (show_kind ret)
| F (ret, args) -> fprintf pp "%s -> %s" (String.concat " -> " @@ List.map string_of_tyvar args) (string_of_tyvar ret)
| Multi (ret, each_arg) -> fprintf pp "{ %s }+ -> %s" (string_of_tyvar each_arg) (string_of_tyvar ret)
| Multi (ret, each_arg) | Coalesce (ret, each_arg) -> fprintf pp "{ %s }+ -> %s" (string_of_tyvar each_arg) (string_of_tyvar ret)

let string_of_func = Format.asprintf "%a" pp_func

let is_grouping = function
| Group _ | Agg -> true
| Ret _ | F _ | Multi _ -> false
| Ret _ | F _ | Multi _ | Coalesce _ -> false
end

module Constraint =
Expand Down Expand Up @@ -426,6 +427,7 @@ val exclude : int -> string -> unit
val monomorphic : Type.t -> Type.t list -> string -> unit
val multi : ret:Type.tyvar -> Type.tyvar -> string -> unit
val multi_polymorphic : string -> unit
val add_multi: Type.func -> string -> unit
val sponge : Type.func

end = struct
Expand Down Expand Up @@ -487,7 +489,7 @@ let () =
"floor" |> monomorphic int [float];
"nullif" |> add 2 (F (Var 0 (* TODO nullable *), [Var 0; Var 0]));
"ifnull" |> add 2 (F (Var 0, [Var 1; Var 0]));
["least";"greatest";"coalesce"] ||> multi_polymorphic;
["least";"greatest";] ||> multi_polymorphic;
"strftime" |> exclude 1; (* requires at least 2 arguments *)
["concat";"concat_ws";"strftime"] ||> multi ~ret:(Typ text) (Typ text);
"date" |> monomorphic datetime [datetime];
Expand All @@ -505,4 +507,5 @@ let () =
"substring_index" |> monomorphic text [text; text; int];
"last_insert_id" |> monomorphic int [];
"last_insert_id" |> monomorphic int [int];
add_multi Type.(Coalesce (Var 0, Var 0)) "coalesce";
()
33 changes: 24 additions & 9 deletions lib/syntax.ml
Original file line number Diff line number Diff line change
Expand Up @@ -171,18 +171,13 @@ and assign_types expr =
(String.concat ", " @@ List.map show types)
in
if !debug then eprintfn "func %s" (show_func ());
let types_to_arg each_arg = List.map (Fun.const each_arg) types in
let func =
match func with
| Multi (ret,each_arg) -> F (ret, List.map (fun _ -> each_arg) types)
| Multi (ret,each_arg) -> F (ret, types_to_arg each_arg)
| x -> x
in
let (ret,inferred_params) = match func, types with
| Multi _, _ -> assert false (* rewritten into F above *)
| Agg, [typ]
| Group typ, _ -> typ, types
| Agg, _ -> fail "cannot use this grouping function with %d parameters" (List.length types)
| F (_, args), _ when List.length args <> List.length types -> fail "wrong number of arguments : %s" (show_func ())
| F (ret, args), _ ->
let convert_args ret args =
let typevar = Hashtbl.create 10 in
List.iter2 begin fun arg typ ->
match arg with
Expand All @@ -206,8 +201,28 @@ and assign_types expr =
if !debug then typevar |> Hashtbl.iter (fun i typ -> eprintfn "%s : %s" (string_of_tyvar (Var i)) (show typ));
let convert = function Typ t -> t | Var i -> Hashtbl.find typevar i in
let args = List.map convert args in
args, convert ret in

let (ret,inferred_params) = match func, types with
| Multi _, _ -> assert false (* rewritten into F above *)
| Agg, [typ]
| Group typ, _ -> typ, types
| Agg, _ -> fail "cannot use this grouping function with %d parameters" (List.length types)
| F (_, args), _ when List.length args <> List.length types -> fail "wrong number of arguments : %s" (show_func ())
| Coalesce (ret, each_arg) , _ ->
let args = types_to_arg each_arg in
let args, ret = convert_args ret args in
let has_one_strict = List.exists (fun arg ->
match arg.nullability with
| Strict -> true | _ -> false
) types in
let ret = if has_one_strict then
{ ret with nullability = Strict }
else args |> common_nullability |> undepend ret in
ret , args
| F (ret, args), _ ->
let args, ret = convert_args ret args in
let nullable = common_nullability args in
let ret = convert ret in
undepend ret nullable, args
| Ret Any, _ -> (* lame *)
begin match common_supertype types with
Expand Down
12 changes: 12 additions & 0 deletions src/test.ml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ let wrong sql =
sql >:: (fun () -> ("Expected error in : " ^ sql) @? (try ignore (Main.parse_one' (sql,[])); false with _ -> true))

let attr ?(extra=[]) n d = make_attribute n (Some d) (Constraints.of_list extra)

let attr' ?(extra=[]) ?(nullability=Type.Strict) name kind =
let domain: Type.t = { t = kind; nullability; } in
{name;domain;extra=Constraints.of_list extra }

let named s t = new_param { label = Some s; pos = (0,0) } (Type.strict t)
let named_nullable s t = new_param { label = Some s; pos = (0,0) } (Type.nullable t)
let param t = new_param { label = None; pos = (0,0) } (Type.strict t)
Expand Down Expand Up @@ -186,6 +191,12 @@ let test_manual_param = [
];
]

let test_coalesce = [
tt "CREATE TABLE test8 (x integer unsigned null)" [] [];
tt "SELECT COALESCE(x, null, null) as x FROM test8" [attr' ~nullability:(Nullable) "x" Int;] [];
tt "SELECT COALESCE(x, coalesce(null, null, 75, null), null) as x FROM test8" [attr "x" Int;] [];
]


let run () =
Gen.params_mode := Some Named;
Expand All @@ -199,6 +210,7 @@ let run () =
"JOIN result columns" >:: test_join_result_cols;
"enum" >::: test_enum;
"manual_param" >::: test_manual_param;
"test_coalesce" >::: test_coalesce;
]
in
let test_suite = "main" >::: tests in
Expand Down