In [2]:
#!/usr/bin/python
# -*- coding: UTF-8 -*-
# basic
import logging
import sys
import os
# io
import pickle
import json
import zipfile
import csv

# util
import time
from tqdm import tqdm
from collections import defaultdict
from collections import Counter
from math import log, sqrt

# nltk
import nltk
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.corpus import stopwords

# preprocess
import unicodedata
import re

# user-defined module
import prepare

stemmer = nltk.stem.PorterStemmer()
################################################################
# ENVIRON
################################################################
# This is server-gpu
server1_homepath = "/home/ubuntu/workspace/codelab/"
server2_homepath = "/home/ubuntu/workspace/codelab/"
gpu_homepath = "/home/shawn/workspace/research/final_codelab/"
jun_homepath = "/home/junw/workspace/codelab/"

# choose from the server1, server2, gpu, jun.
SERVERNAME = 'gpu'
HOMEPATH = {'server1':server1_homepath, 'server2':server2_homepath, 'gpu':gpu_homepath, 'jun':jun_homepath}[SERVERNAME]
# set your own HOMEPATH
TASKNAME = 'Run main part of the project'

## get the path of the wiki files
DATAPATH = HOMEPATH + "submission_data/"
ORGINAL_DATAPATH = DATAPATH +"original/"
INTERMEDIATE_DATAPATH = DATAPATH + "intermediate/"
FINAL_DATAPATH =  DATAPATH + "final/"
################################################################
# ENVIRON
################################################################

[nltk_data] Downloading package stopwords to /home/shawn/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /home/shawn/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [None]:
start_time = time.time()
train,dev,test = prepare.get_training_devset_test(ORGINAL_DATAPATH)
wiki = prepare.Wiki(DATAPATH, True,INTERMEDIATE_DATAPATH )
end_time = time.time()
elapsed_time = round(end_time - start_time,0)
# 
logging.info("The Task has spended {}s or {} minutes or {} hours.".format(elapsed_time, round(elapsed_time/60,2), round(elapsed_time/3600),2))

In [3]:
def preprocessed_claim_sentence(claim):
    claim = unicodedata.normalize('NFC', claim)
    claim = claim.replace(':','-COLON-')
    claim = claim.replace('-COLON-',' -COLON-')
    claim = claim.replace('(','-LRB-')
    claim = claim.replace(')','-RRB-')
    claim = claim.replace("_"," ").replace("-LRB-","-LRB- ").replace("-RRB-"," -RRB")
    claim = re.sub('–', '-', claim)
    claim = claim.replace("`","'")
    claim = claim.replace("  "," ")
    # replaced = re.sub('_-LRB-.*', '', title)
    return claim.strip()

class InvertedIndex:
    def __init__(self, vocab, doc_term_freqs):
        self.vocab = vocab
        self.doc_len = [0] * len(doc_term_freqs)
        self.doc_term_freqs = [[] for i in range(len(vocab))]
        self.doc_ids = [[] for i in range(len(vocab))]
        self.doc_freqs = [0] * len(vocab)
        self.total_num_docs = 0
        self.max_doc_len = 0
        for docid, term_freqs in enumerate(doc_term_freqs):
            doc_len = sum(term_freqs.values())
            self.max_doc_len = max(doc_len, self.max_doc_len)
            self.doc_len[docid] = doc_len
            self.total_num_docs += 1
            for term, freq in term_freqs.items():
                term_id = vocab[term]
                self.doc_ids[term_id].append(docid)
                self.doc_term_freqs[term_id].append(freq)
                self.doc_freqs[term_id] += 1

    def num_terms(self):
        return len(self.doc_ids)

    def num_docs(self):
        return self.total_num_docs

    def docids(self, term):
        term_id = self.vocab[term]
        return self.doc_ids[term_id]

    def freqs(self, term):
        term_id = self.vocab[term]
        return self.doc_term_freqs[term_id]

    def f_t(self, term):
        term_id = self.vocab[term]
        return self.doc_freqs[term_id]
    
