From 481d8cdfa22338068c5c19c33058f6ecb5eca2e7 Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Thu, 15 Aug 2024 05:43:45 -0700 Subject: [PATCH] Rename use_tp to use_distributed Summary: use_tp was renamed to use_distributed in https://github.com/pytorch/torchchat/pull/873, but the PR missed generate.py, which caused "AttributeError: 'BuilderArgs' object has no attribute 'use_tp'" when running generate with --profile. --- generate.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/generate.py b/generate.py index 48a77ba29..b30f90813 100644 --- a/generate.py +++ b/generate.py @@ -160,9 +160,9 @@ def __init__( # global print # from tp import maybe_init_dist # rank = maybe_init_dist() - # use_tp = False + # use_distributed = False self.rank: Optional[int] = None - # if use_tp: + # if use_distributed: # if rank != 0: # # only print on rank 0 # print = lambda *args, **kwargs: None @@ -611,7 +611,7 @@ def chat( ) if generator_args.compile: if ( - self.is_speculative and self.builder_args.use_tp + self.is_speculative and self.builder_args.use_distributed ): # and ("cuda" in builder_args.device): torch._inductor.config.triton.cudagraph_trees = ( False # Bug with cudagraph trees in this case @@ -740,7 +740,7 @@ def callback(x, *, done_generating=False): ) if (i != generator_args.num_samples - 1 or not self.profile) or ( - self.builder_args.use_tp and self.rank != 0 + self.builder_args.use_distributed and self.rank != 0 ): import contextlib @@ -777,7 +777,7 @@ def callback(x, *, done_generating=False): ) compilation_time = time.perf_counter() - t0 if hasattr(prof, "export_chrome_trace"): - if self.builder_args.use_tp: + if self.builder_args.use_distributed: prof.export_chrome_trace(f"{self.profile}_rank_{self.rank}.json") else: prof.export_chrome_trace(f"{self.profile}.json")