New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Use GEMM in _update_dict #11420

Merged
merged 3 commits into from Jul 18, 2018
Jump to file or symbol
Failed to load files and symbols.
+5 −4
Diff settings

Always

Just for now

@@ -373,11 +373,12 @@ def _update_dict(dictionary, Y, code, verbose=False, return_r2=False,
n_components = len(code)
n_features = Y.shape[0]
random_state = check_random_state(random_state)
# Residuals, computed 'in-place' for efficiency
R = -np.dot(dictionary, code)
R += Y
R = np.asfortranarray(R)
# Get BLAS functions
gemm, = linalg.get_blas_funcs(('gemm',), (dictionary, code, Y))
ger, = linalg.get_blas_funcs(('ger',), (dictionary, code))
# Residuals, computed with BLAS for speed and efficiency
# R <- -1.0 * U * V^T + 1.0 * Y
R = gemm(-1.0, dictionary, code, 1.0, Y)

This comment has been minimized.

@jnothman

jnothman Jul 3, 2018

Member

This produces a Fortran array I presume?

This comment has been minimized.

@jakirkham

jakirkham Jul 4, 2018

Contributor

Yep.

This comment has been minimized.

@ogrisel

ogrisel Jul 17, 2018

Member

Maybe it would be worth making it explicit that we expect a fortran array in the comment.

This comment has been minimized.

@jakirkham

jakirkham Jul 17, 2018

Contributor

Interestingly, the SciPy linalg functions correctly handle C ordered arrays as input.

In [1]: import numpy as np

In [2]: import scipy.linalg as linalg

In [3]: np.random.seed(0)

In [4]: a = np.random.random((2, 3))

In [5]: b = np.random.random((3, 4))

In [6]: c = np.random.random((2, 4))

In [7]: gemm, = linalg.get_blas_funcs(('gemm',), (a, b, c))

In [8]: gemm(1.0, a, b, 1.0, c)
Out[8]: 
array([[1.62736178, 1.79020759, 1.92593582, 2.17344607],
       [1.08121316, 1.54678646, 0.89707181, 1.77876957]])

In [9]: gemm(1.0, a, b, 1.0, c).flags
Out[9]: 
  C_CONTIGUOUS : False
  F_CONTIGUOUS : True
  OWNDATA : True
  WRITEABLE : True
  ALIGNED : True
  WRITEBACKIFCOPY : False
  UPDATEIFCOPY : False

That said, we appear to already be forcing dictionary and code to Fortran ordered arrays anywhere _update_dict is called.

for k in range(n_components):
# R <- 1.0 * U_k * V_k^T + R
R = ger(1.0, dictionary[:, k], code[k, :], a=R, overwrite_a=True)
ProTip! Use n and p to navigate between commits in a pull request.