## GAT网络
参考链接  
1.[Graph Attention Networks (GAT)](https://nn.labml.ai/graphs/gat/index.html)  
2. [图注意网络GAT理解及Pytorch代码实现【PyGAT代码详细注释】](https://blog.csdn.net/weixin_43629813/article/details/129278266)

In [1]:
import torch
from torch import nn
from labml_helpers.module import Module


class GraphAttentionLayer(Module):

    def __init__(self, in_features: int, out_features: int, n_heads: int,
                is_concat: bool = True,
                dropout: float = 0.6,
                leaky_relu_negative_slope: float = 0.2):

        super().__init__()

        self.is_concat = is_concat
        self.n_heads = n_heads

        # Calculate the number of dimensions per head
        if is_concat:
            assert out_features % n_heads == 0
            # If we are concatenating the multiple heads
            self.n_hidden = out_features // n_heads
        else:
            # If we are averaging the multiple heads
            self.n_hidden = out_features

        # Linear layer for initial transformation;
        # i.e. to transform the node embeddings before self-attention
        self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)
        # Linear layer to compute attention score $e_{ij}$
        self.attn = nn.Linear(self.n_hidden * 2, 1, bias=False)
        # The activation for attention score $e_{ij}$
        self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
        # Softmax to compute attention $\alpha_{ij}$
        self.softmax = nn.Softmax(dim=1)
        # Dropout layer to be applied for attention
        self.dropout = nn.Dropout(dropout)

    def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):


        n_nodes = h.shape[0]

        g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden)

        g_repeat = g.repeat(n_nodes, 1, 1)

        g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)

        g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1)
        # Reshape so that `g_concat[i, j]` is $\overrightarrow{g_i} \Vert \overrightarrow{g_j}$
        g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)

        e = self.activation(self.attn(g_concat))
        # Remove the last dimension of size `1`
        e = e.squeeze(-1)

        # The adjacency matrix should have shape
        # `[n_nodes, n_nodes, n_heads]` or`[n_nodes, n_nodes, 1]`
        assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes
        assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes
        assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads
        # Mask $e_{ij}$ based on adjacency matrix.
        # $e_{ij}$ is set to $- \infty$ if there is no edge from $i$ to $j$.
        e = e.masked_fill(adj_mat == 0, float('-inf'))

        a = self.softmax(e)

        # Apply dropout regularization
        a = self.dropout(a)

        attn_res = torch.einsum('ijh,jhf->ihf', a, g)

        # Concatenate the heads
        if self.is_concat:
            # $$\overrightarrow{h'_i} = \Bigg\Vert_{k=1}^{K} \overrightarrow{h'^k_i}$$
            return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)
        # Take the mean of the heads
        else:
            # $$\overrightarrow{h'_i} = \frac{1}{K} \sum_{k=1}^{K} \overrightarrow{h'^k_i}$$
            return attn_res.mean(dim=1)
        



In [2]:
# Create an instance of the GAT model
default_gat = GraphAttentionLayer(in_features=512, out_features=256, n_heads=1)
time_gat = GraphAttentionLayer(in_features=512, out_features=256, n_heads=1)
star_gat = GraphAttentionLayer(in_features=512, out_features=256, n_heads=1)


# Create some sample input tensors
# Node embeddings
default_h = torch.randn(5, 512)  
time_h = torch.randn(7, 512)  
star_h = torch.randn(3, 512)  
# Adjacency matrix
default_adj = torch.ones(5, 5, 1)
time_adj = torch.ones(7, 7, 1)
star_adj = torch.ones(3, 3, 1)
# Target node index
default_index: int = default_adj.shape[0]//2 
time_index: int = time_adj.shape[0]//2 
star_index: int = star_adj.shape[0]//2

# Forward pass through the GAT model
node_default = default_gat(default_h, default_adj)[default_index]
node_time = time_gat(time_h, time_adj)[time_index]  
node_star = star_gat(star_h, star_adj)[star_index]
node_all = 0.5*node_default + 0.25*node_time + 0.25*node_star

# Print the output
print(node_all.shape)

torch.Size([256])


