In [30]:
# load generated file
# return: 1. file of generated sentence/postags
#         2. a map of src to generated sentence/postags

ROOT_PATH = "./working/c2c_char_level/topic_cVAE.with_cluster_id.with_wordcut/run1510563558/"
input_file = ROOT_PATH + "test.txt"

import jieba.posseg as pseg

def preprocess(gen_sen):
    EOS_SET = set(['.', '。', '?', '？', "!", "！"])
    words = []
    postags = []
    for w, p in pseg.cut("".join(gen_sen.split(" ")).decode("utf8")):
        words.append(w)
        postags.append(p)
        if w in EOS_SET:
            break
    return " ".join(words).encode("utf8"), " ".join(postags).encode("utf8")

def load_gen_comments(gen_file):
    gen_result = {}
    words_file = open(gen_file + ".words", 'w')
    postags_file = open(gen_file + ".postags", 'w')
    with open(gen_file, 'r') as fin:
        for line in fin:
            line = line.strip()
            if line.startswith("Batch"):
                src = ""
            elif line.startswith("Src"):
                src += line[11:-4] + "\n"
            elif line.startswith("Sample"):
                gen = line.split(">>", 1)[1].strip()
                if "<unk>" in gen:
                    continue
                words, postags = preprocess(gen)
                gen_result.setdefault(src, [])
                gen_result[src].append((words, postags))
                print >> words_file, words
                print >> postags_file, postags
    return gen_result

gen_comments = load_gen_comments(input_file)

In [31]:
def load_lm_score(raw_file, score_file):
    raw_data = [line.strip() for line in open(raw_file, 'r')]
    score_data = [line.strip() for line in open(score_file, 'r')]
    assert(len(raw_data) == len(score_data))
    
    return zip(raw_data, score_data)

def lm_filter(lm_score, pos_thre, eos_thre, seq_thre):
    """
    pos_thre: threshold for each position
    eos_thre: threshold of eos position
    seq_thre: threshold for each sequence
    return: a dict of valid sequence and its score
    """
    keeped_seqs = {}
    for seq, score in lm_score:
        if score == "OOV":
            continue
            
        keep = True
        score = map(float, score.split("\t"))
        for s in score[:-2]:
            if s < pos_thre:
                keep = False
                break
        if score[-2] < eos_thre:
            keep = False
        if score[-1] / (len(score) - 1) < seq_thre:
            keep = False
        
        if keep:
            keeped_seqs[seq] = score[-1] / (len(score) - 1)
    return keeped_seqs

In [None]:
# caculate the RNNLM score of words (in the cmd line)

# caculate the RNNLM score of postags (in the cmd line)

In [46]:
# load the score and do filtering

words_lm_score = load_lm_score(ROOT_PATH + "test.txt.words", ROOT_PATH + "test.txt.words.prob")
words_lm_keeped = lm_filter(words_lm_score, pos_thre=-3, eos_thre=-1, seq_thre=-2)

print "filter ration: %.2f%%(%d/%d)" % (len(words_lm_keeped) * 1.0 / len(words_lm_score), 
                                        len(words_lm_keeped), len(words_lm_score))

filter ration: 0.00%(108/117247)


In [44]:
# load the score and do filtering

postags_lm_score = load_lm_score(ROOT_PATH + "test.txt.postags", ROOT_PATH + "test.txt.postags.prob")

postags_lm_keeped = lm_filter(postags_lm_score, pos_thre=-1.5, eos_thre=-1, seq_thre=-1.25)

print "filter ration: %.2f%%(%d/%d)" % (len(postags_lm_keeped) * 1.0 / len(postags_lm_score), 
                                        len(postags_lm_keeped), len(postags_lm_score))

filter ration: 0.01%(1140/117247)


In [47]:
# output the final result

total_num = 0
gen_num = 0
with open(ROOT_PATH + "final_result.txt", 'w') as fout:
    for src, gens in gen_comments.items():
        total_num += 1
        valid_gens = []
        for words, postags in gens:
            if words in words_lm_keeped and \
                postags in postags_lm_keeped:
                valid_gens.append((words, words_lm_keeped[words]))
        if len(valid_gens) == 0:
            continue
        gen_num += 1
        valid_gens = sorted(valid_gens, key=lambda x: x[1], reverse=True)
        print >> fout, "src: %s" % src.strip()
        for i, (gen, score) in enumerate(valid_gens):
            print >> fout, "gen%02d: %s (%.2f)" % (i, gen, score)
        print >> fout, ""
        
print "Cover: %.2f%% (%d/%d)" %(gen_num * 1.0 / total_num, gen_num, total_num)

Cover: 0.05% (16/355)
