In [None]:
import csv
import math
import torch
from d2l import torch as d2l
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.data as pyg_data
from torch_geometric.data import Data
from torch_geometric.data import HeteroData, Batch
from torch_geometric.transforms import ToUndirected
from torch_geometric.nn import GATConv, global_mean_pool, global_max_pool, global_add_pool
from torch_geometric.nn import to_hetero
import torch_geometric.transforms as T
from torch.masked import MaskedTensor
from functools import partial
from IPython import display
import torch.nn.utils.parametrizations as parametrizations  
from torch.nn.utils import weight_norm

<span style = 'color:red; font-size:25px'>MSE//RMSE

In [None]:
def calculate_mse(actual_values, predicted_values):
    squared_errors = [(actual - predicted) ** 2 for actual, predicted in zip(actual_values, predicted_values)]
    mse = sum(squared_errors) 
    return mse

<span style = 'color:red; font-size:25px'>ADE

In [None]:
def calculate_ade(predictions, ground_truth, mask):
    """
    Compute the ADE value
    param predictions: Predicted tensor with shape (batch_size, pred_time, num_nodes, feature_dim)
    param ground_truth: Ground-truth tensor with shape (batch_size, pred_time, num_nodes, feature_dim)
    param mask: Mask matrix with shape (batch_size, pred_time, num_nodes, feature_dim)
    return: The ADE value and the number of valid nodes
    """
    # 计算 L2 范数（欧氏距离）：每个时间步、每个节点上的预测误差
    displacement_error = torch.sqrt(torch.sum((predictions - ground_truth) ** 2, dim=-1))  # (batch_size, pred_time, num_nodes)
    
    # 将 mask 的最后一个维度降维以匹配 displacement_error
    mask_reduced = mask.any(dim=-1).float()  # (batch_size, pred_time, num_nodes)
    
    # 将填充部分的误差置为 0
    masked_error = displacement_error * mask_reduced  # (batch_size, pred_time, num_nodes)
    
    # 累计误差总和和有效节点数
    total_error = masked_error.sum()    # 总误差
    valid_count = mask_reduced.sum()  # 有效节点总数
    
    return total_error

<span style = 'color:red; font-size:18px'>Data extraction

In [None]:
def preprocess_data(file_path, start_row, chunk_size):
    """
    Read data from a CSV file in chunks, keeping missing values as missing values.
    
    param chunk_size: number of rows per chunk
    return: the loaded DataFrame
    """
    try:
        data = pd.read_csv(file_path, header=None, skiprows=start_row, nrows=chunk_size, low_memory=False)
        if len(data) < chunk_size:
            print(f"Data read is smaller than the expected chunk size ({chunk_size}).")
            return pd.DataFrame() 
        return data.fillna(np.nan) 
    except pd.errors.EmptyDataError:
        print("No more data to read. Exiting.")
        return pd.DataFrame()  

In [None]:
def slice_data_generator(file_path, input_len, pred_len, batch_size):
    """
    Read the CSV with a sliding window and generate batched data
    """
    total_len = input_len + pred_len
    chunk_size = total_len + batch_size - 1  # Total number of rows required for one batch
    start_row = 2  # Start reading from the third row (index 2)

    while True:
        # Step 1: Read the current batch of data
        data = preprocess_data(file_path, start_row=start_row, chunk_size=chunk_size)
        if data.empty:
            break 

        # Step 2: Select valid columns 
        data = data.iloc[:, 1:]
        data_values = data.values

        # Step 3: Use a sliding window to extract all windows from the current chunk
        Data = []
        for start_idx in range(0, len(data_values) - total_len + 1, 1):  # 步长为1
            Data.append(data_values[start_idx: start_idx + total_len])

        # Step 4: If no valid window can be extracted, stop
        if not Data:
            print("No valid data windows extracted. Exiting.")
            break

        yield Data
        
        start_row += chunk_size // 2

<span style = 'color:red; font-size:18px'>Min–max normalization

In [None]:
def normalize_datat(out_Y, max_values, min_values):
    """
    Apply min–max normalization to a tensor of shape (batch_size, time_steps, num_nodes, features).

    param out_Y: shape (batch_size, time_steps, num_nodes, features)
    param max_values: shape (features,)
    param min_values: shape (features,)
    return: 
        - normalized_data: normalized tensor with the same shape as out_Y
    """
    # Ensure `max_values` and `min_values` are tensors.
    max_values = torch.tensor(max_values, device=out_Y.device)
    min_values = torch.tensor(min_values, device=out_Y.device)

    # Check dimensions
    if max_values.dim() != 1 or min_values.dim() != 1:
        raise ValueError("max_values and min_values should be 1D tensors, giving the per-feature maxima and minima")

    if max_values.size(0) != out_Y.size(-1):
        raise ValueError("max_values and min_values must match the feature dimension")

    # Apply normalization
    normalized_data = (out_Y - min_values) / (max_values - min_values)

    return normalized_data

<span style = 'color:red; font-size:18px'>Inverse normalization

In [None]:
def denormalize_data(normalized_data, max_values, min_values):
    """
    Apply inverse normalization along the last dimension of a 3D tensor.
    
    param normalized_data: normalized data of shape (batch_size, num_points, num_features)
    param max_values: list of per-feature maximum values
    param min_values: list of per-feature minimum values
    return: inverse-normalized data with the same shape as the input
    """
    normalized_data = np.array(normalized_data)
    max_values = np.array(max_values)
    min_values = np.array(min_values)
    max_values = max_values[..., 0, 0]
    min_values = min_values[..., 0, 0]
    assert normalized_data.shape[-1] == max_values.shape[0], 
    
    denormalized_data = normalized_data * (max_values - min_values) + min_values
    
    return denormalized_data

<span style = 'color:red;font-size:25px'>Compute the haversine distance

In [None]:
def haversine_distances(points, radius=6371):
    """
    Compute the haversine distance between all node pairs, and set the main diagonal to 1.
    
    Parameters:
    points: tensor of shape (N, 2), each row is a point’s lat/lon [latitude, longitude]
    radius: Earth radius in kilometers, default 6371
    
    Return:
    distances: tensor of shape (N, N), pairwise great-circle distances in nautical miles
    """
    points_rad = points * torch.pi / 180.0  # Shape: (N, 2)
    latitudes = points_rad[:, 0].unsqueeze(1)  # Shape: (N, 1)
    longitudes = points_rad[:, 1].unsqueeze(1)  # Shape: (N, 1)
    dlat = latitudes - latitudes.T  
    dlon = longitudes - longitudes.T  
    a = (torch.sin(dlat / 2) ** 2 +
         torch.cos(latitudes) * torch.cos(latitudes.T) * torch.sin(dlon / 2) ** 2)
    c = 2 * torch.arcsin(torch.sqrt(a))
    distances_km = radius * c  
    distances_nmi = distances_km / 1.852 
    distances_nmi.fill_diagonal_(1.0)
    
    return distances_nmi


In [None]:
def process_data(F_data1, input_len, pred_len, statics_features):
    total_len = input_len + pred_len  # Total window length
    batch_size = len(F_data1)
    sample_shape = np.array(F_data1[0]).shape
    num_nodes = sample_shape[1] // 5  
    statics_list = np.array([statics_features[str(i)] for i in range(num_nodes)], dtype=float)
    
    # Initialize outputs
    F_p_all = []
    input_X_all = []
    output_Y_all = []
    S_all = []
    static_result_all = []
    input_x_all = []
    Static_result_all = []
    for i in range(batch_size):     # Extract data for each batch
        F_data = np.array(F_data1[i])  # [T, N * 5]
        F_data_reshaped = F_data.reshape(total_len, num_nodes, 5)

        # Check which nodes are fully valid (no NaNs) over the entire window
        valid_node_mask = ~np.isnan(F_data_reshaped).any(axis=(0, 2))
        f_p = np.where(valid_node_mask)[0].tolist()
        F_p_all.append(f_p)  # Valid nodes in the current batch
 
        # Split inputs and outputs
        in_x_reshaped = F_data_reshaped[:input_len]
        out_y_reshaped = F_data_reshaped[input_len:] 
        valid_in_mask = ~np.isnan(in_x_reshaped).any(axis=2)
        all_valid_t, all_valid_nodes = np.where(valid_in_mask)  # Find time-step and node indices for valid nodes
        non_empty_features_all = in_x_reshaped[all_valid_t, all_valid_nodes, :]   # Extract indices for the current batch
        sta_result_all_nodes = statics_list[all_valid_nodes]   

        unique_t, counts = np.unique(all_valid_t, return_counts=True)  # Group data by time step
        split_indices = np.split(np.arange(len(all_valid_t)), np.cumsum(counts[:-1]))  # Split by the number of valid nodes at each time step

        # Restore the per-time-step data list
        input_X = [non_empty_features_all[idx] for idx in split_indices]
        S = [all_valid_nodes[idx] for idx in split_indices]
        static_result = [sta_result_all_nodes[idx] for idx in split_indices]

        input_X_all.append(input_X)
        S_all.append(S)
        static_result_all.append(static_result)

        # Build the prediction output output_Y
        if not f_p: 
            output_y = np.array([])
            input_x0 = np.array([])
            statics_list0 = np.array([])
        else:
            output_y = out_y_reshaped[:, f_p, :]  
            input_x0 = in_x_reshaped[:, f_p, :]
            Statics_list0 = statics_list[f_p, :]
            statics_list0 = np.repeat(Statics_list0[np.newaxis, :, :], input_x0.shape[0], axis=0)
        output_Y_all.append(output_y)
        input_x_all.append(input_x0)
        Static_result_all.append(statics_list0)
        
    return input_X_all, F_p_all, output_Y_all, S_all, static_result_all, input_x_all, Static_result_all

<span style = 'color:red;font-size:18px'>Extract static information (returned as a dictionary)

In [None]:
def extract_mmsi_features(file_path):
    """
    Extract MMSI entries and their feature values from the file, and replace MMSI keys with sequential indices.
    
    :param file_path: file path
    :return: a dictionary with sequential indices (starting from 1) as keys and feature-value lists as values
    """
    data = pd.read_csv(file_path, header=None)
    mmsi_row = data.iloc[0]
    feature_row = data.iloc[2]
    mmsi_features = {}
    
    # Iterate over MMSI entries and their feature values
    for col_idx, mmsi in enumerate(mmsi_row):
        if pd.notna(mmsi):  
            if mmsi not in mmsi_features:
                mmsi_features[mmsi] = []
            mmsi_features[mmsi].append(feature_row[col_idx])
    
    # Replace keys with sequential indices
    indexed_features = {}
    for index, (key, value) in enumerate(mmsi_features.items(), start=0):
        indexed_features[str(index)] = value  
    
    return indexed_features

<span style = 'color:red;font-size:18px'>Convert one-hot encodings to indices

In [None]:
def one_hot_to_index(one_hot_str):
    """
    Parse the one-hot encoded string into an index.
    """
    one_hot_list = list(map(int, one_hot_str.split(',')))
    return one_hot_list.index(1)

def transform_data(data_dict):
    """
    Convert the dictionary to a new format, replacing one-hot encodings with indices.
    """
    transformed_data = {}
    for key, value in data_dict.items():
        index = one_hot_to_index(value[0]) 
        transformed_data[key] = [index] + value[1:]
    return transformed_data

# <span style = 'color:red;font-size:18px'>Extract static feature data

In [None]:
def extract_sfeatures(S, features_dict):
    """
    Extract features for the node IDs in list S, and cast feature values to float.
    
    :param S: a list where each element is an array containing node IDs
    :param features_dict: a dict mapping node IDs to feature lists
    :return: a list of node-feature lists at each time step, with features as floats
    """
    result = []

    for node_ids in S:
        node_features = [
            [float(value) for value in features_dict.get(str(node_id), [])] 
            for node_id in node_ids
        ]
        result.append(node_features)

    return result

<span style = 'color:red;font-size:18px'>Convert to one-hot encoding

In [None]:
def convert_to_onehot(data, device, onehot_length = 20):
    """
    Convert the first feature of an input 2D tensor to a one-hot vector and concatenate it with the remaining features.
    
    :param data: input 2D tensor of shape [N, F], where N is the number of nodes and F is the number of features
    :param onehot_length: length of the one-hot vector
    :return: output tensor of shape [N, onehot_length + F - 1]
    """
    indices = data[:, 0].long()
    
    # Create the one-hot tensor
    onehot = torch.zeros(data.size(0), onehot_length, dtype=torch.float32, device = device) 
    onehot.scatter_(1, indices.unsqueeze(1), 1)  
    
    return onehot

<span style = 'color:red;font-size:25px'>Heterogeneous Maritime Graph (HMG)

In [None]:
class Spatial_GATD(nn.Module):
    def __init__(self, input_dim, hidden_dim, gat_heads):
        super(Spatial_GATD, self).__init__()
        self.gat1 = GATConv(input_dim, hidden_dim, heads=gat_heads, concat=True, add_self_loops=False)
        self.gat2 = GATConv(hidden_dim * gat_heads, input_dim, heads=1, concat=True, add_self_loops=False)

    def forward(self, x, edge_index, edge_attr):

        x = self.gat1(x, edge_index, edge_attr = edge_attr) 
        x = F.elu(x)
        x = self.gat2(x, edge_index, edge_attr = edge_attr)

        return x

<span style = 'color:red;font-size:18px'>Heterogeneous-graph computation

In [None]:
class H_Model(torch.nn.Module):
    def __init__(self, hidden_dimS, num_heads, embedding_dim2):
        super(H_Model, self).__init__()
        self.gatD = Spatial_GATD(hidden_dimS, hidden_dimS, num_heads)

        metadataD = (
                          ['DYA', 'STA'],  # Node types
                          [
                              ('DYA', 'DD', 'DYA'),    # Edge type DD: DYA → DYA
                              ('DYA', 'DS', 'STA'),    # Edge type DS: DYA → STA
                              ('STA', 'rev_DS', 'DYA'), # Edge type rev_DS: STA → DYA
                              ('STA', 'SS', 'STA')    # Edge type SS: STA → STA
                          ]
                    )
        
        self.gatD = to_hetero(self.gatD, metadata=metadataD)

    def forward(self, xd_dict, data_D):
        edge_indexD = data_D.edge_index_dict
        edge_attrD = data_D.edge_attr_dict
        x_dictd = self.gatD(xd_dict, edge_indexD, edge_attrD)
        
        return x_dictd['DYA'], x_dictd['STA']

<span style = 'color:red;font-size:20px'>Build the heterogeneous-graph dataset

In [None]:
def HeteroGraphBuilder_batch(dynamic_tensor, static_tensor, masks, device, distances):
    """
    Build heterogeneous-graphdata from inputs of shape (batch_size, time_steps, num_nodes, feature_dim)
    Dynamic and static nodes have the same count but different feature dimensions.
    Also generate the corresponding batch vector
    param dynamic_tensor: dynamic-node feature tensor (batch_size, time_steps, num_nodes, dynamic_feat_dim)
    param static_tensor: static-node feature tensor (batch_size, time_steps, num_nodes, static_feat_dim)
    param masks: mask tensor (batch_size, time_steps, num_nodes)，where 0 indicates nodes to be removed
    param device: device
    return: dataD (HeteroData), batch (Tensor)
    """    
    B, T, N, F_d = dynamic_tensor.shape
    _, _, _, F_s = static_tensor.shape
    
    dynamic_tensor = dynamic_tensor.to(device)
    static_tensor = static_tensor.to(device)
    masks = masks.to(device)
    
    total_subgraphs = B * T  # Total number of subgraphs
    total_nodes = B * T * N

    # Flatten global features for dynamic and static nodes
    X_D_global = dynamic_tensor.reshape(total_subgraphs * N, F_d)     # shape(batch_size * time_steps * num_nodes, feature_dim)
    X_S_global = static_tensor.reshape(total_subgraphs * N, F_s)      # shape(batch_size * time_steps * num_nodes, feature_dim)

    # Generate the batch vector
    batch_vector = torch.repeat_interleave(torch.arange(total_subgraphs, device=device), N)

    # Node indices for a single subgraph
    node_ids_local = torch.arange(N, dtype=torch.long, device=device)

    # Edge indices for a single subgraph (same logic for dynamic and static nodes), single edges
    if N > 1:
        edges_local = torch.combinations(node_ids_local.cpu(), r=2, with_replacement=False).T.to(device)   # Fully connect nodes
    else:
        edges_local = torch.empty((2, 0), dtype=torch.long, device=device)
    self_loops = torch.stack([node_ids_local, node_ids_local], dim=0)
    edges_local = torch.cat([edges_local, self_loops], dim=1)  # (2, num_edges)
    
    # Compute the global offset
    offsets = torch.arange(total_subgraphs, device=device) * N   
    
    # Generate all possible edges (including self-loops)
    cross_edges_local = torch.cartesian_prod(node_ids_local, node_ids_local).T.to(device)  # (2, N*N)
    
    # Remove self-loop edges
    mask = cross_edges_local[0] != cross_edges_local[1]
    cross_edges_local = cross_edges_local[:, mask]  # (2, N*(N-1))
    offsets_flattened = offsets.flatten()
    
    # Expand local edge indices to all subgraphs
    cross_edges_expanded = (
        cross_edges_local.unsqueeze(1) + offsets_flattened.unsqueeze(0).unsqueeze(-1)
    ).reshape(2, -1) 
    edges_expanded = (edges_local.unsqueeze(1) + offsets.unsqueeze(0).unsqueeze(-1)).reshape(2, -1)   # Generate global edges (same node across subgraphs)
    
    # Construct the HeteroData object
    dataD = HeteroData()
    
    dataD['DYA'].x = X_D_global
    dataD['DYA', 'DD', 'DYA'].edge_index = edges_expanded
    dataD['STA'].x = X_S_global
    dataD['DYA', 'DS', 'STA'].edge_index = cross_edges_expanded
    edgeD_DD = dataD['DYA', 'DD', 'DYA'].edge_index 
    edgeD_DS = dataD['DYA', 'DS', 'STA'].edge_index  

    # Locate padded nodes (values equal to 0); strictly keep the original node IDs and do not change them.
    flattened_tensor = masks.flatten()
    zero_indices = (flattened_tensor == 0).nonzero(as_tuple=False).squeeze()
    Mask = torch.ones(batch_vector.size(0), dtype=torch.bool, device=device)
    Mask[zero_indices] = False  
    batch = batch_vector[Mask]

    # Filter edges for dynamic nodes
    maskD_DD = ~(
        torch.isin(edgeD_DD[0], zero_indices) |
        torch.isin(edgeD_DD[1], zero_indices)
    )
    maskD_DS = ~(
        torch.isin(edgeD_DS[0], zero_indices) |
        torch.isin(edgeD_DS[1], zero_indices)
    )
   
    # Remove unused nodes and reconnect edges
    edgeD_DD = edgeD_DD[:, maskD_DD]
    edgeD_DS = edgeD_DS[:, maskD_DS]
  
    # Remap edge indices
    dataD['DYA', 'DD', 'DYA'].edge_index = edgeD_DD
    dataD['DYA', 'DS', 'STA'].edge_index = edgeD_DS
    dataD['STA', 'SS', 'STA'].edge_index = edgeD_DD

    # Assign edge attributes
    adj_DD = distances[edgeD_DD[0], edgeD_DD[1]]
    adj_DS = distances[edgeD_DS[0], edgeD_DS[1]]
    dataD['DYA', 'DD', 'DYA'].edge_attr = adj_DD
    dataD['DYA', 'DS', 'STA'].edge_attr = adj_DS
    dataD['STA', 'SS', 'STA'].edge_attr = adj_DD
    
    # Convert to an undirected graph
    dataD = ToUndirected()(dataD)

    return dataD, batch

<span style = 'color:red;font-size:18px'>Get data suitable for iTransformer input

In [None]:
def process_tensor_and_extract_features(tensor, batch_size, input_len, masks, device):
    """
    Split the input tensor into multiple time steps based on the batch values, and extract per-time-step features using indices in F.
    :param tensor: input tensor of shape (batch_size, time_steps, total_nodes, feature_dim)
    :param masks: tensor of shape (batch_size, time_steps, total_nodes)
    :param batch: batch indices of shape (total_nodes,)
    :param F: index list of prediction nodes in the original full data, shape (batch_size, node_id)
    :param S: per-time-step node indices in the original data, shape (batch_size, time_steps, node_id)
    :return: a 3D tensor of shape (time_steps, num_nodes, feature_dim)
    """
    # 1) Extract the real data for num_nodes per batch
    result_list = []
    for b in range(batch_size):
        mask_b = masks[b] 
        valid_nodes = mask_b.any(dim=0)  
        valid_node_indices = torch.where(valid_nodes)[0]  # Valid node indices
        tensor_valid = tensor[b, :, valid_node_indices, :]  # shape: (T, num_valid_nodes, F)
        result_list.append(tensor_valid)
    
    # 2) Concatenate results and reshape to the target shape
    final_result = torch.cat(result_list, dim=1).permute(1, 0, 2).to(device) # shape: (B * N, T, F)
    final_result = final_result.permute(0, 2, 1) # shape: (B * N, F, T)

    return final_result.reshape(final_result.shape[0], final_result.shape[1], 1, final_result.shape[2])

<span style = 'color:red;font-size:25px'>LSTM

In [None]:
class LSTMNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, bidirectional=False, dropout=0.0):
        super(LSTMNetwork, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional

        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=hidden_size,
                            num_layers=num_layers,
                            batch_first=True,
                            bidirectional=bidirectional,
                            dropout=dropout if num_layers > 1 else 0.0)
    
    def forward(self, x):

        output, (h_n, c_n) = self.lstm(x)
        
        return output, h_n, c_n

<span style = 'color:Red;font-size:25px'>Dual-Axis Attention (DAA)

<span style = 'color:blue;font-size:20px'>Average pooling

In [None]:
def masked_average_pooling(x, mask, pool_dim, device):
    """
    x: input tensor of shape (batch_size, feature_dim, num_nodes, time_steps)
    mask: mask tensor of shape (batch_size, time_steps, num_nodes)
    pool_dim: pooling dimension, 0 -> num_nodes, 1 -> time_steps, 2 -> feature_dim
    return: average-pooled tensor
    """
    x = x.to(device)
    mask = mask.to(device)

    mask = mask.unsqueeze(1)  # reshape (B, 1, T, N)
    mask = mask.permute(0, 1, 3, 2)  
    mask = mask.type_as(x)
    x_masked = x * mask 
    
    # Compute the number of valid elements
    mask_sum = mask.sum(dim=2)  
    mask_sum = mask_sum.clamp(min=1e-6)  
    
    # Pool along different dimensions
    if pool_dim == 2:  # Pool along the nodes dimension
        x_sum = x_masked.sum(dim=2)  
        x_avg = x_sum / mask_sum
        
    elif pool_dim == 3:  # Pool along the time_steps dimension
        x_sum = x_masked.sum(dim=3) 
        mask_sum = mask.sum(dim=3)  
        x_avg = x_sum / mask_sum.clamp(min=1e-6)
        
    elif pool_dim == 1:  # Pool along the feature_dim dimension
        mask_expanded = mask.expand(-1, x.size(1), -1, -1)  
        mask_sum = mask_expanded.sum(dim=1)
        mask_sum = mask_sum.clamp(min=1e-6)  
        x_sum = x_masked.sum(dim=1)  
        x_avg = x_sum / mask_sum
    return x_avg  # shape (batch_size, feature_dim, num_nodes, time_steps)

<span style = 'color:blue;font-size:20px'>Max pooling

In [None]:
def masked_max_pooling(x, mask, pool_dim, device):
    """
    x: input tensor of shape (batch_size, feature_dim, num_nodes, time_steps)
    mask: mask tensor of shape (batch_size, time_steps, num_nodes)
    pool_dim: pooling dimension, 0 -> num_nodes, 1 -> time_steps, 2 -> feature_dim
    return: max-pooled tensor
    """
    mask = mask.unsqueeze(1)  
    mask = mask.permute(0, 1, 3, 2)  
    mask = mask.to(device)
    x = x.to(device)
    x_masked = x.masked_fill(mask == 0, float('-inf'))
    
    # Pool along different dimensions
    if pool_dim == 2:  # Pool along the nodes dimension
        x_max, _ = x_masked.max(dim=2)  
        x_max = x_max.masked_fill(mask.sum(dim=2) == 0, 0)

    elif pool_dim == 3:  # Pool along the time_steps dimension
        x_max, _ = x_masked.max(dim=3) 
        x_max = x_max.masked_fill(mask.sum(dim=3) == 0, 0)
        
    elif pool_dim == 1:  # Pool along the feature_dim dimension
        x_max, _ = x_masked.max(dim=1) 
        x_max = x_max.masked_fill(mask.sum(dim=1) == 0, 0)
    return x_max  # shape (batch_size, feature_dim, num_nodes, time_steps)

In [None]:
class AttentionModule(nn.Module):
    def __init__(self, kernel_size, pool_dim1, pool_dim2, input_channels = 1):
        super(AttentionModule, self).__init__()
        self.pool_dim1 = pool_dim1
        self.pool_dim2 = pool_dim2

        # Nodes–time interaction
        self.nt_conv = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=kernel_size, padding = (kernel_size - 1) // 2),
            nn.BatchNorm2d(1),
            nn.ReLU(),
            nn.Sigmoid()
        )
        
        # Feature–nodes interaction
        self.fn_conv = nn.Sequential(
            nn.Conv2d(2 * input_channels, input_channels, kernel_size=kernel_size, padding = (kernel_size - 1) // 2),
            nn.BatchNorm2d(1),
            nn.ReLU(),
            nn.Sigmoid()
        )
        
    def forward(self, x, mask, device):
        """
        x: input tensor of shape (batch_size, feature_dim, num_nodes, time_steps)
        mask: mask tensor of shape (batch_size, time_steps, num_nodes)
        pool_dim: pooling dimension, 0 -> num_nodes, 1 -> time_steps, 2 -> feature_dim
        return: max-pooled tensor
        """
        # nodes - time_steps
        avg_nt = masked_average_pooling(x, mask, self.pool_dim1, device)
        max_nt = masked_max_pooling(x, mask, self.pool_dim1, device)
        avg_nt = avg_nt.unsqueeze(1)
        max_nt = max_nt.unsqueeze(1)
        am_nt = torch.cat((avg_nt, max_nt), dim = 1)
        att_nt = self.nt_conv(am_nt) # shape(B, 1, N, T)
        x_ant = att_nt * x
        
        # feature - nodes
        avg_fn = masked_average_pooling(x, mask, self.pool_dim2, device)
        max_fn = masked_max_pooling(x, mask, self.pool_dim2, device)
        avg_fn = avg_fn.unsqueeze(1)
        max_fn= max_fn.unsqueeze(1)
        am_fn = torch.cat((avg_fn, max_fn), dim = 1)
        att_fn = self.fn_conv(am_fn) # shape(B, 1, F, N)
        
        data_pfn = x.permute(0, 3, 1, 2)
        weighted_dfn = att_fn * data_pfn
        x_afn = weighted_dfn.permute(0, 2, 3, 1)
        
        return x_ant + x_afn  # shape(B, F, N, T)

<span style = 'color:red;font-size:25px'>iTransformer

In [None]:
class GeLU(nn.Module):
    def forward(self, input_tensor):
        return 0.5 * input_tensor * (1 + torch.tanh(math.sqrt(2 / math.pi) * (input_tensor + 0.044715 * torch.pow(input_tensor, 3))))

In [None]:
class AddNorm(nn.Module):
    "Apply layer normalization after the residual connection."
    def __init__(self, normalized_shape, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)

    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)

In [None]:
class PositionWiseFFN(nn.Module):
    "Position-wise feed-forward network."
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
                 **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.gelu = GeLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.gelu(self.dense1(X)))

In [None]:
class DotProductAttention(nn.Module):
    "Scaled dot-product attention."
    
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, queries, keys, values):
        d = queries.shape[-1]
        scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
        self.attention_weights = nn.functional.softmax(scores, dim = -1)
        return torch.bmm(self.dropout(self.attention_weights), values)

In [None]:
def transpose_qkv(X, num_heads):
    "Reshape for parallel computation across multiple attention heads."
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
    X = X.permute(0, 2, 1, 3)
    return X.reshape(-1, X.shape[2], X.shape[3])

def transpose_output(X, num_heads):
    "Reverse the operations of transpose_qkv."
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)

In [None]:
class MultiHeadAttention_1(nn.Module):
    "Multi-head attention."
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention_1, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        self.W_k = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        self.W_v = nn.Linear(num_hiddens, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values):
        
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)
        output = self.attention(queries, keys, values)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)

In [None]:
class EncoderBlock_1(nn.Module):
    "Transformer encoder block."
    def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape,
                 ffn_num_input, ffn_num_hiddens, ffn_num_outputs, num_heads, dropout, use_bias=False, **kwargs):
        super().__init__(**kwargs)
        self.attention = MultiHeadAttention_1(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X):
        Y = self.addnorm1(X, self.attention(X, X, X))
        return self.addnorm2(Y, self.ffn(Y))

