In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm_notebook

from transformers import AutoConfig, AutoModel, AutoTokenizer

import pickle

from class_new import QuestionNode, ParagraphTitleNode, SentenceNode, EntityNode, Adjacency_sp

# 阶段1

## get_features_from_XLNET

In [2]:
def get_features_from_XLNET(text,text_pair=None,
                            tokenizer = None,
                            model = None,
                            add_special_tokens = True,
                           device = 'cuda'):
    '''XLNET在512张TPU v3上训练5.5天得到. 一张TPU 8核心 128GB内存.'''
    
    assert model
    model_input = tokenizer_XLNET.encode_plus(text,text_pair,
                                        add_special_tokens=add_special_tokens,
                                        return_tensors='pt')
    
    model_input = {k:v.to(device) for k,v in model_input.items()}
    
    # 不能在函数里面设置device.
    # model.to(device)
    with torch.no_grad():
        last_hidden_state = model(**model_input)[0]
    
    return last_hidden_state

In [3]:
HotpotQA_preprocess_file = 'save_cache/hotpotQA_train_preprocess20.pkl'
with open(HotpotQA_preprocess_file,'rb')as fp:
    hotpotQA_train_preprocess = pickle.load(fp)

In [4]:
model_name = 'xlnet-base-cased'
proxies={"http_proxy": "127.0.0.1:10802",
         "https_proxy": "127.0.0.1:10802"}

config = AutoConfig.from_pretrained(model_name,proxies=proxies)

tokenizer_XLNET = AutoTokenizer.from_pretrained(model_name,proxies=proxies)
model_XLNET = AutoModel.from_config(config)
DEVICE = 'cuda'
# DEVICE = 'cpu'
_ = model_XLNET.to(DEVICE)

In [5]:
for ques_item in tqdm_notebook(hotpotQA_train_preprocess, desc = 'building features'):
    node_list = ques_item['node_list']

    # Q node
    Q_node = node_list[0]
    Q_node.content_features = get_features_from_XLNET(Q_node.content_raw,
                                                     tokenizer = tokenizer_XLNET,
                                                      model = model_XLNET,
                                                     device = DEVICE) # [1,N,D]
    Q_node.cls_feature = Q_node.content_features[:,-1,:]

    # S node
    for S_node in [i for i in node_list if i.node_type == 'Sentence']:
        # content_features不能包含特殊字符.
        S_node.content_features = get_features_from_XLNET(S_node.content_raw,
                                                            add_special_tokens=False,
                                                            tokenizer = tokenizer_XLNET,
                                                            model = model_XLNET,
                                                            device = DEVICE)
        
        S_node.cls_feature = get_features_from_XLNET(Q_node.content_raw, 
                                                        S_node.content_raw,
                                                        add_special_tokens=True,
                                                        tokenizer = tokenizer_XLNET,
                                                        model = model_XLNET,
                                                        device = DEVICE)[:,-1,:]    
    
    # P node
    for P_i, P_node in [(i,n) for i,n in enumerate(node_list) if n.node_type == 'Paragraph']:
            S_in_P = [n for n in node_list if n.parent_id == P_i]
            all_S_raw = ' '.join([n.content_raw for n in S_in_P])
            P_node.content_features = [n.content_features for n in S_in_P]
            P_node.cls_feature = get_features_from_XLNET(Q_node.content_raw, 
                                                            all_S_raw,
                                                            add_special_tokens=True,
                                                            tokenizer = tokenizer_XLNET,
                                                            model = model_XLNET,
                                                            device = DEVICE)[:,-1,:]

    # E node
    for E_node in [i for i in node_list if i.node_type == 'Entity']:
        start = E_node.start_in_sentence
        end = E_node.end_in_sentence
        E_node.content_features = node_list[E_node.parent_id].content_features[:,start:end,:]
        E_node.cls_feature = torch.mean(E_node.content_features, dim = 1)
        

