# Extended solution

*Papers that somewhat resemble my ideas or just may be useful*:  
+ [Neural Extractive Text Summarization with Syntactic Compression](https://aclanthology.org/D19-1324.pdf) *by Jiacheng Xu and Greg Durrett*:
    + The approach encodes the text and *then* performs the compression;
    + https://github.com/jiacheng-xu/neu-compression-sum (gosh, what a horrible code)
+ [Fast Abstractive Summarization with Reinforce-Selected Sentence Rewriting](https://aclanthology.org/P18-1063.pdf) *by Yen-Chun Chen and Mohit Bansal*:
    + This approach utilizes RL to rewrite sents, along with abstractive summarization;
    + https://github.com/ChenRocks/fast_abs_rl (much more readable repo)
+ [Extractive Summarization as Text Matching](https://arxiv.org/pdf/2004.08795v1.pdf) *by Ming Zhong, Pengfei Liu, Yiran Chen, Danqing Wang, Xipeng Qiu, Xuanjing Huang*:
    + https://github.com/maszhongming/MatchSum
    + CNN/DailyMail to a new level (44.41 in ROUGE-1);

### Work structure
#### 1. Dataset loading and preprocessing

#### 2. Implementing heuristics as standalone preprocessing functionality (postponed)
My idea is to use some preprocessing tricks to improve the actual effect of summarization. 
The following ways are suggested (each is to be implemented in separate notebook):
1. Utilize coreference resolution among sentences so we won't miss important nouns in our summary.
2. Try to split compound sentences into few smaller ones.
3. Compress resulting sentences so we exclude some low-informative words without (hopefully) sacrifising the readability.
    - Named entities intuitively seem more important that common nouns, so they are not to be deleted.
    
#### 3. Implementation of summarization block per-se
#### 4. Evaluation and metrics


## 1. Dataset

From here: https://cs.nyu.edu/~kcho/DMQA/

## 2. Heuristics (postponed)

See notebooks 3.1, 3.2, 3.3

## 3. Summarization block

For my summarization block I'd like to make use of the novel method suggested by Ming Zhong, Pengfei Liu, Yiran Chen, Danqing Wang, Xipeng Qiu†, Xuanjing Huang in the paper [Extractive Summarization as Text Matching](https://arxiv.org/pdf/2004.08795.pdf)

The main keypoint of the work:
> Instead of scoring and extracting sentences
> one by one to form a summary, we formulate
> extractive summarization as a semantic text matching
> problem and propose a novel summary-level
> framework. Our approach bypasses the difficulty
> of summary-level optimization by contrastive learning,
> that is, a good summary should be more
> semantically similar to the source document than the
> unqualified summaries.

In fact they fine-tune some bert to produce embeddings in a way they favor semantic similarity between a gold summary and a text.


And the loss...
> In order to fine-tune Siamese-BERT, we use a margin-based triplet loss to update the weights

Doesn't search though all possible candidates can be of $\sum_{i=1}^{n}C_n^i$ variants?
> In the inference phase, we formulate extractive summarization as a task to search for the best summary among all the candidates C extracted from the document D.

... and yeah, exactly:
> The matching idea is more intuitive while it suffers from combinatorial explosion problems. \[...] we introduce a content selection module to pre-select salient sentences.

Abovementioned content selection is done via [PreSumm](https://github.com/nlpyang/PreSumm) model

## 3.0 Suggested flow
The paper suggests the following workflow:
1. You score the sentences of the input text with some third-party model accordingly to their presumed informational contribution to the meaning of a text.
2. You get some summary candidates based on combinatorial allocations (hyperparam-dependednt, so you can affect the number of output sentences) with respect to the top scores from the step 1, yet the training itself depends on this choise so I assume if one wants to get good summaries with an arbitary number of sentences, they shall better train the model on 1..n combinations.
3. The deeplearning model learns to choose the best one from the summaries, at the same time avoiding common pitfalls of usual models (the authors of the paper suggest "pearl-summary vs. best-summary" problem).

### \[Preparation] Creating dataset

The dataset I will use needs to be in a certain form. The original solution suggests jsonl format with json objects separated by newline token.
I don't mind it. 

So, first of all, I preprocess and convert my dataset into suitable format


The code for that can be found in `dataset_utils/cnndm_preprocessor`

Also I truncate the dataset to 10k first docs for training and 2k docs for validation. The original paper states that the training took 30 hours with several top GPUs with all the several hundred thousand docs. I have none of such compute.

### Scoring step

We are suggested to somehow score our sentences by their importance. I will use trivial method for that, and also I will not truncate my input on this stage, as long as it can be a hyperparam.


 > - truncate each document into the 5 most important sentences (using BertExt), 
   then select any 2 or 3 sentences to form a candidate summary, so there are C(5,2)+C(5,3)=20 candidate summaries.
   if you want to process other datasets, you may need to adjust these numbers according to specific situation.

BertExt has very questionable codebase and maintainment, so I will stuck with centroidal sorting for this case.
Moreover, I will use not Bert or RoBERTa, but LaBSE simply because I've already used.
Thus, if my integrated encoder will be LaBSE, why shouldn't I just utilize it to create ranking of the sentences?

The code for this step can be found in  `dataset_utils/create_sentence_ranking`

### Dataset loader
The dataset loader was recreated inspired by the MatchSum repo, but with much more readable variable names and with use of torch Dataset instead of any relation on fastNLP library.

It performs tokenization on the go, as well as provides candidate summary combinations.

In [7]:
## DATASET LOADERS
from torch.utils.data import DataLoader
from dataset_utils.dataset import CNNDMDataset
from pathlib import Path

dataset_train_path = Path("../data/cnndm/dataset10k.jsonl")
indices_train_path = Path("../data/cnndm/sent_id10k.jsonl")
dataset_eval_path = Path("../data/cnndm/dataset2k.jsonl")
indices_eval_path = Path("../data/cnndm/sent_id2k.jsonl")

SUMMARY_LENGTH = 5
train_dataset = CNNDMDataset(dataset_train_path, indices_train_path, SUMMARY_LENGTH=SUMMARY_LENGTH)
eval_dataset = CNNDMDataset(dataset_eval_path, indices_eval_path, SUMMARY_LENGTH=SUMMARY_LENGTH)

train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=8)
eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=8)


### Creating matcher

In [None]:
import torch
from torch import nn
from transformers import AutoModel


class MatchSum(nn.Module):
    """
    
    """
    def __init__(self):
        super(MatchSum, self).__init__()

        self.encoder = AutoModel.from_pretrained("cointegrated/LaBSE-en-ru")
        self.hidden_size = 768

    def forward(self, 
        tokenized_text_sentences, 
        tokenized_candidate_sentences, 
        tokenized_summary_sentences):
        
        batch_size = tokenized_text_sentences.size(0)
        # candidate_num = 

        # Get document embedding [SHALL CONSIDER BATCH]
        text_out = self.encoder(**tokenized_text_sentences).pooler_output
        text_embedding = torch.mean(text_out, dim=1)
        text_embedding = nn.functional.normalize(text_embedding)

        assert text_embedding.size() == (batch_size, self.hidden_size) # [batch_size, hidden_size]
        
        # Get summary embedding
        summary_out = self.encoder(**tokenized_summary_sentences).pooler_output
        summary_embedding = torch.mean(summary_out, dim=1)
        summary_embedding = nn.functional.normalize(summary_embedding)

        assert summary_embedding.size() == (batch_size, self.hidden_size) # [batch_size, hidden_size]

        # Get candidates embedding 
        candidates_out = self.encoder(**tokenized_candidate_sentences).pooler_output
        candidate_embedding = torch.mean(candidate_out, dim=1)
        candidate_embedding = nn.functional.normalize(candidate_embedding)
        # [batch_size, candidate_num, hidden_size]
        
        # get summary score
        summary_score = torch.dot(summary_embedding, text_embedding) # similar to cosine cuz normalized

        # get candidate score        
        cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        candidate_scores = cos(candidate_embedding, text_embedding)

        
        assert candidate_scores.size() == (batch_size, candidate_num)

        return {"candidate_scores": candidate_scores, "summary_score": summary_score}

According to the paper, the loss shall be the following:
+ We calculate pairwise losses between the gold summary and the candidate summaries
$$
    L_1 = max(0,f(D,C) − f(D,C^∗) + γ_1), 
$$
where
+ $D$ - document
+ $C$ - candidate summary
+ $C^*$ - gold summary
+ $γ_1$ - margin value

In addition, we calculate pairwise candidate losses, keeping in mind that they're already sorted by ROUGE with respect to the gold summary:
$$
L2 = max(0, f(D,C_j) − f(D,C_i) + (j−i) ∗ γ_2)
$$
- where $i<j$
- $i$ - summary rank
The total loss is the sum of the two above.

> We choose $γ_1 = 0$ and $γ_2 = 0.01$. When $γ_1<0.05$ and $0.005<γ_2<0.05$ they have little effect on performance, otherwise they will cause performance degradation. 

In [None]:
## LOSS FUNCTION
MARGIN = 0.01
def loss_function(candidate_scores, summary_score, MARGIN=MARGIN):
    ones = torch.ones(score.size()).cuda(score.device)
    margin_loss = torch.nn.MarginRankingLoss(0.0)
    total_loss = margin_loss(candidate_scores, candidate_scores, ones)
    
    # candidate loss
    n_candidates = candidate_scores.size(1)
    for i in range(1, n_candidates):
        pos_score = candidate_scores[:, :-i]
        neg_score = candidate_scores[:, i:]
        pos_score = pos_score.contiguous().view(-1)
        neg_score = neg_score.contiguous().view(-1)
        
        ones = torch.ones(pos_score.size()).cuda(candidate_scores.device)
        
        margin_loss = torch.nn.MarginRankingLoss(MARGIN * i)
        total_loss += margin_loss(pos_score, neg_score, ones)

    # gold summary loss
    pos_score = summary_score.unsqueeze(-1).expand_as(candidate_scores)
    neg_score = candidate_scores
    pos_score = pos_score.contiguous().view(-1)
    neg_score = neg_score.contiguous().view(-1)
    ones = torch.ones(pos_score.size()).cuda(candidate_scores.device)
    margin_loss = torch.nn.MarginRankingLoss(0.0)
    total_loss += margin_loss(pos_score, neg_score, ones)

    return total_loss

The learning rate schedule is the following, including the use of warm-up phase:
$$
 lr = 2e^{−3} ·min(step^{−0.5},step · wm^{−1.5})
$$
+ where each $step$ is a batch size of 32 
+ and $wm$ denotes warmup steps of 10,000.

In [None]:
## LR SCHEDULER
from torch.optim.lr_scheduler import _LRScheduler

class LRScheduler(_LRScheduler):
    def __init__(self, optimizer, min_lr=e-6, update_every=2, max_lr=2e-5, warmup_steps=10000):
        self.optimizer = optimizer
        self.update_every = update_every
        self.max_lr = max_lr
        self.min_lr = min_lr  # min learning rate > 0 
        self.warmup_steps = warmup_steps
        self.last_epoch = last_epoch
        super(LambdaLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        # warm up
        if self.step % self.update_every == 0 and self.step > 0:
            self.real_step = self.step // self.update_every 
            cur_lr = self.max_lr * 100 * min(self.real_step**(-0.5), self.real_step * self.warmup_steps**(-1.5))
            cur_lr = max(cur_lr, self.min_lr)
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = cur_lr
            
            return [cur_lr]

#     def on_step_end(self):
        

    

In [None]:
## OPTIMIZER
from torch.optim import Adam
model = MatchSum()
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0)
lr_scheduler = LRScheduler(optimizer)

In [None]:
## TRAINING LOOP
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

from tqdm import tqdm
NUM_TRAINING_STEPS = 10000
NUM_EPOCHS = 5
for epoch in range(NUM_EPOCHS):
    for batch in tqdm(train_dataloader, total=NUM_TRAINING_STEPS):
        batch = {k, v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = loss_function(**outputs)
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

In [None]:
## VALIDATION

In [None]:
MODEL_PATH = Path("./trained_model")
torch.save(model.state_dict(), MODEL_PATH)

## 4. Eval and metrics

In [None]:
model = MatchSum()
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()