diff --git a/torchchat.py b/torchchat.py index 35cdcabae..1eeee0120 100644 --- a/torchchat.py +++ b/torchchat.py @@ -6,7 +6,7 @@ import argparse import logging -import subprocess +import signal import sys # MPS ops missing with Multimodal torchtune @@ -25,7 +25,15 @@ default_device = "cpu" +def signal_handler(sig, frame): + print("\nInterrupted by user. Bye!\n") + sys.exit(0) + + if __name__ == "__main__": + # Set the signal handler for SIGINT + signal.signal(signal.SIGINT, signal_handler) + # Initialize the top-level parser parser = argparse.ArgumentParser( prog="torchchat",