<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#构建图" data-toc-modified-id="构建图-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>构建图</a></span></li><li><span><a href="#GNN" data-toc-modified-id="GNN-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>GNN</a></span></li><li><span><a href="#QA" data-toc-modified-id="QA-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>QA</a></span></li><li><span><a href="#整合" data-toc-modified-id="整合-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>整合</a></span></li></ul></div>

In [1]:
from argparse import Namespace
import json
import torch
from collections import defaultdict

from QA_models import AutoQuestionAnswering
from GNN import GAT_HotpotQA

In [2]:
args = Namespace(
    
    dev_json_path = 'data/HotpotQA/hotpot_dev_distractor_v1.json',
    
    GNN_model_path = 'models_checkpoints/GNN/GNN_hidden256_heads8_pad300.pt',
    QA_model_path = 'models_checkpoints/QA/HotpotQA_QA_MLP+unfreeze1_roberta-base.pt',
    pretrained_model_path = 'data/models/roberta-base',
    model_path = '',
#     # GNN parameters
    features = 768,
    hidden = 256,
    nclass = 2,
    dropout = 0,
    alpha = 0.3,
    nheads = 8,
    pad_max_num = 300,

    device = 'cuda:0',
    

    header_mode='MLP',
    cls_token_id=0,
    
    topN_sents=3,
    max_length=512,
    uncased=False,
    seed=123,
)
args.model_path = args.pretrained_model_path

In [3]:
with open(args.dev_json_path, 'r', encoding='utf-8') as f1:
    dev_json = json.load(f1)

In [4]:
one_item = dev_json[0]

In [5]:
one_item

