diff --git a/src/sourcery/bin_m.f90 b/src/sourcery/bin_m.f90 index cd533d7f..edd396f1 100644 --- a/src/sourcery/bin_m.f90 +++ b/src/sourcery/bin_m.f90 @@ -26,19 +26,17 @@ elemental module function construct(num_items, num_bins, bin_number) result(bin) interface - elemental module function first(self, bin_number) result(first_item_number) + elemental module function first(self) result(first_item_number) !! the result is the first item number associated with the given bin implicit none class(bin_t), intent(in) :: self - integer, intent(in) :: bin_number integer first_item_number end function - elemental module function last(self, bin_number) result(last_item_number) + elemental module function last(self) result(last_item_number) !! the result is the last item number associated with the given bin implicit none class(bin_t), intent(in) :: self - integer, intent(in) :: bin_number integer last_item_number end function diff --git a/src/sourcery/data_partition_m.f90 b/src/sourcery/data_partition_m.f90 index a310506e..70f3dd73 100644 --- a/src/sourcery/data_partition_m.f90 +++ b/src/sourcery/data_partition_m.f90 @@ -18,8 +18,6 @@ module data_partition_m generic :: gather => gather_real32_2D_array, gather_real64_2D_array, gather_real32_1D_array, gather_real64_1D_array end type - integer, allocatable :: first_datum(:), last_datum(:) - interface module subroutine define_partitions(cardinality) diff --git a/src/sourcery/data_partition_s.f90 b/src/sourcery/data_partition_s.f90 index 65078141..bafdc0aa 100644 --- a/src/sourcery/data_partition_s.f90 +++ b/src/sourcery/data_partition_s.f90 @@ -1,51 +1,26 @@ submodule(data_partition_m) data_partition_s use assert_m, only : assert + use bin_m, only : bin_t implicit none logical, parameter :: verbose=.false. + type(bin_t), allocatable :: bin(:) contains module procedure define_partitions - - if (allocated(first_datum)) deallocate(first_datum) - if (allocated(last_datum)) deallocate(last_datum) - - associate( ni => num_images() ) - - call assert( ni<=cardinality, "sufficient data for distribution across images", cardinality) - - allocate(first_datum(ni), last_datum(ni)) - - block - integer i, image - do image=1,ni - associate( remainder => mod(cardinality, ni), quotient => cardinality/ni ) - first_datum(image) = sum([(quotient+overflow(i, remainder), i=1, image-1)]) + 1 - last_datum(image) = first_datum(image) + quotient + overflow(image, remainder) - 1 - end associate - end do - end block - end associate - - contains - - pure function overflow(im, excess) result(extra_datum) - integer, intent(in) :: im, excess - integer extra_datum - extra_datum= merge(1,0,im<=excess) - end function - + integer image + bin = [( bin_t(num_items=cardinality, num_bins=num_images(), bin_number=image), image=1,num_images() )] end procedure module procedure first - call assert( allocated(first_datum), "allocated(first_datum)") - first_index= first_datum( image_number ) + call assert( allocated(bin), "data_partition_s(first): allocated(bin)") + first_index = bin(image_number)%first() end procedure module procedure last - call assert( allocated(last_datum), "allocated(last_datum)") - last_index = last_datum( image_number ) + call assert( allocated(bin), "data_partition_s(last): allocated(bin)") + last_index = bin(image_number)%last() end procedure module procedure gather_real32_1D_array diff --git a/test/bin_test.f90 b/test/bin_test.f90 index 38be4e30..2ae7f74d 100644 --- a/test/bin_test.f90 +++ b/test/bin_test.f90 @@ -51,7 +51,7 @@ function verify_block_partitioning() result(test_passes) integer b bins = [( bin_t(num_items=n_items, num_bins=n_bins, bin_number=b), b = 1,n_bins )] - associate(in_bin => [(bins(b)%last(b) - bins(b)%first(b) + 1, b = 1, n_bins)]) + associate(in_bin => [(bins(b)%last() - bins(b)%first() + 1, b = 1, n_bins)]) associate(remainder => mod(n_items, n_bins), items_per_bin => n_items/n_bins) test_passes = all([(in_bin(1:remainder) == items_per_bin + 1)]) .and. all([(in_bin(remainder+1:) == items_per_bin)]) end associate @@ -69,7 +69,7 @@ function verify_all_items_partitioned() result(test_passes) integer b bins = [( bin_t(num_items=n_items, num_bins=n_bins, bin_number=b), b = 1,n_bins )] - test_passes = sum([(bins(b)%last(b) - bins(b)%first(b) + 1, b = 1, n_bins)]) == n_items + test_passes = sum([(bins(b)%last() - bins(b)%first() + 1, b = 1, n_bins)]) == n_items end function