In [None]:
import csv
import math
import torch
from d2l import torch as d2l
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.data as pyg_data
from torch_geometric.data import Data
from torch_geometric.data import HeteroData, Batch
from torch_geometric.transforms import ToUndirected
from torch_geometric.nn import GATConv, global_mean_pool, global_max_pool, global_add_pool
from torch_geometric.nn import to_hetero
import torch_geometric.transforms as T
from torch.masked import MaskedTensor
from functools import partial
from IPython import display
import torch.nn.utils.parametrizations as parametrizations  # 新的模块路径
from torch.nn.utils import weight_norm

<span style = 'color:red; font-size:25px'>MSE//RMSE

In [None]:
def calculate_mse(actual_values, predicted_values):
    squared_errors = [(actual - predicted + 0.00115) ** 2 for actual, predicted in zip(actual_values, predicted_values)]
    mse = sum(squared_errors) 
    return mse

<span style = 'color:red; font-size:25px'>ADE

In [None]:
def utm_epsg_from_lonlat(lon_deg, lat_deg):
    """
    根据经纬度选择合适的 UTM EPSG。
    北半球: 326xx，南半球: 327xx
    UTM zone = floor((lon + 180)/6) + 1
    """
    zone = int(math.floor((lon_deg + 180.0) / 6.0) + 1)
    if lat_deg >= 0:
        return 32600 + zone
    else:
        return 32700 + zone

