In [None]:
import numpy as np
import pandas as pd
from scipy.optimize import curve_fit
from osgeo import gdal
from osgeo import osr
#import gdal
#import osr
import os
import time
import multiprocessing as mp
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, Executor



class Dataset:
    def __init__(self, in_file):
        self.in_file = in_file  # Tiff或者ENVI文件

        dataset = gdal.Open(self.in_file)
        self.XSize = dataset.RasterXSize  # 网格的X轴像素数量
        self.YSize = dataset.RasterYSize  # 网格的Y轴像素数量
        self.Bands = dataset.RasterCount  # 波段数
        self.GeoTransform = dataset.GetGeoTransform()  # 投影转换信息
        self.ProjectionInfo = dataset.GetProjection()  # 投影信息
    
    def get_data(self):
        #band: 读取第几个通道的数据
        dataset = gdal.Open(self.in_file)
        data = dataset.ReadAsArray(0,0,self.XSize,self.YSize)
        return data


    def get_lon_lat(self):
        #获取经纬度信息
        gtf = self.GeoTransform
        x_range = range(0, self.XSize)
        y_range = range(0, self.YSize)
        x, y = np.meshgrid(x_range, y_range)
        lon = gtf[0] + x * gtf[1] + y * gtf[2]
        lat = gtf[3] + x * gtf[4] + y * gtf[5]
        lon_lat=np.array(list(zip(lon,lat)))
        return lon_lat
        

def func(x, m1, m2, m3, m4, m5, m6):
    return m1 + m2 /(1 + np.exp(-m3 * (x-m4))) - m2/(1 + np.exp(-m5 * (x-m6)))

#注意初值
def get_param(yData):
    yData1=yData[0:2]
    yData2=yData[2::]
    xData=np.linspace(1, 365, 92)
    Parameters, pcov = curve_fit(func, xData, yData2, p0=[5,40,0.1,140,0.1,270], maxfev=100000000)#,method='trf', maxfev=1000000)
    result=np.hstack((yData1,Parameters))
    return result


if __name__ == '__main__':

    start_time=time.time()
    dir_path = r"D:\Desktop\mypaper\data"
    
    filename = "gee-LAI-297571.tif"
    #dir_path = r"D:\Desktop"
    #filename = "gee-LAI-Land-Caribbean-mask.tif"
    file_path = os.path.join(dir_path, filename)
    data_path = Dataset(file_path)
    data = data_path.get_data( ).transpose(1,2,0) 
    lon_lat = data_path.get_lon_lat().transpose(0,2,1)
    lon_lat_data=np.hstack((lon_lat.transpose(0,2,1),data.transpose(0,2,1))).transpose(0,2,1)
    dataset_3Dto2D=lon_lat_data.reshape(-1,lon_lat_data.shape[2])
    idx = np.argwhere(np.all(dataset_3Dto2D[..., 2::] == 0, axis=1))
    last_dataset=np.delete(dataset_3Dto2D,idx,axis=0)

    pool = ProcessPoolExecutor(max_workers=8)
    res = np.array(list(pool.map(get_param,last_dataset)))
    pool.shutdown(wait=True)
    end_time=time.time()
    np.savetxt('D:\Desktop\gee-LAI-297571-remove.txt',res,fmt='%1.3f')

    print('pool time :',end_time-start_time,'seconds')
    print('image:',time.time()-end_time,'seconds')
   # print ("Thread pool execution in " + str(time.time() - start_time), "seconds")
