Skip to content

Commit

Permalink
syntax: move 'context' creation into class.
Browse files Browse the repository at this point in the history
  • Loading branch information
hnrgrgr committed Jan 10, 2011
1 parent 6a8aa43 commit 120e9b0
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 126 deletions.
136 changes: 69 additions & 67 deletions syntax/base.ml
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ struct
(fun (lookup : name -> expr) -> (o lookup)#repr)

let instantiate_modargs, instantiate_modargs_repr =
let lookup ctxt var =
let lookup argmap var =
try
`Constr ([NameMap.find var ctxt.argmap; "a"], [])
`Constr ([NameMap.find var argmap; "a"], [])
with Not_found ->
failwith ("Unbound type parameter '" ^ var)
in (fun ctxt -> instantiate (lookup ctxt)),
(fun ctxt -> instantiate_repr (lookup ctxt))
in (fun argmap -> instantiate (lookup argmap)),
(fun argmap -> instantiate_repr (lookup argmap))

let substitute env =
(object
Expand All @@ -62,8 +62,8 @@ struct
| e -> default# expr e
end) # expr

let cast_pattern ctxt ?(param="x") t =
let t = Untranslate.expr (instantiate_modargs ctxt t) in
let cast_pattern argmap ?(param="x") t =
let t = Untranslate.expr (instantiate_modargs argmap t) in
(<:patt< $lid:param$ >>,
<:expr<
let module M =
Expand Down Expand Up @@ -162,7 +162,7 @@ struct
List.fold_left (fun f p -> <:module_expr< $f$ ($p$) >>) f args

let atype_expr ctxt expr =
Untranslate.expr (instantiate_modargs ctxt expr)
Untranslate.expr (instantiate_modargs ctxt.argmap expr)

let atype ctxt (name, params, rhs, _, _) =
match rhs with
Expand Down Expand Up @@ -309,8 +309,58 @@ struct
| _, #M.low -> -1
| _ -> 0)
decls)

let default_generate ~make_module_expr ~make_module_type context decls =

let find_non_regular params tnames decls : name list =
List.concat_map
(object
inherit [name list] fold as default
method crush = List.concat
method expr = function
| `Constr ([t], args)
when NameSet.mem t tnames ->
(List.concat_map2
(fun (p,_) a -> match a with
| `Param (q,_) when p = q -> []
| _ -> [t])
params
args)
| e -> default#expr e
end)#decl decls

let extract_params =
let has_params params (_, ps, _, _, _) = ps = params in
function
| [] -> invalid_arg "extract_params"
| (_,params,_,_,_)::rest
when List.for_all (has_params params) rest ->
params
| (_,_,rhs,_,_)::_ ->
(* all types in a clique must have the same parameters *)
raise (Underivable ("Instances can only be derived for "
^"recursive groups where all types\n"
^"in the group have the same parameters."))

let setup_context (tdecls : decl list) : context =
let params = extract_params tdecls
and tnames = NameSet.fromList (List.map (fun (name,_,_,_,_) -> name) tdecls) in
match find_non_regular params tnames tdecls with
| _::_ as names ->
failwith ("The following types contain non-regular recursion:\n "
^String.concat ", " names
^"\nderiving does not support non-regular types")
| [] ->
let argmap =
List.fold_right
(fun (p,_) m -> NameMap.add p (Printf.sprintf "V_%s" p) m)
params
NameMap.empty in
{ argmap = argmap;
params = params;
tnames = tnames;
toplevel = None;
}

let default_generate ~make_module_expr ~make_module_type decls =
(* plan:
set up an enclosing recursive module
generate functors for all types in the clique
Expand All @@ -322,7 +372,7 @@ struct
- where there's no recursion
- etc.
*)
(* let _ = ensure_no_polymorphic_recursion in *)
let context = setup_context decls in
let wrapper_name = Printf.sprintf "%s_%s" classname (random_id 32) in
let make_functor =
List.fold_right
Expand Down Expand Up @@ -376,7 +426,8 @@ struct
else
<:module_type< $uid:runtimename$.$uid:classname$ with type a = $atype context decl$ >>

let default_generate_sigs ~make_module_sig context decls =
let default_generate_sigs ~make_module_sig decls =
let context = setup_context decls in
let make (tname, _, _ ,_, generated as decl) =
if generated
then <:sig_item< >>
Expand All @@ -386,59 +437,10 @@ struct
<:sig_item< $list:List.map make decls$ >>

end

let find_non_regular params tnames decls : name list =
List.concat_map
(object
inherit [name list] fold as default
method crush = List.concat
method expr = function
| `Constr ([t], args)
when NameSet.mem t tnames ->
(List.concat_map2
(fun (p,_) a -> match a with
| `Param (q,_) when p = q -> []
| _ -> [t])
params
args)
| e -> default#expr e
end)#decl decls

let extract_params =
let has_params params (_, ps, _, _, _) = ps = params in
function
| [] -> invalid_arg "extract_params"
| (_,params,_,_,_)::rest
when List.for_all (has_params params) rest ->
params
| (_,_,rhs,_,_)::_ ->
(* all types in a clique must have the same parameters *)
raise (Underivable ("Instances can only be derived for "
^"recursive groups where all types\n"
^"in the group have the same parameters."))

