In [None]:
beam1_data.loc[:, 'grid_lon_idx'] = np.floor((beam1_data['Longitude'] - lon_min) / grid_res)
beam1_data.loc[:, 'grid_lat_idx'] = np.floor((beam1_data['Latitude'] - lat_min) / grid_res)

# Treat beam3, beam5 in the same way
# 按格网单元索引分组，计算该格平均SSHA
grouped = beam1_data.groupby(['grid_lon_idx', 'grid_lat_idx'])
grid_averages = grouped['SSHA'].mean().reset_index()


full_grid = np.full((int((lat_max - lat_min) / grid_res), int((lon_max - lon_min) / grid_res)), np.nan)

for index, row in grid_averages.iterrows():
    lon_idx = int(row['grid_lon_idx'])
    lat_idx = int(row['grid_lat_idx'])

    if lon_idx < full_grid.shape[1] and lat_idx < full_grid.shape[0]:
        full_grid[lat_idx, lon_idx] = row['SSHA']

In [None]:
grid_res = 1
lon_min, lon_max, lat_min, lat_max = 100, 140, 0, 50

lon_edges = np.linspace(lon_min, lon_max, int((lon_max - lon_min) / grid_res) + 1)
lat_edges = np.linspace(lat_min, lat_max, int((lat_max - lat_min) / grid_res) + 1)

lon_centers = (lon_edges[:-1] + lon_edges[1:]) / 2
lat_centers = (lat_edges[:-1] + lat_edges[1:]) / 2

grid_lon, grid_lat = np.meshgrid(lon_centers, lat_centers)

is_land = globe.is_land(grid_lat, grid_lon)

season_months = {
    'JFM': [1, 2, 3],
    'AMJ': [4, 5, 6],
    'JAS': [7, 8, 9],
    'OND': [10, 11, 12]
}

In [None]:
def calculate_seasonal_ssha(season_df):
    season_df['grid_lon_idx'] = np.floor((season_df['Longitude'] - lon_min) / grid_res)
    season_df['grid_lat_idx'] = np.floor((season_df['Latitude'] - lat_min) / grid_res)
    
    # 按格网单元索引分组，并计算平均SSHA_DTU21
    grouped = season_df.groupby(['grid_lon_idx', 'grid_lat_idx'])
    grid_averages = grouped['SSHA'].mean().reset_index()
    
    # 创建临时的全格网数组
    temp_grid = np.full(full_grid.shape, np.nan)
    for index, row in grid_averages.iterrows():
        lon_idx = int(row['grid_lon_idx'])
        lat_idx = int(row['grid_lat_idx'])
        
        if lon_idx < temp_grid.shape[1] and lat_idx < temp_grid.shape[0]:
            temp_grid[lat_idx, lon_idx] = row['SSHA']
    
    return temp_grid

In [None]:
def count_data_points(df, grid_lon, grid_lat, grid_res):
    # 创建一个空的numpy数组来存储数据点的数量
    count_grid = np.zeros(grid_lon.shape, dtype=int)

    # 对于每个格网点
    for i in range(grid_lon.shape[0]):
        for j in range(grid_lon.shape[1]):
            
            lon_min, lon_max = grid_lon[i, j] - grid_res / 2, grid_lon[i, j] + grid_res / 2
            lat_min, lat_max = grid_lat[i, j] - grid_res / 2, grid_lat[i, j] + grid_res / 2

            
            subset = df[(df['Longitude'] >= lon_min) & (df['Longitude'] <= lon_max) &
                        (df['Latitude'] >= lat_min) & (df['Latitude'] <= lat_max)]

            
            count_grid[i, j] = len(subset)

    return count_grid

In [None]:
titles = ['Beam 3 - Beam 1', 'Beam 5 - Beam 1', 'Beam 5 - Beam 3']
differences = [diff_ssha_beam3_beam1,diff_ssha_beam5_beam1, diff_ssha_beam5_beam3]

# 设置直方图的柱体之间的空隙大小，可以调整这个值以满足您的需求
space_between_bars = 0.0006   


fixed_bins = np.arange(-0.06,0.07,0.005) 

fig, axs = plt.subplots(1, 3, figsize=(30, 10))
max_count = 500


for i, (diff, title) in enumerate(zip(differences, titles)):
    ax = axs[i]  # 选择当前的子图

    valid_diff = diff[~np.isnan(diff)]
    counts, edges = np.histogram(valid_diff, bins=fixed_bins)
    

    bar_width = np.diff(edges) - space_between_bars
    bars = ax.bar(edges[:-1], counts, width=bar_width, edgecolor='black', color='#87CEFA', align='edge')


    mean = np.nanmean(valid_diff)
    std = np.nanstd(valid_diff)
    ax.axvline(mean, color='black', linestyle='dashed', linewidth=4)
#     info_text = f'$\mu$={mean:.2f}m\n$\sigma$={std:.2f}m'
#     ax.text(0.05, 0.95, info_text, transform=ax.transAxes, fontsize=16, verticalalignment='top', bbox=dict(facecolor='white', alpha=0.5))

    # 这里我们看三张图最大y值到多少，以设置统一ylim
    ax.set_ylim(0, 500)

    # 对于第二和第三直方图，隐藏y轴刻度，这里可以看个人习惯
#     if i > 0:
#         ax.set_yticklabels([])

    # ax.set_title(title)
#     ax.set_xlabel('SSHA Difference (m)')
#     ax.set_ylabel('Grid Cell Count')

    
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    
    ax.tick_params(which='major', direction='in', length=15, width=2, pad=10)
    ax.set_xticks(np.arange(-0.05, 0.06, 0.05))


plt.tight_layout()

plt.show()
    
