Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Allow autocast for 1.6 #2384

Merged
merged 13 commits into from
Jul 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 12 additions & 4 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def is_iterable(obj):
class TestCase(unittest.TestCase):
precision = 1e-5

def assertExpected(self, output, subname=None, prec=None):
def assertExpected(self, output, subname=None, prec=None, strip_suffix=None):
r"""
Test that a python value matches the recorded contents of a file
derived from the name of this test and subname. The value must be
Expand All @@ -96,16 +96,24 @@ def assertExpected(self, output, subname=None, prec=None):

If you call this multiple times in a single function, you must
give a unique subname each time.

strip_suffix allows different tests that expect similar numerics, e.g.
"test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data.
test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass
strip_suffix="_cpu", and they would both use a data file name based on
"test_xyz".
"""
def remove_prefix(text, prefix):
def remove_prefix_suffix(text, prefix, suffix):
if text.startswith(prefix):
return text[len(prefix):]
text = text[len(prefix):]
if suffix is not None and text.endswith(suffix):
text = text[:len(text) - len(suffix)]
return text
# NB: we take __file__ from the module that defined the test
# class, so we place the expect directory where the test script
# lives, NOT where test/common_utils.py lives.
module_id = self.__class__.__module__
munged_id = remove_prefix(self.id(), module_id + ".")
munged_id = remove_prefix_suffix(self.id(), module_id + ".", strip_suffix)
test_file = os.path.realpath(sys.modules[module_id].__file__)
expected_file = os.path.join(os.path.dirname(test_file),
"expect",
Expand Down
206 changes: 138 additions & 68 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,72 +74,114 @@ def get_available_video_models():
}


# The following models exhibit flaky numerics under autocast in _test_*_model harnesses.
# This may be caused by the harness environment (e.g. num classes, input initialization
# via torch.rand), and does not prove autocast is unsuitable when training with real data
# (autocast has been used successfully with real data for some of these models).
# TODO: investigate why autocast numerics are flaky in the harnesses.
#
# For the following models, _test_*_model harnesses skip numerical checks on outputs when
# trying autocast. However, they still try an autocasted forward pass, so they still ensure
# autocast coverage suffices to prevent dtype errors in each model.
autocast_flaky_numerics = (
"fasterrcnn_resnet50_fpn",
"inception_v3",
"keypointrcnn_resnet50_fpn",
"maskrcnn_resnet50_fpn",
"resnet101",
"resnet152",
"wide_resnet101_2",
)


class ModelTester(TestCase):
def checkModule(self, model, name, args):
if name not in script_test_models:
return
unwrapper = script_test_models[name].get('unwrapper', None)
return super(ModelTester, self).checkModule(model, args, unwrapper=unwrapper, skip=False)

def _test_classification_model(self, name, input_shape):
def _test_classification_model(self, name, input_shape, dev):
set_rng_seed(0)
# passing num_class equal to a number other than 1000 helps in making the test
# more enforcing in nature
model = models.__dict__[name](num_classes=50)
model.eval()
x = torch.rand(input_shape)
model.eval().to(device=dev)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)
self.assertExpected(out, prec=0.1)
self.assertExpected(out.cpu(), prec=0.1, strip_suffix="_" + dev)
self.assertEqual(out.shape[-1], 50)
self.checkModule(model, name, (x,))

def _test_segmentation_model(self, name):
if dev == "cuda":
with torch.cuda.amp.autocast():
out = model(x)
# See autocast_flaky_numerics comment at top of file.
if name not in autocast_flaky_numerics:
self.assertExpected(out.cpu(), prec=0.1, strip_suffix="_" + dev)
self.assertEqual(out.shape[-1], 50)

def _test_segmentation_model(self, name, dev):
# passing num_class equal to a number other than 1000 helps in making the test
# more enforcing in nature
model = models.segmentation.__dict__[name](num_classes=50, pretrained_backbone=False)
model.eval()
model.eval().to(device=dev)
input_shape = (1, 3, 300, 300)
x = torch.rand(input_shape)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))
self.checkModule(model, name, (x,))

def _test_detection_model(self, name):
if dev == "cuda":
with torch.cuda.amp.autocast():
out = model(x)
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))

