Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Conversation

@kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Sep 21, 2024

KV cache is extended to have multiple lanes, each letting a separate batch pass through, achieving pipeline parallelism.

# The number of cache lanes is the same as the maximum number of
# micro-batches that can be "in flight" in parallel -- imagine each
# micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces.
# When decoding is done for certain micro-batches, we can reuse the KV cache
# lanes.

Major changes

  1. setup_caches will take one kwarg cache_lanes (default to 1).
def setup_caches(self, max_batch_size, max_seq_length, cache_lanes: int = 1)
  1. attention.kv_cache is now a nn.ModuleList, containing multiple KVCache's, each corresponding to a lane.

  2. We now pass kwargs = {"input_pos": input_pos, "cache_lane": lane} to the step() function. Removing the temporary helper function model.setup_input_pos.

Requires pytorch/pytorch#136416 to support pass-in of kwargs.

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1174

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9514b54 with merge base 8d01d9b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 21, 2024
@kwen2501 kwen2501 changed the title [WIP][Distributed] Add lanes to KV cache [Distributed] Add lanes to KV cache Sep 23, 2024
)
# create schedule
decode_schedule = ScheduleGPipe(decode_stage, mbs)
decorder = ScheduleGPipe(decode_stage, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax error - this should be 'decoder' and not 'decorder'.

# Run data through pipeline
if pp_rank == first_pp_rank:
output = decode_schedule.step(new_token)
output = decorder.step(new_token, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, syntax error - this should be 'decoder' and not 'decorder'.

output = decorder.step(new_token, **kwargs)
elif pp_rank == last_pp_rank:
output = decode_schedule.step()
output = decorder.step(**kwargs)
Copy link
Contributor

@lessw2020 lessw2020 Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, syntax error - this should be 'decoder' and not 'decorder'.

output = decorder.step(**kwargs)
else: # middle pp ranks
decode_schedule.step()
decorder.step(**kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

last one, syntax error - this should be 'decoder' and not 'decorder'.

Copy link
Contributor

@lessw2020 lessw2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice addition!
minor note that 'decorder' should be 'decoder' in the code for ease of understanding/syntax.

@kwen2501 kwen2501 merged commit 2cf4016 into main Sep 23, 2024
51 checks passed
kwen2501 added a commit that referenced this pull request Sep 23, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants