diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index f31305568d..a7b91eec34 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -21,14 +21,14 @@ repos:
           - id: clang-format
             types_or: [c++, c, cuda]
     - repo: https://github.com/keith/pre-commit-buildifier
-      rev: 6.4.0
+      rev: 8.0.3
       hooks:
           - id: buildifier
             args:
                 - --warnings=all
           - id: buildifier-lint
     - repo: https://github.com/abravalheri/validate-pyproject
-      rev: v0.23
+      rev: v0.24.1
       hooks:
           - id: validate-pyproject
     - repo: https://github.com/pycqa/isort
@@ -37,17 +37,17 @@ repos:
           - id: isort
             name: isort (python)
     - repo: https://github.com/pre-commit/mirrors-mypy
-      rev: "v1.9.0"
+      rev: "v1.15.0"
       hooks:
           - id: mypy
             exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py"
     - repo: https://github.com/astral-sh/ruff-pre-commit
       # Ruff version.
-      rev: v0.3.3
+      rev: v0.11.7
       hooks:
           - id: ruff
     - repo: https://github.com/psf/black
-      rev: 24.3.0
+      rev: 25.1.0
       hooks:
           - id: black
             exclude: ^examples/custom_converters/elu_converter/setup.py|^docs
@@ -57,7 +57,7 @@ repos:
           - id: typos
     - repo: https://github.com/astral-sh/uv-pre-commit
       # uv version.
-      rev: 0.5.5
+      rev: 0.7.1
       hooks:
           # Update the uv lockfile
           - id: uv-lock
diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py
index c706c345d6..e0a78e1a0b 100644
--- a/py/torch_tensorrt/_enums.py
+++ b/py/torch_tensorrt/_enums.py
@@ -80,6 +80,12 @@ class dtype(Enum):
     :meta hide-value:
     """
 
+    f4 = auto()
+    """4 bit floating-point number, equivalent to ``dtype.fp4`` and ``dtype.float4``
+
+    :meta hide-value:
+    """
+
     uint8 = u8
     int8 = i8
 
@@ -91,6 +97,9 @@ class dtype(Enum):
     float8 = f8
     fp8 = f8
 
+    float4 = f4
+    fp4 = f4
+
     half = f16
     fp16 = f16
     float16 = f16
@@ -162,6 +171,8 @@ def _from(
                 return dtype.i32
             elif t == torch.float8_e4m3fn:
                 return dtype.f8
+            elif t == torch.float4_e2m1fn_x2:
+                return dtype.f4
             elif t == torch.half:
                 return dtype.f16
             elif t == torch.float:
@@ -188,6 +199,8 @@ def _from(
                 return dtype.i8
             elif t == trt.DataType.FP8:
                 return dtype.f8
+            elif t == trt.DataType.FP4:
+                return dtype.fp4
             elif t == trt.DataType.INT32:
                 return dtype.i32
             elif t == trt.DataType.INT64:
@@ -357,6 +370,8 @@ def to(
                 return torch.long
             elif self == dtype.f8:
                 return torch.float8_e4m3fn
+            elif self == dtype.f4:
+                return torch.float4_e2m1fn_x2
             elif self == dtype.f16:
                 return torch.half
             elif self == dtype.f32:
@@ -394,6 +409,8 @@ def to(
                 return trt.DataType.BOOL
             elif self == dtype.bf16:
                 return trt.DataType.BF16
+            elif self == dtype.f4:
+                return trt.DataType.FP4
             elif use_default:
                 return trt.DataType.FLOAT
             else:
@@ -410,6 +427,8 @@ def to(
                 return np.int64
             elif self == dtype.f16:
                 return np.float16
+            elif self == dtype.f4:
+                return np.float4_e2m1fn_x2
             elif self == dtype.f32:
                 return np.float32
             elif self == dtype.f64:
diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py
index c109e3fa3c..830faf3373 100644
--- a/py/torch_tensorrt/dynamo/_compiler.py
+++ b/py/torch_tensorrt/dynamo/_compiler.py
@@ -257,7 +257,7 @@ def cross_compile_for_windows(
             x in enabled_precisions for x in {torch.float32, dtype.f32}
         ):
             raise AssertionError(
-                f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
+                f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: {_defaults.ENABLED_PRECISIONS}). enabled_precisions should not be used when use_explicit_typing=True"
             )
 
     if use_fp32_acc:
@@ -588,7 +588,7 @@ def compile(
             x in enabled_precisions for x in {torch.float32, dtype.f32}
         ):
             raise AssertionError(
-                f"When use_explicit_typing is enabled, only torch.float32 is allowed in the enabled_precisions but found {enabled_precisions}"
+                f"use_explicit_typing was set to True, however found that enabled_precisions was also specified (saw: {enabled_precisions}, expected: {_defaults.ENABLED_PRECISIONS}). enabled_precisions should not be used when use_explicit_typing=True"
             )
 
     if use_fp32_acc:
diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py
index aafd1072f4..921cb37646 100644
--- a/py/torch_tensorrt/dynamo/_defaults.py
+++ b/py/torch_tensorrt/dynamo/_defaults.py
@@ -29,7 +29,14 @@
 REQUIRE_FULL_COMPILATION = False
 DRYRUN = False
 HARDWARE_COMPATIBLE = False
-SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
+SUPPORTED_KERNEL_PRECISIONS = {
+    dtype.f32,
+    dtype.f16,
+    dtype.bf16,
+    dtype.i8,
+    dtype.f8,
+    dtype.f4,
+}
 TIMING_CACHE_PATH = os.path.join(
     tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin"
 )
diff --git a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py
index 141b68f3e7..1c4926bcfa 100644
--- a/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py
+++ b/py/torch_tensorrt/dynamo/conversion/_ConversionContext.py
@@ -1,6 +1,8 @@
 from dataclasses import dataclass, field
+from typing import Union
 
 import numpy as np
+import torch
 from torch_tensorrt.dynamo._settings import CompilationSettings
 from torch_tensorrt.dynamo.types import TRTNetwork
 
@@ -21,3 +23,9 @@ class ConversionContext:
     )
     requires_output_allocator: bool = False
     mapping: dict[str, np.array] = field(default_factory=dict)
+    cpu_weights_reference_holder: dict[str, Union[torch.Tensor, np.array]] = field(
+        default_factory=dict
+    )
+
+    def clear_cpu_weights_reference_holder(self) -> None:
+        self.cpu_weights_reference_holder.clear()
diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
index 39a1ed957d..bb1a77b4eb 100644
--- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
+++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
@@ -743,6 +743,8 @@ def run(
         )
         _LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")
 
+        self.ctx.clear_cpu_weights_reference_holder()
+
         self._save_timing_cache(
             builder_config, self.compilation_settings.timing_cache_path
         )
diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
index 9d6602ddca..e542f1d417 100644
--- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
+++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
@@ -619,6 +619,42 @@ def aten_ops_quantize_op(
         )
 
 
+try:
+    import modelopt.torch.quantization as mtq  # noqa: F401
+
+    assert torch.ops.tensorrt.dynamic_block_quantize_op.default
+except Exception as e:
+    _LOGGER.warning(
+        "Unable to import quantize op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models"
+    )
+else:
+
+    @dynamo_tensorrt_converter(
+        torch.ops.tensorrt.dynamic_block_quantize_op.default,
+        supports_dynamic_shapes=True,
+    )
+    def aten_ops_dynamic_block_quantize_op(
+        ctx: ConversionContext,
+        target: Target,
+        args: Tuple[Argument, ...],
+        kwargs: Dict[str, Argument],
+        name: str,
+    ) -> Union[TRTTensor, Sequence[TRTTensor]]:
+        return impl.dynamic_block_quantize.quantize(
+            ctx,
+            target,
+            SourceIR.ATEN,
+            name,
+            args[0],
+            args[1],
+            args[2],
+            args[3],
+            args[4],
+            args[5],
+            args[6],
+        )
+
+
 @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim, supports_dynamic_shapes=True)
 @dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims, supports_dynamic_shapes=True)
 def aten_ops_squeeze(
diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py
index 685f40b254..b5b7cce868 100644
--- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py
+++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py
@@ -326,6 +326,7 @@ def create_constant(
     name: str,
     dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]],
     min_rank: Optional[int] = 1,
+    target_quantized_type: Optional[TRTDataType] = None,
 ) -> TRTTensor:
     """
     Add a TensorRT constant layer whose value is `value` to `ctx.net`.
