Skip to content

Commit

Permalink
Add HDFS support to torch_tb_profiler (#793)
Browse files Browse the repository at this point in the history
Summary:
This is to complete #766

**what**
Extend torch_tb_profiler with Hadoop file system, to allow tensorboard to read pytorch profiling result stored on HDFS. The implementation leverages `fsspec` (and pyarrow under the hood) to interact with HDFS. It works with various hdfs and hadoop setup as long as HADOOP_HOME and hadoop lib & bin are correctly configured.

**testing done**
tested with HDFS installed in my local linux box and also a deployed remote hadoop cluster.

Pull Request resolved: #793

Reviewed By: chaekit

Differential Revision: D48039682

Pulled By: aaronenyeshi

fbshipit-source-id: 8eb80f85c887934bd023d5dce96cb80358254a98
  • Loading branch information
yundai424 authored and facebook-github-bot committed Aug 4, 2023
1 parent 170d45a commit 8d4234c
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 1 deletion.
6 changes: 6 additions & 0 deletions tb_plugin/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ and give optimization recommendations.

`pip install torch-tb-profiler`

To install with S3 / AzureBlob / GCS / HDFS extension, `pip install torch-tb-profiler[s3/blob/gs/hdfs]`, for example `pip install torch-tb-profiler[s3]`

* Or you can install from source

Clone the git repository:
Expand Down Expand Up @@ -93,6 +95,10 @@ and give optimization recommendations.
* Google Cloud (GS://)

Install `google-cloud-storage`.

* HDFS (hdfs://)

Install `fsspec` and `pyarrow`. Optionally set environment variable `HADOOP_HOME`.

---
> **_NOTES:_** For AWS S3, Google Cloud and Azure Blob, the trace files need to be put on a top level folder under bucket/container.
Expand Down
3 changes: 2 additions & 1 deletion tb_plugin/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def get_version(rel_path):
EXTRAS = {
"s3": ["boto3"],
"blob": ["azure-storage-blob"],
"gs": ["google-cloud-storage"]
"gs": ["google-cloud-storage"],
"hdfs": ["fsspec", "pyarrow"]
}


Expand Down
10 changes: 10 additions & 0 deletions tb_plugin/torch_tb_profiler/io/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@
except ImportError:
GS_ENABLED = False

try:
# Imports the HDFS library
from fsspec.implementations.arrow import HadoopFileSystem
HDFS_ENABLED = True
except ImportError:
HDFS_ENABLED = False

_DEFAULT_BLOCK_SIZE = 16 * 1024 * 1024

Expand Down Expand Up @@ -352,6 +358,10 @@ def stat(self, filename):
from .gs import GoogleBlobSystem
register_filesystem("gs", GoogleBlobSystem())

if HDFS_ENABLED:
from .hdfs import HadoopFileSystem
register_filesystem("hdfs", HadoopFileSystem())


class File(object):
def __init__(self, filename, mode):
Expand Down
68 changes: 68 additions & 0 deletions tb_plugin/torch_tb_profiler/io/hdfs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import os

import fsspec
from fsspec.implementations import arrow

from .. import utils
from .base import BaseFileSystem, RemotePath, StatData
from .utils import as_bytes, as_text, parse_blob_url

logger = utils.get_logger()

class HadoopFileSystem(RemotePath, BaseFileSystem):
def __init__(self) -> None:
super().__init__()

def get_fs(self) -> arrow.HadoopFileSystem:
return fsspec.filesystem("hdfs")

def exists(self, filename):
return self.get_fs().exists(filename)

def read(self, filename, binary_mode=False, size=None, continue_from=None):
fs = self.get_fs()
mode = "rb" if binary_mode else "r"
encoding = None if binary_mode else "utf8"
offset = None
if continue_from is not None:
offset = continue_from.get("opaque_offset", None)
with fs.open(path=filename, mode=mode, encoding=encoding) as f:
if offset is not None:
f.seek(offset)
data = f.read(size)
continuation_token = {"opaque_offset": f.tell()}
return (data, continuation_token)

def write(self, filename, file_content, binary_mode=False):
fs = self.get_fs()
if binary_mode:
fs.write_bytes(filename, as_bytes(file_content))
else:
fs.write_text(filename, as_text(file_content), encoding="utf8")

def glob(self, filename):
return self.get_fs().glob(filename)

def isdir(self, dirname):
return self.get_fs().isdir(dirname)

def listdir(self, dirname):
fs = self.get_fs()
full_path = fs.listdir(dirname, detail=False)
# strip the protocol from the root path because the path returned by
# pyarrow listdir is not prefixed with the protocol.
root_path_to_strip = fs._strip_protocol(dirname)
return [os.path.relpath(path, root_path_to_strip) for path in full_path]

def makedirs(self, path):
return self.get_fs().makedirs(path, exist_ok=True)

def stat(self, filename):
stat = self.get_fs().stat(filename)
return StatData(stat['size'])

def support_append(self):
return False

def download_file(self, file_to_download, file_to_save):
return self.get_fs().download(file_to_download, file_to_save, recursive=True)

0 comments on commit 8d4234c

Please sign in to comment.