Skip to content

Commit

Permalink
Merge pull request #182 from pytorch/myguo/s3
Browse files Browse the repository at this point in the history
add s3&azure blob&google cloud support for tb_plugin
  • Loading branch information
guotuofeng committed May 11, 2021
2 parents 943af94 + b1e3db9 commit edf8708
Show file tree
Hide file tree
Showing 19 changed files with 1,458 additions and 227 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tb_plugin_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ jobs:
set -e
cd tb_plugin
sh ./ci_scripts/install_env.sh
pip install .
pip install .[gs]
cd test
pytest
24 changes: 24 additions & 0 deletions tb_plugin/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,30 @@ and give optimization recommendations.

If the files under `--logdir` are too big or too many,
please wait a while and refresh the browser to check latest loaded result.
* AWS(S3://), Azure blob(https://\<account\>.blob.core.windows.net) and Google Cloud(GS://) supports
* S3: install boto3. set environment variables: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`. Optionally, `S3_ENDPOINT` can be set as well.\
For minio, the S3 url should start with the bucket name `s3://<bucket>/<folder>/` instead of minio prefix `s3://minio/<bucket>/<folder>`. At the same time, the `S3_ENDPOINT` is needed as well. \
For example, the following command can be used to create minio storage after following guides:
* Server: https://docs.min.io/docs/minio-quickstart-guide.html
* MC Client: https://docs.min.io/docs/minio-client-quickstart-guide.html

```bash
./mc alias set s3 http://10.150.148.189:9000 minioadmin minioadmin
./mc mb s3/profiler --region=us-east-1
./mc cp ~/notebook/version_2 s3/profiler/ --recursive
export AWS_ACCESS_KEY_ID=minioadmin
export AWS_SECRET_ACCESS_KEY=minioadmin
export AWS_REGION=us-east-1
export S3_USE_HTTPS=0
export S3_VERIFY_SSL=0
export S3_ENDPOINT=http://localhost:9000
tensorboard --logdir=s3://profiler/version_2/ --bind_all
```
* Azure Blob: install azure-storage-blob. Optionally, set environment variable `AZURE_STORAGE_CONNECTION_STRING`
* Google Cloud: install google-cloud-storage.
---
> **_NOTES:_** For AWS, Google Cloud and Azure Blob, the trace files need to be put on a top level folder under bucket/container.
---

### Quick Usage Instructions

Expand Down
14 changes: 8 additions & 6 deletions tb_plugin/setup.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# --------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import setuptools

Expand All @@ -27,6 +22,12 @@ def get_version():
"torchvision >= 0.8"
]

EXTRAS = {
"s3": ["boto3"],
"blob": ["azure-storage-blob"],
"gs": ["google-cloud-storage"]
}

setuptools.setup(
name="torch_tb_profiler",
version=get_version(),
Expand All @@ -45,7 +46,7 @@ def get_version():
"torch_profiler = torch_tb_profiler.plugin:TorchProfilerPlugin",
],
},
python_requires=">= 2.7, != 3.0.*, != 3.1.*",
python_requires=">=3.6.2",
install_requires=INSTALL_REQUIRED,
tests_require=TESTS_REQUIRED,
classifiers=[
Expand All @@ -63,4 +64,5 @@ def get_version():
],
license='BSD-3',
keywords='pytorch tensorboard profile plugin',
extras_require=EXTRAS
)
79 changes: 51 additions & 28 deletions tb_plugin/test/test_tensorboard_end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,39 @@

class TestEnd2End(unittest.TestCase):

def test_tensorboard_gs(self):
test_folder = 'gs://pe-tests-public/tb_samples/'
expected_runs = b'["resnet50_profiler_api_num_workers_0", "resnet50_profiler_api_num_workers_4"]'
self._test_tensorboard_with_arguments(test_folder, expected_runs, {'TORCH_PROFILER_START_METHOD':'spawn'})

def test_tensorboard_end2end(self):
test_folder = os.path.join(os.path.dirname(os.path.abspath(__file__)),'../samples')
expected_runs = b'["resnet50_num_workers_0", "resnet50_num_workers_4"]'

print("starting fork mode testing")
self._test_tensorboard_with_arguments(test_folder, expected_runs)
print("starting spawn mode testing...")
self._test_tensorboard_with_arguments(test_folder, expected_runs, {'TORCH_PROFILER_START_METHOD':'spawn'})

def _test_tensorboard_with_arguments(self, test_folder, expected_runs, env=None):
host='localhost'
port=6006

try:
if env:
env_copy = os.environ.copy()
env_copy.update(env)
env = env_copy
tb = Popen(['tensorboard', '--logdir='+test_folder, '--port='+str(port)], env=env)
self._test_tensorboard(host, port, expected_runs)
finally:
pid = tb.pid
tb.terminate()
print("tensorboard process {} is terminated.".format(pid))

def _test_tensorboard(self, host, port, expected_runs):
link_prefix = 'http://{}:{}/data/plugin/pytorch_profiler/'.format(host, port)
run_link = link_prefix + 'runs'
expected_runs = b'["resnet50_num_workers_0", "resnet50_num_workers_4"]'

