In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
from itertools import chain
from tqdm import tqdm
from collections import defaultdict

from importlib import reload

# Utility variable
import sys, getopt
sys.path.insert(0, '../..')

# var
import var.var as V
import var.path as P

# utils
import utils.data as D
import utils.io as IO
import utils.preprocess as PP
import utils.torch as Tor

## Process Command Line Arguments
- command: `python3 1_comment_articut_word_graph.py <option>`
- options:
    - `-d <model save dir name>`
    - `-e <epoch>`
    - `-m <mmr lambda>` 
    - `-b <batch size>`
    - `-t`: inference on year 112

In [None]:
## parse arguments
opts, args = getopt.getopt(sys.argv[1:], "d:e:m:b:f:s:a:t")

In [None]:
MODEL_SAVE_DIR_NAME = "significance_pHAN_cmt_cos_dist_w_cmt_aug_train_2022-12-09_mixed_1"
EPOCH = 20
MMR_LAMBDA = 0.7
BATCH_SIZE = 12
VAL_OR_TEST = 'val'
MAX_SENT = 8
COMBINED_ATT_MIN_THRESHOLD = 0.03
ATT_THRESHOLD_DECAY = 0.005
RL_WEIGHT_BOOST = 1.05

GPU_NUM = 0

for opt, arg in opts:
    if opt == '-d':
        MODEL_SAVE_DIR_NAME = arg
    elif opt == '-e':
        EPOCH = int(arg)
    elif opt == '-m':
        MMR_LAMBDA = float(arg)
    elif opt == '-b':
        BATCH_SIZE = int(arg)
    elif opt == '-s':
        MAX_SENT = int(arg)
    elif opt == '-a':
        COMBINED_ATT_MIN_THRESHOLD = float(arg)
    elif opt == '-t':
        VAL_OR_TEST = 'test'

In [None]:
if 'train' in MODEL_SAVE_DIR_NAME:
    perspective_title = V.TRAIN_PERSPECTIVE_TITLE
    TRAIN_OR_ALL = 'train'
elif 'all' in MODEL_SAVE_DIR_NAME:
    perspective_title = V.ALL_PERSPECTIVE_TITLE
    TRAIN_OR_ALL = 'all'
    
if 'wo' in MODEL_SAVE_DIR_NAME:
    COMMENT_AUGMENTATION = False
else:
    COMMENT_AUGMENTATION = True
    
## hyper parameters to load
USE_SBERT_EMBED = None
USE_BERT_EMBED = None
NUM_PERSPECTIVE = 0
# TRAIN_OR_ALL = ''
TOP_K = 0
BERT_MODEL_NAME = ''
BERT_TOKENIZER_NAME = ''
BERTOPIC_MODEL_NAME = ''

## pHAN params
BERT_DIM = 0
SENT_DIM = 0
CXT_DIM = 0
PRJ_DIM = 0
COMPRESSION = None
PROJECTION = None
ATTENTION_EMPTY_MASK = None
FREEZE_BERT = None
SENT_TEMPERATURE = 0
PERS_TEMPERATURE = 0
DROPOUT_RATE = 0
LEAKY_RELU_NEG_SLOPE = 0
ENC_BZ = 0

## Load hyperparameters

In [None]:
if COMMENT_AUGMENTATION:
    MODEL_SAVE_DIR_PATH = os.path.join(P.FP_SIGNIFICANCE_PHAN_DIR, TRAIN_OR_ALL, 'w', MODEL_SAVE_DIR_NAME)
else:
    MODEL_SAVE_DIR_PATH = os.path.join(P.FP_SIGNIFICANCE_PHAN_DIR, TRAIN_OR_ALL, 'wo', MODEL_SAVE_DIR_NAME)

In [None]:
import torch

