# 初始化

In [1]:
json_train_path = r'./data/hotpot_train_v1.1.json'
HotpotQA_path = './'
save_cache_path = 'save_cache/'
use_proxy = 1
proxies={"http_proxy": "127.0.0.1:10809",
        "https_proxy": "127.0.0.1:10809"} if use_proxy else None

In [2]:
import ujson as json
import torch
import torch.nn as nn
import numpy as np
import time
import pickle

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
with open(json_train_path, 'r', encoding='utf-8') as fp:
    json_train = json.load(fp)
    
json_train[0].keys()

dict_keys(['supporting_facts', 'level', 'question', 'context', 'answer', '_id', 'type'])

# classes

In [4]:
class BaseNode(object):
    '''Node class for graph'''
    
    def __init__(self, node_id, node_type, parent_id,
                 content_raw, content_tokens=None, cls_feature=None):
        
        self.node_id = node_id
        self.node_type = node_type
        self.parent_id = parent_id

        self.content_raw = content_raw
        self.content_tokens = content_tokens

        self.cls_feature = cls_feature # final features. [1,dim]
        
    @classmethod
    def build(cls):
        raise NotImplementedError
        
    def to_serializable(self):
        raise NotImplementedError

    @classmethod
    def from_serializable(cls, contents):
        raise NotImplementedError

    def __str__(self):
        raise NotImplementedError
    
    def __repr__(self):
        raise NotImplementedError


class QuestionNode(BaseNode):
    def __init__(self,node_id, node_type, parent_id, content_raw, content_tokens,
                 answer, answer_tokens, ques_type, cls_feature=None):
        super(QuestionNode, self).__init__(node_id, node_type, parent_id, 
                                           content_raw,content_tokens,cls_feature)

        self.question = self.content_raw
        self.answer = answer
        self.answer_tokens = answer_tokens
        self.ques_type = ques_type

    @classmethod
    def build(cls, node_id, parent_id, content_raw,
                answer, ques_type):
        
        node_type = 'Question'
        answer_tokens = tokenizer.tokenize(answer)
        content_tokens = tokenizer.tokenize(content_raw)
        return cls(node_id, node_type, parent_id, content_raw, content_tokens,\
                    answer, answer_tokens, ques_type)

    def to_serializable(self):
        return {
            'node_id': self.node_id,
            'node_type': self.node_type,
            'parent_id': self.parent_id,
            'content_raw': self.content_raw,
            'content_tokens': self.content_tokens,
            'answer': self.answer,
            'answer_tokens': self.answer_tokens,
            'ques_type': self.ques_type,
            'cls_feature': self.cls_feature.tolist(),
               }

    @classmethod
    def from_serializable(cls, state_dicts):
        state_dicts['cls_feature'] = torch.tensor(state_dicts['cls_feature'])
        return cls(**state_dicts)
    
    def __str__(self):
        return f'QuestionNode: {self.node_id}'
    
    def __repr__(self):
        return self.__str__()


class ParagraphTitleNode(BaseNode):
    '''不含整个段落, 只含有title'''
    def __init__(self,node_id, node_type, parent_id, content_raw, content_tokens,
                 content_NER_list = None, cls_feature=None, is_support=False):
        super(ParagraphTitleNode, self).__init__(node_id, node_type, parent_id, 
                                           content_raw,content_tokens,cls_feature)

        self.content_NER_list = content_NER_list

        self.is_support = is_support # 段落 句子 

    @classmethod
    def build(cls, node_id, parent_id, content_raw):

        _, content_NER_list = find_NER_in_spacy(content_raw, ner=True)
        content_tokens = tokenizer.tokenize(content_raw)
#          content_NER_list = find_NER_in_Model(content_raw, content_tokens)
        # content_tokens: NO CLS.

        node_type = 'Paragraph'
        return cls(node_id, node_type, parent_id, content_raw, content_tokens,\
                    content_NER_list)
    
    def set_support(self):
        self.is_support = True

    def to_serializable(self):
        return {
            'node_id': self.node_id,
            'node_type': self.node_type,
            'parent_id': self.parent_id,
            'content_raw': self.content_raw,
            'content_tokens': self.content_tokens,
            'content_NER_list': self.content_NER_list,
            'cls_feature': self.cls_feature.tolist(),
            'is_support': self.is_support,
               }

    @classmethod
    def from_serializable(cls, state_dicts):
        state_dicts['cls_feature'] = torch.tensor(state_dicts['cls_feature'])
        return cls(**state_dicts)
        
    def __str__(self):
        return f'ParagraphTitleNode: {self.node_id}'
    
    def __repr__(self):
        return self.__str__()


