Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: extend intrinsic matmul #951

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Prev Previous commit
replace all matmul's by gemm
  • Loading branch information
wassup05 committed Mar 20, 2025
commit e709f838aeb1b57276bdfe36840ef9a720e1086c
4 changes: 2 additions & 2 deletions example/intrinsics/example_matmul.f90
Original file line number Diff line number Diff line change
@@ -4,9 +4,9 @@ program example_matmul
real :: r1(50, 100), r2(100, 40), r3(40, 50)
real, allocatable :: res(:, :)
x = reshape([(0, 0), (1, 0), (1, 0), (0, 0)], [2, 2])
y = reshape([(0, 0), (0, -1), (0, 1), (0, 0)], [2, 2]) ! pauli y-matrix
y = reshape([(0, 0), (0, 1), (0, -1), (0, 0)], [2, 2]) ! pauli y-matrix

print *, stdlib_matmul(y, y, y, y, y) ! should be y
print *, stdlib_matmul(y, y, y) ! should be y
print *, stdlib_matmul(x, x, y, x) ! should be -i x sigma_z

call random_seed()
4 changes: 2 additions & 2 deletions src/stdlib_intrinsics.fypp
Original file line number Diff line number Diff line change
@@ -158,10 +158,10 @@ module stdlib_intrinsics
!!
!! matrix multiply more than two matrices with a single function call
!! the multiplication with the optimal parenthesization for efficiency of computation is done automatically
!! Supported data types are `real`, `integer` and `complex`.
!! Supported data types are `real` and `complex`.
!!
!! Note: The matrices must be of compatible shapes to be multiplied
#:for k, t, s in I_KINDS_TYPES + R_KINDS_TYPES + C_KINDS_TYPES
#:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES
pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r)
${t}$, intent(in) :: m1(:,:), m2(:,:)
${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
149 changes: 117 additions & 32 deletions src/stdlib_intrinsics_matmul.fypp
Original file line number Diff line number Diff line change
@@ -4,6 +4,8 @@
#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX))

submodule (stdlib_intrinsics) stdlib_intrinsics_matmul
use stdlib_linalg_blas, only: gemm
use stdlib_constants
implicit none

contains
@@ -36,38 +38,84 @@ contains
end do
end function matmul_chain_order

#:for k, t, s in I_KINDS_TYPES + R_KINDS_TYPES + C_KINDS_TYPES
#:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES

pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s) result(r)
pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s, p) result(r)
${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:)
integer, intent(in) :: start, s(:,2:)
${t}$, allocatable :: r(:,:)
integer :: tmp
tmp = s(start, start + 2)

if (tmp == start) then
r = matmul(m1, matmul(m2, m3))
else if (tmp == start + 1) then
r = matmul(matmul(m1, m2), m3)
integer, intent(in) :: start, s(:,2:), p(:)
${t}$, allocatable :: r(:,:), temp(:,:)
integer :: ord, m, n, k
ord = s(start, start + 2)
allocate(r(p(start), p(start + 3)))

if (ord == start) then
! m1*(m2*m3)
m = p(start + 1)
n = p(start + 3)
k = p(start + 2)
allocate(temp(m,n))
call gemm('N', 'N', m, n, k, one_${s}$, m2, m, m3, k, zero_${s}$, temp, m)
Copy link
Member

@perazz perazz Mar 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good progress @wassup05, thank you! Imho this PR is almost ready to be merged. As you suggest, it would be good to have a nice wrapper for gemm. It has been discussed before. I would like to suggest that all calls to gemm are also wrapped into a stdlib_matmul function - now with two matrices only. This would give stdlib fully functional matmul functionality.

Here I suggest two possible APIs, and I will ask @jalvesz @jvdp1 @loiseaujc to discuss that together:

The first would be similar to gemm

! API Similar to gemm
${t1}$ function stdlib_matmul(A, Astate, B, Bstate) result(C)
  ${t1}$, intent(in) :: A(:,:), B(:,:)
  character, intent(in), optional :: Astate, Bstate

and could use the matrix state definitions already in use for the sparse operations

character(1), parameter :: sparse_op_none = 'N' !! no transpose
character(1), parameter :: sparse_op_transpose = 'T' !! transpose
character(1), parameter :: sparse_op_hermitian = 'H' !! conjugate or hermitian transpose

