Skip to content

Commit

Permalink
Fix sdot return and work with openblas
Browse files Browse the repository at this point in the history
  • Loading branch information
leonbottou committed Aug 10, 2013
1 parent 2763a43 commit 63dfd0c
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 43 deletions.
5 changes: 0 additions & 5 deletions CMakeLists.txt
Expand Up @@ -40,10 +40,5 @@ ADD_SUBDIRECTORY(dev)
# External packages support
INCLUDE(TorchExports)

# User defined cmake stuff
IF (EXISTS "${CMAKE_SOURCE_DIR}/CMakeExtra.txt")
INCLUDE("${CMAKE_SOURCE_DIR}/CMakeExtra.txt")
ENDIF()

# Packaging support
INCLUDE(TorchCPack)
32 changes: 32 additions & 0 deletions lib/TH/CMakeLists.txt
Expand Up @@ -143,6 +143,38 @@ INSTALL(FILES
generic/THVector.c
DESTINATION "${Torch_INSTALL_INCLUDE_SUBDIR}/TH/generic")


IF (WIN32 AND NOT CYGWIN)
SET(BLAS_INSTALL_LIBRARIES "OFF"
CACHE BOOL "Copy the required BLAS DLLs into the Torch install dirs")
ENDIF (WIN32 AND NOT CYGWIN)

MACRO(Install_Required_Library ln)
get_filename_component(libpath ${ln} PATH)
get_filename_component(libname ${ln} NAME_WE)
file(GLOB libdlls "${libpath}/${libname}*.dll")
install(PROGRAMS ${libdlls}
DESTINATION "${Torch_INSTALL_BIN_SUBDIR}")
ENDMACRO(Install_Required_Library libname)

IF (BLAS_FOUND AND BLAS_INSTALL_LIBRARIES)
IF (BLAS_goto2_LIBRARY)
Install_Required_Library(${BLAS_goto2_LIBRARY})
Install_Required_Library("${libpath}/libgfortran")
Install_Required_Library("${libpath}/libquadmath")
Install_Required_Library("${libpath}/libgcc")
ENDIF()
IF (BLAS_openblas_LIBRARY)
Install_Required_Library(${BLAS_openblas_LIBRARY})
Install_Required_Library("${libpath}/libquadmath")
Install_Required_Library("${libpath}/libgfortran")
Install_Required_Library("${libpath}/libquadmath")
Install_Required_Library("${libpath}/libgcc")
ENDIF()
ENDIF()



# Create THConfig.cmake
GET_TARGET_PROPERTY(TH_OUTPUT_NAME TH LOCATION)
GET_FILENAME_COMPONENT(TH_OUTPUT_NAME ${TH_OUTPUT_NAME} NAME)
Expand Down
1 change: 1 addition & 0 deletions lib/TH/THGeneral.h.in
Expand Up @@ -13,6 +13,7 @@
#cmakedefine USE_BLAS
#cmakedefine USE_LAPACK
#cmakedefine BLAS_IS_ACCELERATE
#cmakedefine BLAS_F2C

