Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Profiler][Memory] Export raw timestamped events in export_memory_timeline_raw #105094

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 38 additions & 0 deletions torch/profiler/_memory_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class Action(enum.Enum):
INCREMENT_VERSION = enum.auto()
DESTROY = enum.auto()

_ACTION_TO_INDEX = {i: i.value for i in Action}

@dataclasses.dataclass(eq=True, unsafe_hash=False, frozen=True)
class Key:
Expand Down Expand Up @@ -1046,6 +1047,43 @@ def export_memory_timeline(self, path, device) -> None:
with open(path, 'w') as f:
json.dump([times, sizes], f)

def export_memory_timeline_raw(self, path, device_str) -> None:
"""Saves the memory timeline as raw memory event tuples in the
form of (timestamp, action, numbytes, category)
as a JSON formatted file to the given path for the given
device."""
device = torch.device(device_str)
raw_events: List[Tuple[int, int, int, int]] = []

def get_category_index(key, version):
category = (
self.categories.get(key, version)
if isinstance(key, TensorKey)
else None
)
return _CATEGORY_TO_INDEX[category]

for t, action, (key, version), numbytes in self.timeline:
if key.device != device:
continue

if action in (Action.PREEXISTING, Action.CREATE):
raw_events.append((t, _ACTION_TO_INDEX[action], numbytes, get_category_index(key, version)))

elif action == Action.INCREMENT_VERSION:
raw_events.append((t, _ACTION_TO_INDEX[action], -numbytes, get_category_index(key, version)))
raw_events.append((t, _ACTION_TO_INDEX[action], numbytes, get_category_index(key, version + 1)))

elif action == Action.DESTROY:
raw_events.append((t, _ACTION_TO_INDEX[action], -numbytes, get_category_index(key, version)))

else:
raise ValueError(f"Unknown action: {action}")

import json
with open(path, 'w') as f:
json.dump(raw_events, f)

def export_memory_timeline_html(self, path, device, figsize=(20, 12), title=None) -> None:
"""Exports the memory timeline as an HTML file which contains
the memory timeline plot embedded as a PNG file."""
Expand Down
5 changes: 4 additions & 1 deletion torch/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,10 @@ def export_memory_timeline(self, path: str, device: Optional[str] = None) -> Non
elif path.endswith('.gz'):
fp = tempfile.NamedTemporaryFile('w+t', suffix='.json', delete=False)
fp.close()
self.mem_tl.export_memory_timeline(fp.name, device)
if path.endswith('raw.json.gz'):
self.mem_tl.export_memory_timeline_raw(fp.name, device)
else:
self.mem_tl.export_memory_timeline(fp.name, device)
with open(fp.name) as fin:
with gzip.open(path, 'wt') as fout:
fout.writelines(fin)
Expand Down