In [None]:
class iTransformer_1(nn.Module):
    "Transformer encoder."
    def __init__(self, vocab_size, key_size, query_size, value_size,
                 num_hiddensT, norm_shape, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
                 num_headsT, num_layersT, pred_len, dropout = 0, use_bias=False, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.linear = nn.Linear(num_hiddensT, pred_len)
        self.fc = nn.Linear(vocab_size, num_hiddensT)
        self.blks = nn.Sequential()
        for i in range(num_layersT):
            self.blks.add_module("block"+str(i),
                EncoderBlock_1(key_size, query_size, value_size, num_hiddensT,
                             norm_shape, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
                             num_headsT, dropout, use_bias))
    def forward(self, X, *args):
        X = self.fc(X)
        for i, blk in enumerate(self.blks):
            X = blk(X)
        return self.linear(X)

<span style = 'color:red;font-size:18px'>Spatiotemporal fusion

In [None]:
class STE_block(nn.Module):
    def __init__(self, hidden_dimS, num_heads, embedding_dim2, num_layersl, hidden_sizel, kernel_size, 
                 hidden_dimT, num_layersT, ffn_num_hiddens, num_headsT, input_len, pred_len):
        super(STE_block, self).__init__()
       
        # Heterogeneous graph network.
        self.S_GAT1 = H_Model(hidden_dimS, num_heads, embedding_dim2)
        self.hidden_sizel = hidden_sizel
       
        # Temporal–feature attention mechanism.
        self.Lstm1 = LSTMNetwork(hidden_dimS, hidden_sizel, num_layersl)
        self.iTsf = iTransformer_1(input_len, input_len, input_len, input_len, hidden_dimT, hidden_dimT, 
                                   hidden_dimT, ffn_num_hiddens, hidden_dimT, num_headsT, num_layersT, pred_len)
        # Coordinate attention.
        self.att1 = AttentionModule(kernel_size, pool_dim1 = 1, pool_dim2 = 3)
        self.fc = nn.Linear(embedding_dim2 * 4, hidden_dimS)
        
    def forward(self, input_x, data_D, device, masks_x, distances): 
 
        batch_size, time_steps, num_nodes, _ = input_x.shape
        x_d, x_s = self.fc(data_D['DYA'].x), self.fc(data_D['STA'].x)
        x_dictD0 = {'DYA': x_d, 'STA': x_s}
        
        "Spatial feature mining"
        X_in = torch.cat((data_D['DYA'].x, data_D['STA'].x), dim = -1).reshape(batch_size, time_steps, num_nodes, -1)
        X_in = X_in.permute(0, 2, 3, 1) # reshape(batch_size, num_nodes, feature_dim, time_steps)
        x_dynamic0, x_static0 = self.S_GAT1(x_dictD0, data_D)  # reshape(batch_size * time_steps * 节点数, feature_dim)
        
        "Coordinate attention computation"
        X_dynamic = x_dictD0['DYA'].reshape(batch_size, time_steps, num_nodes, -1)
        X_static = x_dictD0['STA'].reshape(batch_size, time_steps, num_nodes, -1)
        X_SD = torch.cat((X_dynamic, X_static), dim = -1).permute(0, 3, 2, 1)
        X_SD0 = self.att1(X_SD, masks_x, device)  # reshape(batch_size, 2*feature_dim, num_nodes, time_steps)
        F = X_SD0.shape[1]
        X_SD0 = X_SD0.permute(0, 3, 2, 1).reshape(batch_size * time_steps * num_nodes, -1) # reshape(B*T*N, F)   
        
        "Feature fusion"
        X_D0, X_S0 = X_SD0[:, : F // 2], X_SD0[:, F // 2:]
        X_D1, X_S1 = x_dynamic0 + X_D0, x_static0 + X_S0
        X_D1 = X_D1.reshape(batch_size, time_steps, num_nodes, -1).permute(0, 2, 1, 3).reshape(batch_size * num_nodes, time_steps, -1)
        X_S1 = X_S1.reshape(batch_size, time_steps, num_nodes, -1).permute(0, 2, 1, 3).reshape(batch_size * num_nodes, time_steps, -1)
        _, h_d, c_d = self.Lstm1(X_D1)
        _, h_s, c_s = self.Lstm1(X_S1)
        
        "Long-range temporal modeling"
        X_SD1 = torch.cat((data_D['DYA'].x, data_D['STA'].x), dim = -1).reshape(batch_size, time_steps, num_nodes, -1).permute(0, 2, 3, 1)  
        X_SD1 = X_SD1.reshape(batch_size * num_nodes, -1, time_steps)                                  
       
        output = self.iTsf(X_SD1) # shape(batch_size * num_nodes, featrure_dim, time_steps)                                  
        
        # shape(batch_size * num_nodes, featrure_dim, time_steps)   
        return h_d + h_s, c_d + c_s, output

<span style = 'color:red;font-size:20px'>HSTDFormer decoder

In [None]:
class DecoderModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, output_dim):
        """
        Decoder model
        param input_dim: input feature dimension
        param hidden_dim: hidden-layer dimension
        param num_layers: number of LSTM layers
        param output_dim: output feature dimension
        """
        super(DecoderModel, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

    def forward(self, target, h, c):
        """
        Forward pass for training.
        param input_x: initial decoder input, shape (time_steps, batch_size, input_dim)
        param hidden: initial decoder hidden state, shape (1, batch_size, hidden_dim)
        param cell: initial decoder cell state, shape (1, batch_size, hidden_dim)
        param target: target sequence for training, shape (target_len, batch_size, input_dim)
        return: output sequence, shape (batch_size, target_len, output_dim)
        """
        batch_size = target.size(0)
        target_len = target.size(1)
        outputs = torch.zeros(batch_size, target_len, self.fc.out_features).to(h.device)
        hidden = h
        cell = c
        output, (hidden, cell) = self.lstm(target, (hidden.contiguous(), cell.contiguous())) 
        outputs = self.fc(output)  
        
        return outputs   #shape (batch_size, target_len, output_dim)

<span style = 'color:red;font-size:18px'>Heterogeneous-graph static feature embedding

In [None]:
class Embedding(torch.nn.Module):
    def __init__(self, embedding_dim2):
        super(Embedding, self).__init__()

        # Static-data encoder layer
        # Denmark
        self.embedding_layer1 = nn.Embedding(num_embeddings=21, embedding_dim=embedding_dim2)  # Vessel type
        self.embedding_layer2 = nn.Embedding(num_embeddings=65, embedding_dim=embedding_dim2)  # Length
        self.embedding_layer3 = nn.Embedding(num_embeddings=420, embedding_dim=embedding_dim2) # Width 
        self.embedding_layer4 = nn.Embedding(num_embeddings=15, embedding_dim=embedding_dim2)  # Draft
        
        # # California
        # self.embedding_layer1 = nn.Embedding(num_embeddings=91, embedding_dim=embedding_dim2)  
        # self.embedding_layer2 = nn.Embedding(num_embeddings=61, embedding_dim=embedding_dim2)  
        # self.embedding_layer3 = nn.Embedding(num_embeddings=401, embedding_dim=embedding_dim2) 
        # self.embedding_layer4 = nn.Embedding(num_embeddings=23, embedding_dim=embedding_dim2) 

        # # Houston
        # self.embedding_layer1 = nn.Embedding(num_embeddings=100, embedding_dim=embedding_dim2)  
        # self.embedding_layer2 = nn.Embedding(num_embeddings=67, embedding_dim=embedding_dim2)   
        # self.embedding_layer3 = nn.Embedding(num_embeddings=365, embedding_dim=embedding_dim2)  
        # self.embedding_layer4 = nn.Embedding(num_embeddings=23, embedding_dim=embedding_dim2)   
        
    def forward(self, dataD, device):
                
        dataD = dataD.to(device)
        self.xd_s = dataD['STA'].x

        # Static attribute encoding
        self.Xd_S1 = self.embedding_layer1(self.xd_s[:, 0].long())  # Vessel type
        self.Xd_S2 = self.embedding_layer2(self.xd_s[:, 1].long())  # Length
        self.Xd_S3 = self.embedding_layer3(self.xd_s[:, 2].long())  # Width 
        self.Xd_S4 = self.embedding_layer4(self.xd_s[:, 3].long())  # Draft

        self.Xd_S = torch.cat((self.Xd_S1, self.Xd_S2, self.Xd_S3, self.Xd_S4), dim=1)
        dataD['STA'].x = self.Xd_S
        
        return dataD

<span style = 'color:red;font-size:18px'>Prediction padding

In [None]:
def ipad_out(Y, masks, device):
    batch_size, pred_time_steps, max_nodes = masks.shape
    _, _, feature_dim = Y.shape
    
    # 1) Find valid nodes in each batch
    valid_nodes_masks = masks.all(dim=1)
    valid_nodes_indices = [torch.where(valid_nodes_masks[b])[0] for b in range(batch_size)] 
    
    # 2) Initialize an all-zero tensor
    output = torch.zeros(batch_size, max_nodes, pred_time_steps, feature_dim).to(device) 
    
    # 3) Iterate over each batch and fill the data
    start_idx = 0
    for b in range(batch_size):
        valid_indices = valid_nodes_indices[b] 
        num_valid_nodes = len(valid_indices)   
        Y_batch = Y[start_idx:start_idx + num_valid_nodes]  
        start_idx += num_valid_nodes
        output[b, valid_indices, :, :] = Y_batch 
    return output

<span style = 'color:red;font-size:25px'>Temporal encoding

In [None]:
def generate_positional_encoding(num_steps, hidden_dim, device):
    """
    Generate sinusoidal positional encoding.
    
    Parameters：
    num_steps: number of time steps
    hidden_dim: feature dimension
    device: compute device

    Return:
    positional encoding tensor of shape (num_steps, hidden_dim)
    """
    position = torch.arange(num_steps, dtype=torch.float, device=device).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, hidden_dim, 2, dtype=torch.float, device=device) * 
                         -(torch.log(torch.tensor(10000.0, device=device)) / hidden_dim))
    pe = torch.zeros(num_steps, hidden_dim, device=device)
    pe[:, 0::2] = torch.sin(position * div_term)  
    pe[:, 1::2] = torch.cos(position * div_term)  

    return pe