class SentenceNode(BaseNode):
    def __init__(self,node_id, node_type, parent_id, content_raw, content_tokens,\
                    content_NER_list = None, cls_feature=None, is_support = False):
        super(SentenceNode, self).__init__(node_id, node_type, parent_id, 
                                           content_raw,content_tokens,cls_feature)

        self.content_NER_list = content_NER_list

        self.is_support = is_support # 段落 句子 

    @classmethod
    def build(cls, node_id, parent_id, content_raw):

        _, content_NER_list = find_NER_in_spacy(content_raw, ner=True)
        content_tokens = tokenizer.tokenize(content_raw)
#          content_NER_list = find_NER_in_Model(content_raw, content_tokens)
        # content_tokens: NO CLS.

        node_type = 'Sentence'
        return cls(node_id, node_type, parent_id, content_raw, content_tokens,\
                    content_NER_list)

    def set_support(self):
        self.is_support = True

    def get_NER_tuples_list(self):
        '''返回NER元组. e.g. [('ALLPE',id), ('DELL',id)]'''
        return [(i['content'], self.node_id) for i in self.content_NER_list]

    def to_serializable(self):
        return {
            'node_id': self.node_id,
            'node_type': self.node_type,
            'parent_id': self.parent_id,
            'content_raw': self.content_raw,
            'content_tokens': self.content_tokens,
            'content_NER_list': self.content_NER_list,
            'cls_feature': self.cls_feature.tolist(),
            'is_support': self.is_support,
               }

    @classmethod
    def from_serializable(cls, state_dicts):
        state_dicts['cls_feature'] = torch.tensor(state_dicts['cls_feature'])
        return cls(**state_dicts)
    
    def __str__(self):
        return f'SentenceNode: {self.node_id}'
    
    def __repr__(self):
        return self.__str__()


class EntityNode(BaseNode):
    def __init__(self,node_id, node_type, parent_id, content_raw, 
                 content_tokens, cls_feature=None):
        super(EntityNode, self).__init__(node_id, node_type, parent_id, 
                                           content_raw,content_tokens,cls_feature)

    @classmethod
    def build(cls, node_id, parent_id, content_raw):

        node_type = 'Entity'
        content_tokens = tokenizer.tokenize(content_raw)
        return cls(node_id, node_type, parent_id, content_raw,content_tokens)
    
    def to_serializable(self):
        return {
            'node_id': self.node_id,
            'node_type': self.node_type,
            'parent_id': self.parent_id,
            'content_raw': self.content_raw,
            'content_tokens': self.content_tokens,
            'cls_feature': self.cls_feature.tolist(),
               }

    @classmethod
    def from_serializable(cls, state_dicts):
        state_dicts['cls_feature'] = torch.tensor(state_dicts['cls_feature'])
        return cls(**state_dicts)
        
    def __str__(self):
        return f'EntityNode: {self.node_id}'
    
    def __repr__(self):
        return self.__str__()

import scipy.sparse as sp
import numpy as np

class Adjacency_sp():
    '''无重复稀疏邻接矩阵'''
    def __init__(self, v_i_j=[]):
        self.v_i_j = v_i_j
        self.i_j_find_table = []

    def append(self, v, i, j):
        if not (i,j) in self.i_j_find_table:
            self.v_i_j.append([v,i,j])
            self.i_j_find_table.append((i,j))
    
    def to_dense(self):
        '''return numpy ndarray.'''
        np_adj = np.array(self.v_i_j)
        node_len = max(max(np_adj[:, 1]), max(np_adj[:, 2])) + 1
        full_adj = sp.coo_matrix((np_adj[:, 0], (np_adj[:, 1], np_adj[:, 2])), 
                                 shape=(node_len,node_len), dtype=np.float32).todense()
        full_adj = np.array(full_adj) + np.eye(node_len,node_len)
        return full_adj

    def to_dense_symmetric(self):
        '''self-loop symmetric adj matrix.'''
        np_adj = np.array(self.v_i_j)
        node_len = max(max(np_adj[:, 1]), max(np_adj[:, 2])) + 1
        adj = sp.coo_matrix((np_adj[:, 0], (np_adj[:, 1], np_adj[:, 2])), 
                            shape=(node_len,node_len), dtype=np.float32)
        adj_symm = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj).todense()
        adj_symm = np.array(adj_symm) + np.eye(node_len,node_len)
        return adj_symm

    def to_serializable(self):
        return {
            'v_i_j': self.v_i_j,
               }

    @classmethod
    def from_serializable(cls, state_dicts):
        return cls(**state_dicts)
    
    def __repr__(self):
        np_adj = np.array(self.v_i_j)
        return f'Adjacency_sp has {len(self.v_i_j)} edges {max(max(np_adj[:, 1]), max(np_adj[:, 2])) + 1} nodes.'
    def __str__(self):
        return self.__repr__()
    def __len__(self):
        return len(self.v_i_j)

