@@ -158,16 +158,21 @@ def __init__(self, attention: Attention):
158158 attention .kv_cache [0 ].k_cache .shape
159159 )
160160 cache_dtype = attention .kv_cache [0 ].k_cache .dtype
161- self .kv_cache = CustomKVCache (
162- max_batch_size , max_seq_length , n_heads , head_dim , cache_dtype
163- )
161+ # The `Attention` module being replaced can have multiple KV caches
162+ # (denoted by `cache_lanes`). Thus we follow the same setup format
163+ # as in `Attention.setup_cache`.
164+ cache_lanes = len (attention .kv_cache )
165+ self .kv_cache = nn .ModuleList ([
166+ CustomKVCache (max_batch_size , max_seq_length , n_heads , head_dim , cache_dtype )
167+ for _ in range (cache_lanes )
168+ ])
164169
165170 self .n_heads = attention .n_heads
166171 self .head_dim = attention .head_dim
167172 self .n_local_heads = attention .n_local_heads
168173 self .dim = attention .dim
169174
170- def forward (self , x , freqs_cis , mask , input_pos = None ):
175+ def forward (self , x , freqs_cis , mask , input_pos = None , cache_lane : int = 0 ):
171176 bsz , seqlen , _ = x .shape
172177
173178 q = self .wq (x )
@@ -184,12 +189,13 @@ def forward(self, x, freqs_cis, mask, input_pos=None):
184189
185190 # KV cache should always be enabled
186191 assert self .kv_cache is not None
192+ kv_cache = self .kv_cache [cache_lane ]
187193 output = torch .ops .llama .sdpa_with_kv_cache (
188194 q ,
189195 k ,
190196 v ,
191- self . kv_cache .k_cache ,
192- self . kv_cache .v_cache ,
197+ kv_cache .k_cache ,
198+ kv_cache .v_cache ,
193199 input_pos [- 1 ].item (),
194200 seqlen ,
195201 )
0 commit comments