diff --git a/torchchat/utils/build_utils.py b/torchchat/utils/build_utils.py index c9a9eae14..fd30f87d5 100644 --- a/torchchat/utils/build_utils.py +++ b/torchchat/utils/build_utils.py @@ -111,16 +111,30 @@ def use_et_backend() -> bool: ########################################################################## ### set and get target precision for this model ### -precision = torch.float32 +precision = None def set_precision(dtype): + """set_precision() is a torchchat-internal API that records the dtype we're building the model for. +The precision is recorded for future queries by get_precision(), so that when building a model, +or performing optimizations, we can query the type the user is building the model for. +This is an informational value that can be used when we want to know what type to build for (e.g., a kv cache). +Changing the `precision` does not change the precision of the model. +""" + global precision + assert precision is None, "only set precision once to avoid inconsistent answers during different phases of model build and export" precision = dtype def get_precision(): + """get_precision() is a torchchat-internal API that returns the dtype we're building the model for, as specified by the `--dtype` CLI option+, +or the precision quantizer. +""" global precision + # if (and only if) precision has not been set, update it to the default value torch.float32 + if precision is None: + precision = torch.float32 return precision