In [5]:
def auto_reload_Node(state_dicts):
    node_type = state_dicts['node_type']
    if node_type == 'Question':
        return QuestionNode.from_serializable(state_dicts)
    elif node_type == 'Paragraph':
        return ParagraphTitleNode.from_serializable(state_dicts)
    elif node_type == 'Sentence':
        return SentenceNode.from_serializable(state_dicts)
    elif node_type == 'Entity':
        return EntityNode.from_serializable(state_dicts)
    else:
        raise

# ner

In [None]:
from transformers import AutoModelForTokenClassification, AutoTokenizer

model_NER = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english",
                                                           proxies = proxies)
tokenizer = AutoTokenizer.from_pretrained("bert-large-cased-whole-word-masking",
                                         proxies = proxies)

_ = model_NER.to(DEVICE)

In [None]:
def tokensizer_in_Model(content_raw, special=False, tokenizer = tokenizer):
    tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(content_raw)))
    if special: return tokens
    else: return tokens[1:-1]

def find_NER_in_Model(content_raw, tokens=None, model=model_NER, tokenizer = tokenizer):
    '''返回: tokens[1:-1], inputs, entities_list
    第一个是不包含[CLS] [SEP]的分词序列(content_tokens是包含的).
    第二个是实体列表.'''
    label_list = [
        "O",       # Outside of a named entity
        "B-MISC",  # Beginning of a miscellaneous entity right after another miscellaneous entity
        "I-MISC",  # Miscellaneous entity
        "B-PER",   # Beginning of a person's name right after another person's name
        "I-PER",   # Person's name
        "B-ORG",   # Beginning of an organisation right after another organisation
        "I-ORG",   # Organisation
        "B-LOC",   # Beginning of a location right after another location
        "I-LOC"    # Location
        ]
    # Bit of a hack to get the tokens with the special tokens
    if not tokens: tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(content_raw)))
    inputs = tokenizer.encode(content_raw, return_tensors='pt').to(DEVICE)
    outputs = model_NER(inputs)[0]
    predictions = torch.argmax(outputs, dim=2)
    # 去除 [cls] [sep]
    res = [(token, label_list[prediction]) for token, prediction in zip(tokens, predictions[0].tolist())][1:-1]
    # print(predictions)
    # print(res)

    entities_list = []
    cursor_1 = 0
    cursor_2 = 0

    while cursor_1 < len(res):
        entities_dict = {}
        temp = []
        if res[cursor_1][1] == 'O': 
            cursor_1+=1
            continue
        
        entities_dict['type'] = res[cursor_1][1]
        entities_dict['span_start'] = cursor_1
        temp.append(res[cursor_1][0])
        cursor_2 = cursor_1 + 1
        while cursor_2 < len(res):
            if res[cursor_2][1] == 'O':
                cursor_1 = cursor_2 + 1
                break
            
            temp.append(res[cursor_2][0])
            cursor_2 += 1
        
        cursor_1 += cursor_2
        entities_dict['content'] = ' '.join(temp).replace(' ##', '')
        entities_dict['content_tokens'] = temp
        entities_list.append(entities_dict)

    return tokens[1:-1], entities_list

In [6]:
from transformers import XLNetTokenizer

model_path = '/g/data/models/xlnet-large-cased'
tokenizer = XLNetTokenizer.from_pretrained(model_path)

In [7]:
def tokensizer_in_Model(content_raw, special=False, tokenizer = tokenizer):
    tokens = tokenizer.tokenize(content_raw)
    return tokens

In [8]:
import spacy
spacy.prefer_gpu()
nlp = spacy.load("en_core_web_lg")
# 使用BERT进行fine-turning,文本也应该使用bert进行分词,以确保embedding的一致性.
# 这要求ner模块能够接受分好词的list,而不是自己进行分词,否则实体span获取不准确.

In [9]:
def find_NER_in_spacy(raw_content, nlp = nlp, tokensize=False, ner=False, \
                      exclude_list = ['PERCENT', 'MONEY', 'QUANTITY', 'ORDINAL', 'CARDINAL']):
    '''使用spacy进行NER.
    标注解释: https://spacy.io/api/annotation
    '''
    res_nlp = nlp(raw_content)
    tokens = None
    if tokensize:
        tokens =  [str(i) for i in res_nlp.doc]
    
    entities_list = None
    if ner:
        entities_list = []
        for item in res_nlp.ents:
            if item.label_ in exclude_list: continue
            entities_dict = {}
            entities_dict['type'] = item.label_
            entities_dict['span_start'] = item.start
            entities_dict['content'] = item.text
            entities_dict['span_end'] = item.end
            entities_list.append(entities_dict)
        # print(dir(item))
    return tokens, entities_list

# 构建图

