# README

### Purpose of this notebook
- Predict the topic of the application sentences with respect to the comment cluster using K-Nearest Neighbor algorithm (KNN).
- Form the summary based on the above prediction.

### Steps
1. Get the reduced embeddings and cluster label from the comment sentences.
2. Apply KNN algorithm to predict the label and the confidence of application sentences.
3. Generate summary according to the results from step 2.

In [None]:
import pandas as pd
import numpy as np
import os
import torch
from collections import defaultdict
from itertools import combinations, chain
import pickle
import copy
from docx import Document
from zhon import hanzi

import re
from zhon import hanzi

from tqdm import tqdm
tqdm.pandas(desc="progress: ")

from importlib import reload

# Utility variable
import sys, getopt
sys.path.insert(0, '../..')

# var
import var.var as V
import var.path as P

# utils
import utils.data as D
import utils.io as IO
import utils.preprocess as PP

In [None]:
import torch
from torch import Tensor

## Process Command Line Arguments

In [None]:
opts, args = getopt.getopt(sys.argv[1:], "d:e:m:l:s:f:t")

In [None]:
SIGNIFICANCE_MODEL_SAVE_DIR_NAME = "significance_pHAN_cmt_cos_dist_w_cmt_aug_train_2022-12-12_mixed_1"
EPOCH = 20
MAX_SENT = 8
UNI_MAX_SENT = 2
MAX_SENT_LEN = 60
LAMBDA = 0.3
NORM_RATIO = 4
VAL_OR_TEST = 'val'
debug = False
EVIDENCE_SCORE_THRESHOLD = 0.75

for opt, arg in opts:
    if opt == '-d':
        SIGNIFICANCE_MODEL_SAVE_DIR_NAME = arg
    elif opt == '-e':
        EPOCH = int(arg)
    elif opt == '-s':
        MAX_SENT = int(arg)
    elif opt == '-m':
        MAX_SENT_LEN = int(arg)
    elif opt == '-l':
        LAMBDA = float(arg)
    elif opt == '-t':
        VAL_OR_TEST = 'test'
    elif opt == '-f':
        debug = True

In [None]:
if 'train' in SIGNIFICANCE_MODEL_SAVE_DIR_NAME:
    TRAIN_OR_ALL = 'train'
elif 'all' in SIGNIFICANCE_MODEL_SAVE_DIR_NAME:
    TRAIN_OR_ALL = 'all'
    
if 'wo' in SIGNIFICANCE_MODEL_SAVE_DIR_NAME:
    COMMENT_AUGMENTATION = False
else:
    COMMENT_AUGMENTATION = True

### BERT NSP Model

In [None]:
GPU_NUM = 0
device = torch.device(GPU_NUM)

In [None]:
BERT_NSP_MODEL_NAME = 'bert-base-chinese'
BERT_NSP_TOKENIZER_NAME = 'bert-base-chinese'

In [None]:
from transformers import BertTokenizerFast, BertForNextSentencePrediction

In [None]:
bert_nsp_tokenizer = BertTokenizerFast.from_pretrained(BERT_NSP_TOKENIZER_NAME)
bert_nsp_model = BertForNextSentencePrediction.from_pretrained(BERT_NSP_MODEL_NAME).to(device)

In [None]:
SBERT_MODEL_NAME = 'ckiplab/bert-base-chinese'

In [None]:
from sentence_transformers import SentenceTransformer

sbert_model = SentenceTransformer(SBERT_MODEL_NAME).to(device)

### Utils

In [None]:
def defaultdict_init_defaultdict_init_by_int():
    return defaultdict(int)

def defaultdict_init_defaultdict_init_by_float():
    return defaultdict(float)

def defaultdict_init_defaultdict_init_by_str():
    return defaultdict(str)

## Build the reference citation dictionary

In [None]:
df_recommendation_letters = D.read_df_recommendation_letters()

In [None]:
df_recommendation_letters.tail()

