In [2]:
from pandas import DataFrame
import pandas as pd
from torch.utils.data import DataLoader, Dataset, random_split
import torch
import numpy as np
import ast
import networkx as nx
import matplotlib.pyplot as plt

In [3]:
def pad_to_shape(matrix, target_shape=(24, 7)):
    # 获取当前矩阵的形状
    current_shape = matrix.shape
    
    # 计算在每个维度上需要添加多少个0
    padding = [(max(0, ts - cs), 0) for cs, ts in zip(current_shape, target_shape)]
    
    # 使用np.pad进行填充，'constant'表示用常数进行填充，默认为0
    padded_matrix = np.pad(matrix, padding, mode='constant')
    
    return padded_matrix

def convert_string(s):
    s = s.strip("[]' ")
    return ast.literal_eval(s)

def convert_string_to_list(data: pd.DataFrame, time_length: int=6, bin_length: int=4, padding: str=True):
    tmp_list = []
    for item in data.values:
        for value in item:
            tmp_list.append(list(convert_string(value)))
    tmp_list = np.array(tmp_list, dtype=np.float32)
    
    # if tmp_list.shape[0] < time_length * bin_length and padding:
    #     tmp_list = pad_to_shape(tmp_list, target_shape=(time_length * bin_length, 7))
    return tmp_list

def build_training_data(input_path: str, time_length: int=6, bin_length: int=4, pred_length: int=1, way: str='None'):
    data = pd.read_csv(input_path).iloc[:, 1:]
    
    train_features = [0, 1, 2, 3, 4, 5, 6, 7, 8]
    gnn_data, lstm_data, target_data = [], [], []
    
    # 构建输入数据
    for i in range(time_length, data.shape[0], 1):  # 行
        for j in range(bin_length, data.shape[1]-pred_length):  # 列
            if j < bin_length:
                tmp_matrix = convert_string_to_list(data.iloc[i-time_length:i, :j+1], padding= True)
            else:
                tmp_matrix = convert_string_to_list(data.iloc[i-time_length:i, j-bin_length:j], padding= True)
            
            gnn_data.append(tmp_matrix)

            lstm_data_tmp = pd.concat([data.iloc[i-2, j:data.shape[1]], data.iloc[i-1, :j]], axis=0)
            for value in lstm_data_tmp:
                lstm_data.append(convert_string(value))

            # lstm_data.append(pd.concat([data.iloc[i-2, j:data.shape[1]], data.iloc[i-1, :j]], axis=0).values)
            target_data.append(convert_string(data.iloc[i-1, j]))

    # 将 lstm_data 转换为 NumPy 数组
    gnn_data = np.array(gnn_data, dtype= np.float32)
    lstm_data = np.array(lstm_data, dtype=np.float32)
    lstm_data.resize((gnn_data.shape))
    target_data = np.array(target_data, dtype= np.float32)

    return gnn_data, lstm_data, target_data

class StockDataDataset(Dataset):
    def __init__(self,
                gnn_data: np.array,
                lstm_data: np.array,
                target_data: np.array,
                train_features: list= [0, 1, 2, 3, 4, 5, 6, 7, 8],
                pred_features: list= [0, 1, 2, 3, 4, 5, 6, 7, 8]):
        super().__init__()
        # data
        self.gnn_data = gnn_data
        self.lstm_data = lstm_data
        self.target_data = target_data

        # data features
        self.train_features = train_features
        self.pred_features = pred_features
    
    def __len__(self):
        return self.gnn_data.shape[0]

    def __getitem__(self, index):
        """
        先去把数据按照顺序切分好,然后根据index去找到切片
        """
        return self.gnn_data[index], self.lstm_data[index], self.target_data[index]

def gen_adjmatrix(time_length: int = 6, bin_length: int = 4, direct: bool = False) -> np.ndarray:
    num_nodes = time_length * bin_length
    adjacency_matrix = np.zeros((num_nodes, num_nodes), dtype=int)
    for i in range(num_nodes):
        row, col = divmod(i, bin_length)
        # 右侧邻居
        if col < bin_length - 1:
            right_neighbor = i + 1
            adjacency_matrix[i, right_neighbor] = 1
            if not direct:
                adjacency_matrix[right_neighbor, i] = 1
        # 下方邻居
        if row < time_length - 1:
            bottom_neighbor = i + bin_length
            adjacency_matrix[i, bottom_neighbor] = 1
            if not direct:
                adjacency_matrix[bottom_neighbor, i] = 1
    return adjacency_matrix

def draw_grid_with_features(adjacency_matrix, time_length, bin_length, features, feature_index=0):
    G = nx.DiGraph() if np.any(adjacency_matrix != adjacency_matrix.T) else nx.Graph()
    num_nodes = time_length * bin_length
    G.add_nodes_from(range(num_nodes))

    # 添加边
    for i in range(num_nodes):
        for j in range(num_nodes):
            if adjacency_matrix[i, j] == 1:
                G.add_edge(i, j)

    # 设置布局
    pos = {i: (i % bin_length, time_length - 1 - i // bin_length) for i in range(num_nodes)}
    
    # 绘制图，并在节点上显示选定的特征值
    plt.figure(figsize=(8, 6))
    labels = {i: f'{i}\n{features[i, feature_index]:.2f}' for i in range(num_nodes)}  # 创建带有选定特征值的标签
    nx.draw(G, pos, with_labels=True, node_color='skyblue', node_size=500, font_size=10, edge_color='gray', arrowsize=20, arrowstyle='-|>', connectionstyle='arc3,rad=0.1', labels=labels)
    plt.title('6x4 Grid Graph with Selected Feature')
    plt.show()

In [4]:
data_path = './data/volume/0308/Features/000046_25_daily_f_all.csv'
df = pd.read_csv(data_path)

gnn_data, lstm_data, target_data = build_training_data(input_path= data_path)
df.iloc[:10, :10]

Unnamed: 0,date,bin1,bin2,bin3,bin4,bin5,bin6,bin7,bin8,bin9
0,2020-05-13,"[-0.1641, -0.1641, -0.1641, 0.0188, -0.1641, -...","[-0.3463, -0.5104, -0.2552, 0.0397, -0.3463, -...","[-0.5711, -1.0815, -0.3605, 0.0654, -0.5711, -...","[-0.6848, -1.7662, -0.5341, 0.0785, -0.6848, -...","[-0.5939, -2.3601, -0.6166, 0.068, -0.5939, -0...","[-0.2299, -2.59, -0.5028, 0.0263, -0.2299, -0....","[-0.6891, -3.279, -0.5043, 0.079, -0.6891, -0....","[-0.4613, -3.7404, -0.4601, 0.0529, -0.4613, -...","[-0.5676, -4.3079, -0.5727, 0.065, -0.5676, -0..."
1,2020-05-14,"[-0.5371, -0.5371, -0.5371, 0.0378, -0.1641, -...","[-0.3869, -0.924, -0.462, 0.0403, -0.3463, -0....","[-0.5183, -1.4423, -0.4808, 0.0601, -0.5711, -...","[0.3645, -1.0778, -0.1802, 0.02, -0.6848, -0.1...","[0.1882, -0.8895, 0.0115, 0.0241, -0.5939, -0....","[-0.7078, -1.5974, -0.0517, 0.0506, -0.2299, -...","[-0.5059, -2.1033, -0.3418, 0.0662, -0.6891, -...","[-0.7217, -2.825, -0.6451, 0.0646, -0.4613, -0...","[-0.641, -3.466, -0.6229, 0.0664, -0.5676, -0...."
2,2020-05-15,"[1.5194, 1.5194, 1.5194, -0.0248, -0.1641, -0....","[-0.4838, 1.0357, 0.5178, 0.0428, -0.3463, -0....","[-0.497, 0.5387, 0.1796, 0.0564, -0.5711, -0.2...","[-0.6788, -0.1401, -0.5532, 0.0356, -0.6848, -...","[-0.738, -0.8782, -0.6379, 0.0403, -0.5939, -0...","[-0.7189, -1.597, -0.7119, 0.0574, -0.2299, -0...","[-0.6354, -2.2324, -0.6974, 0.0651, -0.6891, -...","[-0.7486, -2.9811, -0.701, 0.0677, -0.4613, -0...","[-0.7604, -3.7415, -0.7148, 0.0693, -0.5676, -..."
3,2020-05-18,"[-0.0519, -0.0519, -0.0519, -0.0292, -0.1641, ...","[0.1828, 0.1309, 0.0654, 0.023, -0.3463, -0.20...","[-0.2728, -0.1419, -0.0473, 0.0444, -0.5711, -...","[-0.6236, -0.7655, -0.2379, 0.0317, -0.6848, -...","[-0.622, -1.3875, -0.5061, 0.0398, -0.5939, -0...","[-0.7104, -2.0979, -0.652, 0.074, -0.2299, -0....","[-0.365, -2.4629, -0.5658, 0.0518, -0.6891, -0...","[-0.7249, -3.1878, -0.6001, 0.076, -0.4613, -0...","[-0.6217, -3.8095, -0.5705, 0.0698, -0.5676, -..."
4,2020-05-19,"[-0.343, -0.343, -0.343, -0.0375, -0.1641, -0....","[-0.4299, -0.7728, -0.3864, 0.0228, -0.3463, -...","[-0.1264, -0.8992, -0.2997, 0.03, -0.5711, -0....","[-0.0033, -0.9026, -0.1865, 0.0447, -0.6848, -...","[-0.7143, -1.6169, -0.2814, 0.0687, -0.5939, -...","[-0.5081, -2.125, -0.4086, 0.0648, -0.2299, -0...","[-0.5259, -2.6509, -0.5828, 0.0503, -0.6891, -...","[-0.5515, -3.2024, -0.5285, 0.0677, -0.4613, -...","[-0.4168, -3.6191, -0.498, 0.0602, -0.5676, -0..."
5,2020-05-20,"[-0.4853, -0.4853, -0.4853, 0.0267, -0.1641, -...","[-0.2843, -0.7696, -0.3848, 0.0152, -0.3463, -...","[-0.5056, -1.2752, -0.4251, 0.0285, -0.5711, -...","[-0.6839, -1.9591, -0.4913, 0.0424, -0.6848, -...","[-0.3943, -2.3534, -0.5279, 0.056, -0.5939, -0...","[0.1757, -2.1777, -0.3008, 0.0361, -0.2299, -0...","[-0.6302, -2.8079, -0.2829, 0.0479, -0.6891, -...","[-0.1868, -2.9947, -0.2138, 0.0485, -0.4613, -...","[-0.0752, -3.0699, -0.2974, 0.0374, -0.5676, -..."
6,2020-05-21,"[-0.3639, -0.3639, -0.3639, 0.0345, -0.5371, -...","[-0.3058, -0.6697, -0.3349, 0.0298, -0.3869, -...","[-0.4159, -1.0856, -0.3619, 0.0297, -0.5183, -...","[-0.3762, -1.4619, -0.366, 0.0301, 0.3645, -0....","[-0.5959, -2.0578, -0.4627, 0.0496, 0.1882, -0...","[-0.7775, -2.8353, -0.5832, 0.0313, -0.7078, -...","[-0.5738, -3.4091, -0.6491, 0.05, -0.5059, -0....","[-0.5705, -3.9796, -0.6406, 0.0378, -0.7217, -...","[-0.6933, -4.6729, -0.6125, 0.0335, -0.641, -0..."
7,2020-05-22,"[-0.4499, -0.4499, -0.4499, 0.0738, 1.5194, -0...","[0.0363, -0.4136, -0.2068, 0.0124, -0.4838, -0...","[0.2816, -0.132, -0.044, -0.0055, -0.497, -0.2...","[-0.494, -0.626, -0.0587, 0.0848, -0.6788, -0....","[-0.6927, -1.3187, -0.3017, 0.1043, -0.738, -0...","[-0.1445, -1.4632, -0.4437, 0.0315, -0.7189, -...","[0.0697, -1.3935, -0.2558, 0.0259, -0.6354, -0...","[-0.5208, -1.9143, -0.1985, 0.0784, -0.7486, -...","[-0.0914, -2.0057, -0.1808, 0.0307, -0.7604, -..."
8,2020-05-25,"[0.2885, 0.2885, 0.2885, 0.0519, -0.0519, -0.1...","[-0.2099, 0.0786, 0.0393, 0.0096, 0.1828, -0.1...","[-0.6148, -0.5363, -0.1788, -0.0039, -0.2728, ...","[-0.6205, -1.1568, -0.4818, 0.0813, -0.6236, -...","[-0.5849, -1.7417, -0.6068, 0.1083, -0.622, -0...","[-0.7084, -2.4501, -0.638, 0.0555, -0.7104, -0...","[-0.403, -2.8531, -0.5654, 0.0182, -0.365, -0....","[-0.5354, -3.3886, -0.5489, 0.0872, -0.7249, -...","[-0.6784, -4.067, -0.539, 0.0465, -0.6217, -0...."
9,2020-05-26,"[-0.4103, -0.4103, -0.4103, 0.0566, -0.343, -0...","[-0.4687, -0.879, -0.4395, 0.0178, -0.4299, -0...","[-0.2911, -1.1702, -0.3901, -0.0048, -0.1264, ...","[0.5444, -0.6258, -0.0718, 0.0524, -0.0033, -0...","[-0.5093, -1.135, -0.0853, 0.1102, -0.7143, -0...","[-0.5575, -1.6925, -0.1741, 0.0543, -0.5081, -...","[-0.2708, -1.9633, -0.4458, 0.0124, -0.5259, -...","[-0.6075, -2.5707, -0.4786, 0.0932, -0.5515, -...","[-0.692, -3.2628, -0.5234, 0.0522, -0.4168, -0..."


In [7]:
gnn_data[0]

array([[-0.1641, -0.1641, -0.1641,  0.0188, -0.1641, -0.2037, -0.849 ],
       [-0.3463, -0.5104, -0.2552,  0.0397, -0.3463, -0.2085,  0.0073],
       [-0.5711, -1.0815, -0.3605,  0.0654, -0.5711, -0.2165, -0.8604],
       [-0.6848, -1.7662, -0.5341,  0.0785, -0.6848, -0.2124, -1.0989],
       [-0.5371, -0.5371, -0.5371,  0.0378, -0.1641, -0.2092,  0.9546],
       [-0.3869, -0.924 , -0.462 ,  0.0403, -0.3463, -0.2072,  0.8104],
       [-0.5183, -1.4423, -0.4808,  0.0601, -0.5711, -0.2056, -0.8075],
       [ 0.3645, -1.0778, -0.1802,  0.02  , -0.6848, -0.1925, -1.3684],
       [ 1.5194,  1.5194,  1.5194, -0.0248, -0.1641, -0.1925, -0.558 ],
       [-0.4838,  1.0357,  0.5178,  0.0428, -0.3463, -0.2115, -1.1043],
       [-0.497 ,  0.5387,  0.1796,  0.0564, -0.5711, -0.208 ,  0.4183],
       [-0.6788, -0.1401, -0.5532,  0.0356, -0.6848, -0.2103, -0.0962],
       [-0.0519, -0.0519, -0.0519, -0.0292, -0.1641, -0.1982,  0.5662],
       [ 0.1828,  0.1309,  0.0654,  0.023 , -0.3463, -0.2033,  0

In [None]:
lstm_data[0]

array([[-7.14300e-01, -1.61690e+00, -2.81400e-01,  6.87000e-02,
        -5.93900e-01, -2.09600e-01,  6.63100e-01],
       [-5.08100e-01, -2.12500e+00, -4.08600e-01,  6.48000e-02,
        -2.29900e-01, -2.06400e-01, -1.06730e+00],
       [-5.25900e-01, -2.65090e+00, -5.82800e-01,  5.03000e-02,
        -6.89100e-01, -2.09600e-01, -1.68050e+00],
       [-5.51500e-01, -3.20240e+00, -5.28500e-01,  6.77000e-02,
        -4.61300e-01, -2.12600e-01,  9.66600e-01],
       [-4.16800e-01, -3.61910e+00, -4.98000e-01,  6.02000e-02,
        -5.67600e-01, -2.11900e-01, -2.04740e+00],
       [-6.70000e-01, -4.28920e+00, -5.46100e-01,  6.09000e-02,
        -3.50700e-01, -2.09600e-01,  1.56500e+00],
       [-6.75700e-01, -4.96490e+00, -5.87500e-01,  6.67000e-02,
        -7.14100e-01, -2.09700e-01,  6.13100e-01],
       [-6.94200e-01, -5.65910e+00, -6.80000e-01,  5.35000e-02,
        -6.90400e-01, -2.11000e-01,  5.38600e-01],
       [-4.20100e-01, -6.07920e+00, -5.96700e-01,  4.80000e-02,
        -6.88300

In [8]:
target_data[0]

array([-0.3943, -2.3534, -0.5279,  0.056 , -0.5939, -0.2099, -1.4422],
      dtype=float32)

In [8]:
import torch
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from new_model.GHAT import PredModel
from config import Config

config = Config()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

data_path_dict = {'000046': ('./data/volume/0308/Features/000046_25_daily_f_all.csv', '/stock-trade-pred/gcn/saved_models/000046/model_best.pt'),
                  '000753': ('./data/volume/0308/Features/000753_25_daily_f_all.csv', '/stock-trade-pred/gcn/saved_models/000753/model_best.pt'),
                  '000951': ('./data/volume/0308/Features/000951_25_daily_f_all.csv', '/stock-trade-pred/gcn/saved_models/000951/model_best.pt'),
                  '002882': ('./data/volume/0308/Features/002882_25_daily_f_all.csv', '/stock-trade-pred/gcn/saved_models/002882/model_best.pt'),
                  '300174': ('./data/volume/0308/Features/300174_25_daily_f_all.csv', '/stock-trade-pred/gcn/saved_models/300174/model_best.pt'),
                  '300263': ('./data/volume/0308/Features/300263_25_daily_f_all.csv', '/stock-trade-pred/gcn/saved_models/300263/model_best.pt'),}

for i, (stock_num, (data_path, model_pt)) in enumerate(data_path_dict.items()):
    if stock_num == '000046':
        print(f'{stock_num}')
        gnn_data, lstm_data, target_data = build_training_data(input_path= data_path)
        data = pd.read_csv(data_path)
        model = PredModel(config.in_features, config.out_features, config.embed_dim,
                            config.ff_dim, config.n_heads, config.n_nodes, config.n_layers).to(device)
        model.load_state_dict(torch.load(model_pt))
        model.eval()

        # 生成邻接矩阵
        time_length = 6
        bin_length = 4
        direct = True  # 设置为 True 表示有向图，设置为 False 表示无向图
        adj_matrix = gen_adjmatrix(time_length, bin_length, direct)

        test_gnn = torch.from_numpy(gnn_data).to(device)
        test_lstm = torch.from_numpy(lstm_data).to(device)
        adj_matrix = torch.from_numpy(adj_matrix).to(device)
        pred = model(test_lstm, test_gnn, adj_matrix)
        pred = pred.cpu().detach().numpy()

        mse = mean_squared_error(target_data[:, 1], pred)
        print(f'Mean Squared Error (MSE): {mse}')

        rmse = np.sqrt(mean_squared_error(target_data[:, 1], pred))
        print(f'Root Mean Squared Error (RMSE): {rmse}')

        mae = mean_absolute_error(target_data[:, 1], pred)
        print(f'Mean Absolute Error (MAE): {mae}')

        r2 = r2_score(target_data[:, 1], pred)
        print(f'R-squared (R²): {r2}')

000046
Mean Squared Error (MSE): 0.34640854597091675
Root Mean Squared Error (RMSE): 0.5885648131370544
Mean Absolute Error (MAE): 0.3114861249923706
R-squared (R²): 0.996669590473175
