Skip to content

Commit 6b1344e

Browse files
Do not fail on lack of default precision set. (#6139)
I discovered some models from the suite do not have the default precision set so instead of failing the script we just log the case, and do nothing, as no additional machinery should run for the Inductor anyway. Additionally I wrapped the exceptions with the ValueError so the logging message will not pollute with info about str not inheriting from Exception class. @cota , note that needs to be hooked "somewhere". Not sure where, as there was a revert in #6134, but in general it can be done prior to moving the model to the device safely.
1 parent a2f80e4 commit 6b1344e

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

benchmarks/torchbench_model.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,23 +204,36 @@ def set_up(self):
204204
gc.collect()
205205

206206
def apply_default_precision_config(self, test, benchmark):
207+
"""
208+
Apply default precision config to XLA, if present.
209+
210+
Whenever a model has a default precision for cuda set
211+
we need to set proper environment flags so XLA catches
212+
the requird precision.
213+
214+
This function is a workaround. Proper solution requires
215+
changes to the PT/XLA bridge so that the input shape
216+
is properly inferred after issuing converts to `torch.nn.Module`.
217+
"""
207218
if test == "eval" and hasattr(benchmark, 'DEFAULT_EVAL_CUDA_PRECISION'):
208219
precision = benchmark.DEFAULT_EVAL_CUDA_PRECISION
209220
elif test == "train" and hasattr(benchmark, 'DEFAULT_TRAIN_CUDA_PRECISION'):
210221
precision = benchmark.DEFAULT_TRAIN_CUDA_PRECISION
211222
else:
212-
raise f"Unkown test type {test}!"
223+
logger.warning("No default precision set. No patching needed.")
224+
return
213225

214226
if precision == "fp16":
215227
os.environ['XLA_USE_FP16'] = '1'
216228
elif precision == "amp":
217-
raise f"AMP for PT/XLA:GPU is not implemented yet for torchbench models"
229+
raise ValueError(
230+
f"AMP for PT/XLA:GPU is not implemented yet for torchbench models")
218231
elif precision == "bf16":
219232
os.environ['XLA_USE_BF16'] = '1'
220233
elif precision == "fp32":
221234
logger.warning("Sticking with the default fp32 precision.")
222235
else:
223-
raise f"Unknown precision: {precision}"
236+
raise ValueError(f"Unknown precision: {precision}")
224237

225238
def pick_grad(self):
226239
# special case

0 commit comments

Comments
 (0)