Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement transeq in omp backend #27

Merged
merged 21 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 126 additions & 8 deletions src/omp/backend.f90
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ module m_omp_backend
use m_base_backend, only: base_backend_t
use m_common, only: dp, globs_t
use m_tdsops, only: dirps_t, tdsops_t
use m_omp_exec_dist, only: exec_dist_tds_compact, exec_dist_transeq_compact
use m_omp_sendrecv, only: sendrecv_fields

use m_omp_common, only: SZ

Expand Down Expand Up @@ -30,6 +32,7 @@ module m_omp_backend
procedure :: vecadd => vecadd_omp
procedure :: set_fields => set_fields_omp
procedure :: get_fields => get_fields_omp
procedure :: transeq_omp_dist
end type omp_backend_t

interface omp_backend_t
Expand Down Expand Up @@ -118,7 +121,7 @@ subroutine transeq_x_omp(self, du, dv, dw, u, v, w, dirps)
class(field_t), intent(in) :: u, v, w
type(dirps_t), intent(in) :: dirps

!call self%transeq_omp_dist(du, dv, dw, u, v, w, dirps)
call self%transeq_omp_dist(du, dv, dw, u, v, w, dirps)

end subroutine transeq_x_omp

Expand All @@ -131,7 +134,7 @@ subroutine transeq_y_omp(self, du, dv, dw, u, v, w, dirps)
type(dirps_t), intent(in) :: dirps

! u, v, w is reordered so that we pass v, u, w
!call self%transeq_omp_dist(dv, du, dw, v, u, w, dirps)
call self%transeq_omp_dist(dv, du, dw, v, u, w, dirps)

end subroutine transeq_y_omp

Expand All @@ -144,10 +147,86 @@ subroutine transeq_z_omp(self, du, dv, dw, u, v, w, dirps)
type(dirps_t), intent(in) :: dirps

! u, v, w is reordered so that we pass w, u, v
!call self%transeq_omp_dist(dw, du, dv, w, u, v, dirps)
call self%transeq_omp_dist(dw, du, dv, w, u, v, dirps)

end subroutine transeq_z_omp

subroutine transeq_omp_dist(self, du, dv, dw, u, v, w, dirps)
implicit none

class(omp_backend_t) :: self
class(field_t), intent(inout) :: du, dv, dw
class(field_t), intent(in) :: u, v, w
type(dirps_t), intent(in) :: dirps
integer :: n_halo

call transeq_halo_exchange(self, u, v, w, dirps)

call transeq_dist_component(self, du, u, u, &
dirps%der1st, dirps%der1st_sym, dirps%der2nd, dirps)
call transeq_dist_component(self, dv, v, u, &
dirps%der1st_sym, dirps%der1st, dirps%der2nd_sym, dirps)
call transeq_dist_component(self, dw, w, u, &
dirps%der1st_sym, dirps%der1st, dirps%der2nd_sym, dirps)

end subroutine transeq_omp_dist


subroutine transeq_halo_exchange(self, u, v, w, dirps)
class(omp_backend_t) :: self
class(field_t), intent(in) :: u, v, w
type(dirps_t), intent(in) :: dirps
integer :: n_halo

! TODO: don't hardcode n_halo
n_halo = 4

call copy_into_buffers(self%u_send_s, self%u_send_e, u%data, dirps%n, dirps%n_blocks)
call copy_into_buffers(self%v_send_s, self%v_send_e, v%data, dirps%n, dirps%n_blocks)
call copy_into_buffers(self%w_send_s, self%w_send_e, w%data, dirps%n, dirps%n_blocks)

call sendrecv_fields(self%u_recv_s, self%u_recv_e, self%u_send_s, self%u_send_e, &
SZ*n_halo*dirps%n_blocks, dirps%nproc, dirps%pprev, dirps%pnext)
call sendrecv_fields(self%v_recv_s, self%v_recv_e, self%v_send_s, self%v_send_e, &
SZ*n_halo*dirps%n_blocks, dirps%nproc, dirps%pprev, dirps%pnext)
call sendrecv_fields(self%w_recv_s, self%w_recv_e, self%w_send_s, self%w_send_e, &
SZ*n_halo*dirps%n_blocks, dirps%nproc, dirps%pprev, dirps%pnext)

