From 30e1c877a507371f59fc11db6cc1e8f19ce9ed24 Mon Sep 17 00:00:00 2001 From: Michael Gschwind <61328285+mikekgfb@users.noreply.github.com> Date: Fri, 25 Oct 2024 00:48:02 -0700 Subject: [PATCH] Add docstrings build_utils.py, enforce consistent precision answers over a single run --- torchchat/utils/build_utils.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/torchchat/utils/build_utils.py b/torchchat/utils/build_utils.py index c9a9eae14..fd30f87d5 100644 --- a/torchchat/utils/build_utils.py +++ b/torchchat/utils/build_utils.py @@ -111,16 +111,30 @@ def use_et_backend() -> bool: ########################################################################## ### set and get target precision for this model ### -precision = torch.float32 +precision = None def set_precision(dtype): + """set_precision() is a torchchat-internal API that records the dtype we're building the model for. +The precision is recorded for future queries by get_precision(), so that when building a model, +or performing optimizations, we can query the type the user is building the model for. +This is an informational value that can be used when we want to know what type to build for (e.g., a kv cache). +Changing the `precision` does not change the precision of the model. +""" + global precision + assert precision is None, "only set precision once to avoid inconsistent answers during different phases of model build and export" precision = dtype def get_precision(): + """get_precision() is a torchchat-internal API that returns the dtype we're building the model for, as specified by the `--dtype` CLI option+, +or the precision quantizer. +""" global precision + # if (and only if) precision has not been set, update it to the default value torch.float32 + if precision is None: + precision = torch.float32 return precision