def processed(docs, lower = False, stem = False):
# processed_docs stores the list of processed docs
    processed_docs = []
    # vocab contains (term, term id) pairs
    vocab = {}
    # total_tokens stores the total number of tokens
    total_tokens = 0
    for raw_doc in docs:
        # norm_doc stores the normalized tokens of a doc
        norm_doc = []
        if stem == True:
            tokenized_sentence = raw_doc.split(" ")
        else:
            tokenized_sentence = nltk.word_tokenize(raw_doc)##tokenize
        for token in tokenized_sentence:
            if lower == True:
                token = token.lower()
            if stem == True:
                token = stemmer.stem(token.lower())
            if not (token in vocab.keys()):
                vocab[token] = len(vocab) ##add into the vocab,len(vocab) will be the id
            norm_doc.append(token)
            total_tokens += 1
        processed_docs.append(norm_doc)
        
    doc_term_freqs = []
    for doc in processed_docs:
        doc_term_freqs.append(Counter(doc))
        
    invindex = InvertedIndex(vocab, doc_term_freqs)
    
    return invindex

def doc_eval(setname,predict):
    wrong = []
    miss = 0
    total_len = 0
    for key in setname:
        total_len += len(predict[key])
        for evi in setname[key]['evidence']:
            title = unicodedata.normalize('NFC', evi[0])
            if title not in predict[key]:
                if key not in wrong:
                    wrong.append(key)
                    miss += 1

        #for i in wrong:
        #    print("==============Claim id: ", i , " ============")
        #    print("Claim : ", setname[i]['claim'])
        #    print("Target evidence : ", setname[i]['evidence'])
        #    print("Guess document : ", predict[i])
   # return miss,wrong,total_len  
    print("Recall : ", 1 - miss/len(setname))
    print("Average Number : ", total_len/len(setname))

def query_sim(query, index, k , th = 0):
    # scores stores doc ids and their scores
    scores = Counter()
    query = preprocessed_claim_sentence(query)
    query = nltk.word_tokenize(query) 
    for word in query:
        word = stemmer.stem(word.lower())
        if word in index.vocab:  #The word will not be counted if the word doesn't exist in the vocab,
            for i in range(0,len(index.docids(word))):
                doc_id = index.docids(word)[i]
                dts = 1/sqrt(index.doc_len[doc_id]) * log(1 + index.freqs(word)[i]) * log(index.num_docs() / index.f_t(word)) #calculate the tf-idf score
                scores[doc_id] = scores[doc_id] + dts #update the score  
    result = []
    if th == 0:
        return(scores.most_common(k))
    else:
        for i in scores.most_common(k):
            if (i[1] >= th) & (len(result)<3):
                result.append(i)
            else:
                return(result)
    return(result)

def sent_eval(target,guess):
    sen_num = 0
    sen_wrong = []
    for i in guess:
        t = target[i]['evidence']
        for e in t:
            e[0] = unicodedata.normalize('NFC',e[0])
            if not e in guess[i]:
                if not i in sen_wrong:
                    sen_wrong.append(i)
                    sen_num +=1
    total_len = 0
    for i in guess:
        total_len+=len(guess[i])
    print("Sentence Selection Result")
    print("Recall : ",1-sen_num/len(guess))
    print("Average length : ",total_len/len(guess))
    #return 1-sen_num/len(guess),total_len/len(guess)

def sentSearch(query,docs,sent_id,wiki = wiki, k = 20, th = 0,  gs = False):
    #docs = []
    #sent_id = []
    #for doc_title in evidence:
    #    for doc in wiki.wiki[doc_title]:
    #        string = ""
    #        for i in wiki.wiki[doc_title][doc].split(",")[1:]:
    #            string += i + " ,"
    #        docs.append(string[:-1])
    #        sent_id.append([doc_title,doc]) 
    index = processed(docs, lower = True, stem = True)
    result = query_sim(query,index,k,th)
    docs = []
    for i in result:
        docs.append(sent_id[i[0]])
    return(docs)

def ss_grid_search(guess,setname,wiki = wiki):
    filt_results_doc = {}
    for key in setname:
        docs,sent_id = getDoc(guess[key],wiki)
        t = sentSearch(setname[key]['claim'],docs,sent_id,k = 100)
        filt_results_doc[key] = t
    return filt_results_doc