end subroutine transeq_halo_exchange

!> Computes RHS_x^v following:
! rhs_x^v = -0.5*(u*dv/dx + duv/dx) + nu*d2v/dx2
subroutine transeq_dist_component(self, rhs, v, u, tdsops_du, tdsops_dud, tdsops_d2u, dirps)

class(omp_backend_t) :: self
class(field_t), intent(inout) :: rhs
class(field_t), intent(in) :: u, v
class(tdsops_t), intent(in) :: tdsops_du
class(tdsops_t), intent(in) :: tdsops_dud
class(tdsops_t), intent(in) :: tdsops_d2u
type(dirps_t), intent(in) :: dirps
class(field_t), pointer :: du, d2u, dud

du => self%allocator%get_block()
dud => self%allocator%get_block()
d2u => self%allocator%get_block()

call exec_dist_transeq_compact(&
rhs%data, du%data, dud%data, d2u%data, &
self%du_send_s, self%du_send_e, self%du_recv_s, self%du_recv_e, &
self%dud_send_s, self%dud_send_e, self%dud_recv_s, self%dud_recv_e, &
self%d2u_send_s, self%d2u_send_e, self%d2u_recv_s, self%d2u_recv_e, &
u%data, self%u_recv_s, self%u_recv_e, &
v%data, self%v_recv_s, self%v_recv_e, &
tdsops_du, tdsops_dud, tdsops_d2u, self%nu, &
dirps%nproc, dirps%pprev, dirps%pnext, dirps%n_blocks)

call self%allocator%release_block(du)
call self%allocator%release_block(dud)
call self%allocator%release_block(d2u)

end subroutine transeq_dist_component

subroutine tds_solve_omp(self, du, u, dirps, tdsops)
implicit none

Expand All @@ -157,10 +236,36 @@ subroutine tds_solve_omp(self, du, u, dirps, tdsops)
type(dirps_t), intent(in) :: dirps
class(tdsops_t), intent(in) :: tdsops

!call self%tds_solve_dist(self, du, u, dirps, tdsops)
call tds_solve_dist(self, du, u, dirps, tdsops)

end subroutine tds_solve_omp

subroutine tds_solve_dist(self, du, u, dirps, tdsops)
implicit none

class(omp_backend_t) :: self
class(field_t), intent(inout) :: du
class(field_t), intent(in) :: u
type(dirps_t), intent(in) :: dirps
class(tdsops_t), intent(in) :: tdsops
integer :: n_halo

! TODO: don't hardcode n_halo
Nanoseb marked this conversation as resolved.
Show resolved Hide resolved
n_halo = 4
call copy_into_buffers(self%u_send_s, self%u_send_e, u%data, dirps%n, dirps%n_blocks)

! halo exchange
call sendrecv_fields(self%u_recv_s, self%u_recv_e, self%u_send_s, self%u_send_e, &
SZ*n_halo*dirps%n_blocks, dirps%nproc, dirps%pprev, dirps%pnext)


call exec_dist_tds_compact( &
du%data, u%data, self%u_recv_s, self%u_recv_e, self%du_send_s, self%du_send_e, &
self%du_recv_s, self%du_recv_e, &
tdsops, dirps%nproc, dirps%pprev, dirps%pnext, dirps%n_blocks)

end subroutine tds_solve_dist

subroutine reorder_omp(self, u_, u, direction)
implicit none

Expand Down Expand Up @@ -200,15 +305,28 @@ subroutine vecadd_omp(self, a, x, b, y)

end subroutine vecadd_omp

subroutine copy_into_buffers(u_send_s, u_send_e, u, n)
subroutine copy_into_buffers(u_send_s, u_send_e, u, n, n_blocks)
implicit none

real(dp), dimension(:, :, :), intent(out) :: u_send_s, u_send_e
real(dp), dimension(:, :, :), intent(in) :: u
integer, intent(in) :: n