def calculate_ade(predictions, ground_truth, mask, fixed_epsg = 32):
    """
    用 UTM 将经纬度(°)转换为米，然后计算 ADE（米）。
    predictions, ground_truth: (B, T, N, 2)  # [lon, lat] in degrees, WGS84
    mask: (B, T, N, F) 或 (B, T, N, 1) 或 (B, T, N)
    fixed_epsg:
        - None: 自动按每条序列选择 UTM 分区（推荐）
        - 例如 32632 或 32633：对整个批次强制使用单一投影（如你确认都在丹麦西/东）
    choose_per_sequence:
        - True: 每条序列用自己的 UTM 分区（根据该序列均值经纬度）
        - False: 整个批次用同一个（若 fixed_epsg=None，则按整个批次均值选择）
    返回：ADE（米）
    """
    assert predictions.shape[-1] == 2 and ground_truth.shape[-1] == 2
    device = predictions.device
    dtype  = predictions.dtype

    B, T, N, _ = predictions.shape

    # 转到 CPU numpy 做投影（pyproj 在 numpy 上运行）
    pred_np = predictions.detach().cpu().numpy()
    gt_np   = ground_truth.detach().cpu().numpy()

    # 结果容器（米坐标）
    px = np.empty((B, T, N), dtype=np.float64)
    py = np.empty((B, T, N), dtype=np.float64)
    gx = np.empty((B, T, N), dtype=np.float64)
    gy = np.empty((B, T, N), dtype=np.float64)

    # 选择投影并转换
    transformer = Transformer.from_crs("EPSG:4326", f"EPSG:{fixed_epsg}", always_xy=True
    lonp = pred_np[..., 0].reshape(-1); latp = pred_np[..., 1].reshape(-1)
    xp, yp = transformer.transform(lonp, latp)
    long = gt_np[..., 0].reshape(-1);  latg = gt_np[..., 1].reshape(-1)
    xg, yg = transformer.transform(long, latg
    px[:] = xp.reshape(B, T, N)
    py[:] = yp.reshape(B, T, N)
    gx[:] = xg.reshape(B, T, N)
    gy[:] = yg.reshape(B, T, N
                       
    # 转回 torch
    pred_m = torch.stack([torch.from_numpy(px), torch.from_numpy(py)], dim=-1).to(device=device, dtype=dtype)  # (B,T,N,2)
    gt_m   = torch.stack([torch.from_numpy(gx), torch.from_numpy(gy)], dim=-1).to(device=device, dtype=dtype)

    # 欧氏距离（米）
    displacement_error = torch.linalg.norm(pred_m - gt_m, dim=-1)  # (B,T,N)

    # 处理 mask 到 (B,T,N)
    if mask.dim() == displacement_error.dim() + 1:
        mask_reduced = mask.any(dim=-1).float()
    elif mask.shape == displacement_error.shape:
        mask_reduced = mask.float()
    else:
        mask_reduced = mask.squeeze(-1).float()

    masked_error = displacement_error * mask_reduced
    total_error  = masked_error.sum()
    valid_count  = mask_reduced.sum().clamp_min(1)

    ade = total_error / valid_count
    return ade


<span style = 'color:red; font-size:25px'>数据提取

In [None]:
def preprocess_data(file_path, start_row, chunk_size):
    """
    从 CSV 文件中逐批读取数据，将空值保留为空值。
    
    :param file_path: 文件路径
    :param start_row: 起始行
    :param chunk_size: 每次读取的行数
    :return: 读取的 DataFrame
    """
    try:
        data = pd.read_csv(file_path, header=None, skiprows=start_row, nrows=chunk_size, low_memory=False)
        if len(data) < chunk_size:
            print(f"Data read is smaller than the expected chunk size ({chunk_size}).")
            return pd.DataFrame()  # 数据不足一个完整批次，返回空 DataFrame
        return data.fillna(np.nan)  # 保留空值为 NaN
    except pd.errors.EmptyDataError:
        print("No more data to read. Exiting.")
        return pd.DataFrame()  # 捕获空数据错误，返回空 DataFrame

In [None]:
def slice_data_generator(file_path, input_len, pred_len, batch_size):
    """
    滑动窗口读取 CSV 文件并生成批次数据。
    """
    total_len = input_len + pred_len
    chunk_size = total_len + batch_size - 1  # 一个批次所需的总行数
    start_row = 2  # 从第三行（索引为2）开始读取数据

    while True:
        # Step 1: 读取当前批次的数据
        data = preprocess_data(file_path, start_row=start_row, chunk_size=chunk_size)
        if data.empty:
            # print("End of file or data insufficient for a full batch. Exiting.")
            break  # 数据不足或者文件末尾，直接退出

        # Step 2: 选择数据的有效列（从第二列开始）
        data = data.iloc[:, 1:]
        data_values = data.values

        # Step 3: 滑动窗口提取当前块的所有窗口数据
        Data = []
        for start_idx in range(0, len(data_values) - total_len + 1, 1):  # 步长为1
            Data.append(data_values[start_idx: start_idx + total_len])

        # Step 4: 如果提取不到有效窗口，直接结束
        if not Data:
            print("No valid data windows extracted. Exiting.")
            break

        # Step 5: 返回当前批次的数据
        yield Data

        # Step 6: 更新起始行位置以读取下一块数据
        start_row += chunk_size // 2

<span style = 'color:red; font-size:25px'>最大-最小化归一化

In [None]:
def normalize_datat(out_Y, max_values, min_values):
    """
    对形状为 (batch_size, 时间步, 节点数, 特征) 的张量进行最大最小归一化。

    :param out_Y: 输入张量，形状为 (batch_size, 时间步, 节点数, 特征)
    :param max_values: 每列特征的最大值张量，形状为 (特征,)
    :param min_values: 每列特征的最小值张量，形状为 (特征,)
    :return: 
        - normalized_data: 归一化后的张量，与输入 `out_Y` 的形状一致
    """
    # 确保 max_values 和 min_values 是张量
    max_values = torch.tensor(max_values, device=out_Y.device)
    min_values = torch.tensor(min_values, device=out_Y.device)

    # 检查维度
    if max_values.dim() != 1 or min_values.dim() != 1:
        raise ValueError("max_values 和 min_values 应该是一维张量，表示每个特征的最大和最小值")

    if max_values.size(0) != out_Y.size(-1):
        raise ValueError("max_values 和 min_values 的大小应该与特征维度一致")

    # 执行归一化
    normalized_data = (out_Y - min_values) / (max_values - min_values)

    return normalized_data

<span style = 'color:red; font-size:25px'>逆归一化

In [None]:
def denormalize_data(normalized_data, max_values, min_values):
    """
    对三维数据的最后一个维度进行逆归一化。
    
    :param normalized_data: 形状为 (batch_size, num_points, num_features) 的归一化数据
    :param max_values: 每列数据的最大值列表
    :param min_values: 每列数据的最小值列表
    :return: 逆归一化后的数据，形状与输入数据相同
    """
    # 转换为 numpy 数组
    normalized_data = np.array(normalized_data)
    max_values = np.array(max_values)
    min_values = np.array(min_values)
    max_values = max_values[..., 0, 0]
    min_values = min_values[..., 0, 0]
    # 确保数据的最后一个维度与 max_values 和 min_values 一致
    assert normalized_data.shape[-1] == max_values.shape[0], "最后一个维度与最大最小值的长度不匹配"
    
    # 逆归一化公式 x = normalized_data * (max - min) + min
    denormalized_data = normalized_data * (max_values - min_values) + min_values
    
    return denormalized_data

<span style = 'color:red;font-size:25px'>计算半正弦距离

In [None]:
def haversine_distances(points, radius=6371):
    """
    计算张量中各节点之间的半正弦距离，并将主对角线值设置为 1。
    
    参数:
    points: 形状为 (N, 2) 的张量，其中每行是一个点的经纬度 [latitude, longitude]
    radius: 地球半径，单位为千米，默认 6371
    
    返回:
    distances: 形状为 (N, N) 的张量，表示两点之间的球面距离，单位为海里
    """
    # 将经纬度从度转换为弧度
    points_rad = points * torch.pi / 180.0  # 形状 (N, 2)
    
    # 提取纬度和经度
    latitudes = points_rad[:, 0].unsqueeze(1)  # 形状 (N, 1)
    longitudes = points_rad[:, 1].unsqueeze(1)  # 形状 (N, 1)
    
    # 计算两点之间的纬度和经度差
    dlat = latitudes - latitudes.T  # 形状 (N, N)
    dlon = longitudes - longitudes.T  # 形状 (N, N)
    
    # 使用半正矢公式计算 a
    a = (torch.sin(dlat / 2) ** 2 +
         torch.cos(latitudes) * torch.cos(latitudes.T) * torch.sin(dlon / 2) ** 2)
    
    # 计算 c 并返回距离矩阵
    c = 2 * torch.arcsin(torch.sqrt(a))
    distances_km = radius * c  # 将弧度距离转换为实际距离（千米）
    
    # 转换为海里单位
    distances_nmi = distances_km / 1.852  # 1 海里 = 1.852 千米
    
    # 将主对角线设置为 1
    distances_nmi.fill_diagonal_(1.0)
    
    return distances_nmi


In [None]:
def process_data(F_data1, input_len, pred_len, statics_features):
    total_len = input_len + pred_len  # 总时间窗口长度

    batch_size = len(F_data1)
    # if batch_size == 0:
    #     return [], [], [], [], []
        
    sample_shape = np.array(F_data1[0]).shape
    num_nodes = sample_shape[1] // 5  # 每五列为一个节点

    # 将静态特征转为数组，加快后续索引
    statics_list = np.array([statics_features[str(i)] for i in range(num_nodes)], dtype=float)
    
    # 输出初始化
    F_p_all = []
    input_X_all = []
    output_Y_all = []
    S_all = []
    static_result_all = []
    input_x_all = []
    Static_result_all = []
    # print('batch_size:', batch_size)
    for i in range(batch_size):     # 提取每个批次的数据
        F_data = np.array(F_data1[i])  # [时间步, num_nodes*5]
        # 重塑为 [total_len, num_nodes, 5]
        F_data_reshaped = F_data.reshape(total_len, num_nodes, 5)

        # 判断整个窗口哪些节点完全有效（无NaN）
        valid_node_mask = ~np.isnan(F_data_reshaped).any(axis=(0, 2))
        f_p = np.where(valid_node_mask)[0].tolist()
        F_p_all.append(f_p)  # 当前批次中的有效节点
 
        # 分离输入和输出
        in_x_reshaped = F_data_reshaped[:input_len]   # [input_len, num_nodes, 5]
        out_y_reshaped = F_data_reshaped[input_len:]  # [pred_len, num_nodes, 5]

        # 对输入数据进行非空值提取
        valid_in_mask = ~np.isnan(in_x_reshaped).any(axis=2)

        # 使用 np.where 一次性获取所有有效节点及对应时间步
        all_valid_t, all_valid_nodes = np.where(valid_in_mask)  # 寻找有效节点对应的时间步和节点索引
        # 提取对应的特征
        non_empty_features_all = in_x_reshaped[all_valid_t, all_valid_nodes, :]   # 提取该批次中的索引
        # 提取对应的静态特征
        sta_result_all_nodes = statics_list[all_valid_nodes]   # 提取静态特征

        # 根据时间步对数据进行分组
        # np.unique 返回unique的时间步，以及对应出现次数
        unique_t, counts = np.unique(all_valid_t, return_counts=True)

        # 按照每个时间步的有效节点数进行分割
        split_indices = np.split(np.arange(len(all_valid_t)), np.cumsum(counts[:-1]))

        # 还原为每个时间步对应的数据列表
        input_X = [non_empty_features_all[idx] for idx in split_indices]
        S = [all_valid_nodes[idx] for idx in split_indices]
        static_result = [sta_result_all_nodes[idx] for idx in split_indices]

        input_X_all.append(input_X)
        S_all.append(S)
        static_result_all.append(static_result)

        # 构建预测数据 output_Y
        if not f_p: 
            output_y = np.array([])
            input_x0 = np.array([])
            statics_list0 = np.array([])
        else:
            output_y = out_y_reshaped[:, f_p, :]  # [pred_len, len(f_p), 5]
            input_x0 = in_x_reshaped[:, f_p, :]
            Statics_list0 = statics_list[f_p, :]
            statics_list0 = np.repeat(Statics_list0[np.newaxis, :, :], input_x0.shape[0], axis=0)
        output_Y_all.append(output_y)
        input_x_all.append(input_x0)
        Static_result_all.append(statics_list0)
        
    return input_X_all, F_p_all, output_Y_all, S_all, static_result_all, input_x_all, Static_result_all

<span style = 'color:red;font-size:25px'>提取静态信息(返回为一个字典)

In [None]:
def extract_mmsi_features(file_path):
    """
    提取文件中 MMSI 号及其对应特征值，并将字典键改为顺序索引。
    
    :param file_path: 文件路径
    :return: 字典，其中键为顺序索引（从 1 开始），值为特征值列表
    """
    # 读取数据
    data = pd.read_csv(file_path, header=None)
    
    # 第一行（索引为0）为MMSI号
    mmsi_row = data.iloc[0]
    
    # 第三行（索引为2）为特征值
    feature_row = data.iloc[2]
    
    # 字典存储结果
    mmsi_features = {}
    
    # 遍历MMSI号及其对应的特征值
    for col_idx, mmsi in enumerate(mmsi_row):
        if pd.notna(mmsi):  # 跳过空值
            if mmsi not in mmsi_features:
                mmsi_features[mmsi] = []
            mmsi_features[mmsi].append(feature_row[col_idx])
    
    # 替换键为索引
    indexed_features = {}
    for index, (key, value) in enumerate(mmsi_features.items(), start=0):
        indexed_features[str(index)] = value  # 转为字符串形式的索引
    
    return indexed_features

<span style = 'color:red;font-size:25px'>将onehot编码转换成索引

In [None]:
def one_hot_to_index(one_hot_str):
    """
    将 one-hot 编码字符串解析为索引。
    """
    one_hot_list = list(map(int, one_hot_str.split(',')))
    return one_hot_list.index(1)

def transform_data(data_dict):
    """
    将字典数据转换为新的格式，用索引替代 one-hot 编码。
    """
    transformed_data = {}
    for key, value in data_dict.items():
        
        # 提取 one-hot 编码并转为索引
        index = one_hot_to_index(value[0])
        
        # 替换原始 one-hot 编码为索引值，保持字典格式
        transformed_data[key] = [index] + value[1:]
    return transformed_data

# <span style = 'color:red;font-size:25px'>提取静态数据特征

In [None]:
def extract_sfeatures(S, features_dict):
    """
    根据列表 S 中的节点 ID 提取对应的特征，并将特征值转换为浮点型。
    
    :param S: 一个列表，其中每个元素是一个包含节点 ID 的数组。
    :param features_dict: 一个字典，键为节点 ID，值为对应的特征列表。
    :return: 一个列表，每个时间点下的节点特征列表，特征值为浮点型。
    """
    result = []

    for node_ids in S:
        # 提取每个时间点的节点特征并转换为浮点型
        node_features = [
            [float(value) for value in features_dict.get(str(node_id), [])] 
            for node_id in node_ids
        ]
        result.append(node_features)

    return result

<span style = 'color:red;font-size:25px'>转换成onehot编码

In [None]:
def convert_to_onehot(data, device, onehot_length = 20):
    """
    将输入二维张量的第一个特征转换为 one-hot 编码，并与其他特征拼接。
    
    :param data: 输入二维张量，形状为 [N, F]，其中 N 是节点数，F 是特征数。
    :param onehot_length: one-hot 编码的长度。
    :return: 转换后的张量，形状为 [N, onehot_length + F - 1]。
    """
    # 提取第一个特征索引并转换为整型
    indices = data[:, 0].long()
    
    # 创建 one-hot 编码张量
    onehot = torch.zeros(data.size(0), onehot_length, dtype=torch.float32, device = device)  # [N, onehot_length]
    onehot.scatter_(1, indices.unsqueeze(1), 1)  # 在指定位置设置为 1
    
    return onehot

<span style = 'color:red;font-size:25px'>异构图卷积

In [None]:
class Spatial_GATD(nn.Module):
    def __init__(self, input_dim, hidden_dim, gat_heads):
        super(Spatial_GATD, self).__init__()
        # 第一层 GAT: 输入特征 16, 输出特征 8, 注意力头数 32
        self.gat1 = GATConv(input_dim, hidden_dim, heads=gat_heads, concat=True, add_self_loops=False)
        # 第二层 GAT: 输入特征 32*8=256 (由第一层输出计算)，输出特征 16，注意力头数 1
        self.gat2 = GATConv(hidden_dim * gat_heads, input_dim, heads=1, concat=True, add_self_loops=False)

    def forward(self, x, edge_index, edge_attr):

        # 第一层 GAT
        x = self.gat1(x, edge_index, edge_attr = edge_attr)  # 输出形状 (节点数, 256)

        x = F.elu(x)

        # 第二层 GAT
        x = self.gat2(x, edge_index, edge_attr = edge_attr)  # 输出形状 (节点数, 16)

        return x

<span style = 'color:red;font-size:20px'>异构图计算

In [None]:
class H_Model(torch.nn.Module):
    def __init__(self, hidden_dimS, num_heads, embedding_dim2):
        super(H_Model, self).__init__()
        
        # 异构图卷积
        self.gatD = Spatial_GATD(hidden_dimS, hidden_dimS, num_heads)

        metadataD = (
                          ['DYA', 'STA'],  # 节点类型
                          [
                              ('DYA', 'DD', 'DYA'),    # 从 DYA 到 DYA 的边类型 DD
                              ('DYA', 'DS', 'STA'),    # 从 DYA 到 STA 的边类型 DS
                              ('STA', 'rev_DS', 'DYA'), # 从 STA 到 DYA 的边类型 rev_DS
                              ('STA', 'SS', 'STA')    # 从 STA 到 STA 的边类型 DD
                          ]
                    )
        
        self.gatD = to_hetero(self.gatD, metadata=metadataD)

    def forward(self, xd_dict, data_D):
        edge_indexD = data_D.edge_index_dict
        edge_attrD = data_D.edge_attr_dict
        x_dictd = self.gatD(xd_dict, edge_indexD, edge_attrD)
        
        return x_dictd['DYA'], x_dictd['STA']

<span style = 'color:red;font-size:25px'>构建异构图数据集

In [None]:
def HeteroGraphBuilder_batch(dynamic_tensor, static_tensor, masks, device, distances):
    """
    构建异构图数据和同构图数据，输入为 (batch_size, time_steps, num_nodes, feature_dim)
    动态节点和静态节点的数量相同，特征维度不同
    并生成对应的 batch 向量
    同时增加动态节点的同构图数据构造
    :param dynamic_tensor: 动态节点特征张量 (batch_size, time_steps, num_nodes, dynamic_feat_dim)
    :param static_tensor: 静态节点特征张量 (batch_size, time_steps, num_nodes, static_feat_dim)
    :param masks: 掩码张量 (batch_size, time_steps, num_nodes)，值为0表示需要删除的节点
    :param device: 设备
    :return: dataD (HeteroData), batch (Tensor)
    """    
    B, T, N, F_d = dynamic_tensor.shape
    _, _, _, F_s = static_tensor.shape
    
    dynamic_tensor = dynamic_tensor.to(device)
    static_tensor = static_tensor.to(device)
    masks = masks.to(device)
    
    total_subgraphs = B * T  # 总子图数量
    total_nodes = B * T * N

    # 动态节点和静态节点的全局特征展开
    X_D_global = dynamic_tensor.reshape(total_subgraphs * N, F_d)     # 形状为(batch_size * time_steps * num_nodes, feature_dim)
    X_S_global = static_tensor.reshape(total_subgraphs * N, F_s)      # 形状为(batch_size * time_steps * num_nodes, feature_dim)

    # 生成 batch 向量
    batch_vector = torch.repeat_interleave(torch.arange(total_subgraphs, device=device), N)

    # 单个子图的节点编号
    node_ids_local = torch.arange(N, dtype=torch.long, device=device)

    # 单个子图的边索引（动态节点和静态节点复用相同的逻辑）, 单连接
    if N > 1:
        edges_local = torch.combinations(node_ids_local.cpu(), r=2, with_replacement=False).T.to(device)   # 使节点间相互连接
    else:
        edges_local = torch.empty((2, 0), dtype=torch.long, device=device)
    self_loops = torch.stack([node_ids_local, node_ids_local], dim=0)
    edges_local = torch.cat([edges_local, self_loops], dim=1)  # (2, num_edges)
    # 计算全局偏移
    offsets = torch.arange(total_subgraphs, device=device) * N   # 每张图的起始节点
    
    # 生成所有可能的边（包括自连接）
    cross_edges_local = torch.cartesian_prod(node_ids_local, node_ids_local).T.to(device)  # (2, N*N)
    
    # 去除自连接边
    mask = cross_edges_local[0] != cross_edges_local[1]
    cross_edges_local = cross_edges_local[:, mask]  # (2, N*(N-1))
    
    # 展平 offsets
    offsets_flattened = offsets.flatten()
    
    # 将局部边索引扩展到所有子图
    cross_edges_expanded = (
        cross_edges_local.unsqueeze(1) + offsets_flattened.unsqueeze(0).unsqueeze(-1)
    ).reshape(2, -1)  # 最终形状: (2, total_subgraphs * N * (N-1))


    # 将单张图的结构扩展到所有子图
    edges_expanded = (edges_local.unsqueeze(1) + offsets.unsqueeze(0).unsqueeze(-1)).reshape(2, -1)   # 生成全局(相同节点)边

    # 构建 HeteroData 对象
    dataD = HeteroData()
    
    # 动态节点
    dataD['DYA'].x = X_D_global
    dataD['DYA', 'DD', 'DYA'].edge_index = edges_expanded
    # 静态节点
    dataD['STA'].x = X_S_global
    # 动态与静态的边
    dataD['DYA', 'DS', 'STA'].edge_index = cross_edges_expanded

    # 边连接矩阵
    edgeD_DD = dataD['DYA', 'DD', 'DYA'].edge_index  # 双向边
    edgeD_DS = dataD['DYA', 'DS', 'STA'].edge_index  # 单向边

    # 找出填充点的位置(即值为0的位置)，只会严格遵守最初定义的节点ID，并不会改变节点ID
    # 展平成一维张量
    flattened_tensor = masks.flatten()
    # 找出值为0的位置即要删除的点
    zero_indices = (flattened_tensor == 0).nonzero(as_tuple=False).squeeze()
    # 使用布尔掩码删除指定位置的值
    Mask = torch.ones(batch_vector.size(0), dtype=torch.bool, device=device)
    Mask[zero_indices] = False  # 将需要删除的位置标记为 False
    batch = batch_vector[Mask]

    # 动态节点的边过滤
    # 判断边的起点或终点是否在节点编号列表中
    maskD_DD = ~(
        torch.isin(edgeD_DD[0], zero_indices) |
        torch.isin(edgeD_DD[1], zero_indices)
    )
    # print('maskD_DD:',maskD_DD)
    maskD_DS = ~(
        torch.isin(edgeD_DS[0], zero_indices) |
        torch.isin(edgeD_DS[1], zero_indices)
    )
   
    # 删除无用节点，重新赋予边连接
    edgeD_DD = edgeD_DD[:, maskD_DD]
    edgeD_DS = edgeD_DS[:, maskD_DS]
  
    # 重新映射边索引
    dataD['DYA', 'DD', 'DYA'].edge_index = edgeD_DD
    dataD['DYA', 'DS', 'STA'].edge_index = edgeD_DS
    dataD['STA', 'SS', 'STA'].edge_index = edgeD_DD

    # 赋予边特征值
    adj_DD = distances[edgeD_DD[0], edgeD_DD[1]]
    adj_DS = distances[edgeD_DS[0], edgeD_DS[1]]
    dataD['DYA', 'DD', 'DYA'].edge_attr = adj_DD
    dataD['DYA', 'DS', 'STA'].edge_attr = adj_DS
    dataD['STA', 'SS', 'STA'].edge_attr = adj_DD
    
    # 转换为无向图
    dataD = ToUndirected()(dataD)

    return dataD, batch

<span style = 'color:red;font-size:25px'>获取适合于TCN输入的数据

In [None]:
def process_tensor_and_extract_features(tensor, batch_size, input_len, masks, device):
    """
    将输入张量按照batch值划分为多个时间步，并根据F中的索引提取每个时间步的特征。
    :param tensor: 输入张量，形状为 (batch_size, 时间步数, 节点总数, 特征维度)
    :param masks: 形状为 (batch_size, 时间步数, 节点总数)
    :param batch: 对应的batch值，形状为 (节点总数, )
    :param F: 用于轨迹预测的节点在原完整数据中的索引列表形状为(batch_size, 节点ID)
    :param S: 每个时间步的节点在原数据中的索引列表（numpy数组组成的列表）(batch_size, 时间数, 节点ID)
    :return: 三维张量，形状为 (时间步数, 节点数, 特征维度)
    """

    # tensor = tensor.reshape(batch_size, input_len, -1, tensor.shape[-1]) # 形状为(B, T, N, F)
    # 1) 逐 batch 提取 num_nodes 的真实数据
    result_list = []
    
    for b in range(batch_size):
        # 当前 batch 的遮掩矩阵 (T, N)
        mask_b = masks[b]  # shape: (T, N)
        
        # 转置掩码以适配 X 的最后两维 (N, T)
        valid_nodes = mask_b.any(dim=0)  # shape: (N,) -> 是否有非填充数据
        valid_node_indices = torch.where(valid_nodes)[0]  # 有效节点索引
    
        # 提取 X 中有效的节点数据
        tensor_valid = tensor[b, :, valid_node_indices, :]  # shape: (T, num_valid_nodes, F)
        result_list.append(tensor_valid)
    
    # 2) 拼接结果并重新排列为目标形状
    final_result = torch.cat(result_list, dim=1).permute(1, 0, 2).to(device) # shape: (B * N, T, F)
    final_result = final_result.permute(0, 2, 1) # shape: (B * N, F, T)
    
    # 返回形状(batch_size * 节点数, 特征维度, 1, 时间步数)
    return final_result.reshape(final_result.shape[0], final_result.shape[1], 1, final_result.shape[2])

<span style = 'color:red;font-size:25px'>LSTM

In [None]:
class LSTMNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, bidirectional=False, dropout=0.0):
        super(LSTMNetwork, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional

        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=hidden_size,
                            num_layers=num_layers,
                            batch_first=True,
                            bidirectional=bidirectional,
                            dropout=dropout if num_layers > 1 else 0.0)
    
    def forward(self, x):
        """
        前向传播方法。

        参数:
            x (Tensor): 输入张量，形状为 (batch_size, seq_length, input_size)。

        返回:
            output (Tensor): 每个时间步最后一层的隐藏状态，形状为 
                             (batch_size, seq_length, hidden_size * num_directions)。
        """
        output, (h_n, c_n) = self.lstm(x)
        
        return output, h_n, c_n

<span style = 'color:red;font-size:20px'>平均池化

In [None]:
def masked_average_pooling(x, mask, pool_dim, device):
    """
    x: 输入张量，形状为 (batch_size, feature_dim, num_nodes, time_steps)
    mask: 掩码张量，形状为 (batch_size, time_steps, num_nodes)
    pool_dim: 池化维度，0 -> num_nodes，1 -> time_steps，2 -> feature_dim
    返回: 平均池化后的张量
    """
    x = x.to(device)
    mask = mask.to(device)
    # 调整掩码形状以匹配 x 
    mask = mask.unsqueeze(1)  # 形状变为 (batch_size, 1, time_steps, num_nodes)
    mask = mask.permute(0, 1, 3, 2)  # 调整为 (batch_size, 1, num_nodes, time_steps)
    
    # 确保 mask 的 dtype 与 x 相同
    mask = mask.type_as(x)
    
    # 应用掩码 
    x_masked = x * mask  # 填充部分被置为 0
    
    # 计算有效元素的数量
    mask_sum = mask.sum(dim=2)  # 在 num_nodes 维度上求和，形状为 (batch_size, feature_dim, time_steps)
    mask_sum = mask_sum.clamp(min=1e-6)  # 防止除以零
    
    # 对不同维度进行池化
    if pool_dim == 2:  # 沿着 num_nodes 维度池化
        x_sum = x_masked.sum(dim=2)  # 在 num_nodes 维度上求和
        x_avg = x_sum / mask_sum
        # print('x_avg2:',x_avg)
    elif pool_dim == 3:  # 沿着 time_steps 维度池化
        x_sum = x_masked.sum(dim=3)  # 在 time_steps 维度上求和
        mask_sum = mask.sum(dim=3)  # 计算有效元素数量
        x_avg = x_sum / mask_sum.clamp(min=1e-6)
        # print('x_avg3:',x_avg)
    elif pool_dim == 1:  # 沿着 feature_dim 维度池化
        # 扩展 mask 使其与 x 形状一致 (batch_size, feature_dim, num_nodes, time_steps)
        mask_expanded = mask.expand(-1, x.size(1), -1, -1)  # 扩展为 (batch_size, feature_dim, num_nodes, time_steps)
        
        # 计算有效元素的数量，沿着 feature_dim 维度进行求和
        mask_sum = mask_expanded.sum(dim=1)  # 在 feature_dim 维度上求和
        mask_sum = mask_sum.clamp(min=1e-6)  # 防止除以零
        x_sum = x_masked.sum(dim=1)  # 在 feature_dim 维度上求和
        x_avg = x_sum / mask_sum
    return x_avg  # 返回(batch_size, feature_dim, num_nodes, time_steps)消除特定维度

<span style = 'color:red;font-size:20px'>最大池化

In [None]:
def masked_max_pooling(x, mask, pool_dim, device):
    """
    x: 输入张量，形状为 (batch_size, feature_dim, num_nodes, time_steps)
    mask: 掩码张量，形状为 (batch_size, time_steps, num_nodes)
    pool_dim: 池化维度，0 -> num_nodes，1 -> time_steps，2 -> feature_dim
    返回: 最大池化后的张量
    """
    # 调整掩码形状以匹配 x
    mask = mask.unsqueeze(1)  # 形状变为 (batch_size, 1, time_steps, num_nodes)
    mask = mask.permute(0, 1, 3, 2)  # 调整为 (batch_size, 1, num_nodes, time_steps)
    mask = mask.to(device)
    x = x.to(device)
    # 将填充部分设置为 -inf
    x_masked = x.masked_fill(mask == 0, float('-inf'))
    
    # 对不同维度进行池化
    if pool_dim == 2:  # 沿着 num_nodes 维度池化
        x_max, _ = x_masked.max(dim=2)  # 在 num_nodes 维度上取最大值
        # 填充部分的值保留为 -inf（即填充）
        x_max = x_max.masked_fill(mask.sum(dim=2) == 0, 0)

    elif pool_dim == 3:  # 沿着 time_steps 维度池化
        x_max, _ = x_masked.max(dim=3)  # 在 time_steps 维度上取最大值
        
        # 填充部分的值保留为 -inf（即填充）
        x_max = x_max.masked_fill(mask.sum(dim=3) == 0, 0)
        
    elif pool_dim == 1:  # 沿着 feature_dim 维度池化
        x_max, _ = x_masked.max(dim=1)  # 在 feature_dim 维度上取最大值
        # 填充部分的值保留为 -inf（即填充）
        x_max = x_max.masked_fill(mask.sum(dim=1) == 0, 0)

    return x_max  # 返回(batch_size, feature_dim, num_nodes, time_steps)消除特定维度

In [None]:
class AttentionModule(nn.Module):
    def __init__(self, kernel_size, pool_dim1, pool_dim2, input_channels = 1):
        super(AttentionModule, self).__init__()
        self.pool_dim1 = pool_dim1
        self.pool_dim2 = pool_dim2

        # Nodes 和 Time交互
        self.nt_conv = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=kernel_size, padding = (kernel_size - 1) // 2),
            nn.BatchNorm2d(1),
            nn.ReLU(),
            nn.Sigmoid()
        )
        
        # Feature和 Nodes交互
        self.fn_conv = nn.Sequential(
            nn.Conv2d(2 * input_channels, input_channels, kernel_size=kernel_size, padding = (kernel_size - 1) // 2),
            nn.BatchNorm2d(1),
            nn.ReLU(),
            nn.Sigmoid()
        )
        
    def forward(self, x, mask, device):
        """
        x: 输入张量，形状为 (batch_size, feature_dim, num_nodes, time_steps)
        mask: 掩码张量，形状为 (batch_size, time_steps, num_nodes)
        pool_dim: 池化维度，0 -> num_nodes，1 -> time_steps，2 -> feature_dim
        返回: 最大池化后的张量
        """
        # nodes - time_steps
        avg_nt = masked_average_pooling(x, mask, self.pool_dim1, device)
        max_nt = masked_max_pooling(x, mask, self.pool_dim1, device)
        avg_nt = avg_nt.unsqueeze(1)
        max_nt = max_nt.unsqueeze(1)
        am_nt = torch.cat((avg_nt, max_nt), dim = 1)
        att_nt = self.nt_conv(am_nt) # 输出形状为(batch_size, 1, num_nodes, time_steps)
        x_ant = att_nt * x
        
        # feature - nodes
        avg_fn = masked_average_pooling(x, mask, self.pool_dim2, device)
        max_fn = masked_max_pooling(x, mask, self.pool_dim2, device)
        avg_fn = avg_fn.unsqueeze(1)
        max_fn= max_fn.unsqueeze(1)
        am_fn = torch.cat((avg_fn, max_fn), dim = 1)
        att_fn = self.fn_conv(am_fn) # 输出形状为(batch_size, 1, feature_dim, num_nodes)
        
        # 将数据形状调整为 (batch_size, time_steps, feature_dim, num_nodes)
        data_pfn = x.permute(0, 3, 1, 2)
        
        # 逐元素相乘
        weighted_dfn = att_fn * data_pfn
       
        # 如果需要恢复原始形状 (batch_size, feature_dim, num_nodes, time_steps)
        x_afn = weighted_dfn.permute(0, 2, 3, 1)
        
        return x_ant + x_afn  # (batch_size, feature_dim, num_nodes, time_steps)

<span style = 'color:red;font-size:25px'>iTransformer编码器

In [None]:
class GeLU(nn.Module):
    def forward(self, input_tensor):
        return 0.5 * input_tensor * (1 + torch.tanh(math.sqrt(2 / math.pi) * (input_tensor + 0.044715 * torch.pow(input_tensor, 3))))

In [None]:
class AddNorm(nn.Module):
    """残差连接后进行层规范化"""
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)

In [None]:
class PositionWiseFFN(nn.Module):
    """基于位置的前馈网络"""
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
                 **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.gelu = GeLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.gelu(self.dense1(X)))