def add_positional_encoding(static_features, device):
    """
    Combine positional encoding with static-node features.

    Parameters:：
    static_features: input static-feature tensor of shape (batch_size * num_nodes, num_steps, hidden_dim)

    Return:
    static features with positional encoding added; shape unchanged
    """
    batch_node_size, num_steps, hidden_dim = static_features.shape 
    positional_encoding = generate_positional_encoding(num_steps, hidden_dim, device)   
    static_features_with_encoding = static_features + positional_encoding.unsqueeze(0)

    return static_features_with_encoding

# <span style = 'color:red'>Main model function

In [None]:
class H_data(nn.Module):
    def __init__(self,input_len,pred_len,input_dimD,hidden_dimD,embedding_dim1,embedding_dim2,hidden_dimS,num_heads,num_layersl,
                 hidden_sizel,kernel_size,hidden_dimT,num_layersT,ffn_num_hiddens,num_headsT,num_layersl2, hidden_sizel2):
        super(H_data, self).__init__()
        
        self.input_len = input_len
        self.pred_len = pred_len
        self.embed = Embedding(embedding_dim2)
        self.STE1 = STE_block(hidden_dimS, num_heads, embedding_dim2, num_layersl2, hidden_sizel2, kernel_size, 
                 hidden_dimT, num_layersT, ffn_num_hiddens, num_headsT, input_len, pred_len)   
        self.lstm = LSTMNetwork(input_dimD, hidden_sizel, num_layersl)
        self.decoder = DecoderModel(embedding_dim2 * 8, hidden_sizel2, num_layersl2, output_dim = 2)
        
    def forward(self, input_x, static_X, device, F, S, batch_size, masks_x, distances): 
        """
        input_x is a list of batches with length batch_size, where each element contains total_len time steps; 
        overall shape: (batch_size, time_steps, num_nodes, feature_dim)
        static_X_list shape: (batch_size, time_steps, num_nodes, feature_dim)
        F shape: (batch_size, node_id)
        S shape: (batch_size, time_steps, node_id)
        masks_x shape: (batch_size, time_steps, num_nodes)
        """
        batch_size, time_steps, num_nodes, _ = input_x.shape
        data_D, batch = HeteroGraphBuilder_batch(input_x, static_X, masks_x, device, distances)
        data_D = self.embed(data_D, device)
        X_dynamic = data_D['DYA'].x
        X_static = data_D['STA'].x
        x_dynamic1, x_static1 = X_dynamic.reshape(batch_size, time_steps, num_nodes, -1), X_static.reshape(batch_size, time_steps, num_nodes, -1)
        x_dynamic1, x_static1 = x_dynamic1.permute(0, 2, 1, 3), x_static1.permute(0, 2, 1, 3) 
        x_dynamic1, x_static1 = x_dynamic1.reshape(-1, time_steps, x_dynamic1.shape[-1]), x_static1.reshape(-1, time_steps, x_static1.shape[-1])
        
        # Use an LSTM to capture temporal features
        x_sta1 = add_positional_encoding(x_static1, device)
        x_dyn1, _, _ = self.lstm(x_dynamic1)
        feature_dim = x_dyn1.shape[-1]
        
        # reshape(batch_size*time_steps*num_nodes, -1)
        x_dyn1 = x_dyn1.reshape(batch_size, num_nodes, time_steps, -1).permute(0, 2, 1, 3).reshape(-1, feature_dim)
        x_sta1 = x_sta1.reshape(batch_size, num_nodes, time_steps, -1).permute(0, 2, 1, 3).reshape(-1, feature_dim)
        data_D['DYA'].x, data_D['STA'].x = x_dyn1, x_sta1

        h_D, c_D, X_din = self.STE1(input_x, data_D, device, masks_x, distances) # shape(B * N, F, T)
        X_din = X_din.permute(0, 2, 1)

        # Decode the outputs
        Y_h = self.decoder(X_din, h_D, c_D)  
        Y_hat = Y_h.reshape(batch_size, num_nodes, self.pred_len, -1) 
        masks_out = masks_x[:, :self.pred_len, :].to(device)

        return Y_hat, masks_out          # Y_hat shape:(batch_size, max_nodes, time_steps, feature_dim),
                                         # masks_out shape: (batch_size, time_steps, max_nodes)

<span style = 'color:red; font-size:18px'>Display training results in real time