def sent_selection_title(TB_docs,setname,wiki=wiki,topk=20,th=0):
    filt_results_doc = {}
    for key in tqdm(setname):
        docs,sent_id = getDoc(TB_docs[key],wiki)
        t = sentSearch(setname[key]['claim'],docs,sent_id,k=topk)
        filt_results_doc[key] = t
    return filt_results_doc

def sent_selection_cont(cont_docs,titles,setname,wiki=wiki,topk=100,th=0):
    filt_results_doc = {}
    for key in tqdm(cont_docs):
        docs = []
        sent_id = []
        cont_title = title_filter_for_cont(cont_docs[key],titles[key])
        for sent in cont_title:
            string = ""
            for i in wiki.single_sent(sent).split(",")[1:]:
                string += i + " ,"
            docs.append(string[:-1])
            sent_id.append(sent)
        t = sentSearch(setname[key]['claim'],docs,sent_id,k=topk,th=th)
        filt_results_doc[key] = t
    return filt_results_doc
    
def title_filter_for_cont(cont_docs,titles, wiki = wiki):
    sent = []
    for doc_title in cont_docs:
        doc_title = unicodedata.normalize('NFC',doc_title)
        for sent_id in wiki.wiki[doc_title]:
            doc = wiki.wiki[doc_title][sent_id]
            for title in titles:
                if title.lower() in doc.lower():
                    sent.append([doc_title,sent_id])
    return sent

def merged_result(sent_sel1,sent_sel2):
    merged = {}
    for i in sent_sel1:
        merged[i] = []
        for j in sent_sel1[i]:
            merged[i].append(j)
        for j in sent_sel2[i]:
            if j not in sent_sel1[i]:
                merged[i].append(j)
    return merged

def sentent_selection(TB_docs,cont_docs,setname,titles,wiki=wiki,k = 30, th = 0.5):
    TB_title = sent_selection_title(TB_docs,setname)
    cont_title = sent_selection_cont(cont_docs,titles,setname,topk = k, th = th)
    final_title = merged_result(TB_title,cont_title)
    return TB_title

def output_senten_result(result,setname,path):
    output,claim_id = getoutput(result,setname)
    with open(path, 'wt') as out_file:
        tsv_writer = csv.writer(out_file, delimiter='\t')
        tsv_writer.writerow("TEST!")
        for i in output:
            tsv_writer.writerow(i)
    return output,claim_id
            
def getoutput(result,setname,wiki=wiki):
    output = []
    claim_id = []
    for i in result:
        for evi in result[i]:
            example = []
            claim = setname[i]['claim']
            claim = preprocessed_claim_sentence(claim)
            claim = unicodedata.normalize('NFC',claim)
            example.append(claim)
            example.append(wiki.single_sent(evi))
            output.append(example)
            claim_id.append([i,evi])
    return output,claim_id

In [8]:
def merged_list(doc_list1,doc_list2,n = 5):
    merged_list = []
    for doc in doc_list1:
        merged_list.append(doc)
    for doc in doc_list2:
        if doc not in merged_list:
            merged_list.append(doc)
    return merged_list

def getDoc(evidence,wiki):
    docs = []
    sent_id = []
    for doc_title in evidence:
        for doc in wiki.wiki[doc_title]:
            string = ""
            for i in wiki.wiki[doc_title][doc].split(",")[1:]:
                string += i + " ,"
            docs.append(string[:-1])
            sent_id.append([doc_title,doc])
    return docs,sent_id

def sentSearch(query,docs,sent_id, k = 20, th = 0,  gs = False):
    index = processed(docs, lower = True, stem = True)
    result = query_sim(query,index,k,th)
    docs = []
    for i in result:
        docs.append(sent_id[i[0]])
    return(docs)