In [10]:
# 返回Q-paragraph(for BERT); adj; node_list
edge_type_map = {
    'Q_P':1,
    'Q_E':2,
    'P_S':3,
    'S_P_hyper':4,
    'S_E':5,
    'P_P':6,
    'S_S':7,
}

def create_ques_info_dict():
    # 每个question item由5个元素组成
    ques_info_dict={}
    ques_info_dict['id'] = None
    ques_info_dict['node_list'] = None
    ques_info_dict['sp_adj'] = None
    return ques_info_dict

In [11]:
from collections import defaultdict
from traceback import print_exc

def process(item):
    '''sub-function for multi-processes function.'''
    ques_info_dict = create_ques_info_dict()

    supporting_facts = defaultdict(list)
    for s_fact in item['supporting_facts']:
        supporting_facts[s_fact[0]].append(s_fact[1])

    level = item['level']
    question = item['question']
    context = item['context']
    answer = item['answer']
    Q_id = item['_id']
    Q_type = item['type']

    ques_info_dict['id'] = Q_id

    node_list = []
    sp_adj = Adjacency_sp([])
    index_cursor = 0
    
    Q_node_cursor = index_cursor    
    Q_node = QuestionNode.build(Q_node_cursor, -1, question,
                               answer = answer, ques_type = Q_type)
    
    node_list.append(Q_node)
    
    for p_index, paragraph in enumerate(context):
        paragraph_label = 0

        title = paragraph[0]
        
        index_cursor += 1

        # 添加P-1和P
        try:
            if index_cursor != 0: sp_adj.append(edge_type_map['P_P'], P_node_cursor, index_cursor) 
        except (NameError): # 首次调用P_node_cursor会报错
            pass
        P_node_cursor = index_cursor
        P_node = ParagraphTitleNode.build(P_node_cursor, Q_node_cursor, title)
        
        # 判断 support paragraph
        if title in supporting_facts.keys():
            paragraph_label = 1
            P_node.set_support()
        node_list.append(P_node)
        
        # 添加Q和P
        sp_adj.append(edge_type_map['Q_P'], Q_node_cursor, P_node_cursor)

        for s_index, sentence in enumerate(paragraph[1]):
            S_id = f'{Q_id}_{p_index}_{s_index}'

            index_cursor += 1
            S_node_cursor = index_cursor
            S_node = SentenceNode.build(S_node_cursor, P_node_cursor, sentence)
            
            # 判断support fact
            if (paragraph_label == 1) and (s_index in supporting_facts[title]):
                S_node.set_support()
            node_list.append(S_node)

            # 添加S之间边; P和S之间边
            if s_index != 0:
                sp_adj.append(edge_type_map['S_S'], S_node_cursor - 1 - _Entity_len, S_node_cursor)
            sp_adj.append(edge_type_map['P_S'], P_node_cursor, S_node_cursor)

            # 添加S和E之间边
            _Entity_len = len(S_node.content_NER_list)
            for entities_dict in S_node.content_NER_list:
                index_cursor += 1
                E_node_cursor = index_cursor
                E_node = EntityNode.build(E_node_cursor, S_node_cursor,
                                          entities_dict['content'])
#                 E_node.set_span_in_sentence(entities_dict['span_start'])

                node_list.append(E_node)
                sp_adj.append(edge_type_map['S_E'], S_node_cursor, E_node_cursor)
              
    # in item loop
    # 连接Q节点和E节点.
    E_nodes_in_Q = [i for i in node_list if i.node_type == 'Entity' \
                    and i.content_raw.replace(' ','') in question.replace(' ','')]
    for i in E_nodes_in_Q:
        sp_adj.append(edge_type_map['Q_E'], Q_node_cursor, i.node_id)
    
    # 连接S和P节点
    S_nodes = [i for i in node_list if i.node_type == 'Sentence']
    E_nodes = [i for i in node_list if i.node_type == 'Entity']
    for E_n in E_nodes:
        entity = E_n.content_raw.replace(' ','')
        for S_n in S_nodes:
            if entity in S_n.content_raw.replace(' ',''):
                sp_adj.append(edge_type_map['S_P_hyper'], node_list[E_n.parent_id].node_id, node_list[S_n.parent_id].node_id)
            
    if len(node_list) != sp_adj.to_dense_symmetric().shape[0]:
        print(node_list)
        for i,(v_i_j) in enumerate(sp_adj.v_i_j): print(f"{i}:\t{v_i_j}")
        raise AssertionError
        
    ques_info_dict['node_list'] = node_list
    ques_info_dict['sp_adj'] = sp_adj
    return ques_info_dict

In [49]:
from tqdm.notebook import tqdm
def preprocessing(item_num = 5):
    item_num = None if item_num<0 else item_num
    
    resturn_list = [] 
    for item in tqdm(json_train[:item_num]):
        resturn_list.append(process(item))
    
    return resturn_list

