diff --git a/generate.py b/generate.py index 48a77ba29..b30f90813 100644 --- a/generate.py +++ b/generate.py @@ -160,9 +160,9 @@ def __init__( # global print # from tp import maybe_init_dist # rank = maybe_init_dist() - # use_tp = False + # use_distributed = False self.rank: Optional[int] = None - # if use_tp: + # if use_distributed: # if rank != 0: # # only print on rank 0 # print = lambda *args, **kwargs: None @@ -611,7 +611,7 @@ def chat( ) if generator_args.compile: if ( - self.is_speculative and self.builder_args.use_tp + self.is_speculative and self.builder_args.use_distributed ): # and ("cuda" in builder_args.device): torch._inductor.config.triton.cudagraph_trees = ( False # Bug with cudagraph trees in this case @@ -740,7 +740,7 @@ def callback(x, *, done_generating=False): ) if (i != generator_args.num_samples - 1 or not self.profile) or ( - self.builder_args.use_tp and self.rank != 0 + self.builder_args.use_distributed and self.rank != 0 ): import contextlib @@ -777,7 +777,7 @@ def callback(x, *, done_generating=False): ) compilation_time = time.perf_counter() - t0 if hasattr(prof, "export_chrome_trace"): - if self.builder_args.use_tp: + if self.builder_args.use_distributed: prof.export_chrome_trace(f"{self.profile}_rank_{self.rank}.json") else: prof.export_chrome_trace(f"{self.profile}.json")