In [None]:
rl_info_dict = defaultdict(defaultdict_init_defaultdict_init_by_str)

for _, row in df_recommendation_letters.iterrows():
    _year = int(row['year'])
    _id = int(row['id'])

    rl_sent = row['all_paragraph_sent']
    info = "，".join(row['info'])
    
    if info == "":
        continue
        
    sent_info_dict = defaultdict_init_defaultdict_init_by_str()
    
    for sent in rl_sent:
        sent_info_dict[sent] = info
        
    rl_info_dict[(_year, _id)] = rl_info_dict[(_year, _id)] | sent_info_dict    

## Calculate the evidence score for each sentence

In [None]:
significance_pseudo_summary_dir = os.path.join(P.FP_SIGNIFICANCE_PSEUDO_SUMMARY_DIR, 'custom_bertopic', TRAIN_OR_ALL)
significance_all_data_dir = os.path.join(significance_pseudo_summary_dir, 'all_data')

uniqueness_pseudo_summary_dir = os.path.join(P.FP_UNIQUENESS_PSEUDO_SUMMARY_DIR, 'custom_bertopic', TRAIN_OR_ALL)
uniqueness_all_data_dir = os.path.join(uniqueness_pseudo_summary_dir, 'all_data')

In [None]:
i = 0
sent_evidence_score_dict = defaultdict(defaultdict_init_defaultdict_init_by_float)

for file in tqdm(os.listdir(significance_all_data_dir)):
    fn = os.path.join(significance_all_data_dir, file)
    
#     IO.print_dividing_line()
#     if i >= 1:
#         break
    
    if os.path.isdir(fn):
        continue

#     print(fn)
        
    with open(fn, "rb") as f:
        group_data = pickle.load(f)

    ## process group data
    candidate_sents_info_buffer = group_data["candidate_sents_info_buffer"]
    chunk_debug_info_buffer = group_data["chunk_debug_info_buffer"]
    
#     print(candidate_sents_info_buffer)
#     print(chunk_debug_info_buffer)

    for info, debug_info in chunk_debug_info_buffer.items():
        buffer_dict = defaultdict_init_defaultdict_init_by_float()
        
#         print(info, debug_info)
        sents = candidate_sents_info_buffer[info]['sents']
        chunks = debug_info['chunks']
        chunk_evidence_scores = debug_info['evidence_score']
    
#         print(info)
#         print(sents)
#         print(chunks)
#         print(chunk_evidence_scores)
        
        for chunk, chunk_evidence_score in zip(chunks, chunk_evidence_scores):
            # find corresponding sentences
            for sent in sents:
                ## aggregate sent evidence score
                if chunk in sent:
                    buffer_dict[sent] = max(
                        buffer_dict[sent], chunk_evidence_score
                    )
                    
        sent_evidence_score_dict[info] = buffer_dict
    
    i += 1

In [None]:
len(sent_evidence_score_dict)

## All data

In [None]:
all_candidate_sents_info_buffer = {}
all_chunk_debug_info_buffer = {}

for file in tqdm(os.listdir(significance_all_data_dir)):
    fn = os.path.join(significance_all_data_dir, file)
    
    if os.path.isdir(fn):
        continue
        
    with open(fn, "rb") as f:
        group_data = pickle.load(f)
        
    candidate_sents_info_buffer = group_data["candidate_sents_info_buffer"]
    chunk_debug_info_buffer = group_data["chunk_debug_info_buffer"]
    
    all_candidate_sents_info_buffer |= candidate_sents_info_buffer
    all_chunk_debug_info_buffer |= chunk_debug_info_buffer

In [None]:
len(all_candidate_sents_info_buffer)

In [None]:
len(all_chunk_debug_info_buffer)

## Long sentence post process utils

In [None]:
from sentence_transformers.util import cos_sim

