Skip to content

Commit

Permalink
[Profiler][Memory] Export raw timestamped events in export_memory_tim…
Browse files Browse the repository at this point in the history
…eline_raw

Summary:
Rather than processing the events into a time and sizes plot, dump the actual events as (timestamp, action, num of bytes, category) when output file ends in `raw.json.gz`.

This can allow downstream analysis tools to process these events. It also avoids having to control the granularity of the previous json.gz in memory profiler.

Test Plan: CI Tests

Differential Revision: D47416544

Pulled By: aaronenyeshi

fbshipit-source-id: fc183a2c23d0e19e4c2bc55c3ccdd50fdef536a9
  • Loading branch information
aaronenyeshi authored and facebook-github-bot committed Jul 12, 2023
1 parent 15c67ca commit 22aa211
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
38 changes: 38 additions & 0 deletions torch/profiler/_memory_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class Action(enum.Enum):
INCREMENT_VERSION = enum.auto()
DESTROY = enum.auto()

_ACTION_TO_INDEX = {i: i.value for i in Action}

@dataclasses.dataclass(eq=True, unsafe_hash=False, frozen=True)
class Key:
Expand Down Expand Up @@ -1046,6 +1047,43 @@ def export_memory_timeline(self, path, device) -> None:
with open(path, 'w') as f:
json.dump([times, sizes], f)

def export_memory_timeline_raw(self, path, device_str) -> None:
"""Saves the memory timeline as raw memory event tuples in the
form of (timestamp, action, numbytes, category)
as a JSON formatted file to the given path for the given
device."""
device = torch.device(device_str)
raw_events: List[Tuple[int, int, int, int]] = []

def get_category_index(key, version):
category = (
self.categories.get(key, version)
if isinstance(key, TensorKey)
else None
)
return _CATEGORY_TO_INDEX[category]

for t, action, (key, version), numbytes in self.timeline:
if key.device != device:
continue

if action in (Action.PREEXISTING, Action.CREATE):
raw_events.append((t, _ACTION_TO_INDEX[action], numbytes, get_category_index(key, version)))

elif action == Action.INCREMENT_VERSION:
raw_events.append((t, _ACTION_TO_INDEX[action], -numbytes, get_category_index(key, version)))
raw_events.append((t, _ACTION_TO_INDEX[action], numbytes, get_category_index(key, version+1)))

elif action == Action.DESTROY:
raw_events.append((t, _ACTION_TO_INDEX[action], -numbytes, get_category_index(key, version)))

else:
raise ValueError(f"Unknown action: {action}")

import json
with open(path, 'w') as f:
json.dump(raw_events, f)

def export_memory_timeline_html(self, path, device, figsize=(20, 12), title=None) -> None:
"""Exports the memory timeline as an HTML file which contains
the memory timeline plot embedded as a PNG file."""
Expand Down
5 changes: 4 additions & 1 deletion torch/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,10 @@ def export_memory_timeline(self, path: str, device: Optional[str] = None) -> Non
elif path.endswith('.gz'):
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
fp.close()
self.mem_tl.export_memory_timeline(fp.name, device)
if path.endswith('raw.json.gz'):
self.mem_tl.export_memory_timeline_raw(fp.name, device)
else:
self.mem_tl.export_memory_timeline(fp.name, device)
with open(fp.name) as fin:
with gzip.open(path, 'wt') as fout:
fout.writelines(fin)
Expand Down

0 comments on commit 22aa211

Please sign in to comment.