class Examples:
    def __init__(self, examples,wiki):
        self.examples = {}
        for key in examples:
            self.examples[key] = {}
            self.examples[key]['claim'] = examples[key]['claim']
            if 'evidence' in self.examples[key]:
                self.examples[key]['evidence'] = examples[key]['evidence']
            else:
                self.examples[key]['evidence'] = []
            if 'label' in self.examples[key]:
                self.examples[key]['label'] = examples[key]['label']
            else:
                self.examples[key]['label'] = ""
        for key in tqdm(self.examples):
            claim = self.examples[key]['claim']
            self.examples[key]['case_doc'],self.examples[key]['case_title'] = wiki.search(claim,False)
            self.examples[key]['uncase_doc'],self.examples[key]['uncase_title'] = wiki.search(claim,True)
            if len(self.examples[key]['case_doc']) == 0:
                    self.examples[key]['case_doc'] = self.examples[key]['uncase_doc']
            self.examples[key]['merged_case_uncase_doc'] = merged_list(self.examples[key]['case_doc'],self.examples[key]['uncase_doc'])
            docs,sent_id = getDoc(self.examples[key]['case_doc'],wiki)
            self.examples[key]['case_sent'] = sentSearch(claim,docs,sent_id,k=100)
            docs,sent_id = getDoc(self.examples[key]['uncase_doc'],wiki)
            self.examples[key]['uncase_sent'] = sentSearch(claim,docs,sent_id,k=100)
            self.examples[key]['merged_case_uncase_sent'] = merged_list(self.examples[key]['case_sent'],self.examples[key]['uncase_sent'])
    
    def get_ori(self):
        ori = {}
        for key in self.examples:
            ori[key] = {}
            ori[key]['claim'] = self.examples[key]['claim']
            ori[key]['evidence'] = self.examples[key]['evidence']
            ori[key]['label'] = self.examples[key]['label']
        return ori
    
    def get_uncase_doc(self):
        uncase = {}
        for key in self.examples:
            uncase[key] = self.examples[key]['uncase_doc']
        return uncase
    
    def get_case_doc(self):
        case = {}
        for key in self.examples:
            case[key] = self.examples[key]['case_doc']
        return case
    
    def get_case_sent_list(self,k):
        case = {}
        for key in self.examples:
            kk = min(k,len(self.examples[key]['case_sent']))
            case[key] = self.examples[key]['case_sent'][:kk]
        return case
    
    def get_uncase_sent_list(self,k):
        uncase = {}
        for key in self.examples:
            kk = min(k,len(self.examples[key]['uncase_sent']))
            uncase[key] = self.examples[key]['uncase_sent'][:kk]
        return uncase
    
    def get_case_sent_full(self,k):
        case = {}
        for key in self.examples:
            kk = min(k,len(self.examples[key]['case_sent']))
            case[key] = {}
            case[key]['claim'] = self.examples[key]['claim']
            case[key]['evidence'] = self.examples[key]['case_sent'][:kk]
            case[key]['label'] = self.examples[key]['label']
        return case
    
    def get_case_title(self):
        case_title = {}
        for key in self.examples:
            case_title[key] = self.examples[key]['case_title']
            case_title[key] = list(set(case_title[key]))
        return case_title
    
    def get_uncase_sent_full(self,k):
        uncase = {}
        for key in self.examples:
            kk = min(k,len(self.examples[key]['uncase_sent']))
            uncase[key] = {}
            uncase[key]['claim'] = self.examples[key]['claim']
            uncase[key]['evidence'] = self.examples[key]['uncase_sent'][:kk]
            uncase[key]['label'] = self.examples[key]['label']
        return uncase

In [9]:
test_examples = Examples(test,wiki)

100%|██████████| 14997/14997 [21:55<00:00, 11.47it/s]  


In [15]:
import pickle
with open(INTERMEDIATE_DATAPATH+"/merged_test100",'rb') as fp:
    merged_devset_dict = pickle.load(fp)
cont_dev ={}
for key in merged_devset_dict:
    cont_dev[key] = merged_devset_dict[key]['matched']


In [20]:
cont_result = {}
for i in cont_dev:
    cont_result[i] = []
    for j in cont_dev[i]:
        cont_result[i].append(j[0])

In [24]:
def get_topk(results,k):
    new_results = {}
    for key in results:
        new_results[key] = results[key][:k]
    return new_results

def get_final(examples,cont_result):
    setname = examples.get_ori()
    case_title = examples.get_case_title()
    filter_cont_sent = sent_selection_cont(cont_result,case_title,setname, topk = 11)
    case_sent = examples.get_case_sent_list(30)
    uncase_sent = examples.get_uncase_sent_list(6)
    merged_uncase_case = merged_result(uncase_sent,case_sent)
    final = merged_result(merged_uncase_case,filter_cont_sent)
    return final

