Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion torchchat/utils/build_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,30 @@ def use_et_backend() -> bool:
##########################################################################
### set and get target precision for this model ###

precision = torch.float32
precision = None


def set_precision(dtype):
"""set_precision() is a torchchat-internal API that records the dtype we're building the model for.
The precision is recorded for future queries by get_precision(), so that when building a model,
or performing optimizations, we can query the type the user is building the model for.
This is an informational value that can be used when we want to know what type to build for (e.g., a kv cache).
Changing the `precision` does not change the precision of the model.
"""

global precision
assert precision is None, "only set precision once to avoid inconsistent answers during different phases of model build and export"
precision = dtype


def get_precision():
"""get_precision() is a torchchat-internal API that returns the dtype we're building the model for, as specified by the `--dtype` CLI option+,
or the precision quantizer.
"""
global precision
# if (and only if) precision has not been set, update it to the default value torch.float32
if precision is None:
precision = torch.float32
return precision


Expand Down
Loading