Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions run_sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,49 @@
from typing import List, Optional, Dict, Any, Tuple
from torchbenchmark import ModelTask

WARMUP_ROUNDS = 3
WARMUP_ROUNDS = 11
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this a tune-able argument of run_sweep.py.

WORKER_TIMEOUT = 600 # seconds
MODEL_DIR = ['torchbenchmark', 'models']
NANOSECONDS_PER_MILLISECONDS = 1_000_000.0

def run_one_step(func, device: str, nwarmup=WARMUP_ROUNDS, num_iter=10) -> Tuple[float, Optional[Tuple[torch.Tensor]]]:
def run_one_step(func, device: str, nwarmup=WARMUP_ROUNDS, num_iter=20) -> Tuple[float, Optional[Tuple[torch.Tensor]]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this one as well, add it to the command line option.

"Run one step of the model, and return the latency in milliseconds."
# Warm-up `nwarmup` rounds
for _i in range(nwarmup):
func()
result_summary = []
events = [torch.cuda.Event(enable_timing=True) for _ in range(num_iter)]
end_event = torch.cuda.Event(enable_timing=True)
for _i in range(num_iter):
if device == "cuda":
torch.cuda.synchronize()
# Collect time_ns() instead of time() which does not provide better precision than 1
# second according to https://docs.python.org/3/library/time.html#time.time.
t0 = time.time_ns()
func()
torch.cuda.synchronize() # Wait for the events to be recorded!
t1 = time.time_ns()
events[_i].record()
else:
t0 = time.time_ns()
func()
t1 = time.time_ns()
result_summary.append((t1 - t0) / NANOSECONDS_PER_MILLISECONDS)
wall_latency = numpy.median(result_summary)
return wall_latency
result_summary.append((t1 - t0) / NANOSECONDS_PER_MILLISECONDS)

if device != 'cuda':
# return wall_latency
wall_latency_cpu = numpy.median(result_summary)
return wall_latency_cpu
else:
end_event.record()
torch.cuda.synchronize()

times = []
for event in events:
times.append(event.elapsed_time(end_event))

linreg = torch.linalg.lstsq(
torch.vstack([torch.arange(num_iter), torch.ones(num_iter)]).T,
torch.tensor(times))
# print(linreg, flush=True)
step_time_ms = -linreg[0][0].item()

return step_time_ms
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this measure the function wall time or the gpu events time? If it is not wall time, we should create another metrics, "gpu_time", in the "results" dict.



@dataclasses.dataclass
class ModelTestResult:
Expand Down
1 change: 0 additions & 1 deletion torchbenchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,6 @@ def set_train(self) -> None:
def invoke(self) -> None:
self.worker.run("""
model.invoke()
maybe_sync()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't understand why you are removing this sync function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is to remove cuda sync at every step.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

@xuzhao9 xuzhao9 Jun 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I am wondering why remove cuda sync at every step? Currently running CUDA sync at the end of every test in the subprocess is our assumption because we found upstream models may (like detectron2) or may not (like torchvision) add CUDA sync in their own code. So even if you remove cuda sync here, the model code may still do cuda sync.

""")

def set_eval(self) -> None:
Expand Down