In [25]:
test_function = get_final(test_examples,cont_result)

100%|██████████| 14997/14997 [30:00<00:00, 12.03it/s] 


In [31]:
output_test,claim_id_test = output_senten_result(test_function,test,FINAL_DATAPATH+ "test.tsv")

In [854]:
with open(FINAL_DATAPATH + "claim_id_dev.json","w") as f:
    json.dump(claim_id_dev,f)

In [54]:
def alleval(predicted,actual,gs=False):
    NEI = "NOT ENOUGH INFO"
    correct_label = num_instances = 0
    evidence_prec = num_eprec = 0
    evidence_recall = num_erec = 0
    doc_prec = num_dprec = 0
    doc_rec = num_drec = 0

    for ident, arecord in actual.items():
        precord = predicted[ident]

        alabel = arecord['label'].upper()
        plabel = precord['label'].upper()
        if alabel == plabel:
            correct_label += 1
        num_instances += 1

        if alabel != NEI:
            prec = prec_hits = 0
            rec = rec_hits = 0

            aes = arecord['evidence']
            pes = precord['evidence'][:5]
            for pe in pes:
                if pe in aes:
                    prec += 1
                prec_hits += 1

            for ae in aes:
                if ae in pes:
                    rec += 1
                rec_hits += 1

            ads = set(map(lambda ds: ds[0], aes))
            last_pd = None
            dp = ndp = 0
            for pe in pes:
                if not last_pd or pe[0] != last_pd:
                    if pe[0] in ads:
                        dp += 1
                    ndp += 1
                last_pd = pe[0]

            pds = set(map(lambda ds: ds[0], pes))
            dr = ndr = 0
            for ae in ads:
                if ae in pds:
                    dr += 1
                ndr += 1

            if prec_hits > 0:
                evidence_prec += float(prec) / prec_hits
                num_eprec += 1

            if ndp > 0:
                doc_prec += float(dp) / ndp
                num_dprec += 1

            assert rec_hits > 0
            evidence_recall += float(rec) / rec_hits
            num_erec += 1

            assert ndr > 0
            doc_rec += float(dr) / ndr
            num_drec += 1
    accuracy = correct_label / float(num_instances)
    precision = evidence_prec / float(num_eprec) if num_eprec != 0 else 0
    recall = evidence_recall / float(num_erec) if num_erec != 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
    doc_precision = doc_prec / float(num_dprec) if num_dprec != 0 else 0
    doc_recall = doc_rec / float(num_drec) if num_drec != 0 else 0
    doc_f1 = 2 * doc_precision * doc_recall / (doc_precision + doc_recall) if doc_precision + doc_recall > 0 else 0
            
    doc=[doc_precision,doc_recall,doc_f1]
    sent = [precision,recall,f1]
    if gs == True:
        return doc,sent,accuracy 
    else:
        
        print('Label Accuracy', '\t\t%.2f%%' % (100 * accuracy))
        print('Sentence Precision', '\t%.2f%%' % (100 * precision))
        print('Sentence Recall', '\t%.2f%%' % (100 * recall))
        print('Sentence F1', '\t\t%.2f%%' % (100 * f1))
        print('Document Precision', '\t%.2f%%' % (100 * doc_precision))
        print('Document Recall', '\t%.2f%%' % (100 * doc_recall))
        print('Document F1', '\t\t%.2f%%' % (100 * doc_f1))
def read_tsv_result(path):
    result = []
    with open(path,'r') as f:
        csv_reader = csv.reader(f, delimiter='\t')
        for row in csv_reader:
            result.append(row)
    return result
def evi_list2str(evi):
    string = evi[0] + " " + str(evi[1])
    return(unicodedata.normalize('NFC',string))
