We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent af55e1c commit a10d262Copy full SHA for a10d262
torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja
@@ -231,8 +231,6 @@
231
start_n1 = pid * BLOCK_N1
232
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
233
234
- desc_k = None
235
- desc_v = None
236
{%- if USE_TMA %}
237
desc_k = tl.make_tensor_descriptor(
238
base=K,
@@ -246,8 +244,6 @@
246
244
strides=(stride_vn, stride_vd),
247
245
block_shape=[BLOCK_N1, V_HEAD_DIM_ROUNDED],
248
)
249
- {%- endif %}
250
- {%- if USE_TMA %}
251
k = tl.load_tensor_descriptor(
252
desc_k,
253
[start_n1.to(tl.int32), 0],
0 commit comments