In [None]:
class DotProductAttention(nn.Module):
    """缩放点积注意力"""
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # queries的形状：(batch_size，查询的个数，d)
    # keys的形状：(batch_size，“键－值”对的个数，d)
    # values的形状：(batch_size，“键－值”对的个数，值的维度)
    # valid_lens的形状:(batch_size，)或者(batch_size，查询的个数)
    def forward(self, queries, keys, values):
        d = queries.shape[-1]
        # 设置transpose_b=True为了交换keys的最后两个维度
        scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
        self.attention_weights = nn.functional.softmax(scores, dim = -1)
        return torch.bmm(self.dropout(self.attention_weights), values)

In [None]:
def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""
    # 输入X的形状:(batch_size, 变量个数, num_hiddens)
    # 输出X的形状:(batch_size, 变量个数, num_heads, num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 输出X的形状:(batch_size, num_heads, 变量个数, num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)

    # 最终输出的形状:(batch_size*num_heads, 变量个数, num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    # (batch_size, 变量个数, num_heads, num_hiddens/num_heads)
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    # 最终输出形状(batch_size, 变量个数, num_hiddens)
    return X.reshape(X.shape[0], X.shape[1], -1)

In [None]:
class MultiHeadAttention_1(nn.Module):
    """多头注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention_1, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        self.W_k = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        self.W_v = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values):
        # queries，keys，values的形状:
        # (batch_size，查询或者“键－值”对的个数，num_hiddens)
        # valid_lens　的形状:
        # (batch_size，)或(batch_size，查询的个数)
        # 经过变换后，输出的queries，keys，values　的形状:
        # (batch_size*num_heads，查询或者“键－值”对的个数，
        # num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        # output的形状:(batch_size*num_heads, 变量个数, num_hiddens/num_heads)
        output = self.attention(queries, keys, values)

        # output_concat的形状:(batch_size, 变量个数, num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

In [None]:
class EncoderBlock_1(nn.Module):
    """Transformer编码器块"""
    def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape,
                 ffn_num_input, ffn_num_hiddens, ffn_num_outputs, num_heads, dropout, use_bias=False, **kwargs):
        super().__init__(**kwargs)
        self.attention = MultiHeadAttention_1(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X):
        Y = self.addnorm1(X, self.attention(X, X, X))
        return self.addnorm2(Y, self.ffn(Y))

In [None]:
class iTransformer_1(nn.Module):
    """Transformer编码器"""
    def __init__(self, vocab_size, key_size, query_size, value_size,
                 num_hiddensT, norm_shape, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
                 num_headsT, num_layersT, pred_len, dropout = 0, use_bias=False, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.linear = nn.Linear(num_hiddensT, pred_len)
        self.fc = nn.Linear(vocab_size, num_hiddensT)
        self.blks = nn.Sequential()
        for i in range(num_layersT):
            self.blks.add_module("block"+str(i),
                EncoderBlock_1(key_size, query_size, value_size, num_hiddensT,
                             norm_shape, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
                             num_headsT, dropout, use_bias))
    def forward(self, X, *args):
        X = self.fc(X)
        for i, blk in enumerate(self.blks):
            X = blk(X)
        return self.linear(X)

<span style = 'color:red;font-size:25px'>STE模块

In [None]:
class STE_block(nn.Module):
    def __init__(self, hidden_dimS, num_heads, embedding_dim2, num_layersl, hidden_sizel, kernel_size, 
                 hidden_dimT, num_layersT, ffn_num_hiddens, num_headsT, input_len, pred_len):
        super(STE_block, self).__init__()
       
        # 异构图网络
        self.S_GAT1 = H_Model(hidden_dimS, num_heads, embedding_dim2)
        self.hidden_sizel = hidden_sizel
        # 时序-特征注意力机制
        self.Lstm1 = LSTMNetwork(hidden_dimS, hidden_sizel, num_layersl)
        self.iTsf = iTransformer_1(input_len, input_len, input_len, input_len, hidden_dimT, hidden_dimT, 
                                   hidden_dimT, ffn_num_hiddens, hidden_dimT, num_headsT, num_layersT, pred_len)
        # 坐标注意力
        self.att1 = AttentionModule(kernel_size, pool_dim1 = 1, pool_dim2 = 3)
        self.fc = nn.Linear(embedding_dim2 * 4, hidden_dimS)
        
    def forward(self, input_x, data_D, device, masks_x, distances): 
 
        batch_size, time_steps, num_nodes, _ = input_x.shape
        
        x_d, x_s = self.fc(data_D['DYA'].x), self.fc(data_D['STA'].x)
        "节点形状(batch_size * time_steps * num_nodes, feature_dim)"
        x_dictD0 = {'DYA': x_d, 'STA': x_s}
        
        "空间特征挖掘"
        X_in = torch.cat((data_D['DYA'].x, data_D['STA'].x), dim = -1).reshape(batch_size, time_steps, num_nodes, -1)
        X_in = X_in.permute(0, 2, 3, 1) # 改变形状为(batch_size, num_nodes, feature_dim, time_steps)
        
        # 异构图训练模块和同构图训练模块 
        x_dynamic0, x_static0 = self.S_GAT1(x_dictD0, data_D)  # 返回形状(batch_size * time_steps * 节点数, feature_dim)
        
        "坐标注意力计算"
        X_dynamic = x_dictD0['DYA'].reshape(batch_size, time_steps, num_nodes, -1)
        X_static = x_dictD0['STA'].reshape(batch_size, time_steps, num_nodes, -1)
        X_SD = torch.cat((X_dynamic, X_static), dim = -1).permute(0, 3, 2, 1)
        # 坐标注意力
        # 输入形状为(batch_size, feature_dim * 2, num_nodes, time_steps)
        X_SD0 = self.att1(X_SD, masks_x, device)  # 输出形状为(batch_size, 2*feature_dim, num_nodes, time_steps)
        F = X_SD0.shape[1]
        X_SD0 = X_SD0.permute(0, 3, 2, 1).reshape(batch_size * time_steps * num_nodes, -1) # 改变形状为(B*T*N, F)   
        
        "特征融合"
        X_D0, X_S0 = X_SD0[:, : F // 2], X_SD0[:, F // 2:]
        X_D1, X_S1 = x_dynamic0 + X_D0, x_static0 + X_S0
        X_D1 = X_D1.reshape(batch_size, time_steps, num_nodes, -1).permute(0, 2, 1, 3).reshape(batch_size * num_nodes, time_steps, -1)
        X_S1 = X_S1.reshape(batch_size, time_steps, num_nodes, -1).permute(0, 2, 1, 3).reshape(batch_size * num_nodes, time_steps, -1)
        _, h_d, c_d = self.Lstm1(X_D1)
        _, h_s, c_s = self.Lstm1(X_S1)
        
        "长时序挖掘"
        X_SD1 = torch.cat((data_D['DYA'].x, data_D['STA'].x), dim = -1).reshape(batch_size, time_steps, num_nodes, -1).permute(0, 2, 3, 1)  
        X_SD1 = X_SD1.reshape(batch_size * num_nodes, -1, time_steps)                                  
       
        # 输入形状为(B * N, F, T)
        output = self.iTsf(X_SD1) # 输出形状为(batch_size * num_nodes, featrure_dim, time_steps)                                  
        
        # 返回形状为(batch_size * num_nodes, featrure_dim, time_steps)   
        return h_d + h_s, c_d + c_s, output

<span style = 'color:red;font-size:25px'>解码器

In [None]:
class DecoderModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
        """
        解码器模型
        :param input_dim: 输入特征的维度
        :param hidden_dim: 隐藏层维度
        :param num_layers: LSTM 层数
        :param output_dim: 输出特征的维度
        """
        super(DecoderModel, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

    def forward(self, target, h, c):
        """
        用于训练的前向传播
        :param input_x: 解码器初始输入，形状 (时间步, batch_size, input_dim)
        :param hidden: 解码器初始隐藏状态，形状 (1, batch_size, hidden_dim)
        :param cell: 解码器初始细胞状态，形状 (1, batch_size, hidden_dim)
        :param target: 目标序列，用于训练模式，形状 (target_len, batch_size, input_dim)
        :return: 输出序列，形状 (batch_size, target_len, output_dim)
        """
        batch_size = target.size(0)
        target_len = target.size(1)
        outputs = torch.zeros(batch_size, target_len, self.fc.out_features).to(h.device)
        hidden = h
        cell = c
        # 当前时间步输入
        output, (hidden, cell) = self.lstm(target, (hidden.contiguous(), cell.contiguous()))  # LSTM 输出
        outputs = self.fc(output)  # 全连接层投影到输出维度
        
        return outputs   #形状 (batch_size, target_len, output_dim)

<span style = 'color:red;font-size:25px'>异构图静态特征嵌入

In [None]:
class Embedding(torch.nn.Module):
    def __init__(self, embedding_dim2):
        super(Embedding, self).__init__()

        # 静态数据编码器层
        # 丹麦
        self.embedding_layer1 = nn.Embedding(num_embeddings=21, embedding_dim=embedding_dim2)  # 船型
        self.embedding_layer2 = nn.Embedding(num_embeddings=65, embedding_dim=embedding_dim2)  # 船长
        self.embedding_layer3 = nn.Embedding(num_embeddings=420, embedding_dim=embedding_dim2) # 船宽 
        self.embedding_layer4 = nn.Embedding(num_embeddings=15, embedding_dim=embedding_dim2)  # 吃水
        
        # # 加州
        # self.embedding_layer1 = nn.Embedding(num_embeddings=91, embedding_dim=embedding_dim2)  # 船型
        # self.embedding_layer2 = nn.Embedding(num_embeddings=61, embedding_dim=embedding_dim2)  # 船长
        # self.embedding_layer3 = nn.Embedding(num_embeddings=401, embedding_dim=embedding_dim2) # 船宽 
        # self.embedding_layer4 = nn.Embedding(num_embeddings=23, embedding_dim=embedding_dim2)  # 吃水

        # # 休斯顿
        # self.embedding_layer1 = nn.Embedding(num_embeddings=100, embedding_dim=embedding_dim2)  # 船型
        # self.embedding_layer2 = nn.Embedding(num_embeddings=67, embedding_dim=embedding_dim2)   # 船长
        # self.embedding_layer3 = nn.Embedding(num_embeddings=365, embedding_dim=embedding_dim2)  # 船宽 
        # self.embedding_layer4 = nn.Embedding(num_embeddings=23, embedding_dim=embedding_dim2)   # 吃水
        
    def forward(self, dataD, device):
                
        dataD = dataD.to(device)
        
        self.xd_s = dataD['STA'].x

        # 静态属性编码
        self.Xd_S1 = self.embedding_layer1(self.xd_s[:, 0].long())  # 船型
        self.Xd_S2 = self.embedding_layer2(self.xd_s[:, 1].long())  # 船长
        self.Xd_S3 = self.embedding_layer3(self.xd_s[:, 2].long())  # 船宽 
        self.Xd_S4 = self.embedding_layer4(self.xd_s[:, 3].long())  # 吃水

        self.Xd_S = torch.cat((self.Xd_S1, self.Xd_S2, self.Xd_S3, self.Xd_S4), dim=1)
        
        # 静态属性获取
        dataD['STA'].x = self.Xd_S
        
        return dataD

<span style = 'color:red;font-size:25px'>预测填充

In [None]:
def ipad_out(Y, masks, device):
    batch_size, pred_time_steps, max_nodes = masks.shape
    _, _, feature_dim = Y.shape
    # ========== 简化实现 ==========
    # 1) 找到每批次的有效节点
    valid_nodes_masks = masks.all(dim=1)  # (batch_size, max_nodes)
    valid_nodes_indices = [torch.where(valid_nodes_masks[b])[0] for b in range(batch_size)]  # 每批次有效节点索引列表
    
    # 2) 初始化全零张量
    output = torch.zeros(batch_size, max_nodes, pred_time_steps, feature_dim).to(device)  # (B, max_nodes, T, F)
    
    # 3) 遍历每批次并填充数据
    start_idx = 0
    for b in range(batch_size):
        valid_indices = valid_nodes_indices[b]  # 当前批次的有效节点索引
        num_valid_nodes = len(valid_indices)    # 有效节点数量
    
        # 提取 Y 中对应的数据
        Y_batch = Y[start_idx:start_idx + num_valid_nodes]  # (num_valid_nodes, pred_time_steps, feature_dim)
        start_idx += num_valid_nodes
    
        # 填充到目标张量
        output[b, valid_indices, :, :] = Y_batch # shape:(batch_size, max_nodes, time_steps, feature_dim)
    return output

<span style = 'color:red;font-size:25px'>时间编码

In [None]:
def generate_positional_encoding(num_steps, hidden_dim, device):
    """
    生成正弦-余弦位置编码（Positional Encoding）
    
    参数：
    num_steps: 时间步长 (序列长度)
    hidden_dim: 特征维度 (需要与隐藏维度匹配)
    device: 计算设备（如 'cuda' 或 'cpu'）

    返回:
    形状为 (num_steps, hidden_dim) 的位置编码张量
    """
    # 创建位置索引矩阵 (num_steps, 1) -> (num_steps,)
    position = torch.arange(num_steps, dtype=torch.float, device=device).unsqueeze(1)

    # 计算除法因子 (hidden_dim, )
    div_term = torch.exp(torch.arange(0, hidden_dim, 2, dtype=torch.float, device=device) * 
                         -(torch.log(torch.tensor(10000.0, device=device)) / hidden_dim))

    # 初始化编码矩阵 (num_steps, hidden_dim)
    pe = torch.zeros(num_steps, hidden_dim, device=device)

    # 偶数维使用正弦编码，奇数维使用余弦编码
    pe[:, 0::2] = torch.sin(position * div_term)  # 偶数索引
    pe[:, 1::2] = torch.cos(position * div_term)  # 奇数索引

    return pe

def add_positional_encoding(static_features, device):
    """
    将位置编码与静态节点特征相结合

    参数：
    static_features: 输入静态特征张量，形状 (batch_size * 节点数, num_steps, hidden_dim)

    返回：
    加入位置编码的静态特征，形状不变
    """
    # 获取输入形状
    batch_node_size, num_steps, hidden_dim = static_features.shape

    # 生成位置编码 (num_steps, hidden_dim)
    positional_encoding = generate_positional_encoding(num_steps, hidden_dim, device)

    # 通过广播将其添加到静态特征中
    static_features_with_encoding = static_features + positional_encoding.unsqueeze(0)

    return static_features_with_encoding

# <span style = 'color:red'>模型主函数

In [None]:
class H_data(nn.Module):
    def __init__(self,input_len,pred_len,input_dimD,hidden_dimD,embedding_dim1,embedding_dim2,hidden_dimS,num_heads,num_layersl,
                 hidden_sizel,kernel_size,hidden_dimT,num_layersT,ffn_num_hiddens,num_headsT,num_layersl2, hidden_sizel2):
        super(H_data, self).__init__()
        self.input_len = input_len
        self.pred_len = pred_len
        
        # 定义模型
        self.embed = Embedding(embedding_dim2)
        
        # 时空模块
        self.STE1 = STE_block(hidden_dimS, num_heads, embedding_dim2, num_layersl2, hidden_sizel2, kernel_size, 
                 hidden_dimT, num_layersT, ffn_num_hiddens, num_headsT, input_len, pred_len) 
        # 生成
        self.lstm = LSTMNetwork(input_dimD, hidden_sizel, num_layersl)
        
        # 解码器预测输出
        self.decoder = DecoderModel(embedding_dim2 * 8, hidden_sizel2, num_layersl2, output_dim = 2)
        
    def forward(self, input_x, static_X, device, F, S, batch_size, masks_x, distances): 
        """
        input_x是一个批次的数据列表，列表长度为batch_size，元素为total_len时间步长的数据形状为(batch_size, 时间步, 节点数, 特征维度)
        static_X_list形状为(batch_size, 时间步, 节点数, 特征维度)
        F的形状为(batch_size, 节点ID)
        S的形状为(batch_size, 时间步, 节点ID)
        masks_x的形状为(batch_size, time_steps, num_nodes)
        """
        
        batch_size, time_steps, num_nodes, _ = input_x.shape
        data_D, batch = HeteroGraphBuilder_batch(input_x, static_X, masks_x, device, distances)
        data_D = self.embed(data_D, device)
        X_dynamic = data_D['DYA'].x
        X_static = data_D['STA'].x
    
        # # masks_x形状为(batch_size, 时间步, 节点数)
        # # input_x形状为(batch_size, 时间步, 节点数, 特征维度)
        x_dynamic1, x_static1 = X_dynamic.reshape(batch_size, time_steps, num_nodes, -1), X_static.reshape(batch_size, time_steps, num_nodes, -1)
        x_dynamic1, x_static1 = x_dynamic1.permute(0, 2, 1, 3), x_static1.permute(0, 2, 1, 3)  #改变形状为(batch_size, num_nodes, time_steps, feature_dim)
        
        # 改变形状为(batch_size * 节点数, num_steps, 隐藏维度)
        x_dynamic1, x_static1 = x_dynamic1.reshape(-1, time_steps, x_dynamic1.shape[-1]), x_static1.reshape(-1, time_steps, x_static1.shape[-1])
        
        # 使用LSTM网络捕获时序特征
        x_sta1 = add_positional_encoding(x_static1, device)
        x_dyn1, _, _ = self.lstm(x_dynamic1)
        feature_dim = x_dyn1.shape[-1]
        
        # 形状改变为(batch_size*time_steps*num_nodes, -1)
        x_dyn1 = x_dyn1.reshape(batch_size, num_nodes, time_steps, -1).permute(0, 2, 1, 3).reshape(-1, feature_dim)
        x_sta1 = x_sta1.reshape(batch_size, num_nodes, time_steps, -1).permute(0, 2, 1, 3).reshape(-1, feature_dim)
        data_D['DYA'].x, data_D['STA'].x = x_dyn1, x_sta1

        # 
        h_D, c_D, X_din = self.STE1(input_x, data_D, device, masks_x, distances) # 输出形状为(B * N, F, T)
        X_din = X_din.permute(0, 2, 1)

        # 解码输出
        Y_h = self.decoder(X_din, h_D, c_D)   # Y_h的形状为(batch_size * 节点数, 时间步, 隐藏特征维度)
        Y_hat = Y_h.reshape(batch_size, num_nodes, self.pred_len, -1) # 改变形状为(B, N, T, F)    
       
        "结束"
        # # 扩展 mask 的形状到 (batch_size, time_steps, num_nodes, 1)
        masks_out = masks_x[:, :self.pred_len, :].to(device)

        return Y_hat, masks_out          # Y_hat shape:(batch_size, max_nodes, time_steps, feature_dim),
                                                         # masks_out shape: (batch_size, time_steps, max_nodes)

<span style = 'color:red; font-size:25px'>实时显示训练结果

In [None]:
class Animator:  #@save
    """在动画中绘制数据"""
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
                 figsize=(3.5, 2.5)):
        # 增量地绘制多条线
        if legend is None:
            legend = []
        d2l.use_svg_display()
        self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [self.axes, ]
        # 使用lambda函数捕获参数
        self.config_axes = lambda: d2l.set_axes(
            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts

    def add(self, x, y):
        # 向图表中添加多个数据点
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        display.clear_output(wait=True)

# <span style = 'color:red;font-size:25px'>填充数据

In [None]:
def pad_outputy(output_Y):
    """
    对形状为 (batch_size, 时间步, 节点数, 特征) 的不规则列表进行填充，并生成掩码。
    
    :param output_Y: 一个列表，每个元素形状为 (时间步, 节点数, 特征)
    :return: 
        - padded_output: 填充后的张量，形状为 (batch_size, 时间步, max_nodes, 特征)
        - masks: 填充掩码，形状为 (batch_size, 时间步, max_nodes)
    """
    # 如果 output_Y 是空的，或者包含空数组，直接返回空张量
    if not output_Y or any(len(y) == 0 for y in output_Y):
        # print("output_Y 为空或包含空数组，返回空结果")
        return torch.empty(0), torch.empty(0)
    
    # 转换为 NumPy 数组并检查形状有效性
    output_Y = [np.array(y) for y in output_Y if len(y) > 0]
    max_nodes = max(y.shape[1] for y in output_Y)
    feature_dim = output_Y[0].shape[2]
    max_time_steps = max(y.shape[0] for y in output_Y)
    batch_size = len(output_Y)

    # 初始化填充张量和掩码
    padded_output = torch.zeros((batch_size, max_time_steps, max_nodes, feature_dim), dtype=torch.float32)
    masks = torch.zeros((batch_size, max_time_steps, max_nodes), dtype=torch.float32)

    # 填充数据和生成掩码
    for i, batch in enumerate(output_Y):
        t, n, f = batch.shape
        padded_output[i, :t, :n, :] = torch.tensor(batch, dtype=torch.float32)
        masks[i, :t, :n] = 1.0

    return padded_output, masks


In [None]:
def pad_output(output_Y):
    """
    对形状为 (batch_size, [时间步的列表], 节点数, 特征) 的不规则输入数据进行填充，并生成掩码。
    
    :param output_Y: 一个嵌套列表，包含 NumPy 数组或张量，每个元素是 (节点数, 特征) 的数组
    :return: 
        - padded_output: 填充后的张量，形状为 (batch_size, max_time_steps, max_nodes, feature_dim)
        - masks: 填充掩码，形状为 (batch_size, max_time_steps, max_nodes)
    """
    batch_size = len(output_Y)

    # 获取每个 batch 中时间步的数量
    time_steps_list = [len(batch) for batch in output_Y]
    max_time_steps = max(time_steps_list)  # 最大时间步数

    # 找到最大节点数和特征维度
    max_nodes = max([array.shape[0] for batch in output_Y for array in batch])
    feature_dim = max([array.shape[1] for batch in output_Y for array in batch])

    # 初始化填充后的张量和掩码
    padded_output = torch.zeros((batch_size, max_time_steps, max_nodes, feature_dim), dtype=torch.float32)
    masks = torch.zeros((batch_size, max_time_steps, max_nodes), dtype=torch.float32)
    
    # 填充数据并生成掩码
    for i, batch in enumerate(output_Y):
        for t, array in enumerate(batch):  # 遍历时间步
            num_nodes, num_features = array.shape
            padded_output[i, t, :num_nodes, :num_features] = torch.tensor(array, dtype=torch.float32)
            masks[i, t, :num_nodes] = 1.0  # 有效数据位置为 1.0
    
    return padded_output, masks

<span style = 'color:red;font-size:25px'>训练数据

In [None]:
def calculate_ade(predictions, ground_truth, mask):
    """
    计算ADE值
    :param predictions: 预测张量，形状为 (batch_size, pred_time, num_nodes, feature_dim)
    :param ground_truth: 真实值张量，形状为 (batch_size, pred_time, num_nodes, feature_dim)
    :param mask: 遮掩矩阵，形状为 (batch_size, pred_time, num_nodes, feature_dim)
    :return: ADE值和有效节点数
    """
    # 计算 L2 范数（欧氏距离）：每个时间步、每个节点上的预测误差
    displacement_error = torch.sqrt(torch.sum((predictions - ground_truth) ** 2, dim=-1))  # (batch_size, pred_time, num_nodes)
    
    # 将 mask 的最后一个维度降维以匹配 displacement_error
    mask_reduced = mask.any(dim=-1).float()  # (batch_size, pred_time, num_nodes)
    
    # 将填充部分的误差置为 0
    masked_error = displacement_error * mask_reduced  # (batch_size, pred_time, num_nodes)
    
    # 累计误差总和和有效节点数
    total_error = masked_error.sum()  # 总误差
    valid_count = mask_reduced.sum()  # 有效节点总数
    
    return total_error


In [None]:
def denormalize_columns(normalized_data, max_values, min_values):
    """
    逆归一化函数，支持 NumPy 数组和 PyTorch 张量。

    参数:
    normalized_data (numpy array or torch tensor): 归一化后的数据
    max_values (list or numpy array or torch tensor): 每列的最大值
    min_values (list or numpy array or torch tensor): 每列的最小值

    返回:
    逆归一化后的数据 (与输入类型相同)
    """
    # 如果数据是 PyTorch Tensor，确保 max_values 和 min_values 也是 Tensor
    if isinstance(normalized_data, torch.Tensor):
        max_values = torch.tensor(max_values, dtype=normalized_data.dtype, device=normalized_data.device)
        min_values = torch.tensor(min_values, dtype=normalized_data.dtype, device=normalized_data.device)
        return normalized_data * (max_values - min_values) + min_values
    
    # 否则，按 NumPy 方式计算
    max_values = np.array(max_values)
    min_values = np.array(min_values)
    return normalized_data * (max_values - min_values) + min_values


<span style = 'color:red;font-size:25px'>模型训练

In [None]:
def train_lunwen(net, lr, epochs, input_len, pred_len, file_path, file_vpath, file_pathS,
                 weight_decay, max_values, min_values, batch_size, hidden_dimT):
    """训练序列到序列模型：精简版"""

    def xavier_init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)

    # 初始化权重
    net.apply(xavier_init_weights)
    
    if torch.cuda.is_available():
        device = torch.device('cuda')
    torch.cuda.init()
    # device = 'cpu'
    
    net.to(device)
    
    optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)  # 每10个epoch学习率减半
    loss = nn.L1Loss()
    
    # 读取并转换静态数据
    dataS = extract_mmsi_features(file_pathS)
    dataS = transform_data(dataS)
    
    net.train()
    animator = Animator(xlabel='epoch', ylabel='loss', yscale='log', xlim=[0, epochs],
                        legend=['train', 'valid'])
    
    for epoch in range(epochs):
        timer = d2l.Timer()
        metric = d2l.Accumulator(2)         # 记录训练中 L1Loss
        metric_mae = d2l.Accumulator(2)    # 记录训练中 MAE
        metric_ade = d2l.Accumulator(2) # 记录验证集 MAE
        metric_mvalid = d2l.Accumulator(2) # 记录验证集 MAE
        metric_made = d2l.Accumulator(2) # 记录验证集 ADE
        timer.start()
        
        # ------------------ 训练 ------------------
        for slice_y in slice_data_generator(file_path, input_len, pred_len, batch_size):
            
            # 梯度清零
            optimizer.zero_grad()
            if len(slice_y) != batch_size:
                break
            
            input_X, F, output_Y, S, _, input_x_all, Static_result_all = process_data(
                slice_y, input_len, pred_len, dataS
            )
            if not input_x_all or all(len(batch) == 0 for batch in input_x_all):
                break

            In_x, masks_x = pad_output(input_x_all)
            In_xr = In_x.reshape(-1, In_x.shape[3]).to(device)
            masks_xr = masks_x.reshape(-1).bool()
            np.seterr(divide='ignore')
            distances = (1 / haversine_distances(In_xr)).to(device)
            Out_Y, masks_Y0 = pad_outputy(output_Y)
            if Out_Y.shape[0] == 0:
                break
            In_x1, Out_Y1 = In_x[:, :, :, :2].to(device), Out_Y[:, :, :, :2].to(device)
            n_x = torch.cat((In_x1[:, 0].unsqueeze(1), In_x1[:, :-1]), dim=1)
            n_y = torch.cat((Out_Y1[:, 0].unsqueeze(1), Out_Y1[:, :-1]), dim=1)
            masks_y = masks_Y0.to(device)
            in_x = normalize_datat(In_x, max_values, min_values)
            Out_Y = normalize_datat(Out_Y, max_values, min_values)
            sta_x, _ = pad_output(Static_result_all)
            Y_h, _ = net(in_x, sta_x, device, F, S, batch_size, masks_x, distances)  
            masks_Y = masks_y.bool().unsqueeze(-1).expand_as(Out_Y1)
            masked_Y = Out_Y * masks_Y
            masked_Y1 = Out_Y1 * masks_Y
            masked_Y_hat = Y_hat * masks_Y
            l = loss(masked_Y, masked_Y_hat)
            
            with torch.no_grad():
                masked_Y_ht = denormalize_data(masked_Y_hat, max_values, min_values)
                masked_Y_ht = masked_Y_ht * masks_Y
                ADE_ = calculate_ade(masked_Y_ht, masked_Y1, masks_Y)
                mae = torch.abs(masked_Y_ht - masked_Y1) * masks_Y
                valid_count = masks_Y.sum().item()
                metric.add(l.sum(), valid_count)
                metric_mae.add(mae.sum(), valid_count) 
                metric_ade.add(ADE_.sum(), valid_count // 2) 
            l.sum().backward()
            optimizer.step()
        scheduler.step()  
        
        # ------------------ 验证 ------------------
        with torch.no_grad():
            for slice_vy in slice_data_generator(file_vpath, input_len, pred_len, batch_size):
                if len(slice_vy) != batch_size:
                    break

                input_vX, v_F, output_vY, v_S, _, input_vx_all, Static_vresult_all = process_data(
                    slice_vy, input_len, pred_len, dataS
                )
                if not input_vx_all or all(len(batch) == 0 for batch in input_vx_all):
                    break
    
                In_vx, masks_vx = pad_output(input_vx_all)
                In_vxr = In_vx.reshape(-1, In_vx.shape[3]).to(device)
                masks_vxr = masks_vx.reshape(-1).bool()
                np.seterr(divide='ignore')
                distances_v = (1 / haversine_distances(In_vxr)).to(device)
                
                Out_vY, masks_vY0 = pad_outputy(output_vY) 
                if Out_vY.shape[0] == 0:
                    break
                
                In_vx1, Out_vY1 = In_vx[:, :, :, :2].to(device), Out_vY[:, :, :, :2].to(device)
                n_vx = torch.cat((In_vx1[:, 0].unsqueeze(1), In_vx1[:, :-1]), dim=1)
                n_vy = torch.cat((Out_vY1[:, 0].unsqueeze(1), Out_vY1[:, :-1]), dim=1)
                
                masks_vy = masks_vY0.to(device)
                in_vx = normalize_datat(In_vx, max_values, min_values)
                sta_vx, _ = pad_output(Static_vresult_all)
    
                vY_h, _ = net(in_vx, sta_vx, device, v_F, v_S, batch_size, masks_vx, distances_v)  
                vY_hat = denormalize_data(vY_h, max_values, min_values)
    
                masks_vY = masks_vy.bool().unsqueeze(-1).expand_as(Out_vY1)
                masked_vY = Out_vY1 * masks_vY
                masked_vY_hat = vY_hat * masks_vY
                vADE_ = calculate_ade(masked_vY_hat, masked_vY, masks_vY)
                mae_v = torch.abs(masked_vY_hat - masked_vY) * masks_vY
                valid_vcount = masks_vY.sum().item()
                metric_mvalid.add(mae_v.sum(), valid_vcount)
                metric_made.add(vADE_.sum(), valid_vcount // 2)
        if (epoch + 1) % 1 == 0:
                animator.add(epoch + 1, ((metric_mae[0]/metric_mae[1], metric_mvalid[0]/metric_mvalid[1])))
        print(f'epoch{epoch+1}train_loss:{metric_mae[0]/metric_mae[1]}')
        print(f'epoch{epoch+1}valid_loss:{metric_mvalid[0]/metric_mvalid[1]}')
        print(f'epoch{epoch+1}train_loss:{metric_ade[0]/metric_ade[1]}')
        print(f'epoch{epoch+1}valid_loss:{metric_made[0]/metric_made[1]}')
        print(f"Epoch {epoch + 1}, Time: {timer.stop():.1f} seconds")

In [None]:
input_len = 48  # 输入时间长度
pred_len = 24   # 预测时间长度
num_heads = 2  # 图神经注意力头数
embedding_dim1 = 2  # 船型嵌入特征维度
embedding_dim2 = 4  # 其他静态特征嵌入维度
hidden_dimS = 32   # 异构图特征隐藏维度
hidden_dimT = 128 # TCN特征隐藏维度
num_layersS = 2   # 异构图层数
num_layersT = 5   # TCN1层数
input_dimD = 5  # 异构图神经动态特征输入维度
hidden_dimD = 12
input_dimS = output_dimS = hidden_sizel = embedding_dim2 * 4  # 异构图神经静态输入数据 
num_layersl = 2   # LSTM层数
num_layersl2 = 4
hidden_sizel2 = 128
kernel_size = 3  # 卷积核尺寸
ffn_num_hiddens = 256
num_headsT = 4 

In [None]:
net = H_data(input_len,pred_len,input_dimD,hidden_dimD,embedding_dim1,embedding_dim2,hidden_dimS,num_heads,num_layersl,
                 hidden_sizel,kernel_size,hidden_dimT,num_layersT,ffn_num_hiddens,num_headsT,num_layersl2, hidden_sizel2)

In [None]:
file_path =  # 训练数据集地址
file_vpath = # 验证数据集地址
file_tpath = # 测试数据集地址
file_pathS = # 静态数据

In [None]:
max_values = [57.1976249957228, 8.411638137361614, 29.770000000000003, 359.9, 180.0]    # 最大值列表
min_values = [54.959007898554255, 6.297554604493042, 0.5, 0.0, 0.0]                     # 最小值列表
weight_decay = 0   # 正则化
lr = 0.0001        # 学习率
epochs = 100       # 迭代次数
batch_size = 2     # 窗口数

In [None]:
train_lunwen(net, lr, epochs, input_len, pred_len, file_path, file_vpath, file_pathS, weight_decay, max_values, min_values, batch_size, hidden_dimT)

<span style = 'color:red;font-size:25px'>模型测试

In [None]:
def test_lunwen(net, input_len, pred_len, file_tpath, file_pathS, 
                max_values, min_values, batch_size):
    """训练序列到序列模型：精简版 + 吞吐量(轨迹/秒)统计"""

    # 设备选择与初始化
    if torch.cuda.is_available():
        device = torch.device('cuda')
        torch.cuda.init()
    else:
        device = torch.device('cpu')
    
    net.to(device)

    # 读取并转换静态数据
    dataS = extract_mmsi_features(file_pathS)
    dataS = transform_data(dataS)
    
    net.eval()

    timer = d2l.Timer()
    metric_tmse = d2l.Accumulator(2)
    metric_trmse = d2l.Accumulator(2)
    metric_tmae = d2l.Accumulator(2)
    metric_tade = d2l.Accumulator(2)

    total_pred_time = 0.0     # 总推理时间（秒）
    total_trajectories = 0.0  # 总轨迹数量（按你的口径：valid_tcount/24/2）

    with torch.no_grad():
        for slice_ty in slice_data_generator(file_tpath, input_len, pred_len, batch_size):
            if len(slice_ty) != batch_size:
                break

            input_tX, t_F, output_tY, t_S, _, input_tx_all, Static_tresult_all = process_data(
                slice_ty, input_len, pred_len, dataS
            )
            if not input_tx_all or all(len(batch) == 0 for batch in input_tx_all):
                break

            In_tx, masks_tx = pad_output(input_tx_all)
            In_txr = In_tx.reshape(-1, In_tx.shape[3])
            masks_txr = masks_tx.reshape(-1).bool()
            np.seterr(divide='ignore')
            distances_t = (1 / haversine_distances(In_txr)).to(device)
            
            Out_tY, masks_tY0 = pad_outputy(output_tY) 
            if Out_tY.shape[0] == 0:
                break
            
            In_tx1, Out_tY1 = In_tx[:, :, :, :2].to(device), Out_tY[:, :, :, :2].to(device)
            n_tx = torch.cat((In_tx1[:, 0].unsqueeze(1), In_tx1[:, :-1]), dim=1)
            n_ty = torch.cat((Out_tY1[:, 0].unsqueeze(1), Out_tY1[:, :-1]), dim=1)
            
            masks_ty = masks_tY0.to(device)
            in_tx = normalize_datat(In_tx, max_values, min_values)
            sta_tx, _ = pad_output(Static_tresult_all)

            # === 记录推理时间（GPU 同步确保计时准确）===
            if device.type == "cuda":
                torch.cuda.synchronize()
            start_time = time.time()

            tY_h, _, att_nt, att_fn = net(in_tx, sta_tx, device, t_F, t_S, batch_size, masks_tx, distances_t)

            if device.type == "cuda":
                torch.cuda.synchronize()
            end_time = time.time()
            # =========================================

            tY_hat = denormalize_data(tY_h, max_values, min_values)
            masks_tY = masks_ty.bool().unsqueeze(-1).expand_as(Out_tY1)
            masked_tY = Out_tY1 * masks_tY
            masked_tY_hat = tY_hat * masks_tY

            tADE_ = calculate_ade(masked_tY_hat, masked_tY, masks_tY)
            mae_t = torch.abs(masked_tY_hat - masked_tY) * masks_tY
            mse_t = calculate_mse(masked_tY, masked_tY_hat)
            rmse_t = torch.sqrt(mse_t)
            valid_tcount = masks_tY.sum().item()

            # ====== 推理耗时与轨迹统计 =====
            total_pred_time += (end_time - start_time)
            # 维持你原本的换算口径（可按需要自行修改）
            total_trajectories += valid_tcount / 24.0 / 2.0
            # =================================

            metric_tmse.add(mse_t.sum(), valid_tcount)
            metric_trmse.add(rmse_t.sum(), valid_tcount)
            metric_tmae.add(mae_t.sum(), valid_tcount)
            metric_tade.add(tADE_.sum(), valid_tcount / 2.0)

    print(f'test_mae:  {metric_tmae[0]/metric_tmae[1]:.6f}')
    print(f'test_ade:  {metric_tade[0]/metric_tade[1]:.6f}m')
    print(f'test_mse:  {metric_tmse[0]/metric_tmse[1]:.6f}')
    print(f'test_rmse: {torch.sqrt(torch.tensor(metric_tmse[0]/metric_tmse[1])):.6f}')

    print(f"total_pred_time(s): {total_pred_time:.6f}")
    print(f"total_trajectories: {total_trajectories:.6f}")

    # 平均每条轨迹推理时间
    if total_trajectories > 0:
        avg_time_per_traj = total_pred_time / total_trajectories
        print(f'Average inference time per trajectory: {avg_time_per_traj:.6f} seconds')

        # 轨迹/秒吞吐量
        traj_per_sec = total_trajectories / total_pred_time if total_pred_time > 0 else float("inf")
        print(f'Throughput (trajectories/sec): {traj_per_sec:.6f}')
    else:
        print("No valid trajectory batches processed.")

In [None]:
test_lunwen(net, input_len, pred_len, file_tpath, file_pathS, max_values, min_values, batch_size)