{'_id': '5a8b57f25542995d1e6f1371',
 'answer': 'yes',
 'question': 'Were Scott Derrickson and Ed Wood of the same nationality?',
 'supporting_facts': [['Scott Derrickson', 0], ['Ed Wood', 0]],
 'context': [['Ed Wood (film)',
   ['Ed Wood is a 1994 American biographical period comedy-drama film directed and produced by Tim Burton, and starring Johnny Depp as cult filmmaker Ed Wood.',
    " The film concerns the period in Wood's life when he made his best-known films as well as his relationship with actor Bela Lugosi, played by Martin Landau.",
    ' Sarah Jessica Parker, Patricia Arquette, Jeffrey Jones, Lisa Marie, and Bill Murray are among the supporting cast.']],
  ['Scott Derrickson',
   ['Scott Derrickson (born July 16, 1966) is an American director, screenwriter and producer.',
    ' He lives in Los Angeles, California.',
    ' He is best known for directing horror films such as "Sinister", "The Exorcism of Emily Rose", and "Deliver Us From Evil", as well as the 2016 Marvel Cinema

# 构建图

In [6]:
from gen_nodes_repr import build_for_one_item

In [7]:
ques_items = build_for_one_item(dev_json[:2], args)

init...


100%|██████████| 2/2 [00:08<00:00,  4.02s/it]


# GNN

In [8]:
classifier = GAT_HotpotQA(features=args.features, hidden=args.hidden, nclass=args.nclass, 
                            dropout=args.dropout, alpha=args.alpha, nheads=args.nheads, 
                            nodes_num=args.pad_max_num)
classifier.to("cuda")
# args.GNN_model_path = 'save_cache_GNN/GNN_HotpotQA_hidden64_heads8_pad300_chunk_first.pt'
checkpoint = torch.load(args.GNN_model_path)
try:
    classifier.load_state_dict(checkpoint['model'])
except:
    classifier.load_state_dict({k.replace('module.',''):v for k,v in checkpoint['model'].items()})

In [9]:
from datasets import HotpotQA_GNN_Dataset, gen_GNN_batches

In [10]:
dataset = HotpotQA_GNN_Dataset.load_for_eval(ques_items)
dataset.set_parameters(300,0)
dataset.set_split('val')
dataset

HotpotQA GNN Dataset. mode: val. size: 2. max_seq: 300

In [11]:
batch_generator = gen_GNN_batches(dataset, 1, shuffle=False, drop_last=False, device='cuda')
QA_eval_list, Qtype_list = [], []
for index, batch_dict in enumerate(batch_generator):
    with torch.no_grad():
        logits_sent, logits_para, logits_Qtype = \
                        classifier(batch_dict['feature_matrix'], batch_dict['adj'])

        max_value, max_index = logits_sent.max(dim=-1) # max_index is predict class.
        topN_sent_index_batch = (max_value * batch_dict['sent_mask'].squeeze()).topk(3, dim=-1)[1]
        topN_sent_index_batch = topN_sent_index_batch.squeeze().tolist()
    
    item=ques_items[index]
    info_list = [[item["node_list"][item["node_list"][s_id].parent_id].content_raw,
                         item["node_list"][s_id].order_in_para,
                         item["node_list"][s_id].content_raw] \
                 for s_id in topN_sent_index_batch]
    
    print(item['id'])
    
    question = item["node_list"][0].content_raw
    sup_sent_id_list = [i[:-1] for i in info_list]
    sup_sent_list = [i[-1] for i in info_list]
    
    _values, indices = logits_Qtype.max(dim=-1)
    Qtype_list.append(indices.tolist()[0])
    
    print(sup_sent_id_list)
    print(sup_sent_list)
    QA_eval_list.append((question, sup_sent_list))

5a8b57f25542995d1e6f1371
[['Woodson, Arkansas', 0], ['Scott Derrickson', 1], ['Scott Derrickson', 0]]
['Woodson is a census-designated place (CDP) in Pulaski County, Arkansas, in the United States.', ' He lives in Los Angeles, California.', 'Scott Derrickson (born July 16, 1966) is an American director, screenwriter and producer.']
5a8c7595554299585d9e36b6
[['Shirley Temple', 0], ['Meet Corliss Archer', 3], ['Meet Corliss Archer', 2]]
["Shirley Temple Black (April 23, 1928 – February 10, 2014) was an American actress, singer, dancer, businesswoman, and diplomat who was Hollywood's number one box-office draw as a child actress from 1935 to 1938.", " Despite the program's long run, fewer than 24 episodes are known to exist.", ' From October 3, 1952 to June 26, 1953, it aired on ABC, finally returning to CBS.']


In [20]:
sup_sent_list

["Shirley Temple Black (April 23, 1928 – February 10, 2014) was an American actress, singer, dancer, businesswoman, and diplomat who was Hollywood's number one box-office draw as a child actress from 1935 to 1938.",
 " Despite the program's long run, fewer than 24 episodes are known to exist.",
 ' From October 3, 1952 to June 26, 1953, it aired on ABC, finally returning to CBS.']

# QA

In [12]:
from datasets import HotpotQA_QA_Dataset, generate_QA_batches
from transformers import AutoTokenizer
from QA_models import AutoQuestionAnswering

In [13]:
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_path, local_files_only=True)
classifier = AutoQuestionAnswering.from_pretrained(model_path=args.pretrained_model_path,
                                                    header_mode=args.header_mode,
                                                    cls_index=tokenizer.cls_token_id)
classifier = classifier.to(args.device)
checkpoint = torch.load(args.QA_model_path)
classifier.load_state_dict(checkpoint['model'])
_ = classifier.eval()

In [14]:
QA_eval_list

[('Were Scott Derrickson and Ed Wood of the same nationality?',
  ['Woodson is a census-designated place (CDP) in Pulaski County, Arkansas, in the United States.',
   ' He lives in Los Angeles, California.',
   'Scott Derrickson (born July 16, 1966) is an American director, screenwriter and producer.']),
 ('What government position was held by the woman who portrayed Corliss Archer in the film Kiss and Tell?',
  ["Shirley Temple Black (April 23, 1928 – February 10, 2014) was an American actress, singer, dancer, businesswoman, and diplomat who was Hollywood's number one box-office draw as a child actress from 1935 to 1938.",
   " Despite the program's long run, fewer than 24 episodes are known to exist.",
   ' From October 3, 1952 to June 26, 1953, it aired on ABC, finally returning to CBS.'])]

In [15]:
datasetQA = HotpotQA_QA_Dataset.load_for_eval(QA_eval_list)
datasetQA.set_parameters(tokenizer = tokenizer, topN_sents = args.topN_sents,
                        max_length=args.max_length, uncased=args.uncased,
                        permutations=False, random_seed=args.seed)
batch_generatorQA = generate_QA_batches(datasetQA, 1, shuffle=False, drop_last=False, device="cuda")

In [16]:
datasetQA

HotpotQA QA Dataset. mode: eval. size: 2. sents num: 3

In [17]:
ans_dict = {}
ans_dict_topN = defaultdict(list)

for index, batch_dict in enumerate(batch_generatorQA):
    print(index)
    with torch.no_grad():
        res = classifier(**batch_dict)
        start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = res[:5]
        start_top_index = start_top_index.squeeze().tolist()
        end_top_index = end_top_index.squeeze().tolist()
        assert len(start_top_index) == len(end_top_index)
        
        input_ids = batch_dict['input_ids'].squeeze().tolist()
        item = ques_items[index]
        
        for index,(i,j) in enumerate(zip(start_top_index,end_top_index)):
            if index == 0:
                if Qtype_list[index] == 0:
                    ans_dict[item['id']] = tokenizer.decode(input_ids[i:j+1])
                else: # comparations
                    _values, indices = cls_logits.max(dim=-1)
                    ans = 'yes' if indices.tolist()[0] == 1 else 'no'
                    ans_dict[item['id']] = ans
            ans_dict_topN[item['id']].append(tokenizer.decode(input_ids[i:j+1]))


0
1


In [18]:
ans_dict

{'5a8b57f25542995d1e6f1371': ' American',
 '5a8c7595554299585d9e36b6': ' actress'}

# 整合

In [19]:
final_res = {}
final_res['answer'] = ans_dict
final_res['sp'] = sup_dict

NameError: name 'sup_dict' is not defined

In [None]:
final_res