In [None]:
debug = False

In [None]:
import pandas as pd
import numpy as np
import os
import torch
from collections import defaultdict
from itertools import combinations

from tqdm import tqdm
tqdm.pandas(desc="progress: ")

from importlib import reload

# Utility variable
import sys
sys.path.insert(0, '../..')

# var
import var.var as V
import var.path as P

# utils
import utils.data as D
import utils.io as IO
import utils.mmr as MMR
import utils.preprocess as PP

## Hyper-parameters

In [None]:
TRAIN_OR_ALL = 'train'
VAL_OR_TEST = 'val'
TEST_YEAR = 112
TOP_K = 5
SIM_THRESHOLD = 0.94

In [None]:
def defaultdict_init_defaultdict_init_by_int():
    return defaultdict(int)

def defaultdict_init_defaultdict_init_by_float():
    return defaultdict(float)

## Load data

In [None]:
df_applicants = D.read_df_applicants()
df_applications = D.read_df_applications()
test_df = pd.read_csv("112_F.csv")

In [None]:
df_applicants = pd.concat([df_applicants, test_df])
df_applicants

In [None]:
df_applications_applicants = pd.merge(
    df_applications, df_applicants[['year', 'id', 'name', 'train_or_test']], how='left', on=['year', 'id']
)
df_applications_applicants.name.fillna('?', inplace=True)

In [None]:
test_info_data = []

if VAL_OR_TEST == 'val':
    for _, row in df_applications_applicants.iterrows():
        train_or_test = row['train_or_test']

        if train_or_test != 'test':
            continue

        _year = row['year']
        _id = row['id']
        _name = row['name']

        test_info_data.append((_year, _id, _name))
        
elif VAL_OR_TEST == 'test':
    for _, row in df_applications_applicants.iterrows():
        _year = row['year']
        _id = row['id']
        _name = row['name']
        
        if _year != TEST_YEAR:
            continue

        test_info_data.append((_year, _id, _name))

## Generate pseudo summary

In [None]:
import torch

In [None]:
GPU_NUM = 0

In [None]:
device = torch.device(GPU_NUM)

In [None]:
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
SBERT_MODEL_NAME = 'ckiplab/bert-base-chinese'

In [None]:
sbert_model = SentenceTransformer(SBERT_MODEL_NAME).to(device)

In [None]:
import pickle
from itertools import chain
from collections import Counter, OrderedDict

In [None]:
def mmr_sorted(docs, q, lambda_=0.7):
    def mmr_sim1(x, q):
        """
            q is the pre-computed score dictionary for each x
        """
        return q[x]

    def mmr_sim2(x, y, sim_mat):
        _idx_x = doc_to_idx[x]
        _idx_y = doc_to_idx[y]
        return sim_mat[_idx_x, _idx_y]
    
    def argmax(keys, f):
        return max(keys, key=f)
    
    if len(docs) == 0:
        return {}
    
    docs_embed = sbert_model.encode(docs, batch_size=512, show_progress_bar=False)
    sim_mat = cosine_similarity(docs_embed, docs_embed)
    doc_to_idx = {doc: i for i, doc in enumerate(docs)}
    
    docs = set(docs)
    
    selected = OrderedDict() 
    while set(selected) != docs: 
        remaining = docs - set(selected) 
        mmr_score = lambda x: lambda_*mmr_sim1(x, q) - (1-lambda_)*max([mmr_sim2(x, y, sim_mat) for y in set(selected)-{x}] or [0]) 
        next_selected = argmax(remaining, mmr_score) 
        selected[next_selected] = len(selected) 
    
    return selected

In [None]:
test_info = [
    "# The content is removed due to confidential concerns."
]

In [None]:
## load uniqueness dictionary
uniqueness_pseudo_summary_dir = os.path.join(P.FP_UNIQUENESS_PSEUDO_SUMMARY_DIR, 'custom_bertopic', TRAIN_OR_ALL)
uniqueness_all_data_dir = os.path.join(uniqueness_pseudo_summary_dir, 'all_data')
uniqueness_debug_buffer = {}

for file in tqdm(os.listdir(uniqueness_all_data_dir)):
    fn = os.path.join(uniqueness_all_data_dir, file)
    
    if os.path.isdir(fn):
        continue
    if 'uniqueness' not in fn:
        continue
        
    with open(fn, 'rb') as f:
        buffer = pickle.load(f)

    uniqueness_debug_buffer |= buffer

In [None]:
significance_pseudo_summary_dir = os.path.join(P.FP_SIGNIFICANCE_PSEUDO_SUMMARY_DIR, 'custom_bertopic', TRAIN_OR_ALL)
significance_all_data_dir = os.path.join(significance_pseudo_summary_dir, 'all_data')
significance_chunk_debug_buffer = {}
significance_sents_info_buffer = {}

for file in tqdm(os.listdir(significance_all_data_dir)):
    fn = os.path.join(significance_all_data_dir, file)
    
    if os.path.isdir(fn):
        continue
        
    with open(fn, 'rb') as f:
        buffer = pickle.load(f)

    significance_chunk_debug_buffer |= buffer['chunk_debug_info_buffer']
    significance_sents_info_buffer |= buffer['candidate_sents_info_buffer']

In [None]:
len(significance_chunk_debug_buffer)

In [None]:
len(significance_sents_info_buffer)

In [None]:
os.path.exists

In [None]:
uniqueness_summary_docx_dir = os.path.join(
    P.FP_UNIQUENESS_SUMMARY_DIR, TRAIN_OR_ALL,
)

if not os.path.exists(uniqueness_summary_docx_dir):
    os.mkdir(uniqueness_summary_docx_dir)

In [None]:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()

In [None]:
uniqueness_summary_dict = {}

for info in tqdm(test_info_data):
    if debug and info not in test_info:
        continue
    
    try:
        debug_dict = uniqueness_debug_buffer[info]
    except:
        print("no test applicant")
        continue
    
    _year = info[0]
    _id = info[1]
    _name = info[2]
    
    if debug:
        print(info)
        
    # [TODO] load sents after fix the bug
#     sents = debug_dict['sents']
    sents = significance_sents_info_buffer[info]['sents']

    chunks = debug_dict['chunks']
    uniqueness_score = debug_dict['uniqueness_score']
    # [TODO] load iaf and ccr score after fix the bug
#     iaf_score = debug_dict['iaf_score']
#     ccr_score = debug_dict['ccr_score']
    
    if debug and info not in test_info:
        continue

    if len(uniqueness_score) == 0:
        uniqueness_summary_dict[info] = []
        continue
        
    ## Normalize uniqueness score
    _min = np.min(uniqueness_score)
    _max = np.max(uniqueness_score)
    uniqueness_score = (uniqueness_score - _min) / (_max - _min)
        
    ## Uniqueness: select top-k sentence from outliers with MMR
    summary = []

    sent_unique_dict = defaultdict(float)
    ## Aggregate sentence uniqueness score over chunk uniqueness score
    ## Method 1: [MAX Pool]
    for chunk, uniq in zip(chunks, uniqueness_score):
        if uniq == 0:
            continue
            
        ## find the sentence cotaining the chunk
        for sent in sents:
            ## aggregate chunk uniqueness
            if chunk in sent:
                sent_unique_dict[sent] = max(uniq, sent_unique_dict[sent]) ## use max chunk as uniqueness score
                break
        
    ## Method 2: [Mean Pool]
    ## find the sentence cotaining the chunk
#     for sent in sents:
#         chunk_cnt = 0
#         agg_unique = 0
        
#         for chunk, uniq in zip(chunks, uniqueness_score):
#             if chunk in sent:
#                 agg_unique += uniq
#                 chunk_cnt += 1
        
#         try:
#             ## aggregate max pool and mean pool
#             sent_unique_dict[sent] = (sent_unique_dict[sent] + agg_unique / chunk_cnt) / 2
#         except:
#             sent_unique_dict[sent] = 0

    sent_mmr_sorted = mmr_sorted(sents, sent_unique_dict)

#     if debug:
#         print("unique sentences: ")
#         pass

    for sent in sent_mmr_sorted.keys():
        if len(summary) == TOP_K:
            break
        ## [TODO] remove sentence with too low uniqueness and iaf value
        if sent_unique_dict[sent] < 0.9:
            continue
        summary.append(sent)
        
        
        if debug:
            print(sent, sent_unique_dict[sent])
#             print(sent)
            pass

    while len(summary) < TOP_K:
        summary.append('') ## pad empty sentence

    if debug:
        IO.print_dividing_line()

    if debug:
        print("before remove similar sentences")
        print(summary)
        IO.print_dividing_line()

    ## within each perspective, remove sentence with too similar semantic meaning (> 0.95)
    summary_sent_embeds = sbert_model.encode(summary, show_progress_bar=False)
    sim_mat = cos_sim(summary_sent_embeds, summary_sent_embeds)

    similar_pair = []

    for i, j in combinations(range(TOP_K), 2):
        if sim_mat[i, j] > SIM_THRESHOLD:
            similar_pair.append((i, j))

    remove_sent_id_buf = []
    ## remove the shorter sentence
    for i, j in similar_pair:
        if summary[i] == '' and summary[j] == '':
            continue

#         if debug:
#             print(summary[i])
#             print(summary[j])
#             IO.print_dividing_line()
        len_i = PP.get_sent_len(summary[i])
        len_j = PP.get_sent_len(summary[j])

        if len_j > len_i:
            remove_sent_id_buf.append(i)
        else:
            remove_sent_id_buf.append(j)

    for i in remove_sent_id_buf:
        summary[i] = ''

    if debug:
        print("after remove similar sentences")
        print(summary)
        IO.print_dividing_line()

    uniqueness_summary_dict[info] = summary

In [None]:
len(uniqueness_summary_dict)

In [None]:
uniqueness_summary_docx_dir

In [None]:
if not debug:
    fp = os.path.join(uniqueness_summary_docx_dir, 'uniqueness_summary_dict.pkl')

    with open(fp, "wb") as f:
        pickle.dump(uniqueness_summary_dict, f)