From 510a368173d4386abb81ae2ec14069eb788550db Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 2 Sep 2020 16:04:37 -0700 Subject: [PATCH 1/4] [quant][graphmode][fx] Support quantization for custom module Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- torch/quantization/fx/custom_module.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 torch/quantization/fx/custom_module.py diff --git a/torch/quantization/fx/custom_module.py b/torch/quantization/fx/custom_module.py new file mode 100644 index 000000000000..7d6004870fcc --- /dev/null +++ b/torch/quantization/fx/custom_module.py @@ -0,0 +1,13 @@ +CUSTOM_MODULES = set() + +def register_custom_module_class(custom_module_class): + ''' Register a module as custom module, when the module + appear in the code, we will observe and quantize it as one + unit + ''' + CUSTOM_MODULES.insert(custom_module_class) + +def is_custom_module(module): + ''' Check if a module is a custom module or not + ''' + return type(module) in CUSTOM_MODULES From a9f935e74d117738afca0b17eb9ec005a564f53f Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 9 Sep 2020 16:22:11 -0700 Subject: [PATCH 2/4] Update on "[quant][graphmode][fx] Support quantization for custom module" Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23580642](https://our.internmc.facebook.com/intern/diff/D23580642) [ghstack-poisoned] --- test/quantization/test_quantize_fx.py | 64 +++++++++++++++++-- .../quantization/fx/quantization_patterns.py | 2 - torch/quantization/fx/quantize.py | 53 ++++++++------- torch/quantization/quantize_fx.py | 4 +- 4 files changed, 89 insertions(+), 34 deletions(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index d86b5e5119c5..fbd94a6d86c1 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -274,16 +274,66 @@ def forward(self, x): x = self.custom(x) return x - m = symbolic_trace(M(), delegate_class=CustomDelegate) - m.eval() + class RefM(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 1, 1) + self.conv2 = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + data = torch.randn(1, 1, 1, 1) + # instantiate M and RefM and align the parameters + original_m = M() + original_ref_m = RefM() + original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach()) + original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach()) + original_ref_m.conv2.weight = torch.nn.Parameter(original_m.custom.conv.weight.detach()) + original_ref_m.conv2.bias = torch.nn.Parameter(original_m.custom.conv.bias.detach()) + + m = symbolic_trace(original_m, delegate_class=CustomDelegate).eval() qconfig_dict = {'': default_qconfig} m = prepare_fx(m, qconfig_dict) - print(m.graph) - m = convert_fx(m) - print(m.graph) - print(m) - print(m.custom) + m(data) + # input and output of first conv, observer for custom module + # will be inserted in the custom module itself + count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 2 + } + self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) + # for output of conv in the custom module + count_check = { + ns.call_module(torch.quantization.MinMaxObserver): 1 + } + self.checkGraphModuleNodes(m.custom, expected_node_occurrence=count_check) + m = convert_fx(m) + count_check = { + ns.call_function(torch.quantize_per_tensor) : 1, + ns.call_module(nnq.Conv2d) : 1, + ns.call_method('dequantize') : 1, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=count_check) + count_check = { + # quantization of input happens in parent module + # quantization of output happens in the quantized conv module + ns.call_function(torch.quantize_per_tensor) : 0, + # dequantization for output happens in parent module + ns.call_method('dequantize') : 0, + } + self.checkGraphModuleNodes(m.custom, expected_node_occurrence=count_check) + res = m(data) + + # quantize the reference model + ref_m = symbolic_trace(original_ref_m).eval() + ref_m = prepare_fx(ref_m, qconfig_dict) + ref_m(data) + ref_m = convert_fx(ref_m) + ref_res = ref_m(data) + self.assertEqual(res, ref_res) class TestQuantizeFxOps(QuantizationTestCase): diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 04b9e6d9acc8..0de01dec8ce4 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -516,9 +516,7 @@ def convert(self, quantizer, node, load_arg, debug=False): else: convert = torch.quantization.convert_child_module_fx observed_custom_module = quantizer.modules[node.target] - print('observed custom module:', type(observed_custom_module)) quantized_custom_module = convert(observed_custom_module, debug=debug) - print('quantized custom module:', quantized_custom_module) parent_name, name = _parent_name(node.target) setattr(quantizer.modules[parent_name], name, quantized_custom_module) return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 98a40c024e12..b6f7df72f8c9 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -239,15 +239,26 @@ def load_arg(a): graph_inputs.append(node.name) get_new_observer_name = get_new_attr_name_with_prefix('activation_post_process_') + + def insert_observer(node, observer, device): + observer_name = get_new_observer_name(model) + setattr(model, observer_name, observer) + self.activation_post_process_map[node.name] = observer + env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {}) + observed_node_names_set.add(node.name) + if device: + getattr(model, observer_name).to(device) + for node in model.graph.nodes: if node.name in observed_node_names_set: continue root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None)) - print('match for node:', node.name, root_node) if root_node is None: env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: + # index for input of custom module that needs to be observed in parent + custom_module_input_idxs = None if node.name in custom_module_nodes: # observe custom module custom_module = self.modules[node.target] @@ -256,23 +267,15 @@ def load_arg(a): prepare = torch.quantization.prepare_dynamic_child_module_fx else: prepare = torch.quantization.prepare_child_module_fx - observed_custom_module, child_observed_idxs, output_is_observed = prepare(traced_custom_module, {'': qconfig}) + observed_custom_module, custom_module_input_idxs, \ + output_is_observed = prepare(traced_custom_module, {'': qconfig}) observed_custom_module._is_custom_module = True + observed_custom_module._observed_input_idxs = custom_module_input_idxs observed_custom_module._output_is_observed = output_is_observed parent_name, name = _parent_name(node.target) setattr(self.modules[parent_name], name, observed_custom_module) env[node.name] = observed_graph.node_copy(node, load_arg) - - def insert_observer(node, observer, device): - observer_name = get_new_observer_name(model) - setattr(model, observer_name, observer) - self.activation_post_process_map[node.name] = observer - env[node.name] = observed_graph.create_node('call_module', observer_name, (load_arg(node),), {}) - observed_node_names_set.add(node.name) - if device: - getattr(model, observer_name).to(device) - # don't need to insert observer for output in dynamic quantization if self.is_dynamic_quant: continue @@ -306,6 +309,12 @@ def is_observed(input_arg): # # respect device affinity when adding observers # device = assert_and_get_unique_device(model) # insert_observer(node, new_observer, device) + if custom_module_input_idxs is not None: + for idx in custom_module_input_idxs: + if node.args[idx].name not in observed_node_names_set: + new_observer = qconfig.activation() + device = assert_and_get_unique_device(model) + insert_observer(node.args[idx], new_observer, device) else: env[node.name] = observed_graph.node_copy(node, load_arg) @@ -318,6 +327,7 @@ def is_observed(input_arg): observer_name = get_new_observer_name(model) _, qconfig, is_weight = quants[node.name] if qconfig is not None: + # TODO: use insert_observer new_observer = \ qconfig.weight() if is_weight else qconfig.activation() # respect device affinity when adding observers @@ -399,22 +409,20 @@ def _convert(self, model, inplace=False, debug=False, is_dynamic_quant=False, is # add custom modules to the match for node in model.graph.nodes: - if node.op == 'call_module': - print('type:', type(self.modules[node.target])) if node.op == 'call_module' and \ hasattr(self.modules[node.target], '_is_custom_module') and \ self.modules[node.target]._is_custom_module: custom_module_qconfig = self.qconfig_map[node.name] - print('addding custom module mtches:', node.name) matches[node.name] = (node, [node], CustomModule(self, node), custom_module_qconfig) self.quantized_graph = Graph() env = {} quant_env = {} + graph_inputs = [] for node in model.graph.nodes: if node.op == 'placeholder': - env[node.name] = node + graph_inputs.append(node.name) def load_non_quantized(n): if n.name not in env: @@ -493,7 +501,6 @@ def is_quantized(node): for node in model.graph.nodes: root_node, matched, obj, qconfig = matches.get(node.name, (None, None, None, None)) if root_node is node: - print('obj:', type(obj)) result = obj.convert(self, node, load_arg) if node.op == 'call_module' and is_custom_module(self.modules[node.target]): quantized = self.modules[node.target]._output_is_observed @@ -535,8 +542,12 @@ def is_quantized(node): root_module, self.quantized_graph, load_non_quantized(node.args[0]), observer_module) continue - # dequantize inputs for the node that are not quantized - env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) + if is_child_module and node.op == 'placeholder' and graph_inputs.index(node.name) in model._observed_input_idxs: + # the node is quantized in parent module + quant_env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) + else: + # dequantize inputs for the node that are not quantized + env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) if is_child_module: # result are kepted quantized in the quantized child module @@ -699,13 +710,9 @@ def visit_arg(arg): # inputs map_arg(matched[-1].args, visit(matched[-1], qconfig)) map_arg(matched[-1].kwargs, visit(matched[-1], qconfig)) - if node.op == 'call_module': - print('modulee:', type(self.modules[node.target])) - print('custom:', is_custom_module) # output if node.op == 'call_module' and \ is_custom_module(self.modules[node.target]): - print("skipping inserting observer for output of custom module") # we don't insert observer for output of custom # module continue diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 6a0e9fed60bc..0d0f3dd02283 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -46,7 +46,7 @@ def prepare_fx(graph_module, qconfig_dict, inplace=False): A GraphModule with observer or fake quant modules, ready for calibration or quantization aware training """ - prepared, _ = _prepare_fx(graph_module, qconfig_dict, inplace, is_dynamic_quant=False) + prepared, _, _ = _prepare_fx(graph_module, qconfig_dict, inplace, is_dynamic_quant=False) return prepared def prepare_static_fx(graph_module, qconfig_dict, inplace=False): @@ -72,7 +72,7 @@ def prepare_qat_fx(graph_module, qconfig_dict, inplace=False): def prepare_dynamic_fx(graph_module, qconfig_dict, inplace=False): r""" Prepare a model for post training dynamic quantization """ - prepared, _ = _prepare_fx(graph_module, qconfig_dict, inplace, True) + prepared, _, _ = _prepare_fx(graph_module, qconfig_dict, inplace, True) return prepared def _convert_fx(graph_module, inplace, debug, is_dynamic_quant, is_child_module=False): From 8d02562774509327ec38cd11eb570f29707a866d Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 9 Sep 2020 16:26:05 -0700 Subject: [PATCH 3/4] Update on "[quant][graphmode][fx] Support quantization for custom module" Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23580642](https://our.internmc.facebook.com/intern/diff/D23580642) [ghstack-poisoned] --- test/quantization/test_quantize_fx.py | 8 ++++++-- torch/quantization/fx/quantization_patterns.py | 1 - torch/quantization/fx/quantize.py | 3 ++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index fbd94a6d86c1..f4ffa36e2165 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -254,10 +254,11 @@ def forward(self, x): return self.conv(x) from torch.fx.symbolic_trace import DefaultDelegate + class CustomDelegate(DefaultDelegate): def is_leaf_module(self, m): - return (m.__module__.startswith('torch.nn') and \ - not isinstance(m, torch.nn.Sequential)) or \ + return (m.__module__.startswith('torch.nn') and + not isinstance(m, torch.nn.Sequential)) or \ isinstance(m, CustomModule) from torch.quantization import register_traceable_custom_module_class @@ -296,7 +297,9 @@ def forward(self, x): m = symbolic_trace(original_m, delegate_class=CustomDelegate).eval() qconfig_dict = {'': default_qconfig} + # check prepared model m = prepare_fx(m, qconfig_dict) + # calibration m(data) # input and output of first conv, observer for custom module # will be inserted in the custom module itself @@ -310,6 +313,7 @@ def forward(self, x): } self.checkGraphModuleNodes(m.custom, expected_node_occurrence=count_check) + # check converted/quantized model m = convert_fx(m) count_check = { ns.call_function(torch.quantize_per_tensor) : 1, diff --git a/torch/quantization/fx/quantization_patterns.py b/torch/quantization/fx/quantization_patterns.py index 0de01dec8ce4..a0c24d970697 100644 --- a/torch/quantization/fx/quantization_patterns.py +++ b/torch/quantization/fx/quantization_patterns.py @@ -3,7 +3,6 @@ Node, ) from ..quantization_mappings import ( - get_static_quant_module_mapping, get_static_quant_module_class, get_quantized_operator, ) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index b6f7df72f8c9..48bce897e775 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -542,7 +542,8 @@ def is_quantized(node): root_module, self.quantized_graph, load_non_quantized(node.args[0]), observer_module) continue - if is_child_module and node.op == 'placeholder' and graph_inputs.index(node.name) in model._observed_input_idxs: + if is_child_module and node.op == 'placeholder' and \ + graph_inputs.index(node.name) in model._observed_input_idxs: # the node is quantized in parent module quant_env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized) else: From 7673461e5317c48bc1b5f1788dab5aed2f4436a1 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 9 Sep 2020 17:51:05 -0700 Subject: [PATCH 4/4] Update on "[quant][graphmode][fx] Support quantization for custom module" Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D23580642](https://our.internmc.facebook.com/intern/diff/D23580642) [ghstack-poisoned] --- torch/quantization/fx/fusion_patterns.py | 4 +-- torch/quantization/fx/quantize.py | 33 +++++++++++---------- torch/quantization/quantization_mappings.py | 3 +- torch/quantization/quantize_fx.py | 33 +++++++++++++++++++-- 4 files changed, 51 insertions(+), 22 deletions(-) diff --git a/torch/quantization/fx/fusion_patterns.py b/torch/quantization/fx/fusion_patterns.py index 866b4c8fa59f..fe5631d85482 100644 --- a/torch/quantization/fx/fusion_patterns.py +++ b/torch/quantization/fx/fusion_patterns.py @@ -104,8 +104,6 @@ def fuse(self, quantizer, load_arg): op_list.reverse() op_type_list = tuple(type(m) for m in op_list) module_parent_name, module_name = _parent_name(self.module_node.target) - fuser_method = OP_LIST_TO_FUSER_METHOD.get(op_type_list, None) - if fuser_method is None: - raise NotImplementedError("Cannot fuse modules: {}".format(types)) + fuser_method = get_fuser_method(op_type_list) setattr(quantizer.modules[module_parent_name], module_name, fuser_method(*op_list)) return quantizer.fused_graph.node_copy(self.module_node, load_arg) diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 48bce897e775..a4cfc6cfaa73 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -258,7 +258,7 @@ def insert_observer(node, observer, device): env[node.name] = observed_graph.node_copy(node, load_arg) elif root_node is node: # index for input of custom module that needs to be observed in parent - custom_module_input_idxs = None + child_module_input_idxs = None if node.name in custom_module_nodes: # observe custom module custom_module = self.modules[node.target] @@ -267,11 +267,9 @@ def insert_observer(node, observer, device): prepare = torch.quantization.prepare_dynamic_child_module_fx else: prepare = torch.quantization.prepare_child_module_fx - observed_custom_module, custom_module_input_idxs, \ - output_is_observed = prepare(traced_custom_module, {'': qconfig}) + observed_custom_module = prepare(traced_custom_module, {'': qconfig}) observed_custom_module._is_custom_module = True - observed_custom_module._observed_input_idxs = custom_module_input_idxs - observed_custom_module._output_is_observed = output_is_observed + child_module_input_idxs = observed_custom_module._observed_input_idxs parent_name, name = _parent_name(node.target) setattr(self.modules[parent_name], name, observed_custom_module) @@ -303,14 +301,16 @@ def is_observed(input_arg): output_is_observed = self.modules[node.target] if output_is_observed: observed_node_names_set.add(node.name) - # elif qconfig is not None and obj.all_nodes: - # # observer for outputs - # new_observer = qconfig.activation() - # # respect device affinity when adding observers - # device = assert_and_get_unique_device(model) - # insert_observer(node, new_observer, device) - if custom_module_input_idxs is not None: - for idx in custom_module_input_idxs: + elif qconfig is not None and obj.all_nodes: + # observer for outputs + new_observer = qconfig.activation() + # respect device affinity when adding observers + device = assert_and_get_unique_device(model) + insert_observer(node, new_observer, device) + + # insert observer for input of child module + if child_module_input_idxs is not None: + for idx in child_module_input_idxs: if node.args[idx].name not in observed_node_names_set: new_observer = qconfig.activation() device = assert_and_get_unique_device(model) @@ -342,11 +342,14 @@ def is_observed(input_arg): observed_graph.output(load_arg(model.graph.result)) # indicate whether output is observed or not. # This used for correctly quantize child modules - output_observed = model.graph.result.name in observed_node_names_set + output_is_observed = model.graph.result.name in observed_node_names_set model = GraphModule(model, observed_graph) self.save_state(model) - return model, observed_input_idxs, output_observed + if is_child_module: + model._observed_input_idxs = observed_input_idxs + model._output_is_observed = output_is_observed + return model def save_state(self, observed): observed._activation_post_process_map = self.activation_post_process_map diff --git a/torch/quantization/quantization_mappings.py b/torch/quantization/quantization_mappings.py index 23bf205b8172..528aea832389 100644 --- a/torch/quantization/quantization_mappings.py +++ b/torch/quantization/quantization_mappings.py @@ -28,6 +28,8 @@ nn.InstanceNorm1d: nnq.InstanceNorm1d, nn.InstanceNorm2d: nnq.InstanceNorm2d, nn.InstanceNorm3d: nnq.InstanceNorm3d, + nn.Embedding: nnq.Embedding, + nn.EmbeddingBag: nnq.EmbeddingBag, QuantStub: nnq.Quantize, DeQuantStub: nnq.DeQuantize, # Wrapper Modules: @@ -66,7 +68,6 @@ nn.LSTMCell: nnqd.LSTMCell, nn.RNNCell: nnqd.RNNCell, nn.GRUCell: nnqd.GRUCell, - nn.EmbeddingBag: nnqd.EmbeddingBag, } # Whitelist for propagating the qconfig diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 5cbabb0770bc..0dc6127ebbdf 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -25,12 +25,27 @@ def _prepare_fx(graph_module, qconfig_dict, inplace, is_dynamic_quant, is_child_ graph_module = fuse_fx(graph_module, inplace) quantizer = Quantizer() prepare = quantizer.prepare_dynamic if is_dynamic_quant else quantizer.prepare - return prepare(graph_module, qconfig_dict, inplace=True, is_child_module) + return prepare(graph_module, qconfig_dict, inplace=True, is_child_module=is_child_module) def prepare_child_module_fx(graph_module, qconfig_dict, inplace=False): + r""" Prepare a child module, so that it can be used when quantizing the + parent module. Used in custom module support + + input of the module is quantized in parent module, output of the module + is quantized in the child module. + Returns: + model(GraphModule): prepared child module with following attributes: + _observed_input_idxs(List[Int]): a list of indexs for the graph inputs that + needs to be observed in parent module + _output_is_observed(Bool): a boolean variable indicate whether the output of the + custom module is observed or not + + """ return _prepare_fx(graph_module, qconfig_dict, inplace, is_dynamic_quant=False, is_child_module=True) def prepare_dynamic_child_module_fx(graph_module, qconfig_dict, inplace=False): + r""" See :func:`~torch.quantization.prepare_child_module_fx` + """ return _prepare_fx(graph_module, qconfig_dict, inplace, is_dynamic_quant=True, is_child_module=True) def prepare_fx(graph_module, qconfig_dict, inplace=False): @@ -46,7 +61,7 @@ def prepare_fx(graph_module, qconfig_dict, inplace=False): A GraphModule with observer or fake quant modules, ready for calibration or quantization aware training """ - prepared, _, _ = _prepare_fx(graph_module, qconfig_dict, inplace, is_dynamic_quant=False) + prepared = _prepare_fx(graph_module, qconfig_dict, inplace, is_dynamic_quant=False) return prepared def prepare_static_fx(graph_module, qconfig_dict, inplace=False): @@ -72,7 +87,7 @@ def prepare_qat_fx(graph_module, qconfig_dict, inplace=False): def prepare_dynamic_fx(graph_module, qconfig_dict, inplace=False): r""" Prepare a model for post training dynamic quantization """ - prepared, _, _ = _prepare_fx(graph_module, qconfig_dict, inplace, True) + prepared = _prepare_fx(graph_module, qconfig_dict, inplace, True) return prepared def _convert_fx(graph_module, inplace, debug, is_dynamic_quant, is_child_module=False): @@ -92,9 +107,21 @@ def convert_dynamic_fx(graph_module, inplace=False, debug=False): return _convert_fx(graph_module, inplace, debug, is_dynamic_quant=True) def convert_child_module_fx(graph_module, inplace=False, debug=False): + r""" Convert a model produced by :func:`~torch.quantization.prepare_child_module_fx` + and convert it to a quantized model + + The inputs will be quantized by parent module, checks `_observed_input_idxs` of + input model and will treat these inputs as quantized + also will not dequantize the final output + Return: + A quantized child module which accepts quantized input(if needed) + and produces quantized output + """ return _convert_fx(graph_module, inplace, debug, is_dynamic_quant=False, is_child_module=True) def convert_dynamic_child_module_fx(graph_module, inplace=False, debug=False): + r""" See :func:`~torch.quantization.convert_child_module_fx` + """ return _convert_fx(graph_module, inplace, debug, is_dynamic_quant=True, is_child_module=True) def _quantize_fx(model, qconfig_dict, run_fn=None, run_args=None, inplace=False,