diff --git a/docsrc/index.rst b/docsrc/index.rst
index 67fbdc56f5..4d28d77640 100644
--- a/docsrc/index.rst
+++ b/docsrc/index.rst
@@ -140,11 +140,10 @@ Model Zoo
 * :ref:`torch_compile_resnet`
 * :ref:`torch_compile_transformer`
 * :ref:`torch_compile_stable_diffusion`
+* :ref:`compile_hf_models`
 * :ref:`torch_compile_gpt2`
 * :ref:`torch_export_gpt2`
-* :ref:`torch_export_llama2`
 * :ref:`torch_export_sam2`
-* :ref:`torch_export_flux_dev`
 * :ref:`notebooks`
 
 .. toctree::
@@ -155,11 +154,10 @@ Model Zoo
    tutorials/_rendered_examples/dynamo/torch_compile_resnet_example
    tutorials/_rendered_examples/dynamo/torch_compile_transformers_example
    tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
+   tutorials/compile_hf_models
    tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
    tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion
    tutorials/_rendered_examples/dynamo/torch_compile_gpt2
-   tutorials/_rendered_examples/dynamo/torch_export_gpt2
-   tutorials/_rendered_examples/dynamo/torch_export_llama2
    tutorials/_rendered_examples/dynamo/torch_export_sam2
    tutorials/_rendered_examples/dynamo/torch_export_flux_dev
    tutorials/notebooks
diff --git a/docsrc/tutorials/compile_hf_models.rst b/docsrc/tutorials/compile_hf_models.rst
new file mode 100644
index 0000000000..f6da87b145
--- /dev/null
+++ b/docsrc/tutorials/compile_hf_models.rst
@@ -0,0 +1,218 @@
+.. _compile_hf_models:
+
+Compiling LLM models from Huggingface
+======================================
+
+This tutorial walks you through how to compile LLM models from Huggingface using Torch-TensorRT. We also introduce KV caching in Torch-TensorRT which can greatly improve the performance of LLM inference. 
+The code is available in the `tools/llm <https://github.com/pytorch/TensorRT/tree/main/tools/llm>`_ directory. We use the ``run_llm.py`` script to compile the model, generate outputs, and measure the performance.
+
+.. note::
+   This is an **experimental release** and APIs may change in future versions.
+
+.. note::
+   The compilation scripts and tutorials for Llama-2-7b-chat-hf and gpt2 models have been consolidated into the unified ``run_llm.py`` script located in the `tools/llm <https://github.com/pytorch/TensorRT/tree/main/tools/llm>`_ directory.
+
+Overview of tools/llm Directory
+-------------------------------
+
+The ``tools/llm`` directory provides the following tools to compile LLM models from Huggingface:
+
+* **run_llm.py**: Main entry point for model compilation, generating outputs, and benchmarking
+* **Static Cache Utilities**: ``static_cache_v1.py`` and ``static_cache_v2.py`` for KV cache optimization
+* **SDPA Attention**: ``sdpa_converter.py`` and ``register_sdpa.py`` for registering scaled dot-product attention converter and lowering pass.
+* **Testing Components**: Model-specific test files for validation
+* **Utility Functions**: ``utils.py`` and ``cache_utils.py`` for common operations
+
+Supported Models
+----------------
+We have officially verified support for the following LLM families:
+
+.. list-table::
+   :widths: 20 40 20 20
+   :header-rows: 1
+
+   * - Model Series
+     - HuggingFace Model Card
+     - Precision
+     - KV Cache Support ?
+   * - GPT-2
+     - gpt2
+     - FP16, FP32
+     - Yes
+   * - LLaMA 2
+     - meta-llama/Llama-2-7b-chat-hf
+     - FP16, FP32
+     - Yes
+   * - LLaMA 3.1
+     - meta-llama/Llama-3.1-8B-Instruct
+     - FP16, FP32
+     - Yes
+   * - LLaMA 3.2
+     - | meta-llama/Llama-3.2-1B-Instruct
+       | meta-llama/Llama-3.2-3B-Instruct
+     - FP16, FP32
+     - Yes
+   * - Qwen 2.5
+     - | Qwen/Qwen2.5-0.5B-Instruct
+       | Qwen/Qwen2.5-1.5B-Instruct
+       | Qwen/Qwen2.5-3B-Instruct
+       | Qwen/Qwen2.5-7B-Instruct
+     - FP16, FP32
+     - Yes
+
+Getting Started with run_llm.py
+-------------------------------
+
+The main entry point is ``run_llm.py``, which provides a complete workflow for model compilation and benchmarking.
+
+Basic Usage
+^^^^^^^^^^^
+
+.. code-block:: bash
+
+   python tools/llm/run_llm.py \
+     --model meta-llama/Llama-3.2-1B-Instruct \
+     --prompt "What is parallel programming?" \
+     --precision FP16 \
+     --num_tokens 128 \
+     --cache static_v2 \
+     --benchmark
+
+Key Arguments
+^^^^^^^^^^^^^
+
+* ``--model``: Name or path of the HuggingFace LLM
+* ``--tokenizer``: (Optional) Tokenizer name; defaults to model name
+* ``--prompt``: Input prompt for text generation
+* ``--precision``: Precision mode (``FP16``, ``FP32``)
+* ``--num_tokens``: Number of output tokens to generate
+* ``--cache``: KV cache type (``static_v1``, ``static_v2``, or empty for no KV caching)
+* ``--benchmark``: Enable benchmarking mode for performance comparison
+* ``--enable_pytorch_run``: Also run and compare PyTorch baseline
+
+
+Other Usage Examples
+^^^^^^^^^^^^^^^^^^^^
+.. code-block:: bash
+
+   # Compare different models performance
+   python tools/llm/run_llm.py --model gpt2 --benchmark --enable_pytorch_run
+   python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --benchmark --enable_pytorch_run
+
+   # Generate the outputs (disable benchmarking) by specifying the number of tokens to generate. Default = 128
+   python tools/llm/run_llm.py --model gpt2 --prompt "What is parallel programming?" --num_tokens 128
+   python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --num_tokens 128
+
+   # Test different caching approaches
+   python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --cache static_v1
+   python tools/llm/run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --cache static_v2
+
+   # Compare FP16 vs FP32 performance
+   python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP16 --benchmark
+   python tools/llm/run_llm.py --model Qwen/Qwen2.5-1.5B-Instruct --precision FP32 --benchmark
+
+
+KV Caching in Torch-TensorRT
+---------------------------------
+
+We provide two versions of static KV caching: `static_cache_v1 <https://github.com/pytorch/TensorRT/blob/main/tools/llm/static_cache_v1.py>`_ and `static_cache_v2 <https://github.com/pytorch/TensorRT/blob/main/tools/llm/static_cache_v2.py>`_.
+In both implementations, we add static KV cache tensors as model inputs/outputs without storing them as external memory.
+The length of KV cache = input sequence length + output sequence length (specified by ``--num_tokens``). The number of heads and head dimension are determined by the model config.
+
+Static Cache v1
+^^^^^^^^^^^^^^^^
+
+The ``static_cache_v1.py`` implements KV cache  in the model graph as follows: 
+
+.. code-block:: python
+
+    class StaticCacheV1Model(nn.Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True):
+            # Concatenate new key/value pairs with existing cache
+            new_key_cache = torch.cat((key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2)
+            new_value_cache = torch.cat((value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2)
+            
+            # Compute attention using the updated cache
+            attn_output = torch._C._nn.scaled_dot_product_attention(
+                q, 
+                new_key_cache[:, :, :end_idx, :], 
+                new_value_cache[:, :, :end_idx, :], 
+                dropout_p=0.0, 
+                is_causal=is_causal
+            )
+
+            return attn_output, new_key_cache, new_value_cache
+
+In the above code, we concatenate the new key/value pairs with the existing cache and update it. To compute the attention, we use the updated cache and gather the corresponding keys/values from the cache up until and including the current token index.
+The above code is actually implemented as a FX graph transformation pass. We register it as a Torch-TensorRT lowering pass using the decorator ``@_aten_lowering_pass`` when we import the ``static_cache_v1.py`` module.
+
+.. note::
+   The ``start_idx`` and ``end_idx`` are the start and end indices of the current token in the cache. For prefill phase, ``start_idx`` is 0 and ``end_idx`` is the input sequence length. 
+   For decode phase, ``start_idx`` begins at the input sequence length and ``end_idx`` equals ``start_idx + 1``. The ``start_idx`` is incremented by 1 until the end of the sequence or we reach the maximum number of tokens to generate.
+
+
+Static Cache v2
+^^^^^^^^^^^^^^^^
+
+The ``static_cache_v2.py`` is similar to ``static_cache_v1.py`` but it uses less number of slice operations. It implements KV cache in the model graph as follows: 
+
+.. code-block:: python
+
+    class StaticCacheV2Model(nn.Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True):
+            concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2) 
+            concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2)
+            new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2)
+            new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2)
+            attn_output = torch._C._nn.scaled_dot_product_attention(
+                  q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal
+            )
+
+            return attn_output, new_key_cache, new_value_cache
+
+In the above code, we concatenate the existing key/value cache with current key/value of the token. We use this to directly compute the attention and update the key/value cache inserting the current key/value.
+The above code is actually implemented as a FX graph transformation pass. We register it as a Torch-TensorRT lowering pass using the decorator ``@_aten_lowering_pass`` when we import the ``static_cache_v1.py`` module.
+The definitons of ``start_idx`` and ``end_idx`` are the same as ``static_cache_v1.py``.
+
+After the model is compiled with static KV cache, the input signature of the model is changed. The new input signature is ``(input_ids, position_ids, key_cache_0, value_cache_0, ..., start_idx, end_idx)``. 
+The number of key/value cache tensors is equal to the number of attention heads in the model. We can use the ``generate_with_static_cache`` function to generate the outputs.
+
+Generating Outputs
+------------------- 
+We use custom `generate <https://github.com/pytorch/TensorRT/blob/main/tools/llm/utils.py#L112>`_ function to generate the outputs. This function performs standard autoregressive decoding without KV caching.
+There is also a `generate_with_static_cache <https://github.com/pytorch/TensorRT/blob/main/tools/llm/utils.py#L141>`_ function that performs autoregressive decoding with KV caching.
+
+The ``generate_with_static_cache`` function takes care of preparing the inputs to the model compiled with static KV cache.
+The model inputs are ``input_ids``, ``position_ids``, ``key_cache_0``, ``value_cache_0``, ...., ``start_idx``, ``end_idx``.
+We initialize the key/value cache tensors with zeros and for every token generated, the new key/value cache tensors are the outputs of the model.
+
+SDPA Converter (sdpa_converter.py)
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+* Converts scaled dot-product attention operation using TRT Python API.
+* Supports causal and standard self-attention.
+
+SDPA Registration (register_sdpa.py)
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+* This is a Torch-TensorRT lowering pass that replaces variants of SDPA with ``torch.nn.functional.scaled_dot_product_attention``.
+* Registers the SDPA converter which is used for converting ``torch.nn.functional.scaled_dot_product_attention`` operation.
+
+
+Limitations and Known Issues
+----------------------------
+
+* Sliding window attention (used in Gemma3 and Qwen 3 models) is not yet supported
+* Some model architectures (e.g. Phi-4) have issues with exporting the torch model.
+
+Requirements
+^^^^^^^^^^^^
+
+* Torch-TensorRT 2.8.0 or later
+* Transformers v4.52.3
\ No newline at end of file
diff --git a/examples/dynamo/torch_export_gpt2.py b/examples/dynamo/torch_export_gpt2.py
deleted file mode 100644
index 4d34c58de4..0000000000
--- a/examples/dynamo/torch_export_gpt2.py
+++ /dev/null
@@ -1,98 +0,0 @@
-"""
-.. _torch_export_gpt2:
-
-Compiling GPT2 using the dynamo backend
-==========================================================
-
-This script illustrates Torch-TensorRT workflow with dynamo backend on popular GPT2 model.
-"""
-
-# %%
-# Imports and Model Definition
-# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-import torch
-import torch_tensorrt
-from transformers import AutoModelForCausalLM, AutoTokenizer
-from utils import export_llm, generate
-
-# %%
-
-# Define the parameters and initialize the model
-MAX_TOKENS = 32
-DEVICE = torch.device("cuda:0")
-
-# Define the GPT2 model from hugging face
-# kv_cache is not supported in Torch-TRT currently.
-# CPU is used here so that GPU memory is reserved for TRT compilation.
-with torch.no_grad():
-    tokenizer = AutoTokenizer.from_pretrained("gpt2")
-    model = (
-        AutoModelForCausalLM.from_pretrained(
-            "gpt2",
-            pad_token_id=tokenizer.eos_token_id,
-            use_cache=False,
-            attn_implementation="eager",
-        )
-        .eval()
-        .half()
-    )
-
-# %%
-# Tokenize a sample input prompt and get pytorch model outputs
-prompt = "I enjoy walking with my cute dog"
-model_inputs = tokenizer(prompt, return_tensors="pt")
-input_ids = model_inputs["input_ids"]
-
-# Auto-regressive generation loop for greedy decoding using PyTorch model
-# We use a custom generate function which is very similar to the huggingface one.
-pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)
-
-
-# %%
-# Compilation with `Torch-TensorRT` using dynamo backend and generate TensorRT outputs
-# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-# Export the GPT2 model into an ExportedProgram which is input of TRT compilation
-# To compile the model in FP16, we do the following
-# 1) Cast the model to FP16 via model.half()
-# 2) Enable use_explicit_typing=True. Certain layers are explicitly casted to FP32 within the pytorch model and this flag respects this behavior during TRT compilation
-# 3) Enable use_fp32_acc=True. This ensures all the matmuls are accumulated in FP32 precision (similar to PyTorch)
-gpt2_ep = export_llm(model, input_ids, max_seq_len=1024)
-trt_model = torch_tensorrt.dynamo.compile(
-    gpt2_ep,
-    inputs=[input_ids],
-    enabled_precisions={torch.float32},
-    truncate_double=True,
-    device=DEVICE,
-    disable_tf32=True,
-    use_explicit_typing=True,
-    use_fp32_acc=True,
-)
-
-# Auto-regressive generation loop for greedy decoding using TensorRT model
-# We use a custom generate function which is very similar to the huggingface one.
-# Move inputs to GPU
-input_ids = input_ids.to(DEVICE)
-trt_gen_tokens = generate(trt_model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)
-
-# %%
-# Decode the output sentences of PyTorch and TensorRT
-# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-print("=============================")
-print(
-    "Pytorch model generated text: ",
-    tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True),
-)
-print("=============================")
-print(
-    "TensorRT model generated text: ",
-    tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True),
-)
-
-# Prompt : What is parallel programming ?
-
-# =============================
-# Pytorch model generated text: The parallel programming paradigm is a set of programming languages that are designed to be used in parallel. The main difference between parallel programming and parallel programming is that
-
-# =============================
-# TensorRT model generated text: The parallel programming paradigm is a set of programming languages that are designed to be used in parallel. The main difference between parallel programming and parallel programming is that
diff --git a/examples/dynamo/torch_export_llama2.py b/examples/dynamo/torch_export_llama2.py
deleted file mode 100644
index 2f3e3cba43..0000000000
--- a/examples/dynamo/torch_export_llama2.py
+++ /dev/null
@@ -1,102 +0,0 @@
-"""
-.. _torch_export_llama2:
-
-Compiling Llama2 using the dynamo backend
-==========================================================
-
-This script illustrates Torch-TensorRT workflow with dynamo backend on popular Llama2 model.
-"""
-
-# %%
-# Imports and Model Definition
-# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-import torch
-import torch_tensorrt
-from transformers import AutoModelForCausalLM, AutoTokenizer
-from utils import export_llm, generate
-
-# %%
-# Define the parameters and initialize the model
-MAX_TOKENS = 32
-DEVICE = torch.device("cuda:0")
-
-# Define the Llama2 model from hugging face
-# kv_cache is not supported in Torch-TRT currently.
-# CPU is used here so that GPU memory is reserved for TRT compilation.
-llama_path = "meta-llama/Llama-2-7b-chat-hf"
-with torch.no_grad():
-    model = (
-        AutoModelForCausalLM.from_pretrained(
-            llama_path, use_cache=False, attn_implementation="eager"
-        )
-        .eval()
-        .half()
-    )
-
-tokenizer = AutoTokenizer.from_pretrained(llama_path)
-
-# %%
-# Tokenize a sample input prompt and get pytorch model outputs
-prompt = "What is dynamic programming?"
-model_inputs = tokenizer(prompt, return_tensors="pt")
-input_ids = model_inputs.input_ids
-
-# Auto-regressive generation loop for greedy decoding using PyTorch model
-# We use a custom generate function which is very similar to the huggingface one.
-pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)
-
-# %%
-# Compilation with `Torch-TensorRT` using dynamo backend and generate TensorRT outputs
-# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-# Export the llama2 model into an ExportedProgram which is input of TRT compilation
-# To compile the model in FP16, we do the following
-# 1) Cast the model to FP16 via model.half()
-# 2) Enable use_explicit_typing=True. Certain layers are explicitly casted to FP32 within the pytorch model and this flag respects this behavior during TRT compilation
-# 3) Enable use_fp32_acc=True. This ensures all the matmuls are accumulated in FP32 precision (similar to PyTorch)
-llama2_ep = export_llm(model, input_ids, max_seq_len=64)
-trt_model = torch_tensorrt.dynamo.compile(
-    llama2_ep,
-    inputs=[input_ids],
-    enabled_precisions={torch.float32},
-    truncate_double=True,
-    device=DEVICE,
-    disable_tf32=True,
-    use_explicit_typing=True,
-    use_fp32_acc=True,
-)
-
-# Auto-regressive generation loop for greedy decoding using TensorRT model
-# We use a custom generate function which is very similar to the huggingface one.
-# Move inputs to GPU
-input_ids = input_ids.to(DEVICE)
-trt_gen_tokens = generate(trt_model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)
-
-# %%
-# Decode the output sentences of PyTorch and TensorRT
-# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-print("=============================")
-print(
-    "Pytorch model generated text: ",
-    tokenizer.batch_decode(
-        pyt_gen_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False
-    )[0],
-)
-print("=============================")
-print(
-    "TensorRT model generated text: ",
-    tokenizer.batch_decode(
-        trt_gen_tokens,
-        skip_special_tokens=True,
-        clean_up_tokenization_spaces=False,
-    )[0],
-)
-
-
-# Prompt : What is dynamic programming?
-
-# =============================
-# Pytorch model generated text: Dynamic programming is an algorithmic technique used to solve complex problems by breaking them down into smaller subproblems, solving each subproblem only once, and
-
-# =============================
-# TensorRT model generated text: Dynamic programming is an algorithmic technique used to solve complex problems by breaking them down into smaller subproblems, solving each subproblem only once, and
diff --git a/examples/dynamo/utils.py b/examples/dynamo/utils.py
deleted file mode 100644
index 25ad99c12d..0000000000
--- a/examples/dynamo/utils.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import torch
-from transformers import StoppingCriteriaList
-from transformers.generation.stopping_criteria import (
-    EosTokenCriteria,
-    MaxLengthCriteria,
-)
-
-
-def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
-    """
-    Exports the LLM model into an ExportedProgram with dynamic shapes.
-    In the case of guard failures due to some PyTorch kernel implements, we also
-    try to re-export the graph by expressing them as runtime assert nodes
-    """
-    with torch.no_grad():
-        # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604
-        seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len)
-        try:
-            print("Trying to export the model using torch.export.export()..")
-            # strict=False only enables aotautograd tracing and excludes dynamo.
-            ep = torch.export.export(
-                model, (inputs,), dynamic_shapes=({1: seq_len},), strict=False
-            )
-        except:
-            print(
-                "Trying torch.export._trace._export to trace the graph since torch.export.export() failed"
-            )
-            # This API is used to express the constraint violation guards as asserts in the graph.
-            ep = torch.export._trace._export(
-                model,
-                (inputs,),
-                dynamic_shapes=({1: seq_len},),
-                strict=False,
-                allow_complex_guards_as_runtime_asserts=True,
-            )
-
-    return ep
-
-
-def generate(model, input_seq, max_tokens, eos_token_id):
-    """
-    Greedy decoding of the model. This generates up to max_tokens.
-    """
-    # Max length of output seq = current input_seq length + max_tokens allowed to generate
-    max_output_seq_length = input_seq.shape[1] + max_tokens
-    stopping_criteria = StoppingCriteriaList(
-        [
-            MaxLengthCriteria(max_length=max_output_seq_length),
-            EosTokenCriteria(eos_token_id=eos_token_id),
-        ]
-    )
-
-    while True:
-        outputs = model(input_seq)
-        logits = outputs.logits
-        next_token_logits = logits[:, -1, :]
-        next_tokens = torch.argmax(next_token_logits, dim=-1)
-        input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1)
-        # TODO: Handle batch in this check
-        if stopping_criteria(input_seq, logits).item():
-            break
-
-    return input_seq
diff --git a/examples/dynamo/weight_streaming_example.py b/examples/dynamo/weight_streaming_example.py
index e1076a9e75..601292ba95 100644
--- a/examples/dynamo/weight_streaming_example.py
+++ b/examples/dynamo/weight_streaming_example.py
@@ -32,7 +32,43 @@
 import torch
 import torch_tensorrt
 from transformers import AutoModelForCausalLM
