diff --git a/examples/models/checkpoint.py b/examples/models/checkpoint.py index cd916142d9c..592ebab1450 100644 --- a/examples/models/checkpoint.py +++ b/examples/models/checkpoint.py @@ -67,7 +67,7 @@ def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]: if value.dtype != dtype ] if len(mismatched_dtypes) > 0: - raise ValueError( + print( f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}" ) return dtype