Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/sourcery/bin_m.f90
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,17 @@ elemental module function construct(num_items, num_bins, bin_number) result(bin)

interface

elemental module function first(self, bin_number) result(first_item_number)
elemental module function first(self) result(first_item_number)
!! the result is the first item number associated with the given bin
implicit none
class(bin_t), intent(in) :: self
integer, intent(in) :: bin_number
integer first_item_number
end function

elemental module function last(self, bin_number) result(last_item_number)
elemental module function last(self) result(last_item_number)
!! the result is the last item number associated with the given bin
implicit none
class(bin_t), intent(in) :: self
integer, intent(in) :: bin_number
integer last_item_number
end function

Expand Down
2 changes: 0 additions & 2 deletions src/sourcery/data_partition_m.f90
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ module data_partition_m
generic :: gather => gather_real32_2D_array, gather_real64_2D_array, gather_real32_1D_array, gather_real64_1D_array
end type

integer, allocatable :: first_datum(:), last_datum(:)

interface

module subroutine define_partitions(cardinality)
Expand Down
41 changes: 8 additions & 33 deletions src/sourcery/data_partition_s.f90
Original file line number Diff line number Diff line change
@@ -1,51 +1,26 @@
submodule(data_partition_m) data_partition_s
use assert_m, only : assert
use bin_m, only : bin_t
implicit none

logical, parameter :: verbose=.false.
type(bin_t), allocatable :: bin(:)

contains

module procedure define_partitions

if (allocated(first_datum)) deallocate(first_datum)
if (allocated(last_datum)) deallocate(last_datum)

associate( ni => num_images() )

call assert( ni<=cardinality, "sufficient data for distribution across images", cardinality)

allocate(first_datum(ni), last_datum(ni))

block
integer i, image
do image=1,ni
associate( remainder => mod(cardinality, ni), quotient => cardinality/ni )
first_datum(image) = sum([(quotient+overflow(i, remainder), i=1, image-1)]) + 1
last_datum(image) = first_datum(image) + quotient + overflow(image, remainder) - 1
end associate
end do
end block
end associate

contains

pure function overflow(im, excess) result(extra_datum)
integer, intent(in) :: im, excess
integer extra_datum
extra_datum= merge(1,0,im<=excess)
end function

integer image
bin = [( bin_t(num_items=cardinality, num_bins=num_images(), bin_number=image), image=1,num_images() )]
end procedure

module procedure first
call assert( allocated(first_datum), "allocated(first_datum)")
first_index= first_datum( image_number )
call assert( allocated(bin), "data_partition_s(first): allocated(bin)")
first_index = bin(image_number)%first()
end procedure

module procedure last
call assert( allocated(last_datum), "allocated(last_datum)")
last_index = last_datum( image_number )
call assert( allocated(bin), "data_partition_s(last): allocated(bin)")
last_index = bin(image_number)%last()
end procedure

module procedure gather_real32_1D_array
Expand Down
4 changes: 2 additions & 2 deletions test/bin_test.f90
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function verify_block_partitioning() result(test_passes)
integer b

bins = [( bin_t(num_items=n_items, num_bins=n_bins, bin_number=b), b = 1,n_bins )]
associate(in_bin => [(bins(b)%last(b) - bins(b)%first(b) + 1, b = 1, n_bins)])
associate(in_bin => [(bins(b)%last() - bins(b)%first() + 1, b = 1, n_bins)])
associate(remainder => mod(n_items, n_bins), items_per_bin => n_items/n_bins)
test_passes = all([(in_bin(1:remainder) == items_per_bin + 1)]) .and. all([(in_bin(remainder+1:) == items_per_bin)])
end associate
Expand All @@ -69,7 +69,7 @@ function verify_all_items_partitioned() result(test_passes)
integer b

bins = [( bin_t(num_items=n_items, num_bins=n_bins, bin_number=b), b = 1,n_bins )]
test_passes = sum([(bins(b)%last(b) - bins(b)%first(b) + 1, b = 1, n_bins)]) == n_items
test_passes = sum([(bins(b)%last() - bins(b)%first() + 1, b = 1, n_bins)]) == n_items

end function

Expand Down