Skip to content

Commit

Permalink
update about NVML and environment
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Jul 15, 2020
1 parent 62ea40c commit 07e28ee
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 33 deletions.
3 changes: 2 additions & 1 deletion docs/tutorial/philosophy/dataflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ or when you need to filter your data on the fly.
but inefficient for generic data type or numpy arrays.
Also, its implementation [does not always clean up the subprocesses correctly](https://github.com/pytorch/pytorch/issues/16608).

PyTorch starts to improve on these bad assumptions (e.g., with [IterableDataset](https://github.com/pytorch/pytorch/pull/19228)).
PyTorch starts to improve on bad assumptions 1-3, (e.g., with IterableDataset).
But the interface still bears the history of these assumptions.
On the other hand, DataFlow:

1. Is an iterator, not necessarily has a length or can be indexed. This is more generic.
Expand Down
2 changes: 1 addition & 1 deletion examples/A3C-Gym/train-atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def cb(outputs):
try:
distrib, value = outputs.result()
except CancelledError:
logger.info("Client {} cancelled.".format(client.ident))
logger.info("Client {} cancelled.".format(client.ident.decode('utf-8')))
return
assert np.all(np.isfinite(distrib)), distrib
action = np.random.choice(len(distrib), p=distrib)
Expand Down
2 changes: 1 addition & 1 deletion examples/FasterRCNN/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ Performance in [Detectron](https://github.com/facebookresearch/Detectron/) can b
We compare models that have identical training & inference cost between the two implementations.
Their numbers can be different due to small implementation details.

<a id="ft2">2</a>: Our mAP is __7 point__ better than the official model in
<a id="ft2">2</a>: This model has __7 point__ better mAP than the official model in
[matterport/Mask_RCNN](https://github.com/matterport/Mask_RCNN/releases/tag/v2.0) which has the same architecture.
Our implementation is also [5x faster](https://github.com/tensorpack/benchmarks/tree/master/MaskRCNN).

Expand Down
16 changes: 8 additions & 8 deletions tensorpack/callbacks/prof.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ def worker(evt, rst_queue, stop_evt, devices):
Args:
devices (list[int])
"""
with NVMLContext() as ctx:
devices = [ctx.device(i) for i in devices]
while True:
try:
try:
with NVMLContext() as ctx:
devices = [ctx.device(i) for i in devices]
while True:
evt.wait() # start epoch
evt.clear()
if stop_evt.is_set(): # or on exit
Expand All @@ -153,10 +153,10 @@ def worker(evt, rst_queue, stop_evt, devices):
cnt -= 1
rst_queue.put(stats / cnt)
break
except Exception:
logger.exception("Exception in GPUUtilizationTracker.worker")
rst_queue.put(-1)
return
except Exception:
logger.exception("Exception in GPUUtilizationTracker.worker")
rst_queue.put(-1)
return


# Can add more features from tfprof
Expand Down
7 changes: 4 additions & 3 deletions tensorpack/tfutils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import tensorflow as tf
import numpy as np

import tensorpack
from ..compat import tfv1
from ..utils.argtools import graph_memoized
from ..utils.utils import find_library_full_path as find_library
Expand Down Expand Up @@ -172,10 +173,10 @@ def collect_env_info():
data = []
data.append(("sys.platform", sys.platform))
data.append(("Python", sys.version.replace("\n", "")))
data.append(("Tensorpack", __git_version__))
data.append(("Tensorpack", __git_version__ + " @" + os.path.dirname(tensorpack.__file__)))
data.append(("Numpy", np.__version__))

data.append(("TensorFlow", tfv1.VERSION + "/" + tfv1.GIT_VERSION))
data.append(("TensorFlow", tfv1.VERSION + "/" + tfv1.GIT_VERSION + " @" + os.path.dirname(tf.__file__)))
data.append(("TF Compiler Version", tfv1.COMPILER_VERSION))
has_cuda = tf.test.is_built_with_cuda()
data.append(("TF CUDA support", has_cuda))
Expand Down Expand Up @@ -221,7 +222,7 @@ def collect_env_info():
# Other important dependencies:
try:
import horovod
data.append(("Horovod", horovod.__version__))
data.append(("Horovod", horovod.__version__ + " @" + os.path.dirname(horovod.__file__)))
except ImportError:
pass

Expand Down
48 changes: 29 additions & 19 deletions tensorpack/utils/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,42 @@ def warn_return(ret, message):
logger.warn(message + "But TensorFlow was not built with CUDA support and could not use GPUs!")
return ret

try:
# Use NVML to query device properties
with NVMLContext() as ctx:
nvml_num_dev = ctx.num_devices()
except Exception:
nvml_num_dev = None

env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
if env:
return warn_return(len(env.split(',')), "Found non-empty CUDA_VISIBLE_DEVICES. ")
num_dev = len(env.split(','))
assert num_dev <= nvml_num_dev, \
"Only {} GPU(s) available, but CUDA_VISIBLE_DEVICES is set to {}".format(nvml_num_dev, env)
return warn_return(num_dev, "Found non-empty CUDA_VISIBLE_DEVICES. ")

output, code = subproc_call("nvidia-smi -L", timeout=5)
if code == 0:
output = output.decode('utf-8')
return warn_return(len(output.strip().split('\n')), "Found nvidia-smi. ")
try:
# Use NVML to query device properties
with NVMLContext() as ctx:
return warn_return(ctx.num_devices(), "NVML found nvidia devices. ")
except Exception:
# Fallback
logger.info("Loading local devices by TensorFlow ...")

try:
import tensorflow as tf
# available since TF 1.14
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
except AttributeError:
from tensorflow.python.client import device_lib
local_device_protos = device_lib.list_local_devices()
# Note this will initialize all GPUs and therefore has side effect
# https://github.com/tensorflow/tensorflow/issues/8136
gpu_devices = [x.name for x in local_device_protos if x.device_type == 'GPU']
return len(gpu_devices)
if nvml_num_dev is not None:
return warn_return(nvml_num_dev, "NVML found nvidia devices. ")

# Fallback to TF
logger.info("Loading local devices by TensorFlow ...")

try:
import tensorflow as tf
# available since TF 1.14
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
except AttributeError:
from tensorflow.python.client import device_lib
local_device_protos = device_lib.list_local_devices()
# Note this will initialize all GPUs and therefore has side effect
# https://github.com/tensorflow/tensorflow/issues/8136
gpu_devices = [x.name for x in local_device_protos if x.device_type == 'GPU']
return len(gpu_devices)


get_nr_gpu = get_num_gpu
2 changes: 2 additions & 0 deletions tensorpack/utils/nvml.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ def device(self, idx):
Returns:
NvidiaDevice: single GPU device
"""
num_dev = self.num_devices()
assert idx < num_dev, "Cannot obtain device {}: NVML only found {} devices.".format(idx, num_dev)

class GpuDevice(Structure):
pass
Expand Down

0 comments on commit 07e28ee

Please sign in to comment.