In [None]:
def load_hyperparams(t):
    ## find model name
    model_name = "epoch_{:04d}.pt".format(t)
    fn = os.path.join(MODEL_SAVE_DIR_PATH, model_name)
    checkpoint = torch.load(fn)
    
    global USE_SBERT_EMBED
    global USE_BERT_EMBED
    global NUM_PERSPECTIVE
    global TRAIN_OR_ALL
    global TOP_K
    global BERT_MODEL_NAME
    global BERT_TOKENIZER_NAME
    global BERTOPIC_MODEL_NAME
    global COMMENT_AUGMENTATION
    global perspective_mean_embed
    global BERT_DIM
    global SENT_DIM
    global CXT_DIM
    global PRJ_DIM
    global COMPRESSION
    global PROJECTION
    global ATTENTION_EMPTY_MASK
    global FREEZE_BERT
    global SENT_TEMPERATURE
    global PERS_TEMPERATURE
    global DROPOUT_RATE
    global LEAKY_RELU_NEG_SLOPE
    global RANDOM_STATE
    global ENC_BZ
    
    USE_SBERT_EMBED = checkpoint['use_sbert_embed']
    USE_BERT_EMBED = checkpoint['use_bert_embed']
    NUM_PERSPECTIVE = checkpoint['num_perspective']
    TRAIN_OR_ALL = checkpoint['train_or_all']
    TOP_K = checkpoint['top_k']
    BERT_MODEL_NAME = checkpoint['bert_model']
    BERT_TOKENIZER_NAME = checkpoint['bert_tokenizer']
    BERTOPIC_MODEL_NAME = checkpoint['bert_topic_model']
    COMMENT_AUGMENTATION = checkpoint['comment_augmentation']
    perspective_mean_embed = checkpoint['perspective_mean_embed']
    BERT_DIM = checkpoint['bert_dim']
    SENT_DIM = checkpoint['sent_dim']
    CXT_DIM = checkpoint['cxt_dim']
    PRJ_DIM = checkpoint['prj_dim']
    try:
        COMPRESSION = checkpoint['compression']
    except:
        COMPRESSION = True
    try:
        PROJECTION = checkpoint['projection']
    except:
        PROJECTION = True
    try:
        ATTENTION_EMPTY_MASK = checkpoint['attention_empty_mask']
    except:
        ATTENTION_EMPTY_MASK = False
    FREEZE_BERT = checkpoint['freeze_bert']
    SENT_TEMPERATURE = checkpoint['sent_temperature']
    PERS_TEMPERATURE = checkpoint['pers_temperature']
    DROPOUT_RATE = checkpoint['dropout_rate']
    LEAKY_RELU_NEG_SLOPE = checkpoint['leaky_relu_negative_slope']
    ENC_BZ = checkpoint['encode_batch_size']
    RANDOM_STATE = checkpoint['random_state']
    
    np.random.seed(RANDOM_STATE)
    
    assert (USE_SBERT_EMBED or USE_BERT_EMBED) == True

In [None]:
load_hyperparams(EPOCH)

## Settings

In [None]:
from torch import Tensor

In [None]:
device = torch.device(GPU_NUM)

In [None]:
import pytz
timezone = pytz.timezone('Asia/Taipei')

In [None]:
# Disable hugging face tokenizer parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "false"

### setup logger

In [None]:
if COMMENT_AUGMENTATION:
    summary_docx_dir = os.path.join(
        P.FP_SIGNIFICANCE_SUMMARY_DIR, TRAIN_OR_ALL, 'w', MODEL_SAVE_DIR_NAME, "epoch_{}_max_sent_{}".format(EPOCH, MAX_SENT)
    )
else:
    summary_docx_dir = os.path.join(
        P.FP_SIGNIFICANCE_SUMMARY_DIR, TRAIN_OR_ALL, 'wo', MODEL_SAVE_DIR_NAME, "epoch_{}_max_sent_{}".format(EPOCH, MAX_SENT)
    )

summary_debug_dir = os.path.join(
    summary_docx_dir, "debug"
)
    
if not os.path.exists(summary_docx_dir):
    os.makedirs(summary_docx_dir)
    
if not os.path.exists(summary_debug_dir):
    os.makedirs(summary_debug_dir)
    
