In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import pickle

from Classes import Node, Adjacency_sp

# 预处理

In [2]:
save_cache_path = 'save_cache/'

with open(save_cache_path+'hotpotQA_train_preprocess100_feat_adj.pkl', 'rb') as fp:
    matrix_adj_label = pickle.load(fp)

In [3]:
DEVICE = 'cuda:0'
feat_matrix, Adj_sp, labels = matrix_adj_label[0]

feat_matrix = feat_matrix.to(DEVICE)

adj = Adj_sp.to_dense_symmetric()
adj = torch.from_numpy(adj).to(DEVICE)

labels = labels.to(DEVICE)

In [4]:
feat_matrix.shape

torch.Size([118, 768])

In [5]:
adj.shape

torch.Size([118, 118])

In [6]:
labels.shape

torch.Size([118, 1])

In [14]:
def pad_all(feat_matrix, adj, labels, max_num = 300, pad_value = 0):
    # feat_matrix: [N, dim]
    assert feat_matrix.shape[0] == adj.shape[0] == labels.shape[0]
    node_len = feat_matrix.shape[0]
    max_num = max(max_num, node_len)
    node_dim = feat_matrix.shape[1]
    
    feat_matrix_p = torch.zeros([max_num, node_dim]).fill_(pad_value)
    feat_matrix_p[:node_len,:] = feat_matrix
    
    adj_p = torch.zeros([max_num, max_num]).fill_(pad_value)
    adj_p[:node_len,:node_len] = adj
    
    labels_p = torch.zeros([max_num, 1]).fill_(pad_value)
    labels_p[:node_len,:] = labels   
    
    return feat_matrix_p, adj_p, labels_p

In [15]:
feat_matrix_p, adj_p, labels_p = pad_all(feat_matrix, adj, labels)
feat_matrix_p = feat_matrix_p.to(DEVICE)
adj_p = adj_p.to(DEVICE)
labels_p = labels_p.to(DEVICE)

# PyGAT

In [71]:
class GraphAttentionLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        # nfeat=dim, nhid=8
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) # (dim, 8)
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1))) #(2*8,1)
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, input, adj):
        # features (B, N, dim) , adj (B, N, N)
        h = torch.matmul(input, self.W) # (B,N,8)
        N = h.shape[-2] # N
        B = input.shape[0]

        a_input = torch.cat([h.repeat(1, 1, N).view(B, N * N, -1), h.repeat(1, N, 1)], dim=-1)\
                                        .view(-1, N, N, 2 * self.out_features) # (B, N, N, 16)

        # 节点聚合!! 后两维(N, 16)* (16, 1)表示对节点i,计算N个节点对(i,j): 进行线性变换后产生一个标量. 对应原文的e_ij
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(-1)) # (B, N, N, 16) * (16, 1) --> (B, N, N)
        # e没有normalizaze?

        zero_vec = -9e15*torch.ones_like(e) # (B, N, N)
        attention = torch.where(adj > 0, e, zero_vec) # 都是[B, N, N]
        attention = F.softmax(attention, dim = -1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.bmm(attention, h)  # (B, N, N)*(B, N ,8)

        if self.concat:
            return F.elu(h_prime) # 一种激活函数
        else:
            return h_prime # [B, N, 8]

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class GAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
        """Dense version of GAT."""
        # nfeat=1433, nhid=8, nclass=7, dropout=0.6, alpha=0.3, nheads=8
        super(GAT, self).__init__()
        self.dropout = dropout

        self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

        # nhid * nheads = 8*8, nclass= 7 
        self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False) # (2708,7)

    def forward(self, feat_matrix, adj, ):
        # features (B, N, dim) , adj (B, N, N)
        feat_matrix = F.dropout(feat_matrix, self.dropout, training=self.training)
        # att --> (N,8)
        feat_matrix = torch.cat([att(feat_matrix, adj) for att in self.attentions], dim=-1) # (N,8*heads)
        feat_matrix = F.dropout(feat_matrix, self.dropout, training=self.training)
        feat_matrix = F.elu(self.out_att(feat_matrix, adj))

        return F.log_softmax(feat_matrix, dim=1)
    
g_model = GAT(nfeat=768, nhid=8, nclass=2, dropout=0.6, alpha=0.3, nheads=8).to(DEVICE)

In [77]:
test_feat_matrix_p = torch.randn([3,300,768], device = DEVICE)
test_adj = torch.randint(0,8,[3, 300,300], device = DEVICE)
test_labels = torch.randint(0,2,[3, 300,1], device = DEVICE)
test_sent_mask = torch.randint(0,2,[3, 300,1], device = DEVICE)
test_para_mask = torch.randint(0,2,[3, 300,1], device = DEVICE)

In [78]:
test_logits = g_model(test_feat_matrix_p, test_adj)

In [79]:
test_logits.shape

torch.Size([3, 300, 2])

In [80]:
# 选择句子的loss
r1 = (test_logits * test_sent_mask).view(-1,2)
l = (test_labels * test_sent_mask).view(-1)
loss = nn.CrossEntropyLoss()(r1,l)

torch.Size([])

In [82]:
l.shape

torch.Size([900])

In [81]:
loss

tensor(1.6126, device='cuda:0', grad_fn=<NllLossBackward>)

## graph-level 输出

1. 判断答案类型: `yes/no/span`.