From d521d3506e7d5c80182c84615b9f8d4ee2e7402b Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Fri, 14 Jun 2024 19:41:36 -0700 Subject: [PATCH 1/2] use pytorch version env variable --- requirements.txt | 2 -- setup.py | 16 +++++++++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/requirements.txt b/requirements.txt index 0e6e860a5a..68cd4298f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,3 @@ -torch -numpy sentencepiece packaging expecttest # So we can use IS_FBCODE flag diff --git a/setup.py b/setup.py index d4cf988b43..129bc70fbf 100644 --- a/setup.py +++ b/setup.py @@ -11,10 +11,6 @@ current_date = datetime.now().strftime("%Y.%m.%d") -def read_requirements(file_path): - with open(file_path, "r") as file: - return file.read().splitlines() - def read_version(file_path="version.txt"): with open(file_path, "r") as file: return file.readline().strip() @@ -88,6 +84,16 @@ def get_extensions(): return ext_modules +# Mimic code from torchvision https://github.com/pytorch/vision/blob/143d078b28f00471156a4e562dd3836370acc9ee/setup.py#L58 +pytorch_dep = "torch" +if os.getenv("PYTORCH_VERSION"): + pytorch_dep += "==" + os.getenv("PYTORCH_VERSION") + +requirements = [ + "numpy", + pytorch_dep, +] + setup( name=package_name, version=version+version_suffix, @@ -97,7 +103,7 @@ def get_extensions(): "torchao.kernel.configs": ["*.pkl"], }, ext_modules=get_extensions() if use_cpp != "0" else None, - install_requires=read_requirements("requirements.txt"), + install_requires=requirements, extras_require={"dev": read_requirements("dev-requirements.txt")}, description="Package for applying ao techniques to GPU models", long_description=open("README.md").read(), From 58f23c2df71a0111c8c21836aec9da77d1dbcc11 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Fri, 14 Jun 2024 19:47:48 -0700 Subject: [PATCH 2/2] yolo --- setup.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/setup.py b/setup.py index 129bc70fbf..fc14725005 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,10 @@ current_date = datetime.now().strftime("%Y.%m.%d") +def read_requirements(file_path): + with open(file_path, "r") as file: + return file.read().splitlines() + def read_version(file_path="version.txt"): with open(file_path, "r") as file: return file.readline().strip()