The second would be more ambitious and essentially zero-overhead, it would wrap the operation in a derived type: (to be templated of course)

type :: matrix_state_rdp
   real(dp), pointer :: A(:,:) => null()
   character(1) :: Astate = 'N'
end type matrix_state_rdp

interface transposed
   module procedure transposed_new_rdp
   ...
end interface transposed    

type(matrix_state_rdp) function transposed_new_rdp(A) result(AT)
    real(dp), intent(inout), target :: A(:,:)
    AT%A => A
    AT%Astate = 'T'
end function

Then we could define a templated base interface

! Work with normal matrices
${t1}$ function stdlib_matmul(A, B) result(C)
  ${t1}$, intent(in) :: A(:,:), B(:,:)

! Work with transposed/hermitian swaps
${rt}$ function stdlib_matmul(A, B) result(C)
  matrix_state_${rn}$, intent(in) :: A, B

So the user writing code would have it clear:

C = strlib_matmul(A, B)
C = stdlib_matmul(transposed(A), B)
C = stlib_matmul(A, hermitian(C))

we could even make it an operator:

C = strlib_matmul(A, B)
C = stdlib_matmul(.t.A, B)
C = stlib_matmul(A, .h.C)

without it triggering any actual data movement.

Copy link

@loiseaujc loiseaujc Mar 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C = stdlib_matmul(transposed(A), B)
C = stlib_matmul(A, hermitian(C))

How would that work in the written code? Would A and B have to be declared as type(matrix_state_type)?
If so, I'm a bit afraid most people would prefer the more "natural" real, dimension(m, n) :: A, B and play around either directly with the intrinsic transpose or the classical gemm dummy arguments. Maybe I'm missing something though?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would that work in the written code? Would A and B have to be declared as type(matrix_state_type)?

I also had the same reaction at first, after thinking about it longer I saw that it is actually a pretty clever solution, the user would declare the matrices as regular dense matrices. It is the internal interface that would make the distinction. This would imply though implementing internally several versions to account for the combinations (dense,dense) / (dense,type) / (type,dense). This looks interesting but I wonder if it should be pursued at this stage. The first proposal by @perazz looks easier and totally valid but I would propose it with a slight modification:

${t1}$ function stdlib_matmul(A, B, op_a, op_b) result(C)
  ${t1}$, intent(in) :: A(:,:), B(:,:)
  character, intent(in), optional :: op_a, op_b

to let all optional arguments of a procedure at the end of the signature.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes @jalvesz @loiseaujc, I don't know how to write it better, but it is outlined in the previous post.
There would be an interface for matmul(A,B) (resolved at compile time) with 4 options for each kind:

  1. both A, B are simple matrices (2d arrays)
  2. A is a matrix_state_, B is a 2D array
  3. B is a matrix_state_, A is a 2d array
  4. Both are matrix_state_*

Only function 4. is actually implemented, and is a gemm wrapper: the other implementations just wrap against it with fypp.

Copy link

@loiseaujc loiseaujc Mar 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. I quite like that indeed. I still believe though that, as a starting point, restricting ourselves to standard stuff might be easier. Introducing new derived types to represent matrices definitely is something I'm looking forward to but, in line with the discussion here, it might require a broader discussion to have a well-designed set of derived types.

I also prefer this signature

${t1}$ function stdlib_matmul(A, B, op_a, op_b) result(C)
  ${t1}$, intent(in) :: A(:,:), B(:,:)
  character, intent(in), optional :: op_a, op_b

for the reasons you've mentioned.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, here I would define these "convenience" derived types, because they're only used to make the matmul interface better and more readable. I would only expose to the user the actual interface (i.e., either the operators .t. .h., or the function names transposed hermitian).

m = p(start)
n = p(start + 3)
k = p(start + 1)
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m)
else if (ord == start + 1) then
! (m1*m2)*m3
m = p(start)
n = p(start + 2)
k = p(start + 1)
allocate(temp(m, n))
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m)
m = p(start)
n = p(start + 3)
k = p(start + 1)
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m3, k, zero_${s}$, r, m)
else
error stop "stdlib_matmul: error: unexpected s(i,j)"
end if

