Skip to content
12 changes: 11 additions & 1 deletion .circleci/unittest/linux/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ conda activate ./env

if [ "${CU_VERSION:-}" == cpu ] ; then
cudatoolkit="cpuonly"
version="cpu"
else
if [[ ${#CU_VERSION} -eq 4 ]]; then
CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}"
Expand All @@ -23,8 +24,17 @@ else
cudatoolkit="cudatoolkit=${version}"
fi

case "$(uname -s)" in
Darwin*) os=MacOSX;;
*) os=Linux
esac

printf "Installing PyTorch with %s\n" "${cudatoolkit}"
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" "pytorch-${UPLOAD_CHANNEL}"::pytorch "${cudatoolkit}" pytest
if [ "${os}" == "MacOSX" ]; then
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" "pytorch-${UPLOAD_CHANNEL}"::pytorch "${cudatoolkit}" pytest
else
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" "pytorch-${UPLOAD_CHANNEL}"::pytorch[build="*${version}*"] "${cudatoolkit}" pytest
fi

if [ $PYTHON_VERSION == "3.6" ]; then
printf "Installing minimal PILLOW version\n"
Expand Down
5 changes: 3 additions & 2 deletions .circleci/unittest/windows/scripts/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
eval "$(./conda/Scripts/conda.exe 'shell.bash' 'hook')"
conda activate ./env

# TODO, refactor the below logic to make it easy to understand how to get correct cuda_version.
if [ "${CU_VERSION:-}" == cpu ] ; then
cudatoolkit="cpuonly"
version="cpu"
else
if [[ ${#CU_VERSION} -eq 4 ]]; then
CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}"
Expand All @@ -26,8 +28,7 @@ else
fi

printf "Installing PyTorch with %s\n" "${cudatoolkit}"
# conda-forge channel is required for cudatoolkit 11.1 on Windows, see https://github.com/pytorch/vision/issues/4458
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c conda-forge "pytorch-${UPLOAD_CHANNEL}"::pytorch "${cudatoolkit}" pytest
conda install -y -c "pytorch-${UPLOAD_CHANNEL}" -c conda-forge "pytorch-${UPLOAD_CHANNEL}"::pytorch[build="*${version}*"] "${cudatoolkit}" pytest

if [ $PYTHON_VERSION == "3.6" ]; then
printf "Installing minimal PILLOW version\n"
Expand Down