-from utils import export_llm
+
+
+def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
+    """
+    Exports the LLM model into an ExportedProgram with dynamic shapes.
+    In the case of guard failures due to some PyTorch kernel implements, we also
+    try to re-export the graph by expressing them as runtime assert nodes
+    """
+    with torch.no_grad():
+        # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604
+        seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len)
+        position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device)
+        try:
+            print("Trying to export the model using torch.export.export()..")
+            # strict=False only enables aotautograd tracing and excludes dynamo.
+            ep = torch.export.export(
+                model,
+                args=(inputs,),
+                kwargs={"position_ids": position_ids},
+                dynamic_shapes=({1: seq_len}, {1: seq_len}),
+                strict=False,
+            )
+        except:
+            print(
+                "Trying torch.export._trace._export to trace the graph since torch.export.export() failed"
+            )
+            # This API is used to express the constraint violation guards as asserts in the graph.
+            ep = torch.export._trace._export(
+                model,
+                args=(inputs,),
+                kwargs={"position_ids": position_ids},
+                dynamic_shapes=({1: seq_len}, {1: seq_len}),
+                strict=False,
+                allow_complex_guards_as_runtime_asserts=True,
+            )
+
+    return ep
 
 
 def time_generate(model, inputs, output_seq_length, iterations=10):
diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py
index 6434afe248..ff7d3b7a07 100644
--- a/py/torch_tensorrt/dynamo/_compiler.py
+++ b/py/torch_tensorrt/dynamo/_compiler.py
@@ -799,6 +799,28 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
             "Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments."
         )
 
+    # Store the original input spec for later use
+    original_in_spec = getattr(gm, "_in_spec", None)
+    original_out_spec = getattr(gm, "_out_spec", None)
+
+    # Function to preserve and restore module specs
+    def preserve_module_specs(
+        in_spec: Any, out_spec: Any, target_module: torch.fx.GraphModule
+    ) -> None:
+        """
+        Applies input and output specs to the target module.
+
+        Args:
+            in_spec: The input spec to apply
+            out_spec: The output spec to apply
+            target_module: The module to apply specs to
+        """
+        # Apply specs to target module
+        if in_spec is not None:
+            target_module._in_spec = in_spec
+        if out_spec is not None:
+            target_module._out_spec = out_spec
+
     # Partition module into components that can be TRT-accelerated
     fast_partitioner_failed = False
     # If specified, try using the fast partitioner and fall back to the global one on failure
@@ -844,6 +866,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
             continue
         submodule_node_dict[node.name] = node
 
+    preserve_module_specs(original_in_spec, original_out_spec, partitioned_module)
     # Store TRT replicas of Torch subgraphs
     trt_modules = {}
     # Iterate over all components that can be accelerated
diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
index b134b3d5f5..8d7a914836 100644
--- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
+++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
@@ -890,10 +890,9 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
         else:
             return converter(self.ctx, target, args, kwargs, self._cur_node_name)
 
-    def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
+    def get_attr(self, target: str, args: Any, kwargs: Any) -> torch.Tensor:
         with _disable_current_modes(), unset_fake_temporarily():
             frozen_attr = self.fetch_attr(target)
-
             if isinstance(frozen_attr, torch.nn.Parameter):
                 constant_tensor = frozen_attr.data
             else:
diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
index e542f1d417..f243d091a4 100644
--- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
+++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
@@ -1935,6 +1935,7 @@ def aten_ops_minimum(
     )
 
 
+@dynamo_tensorrt_converter(operator.sub, supports_dynamic_shapes=True)
 @dynamo_tensorrt_converter(torch.ops.aten.sub.Tensor, supports_dynamic_shapes=True)
 @dynamo_tensorrt_converter(torch.ops.aten.sub.Scalar, supports_dynamic_shapes=True)
 def aten_ops_sub(
diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
index fc76b20141..1d619b6ce3 100644
--- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
+++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
@@ -752,7 +752,14 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
         # Representation of input shapes to a given model
         # Shapes are concatenated as so:
         # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
-        new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in inputs)
+        tensor_inputs = []
+        for t in inputs:
+            if not isinstance(t, torch.Tensor):
+                return True
+            tensor_inputs.append(t)
+        new_shape_key = "".join(
+            str(tuple(t.shape)).replace(" ", "") for t in tensor_inputs
+        )
 
         # If the new shape key differs from the existing one,
         # invalidate the old shape key and remove the CUDAGraph
