diff --git a/torchchat.py b/torchchat.py index 9f85f0692..35cdcabae 100644 --- a/torchchat.py +++ b/torchchat.py @@ -9,6 +9,11 @@ import subprocess import sys +# MPS ops missing with Multimodal torchtune +# https://github.com/pytorch/torchtune/issues/1723 +import os +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + from torchchat.cli.cli import ( add_arguments_for_verb, arg_init,