Skip to content

Commit

Permalink
out_error_bounds and out_racket: correctly handle constant expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
monadius committed Dec 3, 2017
1 parent c0d83a6 commit bd0e168
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 30 deletions.
4 changes: 2 additions & 2 deletions fptaylor.ml
Expand Up @@ -267,7 +267,7 @@ let absolute_errors task tf =

let () = try
Out_racket.create_racket_file (get_file_formatter "racket")
task.Task.name cs total2_i.high exp full_expr (Some total_i.high)
task total2_i.high exp full_expr (Some total_i.high)
with Not_found -> () in

Log.report `Important "exact bound (exp = %d): %s" exp (bound_info bound);
Expand Down Expand Up @@ -338,7 +338,7 @@ let relative_errors task tf (f_min, f_max) =
let () =
try
Out_racket.create_racket_file (get_file_formatter "racket")
task.Task.name cs b2_i.high exp full_expr (Some total_i.high)
task b2_i.high exp full_expr (Some total_i.high)
with Not_found -> () in

Log.report `Important "exact bound-rel (exp = %d): %s" exp (bound_info bound);
Expand Down
52 changes: 31 additions & 21 deletions out_error_bounds.ml
Expand Up @@ -197,18 +197,24 @@ let print_mpfr_and_init_f fmt env expr =
let args = List.map (fun (_, name) -> "mpfr_t " ^ name) env.vars in
let body, result_name =
Lib.write_to_string_result (translate_mpfr env) expr in
assert (result_name = "r_op");
let c_names_mpfr = List.map (fun (_, i) -> sprintf "c_%d" i) env.constants in
let c_names_double = List.map (fun (_, i) -> sprintf "c_%dd" i) env.constants in
let c_names_single = List.map (fun (_, i) -> sprintf "c_%df" i) env.constants in
fprintf fmt "static mpfr_t %a;@." (print_list ", ") env.tmp_vars;
fprintf fmt "static mpfr_t %a;@." (print_list ", ") c_names_mpfr;
fprintf fmt "static double %a;@." (print_list ", ") c_names_double;
fprintf fmt "static float %a;@." (print_list ", ") c_names_single;
let tmp_vars_flag = List.length env.tmp_vars > 0 in
let constants_flag = List.length env.constants > 0 in
if tmp_vars_flag then
fprintf fmt "static mpfr_t %a;@." (print_list ", ") env.tmp_vars;
if constants_flag then begin
fprintf fmt "static mpfr_t %a;@." (print_list ", ") c_names_mpfr;
fprintf fmt "static double %a;@." (print_list ", ") c_names_double;
fprintf fmt "static float %a;@." (print_list ", ") c_names_single;
end;
pp_print_newline fmt ();
fprintf fmt "void f_init()@.{@.";
fprintf fmt " mpfr_inits(%a, NULL);@." (print_list ", ") env.tmp_vars;
fprintf fmt " mpfr_inits(%a, NULL);@." (print_list ", ") c_names_mpfr;
if tmp_vars_flag then
fprintf fmt " mpfr_inits(%a, NULL);@." (print_list ", ") env.tmp_vars;
if constants_flag then
fprintf fmt " mpfr_inits(%a, NULL);@." (print_list ", ") c_names_mpfr;
List.iter
(fun (n, i) ->
fprintf fmt " init_constants(\"%s\", MPFR_RNDN, &c_%df, &c_%dd, c_%d);@."
Expand All @@ -217,12 +223,17 @@ let print_mpfr_and_init_f fmt env expr =
fprintf fmt "}@.";
pp_print_newline fmt ();
fprintf fmt "void f_clear()@.{@.";
fprintf fmt " mpfr_clears(%a, NULL);@." (print_list ", ") env.tmp_vars;
fprintf fmt " mpfr_clears(%a, NULL);@." (print_list ", ") c_names_mpfr;
if tmp_vars_flag then
fprintf fmt " mpfr_clears(%a, NULL);@." (print_list ", ") env.tmp_vars;
if constants_flag then
fprintf fmt " mpfr_clears(%a, NULL);@." (print_list ", ") c_names_mpfr;
fprintf fmt "}@.";
pp_print_newline fmt ();
fprintf fmt "void f_mpfr(mpfr_t r_op, %a)@.{@." (print_list ", ") args;
fprintf fmt "%s}@." body
fprintf fmt "%s" body;
if result_name <> "r_op" then
fprintf fmt " mpfr_set(r_op, %s, MPFR_RNDN);@." result_name;
fprintf fmt "}@."

let print_double_f fmt env expr =
clear_exprs env;
Expand All @@ -241,23 +252,22 @@ let print_single_f fmt env expr =
fprintf fmt "%s@. return %s;@.}@." body result_name

let generate_error_bounds fmt task =
let task_vars = all_variables task in
let var_bounds = List.map (variable_interval task) task_vars in
let task_vars, var_bounds =
let vars = all_active_variables task in
let bounds = List.map (variable_interval task) vars in
match vars with
| [] -> ["unused"], [Interval.make_interval 1. 2.]
| _ -> vars, bounds in
let var_names = List.map (fun s -> "v_" ^ ExprOut.fix_name s) task_vars in
let env = mk_env (Lib.zip task_vars var_names) in
let expr = remove_rnd task.expression in
fprintf fmt "#include \"search_mpfr.h\"@.";
fprintf fmt "#include \"search_mpfr_utils.h\"@.";
pp_print_newline fmt ();
if List.length var_bounds > 0 then begin
let low_str = List.map (fun b -> sprintf "%.20e" b.Interval.low) var_bounds in
let high_str = List.map (fun b -> sprintf "%.20e" b.Interval.high) var_bounds in
fprintf fmt "double low[] = {%a};@." (print_list ", ") low_str;
fprintf fmt "double high[] = {%a};@." (print_list ", ") high_str;
end else begin
fprintf fmt "double low[] = {0};@.";
fprintf fmt "double high[] = {0};@."
end;
let low_str = List.map (fun b -> sprintf "%.20e" b.Interval.low) var_bounds in
let high_str = List.map (fun b -> sprintf "%.20e" b.Interval.high) var_bounds in
fprintf fmt "double low[] = {%a};@." (print_list ", ") low_str;
fprintf fmt "double high[] = {%a};@." (print_list ", ") high_str;
pp_print_newline fmt ();
print_mpfr_and_init_f fmt env expr;
pp_print_newline fmt ();
Expand Down
17 changes: 11 additions & 6 deletions out_racket.ml
Expand Up @@ -14,18 +14,23 @@ open Expr

module Out = ExprOut.Make(ExprOut.RacketIntervalPrinter)

let gen_racket_function fmt (name, cs, total2, exp, e, opt_bound) =
let gen_racket_function fmt (task, total2, exp, e, opt_bound) =
let n2s = Num.string_of_num in
let f2s f = n2s (More_num.num_of_float f) in
let p' = Format.pp_print_string fmt in
let var_names = vars_in_expr e in
let var_bounds = List.map cs.var_rat_bounds var_names in
let var_names, var_bounds =
let e_vars = vars_in_expr e in
let names = Task.all_variables task in
let vars = List.filter (fun v -> List.mem v e_vars) names in
match vars with
| [] -> ["unused"], [(Num.Int 1, Num.Int 2)]
| _ -> vars, List.map (Task.variable_num_interval task) vars in
let bound_strings =
List.map (fun (low, high) ->
Format.sprintf "(cons %s %s)" (n2s low) (n2s high))
var_bounds in
let vars = List.map (fun v -> v ^ "-var") var_names in
Format.fprintf fmt "(define name \"%s\")@." name;
Format.fprintf fmt "(define name \"%s\")@." task.Task.name;
Format.fprintf fmt "(define opt-max %s)@."
(match opt_bound with None -> "#f" | Some f -> f2s f);
Format.fprintf fmt "(define bounds (list %s))@." (String.concat " " bound_strings);
Expand All @@ -37,5 +42,5 @@ let gen_racket_function fmt (name, cs, total2, exp, e, opt_bound) =
Out.print_fmt ~margin:80 fmt e;
p' ")"

let create_racket_file fmt ~name cs total2 exp expr opt_bound =
gen_racket_function fmt (name, cs, total2, exp, expr, opt_bound)
let create_racket_file fmt task total2 exp expr opt_bound =
gen_racket_function fmt (task, total2, exp, expr, opt_bound)
2 changes: 1 addition & 1 deletion out_racket.mli
Expand Up @@ -10,4 +10,4 @@
(* Racket output for FPTaylor expressions *)
(* -------------------------------------------------------------------------- *)

val create_racket_file : Format.formatter -> name:string -> Expr.constraints -> float -> int -> Expr.expr -> float option -> unit
val create_racket_file : Format.formatter -> Task.task -> float -> int -> Expr.expr -> float option -> unit
5 changes: 5 additions & 0 deletions task.ml
Expand Up @@ -32,6 +32,11 @@ type task = {
let all_variables t =
List.map (fun v -> v.var_name) t.variables

let all_active_variables t =
let vars = Expr.vars_in_expr t.expression in
let names = all_variables t in
List.filter (fun name -> List.mem name vars) names

let find_variable t name =
List.find (fun v -> v.var_name = name) t.variables

Expand Down
2 changes: 2 additions & 0 deletions task.mli
Expand Up @@ -27,6 +27,8 @@ type task = {

val all_variables : task -> string list

val all_active_variables : task -> string list

val find_variable : task -> string -> var_info

val variable_type : task -> string -> Rounding.value_type
Expand Down

0 comments on commit bd0e168

Please sign in to comment.