@@ -338,6 +339,7 @@ def create_constant(
         dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
             If a dtype is given, we will convert the type of the given `value` to this dtype.
         min_rank (int): minimum rank of the constant tensor.
+        target_quantized_type (Optional[TRTDataType]): If a quantized type is given, we will convert the type of the given `value` to this dtype.
     Returns:
         A TensorRT ITensor that represents the given value.
     """
@@ -361,12 +363,48 @@ def create_constant(
             shape = list(torch_value.shape)
 
         if torch_value is not None:
+            if torch_value.dtype == torch.float8_e4m3fn:
+                weights = trt.Weights(
+                    type=trt.DataType.FP8,
+                    ptr=torch_value.data_ptr(),
+                    count=torch_value.numel(),
+                )
+                constant = ctx.net.add_constant(
+                    shape,
+                    weights,
+                )
+                constant.name = name
+                ctx.cpu_weights_reference_holder[name + " FP8_CONSTANT"] = torch_value
+                return constant.get_output(0)
+
+            if torch_value.dtype == torch.uint8:
+                if (
+                    target_quantized_type is None
+                    or target_quantized_type != trt.DataType.FP4
+                ):
+                    # Iconstant layer does not support Uint8, it only support that FP4 data packed in uint8
+                    raise ValueError(
+                        "Currently supported target_quantized_type for uint8 is FP4, got {target_quantized_type=}"
+                    )
+                shape[-1] = shape[-1] * 2
+                weights = trt.Weights(
+                    type=trt.DataType.FP4,
+                    ptr=torch_value.data_ptr(),
+                    count=torch_value.numel() * 2,
+                )
+                constant = ctx.net.add_constant(
+                    shape,
+                    weights,
+                )
+                constant.name = name
+                ctx.cpu_weights_reference_holder[name + " FP4_CONSTANT"] = torch_value
+                return constant.get_output(0)
+
             if torch_value.dtype == torch.bfloat16:
                 torch_value_fp32 = torch_value.to(torch.float32)
                 numpy_value = torch_value_fp32.numpy()
             else:
                 numpy_value = torch_value.numpy()
-
             ctx.mapping[name + " CONSTANT"] = numpy_value.reshape(-1)
             constant = ctx.net.add_constant(
                 shape,
@@ -381,7 +419,6 @@ def create_constant(
                     trt.DataType.BF16,
                     name + "_bf16_cast",
                 )
-
             return constant.get_output(0)
         else:
             raise ValueError(
@@ -395,6 +432,7 @@ def get_trt_tensor(
     name: str,
     dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None,
     min_rank: int = 1,
+    target_quantized_type: Optional[TRTDataType] = None,
 ) -> TRTTensor:
     """
     Given a value of random type, we try to convert it to a TensorRT ITensor.
@@ -408,6 +446,7 @@ def get_trt_tensor(
         dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
             If dtype is provided, the given value will be converted to this dtype.
         min_rank (int): minimum rank of the constant tensor.
+        target_quantized_type (Optional[TRTDataType]): If a quantized type is given, we will convert the type of the given `value` to this dtype.
     Returns:
         A TensorRT ITensor that represents the given value.
     """
@@ -420,7 +459,9 @@ def get_trt_tensor(
             input_val = input_val.astype(np.float32)
 
     if isinstance(input_val, (torch.Tensor, np.ndarray, int, float, bool)):
-        return create_constant(ctx, input_val, name, dtype, min_rank)
+        return create_constant(
+            ctx, input_val, name, dtype, min_rank, target_quantized_type
+        )
     elif isinstance(input_val, TRTTensor):
         return input_val
     else:
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py
index df580b1516..10af2ad892 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py
@@ -7,6 +7,7 @@
     condition,
     conv,
     deconv,
+    dynamic_block_quantize,
     elementwise,
     embedding,
     full,
diff --git a/py/torch_tensorrt/dynamo/conversion/impl/dynamic_block_quantize.py b/py/torch_tensorrt/dynamo/conversion/impl/dynamic_block_quantize.py
new file mode 100644
index 0000000000..f76a84dea5
--- /dev/null
+++ b/py/torch_tensorrt/dynamo/conversion/impl/dynamic_block_quantize.py
@@ -0,0 +1,272 @@
+from typing import Optional, Union
+
+import numpy as np
+import tensorrt as trt
+import torch
+from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
+from torch.fx.node import Target
+from torch_tensorrt.dynamo._SourceIR import SourceIR
+from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
+from torch_tensorrt.dynamo.conversion.converter_utils import (
+    get_trt_tensor,
+)
+from torch_tensorrt.fx.converters.converter_utils import set_layer_name
+from torch_tensorrt.fx.types import TRTTensor
+
+
+def quantize(
+    ctx: ConversionContext,
+    target: Target,
+    source_ir: Optional[SourceIR],
+    name: str,
+    input_tensor: TRTTensor,
+    block_size: int,
+    amax: Union[np.ndarray, torch.Tensor],
+    num_bits: int,
+    exponent_bits: int,
+    scale_num_bits: int,
+    scale_exponent_bits: int,
+) -> TRTTensor:
+    """
+    Adds quantize and dequantize ops (QDQ) which quantize to FP4 based
+    on the output_type set and dequantizes them back.
+    """
+    if len(input_tensor.shape) not in (2, 3):
+        raise ValueError(
+            f"dynamic_block_quantize converter received an input of {input_tensor.shape} shape. Supported shapes: 2D or 3D"
+        )
+    with unset_fake_temporarily():
+        axis = -1
+        global_scale = _calculate_global_scale(ctx, name, amax)
+        if ".weight_quantizer" in name:
+            output = _static_double_quantize(
+                ctx,
+                target,
+                source_ir,
+                name,
+                input_tensor,
+                global_scale,
+                axis,
+            )
+        elif ".input_quantizer" in name:
+            output = _dynamic_double_quantize(
+                ctx,
+                target,
+                source_ir,
+                name,
+                input_tensor,
+                global_scale,
+                axis,
+            )
+        else:
+            raise ValueError(
+                f"quantizer received an input of {name}. Supported values: weight_quantizer | input_quantizer"
+            )
+        return output
+
+
+def _dynamic_double_quantize(
+    ctx: ConversionContext,
+    target: Target,
+    source_ir: Optional[SourceIR],
+    name: str,
+    input_tensor: TRTTensor,
+    global_scale: torch.Tensor,
+    axis: int = -1,
+    block_size: int = 16,
+    output_type: trt.DataType = trt.DataType.FP4,
+    scale_type: trt.DataType = trt.DataType.FP8,
+) -> TRTTensor:
+    """
+    quantize input tensor to fp4
+    Parameters:
+        ctx: ConversionContext,
+        target: Target,
+        source_ir: Optional[SourceIR]
+        name: str
+        input_tensor : TRTTensor (On GPU)
+            The input TRTTensor.
+        global_scale : Tensor (On GPU)
+            The global per-tensor scaling factor. It should contain only 1 element.
+        axis : int
+            The axis to quantize. Default is -1 (the last axis).
+        block_size : int
+            The block size for quantization. Default is 16.
+        output_type : trt.DataType
+            The data type for quantized data. Default is FP4.
+        scale_type : trt.DataType
+            The data type for block scale. Default is FP8.
+
+    """
+    global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale")
+
+    if input_tensor.dtype not in [
+        trt.DataType.HALF,
+        trt.DataType.FLOAT,
+        trt.DataType.BF16,
+    ]:
+        raise ValueError(
+            f"Currently supported input tensor type is float16 | float32 | bfloat16, got Unsupported dtype: {input_tensor.dtype}"
+        )
+    # dynamic quantize input tensor to fp4
+    dynamic_quantize_layer = ctx.net.add_dynamic_quantize(
+        input_tensor,
+        axis,
+        block_size,
+        output_type,
+        scale_type,
+    )
+    dynamic_quantize_layer.set_input(1, global_scale)
+    set_layer_name(
+        dynamic_quantize_layer, target, name + "_dynamic_quantize", source_ir
+    )
+    quantized_data_in_fp4 = dynamic_quantize_layer.get_output(0)
+    quantized_scale_in_fp8 = dynamic_quantize_layer.get_output(1)
+
+    return _double_dequantize(
+        ctx,
+        target,
+        source_ir,
+        name,
+        quantized_data_in_fp4,
+        quantized_scale_in_fp8,
+        global_scale,
+        axis,
+        input_tensor.dtype,
+    )
+
+
+def _double_dequantize(
+    ctx: ConversionContext,
+    target: Target,
+    source_ir: Optional[SourceIR],
+    name: str,
+    quantized_data_in_fp4: TRTTensor,
+    quantized_scale_in_fp8: TRTTensor,
+    global_scale: torch.Tensor,
+    axis: int = -1,
+    output_type: trt.DataType = trt.DataType.FLOAT,
+) -> TRTTensor:
+    """
+    double dequantize will first dequantize scale from fp8 to orignal dtype(default is float32)
+    and then dequantize data from fp4 to orignal dtype(default is float32)
+    Parameters:
+        ctx: ConversionContext,
+        target: Target,
+        source_ir: Optional[SourceIR]
+        name: str
+        quantized_data_in_fp4: TRTTensor
+        quantized_scale_in_fp8: TRTTensor
+        global_scale: torch.Tensor
+        axis: int
+        output_type: trt.DataType
+    """
+    # dequantize scale from fp8 to orignal dtype(default is float32)
+    dequantize_scale_layer = ctx.net.add_dequantize(
+        quantized_scale_in_fp8, global_scale, output_type
+    )
+    dequantize_scale_layer.axis = axis
+    dequantize_scale_layer.to_type = output_type
+    set_layer_name(
+        dequantize_scale_layer, target, name + "_dequantize_scale", source_ir
+    )
+    dequantized_scale = dequantize_scale_layer.get_output(0)
+
+    # dequantize quantized_data_in_fp4 from fp4 to orignal dtype(default is float32)
+    dequantize_data_layer = ctx.net.add_dequantize(
+        quantized_data_in_fp4, dequantized_scale, output_type
+    )
+    dequantize_data_layer.axis = axis
+    dequantize_data_layer.to_type = output_type
+    set_layer_name(dequantize_data_layer, target, name + "_dequantize_data", source_ir)
+    dequantized_data = dequantize_data_layer.get_output(0)
+    return dequantized_data
+
+
+def _static_double_quantize(
+    ctx: ConversionContext,
+    target: Target,
+    source_ir: Optional[SourceIR],
+    name: str,
+    weights_tensor: torch.Tensor,
+    global_scale: torch.Tensor,
+    axis: int,
+) -> TRTTensor:
+    """
+    Parameters:
+        ctx: ConversionContext,
+        target: Target,
+        source_ir: Optional[SourceIR],
+        name: str,
+        weights_tensor : Tensor (On GPU)
+            The input tensor for weights.
+        global_scale : Tensor (On GPU)
+            The global per-tensor scaling factor. It should contain only 1 element.
+        axis: int
+            The axis to quantize. Default is -1 (the last axis).
+    Returns:
+        quantized data tensor in fp4
+    """
+
+    import modelopt.core.torch.quantization.qtensor.nvfp4_tensor as nvfp4_tensor
+
+    if weights_tensor.dtype == torch.float16:
+        original_dtype = trt.DataType.HALF
+    elif weights_tensor.dtype == torch.float32:
+        original_dtype = trt.DataType.FLOAT
+    elif weights_tensor.dtype == torch.bfloat16:
+        original_dtype = trt.DataType.BF16
+    else:
+        raise ValueError(
+            f"Currently supported weights tensor type is float16 | float32 | bfloat16, got Unsupported dtype: {weights_tensor.dtype}"
+        )
+    block_scale_fp8 = nvfp4_tensor.NVFP4QTensor.get_weights_scaling_factor(
+        weights_tensor,
+        16,
+        global_scale,
+    )[0]
+    weights_tensor_fp4 = nvfp4_tensor.NVFP4QTensor.quantize(
+        weights_tensor,
+        16,
+        block_scale_fp8,
+        global_scale,
+    )[0]._quantized_data
+
+    block_scale_fp8 = get_trt_tensor(
+        ctx,
+        block_scale_fp8,
+        name + "_block_scale_fp8",
+        target_quantized_type=trt.DataType.FP8,
+    )
+    global_scale = get_trt_tensor(ctx, global_scale, name + "_global_scale")
+    weights_tensor_fp4 = get_trt_tensor(
+        ctx,
+        weights_tensor_fp4,
+        name + "_weights_fp4",
+        target_quantized_type=trt.DataType.FP4,
+    )
+
+    dequantized_data = _double_dequantize(
+        ctx,
+        target,
+        source_ir,
+        name,
+        weights_tensor_fp4,
+        block_scale_fp8,
+        global_scale,
+        axis,
+        original_dtype,
+    )
+    return dequantized_data
+
+
+def _calculate_global_scale(
+    ctx: ConversionContext,
+    name: str,
+    amax: torch.Tensor,
+) -> torch.Tensor:
+    # calculate global scale (the global per-tensor scaling factor, should only contain 1 element)
+    assert len(amax.shape) == 0, "amax should be a scalar"
+    global_scale = amax / 6 / 448
+    global_scale.masked_fill_(global_scale == 0, 1.0)
+    return global_scale
diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
index f13d9a2375..19e97ef099 100644
--- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
+++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
@@ -109,6 +109,7 @@ def is_impure(self, node: torch.fx.node.Node) -> bool:
 
             assert torch.ops.tensorrt.quantize_op.default
             quantization_ops.add(torch.ops.tensorrt.quantize_op.default)
+            quantization_ops.add(torch.ops.tensorrt.dynamic_block_quantize_op.default)
         except Exception as e:
             pass
         if quantization_ops and node.target in quantization_ops:
diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py
index 9e22fef929..189da962b5 100644
--- a/tests/py/dynamo/models/test_models_export.py
+++ b/tests/py/dynamo/models/test_models_export.py
@@ -199,6 +199,127 @@ def test_resnet18_half(ir):
     torch._dynamo.reset()
 
 
+@unittest.skipIf(
+    torch.cuda.get_device_capability() < (10, 0),
+    "FP4 quantization requires compute capability 10.0 or later",
+)
+@unittest.skipIf(
+    not importlib.util.find_spec("modelopt"),
+    "ModelOpt is required to run this test",
+)
+@pytest.mark.unit
+def test_base_fp4_dynamic_shapes(ir):
+    import modelopt.torch.quantization as mtq
+    from modelopt.torch.quantization.utils import export_torch_mode
+
+    dtype = torch.float16
+
+    class SimpleNetwork(torch.nn.Module):
+        def __init__(self):
+            super(SimpleNetwork, self).__init__()
+            self.linear1 = torch.nn.Linear(
+                in_features=64, out_features=32, bias=True, dtype=dtype
+            )
+
+        def forward(self, x):
+            x = self.linear1(x)
+            return x
+
+    def calibrate_loop(model):
+        """Simple calibration function for testing."""
+        model(dummy_inputs)
+
+    BATCH_SIZE = torch.export.Dim("BATCH_SIZE", min=16, max=128)
+    batch_size = 64
+    dummy_inputs = torch.ones(batch_size, 64, dtype=dtype).cuda()
+
+    model = SimpleNetwork().eval().cuda()
+
+    quant_cfg = mtq.NVFP4_DEFAULT_CFG
+    mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
+    # model has qdq nodes at this point
+    with torch.no_grad():
+        with export_torch_mode():
+            exp_program = torch.export.export(
+                model, (dummy_inputs,), strict=False, dynamic_shapes=({0: BATCH_SIZE},)
+            )
+
+            trt_model = torchtrt.dynamo.compile(
+                exp_program,
+                inputs=[dummy_inputs],
+                min_block_size=1,
+                debug=True,
+                cache_built_engines=False,
+                reuse_cached_engines=False,
+                use_explicit_typing=True,
+            )
+            batch_size = 128
+            input_tensor = torch.ones(batch_size, 64, dtype=dtype).cuda()
+            expected_output = model(input_tensor)
+            outputs_trt = trt_model(input_tensor)
+            abs_diff = torch.abs(expected_output - outputs_trt)
+            print(f"max/mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}")
+            assert torch.allclose(expected_output, outputs_trt, rtol=0.3, atol=0.3)
+
+
+@unittest.skipIf(
+    torch.cuda.get_device_capability() < (10, 0),
+    "FP4 quantization requires compute capability 10.0 or later",
+)
+@unittest.skipIf(
+    not importlib.util.find_spec("modelopt"),
+    "ModelOpt is required to run this test",
+)
+@pytest.mark.unit
+def test_base_fp4_static_shapes(ir):
+    import modelopt.torch.quantization as mtq
+    from modelopt.torch.quantization.utils import export_torch_mode
+
+    dtype = torch.bfloat16
+
+    class SimpleNetwork(torch.nn.Module):
+        def __init__(self):
+            super(SimpleNetwork, self).__init__()
+            self.linear1 = torch.nn.Linear(
+                in_features=64, out_features=32, bias=True, dtype=dtype
+            )
+
+        def forward(self, x):
+            x = self.linear1(x)
+            return x
+
+    def calibrate_loop(model):
+        """Simple calibration function for testing."""
+        model(input_tensor)
+
+    input_tensor = torch.randn(128, 64, dtype=dtype).cuda()
+
+    model = SimpleNetwork().eval().cuda()
+    expected_output = model(input_tensor)
+
+    quant_cfg = mtq.NVFP4_DEFAULT_CFG
+    mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
+    # model has qdq nodes at this point
+    with torch.no_grad():
+        with export_torch_mode():
+            exp_program = torch.export.export(model, (input_tensor,), strict=False)
+            from torch.fx import passes
+
+            trt_model = torchtrt.dynamo.compile(
+                exp_program,
+                inputs=[input_tensor],
+                min_block_size=1,
+                debug=True,
+                cache_built_engines=False,
+                reuse_cached_engines=False,
+                use_explicit_typing=True,
+            )
+            outputs_trt = trt_model(input_tensor)
+            abs_diff = torch.abs(expected_output - outputs_trt)
+            print(f"max/mean abs_diff: {abs_diff.max().item()=} {abs_diff.mean()=}")
+            assert torch.allclose(expected_output, outputs_trt, rtol=0.3, atol=0.3)
+
+
 @unittest.skipIf(
     torch.cuda.get_device_capability() < (8, 9),
     "FP8 quantization requires compute capability 8.9 or later",
@@ -230,8 +351,8 @@ def calibrate_loop(model):
 
     input_tensor = torch.randn(1, 10).cuda()
     model = SimpleNetwork().eval().cuda()
-
     quant_cfg = mtq.FP8_DEFAULT_CFG
+
     mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
     # model has FP8 qdq nodes at this point
     output_pyt = model(input_tensor)