pseudo_summary_dir = os.path.join(P.FP_SIGNIFICANCE_PSEUDO_SUMMARY_DIR, 'custom_bertopic', TRAIN_OR_ALL)
all_data_dir = os.path.join(pseudo_summary_dir, 'all_data')

In [None]:
import logging

for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

In [None]:
import var.path as P

In [None]:
if COMMENT_AUGMENTATION:
    log_file = os.path.join(P.FP_SIGNIFICANCE_SUMMARY_DIR, TRAIN_OR_ALL, 'w', MODEL_SAVE_DIR_NAME, 'summary_gen_log.log')
else:
    log_file = os.path.join(P.FP_SIGNIFICANCE_SUMMARY_DIR, TRAIN_OR_ALL, 'wo', MODEL_SAVE_DIR_NAME, 'summary_gen_log.log')
    
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# create file handler which logs even debug messages
fh = logging.FileHandler(log_file)
fh.setLevel(logging.INFO)
# create console handler with a higher log level
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
# add the handlers to logger
logger.addHandler(ch)
logger.addHandler(fh)

### Utils

In [None]:
def defaultdict_init_defaultdict_init_by_int():
    return defaultdict(int)

def defaultdict_init_defaultdict_init_by_float():
    return defaultdict(float)

## Read raw data

In [None]:
df_comments = D.read_df_comments()
df_comments

In [None]:
df_applicants = D.read_df_applicants()
df_applicants

In [None]:
individual_data = []
if VAL_OR_TEST == 'val':
    year_dir = V.YEAR_DIRS[:-1]
elif VAL_OR_TEST == 'test':
    year_dir = V.YEAR_DIRS[-1:]

for year in year_dir:
    _dir = os.path.join(P.FP_SIGNIFICANCE_PSEUDO_SUMMARY_DIR, 'custom_bertopic', TRAIN_OR_ALL, year)
    
    for file in os.listdir(_dir):
        if file == '.ipynb_checkpoints':
            continue

        fn = os.path.join(_dir, file)

        with open(fn, "rb") as f:
            data = pickle.load(f)
            individual_data.append(data)

In [None]:
len(individual_data)

In [None]:
def defaultdict_init_defaultdict_init_by_int():
    return defaultdict(int)

def defaultdict_init_defaultdict_init_by_float():
    return defaultdict(float)

def defaultdict_init_defaultdict_init_by_str():
    return defaultdict(str)

In [None]:
df_recommendation_letters = D.read_df_recommendation_letters()

In [None]:
rl_info_dict = defaultdict(defaultdict_init_defaultdict_init_by_str)

for _, row in df_recommendation_letters.iterrows():
    _year = int(row['year'])
    _id = int(row['id'])

    rl_sent = row['all_paragraph_sent']
    info = "，".join(row['info'])
    
    if info == "":
        continue
        
    sent_info_dict = defaultdict_init_defaultdict_init_by_str()
    
    for sent in rl_sent:
        sent_info_dict[sent] = info
        
    rl_info_dict[(_year, _id)] = rl_info_dict[(_year, _id)] | sent_info_dict 

In [None]:
# all_candidate_sents_info_buffer = {}
# all_chunk_debug_info_buffer = {}

# for file in tqdm(os.listdir(all_data_dir)):
#     fn = os.path.join(all_data_dir, file)
    
#     if os.path.isdir(fn):
#         continue
        
#     with open(fn, "rb") as f:
#         group_data = pickle.load(f)
        
#     candidate_sents_info_buffer = group_data["candidate_sents_info_buffer"]
#     chunk_debug_info_buffer = group_data["chunk_debug_info_buffer"]
    
#     all_candidate_sents_info_buffer |= candidate_sents_info_buffer
#     all_chunk_debug_info_buffer |= chunk_debug_info_buffer

In [None]:
# len(all_candidate_sents_info_buffer)

In [None]:
# len(all_chunk_debug_info_buffer)

## Prepare training data and testing data

