From 6b9986acd8e1836ca0b432225ead0db33a643608 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 23 Jun 2020 01:40:45 +0200 Subject: [PATCH] keep parameter names from PyTorch (#5887) --- python/tvm/relay/frontend/pytorch.py | 4 ++-- tests/python/frontend/pytorch/test_forward.py | 9 +++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index f70a64a6c93c..374e1c2651cc 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -2354,8 +2354,8 @@ def convert_params(graph, state_dict): elif full_attr in state_dict: torch_tensor = state_dict[full_attr] tensor, var = _get_tensor_and_var(torch_tensor, - full_attr_node_name) - param_tensors[full_attr_node_name] = tensor + full_attr) + param_tensors[full_attr] = tensor params[full_attr_node_name] = var return params, param_tensors, packed_param_map diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 6ec3110dcf69..d56496577176 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -2384,6 +2384,12 @@ def fn(t1, t2): verify_model(fn, input_data=[tensor1, tensor2]) +def test_weight_names(): + tm = torch.jit.trace(torch.nn.Linear(3, 4), [torch.randn(2, 3)]) + mod, params = relay.frontend.from_pytorch(tm, [('input', (2, 3))]) + assert set(params.keys()) == set(n for n, p in tm.named_parameters()) + + def test_forward_matmul(): torch.set_grad_enabled(False) @@ -2546,8 +2552,11 @@ def test_forward_pretrained_bert_base_uncased(): if __name__ == "__main__": + # some structural tests test_forward_traced_function() test_forward_dtypes() + test_weight_names() + # Single operator tests test_forward_add() test_forward_subtract()