In [None]:
from mpi4py import MPI
import numpy as np
from osgeo import gdal
from osgeo import osr
import os



class Dataset:
    def __init__(self, in_file):
        self.in_file = in_file  # Tiff或者ENVI文件

        dataset = gdal.Open(self.in_file)
        self.XSize = dataset.RasterXSize  # 网格的X轴像素数量
        self.YSize = dataset.RasterYSize  # 网格的Y轴像素数量
        self.Bands = dataset.RasterCount  # 波段数
        self.GeoTransform = dataset.GetGeoTransform()  # 投影转换信息
        self.ProjectionInfo = dataset.GetProjection()  # 投影信息
    
    def get_data(self):
        #band: 读取第几个通道的数据
        dataset = gdal.Open(self.in_file)
        data = dataset.ReadAsArray(0,0,self.XSize,self.YSize)
        return data


    def get_lon_lat(self):
        #获取经纬度信息
        gtf = self.GeoTransform
        x_range = range(0, self.XSize)
        y_range = range(0, self.YSize)
        x, y = np.meshgrid(x_range, y_range)
        lon = gtf[0] + x * gtf[1] + y * gtf[2]
        lat = gtf[3] + x * gtf[4] + y * gtf[5]
        lon_lat=np.array(list(zip(lon,lat)))
        return lon_lat

if __name__ == '__main__':

    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    nprocs = comm.Get_size()
    
    if rank == 0:
        sendbuf = np.random.random((21,5))
        # print(sendbuf.shape)
        colum2=sendbuf.shape[1]
        
        # each nprocess total number: the size of each sub-task
        ave, res = divmod(sendbuf.size, nprocs*colum2)
        ave1, res1 = divmod(res, colum2)
        each_nprocess_row = np.array([ave + 1 if p < ave1 else ave  for p in range(nprocs)])
        total_number = np.array(each_nprocess_row)*colum2
        
        # each nprocess star index: the starting index of each sub-task
        star_index = np.array([sum(total_number[:p]) for p in range(nprocs)])

    else:
        sendbuf = None
        star_index = None
        colum2=None
        
        # initialize on worker processes
        total_number = np.zeros(nprocs, dtype=np.int)
        each_nprocess_row = np.zeros(nprocs, dtype=np.int)

    # broadcast total_number,each_nprocess_row
    comm.Bcast(total_number, root=0)
    comm.Bcast(each_nprocess_row, root=0)
    colum2=comm.bcast(colum2, root=0)
    # initialize recvbuf on all processes
    recvbuf = np.zeros((each_nprocess_row[rank],colum2))

    comm.Scatterv([sendbuf, total_number, star_index, MPI.DOUBLE], recvbuf, root=0)

    # print(sendbuf)
    print('After Scatterv, process {} has data:'.format(rank))
    print(recvbuf.shape)

In [1]:
import numpy as np
nprocs=4
sendbuf = np.random.random((21,5))
# print(sendbuf.shape)
colum2=sendbuf.shape[1]

# each nprocess total number: the size of each sub-task
ave, res = divmod(sendbuf.size, nprocs*colum2)
ave1, res1 = divmod(res, colum2)
each_nprocess_row = np.array([ave + 1 if p < ave1 else ave  for p in range(nprocs)])
total_number = np.array(each_nprocess_row)*colum2

# each nprocess star index: the starting index of each sub-task
star_index = np.array([sum(total_number[:p]) for p in range(nprocs)])

In [3]:
print(ave, res)
print(ave1, res1)
print(each_nprocess_row)
print(total_number)
print(star_index)

5 5
1 0
[6 5 5 5]
[30 25 25 25]
[ 0 30 55 80]