In [None]:
with tqdm(enumerate(self._trainloader), total=num_batch, disable=not self._verbose) as pbar:
    for step, (inputs, target) in pbar:
        rvw_batch_id = inputs['id_right']
        unique_items = set(rvw_batch_id)
        batch_x = {}
        nbrs = {}
        for key, value in inputs.items():
            if key in ("id_left", "id_right"):
                batch_x[key] = value
                continue
            elif key in ("text_left", "text_left_length", "image_left", "image_left_length"):
                batch_x[key] = value.cuda(0)
                continue
            elif key == "text_right":
                batch_x[key] = []
                for d in value:
                    ke, val = next(iter(d.items()))
                    batch_x[key].append(val)
                batch_x[key] = torch.tensor(batch_x[key]).cuda(0)

                for i, d in enumerate(value):
                    # 创建一个新的nbr字典
                    nbr = {'id': [], 'text': [], 'text_length': [], 'image': [], 'image_length': []}
                    # 提取每个字典的第一个键值对
                    first_key, first_val = next(iter(d.items()))
                    if first_key not in unique_items:
                        continue
                    # 提取每个字典的其他键值对
                    for key, val in list(d.items())[1:]:
                        nbr['id'].append(key)
                        nbr['text'].append(val)
                    # 将第一个键值对添加到'id'和'text'的最后
                    nbr['id'].append(first_key)
                    nbr['text'].append(first_val)
                    nbr['text'] = torch.tensor(nbr['text']).cuda(0)
                    # 将每个字典形成的nbr放进大字典nbrs中
                    nbrs[first_key] = nbr
                    unique_items.remove(first_key)
            elif key == "image_right_length":
                batch_x[key] = []
                for d in value:
                    ke, val = next(iter(d.items()))
                    batch_x[key].append(val[0])
                batch_x[key] = torch.tensor(batch_x[key]).cuda(0)
                unique_items = list(nbrs.keys())

                for dict in value:
                    # 使用字典的键访问其值（列表）
                    for ket, val in dict.items():
                        if ket not in unique_items:
                            continue
                        val.append(val[0])  # 将第一个元素添加到列表末尾
                        del val[0]
                        nbrs[ket]['image_length'].extend(val)
                        unique_items.remove(ket)
                        nbrs[ket]['image_length'] = torch.tensor(nbrs[ket]['image_length']).cuda(0)
            elif key == "image_right":
                batch_x[key] = []
                for d in value:
                    ke, val = next(iter(d.items()))
                    batch_x[key].append(val[0])
                batch_x[key] = torch.tensor(batch_x[key]).cuda(0)
                unique_items = list(nbrs.keys())

                for dict in value:
                    # 使用字典的键访问其值（列表）
                    for ket, val in dict.items():
                        if ket not in unique_items:
                            continue
                        val.append(val[0])  # 将第一个元素添加到列表末尾
                        del val[0]
                        nbrs[ket]['image'].extend(val)
                        unique_items.remove(ket)
                        nbrs[ket]['image'] = torch.tensor(nbrs[ket]['image']).cuda(0)
            elif key == "text_right_length":
                unique_items = list(nbrs.keys())
                batch_x[key] = []
                for d in value:
                    ke, val = next(iter(d.items()))
                    batch_x[key].append(val)
                batch_x[key] = torch.tensor(batch_x[key]).cuda(0)
                for item in value:
                    ket = list(item.keys())[0]  # 提取每个字典的第一个键
                    if ket not in unique_items:  # 如果这个键在unique_items中
                        continue
                    val = list(item.values())  # 将字典的键对应的值提取出来形成一个列表val
                    val.append(val[0])  # 将第一个元素添加到列表末尾
                    del val[0]
                    # nbr = {'id': [], 'text': [], 'text_length': [], 'image': [], 'image_length': []}
                    nbrs[ket]['text_length'].extend(val)
                    unique_items.remove(ket)
                    nbrs[ket]['text_length'] = torch.tensor(nbrs[ket]['text_length']).cuda(0)

        inputs = batch_x
        inputs['time_nbr'] = nbrs
        target = target.cuda(0)
        outputs = self._model(inputs)