def _test_detection_model(self, name, dev):
set_rng_seed(0)
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
model.eval()
model.eval().to(device=dev)
input_shape = (3, 300, 300)
x = torch.rand(input_shape)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
model_input = [x]
out = model(model_input)
self.assertIs(model_input[0], x)
self.assertEqual(len(out), 1)

def subsample_tensor(tensor):
num_elems = tensor.numel()
num_samples = 20
if num_elems <= num_samples:
return tensor

flat_tensor = tensor.flatten()
ith_index = num_elems // num_samples
return flat_tensor[ith_index - 1::ith_index]

def compute_mean_std(tensor):
# can't compute mean of integral tensor
tensor = tensor.to(torch.double)
mean = torch.mean(tensor)
std = torch.std(tensor)
return {"mean": mean, "std": std}

# maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now
# compare results with mean and std
if name == "maskrcnn_resnet50_fpn":
test_value = map_nested_tensor_object(out, tensor_map_fn=compute_mean_std)
# mean values are small, use large prec
self.assertExpected(test_value, prec=.01)
else:
self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor), prec=0.01)
def check_out(out):
self.assertEqual(len(out), 1)

def subsample_tensor(tensor):
num_elems = tensor.numel()
num_samples = 20
if num_elems <= num_samples:
return tensor

flat_tensor = tensor.flatten()
ith_index = num_elems // num_samples
return flat_tensor[ith_index - 1::ith_index]

def compute_mean_std(tensor):
# can't compute mean of integral tensor
tensor = tensor.to(torch.double)
mean = torch.mean(tensor)
std = torch.std(tensor)
return {"mean": mean, "std": std}

# maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now
# compare results with mean and std
if name == "maskrcnn_resnet50_fpn":
test_value = map_nested_tensor_object(out, tensor_map_fn=compute_mean_std)
# mean values are small, use large prec
self.assertExpected(test_value, prec=.01, strip_suffix="_" + dev)
else:
self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor),
prec=0.01,
strip_suffix="_" + dev)

check_out(out)

scripted_model = torch.jit.script(model)
scripted_model.eval()
Expand All @@ -156,6 +198,13 @@ def compute_mean_std(tensor):
# self.check_script(model, name)
self.checkModule(model, name, ([x],))

if dev == "cuda":
with torch.cuda.amp.autocast():
out = model(model_input)
# See autocast_flaky_numerics comment at top of file.
if name not in autocast_flaky_numerics:
check_out(out)

def _test_detection_model_validation(self, name):
set_rng_seed(0)
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
Expand All @@ -179,18 +228,24 @@ def _test_detection_model_validation(self, name):
targets = [{'boxes': boxes}]
self.assertRaises(ValueError, model, x, targets=targets)

def _test_video_model(self, name):
def _test_video_model(self, name, dev):
# the default input shape is
# bs * num_channels * clip_len * h *w
input_shape = (1, 3, 4, 112, 112)
# test both basicblock and Bottleneck
model = models.video.__dict__[name](num_classes=50)
model.eval()
x = torch.rand(input_shape)
model.eval().to(device=dev)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)
self.checkModule(model, name, (x,))
self.assertEqual(out.shape[-1], 50)

if dev == "cuda":
with torch.cuda.amp.autocast():
out = model(x)
self.assertEqual(out.shape[-1], 50)

def _make_sliced_model(self, model, stop_layer):
layers = OrderedDict()
for name, layer in model.named_children():
Expand Down Expand Up @@ -272,6 +327,12 @@ def test_googlenet_eval(self):

@unittest.skipIf(not torch.cuda.is_available(), 'needs GPU')
def test_fasterrcnn_switch_devices(self):
def checkOut(out):
self.assertEqual(len(out), 1)
self.assertTrue("boxes" in out[0])
self.assertTrue("scores" in out[0])
self.assertTrue("labels" in out[0])

model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
model.cuda()
model.eval()
Expand All @@ -280,17 +341,20 @@ def test_fasterrcnn_switch_devices(self):
model_input = [x]
out = model(model_input)
self.assertIs(model_input[0], x)
self.assertEqual(len(out), 1)
self.assertTrue("boxes" in out[0])
self.assertTrue("scores" in out[0])
self.assertTrue("labels" in out[0])

