In [6]:
import numpy as np
import pandas as pd
from scipy.optimize import curve_fit
from osgeo import gdal
import os
import time

In [2]:
class Dataset:
    def __init__(self, in_file):
        self.in_file = in_file  # Tiff或者ENVI文件

        dataset = gdal.Open(self.in_file)
        self.XSize = dataset.RasterXSize  # 网格的X轴像素数量
        self.YSize = dataset.RasterYSize  # 网格的Y轴像素数量
        self.Bands = dataset.RasterCount  # 波段数
        self.GeoTransform = dataset.GetGeoTransform()  # 投影转换信息
        self.ProjectionInfo = dataset.GetProjection()  # 投影信息
    
    def get_data(self):
        #band: 读取第几个通道的数据
        dataset = gdal.Open(self.in_file)
        data = dataset.ReadAsArray(0,0,self.XSize,self.YSize)
        return data
    

    def get_lon_lat(self):
        #获取经纬度信息
        gtf = self.GeoTransform
        x_range = range(0, self.XSize)
        y_range = range(0, self.YSize)
        x, y = np.meshgrid(x_range, y_range)
        lon = gtf[0] + x * gtf[1] + y * gtf[2]
        lat = gtf[3] + x * gtf[4] + y * gtf[5]
        
        lon_lat=[]
        for (longitude,latitude) in zip(lon,lat):
            lon_lat.append(list(zip(longitude,latitude)))
            
        return np.array(lon_lat)
    
    def new_dataset(self,data,lon_lat):
        new_dataset=[]
        for i in range(self.YSize):
            for j in range(self.XSize):
                x1 = lon_lat[i,j,:]
                x2 = data[i,j,:]
                x=np.hstack((x1,x2))
                new_dataset.append(x)
            
        return np.array(new_dataset)
    
    def dataset2dim(self,data):
        dataset2dim=[]
        for i in range(self.YSize):
            for j in range(self.XSize):
                x1 = lon_lat[i,j]
                x2 = data[:,i,j]
                x=np.hstack((x1,x2))
                dataset2dim.append(x2)

        return np.array(dataset2dim)

In [3]:
def func(x, m1, m2, m3, m4, m5, m6):
    return m1 + m2 /(1 + np.exp(-m3 * (x-m4))) - m2/(1 + np.exp(-m5 * (x-m6)))

#注意初值
def fitting(yData, xData) :
#     Parameters, pcov = curve_fit(func, xData, yData, p0=[900,5000,0.1,140,0.1,270], maxfev=10000000)#,method='trf', maxfev=1000000)
#     return Parameters

###########################################################
    start_time1 = time.time()
    for i in range(1,1000000):

        try:
            Parameters, pcov = curve_fit(func, xData, yData, p0=[5,40,0.1,140,0.1,270], maxfev=i)#,method='trf', maxfev=1000000)
        except RuntimeError: 
            continue
        if Parameters is not None:
            break
    totaltime=time.time()-start_time1
    totalparm=np.hstack((Parameters,[i,totaltime]))
    return totalparm
###########################################################

def for_map(yData, xData):
    result=np.array(list(map(fitting,yData,xData)))
 
    return result


In [4]:
start=time.time()

dir_path = r"D:\Desktop\mypaper\data"
filename = "gee-LAI-108.tif"
file_path = os.path.join(dir_path, filename)
dataset = Dataset(file_path)

data = dataset.get_data( ).transpose(1,2,0) 

# lon_lat = dataset.get_lon_lat()  
# new_dataset=dataset.new_dataset(data,lon_lat)

In [5]:
print(data.shape)

(13, 12, 92)


In [6]:
xdim=data.shape[0]
ydim=data.shape[1]
zdim=data.shape[2]
print(xdim,ydim,zdim)

13 12 92


In [7]:
xInput=np.linspace(1, 361, 92)
xData=np.repeat([np.repeat([xInput],ydim,axis=0)],xdim,axis=0)
print('xData',xData.shape)

xData (13, 12, 92)


In [15]:
result=np.array(list(map(for_map,data,xData)))
lon_lat = dataset.get_lon_lat()   # 获取经纬度信息longitude, latitude
new_dataset=dataset.new_dataset(result,lon_lat)

print("Thread pool execution in " + str(time.time() - start_time), "seconds")
writer = pd.ExcelWriter(r"/content/drive/My Drive/tif/get_param_python/LAI_10884_python_param_test.xls")
pd.DataFrame(new_dataset).to_excel(writer,na_rep=0,index=False,header=['lon','lat','p0','p1','p2','p3','p4','p5','item','time'])
writer.save()
writer.close()
print('total time',time.time()-start)
print(lon_lat.shape)
print(result.shape)
# print(result)

  


total time 296.1302101612091
(13, 12, 2)
(13, 12, 8)


In [16]:
print(new_dataset.shape)

(156, 10)