In [None]:
class Animator:  #@save
    
    """在动画中绘制数据"""
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
                 figsize=(3.5, 2.5)):
        if legend is None:
            legend = []
        d2l.use_svg_display()
        self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [self.axes, ]
        self.config_axes = lambda: d2l.set_axes(
            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts

    def add(self, x, y):
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        display.display(self.fig)
        display.clear_output(wait=True)

# <span style = 'color:red;font-size:18px'>Pad the data

In [None]:
def pad_outputy(output_Y):
    """
    Pad an irregular list of tensors of shape (batch_size, time_steps, num_nodes, features) and generate a mask.
    
    :param output_Y: a list where each element has shape (time_steps, num_nodes, features))
    :return: 
        - padded_output: padded tensor of shape (batch_size, time_steps, max_nodes, features)
        - masks: padding mask of shape (batch_size, time_steps, max_nodes)
    """
    if not output_Y or any(len(y) == 0 for y in output_Y):
        return torch.empty(0), torch.empty(0)
    
    output_Y = [np.array(y) for y in output_Y if len(y) > 0]
    max_nodes = max(y.shape[1] for y in output_Y)
    feature_dim = output_Y[0].shape[2]
    max_time_steps = max(y.shape[0] for y in output_Y)
    batch_size = len(output_Y)

    padded_output = torch.zeros((batch_size, max_time_steps, max_nodes, feature_dim), dtype=torch.float32)
    masks = torch.zeros((batch_size, max_time_steps, max_nodes), dtype=torch.float32)
    
    for i, batch in enumerate(output_Y):
        t, n, f = batch.shape
        padded_output[i, :t, :n, :] = torch.tensor(batch, dtype=torch.float32)
        masks[i, :t, :n] = 1.0

    return padded_output, masks


In [None]:
def pad_output(output_Y):
    """
    Pad irregular input data of shape (batch_size, [list of time steps], num_nodes, features) and generate a mask.
    
    :param output_Y: a nested list of NumPy arrays or tensors, where each element is an array of shape (num_nodes, features)
    :return: 
        - padded_output: padded tensor of shape (batch_size, max_time_steps, max_nodes, feature_dim)
        - masks: padding mask of shape (batch_size, max_time_steps, max_nodes)
    """
    batch_size = len(output_Y)
    time_steps_list = [len(batch) for batch in output_Y]
    max_time_steps = max(time_steps_list)  # Maximum number of time steps

    # Find the maximum number of nodes and the feature dimension
    max_nodes = max([array.shape[0] for batch in output_Y for array in batch])
    feature_dim = max([array.shape[1] for batch in output_Y for array in batch])

    # Initialize the padded tensor and mask
    padded_output = torch.zeros((batch_size, max_time_steps, max_nodes, feature_dim), dtype=torch.float32)
    masks = torch.zeros((batch_size, max_time_steps, max_nodes), dtype=torch.float32)
    
    # Pad the data and generate the mask
    for i, batch in enumerate(output_Y):
        for t, array in enumerate(batch):
            num_nodes, num_features = array.shape
            padded_output[i, t, :num_nodes, :num_features] = torch.tensor(array, dtype=torch.float32)
            masks[i, t, :num_nodes] = 1.0  
    
    return padded_output, masks

<span style = 'color:red;font-size:18px'>Model training

In [None]:
def train_model(net, lr, epochs, input_len, pred_len, file_path, file_vpath, file_pathS,
                 weight_decay, max_values, min_values, batch_size, hidden_dimT):
    "Train the seq2seq model"
    
    def xavier_init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)

    net.apply(xavier_init_weights)
    
    if torch.cuda.is_available():
        device = torch.device('cuda')
    torch.cuda.init()
    
    net.to(device)
    
    optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)  
    loss = nn.L1Loss()
    
    # Load and convert static data
    dataS = extract_mmsi_features(file_pathS)
    dataS = transform_data(dataS)
    
    net.train()
    animator = Animator(xlabel='epoch', ylabel='loss', yscale='log', xlim=[0, epochs],
                        legend=['train', 'valid'])
    
    for epoch in range(epochs):
        timer = d2l.Timer()
        metric = d2l.Accumulator(2)         # Log L1 loss during training
        metric_mae = d2l.Accumulator(2)     # MAE during training
        metric_ade = d2l.Accumulator(2)     # ADE during training
        metric_mvalid = d2l.Accumulator(2)  # Log validation MAE 
        metric_made = d2l.Accumulator(2)    # Log validation ADE
        timer.start()
        
        # ------------------ Train------------------
        for slice_y in slice_data_generator(file_path, input_len, pred_len, batch_size):
            
            optimizer.zero_grad()
            if len(slice_y) != batch_size:
                break
            
            input_X, F, output_Y, S, _, input_x_all, Static_result_all = process_data(
                slice_y, input_len, pred_len, dataS
            )
            if not input_x_all or all(len(batch) == 0 for batch in input_x_all):
                break

            In_x, masks_x = pad_output(input_x_all)
            In_xr = In_x.reshape(-1, In_x.shape[3]).to(device)
            masks_xr = masks_x.reshape(-1).bool()
            np.seterr(divide='ignore')
            distances = (1 / haversine_distances(In_xr)).to(device)
            Out_Y, masks_Y0 = pad_outputy(output_Y)
            if Out_Y.shape[0] == 0:
                break
            In_x1, Out_Y1 = In_x[:, :, :, :2].to(device), Out_Y[:, :, :, :2].to(device)
            n_x = torch.cat((In_x1[:, 0].unsqueeze(1), In_x1[:, :-1]), dim=1)
            n_y = torch.cat((Out_Y1[:, 0].unsqueeze(1), Out_Y1[:, :-1]), dim=1)
            masks_y = masks_Y0.to(device)
            in_x = normalize_datat(In_x, max_values, min_values)
            Out_Y = normalize_datat(Out_Y, max_values, min_values)
            sta_x, _ = pad_output(Static_result_all)
            Y_h, _ = net(in_x, sta_x, device, F, S, batch_size, masks_x, distances)  
            masks_Y = masks_y.bool().unsqueeze(-1).expand_as(Out_Y1)
            masked_Y = Out_Y * masks_Y
            masked_Y1 = Out_Y1 * masks_Y
            masked_Y_hat = Y_hat * masks_Y
            l = loss(masked_Y, masked_Y_hat)
            
            with torch.no_grad():
                masked_Y_ht = denormalize_data(masked_Y_hat, max_values, min_values)
                masked_Y_ht = masked_Y_ht * masks_Y
                ADE_ = calculate_ade(masked_Y_ht, masked_Y1, masks_Y)
                mae = torch.abs(masked_Y_ht - masked_Y1) * masks_Y
                valid_count = masks_Y.sum().item()
                metric.add(l.sum(), valid_count)
                metric_mae.add(mae.sum(), valid_count) 
                metric_ade.add(ADE_.sum(), valid_count // 2) 
            l.sum().backward()
            optimizer.step()
        scheduler.step()  
        
        # ------------------ Valid ------------------
        with torch.no_grad():
            for slice_vy in slice_data_generator(file_vpath, input_len, pred_len, batch_size):
                if len(slice_vy) != batch_size:
                    break

                input_vX, v_F, output_vY, v_S, _, input_vx_all, Static_vresult_all = process_data(
                    slice_vy, input_len, pred_len, dataS
                )
                if not input_vx_all or all(len(batch) == 0 for batch in input_vx_all):
                    break
    
                In_vx, masks_vx = pad_output(input_vx_all)
                In_vxr = In_vx.reshape(-1, In_vx.shape[3]).to(device)
                masks_vxr = masks_vx.reshape(-1).bool()
                np.seterr(divide='ignore')
                distances_v = (1 / haversine_distances(In_vxr)).to(device)
                
                Out_vY, masks_vY0 = pad_outputy(output_vY) 
                if Out_vY.shape[0] == 0:
                    break
                
                In_vx1, Out_vY1 = In_vx[:, :, :, :2].to(device), Out_vY[:, :, :, :2].to(device)
                n_vx = torch.cat((In_vx1[:, 0].unsqueeze(1), In_vx1[:, :-1]), dim=1)
                n_vy = torch.cat((Out_vY1[:, 0].unsqueeze(1), Out_vY1[:, :-1]), dim=1)
                
                masks_vy = masks_vY0.to(device)
                in_vx = normalize_datat(In_vx, max_values, min_values)
                sta_vx, _ = pad_output(Static_vresult_all)
    
                vY_h, _ = net(in_vx, sta_vx, device, v_F, v_S, batch_size, masks_vx, distances_v)  
                vY_hat = denormalize_data(vY_h, max_values, min_values)
    
                masks_vY = masks_vy.bool().unsqueeze(-1).expand_as(Out_vY1)
                masked_vY = Out_vY1 * masks_vY
                masked_vY_hat = vY_hat * masks_vY
                vADE_ = calculate_ade(masked_vY_hat, masked_vY, masks_vY)
                mae_v = torch.abs(masked_vY_hat - masked_vY) * masks_vY
                valid_vcount = masks_vY.sum().item()
                metric_mvalid.add(mae_v.sum(), valid_vcount)
                metric_made.add(vADE_.sum(), valid_vcount // 2)
        if (epoch + 1) % 1 == 0:
                animator.add(epoch + 1, ((metric_mae[0]/metric_mae[1], metric_mvalid[0]/metric_mvalid[1])))
        print(f'epoch{epoch+1}train_loss:{metric_mae[0]/metric_mae[1]}')
        print(f'epoch{epoch+1}valid_loss:{metric_mvalid[0]/metric_mvalid[1]}')
        print(f'epoch{epoch+1}train_loss:{metric_ade[0]/metric_ade[1]}')
        print(f'epoch{epoch+1}valid_loss:{metric_made[0]/metric_made[1]}')
        print(f"Epoch {epoch + 1}, Time: {timer.stop():.1f} seconds")

In [None]:
input_len = 48  # Input sequence length
pred_len = 24   # Prediction horizon length
num_heads = 2  # Number of graph-attention heads
embedding_dim2 = 4  # Embedding dimension for other static features
hidden_dimS = 32   # Hidden dimension for heterogeneous-graph features
hidden_dimT = 128 # Hidden feature dimension
num_layersS = 2   # Heterogeneous-graph layers
num_layersT = 5   # iTransformer layers
input_dimD = 5  # Feature dimension of dynamic features 
hidden_dimD = 12
input_dimS = output_dimS = hidden_sizel = embedding_dim2 * 4  # Feature dimension of static input
num_layersl = 2   # LSTM layers
num_layersl2 = 4
hidden_sizel2 = 128
kernel_size = 3  # Kernel size
ffn_num_hiddens = 256
num_headsT = 4 

In [None]:
net = H_data(input_len,pred_len,input_dimD,hidden_dimD,embedding_dim1,embedding_dim2,hidden_dimS,num_heads,num_layersl,
                 hidden_sizel,kernel_size,hidden_dimT,num_layersT,ffn_num_hiddens,num_headsT,num_layersl2, hidden_sizel2)

In [None]:
file_path =  # Training dataset path
file_vpath = # Validation dataset path
file_tpath = # Test dataset path
file_pathS = # Static data path

In [None]:
max_values = [57.1976249957228, 8.411638137361614, 29.770000000000003, 359.9, 180.0]    # Max-value list
min_values = [54.959007898554255, 6.297554604493042, 0.5, 0.0, 0.0]                     # Min-value list
weight_decay = 0   # Regularization
lr = 0.0001        # Learning rate
epochs = 100       # Number of epochs
batch_size = 2     # Number of windows

In [None]:
train_model(net, lr, epochs, input_len, pred_len, file_path, file_vpath, file_pathS, weight_decay, max_values, min_values, batch_size, hidden_dimT)

<span style = 'color:red;font-size:18px'>Model Test

In [None]:
def test_model(net, input_len, pred_len, file_tpath, file_pathS, 
                max_values, min_values, batch_size):
    "Test the model"

    # Device selection and initialization.
    if torch.cuda.is_available():
        device = torch.device('cuda')
        torch.cuda.init()
    else:
        device = torch.device('cpu')
    
    net.to(device)

    # Load and convert static data
    dataS = extract_mmsi_features(file_pathS)
    dataS = transform_data(dataS)
    
    net.eval()

    timer = d2l.Timer()
    metric_tmse = d2l.Accumulator(2)
    metric_trmse = d2l.Accumulator(2)
    metric_tmae = d2l.Accumulator(2)
    metric_tade = d2l.Accumulator(2)

    total_pred_time = 0.0     # Total inference time (s)
    total_trajectories = 0.0  # Total number of trajectories (per your definition: valid_tcount / 24 / 2)

    with torch.no_grad():
        for slice_ty in slice_data_generator(file_tpath, input_len, pred_len, batch_size):
            if len(slice_ty) != batch_size:
                break

            input_tX, t_F, output_tY, t_S, _, input_tx_all, Static_tresult_all = process_data(
                slice_ty, input_len, pred_len, dataS
            )
            if not input_tx_all or all(len(batch) == 0 for batch in input_tx_all):
                break

            In_tx, masks_tx = pad_output(input_tx_all)
            In_txr = In_tx.reshape(-1, In_tx.shape[3])
            masks_txr = masks_tx.reshape(-1).bool()
            np.seterr(divide='ignore')
            distances_t = (1 / haversine_distances(In_txr)).to(device)
            
            Out_tY, masks_tY0 = pad_outputy(output_tY) 
            if Out_tY.shape[0] == 0:
                break
            
            In_tx1, Out_tY1 = In_tx[:, :, :, :2].to(device), Out_tY[:, :, :, :2].to(device)
            n_tx = torch.cat((In_tx1[:, 0].unsqueeze(1), In_tx1[:, :-1]), dim=1)
            n_ty = torch.cat((Out_tY1[:, 0].unsqueeze(1), Out_tY1[:, :-1]), dim=1)
            
            masks_ty = masks_tY0.to(device)
            in_tx = normalize_datat(In_tx, max_values, min_values)
            sta_tx, _ = pad_output(Static_tresult_all)

            # === Log inference time (GPU synchronization for accurate timing) ===
            if device.type == "cuda":
                torch.cuda.synchronize()
            start_time = time.time()
            tY_h, _, att_nt, att_fn = net(in_tx, sta_tx, device, t_F, t_S, batch_size, masks_tx, distances_t)

            if device.type == "cuda":
                torch.cuda.synchronize()
            end_time = time.time()
            # =========================================

            tY_hat = denormalize_data(tY_h, max_values, min_values)
            masks_tY = masks_ty.bool().unsqueeze(-1).expand_as(Out_tY1)
            masked_tY = Out_tY1 * masks_tY
            masked_tY_hat = tY_hat * masks_tY

            tADE_ = calculate_ade(masked_tY_hat, masked_tY, masks_tY)
            mae_t = torch.abs(masked_tY_hat - masked_tY) * masks_tY
            mse_t = calculate_mse(masked_tY, masked_tY_hat)
            rmse_t = torch.sqrt(mse_t)
            valid_tcount = masks_tY.sum().item()

            # ====== Inference time and trajectory statistics =====
            total_pred_time += (end_time - start_time)
            total_trajectories += valid_tcount / 24.0 / 2.0
            # =================================

            metric_tmse.add(mse_t.sum(), valid_tcount)
            metric_trmse.add(rmse_t.sum(), valid_tcount)
            metric_tmae.add(mae_t.sum(), valid_tcount)
            metric_tade.add(tADE_.sum(), valid_tcount / 2.0)

    print(f'test_mae:  {metric_tmae[0]/metric_tmae[1]:.6f}')
    print(f'test_ade:  {metric_tade[0]/metric_tade[1]:.6f}m')
    print(f'test_mse:  {metric_tmse[0]/metric_tmse[1]:.6f}')
    print(f'test_rmse: {torch.sqrt(torch.tensor(metric_tmse[0]/metric_tmse[1])):.6f}')

    print(f"total_pred_time(s): {total_pred_time:.6f}")
    print(f"total_trajectories: {total_trajectories:.6f}")

    # Average inference time per trajectory
    if total_trajectories > 0:
        avg_time_per_traj = total_pred_time / total_trajectories
        print(f'Average inference time per trajectory: {avg_time_per_traj:.6f} seconds')

        # Throughput (trajectories/s)
        traj_per_sec = total_trajectories / total_pred_time if total_pred_time > 0 else float("inf")
        print(f'Throughput (trajectories/sec): {traj_per_sec:.6f}')
    else:
        print("No valid trajectory batches processed.")

In [None]:
test_model(net, input_len, pred_len, file_tpath, file_pathS, max_values, min_values, batch_size)