In [50]:
hotpotQA_train_preprocess = preprocessing(20)

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




In [None]:
from tqdm.notebook import tqdm
from multiprocessing import Pool
def preprocessing(item_num = 2, process_num = 1):
    '''main multi-processes function.'''
    item_num = None if item_num<0 else item_num
    process_num = 1 if process_num<0 else process_num
    
    resturn_list = []
    pbar = tqdm(total = len(json_train[:item_num]), desc = f'processing json items')
    with Pool(process_num) as pool:
        pool_iter = pool.imap(process, json_train[:item_num])
        for i,r in enumerate(pool_iter):
            resturn_list.append(r)
            pbar.update()
    return resturn_list

In [None]:
hotpotQA_train_preprocess = preprocessing(100,16)

## 测试序列化

In [14]:
len(hotpotQA_train_preprocess)

500

In [16]:
for ques_item in hotpotQA_train_preprocess:
    for node in ques_item['node_list']:
        node.cls_feature = torch.randn([1,768])
    ques_item['node_list'] = [node.to_serializable() for node in ques_item['node_list']]
    ques_item['sp_adj'] = ques_item['sp_adj'].to_serializable()

In [20]:
with open('save_preprocess/test500.json', 'w', encoding='utf-8') as fp:
    json.dump(hotpotQA_train_preprocess, fp)
!ls -hlt save_preprocess

总用量 114G
-rwxrwxrwx 1 root root 1.5G 3月  17 17:23 test500.json
-rwxrwxrwx 1 root root  11G 3月  17 17:10 HotpotPrepro_5000.pkl
-rwxrwxrwx 1 root root  12G 3月  17 16:46 HotpotPrepro_4500.pkl
-rwxrwxrwx 1 root root  12G 3月  17 16:21 HotpotPrepro_4000.pkl
-rwxrwxrwx 1 root root  11G 3月  17 15:57 HotpotPrepro_3500.pkl
-rwxrwxrwx 1 root root  12G 3月  17 15:32 HotpotPrepro_3000.pkl
-rwxrwxrwx 1 root root  12G 3月  17 15:08 HotpotPrepro_2500.pkl
-rwxrwxrwx 1 root root  12G 3月  17 14:43 HotpotPrepro_2000.pkl
-rwxrwxrwx 1 root root  12G 3月  17 14:19 HotpotPrepro_1500.pkl
-rwxrwxrwx 1 root root  12G 3月  17 13:55 HotpotPrepro_1000.pkl
-rwxrwxrwx 1 root root  12G 3月  17 13:31 HotpotPrepro_500.pkl


In [24]:
with open('save_preprocess/test500.json', 'r', encoding='utf-8') as fp:
    reload_hotpotQA = json.load(fp)

In [25]:
for ques_item in reload_hotpotQA:
    ques_item['node_list'] = [auto_reload_Node(item) for item in ques_item['node_list']]
    ques_item['sp_adj'] = Adjacency_sp.from_serializable(ques_item['sp_adj'])

# get_features_from_XLNET

下载地址:
```http
https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-pytorch_model.bin
https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin

https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json
https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json
```

文件名必须设置为:
```bash
/g/data/models/xlnet-base-cased# ll -hl

-rwxrwxrwx 1 root root  691 3月  16 22:51 config.json*
-rwxrwxrwx 1 root root 446M 3月  16 22:21 pytorch_model.bin*
```

In [12]:
def get_cls_feature_from_XLNET(text,text_pair=None,
                            tokenizer = None,
                            model = None,
                            add_special_tokens = True,
                           device = 'cuda',
                           test_mode = False):
    '''XLNET在512张TPU v3上训练5.5天得到. 一张TPU 8核心 128GB内存.'''
    if test_mode: return torch.randn([30, 1, 768])   
    assert model
    model_input = tokenizer_XLNET.encode_plus(text,text_pair,
                                        add_special_tokens=add_special_tokens,
                                        return_tensors='pt')
    
    model_input = {k:v.to(device) for k,v in model_input.items()}
    
    # 不能在函数里面设置device.
    # model.to(device)
    with torch.no_grad():
        last_hidden_state = model(**model_input)[0].cpu()[:,-1,:]
    
    del model_input
    
    return last_hidden_state

In [13]:
from transformers import AutoConfig, AutoModel, AutoTokenizer

model_name = 'xlnet-base-cased'

proxies={"http_proxy": "127.0.0.1:10809",
         "https_proxy": "127.0.0.1:10809"}
proxies=None

tokenizer_XLNET = AutoTokenizer.from_pretrained(model_name,proxies=proxies)

In [14]:
model_path = '/g/data/models/xlnet-base-cased'