In [None]:
test_info_data = []
test_pseudo_summary_data = []
test_comment_data = []
test_grade_data = []

# pseudo_summary_to_info_dict = {}

if VAL_OR_TEST == 'val':
    for data in tqdm(individual_data):
        _year = data['year']
        _id = data['id']
        _name = data['name']
        pseudo_summary = data['pseudo_summary']

        ## check train or test data
        row = df_applicants.query('`year` == {} and `id` == {}'.format(_year, _id))
        try:
            train_or_test = row['train_or_test'].to_list()[0]
        except:
            train_or_test = 'train'

        if train_or_test == 'train':
            continue
            
        ## get corresponding comments
        row = df_comments.query('`year` == {} and `id` == {}'.format(_year, _id))
        comments = row['comment'].to_list()
        grades = row['grade'].to_list()

        ## append test data set
        for comment, grade in zip(comments, grades):
            ## remove empty comment
            if PP.is_empty_sent(comment):
                continue

            test_info_data.append((_year, _id, _name))
            test_comment_data.append(comment)
            test_pseudo_summary_data.append(pseudo_summary)
            test_grade_data.append(grade)
                
elif VAL_OR_TEST == 'test':
    for data in tqdm(individual_data):
        _year = data['year']
        _id = data['id']
        _name = data['name']
        pseudo_summary = data['pseudo_summary']

        ## append data to test data set
        test_info_data.append((_year, _id, _name))
        test_comment_data.append('') ## stuff empty comment
        test_pseudo_summary_data.append(pseudo_summary)
        test_grade_data.append('F') ## stuff empty comment

In [None]:
len(test_info_data), len(test_pseudo_summary_data), len(test_comment_data), len(test_grade_data)

### Apply one hot encoder to grade label

In [None]:
from sklearn.preprocessing import OneHotEncoder

In [None]:
V.TRAIN_GRADE_LABELS

In [None]:
enc = OneHotEncoder()
one_hot_vector = enc.fit_transform(V.TRAIN_GRADE_LABELS).toarray()
one_hot_vector

In [None]:
test_ext_grade_data = np.array(test_grade_data).reshape(-1, 1)
test_ext_grade_data = enc.transform(test_ext_grade_data).toarray()

In [None]:
test_ext_grade_data.shape

### Create dataset and dataloader

In [None]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [None]:
class PseudoSummaryEvaluationDataset(Dataset):
    def __init__(self, infos, pseudo_summaries, comments, grades):
        ## list of sentences
        self.infos = infos
        self.pseudo_summaries = pseudo_summaries
        self.comments = comments
        self.grades = grades
        
    def __len__(self):
        return len(self.grades)

    def __getitem__(self, idx):
        info = self.infos[idx]
        pseudo_summary = self.pseudo_summaries[idx]
        comment = self.comments[idx]
        grade = self.grades[idx]
        
        return info, pseudo_summary, comment, grade

In [None]:
test_dataset = PseudoSummaryEvaluationDataset(
    test_info_data, test_pseudo_summary_data, test_comment_data, test_ext_grade_data
)
test_dataloader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, collate_fn=lambda batch: batch,
    num_workers=8, pin_memory=True, shuffle=True
)

# Load previous model

## Model Building

### Load BERT model

In [None]:
if USE_SBERT_EMBED:
    from sentence_transformers import SentenceTransformer
    
    bert_tokenizer = None
    bert_model = SentenceTransformer(BERT_MODEL_NAME).to(device)

In [None]:
if USE_BERT_EMBED:
    from transformers import BertTokenizerFast, AutoModel

    bert_tokenizer = BertTokenizerFast.from_pretrained(BERT_TOKENIZER_NAME)
    bert_model = AutoModel.from_pretrained(BERT_MODEL_NAME).to(device)

### Attention Network

In [None]:
from torch import nn

### Perspective HAN

In [None]:
import utils.pHAN as PHAN

## Training loop & testing loop

In [None]:
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

In [None]:
## utils
def get_data(batch):
    batch_infos = [p[0] for p in batch]
    batch_pseudo_summaries = [p[1] for p in batch]
    batch_comments = [p[2] for p in batch]
    batch_grades = np.array([p[3] for p in batch])
    
    return batch_infos, batch_pseudo_summaries, batch_comments, batch_grades

def logit_to_label(logits, return_numpy=False):
    ## convert logits to labels
    labels_idx = torch.argmax(logits, 1).cpu().detach().numpy()
    labels = [V.GRADE_INDEX_TO_LABEL[idx] for idx in labels_idx]
    
    if return_numpy:
        return np.array(labels)
    
    return labels

def decouple_loss_fn_dict(d):
    cls_loss_fn = d["cls_loss_fn"]
    cos_dis_loss_fn = d["cos_dis_loss_fn"]
    con_loss_fn = d["con_loss_fn"]
    
    return cls_loss_fn, cos_dis_loss_fn, con_loss_fn

def decouple_loss_weight_dict(d):
    cls_loss_weight = d["cls_loss_weight"]
    cos_dis_loss_weight = d["cos_dis_loss_weight"]
    con_loss_weight = d["con_loss_weight"]

    return cls_loss_weight, cos_dis_loss_weight, con_loss_weight

## Model initialization

In [None]:
pHAN = PHAN.PerspectiveHierarchicalAttentionNetwork(
    bert_model, bert_tokenizer, perspective_mean_embed, 
    NUM_PERSPECTIVE, TOP_K, BERT_DIM, SENT_DIM, CXT_DIM, PRJ_DIM, 
    sent_temperature=SENT_TEMPERATURE, pers_temperature=PERS_TEMPERATURE, 
    dropout_rate=DROPOUT_RATE, leaky_relu_negative_slope=LEAKY_RELU_NEG_SLOPE, 
    encode_batch_size=ENC_BZ, compression=COMPRESSION, projection=PROJECTION,
    attention_empty_mask=ATTENTION_EMPTY_MASK, freeze_bert=FREEZE_BERT,
).to(device)
pHAN

### Number of parameters

In [None]:
num_params = torch.tensor(0)

for parameter in pHAN.parameters():
    if parameter.requires_grad:
        num_params += torch.prod(torch.tensor(parameter.shape))

num_params

## Test on test dataset

### Loss functions

In [None]:
from pytorch_metric_learning import losses

In [None]:
## classification loss
weight = torch.tensor([3, 1.5, 1, 1]).to(device)
# weight = torch.tensor(class_weights).to(device)
cls_loss_fn = nn.BCELoss(weight)
# class_weights

In [None]:
## cosine similarity loss w.r.t comment
cos_dis_loss_fn = nn.CosineEmbeddingLoss()

In [None]:
## contrastive loss
con_loss_fn = losses.SupConLoss(temperature=1)

In [None]:
loss_fn_dict = {
    "cls_loss_fn": cls_loss_fn,
    "cos_dis_loss_fn": cos_dis_loss_fn,
    "con_loss_fn": con_loss_fn
}

### Learning rate scheduler

In [None]:
low_lr_param_list = ['grade_classifier']

low_lr_params = list(filter(
    lambda kv: sum([_name in kv[0] for _name in low_lr_param_list]),
    pHAN.named_parameters()
))
low_lr_params = [params[1] for params in low_lr_params]

base_lr_params = list(filter(
    lambda kv: sum([_name not in kv[0] for _name in low_lr_param_list]),
    pHAN.named_parameters()
))
base_lr_params = [params[1] for params in base_lr_params]

In [None]:
base_learning_rate = 1e-3

optimizer = torch.optim.AdamW(
    [
        {"params": base_lr_params, "lr": 1e-3},
        {"params": low_lr_params, "lr": 1e-3},
        {"params": torch.tensor([1]), "lr": 1}, ## eta
    ],
    lr=base_learning_rate
)

In [None]:
lr_gamma = 0.8
step_size = 8
sch_lambda_params = lambda epoch: lr_gamma ** (epoch // step_size)
sch_lambda_eta = lambda epoch: max(0, 1 - 0.1 * (epoch // 1 // step_size))

scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer, lr_lambda=[sch_lambda_params, sch_lambda_params, sch_lambda_eta]
)

In [None]:
# for epoch in range(100):
#     print(epoch, scheduler.get_last_lr())
#     scheduler.step()

## Load state of the model, optimizer, and scheduler

In [None]:
def load_model(t):
    ## find model name
    model_name = "epoch_{:04d}.pt".format(t)
    fn = os.path.join(MODEL_SAVE_DIR_PATH, model_name)
    checkpoint = torch.load(fn)
    
    pHAN.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

In [None]:
load_model(EPOCH)

# Generate summary

## Calculate attention

In [None]:
import math

In [None]:
def get_perspective_sent_distribution(att):
    sent_dist = [math.floor(MAX_SENT * a) for a in att]
    
    ## find perspective above threshold
#     for i, a in enumerate(att):
#         if a > PERSPECTIVE_ATT_MIN_THRESHOLD and math.floor(MAX_SENT * a) == 0:
#             sent_dist[i] = 1
    
    while sum(sent_dist) < MAX_SENT:
        add_sent_cand = [(d, 1-a, i) for i, (d, a) in enumerate(zip(sent_dist, att))]
        sorted_add_sent_cand = sorted(add_sent_cand)
        
        ## add sent quota to the first candidate
        pers_to_add_sent = sorted_add_sent_cand[0][2]
        sent_dist[pers_to_add_sent] += 1
        
    return sent_dist

In [None]:
attention_dict = {}

def summary_generation_loop(dataloader, model):
    eta = 0
    
    batch_num = len(dataloader.dataset) // BATCH_SIZE
    
    with torch.no_grad():
        for batch in tqdm(dataloader, total=batch_num):
            model.eval()
            
            batch_infos, batch_pseudo_summaries, batch_comments, batch_grades = get_data(batch)
            
            # Compute prediction
            _, _, sent_att, pers_att = model(batch_pseudo_summaries, batch_comments, eta)
            
            for applicant_info, pseudo_summary, s_att, p_att in zip(batch_infos, batch_pseudo_summaries, sent_att, pers_att):
                applicant_info = tuple(applicant_info)
                
                attention_dict[applicant_info] = {
                    'pseudo_summary': pseudo_summary,
                    'sent_att': s_att.detach().cpu().numpy(),
                    'pers_att': p_att.detach().cpu().numpy(),
                }
                
#                 print(applicant_info)
#                 print("sent_att: ", s_att)
#                 print("pers_att: ", p_att)

In [None]:
summary_generation_loop(test_dataloader, pHAN)

## Generate summaries

In [None]:
fn = os.path.join(
    P.FP_COMMENT_CLUSTERING_TOPIC_HIERARCHY_DIR, 
    "{}_topic_aggregate_info.pkl".format(BERTOPIC_MODEL_NAME)
)

with open(fn, "rb") as f:
    topic_aggregate_info = pickle.load(f)
    topic_aggregate_dict = topic_aggregate_info['topic_aggregate_dict']
    perspective_mean_embed_dict = topic_aggregate_info['topic_aggregate_embed_mean_dict']
    topic_aggregate_intra_similarity_dict = topic_aggregate_info['topic_aggregate_intra_similarity_dict']

In [None]:
topic_aggregate_info

## Calculate the quota for each perspective

In [None]:
quota_list = [3, 2, 2, 1, 1]

In [None]:
# for i, pids in topic_aggregate_dict.items():
#     print(i, pids)
#     print(topic_aggregate_intra_similarity_dict[i])

In [None]:
quota_dict = {}

for i, (pid, _) in enumerate(sorted(topic_aggregate_intra_similarity_dict.items(), key=lambda l: l[1])):
    quota_dict[pid] = quota_list[i]

In [None]:
quota_dict

In [None]:
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
SBERT_MODEL_NAME = 'ckiplab/bert-base-chinese'

In [None]:
sbert_model = SentenceTransformer(SBERT_MODEL_NAME).to(device)

In [None]:
import pickle
from itertools import chain
from collections import Counter, OrderedDict

In [None]:
perspective_title

In [None]:
def check_is_redundency(sent, summary_buf, redundency_threshold=0.9, quota=2, debug=False):
    ## [TODO] adjust the quota based on the diversity of the perspective
    sent_embed = sbert_model.encode(sent, batch_size=64, show_progress_bar=False)
    summary_buf_embed = sbert_model.encode(summary_buf, batch_size=64, show_progress_bar=False)
    
    sims = cosine_similarity([sent_embed], summary_buf_embed)[0]
    
    if debug:
        print("similarity:", sims)
        
    num_sim_sent = sum(sims > redundency_threshold)
    
    return num_sim_sent >= quota

In [None]:
from docx import Document

In [None]:
summary_dict = {}

for info, _dict in tqdm(attention_dict.items()):
#     print(info)
    
    pseudo_summary = _dict['pseudo_summary']
    sent_att = _dict['sent_att']
    pers_att = _dict['pers_att']
    rls = list(rl_info_dict[info[:2]].keys())
#     print(rls)

    ## get_perspective_sent_distribution
#     print(pers_att)
    pers_sent_dist = get_perspective_sent_distribution(pers_att)
#     print(pers_sent_dist)

    summary = defaultdict(list)
    title_weight = defaultdict(float)
    
    sent_combine_att_dict = {}
    sent_info_dict = {}
    
    for i, (pers_sent, s_att, p_att) in enumerate(zip(pseudo_summary, sent_att, pers_att)):
        for sent, _s_att in zip(pers_sent, s_att):
            ## [TODO] if sent evidence score is 0, boost the weight by a small factor
            boost_weight = 1.0
            if sent in rls:
                boost_weight = RL_WEIGHT_BOOST
            
            sent_combine_att_dict[sent] = p_att * _s_att * boost_weight
            sent_info_dict[sent] = {
                "pers": i,
                "pers_att": p_att,
                "sent_att": _s_att,
                "cmb_att": p_att * _s_att * boost_weight
            }
            
    ## generate summary from the highest combined attention sentence
    summary_str_buf = [""]
    summary_pers_sent_count = Counter()
    for num_sent, (sent, cmb_att) in enumerate(sorted(sent_combine_att_dict.items(), key=lambda i: -i[1])):
        ## check if sent count exceed the limit
        if len(summary_str_buf) >= MAX_SENT:
            break
        
        pers_id = sent_info_dict[sent]['pers']
        pers_num_sent_in_summary = summary_pers_sent_count[pers_id]
        ## check if combined att is lower then threshold
        if cmb_att < COMBINED_ATT_MIN_THRESHOLD - pers_num_sent_in_summary * ATT_THRESHOLD_DECAY:
            continue
        ## check if the sentence in the perspective is already full
        if summary_pers_sent_count[pers_id] >= quota_dict[pers_id]:
            continue
        ## check if the added sent is too similar to the sentence already in the summary
        if check_is_redundency(sent, summary_str_buf, quota=quota_dict[pers_id]):
            continue
        
        ## add the sentence to the summary
        pers_id = sent_info_dict[sent]['pers']
        title = perspective_title[pers_id]
        summary[title] += [sent]
        title_weight[title] += cmb_att
        summary_pers_sent_count[pers_id] += 1
        summary_str_buf.append(sent)
    
    summary_dict[info] = {
        "summary": summary,
        "title_weight": title_weight,
    }
    
    ## generate debug info
    doc = Document()
    ## info about cmb_att ranking list
#     _ = doc.add_heading("Combined attention weight ranking list", level=2)
#     for num_sent, (sent, cmb_att) in enumerate(sorted(sent_combine_att_dict.items(), key=lambda i: -i[1])):
#         _ = doc.add_paragraph("rank: {}, cmb_att: {:.4f}, sent: {}".format(
#             num_sent+1, cmb_att, sent
#         ))
    
#     _ = doc.add_page_break()
    ## info about detailed attention weight list
    _ = doc.add_heading("Detailed attention weight per sentence", level=2)
    for num_sent, (sent, _) in enumerate(sorted(sent_combine_att_dict.items(), key=lambda i: -i[1])):
        weights = sent_info_dict[sent]
        _ = doc.add_paragraph("rank: {}, cmb_att: {:.4f}, pers: {}, pers_att: {:.4f}, sent_att: {:.4f}, sent: {}".format(
            num_sent+1, weights['cmb_att'], weights['pers'], weights['pers_att'], weights['sent_att'], sent
        ))
        
    _ = doc.add_page_break()
    ## info about attention weight per perspective
    _ = doc.add_heading("Sentence attention per perspective", level=2)
    for i, (pers_sent, s_att, p_att) in enumerate(zip(pseudo_summary, sent_att, pers_att)):
        _ = doc.add_paragraph("perspective {} attention weight: {}".format(i, p_att))
        for _idx in np.argsort(s_att)[::-1]:
            _ = doc.add_paragraph("pers_att: {:4f}, sent: {}".format(s_att[_idx], pers_sent[_idx]))
            
        _ = doc.add_paragraph("="*50)
        
    fn = "{}_att_weight_debug.docx".format("_".join(map(str, info)))
    _ = doc.save(os.path.join(summary_debug_dir, fn))

In [None]:
summary_dict

In [None]:
# summary_dict

In [None]:
fn_summary_dict = os.path.join(summary_docx_dir, "summary_dict.pkl")

with open(fn_summary_dict, 'wb') as f:
    pickle.dump(summary_dict, f)

## Calculate BERTScore with corresponding comments

In [None]:
# ## find summary and its corresponding comment
# if VAL_OR_TEST == 'val':
#     test_summary_result = []

#     for applicant_info, pseudo_summary, comment in zip(test_info_data, test_pseudo_summary_data, test_comment_data):
#         applicant_info = tuple(applicant_info)
#         summary = summary_dict[applicant_info]['summary']
#         buffer = []

#         for sents in summary.values():
#             buffer.append(sents)

#         summary = ''.join(list(chain.from_iterable(buffer)))

#         # concat summary together [TODO] different concata method may result in differenet bertscore?
#         test_summary_result.append(summary)

In [None]:
# if VAL_OR_TEST == 'val':
#     from bert_score import score

In [None]:
# def calculate_bert_score(cands, refs, rescale=False, verbose=False):
#     return score(
#         cands,
#         refs,
#         lang="zh",
#     #     model_type=MODEL_TYPE,
#     #     num_layers=LAYER,
#         verbose=verbose,
#         device=0,
#         batch_size=64,
#     #     idf=False,
#         rescale_with_baseline=rescale
#     )

In [None]:
# if VAL_OR_TEST == 'val':
#     _P, _R, _F1 = calculate_bert_score(test_summary_result, test_comment_data, rescale=False)

In [None]:
# from datetime import datetime

In [None]:
# if VAL_OR_TEST == 'val':
#     IO.log_dividing_line(logger)
#     logger.info("model dir: {}".format(MODEL_SAVE_DIR_NAME))
#     logger.info("time: {}".format(str(datetime.now())))
#     logger.info("combined att method")
#     logger.info("epoch: {}".format(EPOCH))
#     logger.info("max sent: {}".format(MAX_SENT))
#     logger.info("cmb att threshold: {}".format(COMBINED_ATT_MIN_THRESHOLD))

In [None]:
# if VAL_OR_TEST == 'val':
#     rouge_precision = torch.mean(_P)
#     rouge_recall = torch.mean(_R)
#     rouge_f1 = torch.mean(_F1)

In [None]:
# if VAL_OR_TEST == 'val':
#     logger.info("p: {:4f}".format(rouge_precision))
#     logger.info("r: {:4f}".format(rouge_recall))
#     logger.info("f: {:4f}".format(rouge_f1))