Skip to content

Commit

Permalink
print TF build info
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Aug 8, 2020
1 parent 775aa3c commit 43a44c1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
32 changes: 28 additions & 4 deletions tensorpack/tfutils/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# File: common.py

from collections import defaultdict
from collections import defaultdict, OrderedDict
from six.moves import map
from tabulate import tabulate
import os
Expand Down Expand Up @@ -165,6 +165,28 @@ def get_tf_version_tuple():
return tuple(map(int, tf.__version__.split('.')[:2]))


def parse_TF_build_info():
ret = OrderedDict()
from tensorflow.python.platform import build_info
try:
for k, v in list(build_info.build_info.items()):
if k == "cuda_version":
ret["TF built with CUDA"] = v
elif k == "cudnn_version":
ret["TF built with CUDNN"] = v
elif k == "cuda_compute_capabilities":
ret["TF compute capabilities"] = ",".join([k.replace("compute_", "") for k in v])
return ret
except AttributeError:
pass
try:
ret["TF built with CUDA"] = build_info.cuda_version_number
ret["TF built with CUDNN"] = build_info.cudnn_version_number
except AttributeError:
pass
return ret


def collect_env_info():
"""
Returns:
Expand Down Expand Up @@ -195,9 +217,11 @@ def collect_env_info():

if has_cuda:
data.append(("Nvidia Driver", find_library("nvidia-ml")))
data.append(("CUDA", find_library("cudart")))
data.append(("CUDNN", find_library("cudnn")))
data.append(("NCCL", find_library("nccl")))
data.append(("CUDA libs", find_library("cudart")))
data.append(("CUDNN libs", find_library("cudnn")))
for k, v in parse_TF_build_info().items():
data.append((k, v))
data.append(("NCCL libs", find_library("nccl")))

# List devices with NVML
data.append(
Expand Down
2 changes: 1 addition & 1 deletion tensorpack/tfutils/varmanip.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def get_checkpoint_path(path):
Args:
path: a user-input path
Returns:
str: the argument that can be passed to NewCheckpointReader
str: the argument that can be passed to `tf.train.NewCheckpointReader`
"""
if os.path.basename(path) == path:
path = os.path.join('.', path) # avoid #4921 and #6142
Expand Down

0 comments on commit 43a44c1

Please sign in to comment.