HBox(children=(IntProgress(value=0, description='building features', max=20, style=ProgressStyle(description_w…




## GNN

In [6]:
class GraphAttentionLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        # features=dim, hidden=8
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) # (dim, 8)
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1))) #(2*8,1)
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, feat_matrix, adj):
        # features (B, N, dim) , adj (B, N, N)
        h = torch.matmul(feat_matrix, self.W) # (B,N,8)
        N = h.shape[-2] # N
        B = feat_matrix.shape[0]

        a_input = torch.cat([h.repeat(1, 1, N).view(B, N * N, -1), h.repeat(1, N, 1)], dim=-1)\
                                        .view(-1, N, N, 2 * self.out_features) # (B, N, N, 16)

        # 节点聚合!! 后两维(N, 16)* (16, 1)表示对节点i,计算N个节点对(i,j): 进行线性变换后产生一个标量. 对应原文的e_ij
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(-1)) # (B, N, N, 16) * (16, 1) --> (B, N, N)
        # e没有normalizaze?

        zero_vec = -9e9*torch.ones_like(e) # (B, N, N)
        attention = torch.where(adj > 0, e, zero_vec) # 都是[B, N, N]
        attention = F.softmax(attention, dim = -1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.bmm(attention, h)  # (B, N, N)*(B, N ,8)

        if self.concat:
            return F.elu(h_prime) # 一种激活函数
        else:
            return h_prime # [B, N, 8]

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class GAT(nn.Module):
    def __init__(self, features, hidden, nclass, dropout, alpha, nheads):
        """Dense version of GAT."""
        # features=1433, hidden=8, nclass=7, dropout=0.6, alpha=0.3, nheads=8
        super(GAT, self).__init__()
        self.dropout = dropout

        self.attentions = [GraphAttentionLayer(features, hidden, dropout=dropout, alpha=alpha, concat=True) \
                           for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

        # hidden * nheads = 8*8, nclass= 7 
        self.out_att = GraphAttentionLayer(hidden * nheads, nclass, dropout=dropout, alpha=alpha, concat=False) # (2708,7)

    def forward(self, feat_matrix, adj):
        # features (B, N, dim) , adj (B, N, N)
        feat_matrix = F.dropout(feat_matrix, self.dropout, training=self.training)
        feat_matrix = torch.cat([att(feat_matrix, adj) for att in self.attentions], dim=-1) # (N,8*heads)
        feat_matrix = F.dropout(feat_matrix, self.dropout, training=self.training)
        logits = F.elu(self.out_att(feat_matrix, adj))

        return logits # [B, N, num_class]
    
g_model = GAT(features=768, hidden=32, nclass=2, dropout=0.6, alpha=0.3, nheads=8).to(DEVICE)

## HotpotQA_Dataset_New

In [7]:
class HotpotQA_Dataset_New(Dataset):
    def __init__(self, train_list, val_list):
        
        # elements
        self.train_list = train_list # [(sen1, sen2, label), ...]
        self.train_size = len(self.train_list)
        
        self.val_list = val_list
        self.val_size = len(self.val_list)
        
        # func
        self._lookup_dict = {'train': (self.train_list, self.train_size),
                    'val': (self.val_list, self.val_size)}

        self.set_split('train')

        # parameters
        self.pad_max_num = 0
        self.pad_value = 0
        self.pad_to_max_length = True
        
    @classmethod
    def build_dataset(cls, hotpotQA_train_preprocess, ratio_train=0.7, seed=123):

        np.random.seed(seed)
        np.random.shuffle(hotpotQA_train_preprocess)        
        
        sample_train_num = int(ratio_train * len(hotpotQA_train_preprocess))

        train_list = hotpotQA_train_preprocess[:sample_train_num]
        val_list = hotpotQA_train_preprocess[sample_train_num:]

        return cls(train_list, val_list)

    @staticmethod
    def build_dataset_from_path(QA_preprocess_path, ratio_train=0.7, seed=123):
        
        with open(QA_preprocess_path,'rb')as fp:
            hotpotQA_train_preprocess = pickle.load(fp)
        
        return HotpotQA_Dataset_New.build_dataset(hotpotQA_train_preprocess,
                            ratio_train,
                            seed,)

    def set_parameters(self, pad_max_num, pad_value=0, pad_to_max_length=True):
        self.pad_max_num = pad_max_num
        self.pad_value = pad_value
        self.pad_to_max_length = pad_to_max_length

    def set_split(self, split="train"):
        assert split in ['train', 'val', 'test']
        self._target_split = split
        self._target_pair, self._target_size = self._lookup_dict[split]

    def __getitem__(self, index):
        QA_item = self.train_list[index]
        node_list = QA_item['node_list']
        sp_adj = QA_item['sp_adj']

        # no padding
        feature_matrix = torch.cat([n.cls_feature for n in node_list], dim=0)
        adj = torch.from_numpy(sp_adj.to_dense_symmetric())
        sent_mask = torch.tensor([1 if n.node_type == 'Sentence' else 0 for n in node_list]).unsqueeze(-1)
        sent_label = torch.tensor([1 if n.node_type == 'Sentence' and n.is_support else 0 for n in node_list]).unsqueeze(-1)

        para_mask = torch.tensor([1 if n.node_type == 'Paragraph' else 0 for n in node_list]).unsqueeze(-1)
        para_label = torch.tensor([1 if n.node_type == 'Paragraph' and n.is_support else 0 for n in node_list]).unsqueeze(-1)

        answer_type = 1 if node_list[0].answer in ['yes', 'no'] else 0
        ans_yes_no = 1 if node_list[0].answer == 'yes' else 0

        answer_type = torch.tensor(answer_type).unsqueeze(-1)
        ans_yes_no = torch.tensor(ans_yes_no).unsqueeze(-1)

        # find ans span in top 4 sentences.
        answer_tokens = node_list[0].answer_tokens
        sent_tokens = [n.content_tokens for n in node_list if n.node_type == 'Sentence']

        ques_features = node_list[0].cls_feature
        sent_features = [n.content_features for n in node_list if n.node_type == 'Sentence']


        if self.pad_to_max_length:
            node_len = feature_matrix.shape[-2]
            pad_max_num = max(self.pad_max_num, node_len)
            pad_value = self.pad_value
            node_dim = feature_matrix.shape[-1]

            feature_matrix_p = torch.zeros([pad_max_num, node_dim]).fill_(pad_value)
            feature_matrix_p[:node_len,:] = feature_matrix
            feature_matrix = feature_matrix_p

            adj_p = torch.zeros([pad_max_num, pad_max_num]).fill_(pad_value)
            adj_p[:node_len,:node_len] = adj
            adj = adj_p

            sent_mask_p = torch.zeros([pad_max_num, 1]).fill_(pad_value)
            sent_mask_p[:node_len,:] = sent_mask
            sent_mask = sent_mask_p
            sent_label_p = torch.zeros([pad_max_num, 1]).fill_(pad_value)
            sent_label_p[:node_len,:] = sent_label
            sent_label = sent_label_p

            para_mask_p = torch.zeros([pad_max_num, 1]).fill_(pad_value)
            para_mask_p[:node_len,:] = para_mask
            para_mask = para_mask_p
            para_label_p = torch.zeros([pad_max_num, 1]).fill_(pad_value)
            para_label_p[:node_len,:] = para_label
            para_label = para_label_p

        item_info_dict = {
            'feature_matrix': feature_matrix,
            'adj': adj,
            'sent_mask': sent_mask,
            'sent_label': sent_label,
            'para_mask': para_mask,
            'para_label': para_label,
            'answer_type': answer_type,
            'ans_yes_no': ans_yes_no,
            'answer_tokens': answer_tokens,
            'sent_tokens': sent_tokens,
            'ques_features': ques_features,
            'sent_features': sent_features,
        }

        return item_info_dict

    
    def __len__(self):
        return self._target_size
    
    def __repr__(self):
        return 'HotpotQA Dataset. mode: {}. size: {}. max_seq: {}'.format\
            (self._target_split,self.pair_type,self.__len__(),self.max_length)

    def __str__(self):
        return self.__repr__()

    def get_num_batches(self, batch_size):
        return len(self) // batch_size
    
def gen_batches(dataset, batch_size, shuffle=True, drop_last=True, device='cpu', seed = 123):

    if seed: np.random.seed(seed)
    dataset_len = dataset.__len__()
    
    index_pool = [i for i in range(dataset_len)]
    if shuffle: np.random.shuffle(index_pool)

    cursor = 0
    while cursor < dataset_len:
        # last batch
        if cursor + batch_size > dataset_len:
            if drop_last: break

        FLAG_FIRST = True
        for index in index_pool[cursor: min(cursor + batch_size, dataset_len)]:
            if FLAG_FIRST:
                feature_matrix = dataset[index]['feature_matrix'].unsqueeze(0)
                adj = dataset[index]['adj'].unsqueeze(0)
                sent_mask = dataset[index]['sent_mask'].unsqueeze(0)
                sent_label = dataset[index]['sent_label'].unsqueeze(0)
                para_mask = dataset[index]['para_mask'].unsqueeze(0)
                para_label = dataset[index]['para_label'].unsqueeze(0)

                answer_type = dataset[index]['answer_type']
                ans_yes_no = dataset[index]['ans_yes_no']
                answer_tokens = dataset[index]['answer_tokens']

                sent_tokens,ques_features,sent_features = [],[],[]

                FLAG_FIRST = False
            else:
                feature_matrix = torch.cat([feature_matrix, dataset[index]['feature_matrix'].unsqueeze(0)], dim=0)
                adj = torch.cat([adj, dataset[index]['adj'].unsqueeze(0)], dim=0)
                sent_mask = torch.cat([sent_mask, dataset[index]['sent_mask'].unsqueeze(0)], dim=0)
                sent_label = torch.cat([sent_label, dataset[index]['sent_label'].unsqueeze(0)], dim=0)
                para_mask = torch.cat([para_mask ,dataset[index]['para_mask'].unsqueeze(0)], dim=0)
                para_label = torch.cat([para_label, dataset[index]['para_label'].unsqueeze(0)], dim=0)

            sent_tokens.append(dataset[index]['sent_tokens'])
            ques_features.append(dataset[index]['ques_features'])
            sent_features.append(dataset[index]['sent_features'])

        cursor += batch_size
        
        
        feature_matrix = feature_matrix.to(device)
        adj = adj.to(device)
        sent_mask = sent_mask.to(device)
        sent_label = sent_label.to(device)
        para_mask = para_mask.to(device)
        para_label = para_label.to(device)
        
        batch_item_info_dict = {
            'feature_matrix': feature_matrix,
            'adj': adj,
            'sent_mask': sent_mask,
            'sent_label': sent_label,
            'para_mask': para_mask,
            'para_label': para_label,
            'answer_type': answer_type,
            'ans_yes_no': ans_yes_no,
            'answer_tokens': answer_tokens,
            'sent_tokens': sent_tokens,
            'ques_features': ques_features,
            'sent_features': sent_features,
        }
        yield batch_item_info_dict
    

In [243]:
dataset = HotpotQA_Dataset_New.build_dataset(hotpotQA_train_preprocess)
dataset.set_parameters(200)
batch_generator = gen_batches(dataset,batch_size=2, device=DEVICE)

In [246]:
for batch_index, batch_dict in enumerate(batch_generator):
    
    logits = g_model(batch_dict['feature_matrix'], batch_dict['adj'])
    
    logits

    sent_mask = batch_dict['sent_mask']
    sent_label = batch_dict['sent_label']
    
    

    break

33
torch.Size([1, 38, 768])


In [134]:
res = (logits * sent_mask)

# 阶段2

## find top K sents

In [135]:
sent_index = torch.nonzero(sent_mask[0].view(-1,).long()).flatten()

In [136]:
sent_index

tensor([ 2,  5,  8, 10, 13, 17, 21, 23, 25, 26, 29, 32, 35, 38, 41, 44, 47, 49,
        53, 56, 59, 61, 64, 67, 71, 75, 78, 81, 84], device='cuda:0')

In [137]:
sent_index.shape

torch.Size([29])

In [138]:
len(sent_tokens[0])

29

In [139]:
len(sent_features[0])

29

In [140]:
sent_logits = torch.index_select(logits[0], 0, sent_index)
sent_logits_soft = F.softmax(sent_logits, dim = -1)
sent_logits_soft.shape

torch.Size([29, 2])

In [141]:
value, index = torch.max(sent_logits_soft, dim=-1)
sent_scores = value.cpu().tolist()
len(sent_scores)

29

In [142]:
sent_score_index = [(score, index) for score, index in zip(sent_scores, range(len(sent_scores)))]

In [143]:
sent_score_index_sort = sorted(sent_score_index,key=lambda x:x[0],reverse=True)
sent_indexes = [x[-1] for x in sent_score_index_sort[:4]]
sent_indexes
# 获取batch中一个元素的前4句.

[1, 19, 21, 18]

## 测试find span

In [144]:
sent_final_tokens = []
for i in sent_indexes: sent_final_tokens.extend(sent_tokens[0][i])

In [148]:
answer_tokens_test =  ['Circuit', ',', 'near']

In [149]:
answer_tokens_test in sent_final_tokens

False

In [151]:
print(answer_tokens_test)
print(sent_final_tokens[55:61])

['Circuit', ',', 'near']
['Mount', 'Pan', '##orama', 'Circuit', ',', 'near']


In [174]:
def find_ans_spans(target, tokens, offets_type = 'position', top_num = None):
    assert offets_type in ['position', 'range'] and \
        type(target) == type(tokens) == list
    len_x1 = len(target)
    len_x2 = len(tokens)
    if len_x1 == 0 or len_x2 == 0 or len_x1 > len_x2:
        return [0,0]
    
    i1=0
    i2=0
    i2_current = 0
    spans = []
    while i2 <= len_x2 - len_x1:
        if top_num and len(spans) == top_num:
            break
        i2_current = i2
        while i1 < len_x1:
            if target[i1] != tokens[i2]: 
                i1 = 0
                i2 = i2_current + 1
                break
            else:
                i1 += 1
                i2 += 1
                
        if not i1 < len_x1:
            i1 = 0
            if offets_type == 'position':  
                spans.append([i2_current, i2_current+len_x1-1])
            else:
                spans.append([i2_current, i2_current+len_x1])
                
    return spans[:top_num+1] if top_num else spans[:]

In [154]:
find_ans_spans(answer_tokens_test, sent_final_tokens)

[[58, 60]]

## 计算span

将`ques_feature`和`content_features`加上`[SEP] [CLS]`拼接在一起.

In [229]:
answer_tokens_test2 = [ 'Australian', 'GT', 'Championship',]

In [230]:
sent_tokens_combine = []
for s_index in sent_indexes: sent_tokens_combine.extend(sent_tokens[0][s_index])

In [231]:
final_tokens = ['<x1>','[ques]','<x2>'] + sent_tokens_combine + ['<cls>']
ans_spans = find_ans_spans(answer_tokens_test2, final_tokens, top_num = 2)
ans_spans

[[88, 90]]

In [232]:
len(final_tokens)

172

## 获取最终features

In [219]:
SEP_rep = torch.randn([768,]).to(DEVICE)
CLS_rep = torch.randn([768,]).to(DEVICE)

In [222]:
final_rep = [SEP_rep, ques_features[0].squeeze(), SEP_rep]
for s_index in sent_indexes: final_rep.extend(sent_features[0][s_index].view(-1, 768))
final_rep = final_rep + [CLS_rep]

In [224]:
torch.stack(final_rep, dim = 0).shape

torch.Size([183, 768])

In [235]:
sent_indexes

[1, 19, 21, 18]