In [None]:
def get_sent_len(s):
    re_alphanumeric = '[a-zA-Z0-9_]+'
    re_ch_p = '[{}]'.format(hanzi.characters + hanzi.punctuation)
    
    l = 0
    
    ## find all english and number token
    l += len(re.findall(re_alphanumeric, s))
    s = re.sub(re_alphanumeric, '', s)
    
    ## remove whitespace
    s = re.sub('\s', '', s)
    
    ## count chinese character
    l += len(re.findall(re_ch_p, s))
    
    return l

In [None]:
def find_longest_consecutive_sequence(seq):
    assert len(seq) >= 1
    
    cand_seqs = []
    
    if len(seq) == 1:
        return seq

    cur_num = seq[0]
    seq_buf = seq[:1]

    for num in seq[1:]:
        if num == cur_num + 1:
            seq_buf.append(num)
        else:
            cand_seqs.append(seq_buf)
            seq_buf = [num]

        cur_num = num

    cand_seqs.append(seq_buf)
    sorted_cand_seqs = sorted(cand_seqs, key=lambda l: -len(l))
    
    return sorted_cand_seqs[0]

In [None]:
def is_reasonable_sent(chunks, tokenizer, model, debug=False):
    ## get all sublist with length >= 2
    sublist_idx = []
    for start_idx, end_idx in combinations(range(len(chunks)+1), 2):
        if end_idx - start_idx > 1:
            sublist_idx.append((start_idx, end_idx))
    
    ## split sublist into two list with len >= 1
    text_pair = []
    for start_idx, end_idx in sublist_idx:
        sublist = chunks[start_idx:end_idx]
        for pivot in range(1, len(sublist)):
            l_sublist = sublist[0:pivot]
            r_sublist = sublist[pivot:len(sublist)]
            text_pair.append(('，'.join(l_sublist)+'。', '，'.join(r_sublist)+'。'))
    
    inputs = tokenizer(text_pair, return_tensors='pt', padding=True)

    for key in inputs:
        if isinstance(inputs[key], Tensor):
            inputs[key] = inputs[key].to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        results = torch.argmax(outputs.logits, dim=-1)
        
    if debug:
        print(outputs)
        for p in text_pair:
            print(p)
        print(results)
    
    return bool(sum(results) == 0)

In [None]:
hanzi.punctuation

In [None]:
split_punc = '\n＂＃＄％＆＇＊＋，－／：；＜＝＞＠＼＾＿｀｜～､\u3000、〃〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·！？｡。'

In [None]:
parenthesis_punc = '＂（）［］｛｝｟｠｢｣〈〉《》「」『』【】〔〕〖〗〘〙〚〛'

In [None]:
def long_sentence_post_process(s, debug_info, ext_summary, debug=False):
    ## get debug chunk info
    debug_chunks_idx = []
    for i, chunk in enumerate(debug_info['chunks']):
        if chunk in s:
            debug_chunks_idx.append(i)

    ## find the longest consecutive sequence
    debug_chunks_idx = find_longest_consecutive_sequence(debug_chunks_idx)
    
    if debug:
        for _idx in debug_chunks_idx:
            print(debug_info['chunks'][_idx])
    
    ## directly split s to get all chunks
    cand_chunks = re.split('[{}]'.format(split_punc), s)
    cand_chunks = [c for c in cand_chunks if c]
    if debug:
        print("real chunks:", cand_chunks)
    
    ## compute the len of each chunk
    chunks_len = [get_sent_len(cc) for cc in cand_chunks]
    ## brutal force to find text span that is below length limit
    span_tuple = combinations(range(len(cand_chunks)+1), 2)
    f_span_tuple = []

    if debug:
        print("cand chunks:", cand_chunks)
    
    for start_idx, end_idx in span_tuple:
        spans = cand_chunks[start_idx:end_idx]
    #     print(spans)
        ## check if satisfy length limit
        if sum(chunks_len[start_idx:end_idx]) <= MAX_SENT_LEN:
            f_span_tuple.append((start_idx, end_idx))
    
    ## remove text span that is the subset of other text span
    filtered_span_tuple = []
    for i, (start_idx, end_idx) in enumerate(f_span_tuple):
        spans = cand_chunks[start_idx:end_idx]
        unique = True
        for j, (_start_idx, _end_idx) in enumerate(f_span_tuple):
            if i == j:
                continue
            if _start_idx <= start_idx and end_idx <= _end_idx:
                unique = False

        if unique:
#             print('，'.join(spans))
            filtered_span_tuple.append((start_idx, end_idx))

    if debug:
        print("spans after brutal force search")
        print(filtered_span_tuple)
    
    ## remove unreasonable sentence
    buf = []
    for start_idx, end_idx in filtered_span_tuple:
        chunks = cand_chunks[start_idx:end_idx]
        if end_idx - start_idx == 1:
            buf.append((start_idx, end_idx))
        elif is_reasonable_sent(chunks, bert_nsp_tokenizer, bert_nsp_model):
            buf.append((start_idx, end_idx))
            
    filtered_span_tuple = buf

    if debug:
        print("spans after nsp filter")
        print(filtered_span_tuple)
    
    if len(filtered_span_tuple) == 0:
        return ''
    
    cand_score = []
    ## calculate final score
    for start_idx, end_idx in filtered_span_tuple:
        cc = cand_chunks[start_idx:end_idx]
        if debug:
#             print(cc)
            pass
        
        ## aggregate imp score from real chunks
        ## should contain inportant info (high importance)
        cand_importance = 0
        for _c in cc:
            for _idx in debug_chunks_idx:
                if _c in debug_info['chunks'][_idx]:
                    cand_importance += debug_info['importance'][_idx]
                    break
        
        cand_importance /= len(cc)
        
#         ## info should be novel (dissimilar with existing summary)
        cc_embed = sbert_model.encode('，'.join(cc), show_progress_bar=False)
        ext_summary_embed = sbert_model.encode(ext_summary, show_progress_bar=False)
        cand_novel = 1 - cos_sim(cc_embed, ext_summary_embed)[0][0]
        
        score = LAMBDA * cand_importance + (1 - LAMBDA) * cand_novel * NORM_RATIO
        cand_score.append(score)
        if debug:
            print('imp', float(cand_importance))
            print('novel', float(cand_novel))
            print(score)
        
    final_span_idx = np.argmax(cand_score)
    final_start_idx, final_end_idx = filtered_span_tuple[final_span_idx]
    final_chunks = cand_chunks[final_start_idx:final_end_idx]
    final_sent = '，'.join(final_chunks) + '。'
    
    return final_sent
    
    # ## find the desire chunk to extract (max topic match, claim score, evidence score, etc.)
    # for idx in cand_chunks_idx:
    #     print(idx)
    #     print("chunk: ", debug_info['chunks'][idx])
    #     print("predicted_topics: ", debug_info['predicted_topics'][idx])
    #     print("knn_confidence: ", debug_info['knn_confidence'][idx])
    #     print("topic_match_score: ", debug_info['topic_match_score'][idx])
    #     print("claim_score: ", debug_info['claim_score'][idx])
    #     print("future_plan_score: ", debug_info['future_plan_score'][idx])
    #     print("evidence_score: ", debug_info['evidence_score'][idx])
    #     print("uniqueness_score: ", debug_info['uniqueness_score'][idx])
    #     print("importance: ", debug_info['importance'][idx])

In [None]:
import string
from zhon import hanzi

