In [1]:
import numpy as np
from scipy.linalg import expm

In [50]:
# Load adjacency matrix (G)
G = np.load("../data_src/file_of_similarity_mat.npy", allow_pickle=True)
N = G.shape[0] 

k = 5
start = 0.01
end = 0.1 
num_steps = 50 
alpha_hat_tj = np.exp(np.linspace(np.log(start), np.log(end), num_steps))  
alpha_tj = np.cumsum(alpha_hat_tj)
print(f"Alpha schedule (cumulative): {alpha_tj}")

A = (G + G.T) / 2 / k
R = A.copy()
np.fill_diagonal(R, 0)
R[np.arange(N), np.arange(N)] = -R.sum(axis=1)


Alpha schedule (cumulative): [0.01       0.02048113 0.03146654 0.0429805  0.05504842 0.06769698
 0.08095409 0.09484904 0.10941253 0.12467671 0.1406753  0.15744362
 0.17501873 0.19343943 0.21274641 0.2329823  0.25419181 0.27642178
 0.2997213  0.32414183 0.34973731 0.37656426 0.40468195 0.43415247
 0.4650409  0.49741548 0.5313477  0.5669125  0.60418844 0.64325784
 0.68420699 0.72712633 0.77211066 0.81925932 0.86867645 0.9204712
 0.97475796 1.03165662 1.09129285 1.15379837 1.21931122 1.28797611
 1.35994468 1.43537588 1.51443631 1.59730059 1.68415172 1.77518154
 1.87059109 1.97059109]


In [51]:
Qt_list = [expm(alpha_hat_tj[t] * R) for t in range(num_steps)]
Qtbar = Qt_list[0]
for i in range(1, num_steps): 
    Qtbar = Qtbar @ Qt_list[i]
Qt_direct = expm(alpha_tj[-1] * R)
diff = np.abs(Qtbar - Qt_direct)
print("最大差值:", diff.max())


最大差值: 6.245004513516506e-17


In [54]:
# 对单独的x0测试
T = 40
np.random.seed(11)
x0 = np.random.choice(N)
prob = np.zeros(N)
prob[x0] = 1.0  # one-hot

sample_path = [x0]
for i in range(T):
    prob = prob @ Qt_list[i]

Qt_direct = expm(alpha_tj[T-1] * R)
prob_direct = np.zeros(N)
prob_direct[x0] = 1.0
final_prob = prob_direct @ Qt_direct

final_point = np.random.choice(N, p=prob)
print("逐步扩散最后结果:", final_point)
# 后面两个应该相等
print("一步扩散样本结果:", final_prob)
print("Qt_direct的第x0行", Qt_direct[x0])


逐步路径: [25]
逐步扩散最后结果: 4
一步扩散样本结果: [0.01514551 0.01607371 0.015668   0.01568122 0.01578334 0.01519203
 0.01584822 0.01589721 0.01531972 0.01397655 0.01635962 0.01585792
 0.01454787 0.01576913 0.0156773  0.01548055 0.01477914 0.01543365
 0.01257785 0.01321602 0.01273218 0.01466229 0.01504799 0.01606867
 0.01628031 0.01816222 0.01554183 0.01247554 0.01238296 0.01227943
 0.01278313 0.01316731 0.01271178 0.01539217 0.01477541 0.01564255
 0.01270557 0.01281897 0.01354215 0.01441478 0.01522461 0.01453817
 0.01361637 0.01215025 0.01582298 0.01449323 0.01539893 0.01237949
 0.01441415 0.01564233 0.01556759 0.01393232 0.01326807 0.01481912
 0.01255296 0.01242599 0.01258875 0.01400843 0.0139517  0.01545641
 0.01636119 0.01539034 0.01338305 0.01290354 0.01466049 0.01272812
 0.01517003 0.0162873  0.01299226]
Qt_direct的第x0行 [0.01514551 0.01607371 0.015668   0.01568122 0.01578334 0.01519203
 0.01584822 0.01589721 0.01531972 0.01397655 0.01635962 0.01585792
 0.01454787 0.01576913 0.0156773  0.01548055 0

In [55]:
import torch
# Save Qt and Qtbar to .pt
G = np.load("../data_src/file_of_similarity_mat.npy", allow_pickle=True)
N = G.shape[0] 

k = 5
start = 0.01
end = 0.1 

num_steps = 50 
alpha_hat_tj = np.exp(np.linspace(np.log(start), np.log(end), num_steps))  # 50 points
alpha_tj = np.cumsum(alpha_hat_tj)

A = (G + G.T) / 2 / k
R = A.copy()
np.fill_diagonal(R, 0)
R[np.arange(N), np.arange(N)] = -R.sum(axis=1)


Qt_list = []
Qt_bar_list = []

for t in range(num_steps):
    Qt = expm(alpha_hat_tj[t] * R)
    Qt_bar = expm(alpha_tj[t] * R)
    Qt_list.append(torch.from_numpy(Qt).float())
    Qt_bar_list.append(torch.from_numpy(Qt_bar).float())

# 存为 PyTorch tensor 文件
torch.save({
    'Qt': torch.stack(Qt_list),
    'Qt_bar': torch.stack(Qt_bar_list)
}, '../data/Qt_all.pt')
print("save success")


Alpha schedule (cumulative): [0.01       0.02048113 0.03146654 0.0429805  0.05504842 0.06769698
 0.08095409 0.09484904 0.10941253 0.12467671 0.1406753  0.15744362
 0.17501873 0.19343943 0.21274641 0.2329823  0.25419181 0.27642178
 0.2997213  0.32414183 0.34973731 0.37656426 0.40468195 0.43415247
 0.4650409  0.49741548 0.5313477  0.5669125  0.60418844 0.64325784
 0.68420699 0.72712633 0.77211066 0.81925932 0.86867645 0.9204712
 0.97475796 1.03165662 1.09129285 1.15379837 1.21931122 1.28797611
 1.35994468 1.43537588 1.51443631 1.59730059 1.68415172 1.77518154
 1.87059109 1.97059109]
save success


In [56]:
# test .pt
data = torch.load('../data/Qt_all.pt', weights_only=True)
print(len(data['Qt']), data['Qt'][0].shape)
print(len(data['Qt_bar']), data['Qt_bar'][0].shape)

50 torch.Size([69, 69])
50 torch.Size([69, 69])