def get_sent_sel_result(model_output,index, target, th=0.999,k = 5):
    result = {}
    result_for_claim = {}
    for i in target:
        result_for_claim[i] = Counter()
    for i,res in enumerate(model_output):
        ind = index[i][0]
        title = evi_list2str(index[i][1])
        result_for_claim[ind][title] = float(res[1])    
    for i in target:
        result[i] = {}
        result[i]['claim'] = target[i]['claim']
        result[i]['label'] = "SUPPORTS"
        result[i]['evidence'] = []
        titles = result_for_claim[i].most_common(k)
        for inx in titles:
            if len(result[i]['evidence']) < 1 :
                title = [unicodedata.normalize('NFD',inx[0].split(" ")[0]),int(inx[0].split(" ")[1])]
                result[i]['evidence'].append(title)
            else:
                if float(inx[1])>th:
                        #a = unicodedata.normalize('NFD',inx[0])
                        #title = [a.split(" ")[0],int(a.split(" ")[1])]
                    title = [unicodedata.normalize('NFD',inx[0].split(" ")[0]),int(inx[0].split(" ")[1])]
                    result[i]['evidence'].append(title)
        #for inx in titles:
         #   if len(result[i]['evidence'])<1:
          #      if float(inx[1]) > 0.99:
           #         title = [unicodedata.normalize('NFD',inx[0].split(" ")[0]),int(inx[0].split(" ")[1])]
            #        result[i]['evidence'].append(title)
        #for inx in titles:
         #   if len(result[i]['evidence'])<1:
          #      if float(inx[1]) > 0.98:
            #        title = [unicodedata.normalize('NFD',inx[0].split(" ")[0]),int(inx[0].split(" ")[1])]
              #      result[i]['evidence'].append(title)
            
    return result
def grid_search_sent(result,index,setname):
    th = {}
    num = 0.9
    max_f1 = 0
    th = 0
    while num< 1:
        after_model = get_sent_sel_result(result,index,setname,th=num)
        doc,sent,acc=alleval(after_model,setname,gs=True)
        if sent[2]>max_f1:
            max_f1 = sent[2]
            th = num
        num += 0.001
    return th
def getoutput_final(result,setname,wiki=wiki):
    output = []
    claim_id = []
    for key in result:
        example = []
        example.append(unicodedata.normalize('NFC',setname[key]['claim']))
        example.append(wiki.multi_sents(result[key]['evidence']))
        output.append(example)
        claim_id.append(key)
    return output,claim_id
def output_final_test(result,setname,path):
    output,claim_id = getoutput_final(result,setname)
    with open(path, 'wt') as out_file:
        tsv_writer = csv.writer(out_file, delimiter='\t')
        tsv_writer.writerow("TEST!")
        for i in output:
            tsv_writer.writerow(i)
    return output,claim_id
def get_re(sent_sel_result_list,dev,claim_id_dev):
    result = {}
    result_for_claim = {}
    for i in dev:
        result_for_claim[i] = Counter()
    for i,res in enumerate(sent_sel_result_list):
        ind = claim_id_dev[i][0]
        title = evi_list2str(claim_id_dev[i][1])
        result_for_claim[ind][title] = float(res[1])    
    for i in dev:
        result[i] = {}
        result[i]['claim'] = dev[i]['claim']
        result[i]['label'] = "SUPPORTS"
        result[i]['evidence'] = []
        titles = result_for_claim[i].most_common(5)
    number = {}
    new_result = {}
    needmore ={}
    for i in dev:
        new_result[i] = {}
        new_result[i]['evidence'] = []
        new_result[i]['claim'] = dev[i]['claim']
        new_result[i]['label'] = "SUPPORTS"
        number[i] = [[],[],[]]
        titles = result_for_claim[i].most_common(5)
        if result_for_claim[i].most_common(1) !=[]:
            top1 = result_for_claim[i].most_common(1)[0]
        else:
            top1 = [0,0]
        for title in titles:
            if float(title[1])>=0.999:
                number[i][0].append(title[0])
            if float(title[1])>=0.995:
                number[i][1].append(title[0])
            if float(title[1])>=0.99:
                number[i][2].append(title[0])
        if len(number[i][0])>0:
            for title in number[i][0]:
                title= [unicodedata.normalize('NFD',title.split(" ")[0]),int(title.split(" ")[1])]
                new_result[i]['evidence'].append(title)
        else:
            if len(number[i][1])>1:
                for title in number[i][1]:
                    title= [unicodedata.normalize('NFD',title.split(" ")[0]),int(title.split(" ")[1])]
                    new_result[i]['evidence'].append(title)
            else:
                if len(number[i][2])>2:
                    for title in number[i][2]:
                        title= [unicodedata.normalize('NFD',title.split(" ")[0]),int(title.split(" ")[1])]
                        new_result[i]['evidence'].append(title)
                else:
                    #for title in titles:
                     #   if float(title[1])>=0.9:
                      #      title= [unicodedata.normalize('NFD',top1[0].split(" ")[0]),int(top1[0].split(" ")[1])]
                       #     new_result[i]['evidence'].append(title)
                    if top1[1] > 0.5:
                        title= [unicodedata.normalize('NFD',top1[0].split(" ")[0]),int(top1[0].split(" ")[1])]
                        new_result[i]['evidence'].append(title)
                    #title= [unicodedata.normalize('NFD',top1[0].split(" ")[0]),int(top1[0].split(" ")[1])]
                    #needmore[i] = []
                    #needmore[i].append(title)
    return new_result

