Skip to content

Conversation

bddppq
Copy link
Contributor

@bddppq bddppq commented Nov 30, 2018

@bddppq bddppq added the module: rocm AMD GPU support for Pytorch label Nov 30, 2018
@ezyang
Copy link
Contributor

ezyang commented Nov 30, 2018

TBH, I'd prefer it if we just arranged things so that run_test didn't have to import torch, which would also solve the problem. It's querying it to get the number of GPUs; we can probably just shell out to get that info.

@bddppq
Copy link
Contributor Author

bddppq commented Nov 30, 2018

@ezyang The usage of GPU could come from other places that run_test imports (e.g. run_test imports common_utils which has used device_count, and functions other than device_count could trigger HIP initialization as well). If some future changes accidentally introduce gpu usage to run_test, it's hard to track because the ihipException failure are not 100% reproducible in all CI boxes.

@ezyang
Copy link
Contributor

ezyang commented Nov 30, 2018

It shouldn't import common_utils either. run_test.py literally used to be a shell script, and it should stay that: a dumb shell script.

@bddppq
Copy link
Contributor Author

bddppq commented Nov 30, 2018

Since this is going to be a temporary workaround (at the end amd should fix the underlying issue, otherwise distributed training will run into problem), I think it's not worth the effort of rewriting a 400 lines python script into a shell script. e.g. you will need to re-implement the TEST_WITH_ROCM and black_list logic (not saying it's hard, just it needs reimplementation).

@bddppq bddppq requested a review from ezyang November 30, 2018 03:35
@petrex
Copy link
Contributor

petrex commented Nov 30, 2018

Is there a way we can try this PR on worker 20 and 30 only? If the SIGIOT error rate does not drop to the similar level with other workers; there might be other issues

@bddppq
Copy link
Contributor Author

bddppq commented Nov 30, 2018

@petrex The ihipException issue has been reliably reproducible on worker 20 and 30 (that's why we have pull them out of the CI pool). I have tested on these two machines that changes in this PR work around the issue (but again this is just a workaround, we need to root cause and fix this issue, otherwise I believe we will run into problems when enabling multi-gpu training).

@ezyang There are indeed failures with running all tests in the same process, any hypothesis? :-)

05:45:11 ======================================================================
05:45:11 FAIL: test_variance_stddev (test_distributions.TestAgainstScipy)
05:45:11 ----------------------------------------------------------------------
05:45:11 Traceback (most recent call last):
05:45:11   File "/var/lib/jenkins/workspace/test/test_distributions.py", line 3732, in test_variance_stddev
05:45:11     self.assertEqual(pytorch_dist.variance, scipy_dist.var(), allow_inf=True, message=pytorch_dist)
05:45:11   File "/var/lib/jenkins/workspace/test/common_utils.py", line 381, in assertEqual
05:45:11     self.assertEqual(x.item(), y, prec, message, allow_inf)
05:45:11   File "/var/lib/jenkins/workspace/test/common_utils.py", line 439, in assertEqual
05:45:11     super(TestCase, self).assertLessEqual(abs(x - y), prec, message)
05:45:11 AssertionError: 2.9815544218135236e+72 not less than or equal to 1e-05 : Weibull(concentration: 0.0325175287673, scale: 8.64022227924)
05:45:11 
05:45:11 ======================================================================
05:45:11 FAIL: test_cdf (test_distributions.TestJit)
05:45:11 ----------------------------------------------------------------------
05:45:11 Traceback (most recent call last):
05:45:11   File "/var/lib/jenkins/workspace/test/test_distributions.py", line 4403, in test_cdf
05:45:11     message='{}\nExpected:\n{}\nActual:\n{}'.format(Dist.__name__, expected, actual))
05:45:11   File "/var/lib/jenkins/workspace/test/common_utils.py", line 414, in assertEqual
05:45:11     assertTensorsEqual(x, y)
05:45:11   File "/var/lib/jenkins/workspace/test/common_utils.py", line 406, in assertTensorsEqual
05:45:11     self.assertLessEqual(max_err, prec, message)
05:45:11 AssertionError: tensor(nan) not less than or equal to 1e-05 : Pareto
05:45:11 Expected:
05:45:11 tensor([[5.5865e+00, 5.5593e+01, 3.0885e-01, 8.0850e+25, 2.5194e+01],
05:45:11         [1.2260e+00,        inf, 3.7320e+06, 2.7672e+01, 3.4107e+00],
05:45:11         [2.4989e-01, 3.5010e-01, 1.5410e-02, 2.0407e+00, 3.4988e-01],
05:45:11         [8.5306e-02, 4.4983e+00, 1.1324e+00, 3.7377e-02, 6.2153e+02],
05:45:11         [6.8025e+00, 6.8080e+00, 5.4267e+01, 1.8983e-02, 6.8430e-01]])
05:45:11 Actual:
05:45:11 tensor([[5.5865e+00, 5.5593e+01, 3.0885e-01, 8.0850e+25, 2.5194e+01],
05:45:11         [1.2260e+00,        inf, 3.7320e+06, 2.7672e+01, 3.4107e+00],
05:45:11         [2.4989e-01, 3.5010e-01, 1.5410e-02, 2.0407e+00, 3.4988e-01],
05:45:11         [8.5306e-02, 4.4983e+00, 1.1324e+00, 3.7377e-02, 6.2153e+02],
05:45:11         [6.8025e+00, 6.8080e+00, 5.4267e+01, 1.8983e-02, 6.8430e-01]])

@ezyang
Copy link
Contributor

ezyang commented Nov 30, 2018

Yes, probably some of the tests are not properly resetting the seed, and thus by running everything in the same process RNG generation has been perturbed.

@bddppq
Copy link
Contributor Author

bddppq commented Dec 1, 2018

@pytorchbot retest this please

@bddppq bddppq force-pushed the pytorch-rocm-ci-no-fork branch from 8efeba4 to db0cfad Compare December 3, 2018 09:08
@bddppq bddppq closed this Dec 6, 2018
@bddppq bddppq reopened this Dec 19, 2018
('index_copy', (), (0, torch.tensor(0, dtype=torch.int64), torch.tensor(2.)), 'scalar_all_dim', [0]),
('index_fill', (S, S), (0, index_variable(2, S), 2), 'dim', [0]),
# FIXME: we should compute the derivative w.r.t torch.tensor(2)
def method_tests():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These test cases are used in test_autograd and test_jit. I suspect test_autograd has some places inplace modify the test inputs/outputs then causing test_jit failed when trying to reuse the same test inputs & outputs. After changing this into a function so test_autograd and test_jit get their own copy the tests are now passing.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bddppq has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

M = 10
S = 5


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surprised this didn't cause lint to fail lol

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because this wasn't a function before.

expected = f(sample, *values)
actual = traced_f(sample, *values)
self.assertEqual(expected, actual,
self.assertEqual(expected, actual, allow_inf=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's up here?

Copy link
Contributor Author

@bddppq bddppq Dec 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one elem being inf in both the expected and actual results (#14600 (comment)).

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks ok, I guess? There's a feel sprinkled seeding calls which look a bit unmotivated. But I guess it's all fine since we're still doing regular old out-of-process for non-ROCm.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: rocm AMD GPU support for Pytorch

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants