In [None]:
debug = False

In [None]:
import pandas as pd
import numpy as np
import os
import torch
from collections import defaultdict
from itertools import combinations

from tqdm import tqdm
tqdm.pandas(desc="progress: ")

from importlib import reload

# Utility variable
import sys
sys.path.insert(0, '../..')

# var
import var.var as V
import var.path as P

# utils
import utils.data as D
import utils.io as IO
import utils.mmr as MMR
import utils.preprocess as PP

### Utils

In [None]:
def defaultdict_init_defaultdict_init_by_int():
    return defaultdict(int)

def defaultdict_init_defaultdict_init_by_float():
    return defaultdict(float)

In [None]:
TRAIN_OR_ALL = 'all'
BERTOPIC_MODEL_NAME = "BERTopic_custom_mcs_100_ckip_diversified_low_{}".format(TRAIN_OR_ALL)
TOP_K = V.TOP_K
MAX_SENT_PER_TOPIC = 3
SIM_THRESHOLD = 0.94

In [None]:
import pickle

In [None]:
fn = os.path.join(
    P.FP_COMMENT_CLUSTERING_TOPIC_HIERARCHY_DIR, 
    "{}_topic_aggregate_info.pkl".format(BERTOPIC_MODEL_NAME)
)

with open(fn, "rb") as f:
    topic_aggregate_info = pickle.load(f)
    topic_aggregate_dict = topic_aggregate_info['topic_aggregate_dict']

In [None]:
topic_aggregate_dict

In [None]:
topic_to_aggregate_topic_dict = {}

for agg_tid, tids in topic_aggregate_dict.items():
    for tid in tids:
        topic_to_aggregate_topic_dict[tid] = agg_tid
        
topic_to_aggregate_topic_dict

## Generate pseudo summary

In [None]:
import torch

In [None]:
GPU_NUM = 0

In [None]:
device = torch.device(GPU_NUM)

In [None]:
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
SBERT_MODEL_NAME = 'ckiplab/bert-base-chinese'

In [None]:
sbert_model = SentenceTransformer(SBERT_MODEL_NAME).to(device)

In [None]:
import pickle
from itertools import chain
from collections import Counter, OrderedDict

In [None]:
def mmr_sorted(docs, q, lambda_=0.3):
    def mmr_sim1(x, q):
        """
            q is the pre-computed score dictionary for each x
        """
        return q[x]

    def mmr_sim2(x, y, sim_mat):
        _idx_x = doc_to_idx[x]
        _idx_y = doc_to_idx[y]
        return sim_mat[_idx_x, _idx_y]
    
    def argmax(keys, f):
        return max(keys, key=f)
    
    if len(docs) == 0:
        return {}
    
    docs_embed = sbert_model.encode(docs, batch_size=512, show_progress_bar=False)
    sim_mat = cosine_similarity(docs_embed, docs_embed)
    doc_to_idx = {doc: i for i, doc in enumerate(docs)}
    
    docs = set(docs)
    
    selected = OrderedDict() 
    while set(selected) != docs: 
        remaining = docs - set(selected) 
        mmr_score = lambda x: lambda_*mmr_sim1(x, q) - (1-lambda_)*max([mmr_sim2(x, y, sim_mat) for y in set(selected)-{x}] or [0]) 
        next_selected = argmax(remaining, mmr_score) 
        selected[next_selected] = len(selected) 
    
    return selected

In [None]:
test_idx = [
    "# The content is removed due to confidential concerns."
]

In [None]:
pseudo_summary_dir = os.path.join(P.FP_SIGNIFICANCE_PSEUDO_SUMMARY_DIR, 'custom_bertopic', TRAIN_OR_ALL)
all_data_dir = os.path.join(pseudo_summary_dir, 'all_data')

In [None]:
cnt = 0

for file in tqdm(os.listdir(all_data_dir)):
    fn = os.path.join(all_data_dir, file)
    
#     IO.print_dividing_line()
    # [TODO] remember to remove testing data
#     if i >= 1:
#         break
    if 'experiment' not in fn:
        continue
    
    if os.path.isdir(fn):
        continue
        
    print(fn)
        
    with open(fn, "rb") as f:
        group_data = pickle.load(f)

    ## process group data
    candidate_sents_info_buffer = group_data["candidate_sents_info_buffer"]
    chunk_debug_info_buffer = group_data["chunk_debug_info_buffer"]

#     print(candidate_sents_info_buffer)
#     print(chunk_debug_info_buffer)
    
    ## extract pseudo summary
    for _idx, info in candidate_sents_info_buffer.items():
        cnt += 1
        
        _year = _idx[0]
        _id = _idx[1]
        _name = _idx[2]
        
        if _year != 112:
            continue
        
        if debug and _idx not in test_idx:
            continue

        if debug:
            print(_idx)
            
        sents = info['sents']
        topic_sent_dict = info['topic_sent_dict']
        sents_topic_importance_dict = info['sents_topic_importance_dict']
        sents_avg_importance_dict = info['sents_avg_importance_dict']
        ## mmr score scaling
        sents_avg_importance_dict = {k: 2*v for k, v in sents_avg_importance_dict.items()}
        
        chunk_debug_info = chunk_debug_info_buffer[_idx]

        ## pseudo summary as list of perspectives of sentences
        pseudo_summary = []
        
        ## Significance: select top-k sentence for each aggregated perspective
        sents_aggregate_perspective_dict = defaultdict(list)
        ## Find the candidate sentences for each aggregated perspective
        ## One sentence can only belong to one aggregated perspective (with highest important score)
        for sent in sents:
            sent_agg_pers_imp_dict = defaultdict(float)
            topic_importance_dict = sents_topic_importance_dict[sent]
            
            for tid, imp in topic_importance_dict.items():
                if tid == -1:
                    continue
                agg_pers_id = topic_to_aggregate_topic_dict[tid]
                sent_agg_pers_imp_dict[agg_pers_id] += imp
                
            if sent_agg_pers_imp_dict == {}:
                continue
                
            belong_agg_pers_id = max(sent_agg_pers_imp_dict, key=sent_agg_pers_imp_dict.get)
            ## append the candidate sentence to the perspective with highest importance score
            sents_aggregate_perspective_dict[belong_agg_pers_id].append(sent)
            
        ## select sentences from each aggregated perspective
        for agg_pers_id, _ in topic_aggregate_dict.items():
#             perspective_sent = sents_aggregate_perspective_dict[agg_pers_id] ## select all sentences
            perspective_sent = [''] * TOP_K ## pad empty sentence
            candidate_sent = sents_aggregate_perspective_dict[agg_pers_id]
            
            # apply mmr
            sent_mmr_sorted = mmr_sorted(candidate_sent, sents_avg_importance_dict)
            
            for i, sent in enumerate(sent_mmr_sorted.keys()):
                if i >= len(perspective_sent):
                    break
                perspective_sent[i] = sent
            pseudo_summary.append(perspective_sent)
        
        if debug:
            IO.print_dividing_line()
        
        if debug:
            print("before remove similar sentences")
            print(pseudo_summary)
            IO.print_dividing_line()
        
        ## within each perspective, remove sentence with too similar semantic meaning (> 0.95)
        for pers_id, pers_sents in enumerate(pseudo_summary):
            if debug:
                print("pers id: ", pers_id)
            pers_sent_embeds = sbert_model.encode(pers_sents, show_progress_bar=False)
            sim_mat = cos_sim(pers_sent_embeds, pers_sent_embeds)
            
            similar_pair = []
            
            for i, j in combinations(range(TOP_K), 2):
                if sim_mat[i, j] > SIM_THRESHOLD:
                    similar_pair.append((i, j))
            
            remove_sent_id_buf = []
            ## remove the shorter sentence
            for i, j in similar_pair:
                if pers_sents[i] == '' and pers_sents[j] == '':
                    continue
                    
                if debug:
                    print(pers_sents[i])
                    print(pers_sents[j])
                    IO.print_dividing_line()
                len_i = PP.get_sent_len(pers_sents[i])
                len_j = PP.get_sent_len(pers_sents[j])
                
                if len_j > len_i:
                    remove_sent_id_buf.append(i)
                else:
                    remove_sent_id_buf.append(j)
            
            for i in remove_sent_id_buf:
                pseudo_summary[pers_id][i] = ''
        
        if debug:
            print("after remove similar sentences")
            print(pseudo_summary)
            IO.print_dividing_line()
            
        if not debug:
            write_buffer = {
                "year": _year,
                "id": _id,
                "name": _name,
                "pseudo_summary": pseudo_summary
            }

            fn = "{}.pkl".format("_".join(map(str, [_year, _id])))
            fp = os.path.join(pseudo_summary_dir, str(_year), fn)

            with open(fp, "wb") as f:
                pickle.dump(write_buffer, f)
    
cnt

In [None]:
pseudo_summary_dir