#ifdef __cplusplus
# define TH_EXTERNC extern "C"
Expand Down
14 changes: 13 additions & 1 deletion lib/TH/cmake/FindBLAS.cmake
Expand Up @@ -118,7 +118,6 @@ if((NOT BLAS_LIBRARIES)
endif(BLAS_LIBRARIES)
endif()


if((NOT BLAS_LIBRARIES)
AND ((NOT WITH_BLAS) OR (WITH_BLAS STREQUAL "open")))
check_fortran_libraries(
Expand All @@ -132,6 +131,19 @@ if((NOT BLAS_LIBRARIES)
endif(BLAS_LIBRARIES)
endif()

if((NOT BLAS_LIBRARIES) AND (WIN32)
AND ((NOT WITH_BLAS) OR (WITH_BLAS STREQUAL "open")))
check_fortran_libraries(
BLAS_LIBRARIES
BLAS
sgemm
""
"libopenblas")
if(BLAS_LIBRARIES)
set(BLAS_INFO "open")
endif(BLAS_LIBRARIES)
endif()

if((NOT BLAS_LIBRARIES)
AND ((NOT WITH_BLAS) OR (WITH_BLAS STREQUAL "goto")))
check_fortran_libraries(
Expand Down
50 changes: 27 additions & 23 deletions lib/TH/generic/THBlas.c
Expand Up @@ -2,6 +2,31 @@
#define TH_GENERIC_FILE "generic/THBlas.c"
#else

#if BLAS_F2C
# define ffloat double
#else
# define ffloat float
#endif

TH_EXTERNC void dswap_(int *n, double *x, int *incx, double *y, int *incy);
TH_EXTERNC void sswap_(int *n, float *x, int *incx, float *y, int *incy);
TH_EXTERNC void dscal_(int *n, double *a, double *x, int *incx);
TH_EXTERNC void sscal_(int *n, float *a, float *x, int *incx);
TH_EXTERNC void dcopy_(int *n, double *x, int *incx, double *y, int *incy);
TH_EXTERNC void scopy_(int *n, float *x, int *incx, float *y, int *incy);
TH_EXTERNC void daxpy_(int *n, double *a, double *x, int *incx, double *y, int *incy);
TH_EXTERNC void saxpy_(int *n, float *a, float *x, int *incx, float *y, int *incy);
TH_EXTERNC double ddot_(int *n, double *x, int *incx, double *y, int *incy);
TH_EXTERNC ffloat sdot_(int *n, float *x, int *incx, float *y, int *incy);
TH_EXTERNC void dgemv_(char *trans, int *m, int *n, double *alpha, double *a, int *lda, double *x, int *incx, double *beta, double *y, int *incy);
TH_EXTERNC void sgemv_(char *trans, int *m, int *n, float *alpha, float *a, int *lda, float *x, int *incx, float *beta, float *y, int *incy);
TH_EXTERNC void dger_(int *m, int *n, double *alpha, double *x, int *incx, double *y, int *incy, double *a, int *lda);
TH_EXTERNC void sger_(int *m, int *n, float *alpha, float *x, int *incx, float *y, int *incy, float *a, int *lda);
TH_EXTERNC void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, double *a, int *lda, double *b, int *ldb, double *beta, double *c, int *ldc);
TH_EXTERNC void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, float *a, int *lda, float *b, int *ldb, float *beta, float *c, int *ldc);



void THBlas_(swap)(long n, real *x, long incx, real *y, long incy)
{
if(n == 1)
Expand All @@ -18,10 +43,8 @@ void THBlas_(swap)(long n, real *x, long incx, real *y, long incy)
int i_incy = (int)incy;

#if defined(TH_REAL_IS_DOUBLE)
extern void dswap_(int *n, double *x, int *incx, double *y, int *incy);
dswap_(&i_n, x, &i_incx, y, &i_incy);
#else
extern void sswap_(int *n, float *x, int *incx, float *y, int *incy);
sswap_(&i_n, x, &i_incx, y, &i_incy);
#endif
return;
Expand Down Expand Up @@ -50,10 +73,8 @@ void THBlas_(scal)(long n, real a, real *x, long incx)
int i_incx = (int)incx;

#if defined(TH_REAL_IS_DOUBLE)
extern void dscal_(int *n, double *a, double *x, int *incx);
dscal_(&i_n, &a, x, &i_incx);
#else
extern void sscal_(int *n, float *a, float *x, int *incx);
sscal_(&i_n, &a, x, &i_incx);
#endif
return;
Expand Down Expand Up @@ -82,10 +103,8 @@ void THBlas_(copy)(long n, real *x, long incx, real *y, long incy)
int i_incy = (int)incy;

#if defined(TH_REAL_IS_DOUBLE)
extern void dcopy_(int *n, double *x, int *incx, double *y, int *incy);
dcopy_(&i_n, x, &i_incx, y, &i_incy);
#else
extern void scopy_(int *n, float *x, int *incx, float *y, int *incy);
scopy_(&i_n, x, &i_incx, y, &i_incy);
#endif
return;
Expand Down Expand Up @@ -114,10 +133,8 @@ void THBlas_(axpy)(long n, real a, real *x, long incx, real *y, long incy)
int i_incy = (int)incy;

#if defined(TH_REAL_IS_DOUBLE)
extern void daxpy_(int *n, double *a, double *x, int *incx, double *y, int *incy);
daxpy_(&i_n, &a, x, &i_incx, y, &i_incy);
#else
extern void saxpy_(int *n, float *a, float *x, int *incx, float *y, int *incy);
saxpy_(&i_n, &a, x, &i_incx, y, &i_incy);
#endif
return;
Expand Down Expand Up @@ -146,16 +163,9 @@ real THBlas_(dot)(long n, real *x, long incx, real *y, long incy)
int i_incy = (int)incy;

#if defined(TH_REAL_IS_DOUBLE)
extern double ddot_(int *n, double *x, int *incx, double *y, int *incy);
return ddot_(&i_n, x, &i_incx, y, &i_incy);
#else
#if defined(BLAS_IS_ACCELERATE)
extern double sdot_(int *n, float *x, int *incx, float *y, int *incy);
return (float)sdot_(&i_n, x, &i_incx, y, &i_incy);
return (real) ddot_(&i_n, x, &i_incx, y, &i_incy);
#else
extern float sdot_(int *n, float *x, int *incx, float *y, int *incy);
return sdot_(&i_n, x, &i_incx, y, &i_incy);
#endif
return (real) sdot_(&i_n, x, &i_incx, y, &i_incy);
#endif
}
#endif
Expand Down Expand Up @@ -186,10 +196,8 @@ void THBlas_(gemv)(char trans, long m, long n, real alpha, real *a, long lda, re
int i_incy = (int)incy;

#if defined(TH_REAL_IS_DOUBLE)
extern void dgemv_(char *trans, int *m, int *n, double *alpha, double *a, int *lda, double *x, int *incx, double *beta, double *y, int *incy);
dgemv_(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy);
#else
extern void sgemv_(char *trans, int *m, int *n, float *alpha, float *a, int *lda, float *x, int *incx, float *beta, float *y, int *incy);
sgemv_(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy);
#endif
return;
Expand Down Expand Up @@ -240,10 +248,8 @@ void THBlas_(ger)(long m, long n, real alpha, real *x, long incx, real *y, long
int i_incy = (int)incy;

#if defined(TH_REAL_IS_DOUBLE)
extern void dger_(int *m, int *n, double *alpha, double *x, int *incx, real *y, int *incy, double *a, int *lda);
dger_(&i_m, &i_n, &alpha, x, &i_incx, y, &i_incy, a, &i_lda);
#else
extern void sger_(int *m, int *n, float *alpha, float *x, int *incx, real *y, int *incy, float *a, int *lda);
sger_(&i_m, &i_n, &alpha, x, &i_incx, y, &i_incy, a, &i_lda);
#endif
return;
Expand Down Expand Up @@ -302,10 +308,8 @@ void THBlas_(gemm)(char transa, char transb, long m, long n, long k, real alpha,
int i_ldc = (int)ldc;

#if defined(TH_REAL_IS_DOUBLE)
extern void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, double *a, int *lda, double *b, int *ldb, double *beta, double *c, int *ldc);
dgemm_(&transa, &transb, &i_m, &i_n, &i_k, &alpha, a, &i_lda, b, &i_ldb, &beta, c, &i_ldc);
#else
extern void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, float *a, int *lda, float *b, int *ldb, float *beta, float *c, int *ldc);
sgemm_(&transa, &transb, &i_m, &i_n, &i_k, &alpha, a, &i_lda, b, &i_ldb, &beta, c, &i_ldc);
#endif
return;
Expand Down
31 changes: 17 additions & 14 deletions lib/TH/generic/THLapack.c
Expand Up @@ -2,14 +2,29 @@
#define TH_GENERIC_FILE "generic/THLapack.c"
#else


TH_EXTERNC void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info);
TH_EXTERNC void sgesv_(int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info);
TH_EXTERNC void dgels_(char *trans, int *m, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, double *work, int *lwork, int *info);
TH_EXTERNC void sgels_(char *trans, int *m, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, float *work, int *lwork, int *info);
TH_EXTERNC void dsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info);
TH_EXTERNC void ssyev_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *info);
TH_EXTERNC void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info);
TH_EXTERNC void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info);
TH_EXTERNC void dgesvd_(char *jobu, char *jobvt, int *m, int *n, double *a, int *lda, double *s, double *u, int *ldu, double *vt, int *ldvt, double *work, int *lwork, int *info);
TH_EXTERNC void sgesvd_(char *jobu, char *jobvt, int *m, int *n, float *a, int *lda, float *s, float *u, int *ldu, float *vt, int *ldvt, float *work, int *lwork, int *info);
TH_EXTERNC void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, int *info);
TH_EXTERNC void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info);
TH_EXTERNC void dgetri_(int *n, double *a, int *lda, int *ipiv, double *work, int *lwork, int *info);
TH_EXTERNC void sgetri_(int *n, float *a, int *lda, int *ipiv, float *work, int *lwork, int *info);


void THLapack_(gesv)(int n, int nrhs, real *a, int lda, int *ipiv, real *b, int ldb, int* info)
{
#ifdef USE_LAPACK
#if defined(TH_REAL_IS_DOUBLE)
extern void dgesv_(int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info);
dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
#else
extern void sgesv_(int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info);
sgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
#endif
#else
Expand All @@ -22,10 +37,8 @@ void THLapack_(gels)(char trans, int m, int n, int nrhs, real *a, int lda, real
{
#ifdef USE_LAPACK
#if defined(TH_REAL_IS_DOUBLE)
extern void dgels_(char *trans, int *m, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, double *work, int *lwork, int *info);
dgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info);
#else
extern void sgels_(char *trans, int *m, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, float *work, int *lwork, int *info);
sgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info);
#endif
#else
Expand All @@ -37,10 +50,8 @@ void THLapack_(syev)(char jobz, char uplo, int n, real *a, int lda, real *w, rea
{
#ifdef USE_LAPACK
#if defined(TH_REAL_IS_DOUBLE)
extern void dsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info);
dsyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info);
#else
extern void ssyev_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *info);
ssyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info);
#endif
#else
Expand All @@ -52,10 +63,8 @@ void THLapack_(geev)(char jobvl, char jobvr, int n, real *a, int lda, real *wr,
{
#ifdef USE_LAPACK
#if defined(TH_REAL_IS_DOUBLE)
extern void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info);
dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info);
#else
extern void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info);
sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info);
#endif
#else
Expand All @@ -67,10 +76,8 @@ void THLapack_(gesvd)(char jobu, char jobvt, int m, int n, real *a, int lda, rea
{
#ifdef USE_LAPACK
#if defined(TH_REAL_IS_DOUBLE)
extern void dgesvd_(char *jobu, char *jobvt, int *m, int *n, double *a, int *lda, double *s, double *u, int *ldu, double *vt, int *ldvt, double *work, int *lwork, int *info);
dgesvd_( &jobu, &jobvt, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, info);
#else
extern void sgesvd_(char *jobu, char *jobvt, int *m, int *n, float *a, int *lda, float *s, float *u, int *ldu, float *vt, int *ldvt, float *work, int *lwork, int *info);
sgesvd_( &jobu, &jobvt, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, info);
#endif
#else
Expand All @@ -83,10 +90,8 @@ void THLapack_(getrf)(int m, int n, real *a, int lda, int *ipiv, int *info)
{
#ifdef USE_LAPACK
#if defined(TH_REAL_IS_DOUBLE)
extern void dgetrf_(int *m, int *n, real *a, int *lda, int *ipiv, int *info);
dgetrf_(&m, &n, a, &lda, ipiv, info);
#else
extern void sgetrf_(int *m, int *n, real *a, int *lda, int *ipiv, int *info);
sgetrf_(&m, &n, a, &lda, ipiv, info);
#endif
#else
Expand All @@ -98,10 +103,8 @@ void THLapack_(getri)(int n, real *a, int lda, int *ipiv, real *work, int lwork,
{
#ifdef USE_LAPACK
#if defined(TH_REAL_IS_DOUBLE)
extern void dgetri_(int *n, real *a, int *lda, int *ipiv, real *work, int *lwork, int *info);
dgetri_(&n, a, &lda, ipiv, work, &lwork, info);
#else
extern void sgetri_(int *n, real *a, int *lda, int *ipiv, real *work, int *lwork, int *info);
sgetri_(&n, a, &lda, ipiv, work, &lwork, info);
#endif
#else
Expand Down

0 comments on commit 63dfd0c

Please sign in to comment.