In [60]:
sent_sel_result_list = read_tsv_result("/home/junw/workspace/codelab/data/final"+"/test_results_testset.tsv")

In [97]:
nr= get_re(sent_sel_result_list,test,claim_id_test)

# FINAL CLASSIFICATION

In [77]:
def getLabel(res):
    res_new = []
    for i in res:
        res_new.append(float(i))
    num = res_new.index(max(res_new))
    if num==0:
        return "REFUTES"
    if num ==1:
        return "SUPPORTS"
    if num == 2:
        return "NOT ENOUGH INFO"
def get_final_result(sen_sel_result,final_result,index):
    for i,res in enumerate(final_result):
        claim_id = index[i]
        sen_sel_result[claim_id]['label'] = getLabel(res)
    return sen_sel_result

def final_output(result,path):
    new_final = {}
    for i in result:
        new_final[i] = {}
        new_final[i]['claim'] = unicodedata.normalize('NFD', result[i]['claim'])
        new_final[i]['label'] = unicodedata.normalize('NFD', result[i]['label'])
        new_evi = []
        for evi in result[i]['evidence']:
            new_evi.append([unicodedata.normalize('NFD', evi[0]),evi[1]])
        new_final[i]['evidence'] = new_evi
    with open(path,"w") as f:
        json.dump(new_final,f)
    return new_final

In [None]:
output_final, cli = getoutput_final(nr,dev)
output_final_test(nr,dev,FINAL_DATAPATH + "final_test.tsv")

In [1208]:
final_result_tsv = read_tsv_result(FINAL_DATAPATH+"/test_results.tsv")
final_result = get_final_result(nr,final_result_tsv,cli)

In [87]:
output_final_rc, cli_rc = getoutput_final(highrecall,test)
output_final_test(nr,test,FINAL_DATAPATH + "final_test_testset.tsv")

([['Birmingham is in France.',
   'Birmingham , Birmingham -LRB- -LSB- ˈbɜːmɪŋəm -RSB- -RRB- is a city and metropolitan borough in the West Midlands , England .'],
  ['Ralph Fults was born in 1993.',
   'Ralph Fults , Ralph Fults -LRB- January 23 , 1911 -- March 16 , 1993 -RRB- was a Depression-era outlaw and escape artist associated with Raymond Hamilton , Bonnie Parker and Clyde Barrow of the Barrow Gang .'],
  ['Margaret Thatcher was a spokesperson for Doritos.', ''],
  ['International students come to the University of Mississippi from 200 nations.',
   'University of Mississippi , About 55 percent of its undergraduates and 60 percent overall come from Mississippi , and 23 percent are minorities ; international students come from 90 nations .'],
  ['Honeymoon is the third perfume line by Lana Del Rey.',
   'Honeymoon -LRB- Lana Del Rey album -RRB- , Honeymoon is the fourth studio album and third major-label record by American singer and songwriter Lana Del Rey .'],
  ['Psych (seaso

In [102]:
final_result_tsv_high = read_tsv_result(FINAL_DATAPATH+"/test_results.tsv")

In [103]:
final_result_high = get_final_result(nr,final_result_tsv_high,cli_rc)