end function matmul_chain_mult_${s}$_3

pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s) result(r)
pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s, p) result(r)
${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:), m4(:,:)
integer, intent(in) :: start, s(:,2:)
${t}$, allocatable :: r(:,:)
integer :: tmp
tmp = s(start, start + 3)

if (tmp == start) then
r = matmul(m1, matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s))
else if (tmp == start + 1) then
r = matmul(matmul(m1, m2), matmul(m3, m4))
else if (tmp == start + 2) then
r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, start, s), m4)
integer, intent(in) :: start, s(:,2:), p(:)
${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:)
integer :: ord, m, n, k
ord = s(start, start + 3)
allocate(r(p(start), p(start + 4)))

if (ord == start) then
! m1*(m2*m3*m4)
temp = matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s, p)
m = p(start)
n = p(start + 4)
k = p(start + 1)
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m)
else if (ord == start + 1) then
! (m1*m2)*(m3*m4)
m = p(start)
n = p(start + 2)
k = p(start + 1)
allocate(temp(m,n))
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m)

m = p(start + 2)
n = p(start + 4)
k = p(start + 3)
allocate(temp1(m,n))
call gemm('N', 'N', m, n, k, one_${s}$, m3, m, m4, k, zero_${s}$, temp1, m)

m = p(start)
n = p(start + 4)
k = p(start + 2)
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m)
else if (ord == start + 2) then
! (m1*m2*m3)*m4
temp = matmul_chain_mult_${s}$_3(m1, m2, m3, start, s, p)
m = p(start)
n = p(start + 4)
k = p(start + 3)
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m4, k, zero_${s}$, r, m)
else
error stop "stdlib_matmul: error: unexpected s(i,j)"
end if
@@ -77,8 +125,8 @@ contains
pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r)
${t}$, intent(in) :: m1(:,:), m2(:,:)
${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:)
${t}$, allocatable :: r(:,:)
integer :: p(6), num_present
${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:)
integer :: p(6), num_present, m, n, k
integer, allocatable :: s(:,:)

p(1) = size(m1, 1)
@@ -102,8 +150,13 @@ contains
num_present = num_present + 1
end if

allocate(r(p(1), p(num_present + 1)))

if (num_present == 2) then
r = matmul(m1, m2)
m = p(1)
n = p(3)
k = p(2)
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, r, m)
return
end if

@@ -113,24 +166,56 @@ contains
s = matmul_chain_order(p(1: num_present + 1))

if (num_present == 3) then
r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s)
r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s, p(1:4))
return
else if (num_present == 4) then
r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s)
r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p(1:5))
return
end if

! Now num_present is 5

select case (s(1, 5))
case (1)
r = matmul(m1, matmul_chain_mult_${s}$_4(m2, m3, m4, m5, 2, s))
! m1*(m2*m3*m4*m5)
temp = matmul_chain_mult_${s}$_4(m2, m3, m4, m5, 2, s, p)
m = p(1)
n = p(6)
k = p(2)
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m)
case (2)
r = matmul(matmul(m1, m2), matmul_chain_mult_${s}$_3(m3, m4, m5, 3, s))
! (m1*m2)*(m3*m4*m5)
m = p(1)
n = p(3)
k = p(2)
allocate(temp(m,n))
call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m)

temp1 = matmul_chain_mult_${s}$_3(m3, m4, m5, 3, s, p)

k = n
n = p(6)
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m)
case (3)
r = matmul(matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s), matmul(m4, m5))
! (m1*m2*m3)*(m4*m5)
temp = matmul_chain_mult_${s}$_3(m1, m2, m3, 3, s, p)

m = p(4)
n = p(6)
k = p(5)
allocate(temp1(m,n))
call gemm('N', 'N', m, n, k, one_${s}$, m4, m, m5, k, zero_${s}$, temp1, m)

k = m
m = p(1)
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m)
case (4)
r = matmul(matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s), m5)
! (m1*m2*m3*m4)*m5
temp = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p)
m = p(1)
n = p(6)
k = p(5)
call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, r, m)
case default
error stop "stdlib_matmul: error: unexpected s(i,j)"
end select
Loading
Oops, something went wrong.