In [None]:
def sentence_post_process(s):
    s = s.strip()
    try:
        ## remove trailing number
        while s[0] in string.punctuation:
            s = s[1:]
    except:
        return ''
    s = s.strip()
    
    ## remove mojibake
    s = s.replace('\uf06c', '')
    s = s.replace('。，', '，')
    s = s.replace('，。', '，')
    s = s.replace('，nan', '')
    s = s.replace('&lt；', '')
    
    ## remove zh number bullet
    ch_number = "一二三四五六七八九十"
    p = '[{}\d]、'.format(ch_number)
    s = re.sub(p, '', s)
    
    ## remove number bullet
    p = '((?<!\d)\d+\.(?!\d)|★)'
    s = re.sub(p, '',s)
    
    s = s.strip()
    try:
        ## remove trailing number
        while s[-1] in string.digits:
            s = s[:-1]
    except:
        return ''
    s = s.strip()
    
    s = s.strip()
    try:
        ## remove trailing non stop punctuation
        while s[-1] in hanzi.non_stops:
            s = s[:-1]
    except:
        return ''
    s = s.strip()
    
    return s

In [None]:
# def sentence_post_process(s):
#     s = s.strip()
#     ## remove trailing number
#     while s[0] in string.punctuation:
#         s = s[1:]
#     s = s.strip()
    
#     ## remove mojibake
#     s = s.replace('\uf06c', '')
#     s = s.replace('。，', '，')
#     s = s.replace('，。', '，')
    
#     ## remove zh number bullet
#     ch_number = "一二三四五六七八九十"
#     p = '[{}\d]、'.format(ch_number)
#     s = re.sub(p, '', s)
    
#     ## remove number bullet
#     p = '((?<!\d)\d+\.(?!\d)|★)'
#     s = re.sub(p, '',s)
    
#     s = s.strip()
#     try:
#         ## remove trailing number
#         while s[-1] in string.digits:
#             s = s[:-1]
#     except:
#         return ''
#     s = s.strip()
    
#     s = s.strip()
#     try:
#         ## remove trailing non stop punctuation
#         while s[-1] in hanzi.non_stops:
#             s = s[:-1]
#     except:
#         return ''
#     s = s.strip()
    
#     return s

# Post process the summary

## Load summary dict

### Utils

In [None]:
default_talents = ["# The content is removed due to confidential concerns."]

def generate_docx(doc, sum_sent, talents=default_talents, debug=False):    
    for talent in talents:
        doc.add_heading(talent, level=2)
        sents = sum_sent[talent]
        
        no_duplicate_sents = []
        sum_sents_buffer = []
        for sent in sents:
            if sent not in no_duplicate_sents:
                no_duplicate_sents.append(sent)
                sum_sents_buffer.append(sent)
        
        for sent in sum_sents_buffer:
            doc.add_paragraph(sent, style='List Bullet')
    
        if len(sents) == 0:
            doc.add_paragraph("無")
    
    for talent in default_talents:
        if talent not in talents:
            doc.add_heading(talent, level=2)
            doc.add_paragraph("無")
    
    return doc

In [None]:
debug_tuple = [
    "# The content is removed due to confidential concerns."
]

In [None]:
## STEP 0: load the summary data
if COMMENT_AUGMENTATION:
    significance_summary_docx_dir = os.path.join(
        P.FP_SIGNIFICANCE_SUMMARY_DIR, TRAIN_OR_ALL, 'w', SIGNIFICANCE_MODEL_SAVE_DIR_NAME, "epoch_{}_max_sent_{}".format(EPOCH, MAX_SENT)
    )
    summary_docx_dir = os.path.join(
        P.FP_SUMMARY_GENERATION_DIR, TRAIN_OR_ALL, 'w', SIGNIFICANCE_MODEL_SAVE_DIR_NAME, "epoch_{}_max_sent_{}".format(EPOCH, MAX_SENT)
    )
else:
    significance_summary_docx_dir = os.path.join(
        P.FP_SIGNIFICANCE_SUMMARY_DIR, TRAIN_OR_ALL, 'wo', SIGNIFICANCE_MODEL_SAVE_DIR_NAME, "epoch_{}_max_sent_{}".format(EPOCH, MAX_SENT)
    )
    summary_docx_dir = os.path.join(
        P.FP_SUMMARY_GENERATION_DIR, TRAIN_OR_ALL, 'wo', SIGNIFICANCE_MODEL_SAVE_DIR_NAME, "epoch_{}_max_sent_{}".format(EPOCH, MAX_SENT)
    )
    
uniqueness_summary_docx_dir = os.path.join(
    P.FP_UNIQUENESS_SUMMARY_DIR, TRAIN_OR_ALL
)

In [None]:
if not os.path.exists(summary_docx_dir):
    os.makedirs(summary_docx_dir)

In [None]:
significance_summary_docx_dir

In [None]:
uniqueness_summary_docx_dir

In [None]:
summary_docx_dir

In [None]:
## load significance summary data
fn_significance_summary_dict = os.path.join(significance_summary_docx_dir, "summary_dict.pkl")
with open(fn_significance_summary_dict, 'rb') as f:
    significance_summary_dict = pickle.load(f)
    
## load uniqueness summary data
fn_uniqueness_summary_dict = os.path.join(uniqueness_summary_docx_dir, "uniqueness_summary_dict.pkl")
with open(fn_uniqueness_summary_dict, 'rb') as f:
    uniqueness_summary_dict = pickle.load(f)

In [None]:
len(significance_summary_dict), len(uniqueness_summary_dict)

In [None]:
pp_summary_dict = {}
pp_summary_no_tag_dict = {}

for info, sig_summary_info in tqdm(significance_summary_dict.items()):
    ## STEP 1: Post process summary
    _year = info[0]
    _id = info[1]
    _name = info[2]

    if debug:
        if info not in debug_tuple:
            continue

    uni_summary = uniqueness_summary_dict[info]
            
    sig_summary = sig_summary_info['summary']
    title_weight = sig_summary_info['title_weight']

    if debug:
        print(info)
        print("significance summary before post process")
        print(sig_summary)
        print("="*10)

    pp_summary = defaultdict(list)
    pp_summary_no_tag = defaultdict(list)
    buf_summary = copy.deepcopy(sig_summary)
    
    ## append uniqueness summary
    cnt = 0
    while cnt < UNI_MAX_SENT and uni_summary != []:
        uni_summary = [sent for sent in uni_summary if not PP.is_empty_sent(sent)]
        uni_summary_sent_embed = sbert_model.encode(uni_summary, show_progress_bar=False)
        
        buf_summary_sent = list(chain.from_iterable(buf_summary.values()))
        buf_summary_sent = [sent for sent in buf_summary_sent if not PP.is_empty_sent(sent)]
        
        buf_summary_sent_embed = sbert_model.encode(buf_summary_sent, show_progress_bar=False)
        
        sim_mat = np.array(cos_sim(uni_summary_sent_embed, buf_summary_sent_embed))
        uni_sent_disimilarity = 1 - np.mean(sim_mat, axis=-1)
        
        idx = np.argmax(uni_sent_disimilarity)
        
        if debug:
            print("disimilarity", uni_sent_disimilarity[idx])
            print("sent", uni_summary[idx])
            print("max sim", np.max(sim_mat[idx]))
        
        ## append unique sentence that is most disimilar to significance summary
        if np.max(sim_mat[idx]) < 0.94:
            buf_summary['獨特表現'].append(uni_summary[idx])
            sig_summary['獨特表現'].append(uni_summary[idx])
            cnt += 1
        uni_summary[idx] = ''
        uni_summary = [sent for sent in uni_summary if not PP.is_empty_sent(sent)]
    
    for talent, sum_sents in sig_summary.items():
        for sum_sent in sum_sents:
            _bidx = buf_summary[talent].index(sum_sent)
            rl_infos = rl_info_dict[(_year, _id)]
            sent_evidence = sent_evidence_score_dict[info]

            tag = ""

            if sum_sent in rl_infos.keys():
                ## 1. add citations (only for rl)
                rl_info = rl_infos[sum_sent]

                if info != '':
                    tag = rl_info
            else:
                ## 2. add verified mark (only for not rl)
                evidence = sent_evidence[sum_sent]

                if evidence > EVIDENCE_SCORE_THRESHOLD:
                    tag = "已驗證"

            ## trim sentence that is too long
            ## get the length of the sentence
            sum_sent_len = get_sent_len(sum_sent)

            if sum_sent_len > MAX_SENT_LEN:                
                if debug:
                    print('sum_sent_to_trim:', sum_sent)

                debug_info = all_chunk_debug_info_buffer[info]

                ext_summary = list(chain.from_iterable(buf_summary.values()))
                ext_summary = [sent for sent in ext_summary if sent != sum_sent]
                ext_summary = '。'.join(ext_summary)

                if debug:
                    print('ext_summary:', ext_summary)
                sum_sent = long_sentence_post_process(sum_sent, debug_info, ext_summary)

                if debug:
                    print('sum_sent:', sum_sent, "len:", get_sent_len(sum_sent))
                    print('-'*10)

            sum_sent = sentence_post_process(sum_sent)
            ## replace sum sent in original summary
            buf_summary[talent][_bidx] = sum_sent

            ## [TODO] process sentence with unclosed parenthesis (remove unreasonable sentence)
            if sum_sent == '':
                continue

            pp_summary_no_tag[talent].append(sum_sent)
                
            if tag:
                sum_sent = "{}（{}）".format(sum_sent, tag)

            pp_summary[talent].append(sum_sent)

    talent_list = sorted(title_weight.items(), key=lambda i: -i[1])
    talent_list = [t[0] for t in talent_list]
    talent_list.append('獨特表現')
    
    if debug:
        print('='*10)
        print("summary after post process")
        print(pp_summary)
        IO.print_dividing_line()

    pp_summary_dict[info] = pp_summary
    pp_summary_no_tag_dict[info] = pp_summary_no_tag
    
    if not debug:
        ## STEP 2: Write post processed summary to docx file
        doc = Document()
        doc = generate_docx(doc, pp_summary, talent_list)

        fn = "{}.docx".format("_".join(map(str, info)))
        if debug:
            print(fn)
        doc.save(os.path.join(summary_docx_dir, fn))

## Prepare data for calculating BERTScore

In [None]:
df_comments = D.read_df_comments()
df_applicants = D.read_df_applicants()

In [None]:
individual_data = []
if VAL_OR_TEST == 'val':
    year_dir = V.YEAR_DIRS[:-1]
elif VAL_OR_TEST == 'test':
    year_dir = V.YEAR_DIRS[-1:]

for year in year_dir:
    _dir = os.path.join(P.FP_SIGNIFICANCE_PSEUDO_SUMMARY_DIR, 'custom_bertopic', TRAIN_OR_ALL, year)
    
    for file in os.listdir(_dir):
        if file == '.ipynb_checkpoints':
            continue

        fn = os.path.join(_dir, file)

        with open(fn, "rb") as f:
            data = pickle.load(f)
            individual_data.append(data)

In [None]:
len(individual_data)

In [None]:
test_info_data = []
test_pseudo_summary_data = []
test_comment_data = []
test_grade_data = []

# pseudo_summary_to_info_dict = {}

if VAL_OR_TEST == 'val':
    for data in tqdm(individual_data):
        _year = data['year']
        _id = data['id']
        _name = data['name']
        pseudo_summary = data['pseudo_summary']

        ## check train or test data
        row = df_applicants.query('`year` == {} and `id` == {}'.format(_year, _id))
        try:
            train_or_test = row['train_or_test'].to_list()[0]
        except:
            train_or_test = 'train'

        if train_or_test == 'train':
            continue
            
        ## get corresponding comments
        row = df_comments.query('`year` == {} and `id` == {}'.format(_year, _id))
        comments = row['comment'].to_list()
        grades = row['grade'].to_list()

        ## append test data set
        for comment, grade in zip(comments, grades):
            ## remove empty comment
            if PP.is_empty_sent(comment):
                continue

            test_info_data.append((_year, _id, _name))
            test_comment_data.append(comment)
            test_pseudo_summary_data.append(pseudo_summary)
            test_grade_data.append(grade)
                