diff --git a/tools/llm/README.md b/tools/llm/README.md
new file mode 100644
index 0000000000..a141505517
--- /dev/null
+++ b/tools/llm/README.md
@@ -0,0 +1,67 @@
+# Optimizing LLMs in Torch-TensorRT
+
+This directory provides utilities and scripts for compiling, optimizing, and benchmarking Large Language Models (LLMs) using Torch-TensorRT, with a focus on efficient inference on NVIDIA GPUs. The main entry point is `run_llm.py`, which demonstrates how to export, compile, and run LLMs with various caching strategies and precision modes. Note that this is an **experimental release** and APIs may change in future versions.
+
+### Key Features
+
+- **Model Support:** Works with popular LLMs such as Llama-3, Qwen2.5, etc.
+- **Precision Modes:** Supports FP16, BF16, and FP32.
+- **KV Cache:** Supports static and dynamic KV cache for efficient autoregressive decoding.
+- **Benchmarking:** Measures and compares throughput and latency for PyTorch and TensorRT backends.
+- **Custom Attention:** Registers and converts custom scaled dot-product attention (SDPA) for compatibility with TensorRT.
+
+
+### Supported Models
+
+We have officially verified support for the following models:
+
+| Model Series | HF Model Card | Precision | KV Cache Supported ? |
+|--------------|---------------|-----------|-------------------|
+| GPT-2 | gpt2<br>gpt2-medium | FP16, FP32 | Yes |
+| LLaMA 2 | meta-llama/Llama-2-7b-chat-hf | FP16, FP32 | Yes |
+| LLaMA 3.1 | meta-llama/Llama-3.1-8B-Instruct | FP16, FP32 | Yes |
+| LLaMA 3.2 | meta-llama/Llama-3.2-1B-Instruct<br>meta-llama/Llama-3.2-3B-Instruct | FP16, FP32 | Yes |
+| Qwen 2.5 | Qwen/Qwen2.5-0.5B-Instruct<br>Qwen/Qwen2.5-1.5B-Instruct<br>Qwen/Qwen2.5-4B-Instruct<br>Qwen/Qwen2.5-7B-Instruct | FP16, FP32 | Yes |
+| Qwen 3 | Qwen/Qwen3-0.6B<br>Qwen/Qwen3-1.7B<br>Qwen/Qwen3-4B<br>Qwen/Qwen3-8B | FP16, FP32 | Yes |
+
+
+### Usage
+
+The main entry point is : `run_llm.py`
+
+```bash
+python run_llm.py --model meta-llama/Llama-3.2-1B-Instruct --prompt "What is parallel programming?" --precision FP16 --num_tokens 128 --cache static_v2 --benchmark
+```
+
+#### Key Arguments
+
+- `--model`: Name or path of the HuggingFace LLM.
+- `--tokenizer`: (Optional) Tokenizer name; defaults to model.
+- `--prompt`: Input prompt for generation.
+- `--precision`: Precision mode (`FP16`, `FP32`).
+- `--num_tokens`: Number of output tokens to generate.
+- `--cache`: KV cache type (`static_v1`, `static_v2`, or empty for no KV caching).
+- `--benchmark`: Enable benchmarking mode.
+- `--enable_pytorch_run`: Also run and compare PyTorch baseline.
+
+### Caching Strategies
+
+- **Static Cache v1/v2:** Adds static KV cache tensors as model inputs/outputs for efficient reuse.
+- **No Cache:** Standard autoregressive decoding.
+
+Please read our tutorial on how static cache is implemented.
+
+## Extension
+
+This codebase can be extended to
+- Add new models by specifying their HuggingFace name.
+- Implement new cache strategies by adding FX graph passes.
+- Customize SDPA conversion for new attention mechanisms.
+
+## Limitations
+- We do not currently support sliding window attention (used in Gemma3 and Qwen 3 models) yet.
+
+## Requirements
+
+- Torch-TensorRT 2.8.0
+- Transformers v4.52.3
\ No newline at end of file
diff --git a/tools/llm/cache_utils.py b/tools/llm/cache_utils.py
new file mode 100644
index 0000000000..d25e5bb40e
--- /dev/null
+++ b/tools/llm/cache_utils.py
@@ -0,0 +1,177 @@
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+
+import tensorrt
+import torch
+import torch_tensorrt
+from torch._export.utils import _detect_fake_mode_from_gm
+from torch._ops import OpOverloadPacket
+from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
+from torch.fx import Graph, GraphModule, Node
+from torch.fx.node import Target
+from torch.fx.passes.shape_prop import _extract_tensor_metadata
+from torch.utils._pytree import _LEAF_SPEC
+
+
+def get_kv_nodes(gm):
+    """
+    Extract key and value nodes from scaled dot-product attention operations in the graph.
+
+    This function searches through the graph for scaled_dot_product_attention operations
+    and extracts the key and value tensor nodes from each operation's arguments.
+
+    Args:
+        gm: A torch.fx.GraphModule containing the computational graph
+
+    Returns:
+        List[Tuple[Node, Node]]: A list of tuples, where each tuple contains
+            (key_node, value_node) from a scaled dot-product attention operation
+    """
+    kv_nodes = []
+    for node in gm.graph.nodes:
+        if (
+            node.op == "call_function"
+            and node.target == torch._C._nn.scaled_dot_product_attention
+        ):
+            q_node, k_node, v_node = node.args[:3]
+            kv_nodes.append((k_node, v_node))
+    return kv_nodes
+
+
+def get_random_tensor_from_node(node: Node) -> torch.Tensor:
+    """
+    Creates a random tensor based on the shape information in a node's metadata.
+    For symbolic dimensions, extracts the maximum value from the shape environment.
+
+    Args:
+        node: A torch.fx.Node object with metadata containing tensor information
+
+    Returns:
+        A random tensor with shape matching the node's metadata, or None if no valid
+        tensor information is found
+    """
+    if "val" not in node.meta:
+        raise ValueError(
+            f"No tensor information found in node metadata for node: {node}"
+        )
+
+    fake_tensor = node.meta["val"]
+    shape = []
+
+    # Iterate through each dimension and handle symbolic dimensions
+    for dim in fake_tensor.shape:
+        if isinstance(dim, torch.SymInt):
+            # Extract the maximum value from the shape environment
+            max_val = dim.node.hint
+            shape.append(max_val)
+        else:
+            shape.append(dim)
+
+    # Create a random tensor with the determined shape
+    dtype = fake_tensor.dtype
+    device = fake_tensor.device
+    random_tensor = torch.rand(shape, dtype=dtype, device=device)
+
+    return random_tensor
+
+
+def create_random_output_tensors(nodes: List[Node]) -> List[torch.Tensor]:
+    """
+    Creates random tensors based on the shape information in node metadata.
+    For symbolic dimensions, extracts the maximum value from the shape environment.
+
+    Args:
+        nodes: List of torch.fx.Node objects with metadata
+
+    Returns:
+        List of random tensors with shapes matching the nodes' metadata
+    """
+    random_tensors = []
+
+    for node in nodes:
+        if isinstance(node, Node):
+            node_tensor = get_random_tensor_from_node(node)
+        elif isinstance(node, tuple):
+            node_tensor_list = []
+            for n in node:
+                random_tensor = get_random_tensor_from_node(n)
+                node_tensor_list.append(random_tensor)
+            node_tensor = tuple(node_tensor_list)
+
+        random_tensors.append(node_tensor)
+
+    return random_tensors
+
+
+def _add_graph_input(
+    gm: GraphModule, name: str, val: Optional[torch.Tensor] = None, dynamic_shape=None
+) -> Node:
+    """Add a graph input to the given GraphModule and return the newly created node.
+
+    NOTE: function does NOT do any graph canonicalization. This is left to the user!
+
+    Args:
+        gm (GraphModule): The GraphModule to add the input to.
+        name (str): The name of the input.
+        val (torch.Tensor): An example tensor to use for the input.
+        dynamic_shape: The dynamic shape of the input tensor [NOT SUPPORTED YET]
+    """
+    # check that no dynamic shape is provided...
+    if dynamic_shape:
+        raise NotImplementedError("Dynamic shape not supported for adding graph inputs")
+
+    # extract graph and input spec
+    graph: Graph = gm.graph
+
+    in_spec = graph._codegen.pytree_info.in_spec
+    in_spec_for_args = in_spec.children_specs[0]
+    orig_args = graph._codegen.pytree_info.orig_args
+    assert in_spec_for_args.type is tuple
+
+    # insert input node after currently last input node
+    node_last_input = graph.find_nodes(op="placeholder", sort=True)[-1]
+    with graph.inserting_after(node_last_input):
+        in_node = graph.placeholder(name)
+        in_spec_for_args.children_specs.append(_LEAF_SPEC)
+        orig_args.append(f"arg_{name}")
+
+    # update pytree info recursively with __post_init__ starting at leaves
+    def call_post_init(spec):
+        for child_spec in spec.children_specs:
+            call_post_init(child_spec)
+        spec.__post_init__()
+
+    call_post_init(in_spec)
+
+    # set fake tensor information if all required information is available
+    fake_mode: Optional[FakeTensorMode] = _detect_fake_mode_from_gm(gm)
+    if fake_mode and val is not None and isinstance(val, torch.Tensor):
+        if isinstance(val, FakeTensor):
+            fake_tensor = val
+        else:
+            fake_tensor: FakeTensor = fake_mode.from_tensor(val, static_shapes=True)
+        in_node.meta["val"] = fake_tensor
+        in_node.meta["tensor_meta"] = _extract_tensor_metadata(fake_tensor)
+
+    # return new node...
+    return in_node
+
+
+def is_op(node: Node, ops: Union[OpOverloadPacket, Iterable[OpOverloadPacket]]) -> bool:
+    """Check if the node is a call to one of the ops."""
+    if node.op != "call_function":
+        return False
+    # check if it's a single op that's provided
+    if isinstance(ops, OpOverloadPacket):
+        ops = [ops]
+
+    # check if it's the op itself instead of an overload
+    if any(node.target == op for op in ops):
+        return True
+
+    return False
+
+
+def get_all_input_output_nodes(graph: Graph) -> Tuple[List[Node], List[Node]]:
+    input_nodes: List[Node] = graph.find_nodes(op="placeholder")
+    output_nodes: List[Node] = graph.find_nodes(op="output")
+    return (input_nodes, output_nodes)
diff --git a/tools/llm/run_llm.py b/tools/llm/run_llm.py
new file mode 100644
index 0000000000..7e50b515c2
--- /dev/null
+++ b/tools/llm/run_llm.py
@@ -0,0 +1,357 @@
+"""
+.. _run_llm:
+
+Running LLM inference with Torch-TensorRT
+==========================================================
+
+This script illustrates Torch-TensorRT workflow with dynamo backend on popular LLM models.
+"""
+
+import argparse
+import copy
+import os
+import timeit
+from contextlib import nullcontext
+
+# %%
+# Imports and Model Definition
+# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+import torch
+import torch_tensorrt
+from torchtrt_ext import register_sdpa
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from utils import (
+    export_llm,
+    generate,
+    generate_with_static_cache,
+    record_stats,
+    time_generate,
+)
+
+DEVICE = torch.device("cuda:0")
+
+
+def get_model(args):
+    """
+    Load and configure the language model for inference.
+
+    This function loads a pre-trained causal language model using the specified
+    model name and configures it with the appropriate precision and settings
+    for inference.
+
+    Args:
+        args: Parsed command line arguments containing:
+            - model (str): Name or path of the model to load
+            - precision (str): Precision to use ("FP16", "BF16", or "FP32")
+
+    Returns:
+        torch.nn.Module: The loaded and configured model ready for inference,
+            moved to CUDA device with the specified precision
+    """
+    with torch.no_grad():
+        model = (
+            AutoModelForCausalLM.from_pretrained(
+                args.model,
+                use_cache=False,
+                attn_implementation="sdpa",
+            )
+            .eval()
+            .cuda()
+        )
+
+    if args.precision == "FP16":
+        model = model.to(torch.float16)
+    elif args.precision == "BF16":
+        model = model.to(torch.bfloat16)
+    else:
+        model = model.to(torch.float32)
+
+    return model
+
+
+def compile_torchtrt(model, input_ids, args):
+    """
+    Compile a PyTorch model to TensorRT using torch_tensorrt.dynamo.compile.
+
+    This function exports the given model to a TorchScript representation and then
+    compiles it to TensorRT for optimized inference. The compilation process includes
+    precision-specific optimizations and various performance tuning parameters.
+
+    Args:
+        model (torch.nn.Module): The PyTorch model to compile
+        input_ids (torch.Tensor): Input token IDs tensor used for model export
+        args: Parsed command line arguments containing:
+            - num_tokens (int): Number of tokens to generate (used for max sequence length)
+            - precision (str): Precision to use ("FP16", "BF16", or "FP32")
+            - debug (bool): Whether to enable debug logging
+            - min_block_size (int): Minimum block size for TensorRT compilation
+
+    Returns:
+        torch_tensorrt.dynamo.TorchTensorRTModule: The compiled TensorRT model ready
+            for optimized inference
+    """
+    max_seq_len = input_ids.shape[1] + args.num_tokens
+    ep = export_llm(model, input_ids, max_seq_len=max_seq_len)
+    position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE)
+    # Set precision specific flags
+    use_fp32_acc = False
+    use_explicit_typing = False
+    if args.precision == "FP16":
+        enabled_precisions = {torch.float32}
+        use_fp32_acc = True
+        use_explicit_typing = True
+    elif args.precision == "BF16":
+        enabled_precisions = {torch.bfloat16}
+        use_fp32_acc = False
+    else:
+        enabled_precisions = {torch.float32}
+
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+        trt_model = torch_tensorrt.dynamo.compile(
+            ep,
+            inputs=[input_ids, position_ids],
+            enabled_precisions=enabled_precisions,
+            # truncate_double=True,
+            use_explicit_typing=use_explicit_typing,
+            use_fp32_acc=use_fp32_acc,
+            device=DEVICE,
+            disable_tf32=True,
+            use_python_runtime=True,
+            debug=args.debug,
+            offload_module_to_cpu=True,
+            min_block_size=args.min_block_size,
+        )
+
+    return trt_model
+
+
+def print_outputs(backend_name, gen_tokens, tokenizer):
+    """
+    Print the generated tokens from the model.
+    """
+    print(f"========= {backend_name} =========")
+    print(
+        f"{backend_name} model generated text: ",
+        tokenizer.decode(gen_tokens[0], skip_special_tokens=True),
+    )
+    print("===================================")
+
+
+def measure_perf(trt_model, input_signature, backend_name):
+    """
+    Measure the performance of a TensorRT model by running it multiple times and
+    calculating the average time per iteration.
+    """
+    total_time = 0
+    iterations = 10
+
+    print("Running warmup iteration...")
+    # Warmup run
+    _ = trt_model(*input_signature)
+    torch.cuda.synchronize()
+
+    print(f"Measuring performance over {iterations} iterations...")
+    for i in range(iterations):
+        start_time = timeit.default_timer()
+        _ = trt_model(*input_signature)
+        torch.cuda.synchronize()
+        end_time = timeit.default_timer()
+        iter_time = end_time - start_time
+        total_time += iter_time
+
+    avg_time = total_time / iterations
+    print(
+        f"Backend: {backend_name} Average time per iteration: {avg_time*1000:.4f} milliseconds"
+    )
+    print(
+        f"Backend: {backend_name} Average throughput: {1.0/avg_time:.2f} iterations/second"
+    )
+
+
+if __name__ == "__main__":
+    arg_parser = argparse.ArgumentParser(
+        description="Run inference on a model with random input values"
+    )
+    arg_parser.add_argument(
+        "--model",
+        type=str,
+        default="meta-llama/Llama-3.2-1B-Instruct",
+        help="Name of LLM model",
+    )
+    arg_parser.add_argument(
+        "--tokenizer",
+        type=str,
+        default="",
+        help="Name of LLM model tokenizer",
+    )
+    arg_parser.add_argument(
+        "--prompt", type=str, default="What is parallel programming ?", help="Prompt"
+    )
+    arg_parser.add_argument(
+        "--precision",
+        type=str,
+        default="FP16",
+        help="Precision to use in the model. Options: FP16, BF16, FP32",
+    )
+    arg_parser.add_argument(
+        "--iterations", type=int, default=5, help="no. of iterations to run"
+    )
+    arg_parser.add_argument(
+        "--min_block_size", type=int, default=1, help="no. of iterations to run"
+    )
+    arg_parser.add_argument(
+        "--num_tokens",
+        type=int,
+        default=128,
+        help="no. of output tokens to be generated",
+    )
+    arg_parser.add_argument(
+        "--batch_size", type=int, default=1, help="Batch size used for benchmarking"
+    )
+    arg_parser.add_argument(
+        "--isl",
+        type=int,
+        default=2048,
+        help="Input sequence length used for benchmarking",
+    )
+    arg_parser.add_argument(
+        "--enable_pytorch_run",
+        action="store_true",
+        help="Enable pytorch run (default: False)",
+    )
+    arg_parser.add_argument(
+        "--cache",
+        type=str,
+        default="",
+        help="Type of KV cache to use. Options: static_v1, static_v2",
+    )
+    arg_parser.add_argument(
+        "--cudagraph", action="store_true", help="Enable cudagraphs (default: False)"
+    )
+    arg_parser.add_argument(
+        "--debug", action="store_true", help="Enable debug (default: False)"
+    )
+    arg_parser.add_argument(
+        "--benchmark", action="store_true", help="Enable benchmark (default: False)"
+    )
+
+    args = arg_parser.parse_args()
+    with torch.inference_mode():
+        model = get_model(args)
+
+        tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model)
+
+        # Prepare input for benchmarking or evaluation
+        if args.benchmark:
+            input_ids = torch.randint(
+                1, 10000, (args.batch_size, args.isl), dtype=torch.int64
+            ).to(model.device)
+            position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE)
+        else:
+            model_inputs = tokenizer(args.prompt, return_tensors="pt")
+            input_ids = model_inputs["input_ids"].to(DEVICE)
+            position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE)
+
+        MAX_OUTPUT_SEQ_LENGTH = input_ids.shape[1] + args.num_tokens
+        # Pyt
+        pyt_gen_tokens = None
+        pyt_timings = None
+        pyt_stats = None
+
+        if args.enable_pytorch_run:
+            pyt_gen_tokens = generate(
+                model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id
+            )
+            if args.benchmark:
+                pyt_timings = time_generate(
+                    generate,
+                    model,
+                    input_ids.clone(),
+                    MAX_OUTPUT_SEQ_LENGTH,
+                    tokenizer.eos_token_id,
+                    iterations=args.iterations,
+                )
+                pyt_stats = record_stats(
+                    "PyTorch",
+                    pyt_timings,
+                    args.precision,
+                    batch_size=args.batch_size,
+                    compile_time_s=None,
+                )
+
+        if args.cache == "static_v1":
+            # This import is required to register static v1 KV cache transformations as lowering passes
+            import static_cache_v1
+        if args.cache == "static_v2":
+            # This import is required to register static v2 KV cache transformations as lowering passes
+            import static_cache_v2
+
+        # Compile the model with Torch-TensorRT
+        trt_model = compile_torchtrt(model, input_ids, args)
+
+        if args.cache == "static_v1" or args.cache == "static_v2":
+            if args.cudagraph:
+                # Run a decoding loop with prefill and generate phases so that the CUDAGraph is recorded for both of these phases.
+                # trt_input_signature = (input_ids.clone(),) + get_zeroed_kv_cache_inputs(trt_model)
+                torch_tensorrt.runtime.set_cudagraphs_mode(True)
+
+            trt_gen_tokens = generate_with_static_cache(
+                trt_model,
+                input_ids.clone(),
+                MAX_OUTPUT_SEQ_LENGTH,
+                tokenizer.eos_token_id,
+            )
+
+            if args.benchmark:
+                trt_timings = time_generate(
+                    generate_with_static_cache,
+                    trt_model,
+                    input_ids.clone(),
+                    MAX_OUTPUT_SEQ_LENGTH,
+                    tokenizer.eos_token_id,
+                    iterations=args.iterations,
+                )
+        else:
+            trt_gen_tokens = generate(
+                trt_model,
+                input_ids.clone(),
+                MAX_OUTPUT_SEQ_LENGTH,
+                tokenizer.eos_token_id,
+            )
+            if args.benchmark:
+                trt_timings = time_generate(
+                    generate,
+                    trt_model,
+                    input_ids.clone(),
+                    MAX_OUTPUT_SEQ_LENGTH,
+                    tokenizer.eos_token_id,
+                    iterations=args.iterations,
+                )
+
+        if args.benchmark:
+            trt_stats = record_stats(
+                "TensorRT",
+                trt_timings,
+                args.precision,
+                batch_size=args.batch_size,
+                compile_time_s=None,
+            )
+
+        if not args.benchmark:
+            if args.enable_pytorch_run:
+                print_outputs("PyTorch", pyt_gen_tokens, tokenizer)
+
+            print_outputs("TensorRT", trt_gen_tokens, tokenizer)
+
+            if args.enable_pytorch_run:
+                print(
+                    f"PyTorch and TensorRT outputs match: {torch.equal(pyt_gen_tokens, trt_gen_tokens)}"
+                )
+
+        if args.benchmark:
+            if args.enable_pytorch_run:
+                print("=========PyTorch PERFORMANCE============ \n")
+                print(pyt_stats)
+            print("===================== \n")
+            print("=========TensorRT PERFORMANCE============ \n")
+            print(trt_stats)
diff --git a/tools/llm/static_cache_v1.py b/tools/llm/static_cache_v1.py
new file mode 100644
index 0000000000..b60396c08b
--- /dev/null
+++ b/tools/llm/static_cache_v1.py
@@ -0,0 +1,277 @@
+import logging
+from typing import List, Tuple
+
+import torch
+import torch.utils._pytree as pytree
+from cache_utils import _add_graph_input, create_random_output_tensors, get_kv_nodes
+from torch.fx import Node
+from torch_tensorrt.dynamo._settings import CompilationSettings
+from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
+    _aten_lowering_pass,
+)
+from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
+    clean_up_graph_after_modifications,
+)
+from torch_tensorrt.dynamo.utils import extract_var_range_info
+
+logger = logging.getLogger(__name__)
+
+SDPA_OP = torch._C._nn.scaled_dot_product_attention
+
+
+def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]):
+    """
+    Modifies the graph to add query, key, and value tensors as outputs.
+
+    This function identifies all scaled dot-product attention (SDPA) operations
+    in the graph, creates copies of their query, key, and value inputs, and adds
+    these copies to the graph's outputs. This allows for accessing these tensors
+    externally, which is useful for operations like key-value caching.
+
+    Args:
+        graph: The torch.fx.Graph to modify
+
+    Returns:
+        None. The graph is modified in-place.
+    """
+    output_node = next(node for node in gm.graph.nodes if node.op == "output")
+
+    # Get the current output args (typically a tuple)
+    current_outputs = output_node.args[0]
+
+    # If the current output is a tuple, extend it with our new outputs
+    if isinstance(current_outputs, tuple):
+        new_outputs = current_outputs + tuple(kv_cache_for_graph)
+    else:
+        # If there's only one output or it's not a tuple, create a new tuple
+        new_outputs = (current_outputs,) + tuple(kv_cache_for_graph)
+
+    gm.graph.output(new_outputs)
+    gm.graph.erase_node(output_node)
+
+    return new_outputs
+
+
+def add_kv_cache_inputs(gm, fixed_kv: bool = True):
+    """
+    Add key-value tensors, index parameters as inputs to the graph.
+
+    Args:
+        gm: The GraphModule to modify
+        fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True.
+
+    Returns:
+        A tuple containing:
+        - List of (k_input, v_input) node pairs for each SDPA operation
+        - start_idx input node for slicing operations
+        - end_idx input node for slicing operations
+    """
+
+    def get_static_tensor(tensor: torch.Tensor):
+        key_shape = []
+        for dim in tensor.shape:
+            if isinstance(dim, torch.SymInt):
+                min_max_opt = extract_var_range_info(dim)
+                key_shape.append(min_max_opt["max"])
+            else:
+                key_shape.append(dim)
+
+        static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
+        return static_tensor
+
+    keys_values = get_kv_nodes(gm)
+
+    kv_inputs = []
+    for idx, key_value in enumerate(keys_values):
+        k_val = key_value[0].meta["val"]
+        v_val = key_value[1].meta["val"]
+        if fixed_kv:
+            k_val = get_static_tensor(k_val)
+            v_val = get_static_tensor(v_val)
+
+        # Add new inputs using _add_graph_input
+        k_input = _add_graph_input(gm, key_value[0].name + "_k_input", k_val)
+        v_input = _add_graph_input(gm, key_value[1].name + "_v_input", v_val)
+        kv_inputs.append((k_input, v_input))
+
+    # Add start_idx and end_idx as inputs
+    start_idx_input = _add_graph_input(gm, "start_idx", torch.tensor(0))
+    end_idx_input = _add_graph_input(gm, "end_idx", torch.tensor(1))
+
+    # Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, is_causal, key_cache1, value_cache1, key_cache2, value_cache2, ..
+    input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
+    input_ids_meta = input_nodes[0].meta["val"]
+    seq_len = input_ids_meta.shape[1]
+    min_max_opt = extract_var_range_info(seq_len)
+    max_seq_len = min_max_opt["max"]
+
+    from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
+    shape_env = ShapeEnv()
+    # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive
+    start_idx_unbacked_symint = shape_env.create_unbacked_symint()
+    torch._check(start_idx_unbacked_symint >= 0)
+    torch._check(start_idx_unbacked_symint <= max_seq_len)
+
+    end_idx_unbacked_symint = shape_env.create_unbacked_symint()
+    torch._check(end_idx_unbacked_symint >= 0)
+    torch._check(end_idx_unbacked_symint <= max_seq_len)
+    # Set the symbolic ints as the metadata for start_idx and end_idx inputs
+    start_idx_input.meta["val"] = start_idx_unbacked_symint
+    end_idx_input.meta["val"] = end_idx_unbacked_symint
+
+    return kv_inputs, start_idx_input, end_idx_input
+
+
+def insert_kv_slicing_before_sdpa(
+    gm,
+    incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]],
+    start_idx_input: Node,
+    end_idx_input: Node,
+):
+    """
+    Insert slicing operations before each scaled_dot_product_attention operation.
+    """
+    # Find all nodes with scaled_dot_product_attention
+    sdpa_nodes = []
+    for node in gm.graph.nodes:
+        if node.op == "call_function" and node.target == SDPA_OP:
+            sdpa_nodes.append(node)
+    kv_cache_for_graph = []
+    for idx, sdpa_node in enumerate(sdpa_nodes):
+        assert (
+            len(sdpa_node.args) == 6
+        ), f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments"
+        q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args
+        incoming_key, incoming_value = incoming_keys_values[idx]
+        kv_cache_for_sdpa_node = []
+        new_keys_values = []
+        for key_or_value, current_key_or_value_node in zip(
+            [incoming_key, incoming_value], [k_node, v_node]
+        ):
+            # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
+            with gm.graph.inserting_before(sdpa_node):
+                slice_1 = gm.graph.create_node(
+                    "call_function",
+                    torch.ops.aten.slice.Tensor,
+                    args=(key_or_value,),
+                    kwargs={},
+                )
+                slice_2 = gm.graph.create_node(
+                    "call_function",
+                    torch.ops.aten.slice.Tensor,
+                    args=(slice_1, 1),
+                    kwargs={},
+                )
+                slice_3 = gm.graph.create_node(
+                    "call_function",
+                    torch.ops.aten.slice.Tensor,
+                    args=(slice_2, 2, None, start_idx_input),
+                    kwargs={},
+                )
+                slice_4 = gm.graph.create_node(
+                    "call_function",
+                    torch.ops.aten.slice.Tensor,
+                    args=(slice_3, 3),
+                    kwargs={},
+                )
+                # =============================================== #
+                # Create a slice node for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
+                slice_5 = gm.graph.create_node(
+                    "call_function",
+                    torch.ops.aten.slice.Tensor,
+                    args=(key_or_value,),
+                    kwargs={},
+                )
+                slice_6 = gm.graph.create_node(
+                    "call_function",
+                    torch.ops.aten.slice.Tensor,
+                    args=(slice_5, 1),
+                    kwargs={},
+                )
+                slice_7 = gm.graph.create_node(
+                    "call_function",
+                    torch.ops.aten.slice.Tensor,
+                    args=(slice_6, 2, end_idx_input),
+                    kwargs={},
+                )
+                slice_8 = gm.graph.create_node(
+                    "call_function",
+                    torch.ops.aten.slice.Tensor,
+                    args=(slice_7, 3),
+                    kwargs={},
+                )
+                # =============================================== #
+                # Concatenate the sliced tensors to build KV cache
+                cat = gm.graph.create_node(
+                    "call_function",
+                    torch.ops.aten.cat.default,
+                    args=([slice_4, current_key_or_value_node, slice_8], 2),
+                    kwargs={},
+                )
+                # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph
+                cat.meta.update(key_or_value.meta)
+                kv_cache_for_sdpa_node.append(cat)
+                # =============================================== #
+                # Get the current key and value by indexing the KV cache
+                slice_9 = gm.graph.create_node(
+                    "call_function", torch.ops.aten.slice.Tensor, args=(cat,), kwargs={}
+                )
+                slice_10 = gm.graph.create_node(
+                    "call_function",
+                    torch.ops.aten.slice.Tensor,
+                    args=(slice_9, 1),
+                    kwargs={},
+                )
+                slice_11 = gm.graph.create_node(
+                    "call_function",
+                    torch.ops.aten.slice.Tensor,
+                    args=(slice_10, 2, None, end_idx_input),
+                    kwargs={},
+                )
+                slice_12 = gm.graph.create_node(
+                    "call_function",
+                    torch.ops.aten.slice.Tensor,
+                    args=(slice_11, 3),
+                    kwargs={},
+                )
+                new_keys_values.append(slice_12)
+
+        kv_cache_for_graph.extend(kv_cache_for_sdpa_node)
+
+        sdpa_node.args = (q_node, new_keys_values[0], new_keys_values[1]) + (
+            attn_mask,
+            dropout_p,
+            True,
+        )
+
+    return gm, kv_cache_for_graph
+
+
+@_aten_lowering_pass
+def insert_static_cache_v1(
+    gm: torch.fx.GraphModule, settings: CompilationSettings
+) -> torch.fx.GraphModule:
+    """Insert KV cache ops in the graph"""
+    """Perform insertion of kv-caches and attention kernel."""
+    # Add static key and value as inputs to the graph
+    kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True)
+
+    # Build and update the KV cache using computed KV inputs for current token and
+    # incoming keys and values from previous tokens (which were added as inputs)
+    gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(
+        gm, kv_inputs, start_idx_input, end_idx_input
+    )
+
+    # Call the function to add KV as outputs
+    logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph)
+
+    gm = clean_up_graph_after_modifications(gm)
+
+    new_output_tensors = create_random_output_tensors(logits_keys_values)
+
+    new_out_spec = pytree.tree_flatten(new_output_tensors)[1]
+    gm._out_spec = new_out_spec
+    logger.debug("After inserting KV cache into the graph: " + str(gm.graph))
+
+    return gm
diff --git a/tools/llm/static_cache_v2.py b/tools/llm/static_cache_v2.py
new file mode 100644
index 0000000000..4634b79a52
--- /dev/null
+++ b/tools/llm/static_cache_v2.py
@@ -0,0 +1,290 @@
+import logging
+from typing import List, Tuple
+
+import torch
+import torch.utils._pytree as pytree
+from cache_utils import _add_graph_input, create_random_output_tensors, get_kv_nodes
+from torch.fx import Node
+from torch_tensorrt.dynamo._settings import CompilationSettings
+from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
+    _aten_lowering_pass,
+)
+from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
+    clean_up_graph_after_modifications,
+)
+from torch_tensorrt.dynamo.utils import extract_var_range_info
+
+logger = logging.getLogger(__name__)
+
+SDPA_OP = torch._C._nn.scaled_dot_product_attention
+
+
+def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Tensor]]):
+    """
+    Modifies the graph to add query, key, and value tensors as outputs.
+
+    This function identifies all scaled dot-product attention (SDPA) operations
+    in the graph, creates copies of their query, key, and value inputs, and adds
+    these copies to the graph's outputs. This allows for accessing these tensors
+    externally, which is useful for operations like key-value caching.
+
+    Args:
+        graph: The torch.fx.Graph to modify
+
+    Returns:
+        None. The graph is modified in-place.
+    """
+    output_node = next(node for node in gm.graph.nodes if node.op == "output")
+
+    # Get the current output args (typically a tuple)
+    current_outputs = output_node.args[0]
+
+    # If the current output is a tuple, extend it with our new outputs
+    if isinstance(current_outputs, tuple):
+        new_outputs = current_outputs + tuple(kv_cache_for_graph)
+    else:
+        # If there's only one output or it's not a tuple, create a new tuple
+        new_outputs = (current_outputs,) + tuple(kv_cache_for_graph)
+
+    gm.graph.output(new_outputs)
+    gm.graph.erase_node(output_node)
+
+    return new_outputs
+
+
+def add_kv_cache_inputs(gm, fixed_kv: bool = True):
+    """
+    Add key-value tensors, index parameters as inputs to the graph.
+
+    Args:
+        gm: The GraphModule to modify
+        fixed_kv: Boolean indicating whether to use static tensors for KV cache. Default is True.
+
+    Returns:
+        A tuple containing:
+        - List of (k_input, v_input) node pairs for each SDPA operation
+        - start_idx input node for slicing operations
+        - end_idx input node for slicing operations
+    """
+
+    def get_static_tensor(tensor: torch.Tensor):
+        key_shape = []
+        for dim in tensor.shape:
+            if isinstance(dim, torch.SymInt):
+                min_max_opt = extract_var_range_info(dim)
+                key_shape.append(min_max_opt["max"])
+            else:
+                key_shape.append(dim)
+
+        static_tensor = torch.randn(key_shape, dtype=tensor.dtype, device=tensor.device)
+        return static_tensor
+
+    keys_values = get_kv_nodes(gm)
+
+    kv_inputs = []
+    for idx, key_value in enumerate(keys_values):
+        k_val = key_value[0].meta["val"]
+        v_val = key_value[1].meta["val"]
+        if fixed_kv:
+            k_val = get_static_tensor(k_val)
+            v_val = get_static_tensor(v_val)
+
+        # Add new inputs using _add_graph_input
+        k_input = _add_graph_input(gm, key_value[0].name + "_k_input", k_val)
+        v_input = _add_graph_input(gm, key_value[1].name + "_v_input", v_val)
+        kv_inputs.append((k_input, v_input))
+
+    # Add start_idx and end_idx as inputs
+    start_idx_input = _add_graph_input(gm, "start_idx", torch.tensor(0))
+    end_idx_input = _add_graph_input(gm, "end_idx", torch.tensor(1))
+
+    # Get the max sequence length from the first key_cache node. The order of input nodes is: input_ids, key_cache1, value_cache1, key_cache2, value_cache2, start_idx, end_idx
+    input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
+    # Get the third last input which should be the last value cache node and store the max_seq_len
+    input_ids_meta = input_nodes[-3].meta["val"]
+    seq_len = input_ids_meta.shape[2]
+
+    if isinstance(seq_len, torch.SymInt):
+        min_max_opt = extract_var_range_info(seq_len)
+        max_seq_len = min_max_opt["max"]
+    else:
+        max_seq_len = seq_len
+
+    from torch.fx.experimental.symbolic_shapes import ShapeEnv
+
+    shape_env = ShapeEnv()
+    # Create symbolic ints for start_idx and end_idx with range [0, seq_len] inclusive
+    start_idx_unbacked_symint = shape_env.create_unbacked_symint()
+    torch._check(start_idx_unbacked_symint >= 0)
+    torch._check(start_idx_unbacked_symint <= max_seq_len)
+
+    end_idx_unbacked_symint = shape_env.create_unbacked_symint()
+    torch._check(end_idx_unbacked_symint >= 0)
+    torch._check(end_idx_unbacked_symint <= max_seq_len)
+    # Set the symbolic ints as the metadata for start_idx and end_idx inputs
+    start_idx_input.meta["val"] = start_idx_unbacked_symint
+    end_idx_input.meta["val"] = end_idx_unbacked_symint
+
+    return kv_inputs, start_idx_input, end_idx_input
+
+
+def create_kv_cache_update_nodes(
+    gm, sdpa_node, current_kv_node, incoming_kv_node, start_idx_input, end_idx_input
+):
+    """
+    Create slicing and concatenation nodes for KV cache update.
+
+    This function creates the necessary slicing and concatenation nodes to update the KV cache
+    during the generation process. It takes the SDPA node, the current KV cache node, and the
+    incoming KV cache node as input.
+    Returns:
+        for a particular SDPA node, a tuple containing:
+        - List of new current KV  nodes
+        - List of updated incoming KV cache nodes
+
+    """
+
+    # Create a slice node for key_cache[:,:,:start_idx,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
+    with gm.graph.inserting_before(sdpa_node):
+        slice_1 = gm.graph.create_node(
+            "call_function",
+            torch.ops.aten.slice.Tensor,
+            args=(incoming_kv_node,),
+            kwargs={},
+        )
+        slice_2 = gm.graph.create_node(
+            "call_function", torch.ops.aten.slice.Tensor, args=(slice_1, 1), kwargs={}
+        )
+        slice_3 = gm.graph.create_node(
+            "call_function",
+            torch.ops.aten.slice.Tensor,
+            args=(slice_2, 2, None, start_idx_input),
+            kwargs={},
+        )
+        slice_4 = gm.graph.create_node(
+            "call_function", torch.ops.aten.slice.Tensor, args=(slice_3, 3), kwargs={}
+        )
+        # Concat key_cache[:,:,:start_idx,:] with current key (k)
+        concat_keys_or_values = gm.graph.create_node(
+            "call_function",
+            torch.ops.aten.cat.default,
+            args=([slice_4, current_kv_node], 2),
+            kwargs={},
+        )
+
+        # =============================================== #
+        # Create nodes for key_cache[:,:, end_idx:,:]. The shape of key_cache is batch_size x num_heads x seq_len x head_dim
+        slice_5 = gm.graph.create_node(
+            "call_function",
+            torch.ops.aten.slice.Tensor,
+            args=(incoming_kv_node,),
+            kwargs={},
+        )
+        slice_6 = gm.graph.create_node(
+            "call_function", torch.ops.aten.slice.Tensor, args=(slice_5, 1), kwargs={}
+        )
+        slice_7 = gm.graph.create_node(
+            "call_function",
+            torch.ops.aten.slice.Tensor,
+            args=(slice_6, 2, end_idx_input),
+            kwargs={},
+        )
+        slice_8 = gm.graph.create_node(
+            "call_function", torch.ops.aten.slice.Tensor, args=(slice_7, 3), kwargs={}
+        )
+        # =============================================== #
+        # Concatenate the sliced tensors to build KV cache
+        new_incoming_keys_or_values = gm.graph.create_node(
+            "call_function",
+            torch.ops.aten.cat.default,
+            args=([concat_keys_or_values, slice_8], 2),
+            kwargs={},
+        )
+        # Update the metadata of the newly built KV cache node with the metadata of the input KV cache node to the graph
+        new_incoming_keys_or_values.meta.update(incoming_kv_node.meta)
+
+    return concat_keys_or_values, new_incoming_keys_or_values
+
+
+def insert_kv_slicing_before_sdpa(
+    gm,
+    incoming_keys_values: List[Tuple[torch.Tensor, torch.Tensor]],
+    start_idx_input: Node,
+    end_idx_input: Node,
+):
+    """
+    Insert slicing and concatenation operations before each scaled_dot_product_attention operation as per the following KV cache update logic:
+    concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2)
+    concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2)
+    new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2)
+    new_value_cache = torch.cat((concat_values, value_cache[:, :, end_idx:, :]), dim=2)
+    out = torch._C._nn.scaled_dot_product_attention(q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal)
+    """
+    # Find all nodes with scaled_dot_product_attention
+    sdpa_nodes = []
+    for node in gm.graph.nodes:
+        if node.op == "call_function" and node.target == SDPA_OP:
+            sdpa_nodes.append(node)
+    kv_cache_for_graph = []
+    for idx, sdpa_node in enumerate(sdpa_nodes):
+        assert (
+            len(sdpa_node.args) == 6
+        ), f"SDPA node should have 6 arguments but got {len(sdpa_node.args)} arguments"
+        q_node, k_node, v_node, attn_mask, dropout_p, is_causal = sdpa_node.args
+        incoming_key, incoming_value = incoming_keys_values[idx]
+        # For keys
+        new_current_key_node, new_incoming_key_cache_node = (
+            create_kv_cache_update_nodes(
+                gm, sdpa_node, k_node, incoming_key, start_idx_input, end_idx_input
+            )
+        )
+        # For values
+        new_current_value_node, new_incoming_value_cache_node = (
+            create_kv_cache_update_nodes(
+                gm, sdpa_node, v_node, incoming_value, start_idx_input, end_idx_input
+            )
+        )
+
+        # Store the KV cache nodes for the current SDPA node
+        kv_cache_for_graph.extend(
+            [new_incoming_key_cache_node, new_incoming_value_cache_node]
+        )
+
+        # Update the SDPA node arguments with current key and value nodes
+        sdpa_node.args = (q_node, new_current_key_node, new_current_value_node) + (
+            attn_mask,
+            dropout_p,
+            True,
+        )
+
+    # kv_cache_for_graph.extend([k_node, v_node])
+    return gm, kv_cache_for_graph
+
+
+@_aten_lowering_pass
+def insert_static_cache_v2(
+    gm: torch.fx.GraphModule, settings: CompilationSettings
+) -> torch.fx.GraphModule:
+    """Insert KV cache ops in the graph"""
+    """Perform insertion of kv-caches and attention kernel."""
+    # Add static key and value as inputs to the graph
+    kv_inputs, start_idx_input, end_idx_input = add_kv_cache_inputs(gm, fixed_kv=True)
+
+    # Build and update the KV cache using computed KV inputs for current token and
+    # incoming keys and values from previous tokens (which were added as inputs)
+    gm, kv_cache_for_graph = insert_kv_slicing_before_sdpa(
+        gm, kv_inputs, start_idx_input, end_idx_input
+    )
+
+    # Call the function to add KV as outputs
+    logits_keys_values = add_kv_as_outputs(gm, kv_cache_for_graph)
+
+    gm = clean_up_graph_after_modifications(gm)
+
+    new_output_tensors = create_random_output_tensors(logits_keys_values)
+
+    new_out_spec = pytree.tree_flatten(new_output_tensors)[1]
+    gm._out_spec = new_out_spec
+
+    logger.debug("After inserting KV cache into the graph: " + str(gm.graph))
+    return gm
diff --git a/tools/llm/test_llama_components.py b/tools/llm/test_llama_components.py
new file mode 100644
index 0000000000..ef7e59cd72
--- /dev/null
+++ b/tools/llm/test_llama_components.py
@@ -0,0 +1,603 @@
+import torch
+
+torch.backends.cuda.matmul.allow_tf32 = False
+torch.backends.cudnn.allow_tf32 = False
+
+import argparse
+import os
+import sys
+from contextlib import nullcontext
+
+import torch.nn as nn
+import torch_tensorrt
+from torch.testing._internal.common_utils import TestCase, run_tests
+from transformers import AutoModelForCausalLM
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer
+
+# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
+sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
+from register_sdpa import *
+
+ATOL = 1e-5
+RTOL = 1e-5
+
+
+# llama2_model_name = "meta-llama/Llama-2-7b-hf"
+llama3_model_name = "meta-llama/Llama-3.2-1B-Instruct"
+llama_model = (
+    AutoModelForCausalLM.from_pretrained(
+        llama3_model_name,
+        use_cache=False,
+        attn_implementation="sdpa",
+        num_hidden_layers=1,
+    )
+    .eval()
+    .cuda()
+)
+LLAMA_CONFIG = llama_model.config
+
+
+def test_llama_attention(args):
+
+    DTYPE = torch.float32
+    if args.precision == "FP16":
+        DTYPE = torch.float16
+    elif args.precision == "BF16":
+        DTYPE = torch.bfloat16
+
+    # Set precision specific flags
+    use_fp32_acc = False
+    use_explicit_typing = False
+    if args.precision == "FP16":
+        enabled_precisions = {torch.float32}
+        use_fp32_acc = True
+        use_explicit_typing = True
+    elif args.precision == "BF16":
+        enabled_precisions = {torch.bfloat16}
+        use_fp32_acc = False
+    else:
+        enabled_precisions = {torch.float32}
+
+    # model = LlamaAttentionBlock().eval().cuda().to(DTYPE)
+    model = llama_model.model.layers[0].self_attn.to(DTYPE)
+    # llama3
+    hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda()
+    position_embeddings = (
+        torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+        torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+    )
+
+    pyt_output = model(hidden_states, position_embeddings, None)
+    seq_len = torch.export.Dim("seq_len", min=2, max=2176)
+    dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None)
+    from torch.export._trace import _export
+
+    # ep = torch.export.export(model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes, strict=False)
+    ep = _export(
+        model,
+        args=(hidden_states, position_embeddings, None),
+        dynamic_shapes=dynamic_shapes,
+        strict=False,
+        allow_complex_guards_as_runtime_asserts=True,
+    )
+
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+        trt_model = torch_tensorrt.dynamo.compile(
+            ep,
+            inputs=[hidden_states, position_embeddings, None],
+            enabled_precisions=enabled_precisions,
+            disable_tf32=True,
+            use_fp32_acc=use_fp32_acc,
+            use_explicit_typing=use_explicit_typing,
+            debug=args.debug,
+        )
+    trt_output = trt_model(hidden_states, position_embeddings, None)
+    if isinstance(pyt_output, tuple):
+        print(
+            f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}"
+        )
+        assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL)
+    else:
+        print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output - trt_output))}")
+        assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL)
+
+
+def print_diff(tensor1, tensor2, prefix=""):
+    """
+    Print the diff between two tensors
+    """
+    print(
+        f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}"
+    )
+
+
+def test_llama_attention_with_static_cache(args):
+    class LlamaAttentionBlock(nn.Module):
+        def __init__(self):
+            super().__init__()
+            self.config = LLAMA_CONFIG
+            self.attn = LlamaAttention(config=self.config, layer_idx=0)
+
+        def forward(self, hidden_states, position_embeddings):
+            attn_output, attn_weights = self.attn(
+                hidden_states, position_embeddings, None
+            )
+            return attn_output
+
+    DTYPE = torch.float32
+    if args.precision == "FP16":
+        DTYPE = torch.float16
+    elif args.precision == "BF16":
+        DTYPE = torch.bfloat16
+
+    # Set precision specific flags
+    use_fp32_acc = False
+    use_explicit_typing = False
+    if args.precision == "FP16":
+        enabled_precisions = {torch.float32}
+        use_fp32_acc = True
+        use_explicit_typing = True
+    elif args.precision == "BF16":
+        enabled_precisions = {torch.bfloat16}
+        use_fp32_acc = False
+    else:
+        enabled_precisions = {torch.float32}
+    model = llama_model.model.layers[0].self_attn.to(DTYPE)
+
+    # Inputs
+    ISL = 2048
+    NUM_TOKENS = 128
+    OSL = ISL + NUM_TOKENS
+    hidden_states = torch.randn((1, ISL, 2048), dtype=DTYPE).cuda()
+    position_embeddings = (
+        torch.randn((1, ISL, 64), dtype=DTYPE).cuda(),
+        torch.randn((1, ISL, 64), dtype=DTYPE).cuda(),
+    )
+    key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
+    value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
+    start_idx = 0
+    end_idx = ISL
+    is_causal = True
+
+    pyt_output = model(hidden_states, position_embeddings, None)
+    seq_len = torch.export.Dim("seq_len", min=2, max=2176)
+    dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None)
+    ep = torch.export.export(
+        model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes
+    )
+    import static_cache_v2
+
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+        trt_model = torch_tensorrt.dynamo.compile(
+            ep,
+            inputs=[
+                hidden_states,
+                position_embeddings,
+                None,
+                key_cache,
+                value_cache,
+                start_idx,
+                end_idx,
+                is_causal,
+            ],
+            enabled_precisions=enabled_precisions,
+            disable_tf32=True,
+            debug=args.debug,
+            # offload_module_to_cpu=True,
+            use_fp32_acc=use_fp32_acc,
+            use_explicit_typing=use_explicit_typing,
+            use_python_runtime=True,
+        )
+
+    # Test Prefill
+    trt_output, _, key_cache, value_cache = trt_model(
+        hidden_states,
+        position_embeddings,
+        None,
+        key_cache,
+        value_cache,
+        start_idx,
+        end_idx,
+        is_causal,
+    )
+    print_diff(pyt_output[0], trt_output[0], "pyt_output[0] vs trt_output[0] [Prefill]")
+
+    # Test Generate
+    for start_idx in range(2048, 2176):
+        end_idx = start_idx + 1
+        hidden_states_curr = torch.randn((1, 1, 2048), dtype=DTYPE).cuda()
+        position_embeddings_curr = (
+            torch.randn((1, 1, 64), dtype=DTYPE).cuda(),
+            torch.randn((1, 1, 64), dtype=DTYPE).cuda(),
+        )
+        # Concatenate the current  hidden_states with the previous ones
+        hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1)
+        position_embeddings_full = (
+            torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1),
+            torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1),
+        )
+
+        is_causal = False
+        out_no_cache, _ = model(hidden_states_full, position_embeddings_full, None)
+        out_trt, _, key_cache, value_cache = trt_model(
+            hidden_states_curr,
+            position_embeddings_curr,
+            None,
+            key_cache,
+            value_cache,
+            start_idx,
+            end_idx,
+            is_causal,
+        )
+        out_pyt = out_no_cache[:, -1:, :]
+        print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}")
+
+        hidden_states = hidden_states_full
+        position_embeddings = position_embeddings_full
+
+
+def test_llama_decoder(args):
+
+    class LlamaDecoderLayerBlock(nn.Module):
+        def __init__(self, model):
+            super().__init__()
+            self.config = LLAMA_CONFIG
+            self.decoder = LlamaDecoderLayer(config=self.config, layer_idx=0)
+            self.model = model
+
+        def forward(self, hidden_states, position_embeddings):
+            return self.model(hidden_states, position_embeddings=position_embeddings)
+
+    DTYPE = torch.float32
+    if args.precision == "FP16":
+        DTYPE = torch.float16
+    elif args.precision == "BF16":
+        DTYPE = torch.bfloat16
+
+    # Set precision specific flags
+    use_fp32_acc = False
+    use_explicit_typing = False
+    if args.precision == "FP16":
+        enabled_precisions = {torch.float32}
+        use_fp32_acc = True
+        use_explicit_typing = True
+    elif args.precision == "BF16":
+        enabled_precisions = {torch.bfloat16}
+        use_fp32_acc = False
+    else:
+        enabled_precisions = {torch.float32}
+
+    model = LlamaDecoderLayerBlock(llama_model.model.layers[0].to(DTYPE))
+    # llama3
+    hidden_states = torch.randn((1, 6, 2048), dtype=DTYPE).cuda()
+    position_embeddings = (
+        torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+        torch.randn((1, 6, 64), dtype=DTYPE).cuda(),
+    )
+
+    pyt_output = model(hidden_states, position_embeddings)
+    seq_len = torch.export.Dim("seq_len", min=2, max=2176)
+    dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}))
+    ep = torch.export.export(
+        model, (hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes
+    )
+
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+        trt_model = torch_tensorrt.dynamo.compile(
+            ep,
+            inputs=[hidden_states, position_embeddings],
+            enabled_precisions=enabled_precisions,
+            debug=args.debug,
+            use_fp32_acc=use_fp32_acc,
+            use_explicit_typing=use_explicit_typing,
+        )
+    trt_output = trt_model(hidden_states, position_embeddings)
+
+    print(
+        f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}"
+    )
+    assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL)
+
+
+def test_llama_decoder_with_static_cache(args):
+
+    class LlamaDecoderLayerBlock(nn.Module):
+        def __init__(self, model):
+            super().__init__()
+            self.config = LLAMA_CONFIG
+            self.decoder = LlamaDecoderLayer(config=self.config, layer_idx=0)
+            self.model = model
+
+        def forward(self, hidden_states, position_embeddings):
+            return self.model(hidden_states, position_embeddings=position_embeddings)
+
+    DTYPE = torch.float32
+    if args.precision == "FP16":
+        DTYPE = torch.float16
+    elif args.precision == "BF16":
+        DTYPE = torch.bfloat16
+
+    # Set precision specific flags
+    use_fp32_acc = False
+    use_explicit_typing = False
+    if args.precision == "FP16":
+        enabled_precisions = {torch.float32}
+        use_fp32_acc = True
+        use_explicit_typing = True
+    elif args.precision == "BF16":
+        enabled_precisions = {torch.bfloat16}
+        use_fp32_acc = False
+    else:
+        enabled_precisions = {torch.float32}
+    model = LlamaDecoderLayerBlock(llama_model.model.layers[0].to(DTYPE))
+
+    # Inputs
+    ISL = 2048
+    NUM_TOKENS = 128
+    OSL = ISL + NUM_TOKENS
+    hidden_states = torch.randn((1, ISL, 2048), dtype=DTYPE).cuda()
+    position_embeddings = (
+        torch.randn((1, ISL, 64), dtype=DTYPE).cuda(),
+        torch.randn((1, ISL, 64), dtype=DTYPE).cuda(),
+    )
+    key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
+    value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
+    start_idx = 0
+    end_idx = ISL
+    is_causal = True
+
+    pyt_output = model(hidden_states, position_embeddings)
+    seq_len = torch.export.Dim("seq_len", min=2, max=2176)
+    dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}))
+    ep = torch.export.export(
+        model, args=(hidden_states, position_embeddings), dynamic_shapes=dynamic_shapes
+    )
+    import static_cache_v2
+
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+        trt_model = torch_tensorrt.dynamo.compile(
+            ep,
+            arg_inputs=[
+                hidden_states,
+                position_embeddings,
+                key_cache,
+                value_cache,
+                start_idx,
+                end_idx,
+                is_causal,
+            ],
+            enabled_precisions=enabled_precisions,
+            disable_tf32=True,
+            debug=args.debug,
+            # offload_module_to_cpu=True,
+            use_fp32_acc=use_fp32_acc,
+            use_explicit_typing=use_explicit_typing,
+            use_python_runtime=True,
+        )
+
+    # Test Prefill
+    trt_output, key_cache, value_cache = trt_model(
+        hidden_states,
+        position_embeddings,
+        key_cache,
+        value_cache,
+        start_idx,
+        end_idx,
+        is_causal,
+    )
+    print_diff(pyt_output[0], trt_output, "pyt_output vs trt_output [Prefill]")
+
+    # Test Generate
+    for start_idx in range(2048, 2176):
+        end_idx = start_idx + 1
+        hidden_states_curr = torch.randn((1, 1, 2048), dtype=DTYPE).cuda()
+        position_embeddings_curr = (
+            torch.randn((1, 1, 64), dtype=DTYPE).cuda(),
+            torch.randn((1, 1, 64), dtype=DTYPE).cuda(),
+        )
+        # Concatenate the current  hidden_states with the previous ones
+        hidden_states_full = torch.cat((hidden_states, hidden_states_curr), dim=1)
+        position_embeddings_full = (
+            torch.cat((position_embeddings[0], position_embeddings_curr[0]), dim=1),
+            torch.cat((position_embeddings[1], position_embeddings_curr[1]), dim=1),
+        )
+
+        is_causal = False
+        out_no_cache = model(hidden_states_full, position_embeddings_full)
+
+        out_trt, key_cache, value_cache = trt_model(
+            hidden_states_curr,
+            position_embeddings_curr,
+            key_cache,
+            value_cache,
+            start_idx,
+            end_idx,
+            is_causal,
+        )
+        out_pyt = out_no_cache[0][:, -1:, :]
+        print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}")
+        hidden_states = hidden_states_full
+        position_embeddings = position_embeddings_full
+
+
+def test_llama_model(args):
+
+    DTYPE = torch.float32
+    if args.precision == "FP16":
+        DTYPE = torch.float16
+    elif args.precision == "BF16":
+        DTYPE = torch.bfloat16
+
+    # Set precision specific flags
+    use_fp32_acc = False
+    use_explicit_typing = False
+    if args.precision == "FP16":
+        enabled_precisions = {torch.float32}
+        use_fp32_acc = True
+        use_explicit_typing = True
+    elif args.precision == "BF16":
+        enabled_precisions = {torch.bfloat16}
+        use_fp32_acc = False
+    else:
+        enabled_precisions = {torch.float32}
+
+    model = llama_model.model.to(DTYPE)
+
+    # Inputs
+    ISL = 2048
+    NUM_TOKENS = 128
+    OSL = ISL + NUM_TOKENS
+    input_ids = torch.randint(1, 20, (1, ISL), dtype=torch.int64).cuda()
+    position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).cuda()
+
+    pyt_output = model(input_ids, position_ids)
+    seq_len = torch.export.Dim("seq_len", min=2, max=2176)
+    dynamic_shapes = ({1: seq_len}, {1: seq_len})
+    kwarg_inputs = {"position_ids": position_ids}
+    from torch.export._trace import _export
+
+    ep = _export(
+        model,
+        args=(input_ids,),
+        kwargs=kwarg_inputs,
+        dynamic_shapes=dynamic_shapes,
+        strict=False,
+        allow_complex_guards_as_runtime_asserts=True,
+    )
+
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+        trt_model = torch_tensorrt.dynamo.compile(
+            ep,
+            arg_inputs=[],
+            kwarg_inputs=kwarg_inputs,
+            enabled_precisions=enabled_precisions,
+            disable_tf32=True,
+            debug=args.debug,
+            offload_module_to_cpu=True,
+            use_fp32_acc=use_fp32_acc,
+            use_explicit_typing=use_explicit_typing,
+            use_python_runtime=True,
+        )
+
+    trt_output = trt_model(input_ids, position_ids)
+
+    print(
+        f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[0] - trt_output[0]))}"
+    )
+    # print(f"Diff b/w pyt and trt: {torch.mean(torch.abs(pyt_output[1] - trt_output[1]))}")
+    breakpoint()
+    assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL)
+
+
+def test_llama_model_with_static_cache(args):
+
+    DTYPE = torch.float32
+    if args.precision == "FP16":
+        DTYPE = torch.float16
+    elif args.precision == "BF16":
+        DTYPE = torch.bfloat16
+
+    # Set precision specific flags
+    use_fp32_acc = False
+    use_explicit_typing = False
+    if args.precision == "FP16":
+        enabled_precisions = {torch.float32}
+        use_fp32_acc = True
+        use_explicit_typing = True
+    elif args.precision == "BF16":
+        enabled_precisions = {torch.bfloat16}
+        use_fp32_acc = False
+    else:
+        enabled_precisions = {torch.float32}
+    model = llama_model.model.to(DTYPE)
+
+    # Inputs
+    ISL = 2048
+    NUM_TOKENS = 128
+    OSL = ISL + NUM_TOKENS
+    input_ids = torch.randint(1, 20, (1, ISL), dtype=torch.int64).cuda()
+    position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).cuda()
+    key_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
+    value_cache = torch.zeros(1, 32, OSL, 64).cuda().to(DTYPE)
+    start_idx = 0
+    end_idx = ISL
+    is_causal = True
+
+    pyt_output = model(input_ids)
+    seq_len = torch.export.Dim("seq_len", min=2, max=2176)
+    dynamic_shapes = ({1: seq_len}, {1: seq_len})
+    kwarg_inputs = {"input_ids": input_ids, "position_ids": position_ids}
+    ep = torch.export.export(
+        model, args=(), kwargs=kwarg_inputs, dynamic_shapes=dynamic_shapes
+    )
+
+    import static_cache_v2
+
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+        trt_model = torch_tensorrt.dynamo.compile(
+            ep,
+            arg_inputs=[],
+            kwarg_inputs=kwarg_inputs,
+            enabled_precisions=enabled_precisions,
+            disable_tf32=True,
+            debug=args.debug,
+            # offload_module_to_cpu=True,
+            use_fp32_acc=use_fp32_acc,
+            use_explicit_typing=use_explicit_typing,
+            use_python_runtime=True,
+        )
+
+    # Test Prefill
+    trt_output, key_cache, value_cache = trt_model(
+        input_ids, position_ids, key_cache, value_cache, start_idx, end_idx, is_causal
+    )
+    pyt_output = pyt_output.last_hidden_state
+    print_diff(pyt_output, trt_output, "pyt_output vs trt_output [Prefill]")
+
+    # Test Generate
+    for start_idx in range(2048, 2176):
+        end_idx = start_idx + 1
+        input_ids_curr = torch.randint(1, 20, (1, 1), dtype=torch.int64).cuda()
+        position_ids_curr = torch.tensor([[start_idx]], dtype=torch.int64).cuda()
+
+        # Concatenate the current  hidden_states with the previous ones
+        input_ids_full = torch.cat((input_ids, input_ids_curr), dim=1)
+        position_ids_full = torch.cat((position_ids, position_ids_curr), dim=1)
+        is_causal = False
+        kwarg_inputs = {"input_ids": input_ids_full, "position_ids": position_ids_full}
+        out_no_cache = model(**kwarg_inputs)
+
+        out_trt, key_cache, value_cache = trt_model(
+            input_ids_curr,
+            position_ids_curr,
+            key_cache,
+            value_cache,
+            start_idx,
+            end_idx,
+            is_causal,
+        )
+        out_pyt = out_no_cache.last_hidden_state[:, -1:, :]
+        print_diff(out_pyt, out_trt, f"pyt_curr_output vs out_trt for idx {start_idx}")
+        input_ids = input_ids_full
+        position_ids = position_ids_full
+
+
+if __name__ == "__main__":
+    arg_parser = argparse.ArgumentParser(
+        description="Run test cases for llama attention and decoder"
+    )
+    arg_parser.add_argument(
+        "--debug", action="store_true", help="Enable debug (default: False)"
+    )
+    arg_parser.add_argument(
+        "--precision", type=str, default="FP16", help="Precision (default: FP16)"
+    )
+    args = arg_parser.parse_args()
+    with torch.inference_mode():
+        # test_llama_attention(args)
+        # test_llama_decoder(args)
+        test_llama_model(args)
+        # test_llama_attention_with_static_cache(args)
+        # test_llama_decoder_with_static_cache(args)
+        # test_llama_model_with_static_cache(args)
diff --git a/tools/llm/test_qwen2.5_components.py b/tools/llm/test_qwen2.5_components.py
new file mode 100644
index 0000000000..60482bf22d
--- /dev/null
+++ b/tools/llm/test_qwen2.5_components.py
@@ -0,0 +1,193 @@
+import torch
+
+torch.backends.cuda.matmul.allow_tf32 = False
+torch.backends.cudnn.allow_tf32 = False
+
+import argparse
+import os
+import sys
+from contextlib import nullcontext
+
+import torch.nn as nn
+import torch_tensorrt
+from torch.testing._internal.common_utils import TestCase, run_tests
+from transformers import AutoModelForCausalLM
+from transformers.models.llama.configuration_llama import LlamaConfig
+
+# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
+sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
+from register_sdpa import *
+
+ATOL = 1e-5
+RTOL = 1e-5
+
+
+qwen2_5_model_name = "Qwen/Qwen2.5-1.5B-Instruct"
+qwen2_5_model = (
+    AutoModelForCausalLM.from_pretrained(
+        qwen2_5_model_name,
+        use_cache=False,
+        attn_implementation="sdpa",
+        num_hidden_layers=1,
+    )
+    .eval()
+    .cuda()
+)
+QWEN_CONFIG = qwen2_5_model.config
+
+
+def print_diff(tensor1, tensor2, prefix=""):
+    """
+    Print the diff between two tensors
+    """
+    print(
+        f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}"
+    )
+
+
+def test_qwen_apply_rotary_pos_emb(args):
+    class QwenApplyRotaryPosEmb(nn.Module):
+        def __init__(self):
+            super().__init__()
+
+        def rotate_half(self, x):
+            x1 = x[..., : x.shape[-1] // 2]
+            x2 = x[..., x.shape[-1] // 2 :]
+            return torch.cat((-x2, x1), dim=-1)
+
+        def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1):
+            cos = cos.unsqueeze(unsqueeze_dim)
+            sin = sin.unsqueeze(unsqueeze_dim)
+            q_embed = (q * cos) + (self.rotate_half(q) * sin)
+            k_embed = (k * cos) + (self.rotate_half(k) * sin)
+            return q_embed, k_embed
+
+        def forward(self, q, k, cos, sin, unsqueeze_dim=1):
+            return self.apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim)
+
+    DTYPE = torch.float32
+    if args.precision == "FP16":
+        DTYPE = torch.float16
+    elif args.precision == "BF16":
+        DTYPE = torch.bfloat16
+
+    # Set precision specific flags
+    use_fp32_acc = False
+    use_explicit_typing = False
+    if args.precision == "FP16":
+        enabled_precisions = {torch.float32}
+        use_fp32_acc = True
+        use_explicit_typing = True
+    elif args.precision == "BF16":
+        enabled_precisions = {torch.bfloat16}
+        use_fp32_acc = False
+    else:
+        enabled_precisions = {torch.float32}
+
+    model = QwenApplyRotaryPosEmb().eval().cuda().to(DTYPE)
+    # Shapes for Qwen 2.5
+    q = torch.randn((1, 12, 5, 128), dtype=DTYPE).cuda()
+    k = torch.randn((1, 12, 5, 128), dtype=DTYPE).cuda()
+    cos = torch.randn((1, 5, 128), dtype=DTYPE).cuda()
+    sin = torch.randn((1, 5, 128), dtype=DTYPE).cuda()
+
+    pyt_output = model(q, k, cos, sin)
+
+    seq_len = torch.export.Dim("seq_len", min=2, max=2176)
+    dynamic_shapes = ({2: seq_len}, {2: seq_len}, {1: seq_len}, {1: seq_len})
+    ep = torch.export.export(model, (q, k, cos, sin), dynamic_shapes=dynamic_shapes)
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+        trt_model = torch_tensorrt.dynamo.compile(
+            ep,
+            inputs=[q, k, cos, sin],
+            enabled_precisions=enabled_precisions,
+            disable_tf32=True,
+            use_fp32_acc=use_fp32_acc,
+            use_explicit_typing=use_explicit_typing,
+            debug=args.debug,
+        )
+    trt_output = trt_model(q, k, cos, sin)
+
+    if isinstance(pyt_output, tuple):
+        print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt")
+        # print_diff(pyt_output[1], trt_output[1], "Diff b/w pyt and trt")
+        assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL)
+    else:
+        print_diff(pyt_output, trt_output, "Diff b/w pyt and trt")
+        assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL)
+
+
+def test_qwen_attention(args):
+
+    DTYPE = torch.float32
+    if args.precision == "FP16":
+        DTYPE = torch.float16
+    elif args.precision == "BF16":
+        DTYPE = torch.bfloat16
+
+    # Set precision specific flags
+    use_fp32_acc = False
+    use_explicit_typing = False
+    if args.precision == "FP16":
+        enabled_precisions = {torch.float32}
+        use_fp32_acc = True
+        use_explicit_typing = True
+    elif args.precision == "BF16":
+        enabled_precisions = {torch.bfloat16}
+        use_fp32_acc = False
+    else:
+        enabled_precisions = {torch.float32}
+
+    model = qwen2_5_model.model.layers[0].self_attn.to(DTYPE)
+    # qwen2.5
+    hidden_states = torch.randn((1, 5, 1536), dtype=DTYPE).cuda()
+    position_embeddings = (
+        torch.randn((1, 5, 128), dtype=DTYPE).cuda(),
+        torch.randn((1, 5, 128), dtype=DTYPE).cuda(),
+    )
+
+    pyt_output = model(hidden_states, position_embeddings, None)
+
+    seq_len = torch.export.Dim("seq_len", min=2, max=2176)
+    dynamic_shapes = ({1: seq_len}, ({1: seq_len}, {1: seq_len}), None)
+    ep = torch.export.export(
+        model, (hidden_states, position_embeddings, None), dynamic_shapes=dynamic_shapes
+    )
+
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+        trt_model = torch_tensorrt.dynamo.compile(
+            ep,
+            inputs=[hidden_states, position_embeddings, None],
+            enabled_precisions=enabled_precisions,
+            disable_tf32=True,
+            use_fp32_acc=use_fp32_acc,
+            use_explicit_typing=use_explicit_typing,
+            debug=args.debug,
+        )
+    trt_output = trt_model(hidden_states, position_embeddings, None)
+
+    if isinstance(pyt_output, tuple):
+        print_diff(pyt_output[0], trt_output[0], "Diff b/w pyt and trt")
+        assert torch.allclose(pyt_output[0], trt_output[0], atol=ATOL, rtol=RTOL)
+    else:
+        print_diff(pyt_output, trt_output, "Diff b/w pyt and trt")
+        assert torch.allclose(pyt_output, trt_output, atol=ATOL, rtol=RTOL)
+
+
+if __name__ == "__main__":
+    arg_parser = argparse.ArgumentParser(
+        description="Run test cases for llama attention and decoder"
+    )
+    arg_parser.add_argument(
+        "--debug", action="store_true", help="Enable debug (default: False)"
+    )
+    arg_parser.add_argument(
+        "--precision",
+        type=str,
+        default="FP16",
+        help="Precision to use in the model. Options: FP16, BF16, FP32",
+    )
+    args = arg_parser.parse_args()
+    with torch.inference_mode():
+        # test_qwen_apply_rotary_pos_emb(args)
+        test_qwen_attention(args)
diff --git a/tools/llm/test_static_cache.py b/tools/llm/test_static_cache.py
new file mode 100644
index 0000000000..603f84d3a6
--- /dev/null
+++ b/tools/llm/test_static_cache.py
@@ -0,0 +1,478 @@
+import argparse
+import os
+import sys
+from contextlib import nullcontext
+
+import torch
+import torch.nn as nn
+import torch_tensorrt
+from torch.export import export
+from torch_tensorrt.dynamo.lowering import (
+    get_decompositions,
+    post_lowering,
+    pre_export_lowering,
+)
+from transformers import AutoModelForCausalLM
+from transformers.models.llama.configuration_llama import LlamaConfig
+from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer
+
+# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
+sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
+import register_sdpa
+
+ATOL = 1e-5
+RTOL = 1e-5
+torch.backends.cuda.matmul.allow_tf32 = False
+torch.backends.cudnn.allow_tf32 = False
+
+
+class DynamicCacheModel(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, q, k, v, k1, v1, flag):
+        def true_fn(q, k, v, k1, v1):
+            k_new = torch.cat((k, k1), dim=2)
+            v_new = torch.cat((v, v1), dim=2)
+            return torch._C._nn.scaled_dot_product_attention(q, k_new, v_new)
+
+        def false_fn(q, k, v, k1, v1):
+            return torch._C._nn.scaled_dot_product_attention(q, k, v)
+
+        out = torch.cond(flag, true_fn, false_fn, (q, k, v, k1, v1))
+
+        return 2 * out
+
+
+class ModelNoCache(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, q, k, v):
+        return torch._C._nn.scaled_dot_product_attention(
+            q, k, v, dropout_p=0.0, is_causal=True
+        )
+
+
+class StaticCacheModel(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(
+        self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True
+    ):
+        new_key_cache = torch.cat(
+            (key_cache[:, :, :start_idx, :], k, key_cache[:, :, end_idx:, :]), dim=2
+        )
+        new_value_cache = torch.cat(
+            (value_cache[:, :, :start_idx, :], v, value_cache[:, :, end_idx:, :]), dim=2
+        )
+        attn_output = torch._C._nn.scaled_dot_product_attention(
+            q,
+            new_key_cache[:, :, :end_idx, :],
+            new_value_cache[:, :, :end_idx, :],
+            dropout_p=0.0,
+            is_causal=is_causal,
+        )
+
+        return attn_output, new_key_cache, new_value_cache
+
+    def forward(
+        self, q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True
+    ):
+        concat_keys = torch.cat((key_cache[:, :, :start_idx, :], k), dim=2)
+        concat_values = torch.cat((value_cache[:, :, :start_idx, :], v), dim=2)
+        new_key_cache = torch.cat((concat_keys, key_cache[:, :, end_idx:, :]), dim=2)
+        new_value_cache = torch.cat(
+            (concat_values, value_cache[:, :, end_idx:, :]), dim=2
+        )
+        attn_output = torch._C._nn.scaled_dot_product_attention(
+            q, concat_keys, concat_values, dropout_p=0.0, is_causal=is_causal
+        )
+
+        return attn_output, new_key_cache, new_value_cache
+
+
+def eager_sdpa(
+    query,
+    key,
+    value,
+    attn_mask=None,
+    dropout_p=0.0,
+    is_causal=False,
+    scale=None,
+    enable_gqa=False,
+) -> torch.Tensor:
+    """
+    Eager implementation of SDPA
+    """
+    import math
+
+    L, S = query.size(-2), key.size(-2)
+    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
+    attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
+
+    if is_causal:
+        assert attn_mask is None
+        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0).cuda()
+        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
+        attn_bias.to(query.dtype)
+
+    if attn_mask is not None:
+        if attn_mask.dtype == torch.bool:
+            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
+        else:
+            attn_bias = attn_mask + attn_bias
+
+    if enable_gqa:
+        key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
+        value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
+
+    attn_weight = query @ key.transpose(-2, -1) * scale_factor
+    attn_weight += attn_bias
+    attn_weight = torch.softmax(attn_weight, dim=-1)
+    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
+    return attn_weight @ value
+
+
+def print_diff(tensor1, tensor2, prefix=""):
+    """
+    Print the diff between two tensors
+    """
+    print(
+        f"[{prefix}] Diff between tensor1 and tensor2: {torch.mean(torch.abs(tensor1 - tensor2))}"
+    )
+
+
+def test_no_cache_model_with_torch_tensorrt(args):
+    """
+    Test the no cache model
+    """
+    with torch.inference_mode():
+        model_no_cache = ModelNoCache().eval().cuda()
+        # q = torch.randn(1, 32, 6, 64).cuda()
+        # k = torch.randn(1, 32, 6, 64).cuda()
+        # v = torch.randn(1, 32, 6, 64).cuda()
+        q = torch.load("query.pt")
+        k = torch.load("key.pt")
+        v = torch.load("value.pt")
+        out_no_cache = model_no_cache(q, k, v)
+        out_eager = eager_sdpa(q, k, v, is_causal=True)
+        q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176)
+        # Export the model
+        exported_program = torch.export.export(
+            model_no_cache,
+            args=(q, k, v),
+            dynamic_shapes=({2: q_seq_len}, {2: q_seq_len}, {2: q_seq_len}),
+            strict=False,
+        )
+        with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+            trt_model = torch_tensorrt.dynamo.compile(
+                exported_program,
+                inputs=[q, k, v],
+                enabled_precisions={torch.float32},
+                disable_tf32=True,
+                debug=args.debug,
+                min_block_size=1,
+            )
+        out_trt = trt_model(q, k, v)
+
+        print_diff(out_no_cache, out_eager, "out_no_cache vs out_eager")
+        print_diff(out_no_cache, out_trt, "out_no_cache vs out_trt")
+        print_diff(out_eager, out_trt, "out_eager vs out_trt")
+        breakpoint()
+
+
+def test_static_cache_model(args):
+    """
+    Test the static cache model
+    """
+    with torch.inference_mode():
+        model_no_cache = ModelNoCache().eval().cuda()
+        model_static_cache = StaticCacheModel().eval().cuda()
+        q = torch.randn(1, 32, 2048, 64).cuda()
+        k = torch.randn(1, 32, 2048, 64).cuda()
+        v = torch.randn(1, 32, 2048, 64).cuda()
+        key_cache = torch.zeros(1, 32, 2176, 64).cuda()
+        value_cache = torch.zeros(1, 32, 2176, 64).cuda()
+
+        # Test Prefill
+        start_idx = 0
+        end_idx = 2048
+        out_no_cache = model_no_cache(q, k, v)
+        out_static_cache, new_key_cache, new_value_cache = model_static_cache(
+            q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal=True
+        )
+        assert torch.allclose(out_no_cache, out_static_cache, atol=ATOL, rtol=RTOL)
+
+        # Test Generate
+        for start_idx in range(2048, 2176):
+            end_idx = start_idx + 1
+            q_curr = torch.randn(1, 32, 1, 64).cuda()
+            k_curr = torch.randn(1, 32, 1, 64).cuda()
+            v_curr = torch.randn(1, 32, 1, 64).cuda()
+
+            # Concatenate the current query, key, and value with the previous ones
+            q_full = torch.cat((q, q_curr), dim=2)
+            k_full = torch.cat((k, k_curr), dim=2)
+            v_full = torch.cat((v, v_curr), dim=2)
+
+            out_no_cache = model_no_cache(q_full, k_full, v_full)
+            out_static_cache, new_key_cache, new_value_cache = model_static_cache(
+                q_curr,
+                k_curr,
+                v_curr,
+                new_key_cache,
+                new_value_cache,
+                start_idx,
+                end_idx,
+                is_causal=False,
+            )
+
+            assert torch.allclose(
+                out_no_cache[:, :, -1:, :], out_static_cache, atol=ATOL, rtol=RTOL
+            )
+            q = q_full
+            k = k_full
+            v = v_full
+        print("============== test_static_cache passed ==============")
+
+
+def transform_gm_with_kv_cache(exported_program: torch.export.ExportedProgram, args):
+    """
+    Transform the graph module by adding key and value cache to the graph
+    """
+    gm = exported_program.module()
+    # Post lower the model
+    settings = torch_tensorrt.dynamo.conversion.CompilationSettings(
+        enabled_precisions={torch.float32},
+        disable_tf32=True,
+        use_python_runtime=True,
+        debug=args.debug,
+        min_block_size=1,
+    )
+    exported_program = pre_export_lowering(exported_program, settings)
+    exported_program = exported_program.run_decompositions(get_decompositions(False))
+
+    gm = exported_program.module()
+    gm = post_lowering(gm, settings)
+
+    return gm
+
+
+def test_static_cache_lowering(args):
+    """
+    Test static cache lowering pass applied to the model with no cache and run the graph module
+    and compare the output with the model with no cache
+    """
+    import static_cache2
+
+    model_no_cache = ModelNoCache().eval().cuda()
+    q = torch.randn(1, 32, 2, 64).cuda()
+    k = torch.randn(1, 32, 2048, 64).cuda()
+    v = torch.randn(1, 32, 2048, 64).cuda()
+    key_cache = torch.zeros(1, 32, 2176, 64).cuda()
+    value_cache = torch.zeros(1, 32, 2176, 64).cuda()
+
+    # Export the model
+    q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176)
+    kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176)
+    exported_program = export(
+        model_no_cache,
+        args=(q, k, v),
+        dynamic_shapes=({2: q_seq_len}, {2: kv_seq_len}, {2: kv_seq_len}),
+        strict=False,
+    )
+
+    gm = transform_gm_with_kv_cache(exported_program, args)
+
+    # Test Prefill
+    start_idx = 0
+    end_idx = 2048
+    is_causal = True
+    q = torch.randn(1, 32, 2048, 64).cuda()
+    out_no_cache = model_no_cache(q, k, v)
+    out_pyt_cache, key_cache, value_cache = gm(
+        q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal
+    )
+    assert torch.allclose(out_no_cache, out_pyt_cache, atol=ATOL, rtol=RTOL)
+
+    # Test Generate
+    for start_idx in range(2048, 2176):
+        end_idx = start_idx + 1
+        is_causal = False
+        q_curr = torch.randn(1, 32, 1, 64).cuda()
+        k_curr = torch.randn(1, 32, 1, 64).cuda()
+        v_curr = torch.randn(1, 32, 1, 64).cuda()
+        # Concatenate the current query, key, and value with the previous ones
+        q_full = torch.cat((q, q_curr), dim=2)
+        k_full = torch.cat((k, k_curr), dim=2)
+        v_full = torch.cat((v, v_curr), dim=2)
+
+        out_no_cache = model_no_cache(q_full, k_full, v_full)
+        out_pyt_static_cache, key_cache, value_cache = gm(
+            q_curr,
+            k_curr,
+            v_curr,
+            key_cache,
+            value_cache,
+            start_idx,
+            end_idx,
+            is_causal,
+        )
+        assert torch.allclose(
+            out_no_cache[:, :, -1:, :], out_pyt_static_cache, atol=ATOL, rtol=RTOL
+        )
+        q = q_full
+        k = k_full
+        v = v_full
+
+    print("============== test_static_cache_lowering passed ==============")
+
+
+def test_static_cache_export(args):
+    """
+    Test the static cache model export
+    """
+    model_static_cache = StaticCacheModel().eval().cuda()
+    q = torch.randn(1, 32, 2048, 64).cuda()
+    k = torch.randn(1, 32, 2048, 64).cuda()
+    v = torch.randn(1, 32, 2048, 64).cuda()
+    key_cache = torch.zeros(1, 32, 2176, 64).cuda()
+    value_cache = torch.zeros(1, 32, 2176, 64).cuda()
+    # Test Prefill
+    start_idx = 0
+    end_idx = 2048
+    is_causal = True
+    # Export the model
+    seq_len = torch.export.Dim("seq_len", min=2, max=2048)
+    seq_len_dyn_dim = {2: seq_len}
+    exported_program = export(
+        model_static_cache,
+        args=(q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal),
+        dynamic_shapes=(
+            seq_len_dyn_dim,
+            seq_len_dyn_dim,
+            seq_len_dyn_dim,
+            None,
+            None,
+            torch.export.Dim.DYNAMIC,
+            torch.export.Dim.DYNAMIC,
+            None,
+        ),
+        strict=False,
+    )
+
+
+def test_static_cache_with_torch_tensorrt(args):
+    """
+    Test the static cache model with torch_tensorrt
+    """
+    import static_cache_v2
+
+    model_no_cache = ModelNoCache().eval().cuda()
+    q = torch.randn(1, 32, 2, 64).cuda()
+    k = torch.randn(1, 32, 2048, 64).cuda()
+    v = torch.randn(1, 32, 2048, 64).cuda()
+    key_cache = torch.zeros(1, 32, 2176, 64).cuda()
+    value_cache = torch.zeros(1, 32, 2176, 64).cuda()
+
+    # Export the model
+    q_seq_len = torch.export.Dim("q_seq_len", min=2, max=2176)
+    kv_seq_len = torch.export.Dim("kv_seq_len", min=2, max=2176)
+    exported_program = export(
+        model_no_cache,
+        args=(q, k, v),
+        dynamic_shapes=({2: q_seq_len}, {2: kv_seq_len}, {2: kv_seq_len}),
+        strict=False,
+    )
+    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
+        trt_model = torch_tensorrt.dynamo.compile(
+            exported_program,
+            inputs=[q, k, v],
+            enabled_precisions={torch.float32},
+            disable_tf32=True,
+            use_python_runtime=True,
+            debug=args.debug,
+            min_block_size=1,
+        )
+
+    start_idx = 0
+    end_idx = 2048
+    is_causal = True
+    q = torch.randn(1, 32, 2048, 64).cuda()
+    # out_eager = eager_sdpa(q, k, v, is_causal=is_causal)
+    out_no_cache = model_no_cache(q, k, v)
+    out_trt, trt_key_cache, trt_value_cache = trt_model(
+        q, k, v, key_cache, value_cache, start_idx, end_idx, is_causal
+    )
+
+    assert torch.allclose(
+        out_no_cache, out_trt, atol=ATOL, rtol=RTOL
+    ), "Prefill TRT logits don't match"
+    assert torch.allclose(
+        trt_key_cache[:, :, :end_idx, :], k, atol=ATOL, rtol=RTOL
+    ), "Prefill TRT key cache don't match"
+    assert torch.allclose(
+        trt_value_cache[:, :, :end_idx, :], v, atol=ATOL, rtol=RTOL
+    ), "Prefill TRT value cache don't match"
+
+    # Test Generate
+    for start_idx in range(2048, 2176):
+        end_idx = start_idx + 1
+        q_curr = torch.randn(1, 32, 1, 64).cuda()
+        k_curr = torch.randn(1, 32, 1, 64).cuda()
+        v_curr = torch.randn(1, 32, 1, 64).cuda()
+        # Concatenate the current query, key, and value with the previous ones
+        q_full = torch.cat((q, q_curr), dim=2)
+        k_full = torch.cat((k, k_curr), dim=2)
+        v_full = torch.cat((v, v_curr), dim=2)
+        is_causal = True
+        out_no_cache = model_no_cache(q_full, k_full, v_full)
+        out_trt, trt_key_cache, trt_value_cache = trt_model(
+            q_curr,
+            k_curr,
+            v_curr,
+            trt_key_cache,
+            trt_value_cache,
+            start_idx,
+            end_idx,
+            is_causal,
+        )
+        # breakpoint()
+        # print_diff(out_no_cache[:, :, -1:, :], out_trt, f"out_no_cache[:, :, -1:, :] vs out_trt for idx {start_idx}")
+        # print_diff(trt_key_cache[:, :, :end_idx, :], k_full, f"trt_key_cache[:, :, :end_idx, :] vs k_full for idx {start_idx}")
+        # print_diff(trt_value_cache[:, :, :end_idx, :], v_full, f"trt_value_cache[:, :, :end_idx, :] vs v_full for idx {start_idx}")
+        assert torch.allclose(
+            out_no_cache[:, :, -1:, :], out_trt, atol=ATOL, rtol=RTOL
+        ), f"Generate TRT logits don't match for idx {start_idx}"
+        assert torch.allclose(
+            trt_key_cache[:, :, :end_idx, :], k_full, atol=ATOL, rtol=RTOL
+        ), f"Generate TRT key cache don't match for idx {start_idx}"
+        assert torch.allclose(
+            trt_value_cache[:, :, :end_idx, :], v_full, atol=ATOL, rtol=RTOL
+        ), f"Generate TRT value cache don't match for idx {start_idx}"
+        q = q_full
+        k = k_full
+        v = v_full
+
+    print("============== test_static_cache_with_torch_tensorrt passed ==============")
+
+
+def main():
+    arg_parser = argparse.ArgumentParser(
+        description="Run test cases for llama attention and decoder"
+    )
+    arg_parser.add_argument(
+        "--debug", action="store_true", help="Enable debug (default: False)"
+    )
+    args = arg_parser.parse_args()
+    with torch.inference_mode():
+        # test_no_cache_model_with_torch_tensorrt(args)
+        # test_static_cache_model(args)
+        # test_static_cache_lowering(args)
+        test_static_cache_with_torch_tensorrt(args)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/examples/dynamo/register_sdpa.py b/tools/llm/torchtrt_ext/register_sdpa.py
similarity index 86%
rename from examples/dynamo/register_sdpa.py
rename to tools/llm/torchtrt_ext/register_sdpa.py
index 7436f31939..90a00a5798 100644
--- a/examples/dynamo/register_sdpa.py
+++ b/tools/llm/torchtrt_ext/register_sdpa.py
@@ -4,7 +4,6 @@
 from typing import Callable, Sequence, Tuple
 
 import torch
