From 2ca38da2b512c6e581f4d8df1f97b365ebf4ba62 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Thu, 2 Oct 2025 10:31:15 -0700 Subject: [PATCH] update lama export DS specs to be more accurate. (#14737) Summary: this PR https://github.com/pytorch/pytorch/pull/164075 enhance value range analysis discovering a problem in the current upper bounds for DS specs for lama full DS logs if someone want to understand why in details https://www.internalfb.com/phabricator/paste/view/P1973006378 This blocks landing the PR above. Reviewed By: angelayi, larryliu0820 Differential Revision: D83708583 --- extension/llm/export/builder.py | 9 +++++++-- extension/llm/export/test/test_builder.py | 2 +- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 01000f3564c..da5c3324662 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -142,9 +142,14 @@ def __init__( {1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)}, ) else: - # Two input arguments: tokens and input_pos but input_pos is static shape + # Two input arguments: tokens and input_pos but input_pos is static shape. + + # A runtime assertion is added by torch.ops.llama.update_cache requires that + # L['tokens'].size()[1] + input_pos[0].item() < self.max_seq_len + # This consttaint L['tokens'].size()[1] to be elf.max_seq_len-1 + # run with TORCH_LOGS=+dynamic for details self.dynamic_shapes = ( - {1: torch.export.Dim("token_dim", max=self.max_seq_len)}, + {1: torch.export.Dim("token_dim", max=self.max_seq_len - 1)}, {"input_pos": {0: 1}}, ) diff --git a/extension/llm/export/test/test_builder.py b/extension/llm/export/test/test_builder.py index 8bf591813ec..7883480c1e7 100644 --- a/extension/llm/export/test/test_builder.py +++ b/extension/llm/export/test/test_builder.py @@ -88,7 +88,7 @@ def test_get_dynamic_shape_with_dynamic_shape_enabled_with_kv_cache(self) -> Non # Check first element (tokens dimension) self.assertIsInstance(result[0], dict) self.assertIn(1, result[0]) - self.assertEqual(result[0][1].max, self.max_seq_len) + self.assertEqual(result[0][1].max, self.max_seq_len - 1) # Check second element (input_pos dimension) self.assertIsInstance(result[1], dict)