In [17]:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
from torch_geometric.data import HeteroData

In [18]:

# 源节点特征 (6 个源节点，每个节点 16 维特征)
src = torch.tensor([
    [1, 1, 1, 1, 1, 1],  # 源节点索引
    [2, 2, 2, 2, 2, 2],
    [3, 3, 3, 3, 3, 3],
    [4, 4, 4, 4, 4, 4],  
    [5, 5, 5, 5, 5, 5]
], dtype=torch.float)
tgt = torch.tensor([
    [1, 1, 1, 1, 1, 1],  
    [1, 1, 1, 1, 1, 1],
    [1, 1, 1, 1, 1, 1]
], dtype=torch.float)
adj = torch.tensor([
    [0, 3, 4, 1, 4, 1],  # 源节点索引
    [0, 0, 0, 1, 1, 2]   # 目标节点索引
], dtype=torch.long)
print("adj[0]:",adj[0])
data = HeteroData()
data['src'] = src
data['tgt'] = tgt
data['adj'] = adj


adj[0]: tensor([0, 3, 4, 1, 4, 1])


In [19]:
class GNN(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GNN, self).__init__(aggr='sum', flow='source_to_target')
        self.linear = torch.nn.Linear(in_channels, out_channels)

    def forward(self, data: HeteroData):
        x_src, x_tgt, edge_index = data['src'], data['tgt'], data['adj']
        out = self.propagate(x_src=x_src, x_tgt=x_tgt,edge_index=edge_index)
        out = self.linear(out)
        return out
    
    def message(self, x_src,edge_index):
        message_src = x_src[edge_index[0]]
        return message_src
    


In [20]:
def lossfunction(emd_out,data: HeteroData):
        x_src = data['src']
        edge_index = data['adj']
        x_src = x_src[edge_index[0]]
        x_src = x_src.t()
        simliarity = torch.matmul(emd_out , x_src)
        simliarity = torch.sum(simliarity, dim=0)
        loss = torch.mean(simliarity)
        return loss


In [21]:
model = GNN(6, 6)
out = model(data)
loss = lossfunction(out,data)
print("loss:",loss)



loss: tensor(5.9770, grad_fn=<MeanBackward0>)


In [23]:
iteration = 1000
learning_rate = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for i in range(iteration):
    out = model(data)
    loss = lossfunction(out,data)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i % 100 == 0:
        print("loss:",loss)
print("out:",model(data))





loss: tensor(-22224.1816, grad_fn=<MeanBackward0>)
loss: tensor(-24447.2324, grad_fn=<MeanBackward0>)
loss: tensor(-26670.2832, grad_fn=<MeanBackward0>)
loss: tensor(-28893.3359, grad_fn=<MeanBackward0>)
loss: tensor(-31116.3848, grad_fn=<MeanBackward0>)
loss: tensor(-33339.4414, grad_fn=<MeanBackward0>)
loss: tensor(-35562.4883, grad_fn=<MeanBackward0>)
loss: tensor(-37785.5391, grad_fn=<MeanBackward0>)
loss: tensor(-40008.5859, grad_fn=<MeanBackward0>)
loss: tensor(-42231.6367, grad_fn=<MeanBackward0>)
out: tensor([[-1217.3284, -1216.5657, -1217.8594, -1219.7936, -1223.4185, -1223.5702],
        [ -858.2513,  -857.6175,  -858.6173,  -859.9014,  -862.4227,  -862.4664],
        [ -259.7892,  -259.3703,  -259.8803,  -260.0811,  -260.7632,  -260.6266]],
       grad_fn=<AddmmBackward0>)
