Skip to content

Commit

Permalink
Implement 'to' on ScriptModules
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Dec 18, 2018
1 parent bb9b7de commit 10fc5b3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
5 changes: 5 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,11 @@ def forward(self, x):
traced_model.cpu()
cpu_out = traced_model(x.float())
self.assertEqual(cpu_out, cuda_out)
traced_model.to('cuda')
cuda_out = traced_model(x.float().cuda())
traced_model.to('cpu')
cpu_out = traced_model(x.float())
self.assertEqual(cpu_out, cuda_out)
traced_model.double()

# state_dict + load_state_dict
Expand Down
2 changes: 1 addition & 1 deletion torch/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,7 +1267,7 @@ def _get_methods(cls):

_compiled_methods_whitelist = {
'forward', 'register_buffer', 'register_parameter', 'add_module',
'_apply', 'apply', 'cuda', 'cpu', 'type', 'float', 'double', 'half',
'_apply', 'apply', 'cuda', 'cpu', 'to', 'type', 'float', 'double', 'half',
'state_dict', 'load_state_dict', '_load_from_state_dict',
'_named_members', 'parameters', 'named_parameters',
'buffers', 'named_buffers', 'children', 'named_children', 'modules',
Expand Down

0 comments on commit 10fc5b3

Please sign in to comment.