From b99624a49bd838c3ee7cc564fb23f63bb415786f Mon Sep 17 00:00:00 2001 From: atalman Date: Thu, 31 Aug 2023 09:33:00 -0700 Subject: [PATCH] Use regular env setup --- .github/scripts/setup-env.sh | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/.github/scripts/setup-env.sh b/.github/scripts/setup-env.sh index 4b852efd9b7..a4f113c367f 100755 --- a/.github/scripts/setup-env.sh +++ b/.github/scripts/setup-env.sh @@ -73,11 +73,21 @@ else CHANNEL=nightly fi -pip install --progress-bar=off light-the-torch -ltt install --progress-bar=off \ - --pytorch-computation-backend="${GPU_ARCH_TYPE}${GPU_ARCH_VERSION}" \ - --pytorch-channel="${CHANNEL}" \ - torch +case $GPU_ARCH_TYPE in + cpu) + GPU_ARCH_ID="cpu" + ;; + cuda) + VERSION_WITHOUT_DOT=$(echo "${GPU_ARCH_VERSION}" | sed 's/\.//') + GPU_ARCH_ID="cu${VERSION_WITHOUT_DOT}" + ;; + *) + echo "Unknown GPU_ARCH_TYPE=${GPU_ARCH_TYPE}" + exit 1 + ;; +esac +PYTORCH_WHEEL_INDEX="https://download.pytorch.org/whl/${CHANNEL}/${GPU_ARCH_ID}" +pip install --progress-bar=off --pre torch --index-url="${PYTORCH_WHEEL_INDEX}" if [[ $GPU_ARCH_TYPE == 'cuda' ]]; then python -c "import torch; exit(not torch.cuda.is_available())"