checkOut(out)

with torch.cuda.amp.autocast():
out = model(model_input)

checkOut(out)

# now switch to cpu and make sure it works
model.cpu()
x = x.cpu()
out_cpu = model([x])
self.assertTrue("boxes" in out_cpu[0])
self.assertTrue("scores" in out_cpu[0])
self.assertTrue("labels" in out_cpu[0])

checkOut(out_cpu)

def test_generalizedrcnn_transform_repr(self):

Expand All @@ -312,34 +376,40 @@ def test_generalizedrcnn_transform_repr(self):
self.assertEqual(t.__repr__(), expected_string)


_devs = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]


for model_name in get_available_classification_models():
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name):
input_shape = (1, 3, 224, 224)
if model_name in ['inception_v3']:
input_shape = (1, 3, 299, 299)
self._test_classification_model(model_name, input_shape)
for dev in _devs:
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name, dev=dev):
input_shape = (1, 3, 224, 224)
if model_name in ['inception_v3']:
input_shape = (1, 3, 299, 299)
self._test_classification_model(model_name, input_shape, dev)

setattr(ModelTester, "test_" + model_name, do_test)
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)


for model_name in get_available_segmentation_models():
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name):
self._test_segmentation_model(model_name)
for dev in _devs:
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name, dev=dev):
self._test_segmentation_model(model_name, dev)

setattr(ModelTester, "test_" + model_name, do_test)
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)


for model_name in get_available_detection_models():
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name):
self._test_detection_model(model_name)
for dev in _devs:
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name, dev=dev):
self._test_detection_model(model_name, dev)

setattr(ModelTester, "test_" + model_name, do_test)
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)

def do_validation_test(self, model_name=model_name):
self._test_detection_model_validation(model_name)
Expand All @@ -348,11 +418,11 @@ def do_validation_test(self, model_name=model_name):


for model_name in get_available_video_models():
for dev in _devs:
def do_test(self, model_name=model_name, dev=dev):
self._test_video_model(model_name, dev)

def do_test(self, model_name=model_name):
self._test_video_model(model_name)

setattr(ModelTester, "test_" + model_name, do_test)
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)

if __name__ == '__main__':
unittest.main()
20 changes: 16 additions & 4 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,30 @@ def _test_backward(self, device, contiguous):


class RoIOpTester(OpTester):
def _test_forward(self, device, contiguous):
def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None):
x_dtype = self.dtype if x_dtype is None else x_dtype
rois_dtype = self.dtype if rois_dtype is None else rois_dtype
pool_size = 5
# n_channels % (pool_size ** 2) == 0 required for PS opeartions.
n_channels = 2 * (pool_size ** 2)
x = torch.rand(2, n_channels, 10, 10, dtype=self.dtype, device=device)
x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device)
if not contiguous:
x = x.permute(0, 1, 3, 2)
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
[0, 0, 5, 4, 9],
[0, 5, 5, 9, 9],
[1, 0, 0, 9, 9]],
dtype=self.dtype, device=device)
dtype=rois_dtype, device=device)

pool_h, pool_w = pool_size, pool_size
y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1)
# the following should be true whether we're running an autocast test or not.
self.assertTrue(y.dtype == x.dtype)
gt_y = self.expected_fn(x, rois, pool_h, pool_w, spatial_scale=1,
sampling_ratio=-1, device=device, dtype=self.dtype)

self.assertTrue(torch.allclose(gt_y, y))
tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
self.assertTrue(torch.allclose(gt_y.to(y.dtype), y, rtol=tol, atol=tol))

def _test_backward(self, device, contiguous):
pool_size = 2
Expand Down Expand Up @@ -290,6 +295,13 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r
def _test_boxes_shape(self):
self._helper_boxes_shape(ops.roi_align)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_align_autocast(self):
for x_dtype in (torch.float, torch.half):
for rois_dtype in (torch.float, torch.half):
with torch.cuda.amp.autocast():
self._test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)


class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
Expand Down