@@ -189,6 +189,67 @@ def _dump_yaml(
189189 )
190190
191191
192+ def gen_oplist (
193+ output_path : str ,
194+ model_file_path : Optional [str ] = None ,
195+ ops_schema_yaml_path : Optional [str ] = None ,
196+ root_ops : Optional [str ] = None ,
197+ ops_dict : Optional [str ] = None ,
198+ include_all_operators : bool = False ,
199+ ):
200+ assert (
201+ model_file_path
202+ or ops_schema_yaml_path
203+ or root_ops
204+ or ops_dict
205+ or include_all_operators
206+ ), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or ops_dict or include_all_operators."
207+
208+ assert output_path , "Need to provide output_path for dumped yaml file."
209+ op_set = set ()
210+ source_name = None
211+ et_kernel_metadata = {}
212+ if root_ops :
213+ op_set .update (set (filter (lambda x : len (x ) > 0 , root_ops .split ("," ))))
214+ et_kernel_metadata = merge_et_kernel_metadata (
215+ et_kernel_metadata , {op : ["default" ] for op in op_set }
216+ )
217+ if ops_dict :
218+ ops_and_metadata = json .loads (ops_dict )
219+ for op , metadata in ops_and_metadata .items ():
220+ op_set .update ({op })
221+ op_metadata = metadata if len (metadata ) > 0 else ["default" ]
222+ et_kernel_metadata = merge_et_kernel_metadata (
223+ et_kernel_metadata , {op : op_metadata }
224+ )
225+ if model_file_path :
226+ assert os .path .isfile (
227+ model_file_path
228+ ), "The value for --model_file_path needs to be a valid file."
229+ op_set .update (_get_operators (model_file_path ))
230+ source_name = model_file_path
231+ et_kernel_metadata = merge_et_kernel_metadata (
232+ et_kernel_metadata , _get_kernel_metadata_for_model (model_file_path )
233+ )
234+ if ops_schema_yaml_path :
235+ assert os .path .isfile (
236+ ops_schema_yaml_path
237+ ), "The value for --ops_schema_yaml_path needs to be a valid file."
238+ et_kernel_metadata = merge_et_kernel_metadata (
239+ et_kernel_metadata ,
240+ _get_et_kernel_metadata_from_ops_yaml (ops_schema_yaml_path ),
241+ )
242+ op_set .update (et_kernel_metadata .keys ())
243+ source_name = ops_schema_yaml_path
244+ _dump_yaml (
245+ sorted (op_set ),
246+ output_path ,
247+ os .path .basename (source_name ) if source_name else None ,
248+ et_kernel_metadata ,
249+ include_all_operators ,
250+ )
251+
252+
192253def main (args : List [Any ]) -> None :
193254 """This binary generates selected_operators.yaml which will be consumed by caffe2/torchgen/gen.py.
194255 It reads the model file, deserialize it and dumps all the operators into selected_operators.yaml so
@@ -233,54 +294,14 @@ def main(args: List[Any]) -> None:
233294 required = False ,
234295 )
235296 options = parser .parse_args (args )
236- assert (
237- options .model_file_path
238- or options .ops_schema_yaml_path
239- or options .root_ops
240- or options .ops_dict
241- or options .include_all_operators
242- ), "Need to provide either model_file_path or ops_schema_yaml_path or root_ops or include_all_operators."
243- op_set = set ()
244- source_name = None
245- et_kernel_metadata = {}
246- if options .root_ops :
247- op_set .update (set (filter (lambda x : len (x ) > 0 , options .root_ops .split ("," ))))
248- et_kernel_metadata = merge_et_kernel_metadata (
249- et_kernel_metadata , {op : ["default" ] for op in op_set }
250- )
251- if options .ops_dict :
252- ops_and_metadata = json .loads (options .ops_dict )
253- for op , metadata in ops_and_metadata .items ():
254- op_set .update ({op })
255- op_metadata = metadata if len (metadata ) > 0 else ["default" ]
256- et_kernel_metadata = merge_et_kernel_metadata (
257- et_kernel_metadata , {op : op_metadata }
258- )
259- if options .model_file_path :
260- assert os .path .isfile (
261- options .model_file_path
262- ), "The value for --model_file_path needs to be a valid file."
263- op_set .update (_get_operators (options .model_file_path ))
264- source_name = options .model_file_path
265- et_kernel_metadata = merge_et_kernel_metadata (
266- et_kernel_metadata , _get_kernel_metadata_for_model (options .model_file_path )
267- )
268- if options .ops_schema_yaml_path :
269- assert os .path .isfile (
270- options .ops_schema_yaml_path
271- ), "The value for --ops_schema_yaml_path needs to be a valid file."
272- et_kernel_metadata = merge_et_kernel_metadata (
273- et_kernel_metadata ,
274- _get_et_kernel_metadata_from_ops_yaml (options .ops_schema_yaml_path ),
275- )
276- op_set .update (et_kernel_metadata .keys ())
277- source_name = options .ops_schema_yaml_path
278- _dump_yaml (
279- sorted (op_set ),
280- options .output_path ,
281- os .path .basename (source_name ) if source_name else None ,
282- et_kernel_metadata ,
283- options .include_all_operators ,
297+
298+ gen_oplist (
299+ output_path = options .output_path ,
300+ model_file_path = options .model_file_path ,
301+ ops_schema_yaml_path = options .ops_schema_yaml_path ,
302+ root_ops = options .root_ops ,
303+ ops_dict = options .ops_dict ,
304+ include_all_operators = options .include_all_operators ,
284305 )
285306
286307
0 commit comments