model_XLNET01 = AutoModel.from_pretrained(model_path,local_files_only=True)
model_XLNET01.eval()
_ = model_XLNET01.to(DEVICE)

# model_XLNET02 = AutoModel.from_pretrained(model_path,local_files_only=True)
# _ = model_XLNET02.to(DEVICE)

# model_XLNET03 = AutoModel.from_pretrained(model_path,local_files_only=True)
# _ = model_XLNET03.to(DEVICE)

# model_XLNET04 = AutoModel.from_pretrained(model_path,local_files_only=True)
# _ = model_XLNET04.to(DEVICE)


**注意**. 关于bert ner和XLNET获取features. 不同模型的分词结果不一样. 上文使用bert分词和ner, 目的是获取ner的`cls_features`. 之后只需使用原始句子, 重新进行分词, 计算`span`即可.

In [15]:
from traceback import print_exc

In [16]:
def build(index_item, test_mode = False):
    index, ques_item = index_item[0], index_item[1]
    if index % 1 == 0:
        model_XLNET = model_XLNET01
#     elif index % 2 == 1:
#         model_XLNET = model_XLNET02
#     elif index % 4 == 2:
#         model_XLNET = model_XLNET03
#     elif index % 4 == 3:
#         model_XLNET = model_XLNET04
        
    try:
        node_list = ques_item['node_list']
        # Q node
        Q_node = node_list[0]
        Q_node.cls_feature = get_cls_feature_from_XLNET(Q_node.content_raw,
                                                         tokenizer = tokenizer_XLNET,
                                                          model = model_XLNET,
                                                         device = DEVICE,
                                                        test_mode=test_mode) # [1,N,D]
#         Q_node.cls_feature = Q_node.content_features[:,-1,:]


        # S node
        for S_node in [i for i in node_list if i.node_type == 'Sentence']:
            # content_features不能包含特殊字符.
#             S_node.content_features = get_cls_feature_from_XLNET(S_node.content_raw,
#                                                                 add_special_tokens=False,
#                                                                 tokenizer = tokenizer_XLNET,
#                                                                 model = model_XLNET,
#                                                                 device = DEVICE)

            S_node.cls_feature = get_cls_feature_from_XLNET(Q_node.content_raw, 
                                                            S_node.content_raw,
                                                            add_special_tokens=True,
                                                            tokenizer = tokenizer_XLNET,
                                                            model = model_XLNET,
                                                            device = DEVICE,
                                                        test_mode=test_mode)   

        # P node
        for P_i, P_node in [(i,n) for i,n in enumerate(node_list) if n.node_type == 'Paragraph']:
                S_in_P = [n for n in node_list if n.parent_id == P_i]
                all_S_raw = ' '.join([n.content_raw for n in S_in_P])
#                 P_node.content_features = [n.content_features for n in S_in_P]
                P_node.cls_feature = get_cls_feature_from_XLNET(Q_node.content_raw, 
                                                                all_S_raw,
                                                                add_special_tokens=True,
                                                                tokenizer = tokenizer_XLNET,
                                                                model = model_XLNET,
                                                                device = DEVICE,
                                                        test_mode=test_mode)

        # E node
        for E_node in [i for i in node_list if i.node_type == 'Entity']:
#             start = E_node.start_in_sentence
#             end = E_node.end_in_sentence
    #         E_node.content_features = node_list[E_node.parent_id].content_features[:,start:end,:]
    #         E_node.cls_feature = torch.mean(E_node.content_features, dim = 1)

#             E_node.content_features = get_cls_feature_from_XLNET(E_node.content_raw,
#                                                                 add_special_tokens=False,
#                                                                 tokenizer = tokenizer_XLNET,
#                                                                 model = model_XLNET,
#                                                                 device = DEVICE)
            E_node.cls_feature = get_cls_feature_from_XLNET(Q_node.content_raw,
                                                            E_node.content_raw, 
                                                            add_special_tokens=True,
                                                            tokenizer = tokenizer_XLNET,
                                                            model = model_XLNET,
                                                            device = DEVICE,
                                                        test_mode=test_mode)

            # 减负
#             E_node.content_features = None

        return (index,ques_item)
    except:
        print_exc()

In [76]:
# 单进程
from tqdm.notebook import tqdm

def building_cls(item_num = 5):
    item_num = None if item_num<0 else item_num
    
    resturn_list = [] 
    for item in tqdm(hotpotQA_train_preprocess[:item_num], desc = 'building features'):
        resturn_list.append(_build(item))
    
    return resturn_list

In [None]:
hotpotQA_cls_feat = building_cls(20)

In [26]:
Hotpot_index_items = [(i, item) for i,item in enumerate(hotpotQA_train_preprocess)]
len(Hotpot_index_items)

5000

In [29]:
from tqdm.notebook import tqdm
from multiprocessing.dummy import Pool

