Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 90 additions & 1 deletion test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import tempfile

import torch
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCUDA, toleranceOverride, tol
)
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_utils import (
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from)
Expand Down Expand Up @@ -156,6 +158,93 @@ def test_pickle(self, device, dtype, module_info):
output_from_copy = m_copy(*args, **kwargs)
self.assertEqual(output, output_from_copy)

def _retain_grad(self, obj):
# gradients needs to be retained to check for grad. This is useful when
# non-leafs are present in the graph.
if isinstance(obj, dict):
for i in obj.values():
self._retain_grad(i)
elif isinstance(obj, (tuple, list)):
for i in obj:
self._retain_grad(i)
elif isinstance(obj, torch.Tensor) and obj.requires_grad:
obj.retain_grad()

def _get_grads(self, obj):
if isinstance(obj, (tuple, list)):
return tuple(self._get_grads(o) for o in obj)
elif isinstance(obj, dict):
return {name: self._get_grads(o) for name, o in obj.items()}
elif isinstance(obj, torch.Tensor) and obj.requires_grad:
return obj.grad

@onlyCUDA
@toleranceOverride({torch.float32: tol(5e-2, 0),
torch.float64: tol(4e-4, 0)})
@modules(module_db)
def test_cpu_gpu_parity(self, device, dtype, module_info):
# Test cpu and gpu results are the same
module_cls = module_info.module_cls
module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype,
requires_grad=True)

def _to_device(obj):
if isinstance(obj, torch.Tensor):
res = obj.detach().to(device=device)
res.requires_grad = obj.requires_grad
return res
elif isinstance(obj, tuple):
return tuple(_to_device(o) for o in obj)
elif isinstance(obj, dict):
return {key: _to_device(o) for key, o in obj.items()}
else:
return deepcopy(obj)

for module_input in module_inputs_cpu:

# === Move input from cpu to device ===
cpu_forward_args = module_input.forward_input.args
cpu_forward_kwargs = module_input.forward_input.kwargs

gpu_forward_args, gpu_forward_kwargs = _to_device((cpu_forward_args, cpu_forward_kwargs))

self._retain_grad((cpu_forward_args, cpu_forward_kwargs, gpu_forward_args, gpu_forward_kwargs))

# === Construct module on cpu and gpu ===
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs

cpu_module = module_cls(*args, **kwargs).to(dtype).to("cpu")
gpu_module = module_cls(*args, **kwargs).to(dtype).to(device)

for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()):
gpu_p.data.copy_(cpu_p)

# === Compare forward output between cpu and gpu ===
cpu_output = cpu_module(*cpu_forward_args, **cpu_forward_kwargs)
gpu_output = gpu_module(*gpu_forward_args, **gpu_forward_kwargs)

self.assertEqual(cpu_output, gpu_output)

# === Run backwards on CPU and GPU and compare results ===
for _ in range(5):
cpu_grad_output = cpu_output.clone().normal_()
gpu_grad_output = cpu_grad_output.type_as(gpu_output)

cpu_output.backward(cpu_grad_output, retain_graph=True)
gpu_output.backward(gpu_grad_output, retain_graph=True)

cpu_grad_input = self._get_grads(cpu_forward_args)
gpu_grad_input = self._get_grads(gpu_forward_args)
self.assertEqual(cpu_grad_input, gpu_grad_input)

for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()):
self.assertEqual(cpu_p.grad, gpu_p.grad)

cpu_grad_kwarg_input = self._get_grads(cpu_forward_kwargs)
gpu_grad_kwarg_input = self._get_grads(gpu_forward_kwargs)
self.assertEqual(cpu_grad_kwarg_input, gpu_grad_kwarg_input)


@modules([module_info for module_info in module_db
if 'inplace' in signature(module_info.module_cls).parameters])
def test_check_inplace(self, device, dtype, module_info):
Expand Down
13 changes: 7 additions & 6 deletions torch/testing/_internal/common_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, **k

module_inputs = [
ModuleInput(constructor_input=FunctionInput(10, 8),
forward_input=FunctionInput(make_input((4, 10))),
reference_fn=lambda m, p, i: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8)),
forward_input=FunctionInput(input=make_input((4, 10))),
reference_fn=lambda m, p, input: torch.mm(input, p[0].t()) + p[1].view(1, -1).expand(4, 8)),
ModuleInput(constructor_input=FunctionInput(10, 8, bias=False),
forward_input=FunctionInput(make_input((4, 10))),
desc='no_bias',
Expand All @@ -176,13 +176,14 @@ def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, **k

def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)

cases: List[Tuple[str, dict]] = [
('', {}),
('ignore_index', {'ignore_index': 2}),
('weights', {'weight': make_input(10)}),
('weights_ignore_index', {'weight': make_input(10), 'ignore_index': 2}),
('weights_ignore_index_neg', {'weight': make_input(10), 'ignore_index': -1})
# ('ignore_index', {'ignore_index': 2}),
# ('weights', {'weight': make_weight(10).abs()}),
# ('weights_ignore_index', {'weight': make_weight(10).abs(), 'ignore_index': 2}),
# ('weights_ignore_index_neg', {'weight': make_weight(10).abs(), 'ignore_index': -1})
]
module_inputs = []
for desc, constructor_kwargs in cases:
Expand Down