- 
                Notifications
    You must be signed in to change notification settings 
- Fork 248
[Distributed] Add lanes to KV cache #1174
Conversation
| 🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1174
 Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 9514b54 with merge base 8d01d9b ( This comment was automatically generated by Dr. CI and updates every 15 minutes. | 
e33e681    to
    39eff90      
    Compare
  
    | ) | ||
| # create schedule | ||
| decode_schedule = ScheduleGPipe(decode_stage, mbs) | ||
| decorder = ScheduleGPipe(decode_stage, 1) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax error - this should be 'decoder' and not 'decorder'.
| # Run data through pipeline | ||
| if pp_rank == first_pp_rank: | ||
| output = decode_schedule.step(new_token) | ||
| output = decorder.step(new_token, **kwargs) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same, syntax error - this should be 'decoder' and not 'decorder'.
| output = decorder.step(new_token, **kwargs) | ||
| elif pp_rank == last_pp_rank: | ||
| output = decode_schedule.step() | ||
| output = decorder.step(**kwargs) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same, syntax error - this should be 'decoder' and not 'decorder'.
| output = decorder.step(**kwargs) | ||
| else: # middle pp ranks | ||
| decode_schedule.step() | ||
| decorder.step(**kwargs) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
last one, syntax error - this should be 'decoder' and not 'decorder'.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice addition!
minor note that 'decorder' should be 'decoder' in the code for ease of understanding/syntax.
KV cache is extended to have multiple lanes, each letting a separate batch pass through, achieving pipeline parallelism.
Major changes
setup_cacheswill take one kwargcache_lanes(default to 1).attention.kv_cacheis now ann.ModuleList, containing multipleKVCache's, each corresponding to a lane.We now pass
kwargs = {"input_pos": input_pos, "cache_lane": lane}to thestep()function. Removing the temporary helper functionmodel.setup_input_pos.Requires pytorch/pytorch#136416 to support pass-in of kwargs.