diff --git a/tools/docker/gpu_tests.Dockerfile b/tools/docker/gpu_tests.Dockerfile index 6f1f52045f..6c36a8666c 100644 --- a/tools/docker/gpu_tests.Dockerfile +++ b/tools/docker/gpu_tests.Dockerfile @@ -1,5 +1,9 @@ FROM tensorflow/tensorflow:2.1.0-custom-op-gpu-ubuntu16 ENV TF_NEED_CUDA="1" +ENV TF_CUDA_VERSION="10.1" +ENV CUDA_TOOLKIT_PATH="/usr/local/cuda" +ENV TF_CUDNN_VERSION="7" +ENV CUDNN_INSTALL_PATH="/usr/lib/x86_64-linux-gnu" RUN python3 -m pip install --upgrade pip setuptools auditwheel==2.0.0