u_send_s(:, :, :) = u(:, 1:4, :)
u_send_e(:, :, :) = u(:, n - 3:n, :)
integer, intent(in) :: n_blocks
integer :: i, j, k
integer :: n_halo = 4

!$omp parallel do
do k=1, n_blocks
do j=1, n_halo
!$omp simd
do i=1, SZ
u_send_s(i, j, k) = u(i, j, k)
u_send_e(i, j, k) = u(i, n - n_halo + j, k)
end do
!$omp end simd
end do
end do
!$omp end parallel do

end subroutine copy_into_buffers

Expand Down
129 changes: 129 additions & 0 deletions src/omp/exec_dist.f90
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,134 @@ subroutine exec_dist_tds_compact( &

end subroutine exec_dist_tds_compact


subroutine exec_dist_transeq_compact(&
rhs, du, dud, d2u, &
du_send_s, du_send_e, du_recv_s, du_recv_e, &
dud_send_s, dud_send_e, dud_recv_s, dud_recv_e, &
d2u_send_s, d2u_send_e, d2u_recv_s, d2u_recv_e, &
u, u_recv_s, u_recv_e, &
v, v_recv_s, v_recv_e, &
tdsops_du, tdsops_dud, tdsops_d2u, nu, nproc, pprev, pnext, n_block)

implicit none

! du = d(u)
real(dp), dimension(:, :, :), intent(out) :: rhs, du, dud, d2u

! The ones below are intent(out) just so that we can write data in them,
! not because we actually need the data they store later where this
! subroutine is called. We absolutely don't care about the data they pass back
real(dp), dimension(:, :, :), intent(out) :: &
du_send_s, du_send_e, du_recv_s, du_recv_e
real(dp), dimension(:, :, :), intent(out) :: &
dud_send_s, dud_send_e, dud_recv_s, dud_recv_e
real(dp), dimension(:, :, :), intent(out) :: &
d2u_send_s, d2u_send_e, d2u_recv_s, d2u_recv_e

real(dp), dimension(:, :, :), intent(in) :: u, u_recv_s, u_recv_e
real(dp), dimension(:, :, :), intent(in) :: v, v_recv_s, v_recv_e

type(tdsops_t), intent(in) :: tdsops_du, tdsops_dud, tdsops_d2u

real(dp), dimension(:, :), allocatable :: ud, ud_recv_s, ud_recv_e
real(dp) :: nu
integer, intent(in) :: nproc, pprev, pnext
integer, intent(in) :: n_block

integer :: n_data, n_halo
integer :: k, i, j, n

! TODO: don't hardcode n_halo
Nanoseb marked this conversation as resolved.
Show resolved Hide resolved
n_halo = 4
n = tdsops_d2u%n
n_data = SZ*n_block

allocate(ud(SZ, n))
allocate(ud_recv_e(SZ, n_halo))
allocate(ud_recv_s(SZ, n_halo))

!$omp parallel do private(ud, ud_recv_e, ud_recv_s)
do k = 1, n_block
call der_univ_dist( &
du(:, :, k), du_send_s(:, :, k), du_send_e(:, :, k), u(:, :, k), &
u_recv_s(:, :, k), u_recv_e(:, :, k), &
tdsops_du%coeffs_s, tdsops_du%coeffs_e, tdsops_du%coeffs, tdsops_du%n, &
tdsops_du%dist_fw, tdsops_du%dist_bw, tdsops_du%dist_af &
)

call der_univ_dist( &
d2u(:, :, k), d2u_send_s(:, :, k), d2u_send_e(:, :, k), u(:, :, k), &
u_recv_s(:, :, k), u_recv_e(:, :, k), &
tdsops_d2u%coeffs_s, tdsops_d2u%coeffs_e, tdsops_d2u%coeffs, tdsops_d2u%n, &
tdsops_d2u%dist_fw, tdsops_d2u%dist_bw, tdsops_d2u%dist_af &
)

! Handle dud by locally generating u*v
do j = 1, n
!$omp simd
do i = 1, SZ
ud(i, j) = u(i, j, k) * v(i, j, k)
end do
!$omp end simd
end do

do j = 1, n_halo
!$omp simd
do i = 1, SZ
ud_recv_s(i, j) = u_recv_s(i, j, k) * v_recv_s(i, j, k)
ud_recv_e(i, j) = u_recv_e(i, j, k) * v_recv_e(i, j, k)
end do
!$omp end simd
end do

call der_univ_dist( &
dud(:, :, k), dud_send_s(:, :, k), dud_send_e(:, :, k), ud(:, :), &
ud_recv_s(:, :), ud_recv_e(:, :), &
tdsops_dud%coeffs_s, tdsops_dud%coeffs_e, tdsops_dud%coeffs, tdsops_dud%n, &
tdsops_dud%dist_fw, tdsops_dud%dist_bw, tdsops_dud%dist_af &
)

end do
!$omp end parallel do

! halo exchange for 2x2 systems
call sendrecv_fields(du_recv_s, du_recv_e, du_send_s, du_send_e, &
n_data, nproc, pprev, pnext)
call sendrecv_fields(dud_recv_s, dud_recv_e, dud_send_s, dud_send_e, &
n_data, nproc, pprev, pnext)
call sendrecv_fields(d2u_recv_s, d2u_recv_e, d2u_send_s, d2u_send_e, &
n_data, nproc, pprev, pnext)

!$omp parallel do
do k = 1, n_block
call der_univ_subs(du(:, :, k), &
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realised that here we're writing 3 field sized arrays into main memory unnecessarily. It is potentially increasing the runtime %20.

In the second phase of the algorithm here we pass a part of the du, dud, and d2u into der_univ_subs, and they're all rewritten in place. Then later we combine them in rhs for the final result. Ideally, we want du, dud, and d2u to be read once and rhs to be written only once. However because of the way der_univ_subs work, the updated data in du arrays after der_univ_subs call gets written in the main memory, even though we don't need this at all.

There are three ways we can fix this

  • In the parallel do loop in the second phase we can copy the relevant parts of du, dud, and d2u arrays into (SZ, n) sized temporary arrays. Then we pass temporary arrays into der_univ_subs, and at the end we use these temporaries to obtain final rhs. This is the easiest solution but it may not be the best in terms of performance.
  • We can write an alternative der_univ_subs and separate input and output arrays. This way we can pass a part of the du arrays as we do now, and pass a small temporary array as the output one. Because du arrays will be input arrays no data will be written in main memory. Then we can combine the temporaries to get rhs.
  • If we're writing an alternative der_univ_subs to be used in transeq, we can go one step further and have a fused version of it. This would probably the most performant solution. der_univ_subs is relatively lightweight so it isn't really hard to do so. The new subrotuine can input all du, dud, and d2u, and write the final result rhs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#40 is relevant here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points, I will have a think about it, but indeed having it in a new PR focusing on performance makes sense.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I think best way to move forward is implementing the first strategy and checking how much of the peak BW we get. If its a good utilisation then maybe we don't need to bother with a new der_univ_subs at all.

du_recv_s(:, :, k), du_recv_e(:, :, k), &
tdsops_du%n, tdsops_du%dist_sa, tdsops_du%dist_sc)

call der_univ_subs(dud(:, :, k), &
dud_recv_s(:, :, k), dud_recv_e(:, :, k), &
tdsops_dud%n, tdsops_dud%dist_sa, tdsops_dud%dist_sc)

call der_univ_subs(d2u(:, :, k), &
d2u_recv_s(:, :, k), d2u_recv_e(:, :, k), &
tdsops_d2u%n, tdsops_d2u%dist_sa, tdsops_d2u%dist_sc)

do j = 1, n
!$omp simd
do i = 1, SZ
rhs(i, j, k) = -0.5_dp*(v(i, j, k)*du(i, j, k) + dud(i, j, k)) + nu*d2u(i, j, k)
end do
!$omp end simd
end do

end do
!$omp end parallel do


end subroutine exec_dist_transeq_compact



end module m_omp_exec_dist

Loading
Loading