expected_links_format=[
link_prefix + 'overview?run={}&worker=worker0&view=Overview',
Expand All @@ -28,51 +52,50 @@ def test_tensorboard_end2end(self):
link_prefix + 'kernel?run={}&worker=worker0&view=Kernel&group_by=Kernel'
]

tb = Popen(['tensorboard', '--logdir='+test_folder, '--port='+str(port)])

timeout = 60
retry_times = 60
while True:
try:
socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect((host, port))
print('tensorboard start successfully')
break
except socket.error:
time.sleep(2)
timeout -= 1
if timeout < 0:
tb.kill()
raise RuntimeError("tensorboard start timeout")
retry_times -= 1
if retry_times < 0:
self.fail("tensorboard start timeout")
continue

timeout = 60
retry_times = 60

while True:
try:
response = urllib.request.urlopen(run_link)
if response.read()==expected_runs:
data = response.read()
if data == expected_runs:
break
if retry_times % 10 == 0:
print("receive mismatched data, retrying", data)
time.sleep(2)
timeout -= 1
if timeout<0:
tb.kill()
raise RuntimeError("Load run timeout")
except Exception:
continue
retry_times -= 1
if retry_times<0:
self.fail("Load run timeout")
except Exception as e:
if retry_times > 0:
continue
else:
print(e)
self.fail("exception happens {}".format(e))

links=[]
for run in json.loads(expected_runs):
for expected_link in expected_links_format:
links.append(expected_link.format(run))

try:
with open('result_check_file.txt', 'r') as f:
lines=f.readlines()
i = 0
for link in links:
response = urllib.request.urlopen(link)
self.assertEqual(response.read(), lines[i].strip().encode(encoding="utf-8"))
i = i + 1
self.assertEqual(i, 10)
finally:
tb.kill()

with open('result_check_file.txt', 'r') as f:
lines=f.readlines()
i = 0
for link in links:
response = urllib.request.urlopen(link)
self.assertEqual(response.read(), lines[i].strip().encode(encoding="utf-8"))
i = i + 1
self.assertEqual(i, 10)
13 changes: 6 additions & 7 deletions tb_plugin/torch_tb_profiler/consts.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# --------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import re
from collections import namedtuple

PLUGIN_NAME = "pytorch_profiler"

TRACE_FILE_SUFFIX = ".pt.trace.json"
TRACE_GZIP_FILE_SUFFIX = ".pt.trace.json.gz"
WORKER_PATTERN = re.compile(r"""^(.*?) # worker name
# TODO: uncomment the following line when we need supprort multiple steps
# (?:\.\d+)? # optional timestamp like 1619499959628
\.pt\.trace\.json # the ending suffix
(?:\.gz)?$""", re.X) # optional .gz extension

MONITOR_RUN_REFRESH_INTERNAL_IN_SECONDS = 10

Expand Down
4 changes: 4 additions & 0 deletions tb_plugin/torch_tb_profiler/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .cache import Cache
from .file import (BaseFileSystem, File, StatData, abspath, basename,
download_file, exists, get_filesystem, glob, isdir, join,
listdir, makedirs, register_filesystem, relpath, walk)
66 changes: 66 additions & 0 deletions tb_plugin/torch_tb_profiler/io/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import multiprocessing as mp
import os

from .. import utils
from .file import File, download_file

logger = utils.get_logger()

class Cache:
def __init__(self):
self._lock = mp.Lock()
self._manager = mp.Manager()
self._cache_dict = self._manager.dict()
self._tempfiles = self._manager.list()

def __getstate__(self):
'''The multiprocessing module can start one of three ways: spawn, fork, or forkserver.
The default mode is fork in Unix and spawn on Windows and macOS.
Therefore, the __getstate__ and __setstate__ are used to pickle/unpickle the state in spawn mode.
'''
data = self.__dict__.copy()
# remove the _manager to bypass the following pickle error
# TypeError: cannot pickle 'weakref' object
del data['_manager']
logger.debug("Cache.__getstate__: %s " % data)
return data

def __setstate__(self, state):
'''The default logging level in new process is warning. Only warning and error log can be written to
streams.
'''
with utils.mp_logging() as logger:
logger.debug("Cache.__setstate__ %s " % state)
self.__dict__.update(state)

def read(self, filename):
local_file = self._cache_dict.get(filename)
if local_file is None:
local_file = download_file(filename)
# skip the cache for local files
if local_file != filename:
with self._lock:
self._cache_dict[filename] = local_file

logger.debug("reading local cache %s for file %s" % (local_file, filename))
with File(local_file, 'rb') as f:
return f.read()

def add_tempfile(self, filename):
self._tempfiles.append(filename)

def close(self):
for file in self._tempfiles:
logger.info("remove tempfile %s" % file)
os.remove(file)
for key, value in self._cache_dict.items():
if key != value:
logger.info("remove temporary file %s" % value)
os.remove(value)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.close()
self._manager.__exit__(exc_type, exc_value, traceback)
Loading

0 comments on commit edf8708

Please sign in to comment.