# torch.multiprocessing.set_start_method('spawn', force=True)
# model_XLNET.share_memory()

def multi_build(from_index = 0, to_index = None, thread_num = 1):
    try:
        thread_num = 1 if thread_num<0 else thread_num

        resturn_list = []
        pbar = tqdm(total = len(hotpotQA_train_preprocess[from_index:to_index]), desc = f'building features')
        with Pool(thread_num) as pool:
            pool_iter = pool.imap(_build, Hotpot_index_items[from_index:to_index])
            for i,r in enumerate(pool_iter):
                resturn_list.append(r)
                pbar.update()
        return resturn_list
    except:
        print (time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
        print_exc()

In [34]:
hotpotQA_preprocess_cls = multi_build(-1,1)
# 20/20 [01:21<00:00, 4.25s/it]

HBox(children=(FloatProgress(value=0.0, description='building features', max=20.0, style=ProgressStyle(descrip…

In [30]:
hotpotQA_preprocess_cls = []

In [31]:
hotpotQA_preprocess_cls.extend(multi_build(-1,1))

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='building features', max=1.0, style=Prog…

In [32]:
for ques_item in tqdm(hotpotQA_preprocess_cls, desc = 'clean features'):
    for X_node in [i for i in ques_item['node_list']]:
        X_node.content_features = None

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='clean features', max=1.0, style=Progres…




# 检查nan

In [None]:
for ques_item in tqdm(hotpotQA_preprocess_cls, desc = 'clean features'):
    for X_node in [i for i in ques_item['node_list']]:
        if torch.any(torch.isnan(X_node.cls_feature)).item():
            print (X_node)

# 保存

In [None]:
import pickle
with open(save_cache_path+f'hotpotQA_preprocess_cls_{len(hotpotQA_preprocess_cls)}.pkl', 'wb') as fp:
    pickle.dump(hotpotQA_preprocess_cls, fp, protocol=-1)
!ls -hl $save_cache_path

In [38]:
import pickle
with open(save_cache_path+'hotpotQA_preprocess_cls20.pkl', 'wb') as fp:
    pickle.dump(hotpotQA_preprocess_cls[:20], fp, protocol=-1)
!ls -hl $save_cache_path

总用量 2.9G
-rwxrwxrwx 1 root root 2.5G 3月  16 12:17 hotpotQA_preprocess_cls100.pkl
-rwxrwxrwx 1 root root 470M 3月  16 20:42 hotpotQA_preprocess_cls20.pkl


# 封装与内存优化

In [17]:
import os
import json
from tqdm.notebook import tqdm

In [18]:
def save_in_splits(split_num = 200, start = 0, end = 1000):
    hotpotQA_preprocess_cls = []
    for index,item in enumerate(tqdm(json_train[start:end])):
        if os.path.exists(f"save_preprocess_new/{item['_id']}.json"): continue
        i = build((index + start + 1, process(item)))
        hotpotQA_preprocess_cls.append(i[1])
        if i[0] % split_num == 0:
            
            for ques_item in tqdm(hotpotQA_preprocess_cls, desc = f'{i[0]}'):
                ques_item['node_list'] = [node.to_serializable() for node in ques_item['node_list']]
                ques_item['sp_adj'] = ques_item['sp_adj'].to_serializable()

                with open(f"save_preprocess_new/{ques_item['id']}.json", 'w', encoding='utf-8') as fp:
                    json.dump(ques_item, fp)
                
            hotpotQA_preprocess_cls = []
            torch.cuda.empty_cache()


In [19]:
save_in_splits(start = 0, end = 60000)
!ls -l save_preprocess_new/|grep "^-"| wc -l

HBox(children=(FloatProgress(value=0.0, max=60000.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='50200', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='50400', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='50600', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='50800', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='51000', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='51200', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='51400', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='51600', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='51800', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='52000', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='52200', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='52400', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='52600', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='52800', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='53000', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='53200', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='53400', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='53600', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='53800', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='54000', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='54200', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='54400', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='54600', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='54800', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='55000', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='55200', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='55400', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='55600', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='55800', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='56000', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='56200', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='56400', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='56600', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='56800', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='57000', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='57200', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='57400', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='57600', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='57800', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='58000', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='58200', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='58400', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='58600', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='58800', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='59000', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='59200', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='59400', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='59600', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='59800', max=200.0, style=ProgressStyle(description_width=…




HBox(children=(FloatProgress(value=0.0, description='60000', max=200.0, style=ProgressStyle(description_width=…



60000


## 整合json chunk



In [1]:
import os
import json
from tqdm.notebook import tqdm
from class_simple import QuestionNode, ParagraphTitleNode, SentenceNode, EntityNode, Adjacency_sp, auto_reload_Node

In [3]:
paths = [f"save_preprocess/{path}" for path in os.listdir('save_preprocess/')]
chunk_size = 2000

done_list = []
for index in range(0, len(paths), chunk_size):
    if index == 2000:
        break
    chunk_json = []
    _to = min(index + chunk_size, len(paths))
    with open(f"HotpotQA_chunk_jsons3/{index}_{_to}.json", 'w', encoding='utf-8') as fp1:
        fp1.write('[')
        for path in tqdm(paths[index:_to], desc = f"{index} ~ {_to}"):
            with open(path, 'r', encoding='utf-8') as fp2:
                fp1.write(fp2.read() + ',')
            done_list.append(path)
        fp1.seek(fp1.tell() - 1)
        fp1.write(']')
    break

HBox(children=(FloatProgress(value=0.0, description='0 ~ 2000', max=2000.0, style=ProgressStyle(description_wi…




In [1]:
import ujson
import time
start = time.time()

with open('HotpotQA_chunk_jsons2/0_1000.json', 'r', encoding='utf-8') as fp:
    xx = ujson.load(fp)
    
end = time.time()

print(f"{end-start:.2f}s")

91.39


In [None]:
import ijson

res = []
def floaten(event):
    if event[1] == 'number':
        return (event[0], event[1], float(event[2]))
    else:
        return event

pbar = tqdm(total = 1000)
with open('HotpotQA_chunk_jsons3/0_1000.json', 'r', encoding='utf-8') as fp:
    events = map(floaten, ijson.parse(fp))
    objects = ijson.common.items(events, 'items.item')
    for QA_item in objects:
        QA_item['node_list'] = [auto_reload_Node(state_dicts) for state_dicts in QA_item['node_list']]
        QA_item['sp_adj'] = Adjacency_sp.from_serializable(QA_item['sp_adj'])
        res.append(QA_item)
        pbar.update()

In [3]:
import os, time, gc
import ujson as json
from tqdm.notebook import tqdm
from multiprocessing import Pool

from class_simple import QuestionNode, ParagraphTitleNode, SentenceNode, EntityNode, Adjacency_sp, auto_reload_Node

In [4]:
paths = [f"save_preprocess/{path}" for path in os.listdir('save_preprocess')]
QA_item_list = []

In [5]:
def __item(path):
    with open(path, 'r', encoding = 'utf-8') as fp:
        QA_item = json.load(fp)
#     QA_item['node_list'] = [auto_reload_Node(state_dicts) for state_dicts in QA_item['node_list']]
#     QA_item['sp_adj'] = Adjacency_sp.from_serializable(QA_item['sp_adj'])
        return QA_item
        
        
def __items(path):
    start = time.time()
    with open(path, 'r', encoding = 'utf-8') as fp:
        _temp = json.load(fp)

    end = time.time()
    print(f"{end-start:.2f}s")

    for QA_item in _temp:
        QA_item['node_list'] = [auto_reload_Node(state_dicts) for state_dicts in QA_item['node_list']]
        QA_item['sp_adj'] = Adjacency_sp.from_serializable(QA_item['sp_adj'])
    QA_item_list.extend(_temp)
    return _temp
    
def multi_build(from_index = 0, to_index = 1000, p_num = 8):
    p_num = 1 if p_num<0 else p_num

    resturn_list = []
    pbar = tqdm(total = len(paths[from_index:to_index]), desc = f'reloading')
    with Pool(p_num) as pool:
        pool_iter = pool.imap(__item, paths[from_index:to_index])
        for r in pool_iter:
            resturn_list.append(r)
            pbar.update()
    return resturn_list


In [6]:
res = multi_build()

HBox(children=(FloatProgress(value=0.0, description='reloading', max=1000.0, style=ProgressStyle(description_w…

In [23]:
import sys
def get_size(obj, seen=None):
    # From 
    # Recursively finds size of objects
    size = sys.getsizeof(obj)
    if seen is None:
        seen = set()
    obj_id = id(obj)
    if obj_id in seen:
        return 0
# Important mark as seen *before* entering recursion to gracefully handle
    # self-referential objects
    seen.add(obj_id)
    if isinstance(obj, dict):
        size += sum([get_size(v, seen) for v in obj.values()])
        size += sum([get_size(k, seen) for k in obj.keys()])
    elif hasattr(obj, '__dict__'):
        size += get_size(obj.__dict__, seen)
    elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)):
        size += sum([get_size(i, seen) for i in obj])
    return size

In [24]:
get_size(res) / 1024 / 1024 / 1024

19390231714

In [25]:
sys.getsizeof?

In [None]:
import os, time, gc
import ujson as json

start = time.time()
with open('HotpotQA_chunk_jsons/0_5000.json', 'r', encoding = 'utf-8') as fp:
    _temp = json.load(fp)

end = time.time()
print(f"{end-start:.2f}s")