Skip to content

Commit

Permalink
Merge pull request pytorch#3 from peterjc123/windows-new
Browse files Browse the repository at this point in the history
modify cuda and cudnn dll names for win32
  • Loading branch information
peterjc123 committed Jun 13, 2017
2 parents 0881771 + 2ab90b2 commit 8b4feee
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
13 changes: 8 additions & 5 deletions torch/backends/cudnn/__init__.py
Expand Up @@ -13,12 +13,15 @@
def _libcudnn():
global lib, __cudnn_version
if lib is None:
lib = ctypes.cdll.LoadLibrary(None)
if hasattr(lib, 'cudnnGetErrorString'):
lib.cudnnGetErrorString.restype = ctypes.c_char_p
__cudnn_version = lib.cudnnGetVersion()
if sys.platform == "win32":
lib = ctypes.cdll.LoadLibrary('cudnn64_6')
else:
lib = None
lib = ctypes.cdll.LoadLibrary(None)
if hasattr(lib, 'cudnnGetErrorString'):
lib.cudnnGetErrorString.restype = ctypes.c_char_p
__cudnn_version = lib.cudnnGetVersion()
else:
lib = None
return lib


Expand Down
6 changes: 5 additions & 1 deletion torch/cuda/__init__.py
Expand Up @@ -12,6 +12,7 @@
import platform
import ctypes
import os
import sys
import torch
from multiprocessing.util import register_after_fork as _register_after_fork

Expand All @@ -35,7 +36,10 @@ def _sleep(cycles):

def _load_cudart():
# First check the main program for CUDA symbols
lib = ctypes.cdll.LoadLibrary(None)
if sys.platform == "win32":
lib = ctypes.cdll.LoadLibrary('cudart64_80')
else:
lib = ctypes.cdll.LoadLibrary(None)
if hasattr(lib, 'cudaGetErrorName'):
return lib

Expand Down

0 comments on commit 8b4feee

Please sign in to comment.