Skip to content

Commit

Permalink
Add support for snoop (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm committed Jul 10, 2019
1 parent be3ce93 commit 2ec0ec8
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 0 deletions.
10 changes: 10 additions & 0 deletions README.md
Expand Up @@ -29,6 +29,16 @@ To install:
pip install torchsnooper
```

TorchSnooper also support [snoop](https://github.com/alexmojaki/snoop). To use TorchSnooper with snoop, simply execute:
```python
torchsnooper.register_snoop()
```
or
```python
torchsnooper.register_snoop(verbose=True)
```
at the beginning, and use snoop normally.

# Example 1: Monitoring device and dtype

We're writing a simple function:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -27,5 +27,6 @@
'torch',
'python-toolbox',
'coverage',
'snoop',
],
)
105 changes: 105 additions & 0 deletions tests/test_snoop.py
@@ -0,0 +1,105 @@
import torch
import numpy
import math
import sys
import torchsnooper
from python_toolbox import sys_tools
import re
import snoop
import copy


ansi_escape = re.compile(r'\x1B\[[0-?]*[ -/]*[@-~]')
default_config = copy.copy(snoop.config)


def func():
x = torch.tensor(math.inf)
x = torch.tensor(math.nan)
x = torch.tensor(1.0, requires_grad=True)
x = torch.tensor([1.0, math.nan, math.inf])
x = numpy.zeros((2, 2))
x = (x, x)


verbose_expect = '''
21:43:42.09 >>> Call to func in File "test_snoop.py", line 16
21:43:42.09 16 | def func():
21:43:42.09 17 | x = torch.tensor(math.inf)
21:43:42.10 .......... x = tensor(inf)
21:43:42.10 .......... x.shape = ()
21:43:42.10 .......... x.dtype = torch.float32
21:43:42.10 .......... x.device = device(type='cpu')
21:43:42.10 .......... x.requires_grad = False
21:43:42.10 .......... x.has_nan = False
21:43:42.10 .......... x.has_inf = True
21:43:42.10 18 | x = torch.tensor(math.nan)
21:43:42.10 .......... x = tensor(nan)
21:43:42.10 .......... x.has_nan = True
21:43:42.10 .......... x.has_inf = False
21:43:42.10 19 | x = torch.tensor(1.0, requires_grad=True)
21:43:42.10 .......... x = tensor(1., requires_grad=True)
21:43:42.10 .......... x.requires_grad = True
21:43:42.10 .......... x.has_nan = False
21:43:42.10 20 | x = torch.tensor([1.0, math.nan, math.inf])
21:43:42.10 .......... x = tensor([1., nan, inf])
21:43:42.10 .......... x.shape = (3,)
21:43:42.10 .......... x.requires_grad = False
21:43:42.10 .......... x.has_nan = True
21:43:42.10 .......... x.has_inf = True
21:43:42.10 21 | x = numpy.zeros((2, 2))
21:43:42.10 .......... x = array([[0., 0.],
21:43:42.10 [0., 0.]])
21:43:42.10 .......... x.shape = (2, 2)
21:43:42.10 .......... x.dtype = dtype('float64')
21:43:42.10 22 | x = (x, x)
21:43:42.10 .......... x = (array([[0., 0.],
21:43:42.10 [0., 0.]]), array([[0., 0.],
21:43:42.10 [0., 0.]]))
21:43:42.10 .......... len(x) = 2
21:43:42.10 <<< Return value from func: None
'''.strip()

terse_expect = '''
21:44:09.63 >>> Call to func in File "test_snoop.py", line 16
21:44:09.63 16 | def func():
21:44:09.63 17 | x = torch.tensor(math.inf)
21:44:09.63 .......... x = tensor<(), float32, cpu, has_inf>
21:44:09.63 18 | x = torch.tensor(math.nan)
21:44:09.63 .......... x = tensor<(), float32, cpu, has_nan>
21:44:09.63 19 | x = torch.tensor(1.0, requires_grad=True)
21:44:09.63 .......... x = tensor<(), float32, cpu, grad>
21:44:09.63 20 | x = torch.tensor([1.0, math.nan, math.inf])
21:44:09.63 .......... x = tensor<(3,), float32, cpu, has_nan, has_inf>
21:44:09.63 21 | x = numpy.zeros((2, 2))
21:44:09.63 .......... x = ndarray<(2, 2), float64>
21:44:09.63 22 | x = (x, x)
21:44:09.63 .......... x = (ndarray<(2, 2), float64>, ndarray<(2, 2), float64>)
21:44:09.63 <<< Return value from func: None
'''.strip()


def clean_output(input_):
lines = input_.splitlines()[1:]
lines = [x[len('21:14:00.89 '):] for x in lines]
return '\n'.join(lines)


def assert_output(verbose, expect):
torchsnooper.register_snoop(verbose=verbose)
with sys_tools.OutputCapturer(stdout=False, stderr=True) as output_capturer:
assert sys.gettrace() is None
snoop(func)()
assert sys.gettrace() is None
output = output_capturer.string_io.getvalue()
output = ansi_escape.sub('', output)
assert clean_output(output) == clean_output(expect)
snoop.config = default_config


def test_verbose():
assert_output(True, verbose_expect)


def test_terse():
assert_output(False, terse_expect)
22 changes: 22 additions & 0 deletions torchsnooper/__init__.py
Expand Up @@ -151,3 +151,25 @@ def compute_repr(self, x):


snoop = TorchSnooper


def register_snoop(verbose=False, tensor_format=default_format, numpy_format=default_numpy_format):
import snoop
if verbose:
snoop.config.watch_extras += (
lambda source, value: ('{}.device'.format(source), value.device),
lambda source, value: ('{}.requires_grad'.format(source), value.requires_grad),
lambda source, value: ('{}.has_nan'.format(source), bool(torch.isnan(value).any())),
lambda source, value: ('{}.has_inf'.format(source), bool(torch.isinf(value).any())),
)
else:
import cheap_repr
import snoop.configuration
cheap_repr.register_repr(torch.Tensor)(lambda x, _: tensor_format(x))
cheap_repr.register_repr(numpy.ndarray)(lambda x, _: numpy_format(x))
cheap_repr.cheap_repr(torch.zeros(6))
unwanted = {
snoop.configuration.len_shape_watch,
snoop.configuration.dtype_watch,
}
snoop.config.watch_extras = tuple(x for x in snoop.config.watch_extras if x not in unwanted)

0 comments on commit 2ec0ec8

Please sign in to comment.