-from sdpa_converter import *
 from torch_tensorrt.dynamo._settings import CompilationSettings
 from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check
 from torch_tensorrt.dynamo.lowering import TORCH_TRT_DECOMPOSITIONS
@@ -15,15 +14,19 @@
     clean_up_graph_after_modifications,
 )
 
+from .sdpa_converter import *
+
 logger = logging.getLogger(__name__)
 
 # Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention
 # This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it.
-TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default)
+TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default, None)
+TORCH_TRT_DECOMPOSITIONS.pop(
+    torch.ops.aten._scaled_dot_product_efficient_attention.default, None
+)
 TORCH_TRT_DECOMPOSITIONS.pop(
-    torch.ops.aten._scaled_dot_product_efficient_attention.default
+    torch.ops.aten._scaled_dot_product_flash_attention.default, None
 )
-TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten._scaled_dot_product_flash_attention.default)
 
 REPLACEABLE_ATEN_OPS = {
     torch.ops.aten._scaled_dot_product_efficient_attention.default,
@@ -59,6 +62,7 @@ def replace_variants_of_sdpa(
                 elif len(node.args) == 5:
                     query, key, value, attn_mask, is_causal = node.args
                     dropout_p = 0.0
+
                 else:
                     raise ValueError(
                         f"Unexpected number of arguments for {node.target} in the graph"
@@ -71,6 +75,8 @@ def replace_variants_of_sdpa(
                     query, key, value, dropout_p, is_causal, return_debug_mask = (
                         node.args
                     )
+                if len(node.args) == 5:
+                    query, key, value, dropout_p, is_causal = node.args
                 elif len(node.args) == 3:
                     query, key, value = node.args
                     dropout_p = 0.0
@@ -79,20 +85,21 @@ def replace_variants_of_sdpa(
                     raise ValueError(
                         f"Unexpected number of arguments for {node.target} in the graph"
                     )
-            if attn_mask is not None:
-                logger.warning(
-                    f"This current version of SDPA converter does not support attn_mask for {node.target} in the graph. Ignoring it and using is_causal=True configuration."
-                )
-
-            modified_input_args = (query, key, value, None, dropout_p, is_causal)
 
+            logger.warning(
+                f"This current version of SDPA converter only supports attn_mask = None, dropout_p = 0.0 and is_causal = True configuration. This could cause issues with accuracy for models with different configurations."
+            )
+            modified_input_args = (query, key, value, None, dropout_p, True)
             # Create a new node with torch.nn.functional.scaled_dot_product_attention
             # The input args is (query, key, value, is_causal). kwargs has scale
             with gm.graph.inserting_after(node):
                 new_node = gm.graph.call_function(
                     torch.nn.functional.scaled_dot_product_attention,
                     args=modified_input_args,
-                    kwargs={"scale": node.kwargs.get("scale", None)},
+                    kwargs={
+                        "scale": node.kwargs.get("scale", None),
+                        "use_fp32_acc": settings.use_fp32_acc,
+                    },
                 )
 
                 # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
@@ -113,7 +120,7 @@ def replace_variants_of_sdpa(
     # Clean up the graph
     clean_up_graph_after_modifications(gm)
 
-    logger.info(
+    logger.debug(
         "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention"
     )
     return gm
diff --git a/examples/dynamo/sdpa_converter.py b/tools/llm/torchtrt_ext/sdpa_converter.py
similarity index 51%
rename from examples/dynamo/sdpa_converter.py
rename to tools/llm/torchtrt_ext/sdpa_converter.py
index 903324dff5..47083c7b48 100644
--- a/examples/dynamo/sdpa_converter.py
+++ b/tools/llm/torchtrt_ext/sdpa_converter.py
@@ -62,25 +62,15 @@ def scaled_dot_product_attention(
 ) -> TRTTensor:
     # TODO: Handle attn_mask and is_causal arguments in the future
     query, key, value, attn_mask, dropout_p, is_causal = args
-    logger.info(
-        "Ignoring attn_mask and is_causal arguments provided by the original graph. "
-        "This converter expects is_causal to be an input to the graph. For prefill phase, is_causal=True "
-        "and for generate phase, is_causal=False since we pass only 1 input token at a time"
-    )
 
     # TODO: remove this once we have a better way to handle the causal mask
     scale = kwargs.get("scale", None)
     source_ir = SourceIR.ATEN
+    is_causal = True
     # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
-    mm = impl.matmul.matrix_multiply(
-        ctx,
-        target,
-        source_ir,
-        name + "_mm",
-        query,
-        key,
-        other_matrix_op=trt.MatrixOperation.TRANSPOSE,
-    )
+    use_fp32_acc = kwargs.get("use_fp32_acc", False)
+    query_dtype = query.dtype
+
     if scale is None:
         scale = query.shape[-1]
         if scale < 0:
@@ -90,80 +80,106 @@ def scaled_dot_product_attention(
         else:
             # static shape
             sqrt_scaled = math.sqrt(scale)
-        scaled = impl.elementwise.div(
+        key = impl.elementwise.div(
             ctx,
             target,
             source_ir,
             name + "_scale",
-            mm,
+            key,
             sqrt_scaled,
         )
     else:
-        scaled = impl.elementwise.mul(
+        key = impl.elementwise.mul(
             ctx,
             target,
             source_ir,
             name + "_scale",
-            mm,
+            key,
             scale,
         )
 
-    # If is_causal is True, we need to generate a causal mask
-    if is_causal:
-        L, S = query.shape[-2], key.shape[-2]
-        if L >= 0 and S >= 0:
-            # static shape
-            attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype))
-            temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
-            attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
-            attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")
-        else:
-            # if any of the L or S is dynamic shape
-            if L < 0:
-                L = impl.shape.shape(
-                    ctx, target, source_ir, name + "_shape_0", query, 2
-                )
-            if S < 0:
-                S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2)
-
-            # generate the mask tensor
-            tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S)
-
-            temp_mask = impl.unary.logical_not(
-                ctx, target, source_ir, name + "_logical_not", tril_tensor
-            )
-            temp_mask_casted = cast_trt_tensor(
-                ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir
-            )
-            one_minus_temp_mask = impl.elementwise.sub(
-                ctx,
-                target,
-                source_ir,
-                name + "_one_minus_temp_mask",
-                1.0,
-                temp_mask_casted,
-            )
-            attn_bias = impl.unary.log(
-                ctx, target, source_ir, name + "_log", one_minus_temp_mask
-            )
-
-        scaled_add_attn_bias = impl.elementwise.add(
-            ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias
+    if use_fp32_acc and query_dtype == trt.float16:
+        query = cast_trt_tensor(
+            ctx, query, trt.float32, name + "_query_cast_to_fp32", target, source_ir
+        )
+        key = cast_trt_tensor(
+            ctx, key, trt.float32, name + "_key_cast_to_fp32", target, source_ir
         )
+
+    mm = impl.matmul.matrix_multiply(
+        ctx,
+        target,
+        source_ir,
+        name + "_mm",
+        query,
+        key,
+        other_matrix_op=trt.MatrixOperation.TRANSPOSE,
+    )
+
+    if use_fp32_acc:
+        mm = cast_trt_tensor(
+            ctx, mm, query_dtype, name + "_mm_cast_to_fp16", target, source_ir
+        )
+
+    L, S = query.shape[-2], key.shape[-2]
+    if L >= 0 and S >= 0:
+        # static shape
+        attn_bias = np.zeros((L, S), dtype=dtype._from(query_dtype).to(np.dtype))
+        temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
+        attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
+        attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")
     else:
-        scaled_add_attn_bias = scaled
+        # if any of the L or S is dynamic shape
+        if L < 0:
+            L = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", query, 2)
+        if S < 0:
+            S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2)
 
-    # Create a if condition to check if is_causal is True
-    if isinstance(is_causal, TRTTensor):
-        if_layer = ctx.net.add_if_conditional()
-        condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, scaled
-        if_layer.set_condition(condition)
-        output_layer = if_layer.add_output(true_branch, false_branch)
-        scaled_add_attn_bias = output_layer.get_output(0)
+        # generate the mask tensor
+        tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S)
+
+        temp_mask = impl.unary.logical_not(
+            ctx, target, source_ir, name + "_logical_not", tril_tensor
+        )
+
+        # This need_mask determines if we want to use the causal mask or not
+        # When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask.
+        # So need_mask will be all False values in this case.
+        # TODO: Implement more general case where L != 1 and S != L
+        need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S)
+        temp_mask = impl.elementwise.logical_and(
+            ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask
+        )
+        temp_mask_casted = cast_trt_tensor(
+            ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir
+        )
+
+        one_minus_temp_mask = impl.elementwise.sub(
+            ctx,
+            target,
+            source_ir,
+            name + "_one_minus_temp_mask",
+            1.0,
+            temp_mask_casted,
+        )
+        attn_bias = impl.unary.log(
+            ctx, target, source_ir, name + "_log", one_minus_temp_mask
+        )
+
+    scaled_add_attn_bias = impl.elementwise.add(
+        ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias
+    )
 
     softmax = impl.normalization.softmax(
         ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False
     )
+    if use_fp32_acc:
+        softmax = cast_trt_tensor(
+            ctx, softmax, trt.float32, name + "_softmax_cast_to_fp32", target, source_ir
+        )
+        value = cast_trt_tensor(
+            ctx, value, trt.float32, name + "_value_cast_to_fp32", target, source_ir
+        )
     out = impl.matmul.matrix_multiply(
         ctx,
         target,
@@ -172,5 +188,9 @@ def scaled_dot_product_attention(
         softmax,
         value,
     )
+    if use_fp32_acc:
+        out = cast_trt_tensor(
+            ctx, out, query_dtype, name + "_out_cast_to_fp16", target, source_ir
+        )
 
     return out
diff --git a/tools/llm/utils.py b/tools/llm/utils.py
new file mode 100644
index 0000000000..2c3434b0ed
--- /dev/null
+++ b/tools/llm/utils.py
@@ -0,0 +1,244 @@
+import copy
+import timeit
+
+import numpy as np
+import torch
+from transformers import StoppingCriteriaList
+from transformers.generation.stopping_criteria import (
+    EosTokenCriteria,
+    MaxLengthCriteria,
+)
+
+
+def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
+    """
+    Exports the LLM model into an ExportedProgram with dynamic shapes.
+    In the case of guard failures due to some PyTorch kernel implements, we also
+    try to re-export the graph by expressing them as runtime assert nodes
+    """
+    with torch.no_grad():
+        # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604
+        seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len)
+        position_ids = torch.arange(inputs.shape[1]).unsqueeze(0).to(inputs.device)
+        try:
+            print("Trying to export the model using torch.export.export()..")
+            # strict=False only enables aotautograd tracing and excludes dynamo.
+            ep = torch.export.export(
+                model,
+                args=(inputs,),
+                kwargs={"position_ids": position_ids},
+                dynamic_shapes=({1: seq_len}, {1: seq_len}),
+                strict=False,
+            )
+        except:
+            print(
+                "Trying torch.export._trace._export to trace the graph since torch.export.export() failed"
+            )
+            # This API is used to express the constraint violation guards as asserts in the graph.
+            ep = torch.export._trace._export(
+                model,
+                args=(inputs,),
+                kwargs={"position_ids": position_ids},
+                dynamic_shapes=({1: seq_len}, {1: seq_len}),
+                strict=False,
+                allow_complex_guards_as_runtime_asserts=True,
+            )
+
+    return ep
+
+
+def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule):
+    """
+    Extracts and returns zeroed static KV cache tensors from a torch.fx.GraphModule. This should only be used for static cache_v1 and static cache_v2.
+
+    This function identifies placeholder nodes in the graph that represent KV cache tensors,
+    and creates zeroed tensors with the same shape, dtype, and device as the original placeholders.
+
+    Args:
+        model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders
+
+    Returns:
+        tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph
+    """
+    # placeholder nodes are expected to be in the following order:
+    # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx
+    placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"]
+    # The first two inputs are input_ids, position_ids. The last two inputs are start_idx, end_idx. In between are the KV cache tensors.
+    kv_cache_inputs = placeholder_nodes[2:-2]
+    zeroed_kv_cache_inputs = []
+    for input in kv_cache_inputs:
+        zeroed_kv_cache_inputs.append(
+            torch.zeros(
+                input.meta["val"].shape,
+                dtype=input.meta["val"].dtype,
+                device=torch.device("cuda:0"),
+            )
+        )
+
+    return tuple(zeroed_kv_cache_inputs)
+
+
+def get_zeroed_dynamic_cache_inputs(model: torch.fx.GraphModule):
+    """
+    Extracts and returns zeroed KV cache tensors from a torch.fx.GraphModule. This should only be used for dynamic cache.
+
+    This function identifies placeholder nodes in the graph that represent KV cache tensors,
+    and creates zeroed tensors with the same shape, dtype, and device as the original placeholders.
+
+    Args:
+        model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders
+
+    Returns:
+        tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph
+    """
+    # placeholder nodes are expected to be in the following order:
+    # input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx
+    placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"]
+    # The first two inputs are input_ids, position_ids. The last input is is_generate. In between are the KV cache tensors.
+    kv_cache_inputs = placeholder_nodes[2:-1]
+    zeroed_kv_cache_inputs = []
+    for input in kv_cache_inputs:
+        zeroed_kv_cache_inputs.append(
+            torch.zeros(
+                input.meta["val"].shape,
+                dtype=input.meta["val"].dtype,
+                device=torch.device("cuda:0"),
+            )
+        )
+
+    return tuple(zeroed_kv_cache_inputs)
+
+
+def generate(model, input_seq, max_output_seq_length, eos_token_id, benchmark=True):
+    """
+    Greedy decoding of the model. This generates up to max_tokens.
+    """
+    stopping_criteria = StoppingCriteriaList(
+        [
+            MaxLengthCriteria(max_length=max_output_seq_length),
+            EosTokenCriteria(eos_token_id=eos_token_id),
+        ]
+    )
+    isl = input_seq.shape[1]
+    osl = max_output_seq_length - isl
+
+    num_tokens_generated = 0
+    while num_tokens_generated < osl:
+        position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda()
+        outputs = model(input_seq, position_ids=position_ids)
+        logits = outputs.logits
+        next_token_logits = logits[:, -1, :]
+        next_tokens = torch.argmax(next_token_logits, dim=-1)
+        input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1)
+        num_tokens_generated += 1
+        # TODO: Handle batch in this check
+        if not benchmark and stopping_criteria(input_seq, logits).item():
+            break
+
+    return input_seq
+
+
+def generate_with_static_cache(model, input_seq, max_output_seq_length, eos_token_id):
+    """
+    Greedy decoding of the model with static KV cache.
+    """
+    start_idx = 0
+    end_idx = input_seq.shape[1]
+    position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda()
+    output_seq = input_seq.clone()
+    # TODO: Confirm this: When end_idx = max_output_seq_length-1, number of tokens generated = OSL
+    num_tokens_generated = 0
+    kv_cache = get_zeroed_static_cache_inputs(model)
+    while end_idx < max_output_seq_length:
+        position_ids = (
+            torch.tensor([[start_idx]], dtype=torch.int64).cuda()
+            if input_seq.shape[1] == 1
+            else position_ids
+        )
+        input_signature = (input_seq, position_ids, *kv_cache, start_idx, end_idx)
+        logits_keys_values = model(*input_signature)
+        num_tokens_generated += 1
+        logits = logits_keys_values[0]
+        kv_cache = logits_keys_values[1:]
+        next_token_logits = logits[:, -1, :]
+        next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)
+        output_seq = torch.cat([output_seq, next_tokens], dim=-1)
+        input_seq = next_tokens
+        start_idx = end_idx
+        end_idx = start_idx + 1
+    return output_seq
+
+
+def generate_with_dynamic_cache(model, input_seq, max_output_seq_length, eos_token_id):
+    """
+    Greedy decoding of the model with dynamic KV cache.
+    """
+    position_ids = torch.arange(input_seq.shape[1]).unsqueeze(0).cuda()
+    output_seq = input_seq.clone()
+    num_output_tokens = max_output_seq_length - input_seq.shape[1]
+    num_tokens_generated = 0
+    kv_cache = get_zeroed_dynamic_cache_inputs(model)
+    last_position_id = position_ids[-1, -1].item()
+    breakpoint()
+    while num_tokens_generated < num_output_tokens:
+        is_generate = False if input_seq.shape[1] > 1 else True
+        position_ids = (
+            torch.tensor([[last_position_id + 1]], dtype=torch.int64).cuda()
+            if input_seq.shape[1] == 1
+            else position_ids
+        )
+        input_signature = (input_seq, position_ids, *kv_cache, is_generate)
+        logits_keys_values = model(*input_signature)
+        num_tokens_generated += 1
+        logits = logits_keys_values[0]
+        kv_cache = logits_keys_values[1:]
+        next_token_logits = logits[:, -1, :]
+        next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True)
+        output_seq = torch.cat([output_seq, next_tokens], dim=-1)
+        input_seq = next_tokens
+        last_position_id += 1
+    return output_seq
+
+
+def time_generate(
+    generate_fn, model, inputs, output_seq_length, eos_token_id, iterations=10
+):
+    """
+    Measure the time for generating a sentence over certain number of iterations
+    """
+    timings = []
+    for _ in range(iterations):
+        start_time = timeit.default_timer()
+        _ = generate_fn(model, inputs, output_seq_length, eos_token_id)
+        torch.cuda.synchronize()
+        end_time = timeit.default_timer()
+        timings.append(end_time - start_time)
+
+    return timings
+
+
+def record_stats(backend, timings, precision, batch_size=1, compile_time_s=None):
+    """
+    Records different timing stats and adds it to the result
+    """
+    times = np.array(timings)
+    speeds = batch_size / times
+    time_mean = np.mean(times).item()
+    time_med = np.median(times).item()
+    time_99th = np.percentile(times, 99).item()
+    time_std = np.std(times, ddof=0).item()
+    speed_mean = np.mean(speeds).item()
+    speed_med = np.median(speeds).item()
+
+    stats = {
+        "Backend": backend,
+        "Precision": precision,
+        "Batch size": batch_size,
+        "Median(FPS)": speed_med,
+        "Mean(FPS)": speed_mean,
+        "Median-Latency(ms)": time_med * 1000,
+        "Mean-Latency(ms)": time_mean * 1000,
+        "Latency-StdDev(ms)": time_std * 1000,
+        "Compile Time(s)": compile_time_s,
+    }
+    return stats