From 39eff9076ac12eed46b15712f1b0ba1ccd5b3676 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 20 Sep 2024 22:48:26 -0700 Subject: [PATCH 1/6] [WIP][Distributed] Add lanes to KV cache --- dist_run.py | 58 ++++++++++++++++++++++++++-------------------- torchchat/model.py | 22 ++++++++++-------- 2 files changed, 45 insertions(+), 35 deletions(-) diff --git a/dist_run.py b/dist_run.py index fc580ea2a..c39e2c275 100644 --- a/dist_run.py +++ b/dist_run.py @@ -273,13 +273,11 @@ def main(args): pp_rank = pp_mesh.get_local_rank() tp_group = tp_mesh.get_group() pp_group = pp_mesh.get_group() - pp_group_size = pp_group.size() - tp_group_size = tp_group.size() - logger.info(f"{pp_group_size=}, {tp_group_size=}") + logger.info(f"{pp_degree=}, {tp_degree=}") # Convenience variables first_pp_rank = 0 - last_pp_rank = pp_group_size - 1 + last_pp_rank = pp_degree - 1 # Assuming same number of GPUs per node device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") @@ -297,18 +295,22 @@ def main(args): if rank == 0: logger.info(f"Model: {model}") - mbs = 1 # number of micro-batches - mb_size = 4 # micro-batch size - batch_size = mbs * mb_size # total batch size - + # Batch size. Since we push batches dynamically through the pipeline rather + # than chunking them, this is effectively micro-batch size in pipeline + # sense. Thus it is interchangeable with micro-batch size below. + batch_size = 4 seqlen_prefill = 1024 # sequence length dim = 4096 # embedding dimension # Setup KV caches (after model distribution) - # TODO: the setting below only works for 1 micro-batch case. To support - # multiple micro-batches, we need the KV cache in the model to be aware of - # the number of micro-batches and the current micro-batch index. - model.setup_caches(mb_size, seqlen_prefill) + # The number of cache lanes is the same as the maximum number of + # micro-batches that can be "in flight" in parallel -- imagine each + # micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces. + # When decoding is done for certain micro-batches, we can reuse the KV cache + # lanes. + # TODO: bump up the lane count + cache_lanes = 1 + model.setup_caches(batch_size, seqlen_prefill, cache_lanes=cache_lanes) # Load weights logger.info(f"Loading weights for {pp_rank=} on {device=}") @@ -317,7 +319,7 @@ def main(args): model.to(device) logger.info( - f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for stage {rank}{color.reset}" + f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" ) # info on stage size and params @@ -335,12 +337,12 @@ def main(args): # Helper function to get example inputs and outputs for the stages. def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: - mb_ids = torch.randint(0, config.vocab_size, (mb_size, seqlen), device=device) + mb_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), device=device) activation = torch.rand( - mb_size, seqlen, dim, device=device, dtype=model_dtype + batch_size, seqlen, dim, device=device, dtype=model_dtype ) logits = torch.rand( - mb_size, seqlen, config.vocab_size, device=device, dtype=model_dtype + batch_size, seqlen, config.vocab_size, device=device, dtype=model_dtype ) example_inputs = (mb_ids if pp_rank == first_pp_rank else activation,) example_outputs = (logits if pp_rank == last_pp_rank else activation,) @@ -358,8 +360,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: output_args=example_outputs, group=pp_group, ) + # Number of micro-batches for the schedule is 1, because each step() call we + # only push 1 micro-batch into the pipeline. But we can continuously push + # new micro-batches into the pipeline as they arrive, achieving same + # pipelining effect. + mbs = 1 # create schedule - prefill_schedule = ScheduleGPipe(prefill_stage, mbs) + prefiller = ScheduleGPipe(prefill_stage, mbs) prompt = [ "What is a computer?", @@ -401,14 +408,15 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: num_tokens = 40 # Prefill phase - # Run context input through pipeline, in 1 step + # Run context input through pipeline + # TODO: we need to pass `input_pos` and `cache_lane` to each stage. with torch.no_grad(): if pp_rank == first_pp_rank: - output = prefill_schedule.step(padded_sequence) + output = prefiller.step(padded_sequence) elif pp_rank == last_pp_rank: - output = prefill_schedule.step() + output = prefiller.step() else: # middle pp ranks - prefill_schedule.step() + prefiller.step() # Decode the output -- first generated token if pp_rank == last_pp_rank: @@ -445,7 +453,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: group=pp_group, ) # create schedule - decode_schedule = ScheduleGPipe(decode_stage, mbs) + decorder = ScheduleGPipe(decode_stage, mbs) # Decoding with torch.no_grad(): @@ -467,11 +475,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # Run data through pipeline if pp_rank == first_pp_rank: - output = decode_schedule.step(new_token) + output = decorder.step(new_token) elif pp_rank == last_pp_rank: - output = decode_schedule.step() + output = decorder.step() else: # middle pp ranks - decode_schedule.step() + decorder.step() # Decode the output if pp_rank == last_pp_rank: diff --git a/torchchat/model.py b/torchchat/model.py index aaa72cb2a..a641f6116 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -606,7 +606,7 @@ def __init__(self, config: TransformerArgs) -> None: self.max_batch_size = -1 self.max_seq_length = -1 - def setup_caches(self, max_batch_size, max_seq_length): + def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1): if ( self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size @@ -620,7 +620,7 @@ def setup_caches(self, max_batch_size, max_seq_length): # parallelism may have been applied there and the `n_local_heads`` # value being adjusted. b.attention.setup_cache( - max_batch_size, max_seq_length, + max_batch_size, max_seq_length, cache_lanes=cache_lanes ) freqs_cis = precompute_freqs_cis( @@ -658,7 +658,7 @@ def distribute(self, device_mesh: DeviceMesh): def setup_input_pos(self, input_pos: Tensor) -> None: self._input_pos = input_pos - def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 1) -> Tensor: assert self.freqs_cis is not None, "Caches must be initialized first" # TODO: find a better way to pass input_pos to non-0 pipeline stages input_pos = input_pos if input_pos is not None else self._input_pos @@ -668,7 +668,7 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: x = self.tok_embeddings(x) for _, layer in self.layers.items(): - x = layer(x, input_pos, freqs_cis, mask) + x = layer(x, input_pos, freqs_cis, mask, cache_lane=cache_lane) if self.norm: x = self.norm(x) @@ -691,7 +691,7 @@ def distribute(self, device_mesh: DeviceMesh): self.feed_forward.distribute(device_mesh) def forward( - self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor + self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, cache_lane: int = 0 ) -> Tensor: h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) out = h + self.feed_forward(self.ffn_norm(h)) @@ -723,15 +723,16 @@ def __init__(self, config: TransformerArgs): self.dim = config.dim self._register_load_state_dict_pre_hook(self.load_hook) - def setup_cache(self, max_batch_size, max_seq_length): + def setup_cache(self, max_batch_size, max_seq_length, cache_lanes: int = 1): n_local_heads = self.n_local_heads # If TP is enabled, the heads would be divided and assigned to different ranks if hasattr(self, "tp_degree"): n_local_heads = self.n_local_heads // self.tp_degree - self.kv_cache = KVCache( - max_batch_size, max_seq_length, n_local_heads, self.head_dim - ) + self.kv_cache = nn.ModuleList([ + KVCache(max_batch_size, max_seq_length, n_local_heads, self.head_dim) + for _ in range(cache_lanes) + ]) def load_hook(self, state_dict, prefix, *args): # if prefix + "wq.weight" in state_dict: @@ -784,6 +785,7 @@ def forward( freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None, + cache_lane: int = 0, ) -> Tensor: bsz, seqlen, _ = x.shape @@ -809,7 +811,7 @@ def forward( q, k, v = (x.transpose(1, 2) for x in (q, k, v)) if self.kv_cache is not None: - k, v = self.kv_cache.update(input_pos, k, v) + k, v = self.kv_cache[cache_lane].update(input_pos, k, v) k = k.repeat_interleave(self.n_heads // self.n_local_heads, dim=1) v = v.repeat_interleave(self.n_heads // self.n_local_heads, dim=1) From 3ba19ea3c7da686289f3e29064e2fcea4f1f1cfd Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Sat, 21 Sep 2024 00:10:06 -0700 Subject: [PATCH 2/6] Compatibility change --- torchchat/export.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchchat/export.py b/torchchat/export.py index affb8b871..263c3815a 100644 --- a/torchchat/export.py +++ b/torchchat/export.py @@ -152,9 +152,9 @@ def __init__(self, attention: Attention): self.wo = attention.wo max_batch_size, n_heads, max_seq_length, head_dim = ( - attention.kv_cache.k_cache.shape + attention.kv_cache[0].k_cache.shape ) - cache_dtype = attention.kv_cache.k_cache.dtype + cache_dtype = attention.kv_cache[0].k_cache.dtype self.kv_cache = CustomKVCache( max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype ) From 2bec61c4fd729daa165f98dd38a4a96f454f6979 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Sat, 21 Sep 2024 00:14:34 -0700 Subject: [PATCH 3/6] Naming --- dist_run.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dist_run.py b/dist_run.py index c39e2c275..44bf161bc 100644 --- a/dist_run.py +++ b/dist_run.py @@ -309,8 +309,8 @@ def main(args): # When decoding is done for certain micro-batches, we can reuse the KV cache # lanes. # TODO: bump up the lane count - cache_lanes = 1 - model.setup_caches(batch_size, seqlen_prefill, cache_lanes=cache_lanes) + pipeline_lanes = 1 + model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes) # Load weights logger.info(f"Loading weights for {pp_rank=} on {device=}") From 4ecb9513d5e8ed5cd5cf7f502022066d21baea90 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Sat, 21 Sep 2024 21:55:03 -0700 Subject: [PATCH 4/6] Remove setup_input_pos --- dist_run.py | 18 +++++++++--------- torchchat/model.py | 7 ------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/dist_run.py b/dist_run.py index 44bf161bc..c211dd343 100644 --- a/dist_run.py +++ b/dist_run.py @@ -332,7 +332,6 @@ def main(args): # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen input_pos = torch.arange(seqlen_prefill, device=device) - model.setup_input_pos(input_pos) model.eval() # Helper function to get example inputs and outputs for the stages. @@ -410,13 +409,15 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # Prefill phase # Run context input through pipeline # TODO: we need to pass `input_pos` and `cache_lane` to each stage. + lane = 0 + kwargs = {"input_pos": input_pos, "cache_lane": lane} with torch.no_grad(): if pp_rank == first_pp_rank: - output = prefiller.step(padded_sequence) + output = prefiller.step(padded_sequence, **kwargs) elif pp_rank == last_pp_rank: - output = prefiller.step() + output = prefiller.step(**kwargs) else: # middle pp ranks - prefiller.step() + prefiller.step(**kwargs) # Decode the output -- first generated token if pp_rank == last_pp_rank: @@ -438,7 +439,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # seqlen = 1 now seqlen_decode = 1 input_pos = torch.tensor([prompt_lengths[0]], device=device) - model.setup_input_pos(input_pos) # Create decode stage logger.info(f"Creating pipeline stage for decode {pp_rank=}, {pp_degree=}") @@ -458,6 +458,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # Decoding with torch.no_grad(): for step in range(num_tokens - 1): + kwargs = {"input_pos": input_pos, "cache_lane": lane} # sendrecv between last and first ranks, only if: # first_pp_rank != last_pp_rank. if pp_rank == last_pp_rank and pp_rank != first_pp_rank: @@ -475,11 +476,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # Run data through pipeline if pp_rank == first_pp_rank: - output = decorder.step(new_token) + output = decorder.step(new_token, **kwargs) elif pp_rank == last_pp_rank: - output = decorder.step() + output = decorder.step(**kwargs) else: # middle pp ranks - decorder.step() + decorder.step(**kwargs) # Decode the output if pp_rank == last_pp_rank: @@ -499,7 +500,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: ) # decode_results[i][0] input_pos += 1 - model.setup_input_pos(input_pos) # Display the decoding results diff --git a/torchchat/model.py b/torchchat/model.py index a641f6116..228b97c3d 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -653,15 +653,8 @@ def distribute(self, device_mesh: DeviceMesh): ColwiseParallel(output_layouts=Replicate()), ) - # This is a temporary solution to pass input_pos to non-0 pipeline stages - # TODO: make `step()` function of dist.pipelining accept args for non-0 stages - def setup_input_pos(self, input_pos: Tensor) -> None: - self._input_pos = input_pos - def forward(self, x: Tensor, input_pos: Optional[Tensor] = None, cache_lane: int = 1) -> Tensor: assert self.freqs_cis is not None, "Caches must be initialized first" - # TODO: find a better way to pass input_pos to non-0 pipeline stages - input_pos = input_pos if input_pos is not None else self._input_pos mask = self.causal_mask[None, None, input_pos] freqs_cis = self.freqs_cis[input_pos] if self.tok_embeddings: From 5951e393b97e919654d99e89a207b9f6076c6032 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Sun, 22 Sep 2024 23:25:25 -0700 Subject: [PATCH 5/6] Add timer --- dist_run.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/dist_run.py b/dist_run.py index c211dd343..502ad2260 100644 --- a/dist_run.py +++ b/dist_run.py @@ -394,7 +394,6 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: s = set(prompt_lengths) assert len(s) == 1, f"prompt_lengths should be the same, got {s}" - # with CUDATrackTime() as timer: # Need these global ids due to the API definition of dist.send and recv first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank) last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank) @@ -411,7 +410,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: # TODO: we need to pass `input_pos` and `cache_lane` to each stage. lane = 0 kwargs = {"input_pos": input_pos, "cache_lane": lane} - with torch.no_grad(): + with torch.no_grad(), CUDATrackTime() as timer: if pp_rank == first_pp_rank: output = prefiller.step(padded_sequence, **kwargs) elif pp_rank == last_pp_rank: @@ -419,6 +418,10 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: else: # middle pp ranks prefiller.step(**kwargs) + logger.info( + f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" + ) + # Decode the output -- first generated token if pp_rank == last_pp_rank: decode_results = _batch_decode_next_tokens( @@ -456,7 +459,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: decorder = ScheduleGPipe(decode_stage, mbs) # Decoding - with torch.no_grad(): + with torch.no_grad(), CUDATrackTime() as timer: for step in range(num_tokens - 1): kwargs = {"input_pos": input_pos, "cache_lane": lane} # sendrecv between last and first ranks, only if: @@ -501,6 +504,10 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: input_pos += 1 + logger.info( + f"{color.green}Decoding time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" + ) + # Display the decoding results # output formatted response via last pp group and tp rank 0 From 9514b5404e33337a44da830795004703f148cf0a Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Sun, 22 Sep 2024 23:51:55 -0700 Subject: [PATCH 6/6] Remove mbs --- dist_run.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dist_run.py b/dist_run.py index 502ad2260..3fbb857c7 100644 --- a/dist_run.py +++ b/dist_run.py @@ -359,13 +359,13 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: output_args=example_outputs, group=pp_group, ) + + # Create schedule # Number of micro-batches for the schedule is 1, because each step() call we # only push 1 micro-batch into the pipeline. But we can continuously push # new micro-batches into the pipeline as they arrive, achieving same # pipelining effect. - mbs = 1 - # create schedule - prefiller = ScheduleGPipe(prefill_stage, mbs) + prefiller = ScheduleGPipe(prefill_stage, 1) prompt = [ "What is a computer?", @@ -456,7 +456,7 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: group=pp_group, ) # create schedule - decorder = ScheduleGPipe(decode_stage, mbs) + decorder = ScheduleGPipe(decode_stage, 1) # Decoding with torch.no_grad(), CUDATrackTime() as timer: