In [1]:
import xarray as xr

In [2]:
gnss_file=xr.open_dataset(r'E:/gnss_ztd_combined_robust_morestation.nc')
wind_file=xr.open_dataset(r'E:/merged_stations_6min_common_period_float32_target_height.nc')

In [3]:
import pandas as pd
wind_new_time = pd.date_range(wind_file.Datetime.values[0], wind_file.Datetime.values[-1], freq='5T')
wind_file = wind_file.interp(Datetime=wind_new_time)

  wind_new_time = pd.date_range(wind_file.Datetime.values[0], wind_file.Datetime.values[-1], freq='5T')


In [4]:
import numpy as np
print(np.array(gnss_file['time']))
print(np.array(wind_file['Datetime']))

['2025-05-14T05:40:00.000000000' '2025-05-14T05:45:00.000000000'
 '2025-05-14T05:50:00.000000000' ... '2025-07-21T00:45:00.000000000'
 '2025-07-21T00:50:00.000000000' '2025-07-21T00:55:00.000000000']
['2024-07-29T09:30:00.000000000' '2024-07-29T09:35:00.000000000'
 '2024-07-29T09:40:00.000000000' ... '2025-07-21T23:40:00.000000000'
 '2025-07-21T23:45:00.000000000' '2025-07-21T23:50:00.000000000']


In [5]:
# 先插值填充内部 NaN
gnss_file = gnss_file.interpolate_na(dim='time')

# 再前向填充和后向填充
gnss_file = gnss_file.ffill(dim='time').bfill(dim='time')

In [6]:
print(np.isnan(gnss_file['ztd']).any())
print(np.isnan(wind_file['U']).any())
print(np.isnan(wind_file['V']).any())
print(np.isnan(wind_file['W']).any())

<xarray.DataArray 'ztd' ()> Size: 1B
array(False)
<xarray.DataArray 'U' ()> Size: 1B
array(False)
<xarray.DataArray 'V' ()> Size: 1B
array(False)
<xarray.DataArray 'W' ()> Size: 1B
array(False)


In [7]:
import xarray as xr
import numpy as np
import pandas as pd
import time # 用于性能比较


def prepare_transformer_inputs_mem_efficient_robust(gnss_ds: xr.Dataset, wind_ds: xr.Dataset):
    """
    更稳健的版本，处理时间戳 dtype 不匹配的问题。
    """
    start_time = time.time()
    
    gnss_ztd = gnss_ds['ztd']
    wind_u = wind_ds['U']
    
    sequence_length = 6
    time_step = pd.to_timedelta('5min')
    expected_duration = time_step * (sequence_length - 1)

    num_gnss_times = len(gnss_ztd.time)
    
    # --- Pass 1: 扫描并找到所有有效样本的起始索引 ---
    print("Pass 1: 正在扫描有效的连续序列 (稳健模式)...")
    
    # 打印 dtype 以供参考
    gnss_dtype = gnss_ztd.time.dtype
    wind_dtype = wind_u.Datetime.dtype
    print(f"GNSS 时间 dtype: {gnss_dtype}, Wind 时间 dtype: {wind_dtype}")
    # 获取 wind_dtype 的单位，如 'ns', 'us'
    wind_dtype_unit = np.datetime_data(wind_dtype)[0]

    valid_start_indices = []
    wind_times_set = set(wind_u.Datetime.values)

    for i in range(num_gnss_times - sequence_length + 1):
        window_times = gnss_ztd.time[i : i + sequence_length]
        
        # 稳健地比较时长，避免微小浮点误差
        actual_duration = window_times[-1].values - window_times[0].values
        if np.abs(actual_duration - expected_duration) < pd.to_timedelta('1s'):
            
            # 计算目标时间
            target_wind_time_raw = window_times[-1].values + 6*time_step
            
            # 【关键修复】: 将计算出的时间强制转换为与风场数据相同的时间类型/精度
            target_wind_time_converted = np.datetime64(target_wind_time_raw, wind_dtype_unit)

            if target_wind_time_converted in wind_times_set:
                valid_start_indices.append(i)
    
    num_samples = len(valid_start_indices)
    print(f"Pass 1 完成. 共找到 {num_samples} 个有效样本。耗时: {time.time() - start_time:.2f} 秒。")

    if num_samples == 0:
        print("在稳健模式下仍然未找到任何有效序列。请检查数据本身，例如风场数据是否覆盖了GNSS数据的时间范围。")
        return None, None

    # --- 内存分配 (Pass 2) ---
    # ... (这部分代码与之前相同，无需更改) ...
    # 为了完整性，我将其余代码也附上
    print("正在预分配内存...")
    num_stations_gnss = len(gnss_ztd.station)
    num_stations_wind = len(wind_u.station)
    num_press_levels = len(wind_u.HEIGHT)
    
    vx_data = np.empty((num_samples, sequence_length, num_stations_gnss), dtype=np.float32)
    vy_data = np.empty((num_samples, num_stations_wind * num_press_levels), dtype=np.float32)

    print("Pass 2: 正在填充数据...")
    fill_start_time = time.time()
    gnss_ztd_values = gnss_ztd.values
    
    for k, start_idx in enumerate(valid_start_indices):
        end_idx = start_idx + sequence_length
        vx_data[k, :, :] = gnss_ztd_values[start_idx:end_idx, :]
        
        last_gnss_time = gnss_ztd.time[end_idx - 1]
        target_wind_time = last_gnss_time.values + 6*time_step
        
        # 在这里也需要使用转换后的时间来索引
        target_wind_time_converted = np.datetime64(target_wind_time, wind_dtype_unit)
        vy_slice_values = wind_u.sel(Datetime=target_wind_time_converted).values
        vy_data[k, :] = vy_slice_values.flatten()
        
    print(f"Pass 2 完成. 数据填充完毕。耗时: {time.time() - fill_start_time:.2f} 秒。")

    print("正在创建最终的 xarray.DataArray...")
    sample_coords = gnss_ztd.time.values[valid_start_indices]
    vx = xr.DataArray(
        vx_data,
        dims=('sample', 'timesteps', 'station'),
        coords={'sample': sample_coords, 'timesteps': np.arange(sequence_length), 'station': gnss_ztd.station.values}
    )

    vy_flat_coords = wind_u.stack(station_HEIGHT_flat=('station', 'HEIGHT')).coords['station_HEIGHT_flat']
    vy = xr.DataArray(
        vy_data,
        dims=('sample', 'station_HEIGHT_flat'),
        coords={'sample': sample_coords, 'station_HEIGHT_flat': vy_flat_coords}
    )
    
    total_time = time.time() - start_time
    print(f"所有处理完成！总耗时: {total_time:.2f} 秒。")
    return vx, vy

# --- 主程序 ---
if __name__ == '__main__':
    

    # 调用内存优化后的函数
    vx, vy = prepare_transformer_inputs_mem_efficient_robust(gnss_file, wind_file)

    if vx is not None and vy is not None:
        print("\n--- 处理后结果 ---")
        print("输入变量 vx:")
        print(f"  - 形状: {vx.shape}")
        print(f"  - 维度: {vx.dims}")
        print(f"  - 内存占用: {vx.nbytes / 1e6:.2f} MB")
        
        print("\n目标变量 vy:")
        print(f"  - 形状: {vy.shape}")
        print(f"  - 维度: {vy.dims}")
        print(f"  - 内存占用: {vy.nbytes / 1e6:.2f} MB")

Pass 1: 正在扫描有效的连续序列 (稳健模式)...
GNSS 时间 dtype: datetime64[ns], Wind 时间 dtype: datetime64[ns]
Pass 1 完成. 共找到 17594 个有效样本。耗时: 16.91 秒。
正在预分配内存...
Pass 2: 正在填充数据...
Pass 2 完成. 数据填充完毕。耗时: 9.47 秒。
正在创建最终的 xarray.DataArray...
所有处理完成！总耗时: 26.39 秒。

--- 处理后结果 ---
输入变量 vx:
  - 形状: (17594, 6, 1215)
  - 维度: ('sample', 'timesteps', 'station')
  - 内存占用: 513.04 MB

目标变量 vy:
  - 形状: (17594, 126)
  - 维度: ('sample', 'station_HEIGHT_flat')
  - 内存占用: 8.87 MB


In [8]:
print(vx,vy)

<xarray.DataArray (sample: 17594, timesteps: 6, station: 1215)> Size: 513MB
array([[[2.531    , 2.8242776, 2.538875 , ..., 2.480333 , 2.29475  ,
         2.3930626],
        [2.531    , 2.8242776, 2.538875 , ..., 2.480333 , 2.29475  ,
         2.3930626],
        [2.531    , 2.8242776, 2.538875 , ..., 2.480333 , 2.29475  ,
         2.3930626],
        [2.531    , 2.8242776, 2.538875 , ..., 2.480333 , 2.29475  ,
         2.3930626],
        [2.531    , 2.8242776, 2.538875 , ..., 2.480333 , 2.29475  ,
         2.3930626],
        [2.531    , 2.8242776, 2.538875 , ..., 2.480333 , 2.29475  ,
         2.3930626]],

       [[2.531    , 2.8242776, 2.538875 , ..., 2.480333 , 2.29475  ,
         2.3930626],
        [2.531    , 2.8242776, 2.538875 , ..., 2.480333 , 2.29475  ,
         2.3930626],
        [2.531    , 2.8242776, 2.538875 , ..., 2.480333 , 2.29475  ,
         2.3930626],
        [2.531    , 2.8242776, 2.538875 , ..., 2.480333 , 2.29475  ,
...
         2.6803334],
        [2.6883326

In [9]:
#Transformer网络
def Auto_Transformer(vy,vx,timestep,model_list,test_size=0.2,valid_size=0.1,k_fold=None,task_mode='regression',if_best_mode='no',modelpath=None,encoder_deep=1,num_heads=2,key_dim=2,ifdropout='no',trans_dropout_rate=0.0,trans_units=64,trans_activation='sigmoid',embedding_num=None,if_weight_initialize='no',weight_initialize_method='TruncatedNormal',weight_initialize_parameter1=0.00,weight_initialize_parameter2=0.05,if_print_model='yes',loss_function='default',optimizer='SGD',metrics='default',if_early_stopping=None,learning_rate=0.01,epochs=2000,batch_size=20,ifrandom_split='yes',ifweight='yes',ifmute='no',ifsave='no',savepath=None,device='cpu'):
    import tensorflow as tf
    if device=='gpu':
        gpus = tf.config.list_physical_devices('GPU')
        if gpus:
            try:
                # 设置只使用 GPU 0
                tf.config.set_visible_devices(gpus[0], 'GPU')
                # 设置 GPU 0 的内存动态增长
                tf.config.experimental.set_memory_growth(gpus[0], True)
            except RuntimeError as e:
                print(e)
    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    from keras.models import Sequential,Model
    from keras.layers.core import Activation,Dropout,Dense
    from keras.layers import Input,BatchNormalization,LayerNormalization,Embedding,Add,MultiHeadAttention,Flatten
    from keras.initializers import TruncatedNormal,RandomNormal,RandomUniform
    from sklearn.model_selection import train_test_split
    from sklearn.model_selection import KFold
    import numpy as np
    from tensorflow.keras.optimizers import SGD,Adam
    import keras
    from scipy.stats import pearsonr
    import os
    from sklearn.metrics import accuracy_score,recall_score,precision_score,f1_score
    from keras.models import load_model
    import sklearn
    import copy
    from sklearn.utils.class_weight import compute_class_weight
    from tensorflow.keras import backend as K
    import warnings
    warnings.filterwarnings("ignore")
    
    if embedding_num==None:
        embedding_num=timestep+1
    if task_mode=='regression':
        if loss_function=='default' or loss_function=='MeanSquaredError':
            loss=tf.keras.losses.MeanSquaredError()
        elif loss_function=='MeanSquaredError':
            loss=tf.keras.losses.MeanSquaredError()
        elif loss_function=='MeanAbsoluteError':
            loss=tf.keras.losses.MeanAbsoluteError()
        elif loss_function=='MeanAbsolutePercentageError':
            loss=tf.keras.losses.MeanAbsolutePercentageError()
        elif loss_function=='MeanSquaredLogarithmicError':
            loss=tf.keras.losses.MeanSquaredLogarithmicError()
        elif loss_function=='CosineSimilarity':
            loss=tf.keras.losses.CosineSimilarity()
        elif loss_function=='Huber':
            loss=tf.keras.losses.Huber()
        elif loss_function=='LogCosh':
            loss=tf.keras.losses.LogCosh()
        elif loss_function=='Pearsonr':
            def loss_pearsonr(y_true,y_pred):
                import tensorflow as tf
                y_true_mean=tf.reduce_mean(y_true,axis=0)
                y_pred_mean=tf.reduce_mean(y_pred,axis=0)
                cov=tf.reduce_sum((y_true-y_true_mean)*(y_pred-y_pred_mean),axis=0)
                y_true_v=tf.reduce_sum(tf.square((y_true-y_true_mean)),axis=0)
                y_pred_v=tf.reduce_sum(tf.square((y_pred-y_pred_mean)),axis=0)
                y_true_v=tf.sqrt(y_true_v)
                y_pred_v=tf.sqrt(y_pred_v)
                pearson=cov/(y_true_v*y_pred_v)
                return (1-pearson)**1.5
            loss=loss_pearsonr
        if metrics=='default' or metrics=='MeanSquaredError':
            metric=tf.keras.metrics.MeanSquaredError()
        elif metrics=='MeanAbsoluteError':
            metric=tf.keras.metrics.MeanAbsoluteError()
        elif metrics=='MeanAbsolutePercentageError':
            metric=tf.keras.metrics.MeanAbsolutePercentageError()
        elif metrics=='MeanSquaredLogarithmicError':
            metric=tf.keras.metrics.MeanSquaredLogarithmicError()
        elif metrics=='CosineSimilarity':
            metric=tf.keras.metrics.CosineSimilarity()
        elif metrics=='LogCoshError':
            metric=tf.keras.metrics.LogCoshError()
        elif metrics=='Pearsonr':
            def metrics_pearsonr(y_true,y_pred):
                import tensorflow as tf
                y_true_mean=tf.reduce_mean(y_true,axis=0)
                y_pred_mean=tf.reduce_mean(y_pred,axis=0)
                cov=tf.reduce_sum((y_true-y_true_mean)*(y_pred-y_pred_mean),axis=0)
                y_true_v=tf.reduce_sum(tf.square((y_true-y_true_mean)),axis=0)
                y_pred_v=tf.reduce_sum(tf.square((y_pred-y_pred_mean)),axis=0)
                y_true_v=tf.sqrt(y_true_v)
                y_pred_v=tf.sqrt(y_pred_v)
                pearson=cov/(y_true_v*y_pred_v)
                return (1-pearson)**1.5
            metric=metrics_pearsonr
    elif task_mode=='binary_classify':
        if loss_function=='default' or loss_function=='BinaryCrossentropy':
            loss=tf.keras.losses.BinaryCrossentropy()
        elif loss_function=='CategoricalCrossentropy':
            loss=tf.keras.losses.CategoricalCrossentropy()
        elif loss_function=='SparseCategoricalCrossentropy':
            loss=tf.keras.losses.SparseCategoricalCrossentropy()
        elif loss_function=='Poisson':
            loss=tf.keras.losses.Poisson()
        elif loss_function=='KLDivergence':
            loss=tf.keras.losses.KLDivergence()
        elif loss_function=='f1':
            def loss_f1(y_true, y_pred):
                y_true = K.cast(y_true, 'float32')
                y_pred = K.cast(y_pred, 'float32')
                
                tp = K.sum(y_true * y_pred)
                fp = K.sum((1 - y_true) * y_pred)
                fn = K.sum(y_true * (1 - y_pred))
                
                precision = tp / (tp + fp + K.epsilon())
                recall = tp / (tp + fn + K.epsilon())
                
                f1 = 2 * precision * recall / (precision + recall + K.epsilon())
                
                return 1 - f1
            loss=loss_f1
        if metrics=='default' or metrics=='BinaryAccuracy':
            metric=tf.keras.metrics.BinaryAccuracy()
        elif metrics=='MeanAbsoluteError':
            metric=tf.keras.metrics.MeanAbsoluteError()
        elif metrics=='Accuracy':
            metric=tf.keras.metrics.Accuracy()
        elif metrics=='CategoricalAccuracy':
            metric=tf.keras.metrics.CategoricalAccuracy()
        elif metrics=='SparseCategoricalAccuracy':
            metric=tf.keras.metrics.SparseCategoricalAccuracy()
        elif metrics=='TopKCategoricalAccuracy':
            metric=tf.keras.metrics.TopKCategoricalAccuracy()
        elif metrics=='SparseTopKCategoricalAccuracy':
            metric=tf.keras.metrics.SparseTopKCategoricalAccuracy()
        elif metrics=='BinaryCrossentropy':
            metric=tf.keras.metrics.BinaryCrossentropy()
        elif metrics=='CategoricalCrossentropy':
            metric=tf.keras.metrics.CategoricalCrossentropy()
        elif metrics=='SparseCategoricalCrossentropy':
            metric=tf.keras.metrics.SparseCategoricalCrossentropy()
        elif metrics=='Accuracy':
            metric=tf.keras.metrics.Accuracy()
        elif metrics=='CategoricalAccuracy':
            metric=tf.keras.metrics.CategoricalAccuracy()
        elif metrics=='SparseCategoricalAccuracy':
            metric=tf.keras.metrics.SparseCategoricalAccuracy()
        elif metrics=='KLDivergence':
            metric=tf.keras.metrics.KLDivergence()
        elif metrics=='Poisson':
            metric=tf.keras.metrics.Poisson()
        elif metrics=='AUC':
            metric=tf.keras.metrics.AUC()
        elif metrics=='Precision':
            metric=tf.keras.metrics.Precision()
        elif metrics=='Recall':
            metric=tf.keras.metrics.Recall()
        elif metrics=='TruePositives':
            metric=tf.keras.metrics.TruePositives()
        elif metrics=='TrueNegatives':
            metric=tf.keras.metrics.TrueNegatives()
        elif metrics=='FalsePositives':
            metric=tf.keras.metrics.FalsePositives()
        elif metrics=='FalseNegatives':
            metric=tf.keras.metrics.FalseNegatives()
        elif metrics=='PrecisionAtRecall':
            metric=tf.keras.metrics.PrecisionAtRecall()
        elif metrics=='SensitivityAtSpecificity':
            metric=tf.keras.metrics.SensitivityAtSpecificity()
        elif metrics=='SpecificityAtSensitivity':
            metric=tf.keras.metrics.SpecificityAtSensitivity()
        elif metrics=='f1':
            def metric_f1(y_true, y_pred):
                y_true = K.cast(y_true, 'float32')
                y_pred = K.cast(y_pred, 'float32')
                
                tp = K.sum(y_true * y_pred)
                fp = K.sum((1 - y_true) * y_pred)
                fn = K.sum(y_true * (1 - y_pred))
                
                precision = tp / (tp + fp + K.epsilon())
                recall = tp / (tp + fn + K.epsilon())
                
                f1 = 2 * precision * recall / (precision + recall + K.epsilon())
                
                return f1
            metric=metric_f1
    elif task_mode=='multi_classify':
        if loss_function=='default' or loss_function=='CategoricalCrossentropy':
            loss=tf.keras.losses.CategoricalCrossentropy()
        elif loss_function=='SparseCategoricalCrossentropy':
            loss=tf.keras.losses.SparseCategoricalCrossentropy()
        elif loss_function=='Poisson':
            loss=tf.keras.losses.Poisson()
        elif loss_function=='KLDivergence':
            loss=tf.keras.losses.KLDivergence()
        if metrics=='default' or metrics=='Accuracy':
            metric=tf.keras.metrics.Accuracy()
        elif metrics=='MeanAbsoluteError':
            metric=tf.keras.metrics.MeanAbsoluteError()
        elif metrics=='CategoricalAccuracy':
            metric=tf.keras.metrics.CategoricalAccuracy()
        elif metrics=='SparseCategoricalAccuracy':
            metric=tf.keras.metrics.SparseCategoricalAccuracy()
        elif metrics=='TopKCategoricalAccuracy':
            metric=tf.keras.metrics.TopKCategoricalAccuracy()
        elif metrics=='SparseTopKCategoricalAccuracy':
            metric=tf.keras.metrics.SparseTopKCategoricalAccuracy()
        elif metrics=='CategoricalCrossentropy':
            metric=tf.keras.metrics.CategoricalCrossentropy()
        elif metrics=='SparseCategoricalCrossentropy':
            metric=tf.keras.metrics.SparseCategoricalCrossentropy()
        elif metrics=='Accuracy':
            metric=tf.keras.metrics.Accuracy()
        elif metrics=='CategoricalAccuracy':
            metric=tf.keras.metrics.CategoricalAccuracy()
        elif metrics=='SparseCategoricalAccuracy':
            metric=tf.keras.metrics.SparseCategoricalAccuracy()
        elif metrics=='KLDivergence':
            metric=tf.keras.metrics.KLDivergence()
        elif metrics=='Poisson':
            metric=tf.keras.metrics.Poisson()
    weights=0
    model=0
    if vy.ndim==1:
        vy=vy.reshape(vy.shape[0],1)
    if ifrandom_split=='yes':
        trainy,testy,trainx,testx = train_test_split(vy,vx,test_size=test_size,random_state=25)
    else:
        index=int((1-test_size)*vy.shape[0])
        trainy=vy[:index]
        testy=vy[index:]
        trainx=vx[:index,:,:]
        testx=vx[index:,:,:]
    train_position=np.zeros((trainx.shape[0],trainx.shape[1]))
    test_position=np.zeros((testx.shape[0],testx.shape[1]))
    for i in range(trainx.shape[0]):
        train_position[i,:]=np.arange(0,timestep,1)
    for i in range(testx.shape[0]):
        test_position[i,:]=np.arange(0,timestep,1)
    if task_mode!='regression':
        def create_sample_weights_for_batch_multitask(y_batch_multitask, list_of_task_weights_dicts):
            batch_size, num_tasks = y_batch_multitask.shape
            
            if len(list_of_task_weights_dicts) != num_tasks:
                raise ValueError(f"Number of tasks in y_batch_multitask ({num_tasks}) "
                                 f"must match length of list_of_task_weights_dicts ({len(list_of_task_weights_dicts)}).")
        
            sample_weight_batch = np.ones_like(y_batch_multitask, dtype=np.float32)
        
            for i in range(num_tasks):
                task_labels_current_channel = y_batch_multitask[:, i] 
                weights_dict_for_task_i = list_of_task_weights_dicts[i]
                
                weight_for_0 = weights_dict_for_task_i.get(0, 1.0)
                weight_for_1 = weights_dict_for_task_i.get(1, 1.0)
                
                current_task_weights = sample_weight_batch[:, i] 
                current_task_weights[task_labels_current_channel == 0] = weight_for_0
                current_task_weights[task_labels_current_channel == 1] = weight_for_1
                sample_weight_batch[:, i] = current_task_weights
                
            return sample_weight_batch
        def compute_unified_class_weights(y, task_mode=task_mode):
            if task_mode == 'binary_classify':
                if y.ndim == 2 and y.shape[-1] > 1:
                    num_tasks = y.shape[-1]
                    list_of_task_weights_dicts = []
                    possible_binary_classes = np.array([0, 1])
                    for i in range(num_tasks):
                        y_task_i_flat = y[:, i].ravel()
                        if len(y_task_i_flat) == 0:
                            weights_dict_task_i = {0: 1.0, 1: 1.0} 
                        else:
                            valid_labels_mask = np.isin(y_task_i_flat, possible_binary_classes)
                            if not np.all(valid_labels_mask) and np.any(valid_labels_mask): 
                                y_task_i_flat_filtered = y_task_i_flat[valid_labels_mask]
                                if len(y_task_i_flat_filtered) == 0 : y_task_i_flat_filtered = np.array([0]) 
                            elif not np.any(valid_labels_mask): 
                                 y_task_i_flat_filtered = np.array([0]) 
                            else:
                                y_task_i_flat_filtered = y_task_i_flat
                            class_weights_arr = compute_class_weight(
                                class_weight='balanced',
                                classes=possible_binary_classes, 
                                y=y_task_i_flat_filtered
                            )
                            weights_dict_task_i = dict(zip(possible_binary_classes, class_weights_arr))
                        list_of_task_weights_dicts.append(weights_dict_task_i)
                    return list_of_task_weights_dicts 
        
                else: 
                    y_flat = y.ravel()
                    possible_binary_classes = np.array([0, 1])
                    valid_labels_mask = np.isin(y_flat, possible_binary_classes)
                    if not np.all(valid_labels_mask) and np.any(valid_labels_mask):
                        y_flat_filtered = y_flat[valid_labels_mask]
                        if len(y_flat_filtered) == 0 : y_flat_filtered = np.array([0])
                    elif not np.any(valid_labels_mask):
                         y_flat_filtered = np.array([0])
                    else:
                        y_flat_filtered = y_flat
        
                    class_weights_arr = compute_class_weight(
                        class_weight='balanced',
                        classes=possible_binary_classes,
                        y=y_flat_filtered
                    )
                    return dict(zip(possible_binary_classes, class_weights_arr)) 
        
            elif task_mode == 'multi_classify':
                y_flat = y.ravel()
                possible_multiclass_classes = np.arange(int(np.max(y)+1))
                valid_labels_mask = np.isin(y_flat, possible_multiclass_classes)
                if not np.all(valid_labels_mask) and np.any(valid_labels_mask):
                    y_flat_filtered = y_flat[valid_labels_mask]
                    if len(y_flat_filtered) == 0 : y_flat_filtered = np.array([0]) 
                elif not np.any(valid_labels_mask):
                    y_flat_filtered = np.array([0]) 
                else:
                    y_flat_filtered = y_flat
                class_weights_arr = compute_class_weight(
                    class_weight='balanced',
                    classes=possible_multiclass_classes,
                    y=y_flat_filtered
                )
                return dict(zip(possible_multiclass_classes, class_weights_arr)) 
            else:
                raise ValueError(f"Unsupported task_mode: {task_mode}")
        def create_unified_sample_weights_for_batch(y_batch, unified_class_weights):
            if isinstance(unified_class_weights, list):
                if not (y_batch.ndim == 2 and y_batch.shape[-1] == len(unified_class_weights)):
                     raise ValueError(f"Shape mismatch for multi-task binary weights. "
                                      f"y_batch shape: {y_batch.shape}, num_weight_dicts: {len(unified_class_weights)}")
                return create_sample_weights_for_batch_multitask(y_batch, unified_class_weights)
            elif isinstance(unified_class_weights, dict):
                y_int_labels_for_weights = y_batch
                if y_batch.ndim == 2 and y_batch.shape[-1] == 1: 
                    y_int_labels_for_weights = np.squeeze(y_batch, axis=-1)
                sample_weight_for_batch = np.ones_like(y_int_labels_for_weights, dtype=np.float32)
                for class_label, weight in unified_class_weights.items():
                    sample_weight_for_batch[y_int_labels_for_weights == class_label] = weight
                
                return sample_weight_for_batch
            else:
                raise TypeError(f"unified_class_weights has unexpected type: {type(unified_class_weights)}. Expected dict or list.")
        def train_data_generator(x,position, y, batch_size, task_mode=task_mode):
            num_samples = x.shape[0]
            global_unified_weights = compute_unified_class_weights(y, task_mode)
            while True:
                indices = np.arange(num_samples)
                
                for start in range(0, num_samples, batch_size):
                    end = min(start + batch_size, num_samples)
                    batch_indices = indices[start:end]
                    
                    if len(batch_indices) == 0:
                        continue
        
                    x_batch = x[batch_indices]
                    position_batch = position[batch_indices]
                    y_batch = y[batch_indices] 
                    sample_weight_batch = create_unified_sample_weights_for_batch(
                        y_batch, 
                        global_unified_weights
                    )
                    yield {"input_1": x_batch, "input_2": position_batch}, y_batch, sample_weight_batch
    else:
        def train_data_generator(x, position, y, batch_size):
            num_samples = x.shape[0]
            while True:
                indices = np.arange(num_samples)
                
                for start in range(0, num_samples, batch_size):
                    end = min(start + batch_size, num_samples)
                    batch_indices = indices[start:end]
                    
                    x_batch = x[batch_indices]  
                    position_batch = position[batch_indices]   
                    y_batch = y[batch_indices]      
                    
                    yield ({"input_1": x_batch, "input_2": position_batch}, y_batch)
    def test_data_generator(x, position, batch_size):
        num_samples = x.shape[0]
        while True:
            indices = np.arange(num_samples)
            
            for start in range(0, num_samples, batch_size):
                end = min(start + batch_size, num_samples)
                batch_indices = indices[start:end]
                
                x_batch = x[batch_indices]  # 第一个输入特征
                position_batch = position[batch_indices]
                
                yield ({"input_1": x_batch, "input_2": position_batch})
    if if_best_mode=='no':
        inputshape1=(None,timestep,trainx.shape[2])
        inputshape2=(None,timestep)
        inputs1=Input(shape=(timestep,trainx.shape[2]))
        inputs2=Input(shape=(timestep))
        for i in range(len(model_list)):
            if model_list[i][0] == 'transformer':
                position_embedding=Embedding(embedding_num,trainx.shape[2],input_length=timestep,input_shape=inputshape2)(inputs2)
                add=Add(input_shape=inputshape1)([inputs1,position_embedding])
                for j in range(encoder_deep):
                    if j ==0:
                        exec('en_multihead'+str(j+1)+'=MultiHeadAttention(num_heads=num_heads,key_dim=key_dim,dropout=trans_dropout_rate,attention_axes=1)(add,add,add)')
                        exec('en_add'+str(j+1)+'=Add()([add,en_multihead'+str(j+1)+'])')
                    else:
                        exec('en_multihead'+str(j+1)+'=MultiHeadAttention(num_heads=num_heads,key_dim=key_dim,dropout=trans_dropout_rate,attention_axes=1)(en_layernormalization'+str(j)+',en_layernormalization'+str(j)+',en_layernormalization'+str(j)+')')
                        exec('en_add'+str(j+1)+'=Add()([en_layernormalization'+str(j)+',en_multihead'+str(j+1)+'])')
                    exec('en_layernormalization'+str(j+1)+'=LayerNormalization()(en_add'+str(j+1)+')')
                    if ifdropout=='yes':
                        exec('en_dropout'+str(j+1)+'=Dropout(trans_dropout_rate)(en_layernormalization'+str(j+1)+')')
                        exec('en_fc'+str(j+1)+'=Dense(trans_units,activation=trans_activation)(en_dropout'+str(j+1)+')')
                        exec('en_fc'+str(j+1)+'=Dense(trainx.shape[2],activation=trans_activation)(en_fc'+str(j+1)+')')
                    else:
                        exec('en_fc'+str(j+1)+'=Dense(trans_units,activation=trans_activation)(en_layernormalization'+str(j+1)+')')
                        exec('en_fc'+str(j+1)+'=Dense(trainx.shape[2],activation=trans_activation)(en_fc'+str(j+1)+')')
                    exec('en_add'+str(j+1)+'=Add()([en_fc'+str(j+1)+',en_layernormalization'+str(j+1)+'])')
                    exec('en_layernormalization'+str(j+1)+'=LayerNormalization()(en_add'+str(j+1)+')')
                exec('en_fla=Flatten()(en_layernormalization'+str(j+1)+')')
            elif model_list[i][0] == 'batchnormalization':
                if i==len(model_list)-1:
                    if model_list[i-1][0]=='transformer':
                        outputs=eval('BatchNormalization(axis=-1)(en_fla)')
                    elif model_list[i-1][0]=='batchnormalization' or model_list[i-1][0]=='layernormalization':
                        outputs=eval('BatchNormalization(axis=-1)(norm'+str(i)+')')
                    elif model_list[i-1][0]=='activation':
                        outputs=eval('BatchNormalization(axis=-1)(act'+str(i)+')')
                    elif model_list[i-1][0]=='dropout' :
                        outputs=eval('BatchNormalization(axis=-1)(drop'+str(i)+')')
                    elif model_list[i-1][0]=='fc':
                        outputs=eval('BatchNormalization(axis=-1)(fc'+str(i)+')')
                    elif model_list[i-1][0]=='flatten':
                        outputs=eval('BatchNormalization(axis=-1)(fla'+str(i)+')')
                else:
                    if model_list[i-1][0]=='transformer':
                        exec('norm'+str(i+1)+'=BatchNormalization(axis=-1)(en_fla)')
                    elif model_list[i-1][0]=='batchnormalization' or model_list[i-1][0]=='layernormalization':
                        exec('norm'+str(i+1)+'=BatchNormalization(axis=-1)(norm'+str(i)+')')
                    elif model_list[i-1][0]=='activation':
                        exec('norm'+str(i+1)+'=BatchNormalization(axis=-1)(act'+str(i)+')')
                    elif model_list[i-1][0]=='dropout' :
                        exec('norm'+str(i+1)+'=BatchNormalization(axis=-1)(drop'+str(i)+')')
                    elif model_list[i-1][0]=='fc':
                        exec('norm'+str(i+1)+'=BatchNormalization(axis=-1)(fc'+str(i)+')')
                    elif model_list[i-1][0]=='flatten':
                        exec('norm'+str(i+1)+'=BatchNormalization(axis=-1)(fla'+str(i)+')')
            elif model_list[i][0] == 'layernormalization':
                if i==len(model_list)-1:
                    if model_list[i-1][0]=='transformer':
                        outputs=eval('LayerNormalization(axis=-1)(en_fla)')
                    elif model_list[i-1][0]=='batchnormalization' or model_list[i-1][0]=='layernormalization':
                        outputs=eval('LayerNormalization(axis=-1)(norm'+str(i)+')')
                    elif model_list[i-1][0]=='activation':
                        outputs=eval('LayerNormalization(axis=-1)(act'+str(i)+')')
                    elif model_list[i-1][0]=='dropout' :
                        outputs=eval('LayerNormalization(axis=-1)(drop'+str(i)+')')
                    elif model_list[i-1][0]=='fc':
                        outputs=eval('LayerNormalization(axis=-1)(fc'+str(i)+')')
                    elif model_list[i-1][0]=='flatten':
                        outputs=eval('LayerNormalization(axis=-1)(fla'+str(i)+')')
                else:
                    if model_list[i-1][0]=='transformer':
                        exec('norm'+str(i+1)+'=LayerNormalization(axis=-1)(en_fla)')
                    elif model_list[i-1][0]=='batchnormalization' or model_list[i-1][0]=='layernormalization':
                        exec('norm'+str(i+1)+'=LayerNormalization(axis=-1)(norm'+str(i)+')')
                    elif model_list[i-1][0]=='activation':
                        exec('norm'+str(i+1)+'=LayerNormalization(axis=-1)(act'+str(i)+')')
                    elif model_list[i-1][0]=='dropout' :
                        exec('norm'+str(i+1)+'=LayerNormalization(axis=-1)(drop'+str(i)+')')
                    elif model_list[i-1][0]=='fc':
                        exec('norm'+str(i+1)+'=LayerNormalization(axis=-1)(fc'+str(i)+')')
                    elif model_list[i-1][0]=='flatten':
                        exec('norm'+str(i+1)+'=LayerNormalization(axis=-1)(fla'+str(i)+')')
            elif model_list[i][0] == 'activation':
                if i==len(model_list)-1:
                    if model_list[i-1][0]=='transformer':
                        outputs=eval('Activation(model_list[i][1])(en_fla)')
                    elif model_list[i-1][0]=='batchnormalization' or model_list[i-1][0]=='layernormalization':
                        outputs=eval('Activation(model_list[i][1])(norm'+str(i)+')')
                    elif model_list[i-1][0]=='activation':
                        outputs=eval('Activation(model_list[i][1])(act'+str(i)+')')
                    elif model_list[i-1][0]=='dropout' :
                        outputs=eval('Activation(model_list[i][1])(drop'+str(i)+')')
                    elif model_list[i-1][0]=='fc':
                        outputs=eval('Activation(model_list[i][1])(fc'+str(i)+')')
                    elif model_list[i-1][0]=='flatten':
                        outputs=eval('Activation(model_list[i][1])(fla'+str(i)+')')
                else:
                    if model_list[i-1][0]=='transformer':
                        exec('act'+str(i+1)+'=Activation(model_list[i][1])(en_fla)')
                    elif model_list[i-1][0]=='batchnormalization' or model_list[i-1][0]=='layernormalization':
                        exec('act'+str(i+1)+'=Activation(model_list[i][1])(norm'+str(i)+')')
                    elif model_list[i-1][0]=='activation':
                        exec('act'+str(i+1)+'=Activation(model_list[i][1])(act'+str(i)+')')
                    elif model_list[i-1][0]=='dropout' :
                        exec('act'+str(i+1)+'=Activation(model_list[i][1])(drop'+str(i)+')')
                    elif model_list[i-1][0]=='fc':
                        exec('act'+str(i+1)+'=Activation(model_list[i][1])(fc'+str(i)+')')
                    elif model_list[i-1][0]=='flatten':
                        exec('act'+str(i+1)+'=Activation(model_list[i][1])(fla'+str(i)+')')
            elif model_list[i][0] == 'flatten':
                if model_list[i-1][0]=='transformer':
                    exec('fla'+str(i+1)+'=Flatten()(en_fla)')
                elif model_list[i-1][0]=='batchnormalization' or model_list[i-1][0]=='layernormalization':
                    exec('fla'+str(i+1)+'=Flatten()(norm'+str(i)+')')
                elif model_list[i-1][0]=='activation':
                    exec('fla'+str(i+1)+'=Flatten()(act'+str(i)+')')
                elif model_list[i-1][0]=='dropout':
                    exec('fla'+str(i+1)+'=Flatten()(drop'+str(i)+')')
            elif model_list[i][0] =='fc':
                if if_weight_initialize=='no':
                    if i==len(model_list)-1:
                        if model_list[i-1][0]=='transformer':
                            outputs=eval('Dense(model_list[i][1])(en_fla)')
                        elif model_list[i-1][0]=='batchnormalization' or model_list[i-1][0]=='layernormalization':
                            outputs=eval('Dense(model_list[i][1])(norm'+str(i)+')')
                        elif model_list[i-1][0]=='activation':
                            outputs=eval('Dense(model_list[i][1])(act'+str(i)+')')
                        elif model_list[i-1][0]=='dropout':
                            outputs=eval('Dense(model_list[i][1])(drop'+str(i)+')')
                        elif model_list[i-1][0]=='fc':
                            outputs=eval('Dense(model_list[i][1])(fc'+str(i)+')')
                        elif model_list[i-1][0]=='flatten':
                            outputs=eval('Dense(model_list[i][1])(fla'+str(i)+')')
                    else:
                        if model_list[i-1][0]=='transformer':
                            exec('fc'+str(i+1)+'=Dense(model_list[i][1])(en_fla)')
                        elif model_list[i-1][0]=='batchnormalization' or model_list[i-1][0]=='layernormalization':
                            exec('fc'+str(i+1)+'=Dense(model_list[i][1])(norm'+str(i)+')')
                        elif model_list[i-1][0]=='activation':
                            exec('fc'+str(i+1)+'=Dense(model_list[i][1])(act'+str(i)+')')
                        elif model_list[i-1][0]=='dropout':
                            exec('fc'+str(i+1)+'=Dense(model_list[i][1])(drop'+str(i)+')')
                        elif model_list[i-1][0]=='fc':
                            exec('fc'+str(i+1)+'=Dense(model_list[i][1])(fc'+str(i)+')')
                        elif model_list[i-1][0]=='flatten':
                            exec('fc'+str(i+1)+'=Dense(model_list[i][1])(fla'+str(i)+')')
                else:
                    if i==len(model_list)-1:
                        if model_list[i-1][0]=='transformer':
                            if weight_initialize_method=='RandomNormal':
                                outputs=eval('Dense(model_list[i][1],kernel_initializer = RandomNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(en_fla)')
                            elif weight_initialize_method=='RandomUniform':
                                outputs=eval('Dense(model_list[i][1],kernel_initializer = RandomUniform(minval=weight_initialize_parameter1,maxval=weight_initialize_parameter2))(en_fla)')
                            elif weight_initialize_method=='TruncatedNormal':
                                outputs=eval('Dense(model_list[i][1],kernel_initializer = TruncatedNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(en_fla)')
                        elif model_list[i-1][0]=='batchnormalization' or model_list[i-1][0]=='layernormalization':
                            if weight_initialize_method=='RandomNormal':
                                outputs=eval('Dense(model_list[i][1],kernel_initializer = RandomNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(norm'+str(i)+')')
                            elif weight_initialize_method=='RandomUniform':
                                outputs=eval('Dense(model_list[i][1],kernel_initializer = RandomUniform(minval=weight_initialize_parameter1,maxval=weight_initialize_parameter2))(norm'+str(i)+')')
                            elif weight_initialize_method=='TruncatedNormal':
                                outputs=eval('Dense(model_list[i][1],kernel_initializer = TruncatedNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(norm'+str(i)+')')
                        elif model_list[i-1][0]=='activation':
                            if weight_initialize_method=='RandomNormal':
                                outputs=eval('Dense(model_list[i][1],kernel_initializer = RandomNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(act'+str(i)+')')
                            elif weight_initialize_method=='RandomUniform':
                                outputs=eval('Dense(model_list[i][1],kernel_initializer = RandomUniform(minval=weight_initialize_parameter1,maxval=weight_initialize_parameter2))(act'+str(i)+')')
                            elif weight_initialize_method=='TruncatedNormal':
                                outputs=eval('Dense(model_list[i][1],kernel_initializer = TruncatedNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(act'+str(i)+')')
                        elif model_list[i-1][0]=='dropout':
                            if weight_initialize_method=='RandomNormal':
                                outputs=eval('Dense(model_list[i][1],kernel_initializer = RandomNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(drop'+str(i)+')')
                            elif weight_initialize_method=='RandomUniform':
                                outputs=eval('Dense(model_list[i][1],kernel_initializer = RandomUniform(minval=weight_initialize_parameter1,maxval=weight_initialize_parameter2))(drop'+str(i)+')')
                            elif weight_initialize_method=='TruncatedNormal':
                                outputs=eval('Dense(model_list[i][1],kernel_initializer = TruncatedNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(drop'+str(i)+')')
                        elif model_list[i-1][0]=='fc':   
                            if weight_initialize_method=='RandomNormal':
                                outputs=eval('Dense(model_list[i][1],kernel_initializer = RandomNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(fc'+str(i)+')')
                            elif weight_initialize_method=='RandomUniform':
                                outputs=eval('Dense(model_list[i][1],kernel_initializer = RandomUniform(minval=weight_initialize_parameter1,maxval=weight_initialize_parameter2))(fc'+str(i)+')')
                            elif weight_initialize_method=='TruncatedNormal':
                                outputs=eval('Dense(model_list[i][1],kernel_initializer = TruncatedNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(fc'+str(i)+')')
                        elif model_list[i-1][0]=='flatten':
                            if weight_initialize_method=='RandomNormal':
                                outputs=eval('Dense(model_list[i][1],kernel_initializer = RandomNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(fla'+str(i)+')')
                            elif weight_initialize_method=='RandomUniform':
                                outputs=eval('Dense(model_list[i][1],kernel_initializer = RandomUniform(minval=weight_initialize_parameter1,maxval=weight_initialize_parameter2))(fla'+str(i)+')')
                            elif weight_initialize_method=='TruncatedNormal':
                                outputs=eval('Dense(model_list[i][1],kernel_initializer = TruncatedNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(fla'+str(i)+')')
                    else:
                        if model_list[i-1][0]=='transformer':
                            if weight_initialize_method=='RandomNormal':
                                exec('fc'+str(i+1)+'=Dense(model_list[i][1],kernel_initializer = RandomNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(en_fla)')
                            elif weight_initialize_method=='RandomUniform':
                                exec('fc'+str(i+1)+'=Dense(model_list[i][1],kernel_initializer = RandomUniform(minval=weight_initialize_parameter1,maxval=weight_initialize_parameter2))(en_fla)')
                            elif weight_initialize_method=='TruncatedNormal':
                                exec('fc'+str(i+1)+'=Dense(model_list[i][1],kernel_initializer = TruncatedNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(en_fla)')
                        elif model_list[i-1][0]=='batchnormalization' or model_list[i-1][0]=='layernormalization':
                            if weight_initialize_method=='RandomNormal':
                                exec('fc'+str(i+1)+'=Dense(model_list[i][1],kernel_initializer = RandomNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(norm'+str(i)+')')
                            elif weight_initialize_method=='RandomUniform':
                                exec('fc'+str(i+1)+'=Dense(model_list[i][1],kernel_initializer = RandomUniform(minval=weight_initialize_parameter1,maxval=weight_initialize_parameter2))(norm'+str(i)+')')
                            elif weight_initialize_method=='TruncatedNormal':
                                exec('fc'+str(i+1)+'=Dense(model_list[i][1],kernel_initializer = TruncatedNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(norm'+str(i)+')')
                        elif model_list[i-1][0]=='activation':
                            if weight_initialize_method=='RandomNormal':
                                exec('fc'+str(i+1)+'=Dense(model_list[i][1],kernel_initializer = RandomNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(act'+str(i)+')')
                            elif weight_initialize_method=='RandomUniform':
                                exec('fc'+str(i+1)+'=Dense(model_list[i][1],kernel_initializer = RandomUniform(minval=weight_initialize_parameter1,maxval=weight_initialize_parameter2))(act'+str(i)+')')
                            elif weight_initialize_method=='TruncatedNormal':
                                exec('fc'+str(i+1)+'=Dense(model_list[i][1],kernel_initializer = TruncatedNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(act'+str(i)+')')
                        elif model_list[i-1][0]=='dropout':
                            if weight_initialize_method=='RandomNormal':
                                exec('fc'+str(i+1)+'=Dense(model_list[i][1],kernel_initializer = RandomNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(drop'+str(i)+')')
                            elif weight_initialize_method=='RandomUniform':
                                exec('fc'+str(i+1)+'=Dense(model_list[i][1],kernel_initializer = RandomUniform(minval=weight_initialize_parameter1,maxval=weight_initialize_parameter2))(drop'+str(i)+')')
                            elif weight_initialize_method=='TruncatedNormal':
                                exec('fc'+str(i+1)+'=Dense(model_list[i][1],kernel_initializer = TruncatedNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(drop'+str(i)+')')
                        elif model_list[i-1][0]=='fc':   
                            if weight_initialize_method=='RandomNormal':
                                exec('fc'+str(i+1)+'=Dense(model_list[i][1],kernel_initializer = RandomNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(fc'+str(i)+')')
                            elif weight_initialize_method=='RandomUniform':
                                exec('fc'+str(i+1)+'=Dense(model_list[i][1],kernel_initializer = RandomUniform(minval=weight_initialize_parameter1,maxval=weight_initialize_parameter2))(fc'+str(i)+')')
                            elif weight_initialize_method=='TruncatedNormal':
                                exec('fc'+str(i+1)+'=Dense(model_list[i][1],kernel_initializer = TruncatedNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(fc'+str(i)+')')
                        elif model_list[i-1][0]=='flatten':
                            if weight_initialize_method=='RandomNormal':
                                exec('fc'+str(i+1)+'=Dense(model_list[i][1],kernel_initializer = RandomNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(fla'+str(i)+')')
                            elif weight_initialize_method=='RandomUniform':
                                exec('fc'+str(i+1)+'=Dense(model_list[i][1],kernel_initializer = RandomUniform(minval=weight_initialize_parameter1,maxval=weight_initialize_parameter2))(fla'+str(i)+')')
                            elif weight_initialize_method=='TruncatedNormal':
                                exec('fc'+str(i+1)+'=Dense(model_list[i][1],kernel_initializer = TruncatedNormal(mean=weight_initialize_parameter1,stddev=weight_initialize_parameter2))(fla'+str(i)+')')
            elif model_list[i][0] == 'dropout':
                if i==len(model_list)-1:
                    if model_list[i-1][0]=='transformer':
                        outputs=eval('Dropout(model_list[i][1])(en_fla)')
                    elif model_list[i-1][0]=='batchnormalization' or model_list[i-1][0]=='layernormalization':
                        outputs=eval('Dropout(model_list[i][1])(norm'+str(i)+')')
                    elif model_list[i-1][0]=='activation':
                        outputs=eval('Dropout(model_list[i][1])(act'+str(i)+')')
                    elif model_list[i-1][0]=='dropout' :
                        outputs=eval('Dropout(model_list[i][1])(drop'+str(i)+')')
                    elif model_list[i-1][0]=='fc':
                        outputs=eval('Dropout(model_list[i][1])(fc'+str(i)+')')
                    elif model_list[i-1][0]=='flatten':
                        outputs=eval('Dropout(model_list[i][1])(fla'+str(i)+')')
                else:
                    if model_list[i-1][0]=='transformer':
                        exec('drop'+str(i+1)+'=Dropout(model_list[i][1])(en_fla)')
                    elif model_list[i-1][0]=='batchnormalization' or model_list[i-1][0]=='layernormalization':
                        exec('drop'+str(i+1)+'=Dropout(model_list[i][1])(norm'+str(i)+')')
                    elif model_list[i-1][0]=='activation':
                        exec('drop'+str(i+1)+'=Dropout(model_list[i][1])(act'+str(i)+')')
                    elif model_list[i-1][0]=='dropout' :
                        exec('drop'+str(i+1)+'=Dropout(model_list[i][1])(drop'+str(i)+')')
                    elif model_list[i-1][0]=='fc':
                        exec('drop'+str(i+1)+'=Dropout(model_list[i][1])(fc'+str(i)+')')
                    elif model_list[i-1][0]=='flatten':
                        exec('drop'+str(i+1)+'=Dropout(model_list[i][1])(fla'+str(i)+')')
        model=eval('Model(inputs=[inputs1,inputs2], outputs=outputs)')
        if optimizer == 'SGD':
            opt = SGD(lr = learning_rate)
        elif optimizer == 'Adam':
            opt = Adam(lr = learning_rate)
        model.compile(loss=loss,optimizer=opt,metrics=[metric])
    elif if_best_mode=='yes' or if_best_mode=='load':
        if k_fold!=None:
            models=[]
            for i in range(k_fold):
                models.append(load_model(modelpath+'_'+str(i+1)))
        else:
            model=load_model(modelpath)
    if if_print_model=='yes':
        if k_fold!=None:
            if if_best_mode=='yes' or if_best_mode=='load':
                print(models[0].summary())
            else:
                print(model.summary())
        else:
            print(model.summary())
    if epochs!=0:
        if valid_size!=None or k_fold !=None:
            if k_fold!=None:
                if if_best_mode=='no' :
                    models = []
                    if ifrandom_split=='yes':
                        kf = KFold(n_splits=k_fold, shuffle=True, random_state=25)
                    else:
                        kf = KFold(n_splits=k_fold, shuffle=False)
                    for fold_no, (train_idx, val_idx) in enumerate(kf.split(trainx, trainy)):
                        X_train_fold, y_train_fold, position_train_fold = trainx[train_idx], trainy[train_idx], train_position[train_idx]
                        X_val_fold, y_val_fold, position_val_fold = trainx[val_idx], trainy[val_idx], train_position[val_idx]
                        model=eval('Model(inputs=[inputs1,inputs2], outputs=outputs)')
                        model.compile(loss=loss,optimizer=opt,metrics=[metric])
                        if if_early_stopping!=None:
                            H = model.fit(train_data_generator(X_train_fold,position_train_fold,y_train_fold,batch_size),steps_per_epoch=(len(X_train_fold) // batch_size+(1 if len(X_train_fold) % batch_size != 0 else 0)),validation_data=train_data_generator(X_val_fold,position_val_fold,y_val_fold,batch_size),validation_steps=(len(X_val_fold) // batch_size+(1 if len(X_val_fold) % batch_size != 0 else 0)),epochs = epochs,callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=if_early_stopping,restore_best_weights=True)])
                        else:
                            H = model.fit(train_data_generator(X_train_fold,position_train_fold,y_train_fold,batch_size),steps_per_epoch=(len(X_train_fold) // batch_size+(1 if len(X_train_fold) % batch_size != 0 else 0)),validation_data=train_data_generator(X_val_fold,position_val_fold,y_val_fold,batch_size),validation_steps=(len(X_val_fold) // batch_size+(1 if len(X_val_fold) % batch_size != 0 else 0)),epochs = epochs)
                        models.append(model)
                else:
                    models_new = []
                    if ifrandom_split=='yes':
                        kf = KFold(n_splits=k_fold, shuffle=True, random_state=25)
                    else:
                        kf = KFold(n_splits=k_fold, shuffle=False)
                    for fold_no, (train_idx, val_idx) in enumerate(kf.split(trainx, trainy)):
                        X_train_fold, y_train_fold, position_train_fold = trainx[train_idx], trainy[train_idx], train_position[train_idx]
                        X_val_fold, y_val_fold, position_val_fold = trainx[val_idx], trainy[val_idx], train_position[val_idx]
                        if if_early_stopping!=None:
                            H = models[fold_no].fit(train_data_generator(X_train_fold,position_train_fold,y_train_fold,batch_size),steps_per_epoch=(len(X_train_fold) // batch_size+(1 if len(X_train_fold) % batch_size != 0 else 0)),validation_data=train_data_generator(X_val_fold,position_val_fold,y_val_fold,batch_size),validation_steps=(len(X_val_fold) // batch_size+(1 if len(X_val_fold) % batch_size != 0 else 0)),epochs = epochs,callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=if_early_stopping,restore_best_weights=True)])
                        else:
                            H = models[fold_no].fit(train_data_generator(X_train_fold,position_train_fold,y_train_fold,batch_size),steps_per_epoch=(len(X_train_fold) // batch_size+(1 if len(X_train_fold) % batch_size != 0 else 0)),validation_data=train_data_generator(X_val_fold,position_val_fold,y_val_fold,batch_size),validation_steps=(len(X_val_fold) // batch_size+(1 if len(X_val_fold) % batch_size != 0 else 0)),epochs = epochs)
                        models_new.append(models[fold_no])
                    models=models_new
            else:
                if ifrandom_split=='yes':
                    trainy,validy,trainx,validx,train_position,valid_position = train_test_split(trainy,trainx,train_position,test_size=valid_size/(1-test_size),random_state=25)
                else:
                    index=int((1-valid_size/(1-test_size))*trainy.shape[0])
                    validy=trainy[index:]
                    trainy=trainy[:index]
                    validx=trainx[index:]
                    trainx=trainx[:index]
                    valid_position=train_position[index:]
                    train_position=train_position[:index]
                if if_early_stopping!=None:
                    H = model.fit(train_data_generator(trainx,train_position,trainy,batch_size),steps_per_epoch=(len(trainx) // batch_size+(1 if len(trainx) % batch_size != 0 else 0)),validation_data=train_data_generator(validx,valid_position,validy,batch_size),validation_steps=(len(validx) // batch_size+(1 if len(validx) % batch_size != 0 else 0)),epochs = epochs,callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=if_early_stopping,restore_best_weights=True)])
                else:
                    H = model.fit(train_data_generator(trainx,train_position,trainy,batch_size),steps_per_epoch=(len(trainx) // batch_size+(1 if len(trainx) % batch_size != 0 else 0)),validation_data=train_data_generator(validx,valid_position,validy,batch_size),validation_steps=(len(validx) // batch_size+(1 if len(validx) % batch_size != 0 else 0)),epochs = epochs)
        else:
            if if_early_stopping!=None:
                H = model.fit(train_data_generator(trainx,train_position,trainy,batch_size),steps_per_epoch=(len(trainx) // batch_size+(1 if len(trainx) % batch_size != 0 else 0)),validation_data=train_data_generator(testx,test_position,testy,batch_size),validation_steps=(len(testx) // batch_size+(1 if len(testx) % batch_size != 0 else 0)),epochs = epochs,callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=if_early_stopping,restore_best_weights=True)])
            else:
                H = model.fit(train_data_generator(trainx,train_position,trainy,batch_size),steps_per_epoch=(len(trainx) // batch_size+(1 if len(trainx) % batch_size != 0 else 0)),validation_data=train_data_generator(testx,test_position,testy,batch_size),validation_steps=(len(testx) // batch_size+(1 if len(testx) % batch_size != 0 else 0)),epochs = epochs)
    if k_fold!=None:
        predicty = [model.predict(test_data_generator(testx,test_position,batch_size),steps=(len(testx) // batch_size+(1 if len(testx) % batch_size != 0 else 0))) for model in models]
        predicty=np.nanmean(predicty,axis=0)
    else:
        predicty = model.predict(test_data_generator(testx,test_position,batch_size),steps=(len(testx) // batch_size+(1 if len(testx) % batch_size != 0 else 0)))
    predicty = np.nan_to_num(predicty,nan=0)
    if task_mode=='regression':
        r=np.zeros((testy.shape[1]))
        p=np.zeros((testy.shape[1]))
        for i in range(testy.shape[1]):
            r[i],p[i] = pearsonr(predicty[:,i],testy[:,i])
            r=np.nan_to_num(r,nan=0)
    elif task_mode=='binary_classify':
        accuracy=np.zeros((testy.shape[1]))
        recall=np.zeros((testy.shape[1]))
        precision=np.zeros((testy.shape[1]))
        f1=np.zeros((testy.shape[1]))
        for i in range(predicty.shape[1]):
            predicty[:,i]=[int(round(predicty[j,i],0)) for j in range(predicty.shape[0])]
        r=np.zeros((testy.shape[1]))
        for i in range(testy.shape[1]):
            if metrics=='Recall':
                r[i]=recall_score(testy[:,i], predicty[:,i])
            elif metrics=='Precision':
                r[i]=precision_score(testy[:,i], predicty[:,i])
            else:
                r[i]=accuracy_score(testy[:,i], predicty[:,i])
            recall[i]=recall_score(testy[:,i], predicty[:,i])
            precision[i]=precision_score(testy[:,i], predicty[:,i])
            accuracy[i]=accuracy_score(testy[:,i], predicty[:,i])
            f1[i]=f1_score(testy[:,i], predicty[:,i])
        p=0
    elif task_mode=='multi_classify':
        r=np.zeros((testy.shape[1]))
        for i in range(testy.shape[1]):
            r[i]=accuracy_score(testy[:,i], np.argmax(predicty,axis=1))
        p=0
    if ifmute == 'no':
        if task_mode=='regression':
            print('相关系数',np.nanmean(r))
        elif task_mode=='binary_classify':
            print('召回率+精确率',np.nanmean(f1),'准确率',np.nanmean(accuracy),'召回率',np.nanmean(recall),'精确率',np.nanmean(precision))
        elif task_mode=='multi_classify':
            print('准确率',np.nanmean(r))
    if ifweight=='yes':
        weights=np.zeros((testy.shape[1],testx.shape[2]))
        weight_more=np.zeros((testy.shape[1],testx.shape[2]))
        for i in range(testy.shape[1]):
            for j in range(testx.shape[2]):
                testx_new=copy.deepcopy(testx)
                weight=[]
                for k in range(10):
                    per=np.random.permutation(testx.shape[0])
                    testx_shuffle=testx[per,:,j]
                    testx_new[:,:,j]=testx_shuffle
                    if k_fold!=None:
                        predicty_new = [model.predict(test_data_generator(testx_new,test_position,batch_size),steps=(len(testx_new) // batch_size+(1 if len(testx_new) % batch_size != 0 else 0))) for model in models]
                        predicty_new=np.nanmean(predicty_new,axis=0)
                    else:
                        predicty_new = model.predict(test_data_generator(testx_new,test_position,batch_size),steps=(len(testx_new) // batch_size+(1 if len(testx_new) % batch_size != 0 else 0)))
                    if task_mode=='regression':
                        weight.append(sklearn.metrics.mean_squared_error(testy[:,i],predicty_new[:,i])-sklearn.metrics.mean_squared_error(testy[:,i],predicty[:,i]))
                    elif task_mode=='multi_classify':
                        weight.append(sklearn.metrics.log_loss(testy[:,i],predicty_new[:,:])-sklearn.metrics.log_loss(testy[:,i],predicty[:,:]))
                    else:
                        weight.append(sklearn.metrics.log_loss(testy[:,i],predicty_new[:,i])-sklearn.metrics.log_loss(testy[:,i],predicty[:,i]))
                weight_more[i,j]=np.nanmean(weight)
        for i in range(testy.shape[1]):
            for j in range(testx.shape[2]):
                weights[i,j]=(weight_more[i,j]/np.nansum(weight_more[i,:]))*100
                print('预报因子',j+1,'对预报值',i+1,'的贡献：',np.array(weights[i,j]),'％')
            print('\n')
    if ifsave=='yes':
        if k_fold!=None:
            for i, model in enumerate(models):
                model.save(savepath+'_'+str(i+1))
        else:
            model.save(savepath)
    if k_fold!=None:
        return models,predicty,testy,r,p,weights
    else:
        return model,predicty,testy,r,p,weights

In [10]:
model,predicty,testy,r,p,weights=Auto_Transformer(vy,vx,6,[['transformer'],['fc',126]],test_size=0.2,valid_size=0.1,k_fold=5,task_mode='regression',if_best_mode='no',modelpath=None,encoder_deep=1,num_heads=1,key_dim=1,ifdropout='no',trans_dropout_rate=0.0,trans_units=1215,trans_activation='tanh',embedding_num=None,if_weight_initialize='no',weight_initialize_method='TruncatedNormal',weight_initialize_parameter1=0.00,weight_initialize_parameter2=0.05,if_print_model='yes',loss_function='default',optimizer='Adam',metrics='Pearsonr',if_early_stopping=None,learning_rate=0.0001,epochs=500,batch_size=5000,ifrandom_split='yes',ifweight='no',ifmute='no',ifsave='yes',savepath='E:/huawei/huawei_gnss_wind_u_30min_hight_k5_morestation',device='gpu')

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 6)]          0           []                               
                                                                                                  
 input_1 (InputLayer)           [(None, 6, 1215)]    0           []                               
                                                                                                  
 embedding (Embedding)          (None, 6, 1215)      8505        ['input_2[0][0]']                
                                                                                                  
 add (Add)                      (None, 6, 1215)      0           ['input_1[0][0]',                
                                                                  'embedding[0][0]']          



INFO:tensorflow:Assets written to: E:/huawei/huawei_gnss_wind_u_30min_hight_k5_morestation_1\assets


INFO:tensorflow:Assets written to: E:/huawei/huawei_gnss_wind_u_30min_hight_k5_morestation_1\assets


INFO:tensorflow:Assets written to: E:/huawei/huawei_gnss_wind_u_30min_hight_k5_morestation_2\assets


INFO:tensorflow:Assets written to: E:/huawei/huawei_gnss_wind_u_30min_hight_k5_morestation_2\assets


INFO:tensorflow:Assets written to: E:/huawei/huawei_gnss_wind_u_30min_hight_k5_morestation_3\assets


INFO:tensorflow:Assets written to: E:/huawei/huawei_gnss_wind_u_30min_hight_k5_morestation_3\assets


INFO:tensorflow:Assets written to: E:/huawei/huawei_gnss_wind_u_30min_hight_k5_morestation_4\assets


INFO:tensorflow:Assets written to: E:/huawei/huawei_gnss_wind_u_30min_hight_k5_morestation_4\assets


INFO:tensorflow:Assets written to: E:/huawei/huawei_gnss_wind_u_30min_hight_k5_morestation_5\assets


INFO:tensorflow:Assets written to: E:/huawei/huawei_gnss_wind_u_30min_hight_k5_morestation_5\assets


In [11]:
print(np.nanmean(np.abs(testy-predicty)))

0.51111054


In [12]:
testy=np.array(testy).reshape(testy.shape[0],18,7)
predicty=np.array(predicty).reshape(predicty.shape[0],18,7)

In [13]:
print(testy.shape,predicty.shape)

(3519, 18, 7) (3519, 18, 7)


In [14]:
times=np.arange(3519)
testy_u=testy
predicty_u=predicty
levels=[110,760,1460,3010,4200,5570,7160]
station=np.array(wind_file['station'])

In [15]:
import numpy as np
import xarray as xr
import os



dim_names = ('times', 'station', 'levels')


testy_u_da = xr.DataArray(
    data=testy_u,
    coords={
        'times': times,
        'station': station,
        'levels': levels
    },
    dims=dim_names,
    name='testy_u' 
)


predicty_u_da = xr.DataArray(
    data=predicty_u,
    coords={
        'times': times,
        'station': station,
        'levels': levels
    },
    dims=dim_names,
    name='predicty_u'
)

testy_u_da.to_netcdf('E:/huawei/result/huawei_Transformer_30min_k5_gnss_to_wind_testy_u_hight_morestation.nc')
predicty_u_da.to_netcdf('E:/huawei/result/huawei_Transformer_30min_k5_gnss_to_wind_predicty_u_hight_morestation.nc')

In [16]:
import numpy as np
import xarray as xr
testy_u_file=xr.open_dataset('E:/huawei/result/huawei_Transformer_30min_k5_gnss_to_wind_testy_u_hight_morestation.nc')
predicty_u_file=xr.open_dataset('E:/huawei/result/huawei_Transformer_30min_k5_gnss_to_wind_predicty_u_hight_morestation.nc')
testy_u=np.array(testy_u_file['testy_u'])
predicty_u=np.array(predicty_u_file['predicty_u'])

In [17]:
#CDF匹配
def Auto_cdf_matching(vx,vy):
    import numpy as np
    from scipy.interpolate import InterpolatedUnivariateSpline
    from scipy.optimize import curve_fit

    if np.array(vx).ndim==1:
        vx_cdf = (np.arange(len(vx)) +  1) / (len(vx))
        vy_cdf = (np.arange(len(vy)) +  1) / (len(vy))
        
        spl = InterpolatedUnivariateSpline(vx_cdf, np.sort(vx))
        vx_interp = spl(vy_cdf)
        
        def func(x, a, b, c, d):
            return a*x + b*x**2 + c*x**3 + d
        
        popt = curve_fit(func, vx_interp, np.sort(vy))[0]
        
        matched_vx = func(vx, *popt)
    elif np.array(vx).ndim==2:
        matched_vx=np.zeros((vx.shape[0],vx.shape[1]))
        for i in range(vx.shape[1]):
            vx_cdf = (np.arange(len(vx[:,i])) +  1) / (len(vx[:,i]))
            vy_cdf = (np.arange(len(vy[:,i])) +  1) / (len(vy[:,i]))
            
            spl = InterpolatedUnivariateSpline(vx_cdf, np.sort(vx[:,i]))
            vx_interp = spl(vy_cdf)
            
            def func(x, a, b, c, d):
                return a*x + b*x**2 + c*x**3 + d
            
            popt = curve_fit(func, vx_interp, np.sort(vy[:,i]))[0]
            
            matched_vx[:,i] = func(vx[:,i], *popt)
    elif np.array(vx).ndim==3:
        matched_vx=np.zeros((vx.shape[0],vx.shape[1],vx.shape[2]))
        for i in range(vx.shape[1]):
            for j in range(vx.shape[2]):
                vx_cdf = (np.arange(len(vx[:,i,j])) +  1) / (len(vx[:,i,j]))
                vy_cdf = (np.arange(len(vy[:,i,j])) +  1) / (len(vy[:,i,j]))
                
                spl = InterpolatedUnivariateSpline(vx_cdf, np.sort(vx[:,i,j]))
                vx_interp = spl(vy_cdf)
                
                def func(x, a, b, c, d):
                    return a*x + b*x**2 + c*x**3 + d
                
                popt = curve_fit(func, vx_interp, np.sort(vy[:,i,j]))[0]
                
                matched_vx[:,i,j] = func(vx[:,i,j], *popt)
    elif np.array(vx).ndim==4:
        matched_vx=np.zeros((vx.shape[0],vx.shape[1],vx.shape[2],vx.shape[3]))
        for i in range(vx.shape[1]):
            for j in range(vx.shape[2]):
                for k in range(vx.shape[3]):
                    vx_cdf = (np.arange(len(vx[:,i,j,k])) +  1) / (len(vx[:,i,j,k]))
                    vy_cdf = (np.arange(len(vy[:,i,j,k])) +  1) / (len(vy[:,i,j,k]))
                    
                    spl = InterpolatedUnivariateSpline(vx_cdf, np.sort(vx[:,i,j,k]))
                    vx_interp = spl(vy_cdf)
                    
                    def func(x, a, b, c, d):
                        return a*x + b*x**2 + c*x**3 + d
                    
                    popt = curve_fit(func, vx_interp, np.sort(vy[:,i,j,k]))[0]
                    
                    matched_vx[:,i,j,k] = func(vx[:,i,j,k], *popt)

    return matched_vx

In [18]:
from sklearn.model_selection import train_test_split
import Auto_paint_self
np.random.seed(25)
trainy,testy,trainx,testx = train_test_split(np.array(vy).reshape(-1,18,7),vx,test_size=0.2,random_state=25)
predicty_u=Auto_cdf_matching(np.array(predicty_u),trainy[np.random.randint(0,trainy.shape[0], predicty_u.shape[0]),:,:])

In [19]:
from sklearn.metrics import mean_squared_error,mean_absolute_error,mean_absolute_percentage_error
from scipy.stats import pearsonr
import numpy as np
import math
from tqdm import tqdm
from metpy.calc import wind_direction,wind_speed
from metpy.units import units
u_rmse=np.zeros((predicty_u.shape[1],predicty_u.shape[2]))
u_mae=np.zeros((predicty_u.shape[1],predicty_u.shape[2]))
u_pearson=np.zeros((predicty_u.shape[1],predicty_u.shape[2]))
u_mape=np.zeros((predicty_u.shape[1],predicty_u.shape[2]))
for i in tqdm(range(predicty_u.shape[1])):
    for j in range(predicty_u.shape[2]):
        u_rmse[i,j]=mean_squared_error(testy_u[:,i,j],predicty_u[:,i,j])
        u_pearson[i,j],_=pearsonr(testy_u[:,i,j],predicty_u[:,i,j])
        u_mae[i,j]=mean_absolute_error(testy_u[:,i,j],predicty_u[:,i,j])
        u_mape[i,j]=mean_absolute_percentage_error(testy_u[:,i,j],predicty_u[:,i,j])

100%|██████████████████████████████████████████████████████████████████████████████████| 18/18 [00:00<00:00, 69.90it/s]


In [20]:
print(np.nanmean(u_mae,axis=(0)))
print(np.nanmean(u_mae))

[0.45022744 0.40385336 0.47256346 0.54870131 0.65395147 0.60307125
 0.67929865]
0.5445238464672941


In [21]:
print(np.nanmean(u_rmse,axis=(0)))
print(np.nanmean(u_rmse))

[0.49980965 0.34683561 0.47971718 0.75763247 1.49924104 1.13751016
 1.53441154]
0.893593951337703


In [22]:
u_p=np.sqrt(u_rmse)/(np.nanmax(testy_u,axis=0)-np.nanmin(testy_u,axis=0))
#v_p=np.sqrt(v_rmse)/(np.nanmax(testy_v,axis=0)-np.nanmin(testy_v,axis=0))
#wind_p=np.sqrt(wind_rmse)/(np.nanmax(np.sqrt(testy_u**2+testy_v**2),axis=0)-np.nanmin(np.sqrt(testy_u**2+testy_v**2),axis=0))
print(np.nanmean(u_p,axis=(0)))
print(np.nanmean(u_p))

[0.02081009 0.01650273 0.01455292 0.01572533 0.0193032  0.01629008
 0.01953459]
0.017531277361275754


In [23]:
import numpy as np
import xarray as xr
import os



dim_names = ('times', 'station', 'levels')

predicty_u_da_cdf = xr.DataArray(
    data=predicty_u,
    coords={
        'times': times,
        'station': station,
        'levels': levels
    },
    dims=dim_names,
    name='predicty_u'
)

predicty_u_da.to_netcdf('E:/huawei/result/huawei_Transformer_30min_k5_gnss_to_wind_predicty_u_cdf_hight_morestation.nc')