elif VAL_OR_TEST == 'test':
    for data in tqdm(individual_data):
        _year = data['year']
        _id = data['id']
        _name = data['name']
        pseudo_summary = data['pseudo_summary']

        ## append data to test data set
        test_info_data.append((_year, _id, _name))
        test_comment_data.append('') ## stuff empty comment
        test_pseudo_summary_data.append(pseudo_summary)
        test_grade_data.append('F') ## stuff empty comment

In [None]:
len(pp_summary_no_tag_dict)

In [None]:
# ## find summary and its corresponding comment
# if VAL_OR_TEST == 'val':
#     test_summary_result = []

#     for applicant_info, pseudo_summary, comment in zip(test_info_data, test_pseudo_summary_data, test_comment_data):
#         applicant_info = tuple(applicant_info)
#         summary = pp_summary_no_tag_dict[applicant_info]
#         buffer = []

#         for sents in summary.values():
#             buffer.append(sents)

#         summary = ''.join(list(chain.from_iterable(buffer)))

#         # concat summary together [TODO] different concata method may result in differenet bertscore?
#         test_summary_result.append(summary)

In [None]:
# if VAL_OR_TEST == 'val':
#     from bert_score import score

In [None]:
# def calculate_bert_score(cands, refs, rescale=False, verbose=False):
#     return score(
#         cands,
#         refs,
#         lang="zh",
#     #     model_type=MODEL_TYPE,
#     #     num_layers=LAYER,
#         verbose=verbose,
#         device=0,
#         batch_size=32,
#     #     idf=False,
#         rescale_with_baseline=rescale
#     )

In [None]:
# if VAL_OR_TEST == 'val':
#     _P, _R, _F1 = calculate_bert_score(test_summary_result, test_comment_data, rescale=False)

In [None]:
# import logging
# from datetime import datetime

# for handler in logging.root.handlers[:]:
#     logging.root.removeHandler(handler)

In [None]:
# if COMMENT_AUGMENTATION:
#     log_file = os.path.join(P.FP_SUMMARY_GENERATION_DIR, TRAIN_OR_ALL, 'w', SIGNIFICANCE_MODEL_SAVE_DIR_NAME, 'summary_post_process_log.log')
# else:
#     log_file = os.path.join(P.FP_SUMMARY_GENERATION_DIR, TRAIN_OR_ALL, 'wo', SIGNIFICANCE_MODEL_SAVE_DIR_NAME, 'summary_post_process_log.log')
    
# logger = logging.getLogger()
# logger.setLevel(logging.INFO)
# # create file handler which logs even debug messages
# fh = logging.FileHandler(log_file)
# fh.setLevel(logging.INFO)
# # create console handler with a higher log level
# ch = logging.StreamHandler()
# ch.setLevel(logging.INFO)
# # add the handlers to logger
# logger.addHandler(ch)
# logger.addHandler(fh)

In [None]:
# if VAL_OR_TEST == 'val':
#     IO.log_dividing_line(logger)
#     logger.info("model dir: {}".format(SIGNIFICANCE_MODEL_SAVE_DIR_NAME))
#     logger.info("time: {}".format(str(datetime.now())))
#     logger.info("epoch: {}".format(EPOCH))
#     logger.info("max sent: {}".format(MAX_SENT))

In [None]:
# if VAL_OR_TEST == 'val':
#     rouge_precision = torch.mean(_P)
#     rouge_recall = torch.mean(_R)
#     rouge_f1 = torch.mean(_F1)

In [None]:
# if VAL_OR_TEST == 'val':
#     logger.info("p: {:4f}".format(rouge_precision))
#     logger.info("r: {:4f}".format(rouge_recall))
#     logger.info("f: {:4f}".format(rouge_f1))
#     IO.log_dividing_line(logger)