let setup_context loc (tdecls : decl list) : context =
let params = extract_params tdecls
and tnames = NameSet.fromList (List.map (fun (name,_,_,_,_) -> name) tdecls) in
match find_non_regular params tnames tdecls with
| _::_ as names ->
failwith ("The following types contain non-regular recursion:\n "
^String.concat ", " names
^"\nderiving does not support non-regular types")
| [] ->
let argmap =
List.fold_right
(fun (p,_) m -> NameMap.add p (Printf.sprintf "V_%s" p) m)
params
NameMap.empty in
{ argmap = argmap;
params = params;
tnames = tnames;
toplevel = None;
}

type deriver = Loc.t * context * Type.decl list -> Ast.str_item
and sigderiver = Loc.t * context * Type.decl list -> Ast.sig_item

type deriver = Loc.t * Type.decl list -> Ast.str_item
and sigderiver = Loc.t * Type.decl list -> Ast.sig_item
let derivers : (name, (deriver * sigderiver)) Hashtbl.t = Hashtbl.create 15
let register = Hashtbl.add derivers
let find classname =
Expand All @@ -451,13 +453,13 @@ module Register
(Desc : ClassDescription)
(MakeClass : functor(L : Loc) -> Class) = struct

let generate (loc, context, decls) =
let generate (loc, decls) =
let module Class = MakeClass(struct let loc = loc end) in
Class.generate context decls
Class.generate decls

let generate_sigs (loc, context, decls) =
let generate_sigs (loc, decls) =
let module Class = MakeClass(struct let loc = loc end) in
Class.generate_sigs context decls
Class.generate_sigs decls

let depends loc =
let module Class = MakeClass(struct let loc = loc end) in
Expand Down
14 changes: 8 additions & 6 deletions syntax/defs.ml
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ module type ClassDescription = sig
end

module type Class = sig
val generate: context -> Type.decl list -> Ast.str_item
val generate_sigs: context -> Type.decl list -> Ast.sig_item
val generate: Type.decl list -> Ast.str_item
val generate_sigs: Type.decl list -> Ast.sig_item
val generate_expr: context -> Type.expr -> Ast.module_expr
end

Expand Down Expand Up @@ -94,14 +94,16 @@ module type ClassHelpers = sig
val tuple: ?param:string -> int -> string list * Ast.patt * Ast.expr

val cast_pattern:
context -> ?param:string -> Type.expr -> Ast.patt * Ast.expr * Ast.expr
Type.name Type.NameMap.t -> ?param:string ->
Type.expr -> Ast.patt * Ast.expr * Ast.expr

(* For Functor only *)
val modname_from_qname: qname:string list -> classname:string -> Ast.ident
val substitute: Type.name Type.NameMap.t -> Type.expr -> Type.expr
val setup_context: Type.decl list -> context

(* For Pickle only *)
val instantiate_modargs_repr: context -> Type.repr -> Type.repr
val instantiate_modargs_repr: Type.name Type.NameMap.t -> Type.repr -> Type.repr

class virtual make_module_expr : generator
val make_module_sig: context -> Type.decl -> Ast.module_type
Expand All @@ -110,9 +112,9 @@ module type ClassHelpers = sig
val default_generate:
make_module_expr:(context -> Type.decl -> Ast.module_expr) ->
make_module_type:(context -> Type.decl -> Ast.module_type) ->
context -> Type.decl list -> Ast.str_item
Type.decl list -> Ast.str_item
val default_generate_sigs:
make_module_sig:(context -> Type.decl -> Ast.module_type) ->
context -> Type.decl list -> Ast.sig_item
Type.decl list -> Ast.sig_item

end
2 changes: 1 addition & 1 deletion syntax/dump_class.ml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ module InContext (L : Loc) : Class = struct
<:match_case< `$name$ x -> $self#dump_int ctxt n$; $to_buffer$ >>,
<:match_case< $`int:n$ -> `$name$ ($from_stream$) >> end
| Extends t ->
let patt, guard, cast = cast_pattern ctxt t in
let patt, guard, cast = cast_pattern ctxt.argmap t in
let to_buffer =
<:expr< $self#call_expr ctxt t "to_buffer"$ buffer $cast$ >> in
let from_stream =
Expand Down
4 changes: 2 additions & 2 deletions syntax/eq_class.ml
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ module InContext (L : Loc) : Class = struct
| Tag (name, Some e) ->
<:match_case< `$name$ l, `$name$ r -> $self#call_expr ctxt e "eq"$ l r >>
| Extends t ->
let lpatt, lguard, lcast = cast_pattern ctxt ~param:"l" t in
let rpatt, rguard, rcast = cast_pattern ctxt ~param:"r" t in
let lpatt, lguard, lcast = cast_pattern ctxt.argmap ~param:"l" t in
let rpatt, rguard, rcast = cast_pattern ctxt.argmap ~param:"r" t in
let patt = <:patt< ($lpatt$, $rpatt$) >> in
let eq = <:expr< $self#call_expr ctxt t "eq"$ $lcast$ $rcast$ >> in
<:match_case< $patt$ when $lguard$ && $rguard$ -> $eq$ >>
Expand Down
4 changes: 1 addition & 3 deletions syntax/extend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ struct
fatal_error loc msg

let derive proj (loc : Loc.t) tdecls classname =
let context = display_errors loc (Base.setup_context loc) tdecls in
display_errors loc
(proj (Base.find classname)) (loc, context, tdecls)
display_errors loc (proj (Base.find classname)) (loc, tdecls)

let derive_str loc (tdecls : Type.decl list) classname : Ast.str_item =
derive fst loc tdecls classname
Expand Down
Loading

0 comments on commit 120e9b0

Please sign in to comment.