# Encoder Text Summarizer

Notebook for extracting summaries from the CNN articles dataset

**Brief Summary:**

- exprected running time 1h30 

- the model uses an encoder to transform an hot-encoded text into an embedding, we then use these embeddings to say wether a token/sentence is important or not

- we then create the summary by taking the 2 most important sentences from the article

- results are decent altough distant from the state of the art: (selecting sentences randomly gives an Rouge-2 score of 0.04 while selecting the important ones gives a Rouge-2 score of 0.11)

- problem: most of the ground truth highlights are abstractive (harder to train extractive). This makes the it harder to create a dataset that is capable of giving informations about which parts/tokens of the document are important. If we had extractive summaries we could have created the summaries by concatenating the most important tokens.

**Notebook Sections:**

- Installing libraries

- CNN Dataset (from https://github.com/abisee/cnn-dailymail)

- Neural Network (Transformer instead of RNN because RNN like LSTM are soooo long to training)

- Instanciating Dataset/Model

- Training (forward/backward prop)

- Testing (generate summaries and get scores)


## Installing libraries

In [None]:
!pip install Rouge

Collecting Rouge
  Downloading rouge-1.0.1-py3-none-any.whl (13 kB)
Installing collected packages: Rouge
Successfully installed Rouge-1.0.1


In [None]:
from collections import defaultdict
import os
import math
from tqdm import tqdm

from google_drive_downloader import GoogleDriveDownloader as gdd
from rouge import Rouge

import torch, torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

## CNN Dataset

### Vocabolary to hot-encode the tokens

In [None]:
MAX_ARTICLE_LENGTH = 700         # default 400
VOCAB_SIZE         = 42000

In [None]:
# vocab will store these data (n_doc used for TF-IDF)
class Word():
  def __init__(self, idx):
    self.idx : int  = idx                    # the hot-encoding
    self.n_documents_with_word: int = 0      # how many articles it appeared in

class Vocab():
  def __init__(self, max_size=32000):
    self.words = {}
    self.idx   = 0
    self.max_size = max_size
    self.add_hotencode('[PAD]')
    self.add_hotencode('[UNK]')
    self.add_hotencode('.')

  def get_hotencode(self, word): 
    if word in self.words:
      return self.words[word].idx
    else:
      return self.words['[UNK]'].idx

  def add_hotencode(self, word):
    if word not in self.words and len(self.words)<self.max_size:
      self.words[word] = Word(self.idx)  # add word to dict
      self.idx += 1                      # next word will have a new code
    elif word not in self.words and len(self.words)>=self.max_size:
      return False
    return True

## Class containing CNN Dataset

- automated download of the files when called

- returns:
  - hotencoding of article (first 400 tokens)
  - hotencoding of the summary
  - tf-idf score for each token of the article
  - importance score of each sentence of the article (a sentence is important if it contains a piece of the summary)
  - importance score of each token of the article (a token is important if it is used in the summary)  ---> do not work too well: mostly abstractive summaries

In [None]:
class CNN_dailymail(Dataset):
  def __init__(self, root='./dataset', article_length=400, n_articles = 100, offset=0, vocab_size=42000):
    super().__init__()

    # set up paths
    self.root = root
    self.article_length = article_length
    self.n_articles = n_articles

    # get data
    if not os.path.exists(root):
      self.download_dataset()
    self.articles = os.listdir(f'{self.root}/cnn_stories_tokenized/')[offset:offset+n_articles]

    # vocab
    self.vocab = Vocab(max_size=vocab_size)
    self.init_vocab()

    # create tensors
    self.data = self.make_dataset()

  def init_vocab(self):
    # track info for TF-IDF
    for story in self.articles:
      with open(f'{self.root}/cnn_stories_tokenized/{story}') as file_in:
        text = file_in.read()
      if not self.is_valid(text):
        self.n_articles -= 1
        continue
      found = set()
      text = text[19:].split('@highligh')
      article =  text[0].split()

      for token in article:
        success = self.vocab.add_hotencode(token)
        if success and token not in found:
          self.vocab.words[token].n_documents_with_word += 1
          found.add(token)

  def get_best_match(self, summ, text, idx):
    base = summ[idx:]
    best, start, score = -1, 0, 0
    for i, t in enumerate(text):
      if len(base)>start and base[start] == t:
        if start>=score:
          best = i-start
        start += 1
      else:
        start = 0
    return best

  def make_dataset(self):
    """read each file in the dataset folder (each file is an article) and extracts the data"""
    data = []
    self.texts = []

    for story in self.articles:

      # read article
      with open(f'{self.root}/cnn_stories_tokenized/{story}') as file_in:
        text = file_in.read()

      # check goodness of article
      if not self.is_valid(text): continue

      # get tokens
      text = text[19:].split('@highligh')
      article =  text[0].split()[:self.article_length]
      summary = [t.split() for t in text[1:]]
      if len(article)<20: continue

      # hotencode summary
      y, i = torch.zeros(100, dtype=torch.long), 0
      drop_sent = False
      for sent in summary:
        for j, token in enumerate(sent):
          if i+j >= 100: break
          idx    = self.vocab.get_hotencode(token)
          if idx==1:
            drop_sent = True
          y[i+j] = idx
        i += len(sent)
      #if drop_sent: continue  ##### if unknown found ---> not extractive ---> DROP

      # predictions
      t_art = (' '.join(article)).split('. ')
      t_sum = {' '.join(s) for s in summary}
      self.texts.append((t_art, t_sum))    # save summary for testing
      pred = torch.zeros(42)
      for i, t in enumerate(t_art[:42]):
        if any(t[:8] in t2 for t2 in t_sum) or any(t[-8:] in t2 for t2 in t_sum):
          pred[i] = 1.

      # count term frequencies (TF)
      counts = defaultdict(lambda: 0)
      for tok in article:
        counts[tok] += 1

      # hotencode article & TF-IDF
      x = torch.zeros(self.article_length, dtype=torch.long)  # [PAD] has idx 0
      tfidf = torch.zeros(self.article_length)
      for i, token in enumerate(article):
        # hot-encode
        x[i]     = self.vocab.get_hotencode(token)

        # TF-IDF
        if token in self.vocab.words:
          tfidf[i] = (counts[token] / self.article_length) * torch.log(torch.tensor(self.n_articles / self.vocab.words[token].n_documents_with_word))
        
      # word importance
      word_importance = torch.zeros((2,self.article_length), dtype=torch.float)
      full_summ='. '.join(list(t_sum)).split()
      for i in range(len(full_summ)):
        best_match = self.get_best_match(full_summ, article, i)
        if best_match is None: continue
        word_importance[0,best_match] = 1.
        if i>0 and best_match>0 and full_summ[i-1] == article[best_match-1]:
          if full_summ[i] != article[best_match]: print('AAAAAAAAAAA')
          word_importance[1,best_match] = 1.

      # add sample
      data.append( (x,y,tfidf,pred, word_importance) )

    return data

  def get_sample(self, idx):
    """returns an encoded article and the text summary"""
    x, _,_,_,_ = self.data[idx]
    t_art, t_sum = self.texts[idx]
    return x, t_art, t_sum

  def is_valid(self, text):
    valid = True
    if len(text) < 100:           valid = False
    elif '@highligh' not in text: valid = False

    return valid


  def download_dataset(self):
    """preprocessed dataset from https://github.com/JafferWilson/Process-Data-of-CNN-DailyMail"""

    os.system(f'mkdir {self.root}')
    gdd.download_file_from_google_drive(
        file_id='0BzQ6rtO2VN95cmNuc2xwUS1wdEE',#'1C0MsFvsTdD-rCY6l4_kGsOEsKy46r_Yn',
        dest_path=f"{self.root}/data.zip",
        unzip=True)
    

  def __len__(self):
    return len(self.data)
  
  def __getitem__(self, idx):
    return self.data[idx]

## Neural Network Model

todo:
- add FF nn after CNN
- improve ch2 w_importance as rogue-2 existence

### positional encoder
(from pytorch transformers tutorial)

In [None]:
class PositionalEncoding(nn.Module):
  """add to embedding the positional encoding (sin/cos values)"""
  def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
    super().__init__()
    self.dropout = nn.Dropout(p=dropout)

    position = torch.arange(max_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
    pe = torch.zeros(max_len, 1, d_model)
    pe[:, 0, 0::2] = torch.sin(position * div_term)
    pe[:, 0, 1::2] = torch.cos(position * div_term)
    self.register_buffer('pe', pe)

  def forward(self, x: torch.Tensor):
    """ x: Tensor, shape [seq_len, batch_size, embedding_dim] """
    x = x + self.pe[:x.size(0)]
    return self.dropout(x)


## Transformer with Self-Attention

**INPUT DOCUMENT**

⬇

2 Transformers Encoders (self-attention on whole document)

⬇

2 Transformers Encoders (self-attention on single sentences)

⬇

CNN

⬇

**SENTENCE SCORE** (as sum of the scores of the tokens of a sentence)

In [None]:
class ModelSummarizer(nn.Module):
  def __init__(self, vocab_size, emb_size=256, doc_len=400, dot_hotencode=2):
    super().__init__()

    self.dot_hotencode = dot_hotencode
    self.d_model = emb_size

    # hotencoding to vector
    self.pos_encoder = PositionalEncoding(d_model=emb_size, dropout=0.4)
    self.encoder = nn.Embedding(vocab_size, emb_size)
    self.encoder.weight.data.uniform_(-.1, .1)

    # transformer encoder 1
    encoder_layers = nn.TransformerEncoderLayer(emb_size, 2, emb_size*2, .4, batch_first=True)
    self.fullsent_encoder = nn.TransformerEncoder(encoder_layers, 2)

    # transformer encoder 2
    encoder_layers = nn.TransformerEncoderLayer(emb_size, 1, emb_size*2, .4, batch_first=True)
    self.singlesent_encoder = nn.TransformerEncoder(encoder_layers, 2)

    # sentence scorer
    self.embedd_to_score = nn.Linear(emb_size,1)

    # importance extraction
    self.conv = nn.Conv1d(emb_size, 3, 3, padding=1)  # importance of  ch0:word    ch1:previous_and_current_word     ch2:sentence

    # latent space vectorization
    self.getlatent = nn.Sequential(
        nn.Conv1d(emb_size, emb_size//2, 5, padding=1, stride=15, dilation=3),
        nn.AdaptiveAvgPool1d(8),
        nn.Flatten(),
        nn.ReLU(),
        nn.Linear(8*emb_size//2, 512),
        nn.ReLU()
    )

  def forward(self, src):
    # hotencoding to vector
    x = self.encoder(src) * math.sqrt(self.d_model)
    x = self.pos_encoder(x)                                 # (bs, doc_len, emb_size)
    batch_size, sent_length, _ = x.shape

    # attention on the whole sentence
    x_mask = torch.zeros(sent_length, sent_length).to(src.device)
    full_sent_emb = self.fullsent_encoder(x, x_mask)    
    pred_tfidf = full_sent_emb[:,:,0]                       # (bs, doc_len, emb_size)

    # attention on single sentences
    src_mask = torch.ones(batch_size, sent_length, sent_length).to(src.device) * -1e20
    all_dots = []
    for i in range(batch_size):
      dots = ((src[i] == 2).nonzero(as_tuple=True)[0])
      all_dots.append(dots.tolist())
      prev = 0
      for d in dots:
        d = d.item()
        src_mask[i, prev:d+1, prev:d+1] = 0
        prev = d
      src_mask[i, prev:sent_length, prev:sent_length]
    single_sent_emb = self.singlesent_encoder(full_sent_emb, src_mask) # (bs, doc_len, emb_size)

    # summary
    single_sent_emb = single_sent_emb.permute((0,2,1))   
    words_importance = self.conv(single_sent_emb)            # (bs, doc_len, 3)   [ ch0:word    ch1:previous_and_current_word     ch2:sentence ]

    # sentence importance
    batch_docs = []
    for i in range(batch_size):
      doc, prev = [], 0
      for d in all_dots[i]:
        doc.append((words_importance[i, prev:d, 2]**2).sum().view(1))
        prev = d
      doc.append((words_importance[i, prev:sent_length, 2]**2).sum().view(1))
      doc += [torch.tensor(0.).view(1).to(src.device)] * (42-len(doc))
      doc = torch.cat(doc, dim=-1)[:42]
      batch_docs.append(doc)
    batch_docs = torch.stack(batch_docs)

    # latent space of sentence
    latent_space = self.getlatent(single_sent_emb)

    # return stuff
    results = {
        'sent_importance':batch_docs,                   # a score that tells the importance of each sentence
        'word_importance':words_importance[:,:2,:],     # score that thells if the word is in the summary
        'tfidf':          pred_tfidf,                   # learn tf-idf as multitask
        'latent':         latent_space                  # document to vector (summary similar to full doc)
    }
    return results


# Training

## Set Up network & stuff

In [None]:
# CPU or CUDA
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Dataset
loader = torch.utils.data.DataLoader(CNN_dailymail(n_articles=200000, article_length=MAX_ARTICLE_LENGTH, vocab_size=VOCAB_SIZE), 8, pin_memory=True)

# Neural Network
net = ModelSummarizer(VOCAB_SIZE, doc_len=MAX_ARTICLE_LENGTH).to(device).train()
optim = torch.optim.Adam(net.parameters(), lr=5e-3, weight_decay=1e-8)

# trainable parameters
sum(p.numel() for p in net.parameters())

Downloading 0BzQ6rtO2VN95cmNuc2xwUS1wdEE into ./dataset/data.zip... Done.
Unzipping...Done.


13551748

## Do Epochs

In [None]:
for epoch in range(2):
  tfloss, latloss, summloss = 0, 0, 0

  for art, summ, tf_true, y_true, w_importance in tqdm(loader):
    art, summ, tf_true, y_true, w_importance = art.to(device), summ.to(device), tf_true.to(device), y_true.to(device), w_importance.to(device)

    # forward
    doc_results = net(art)
    sum_results = net(summ)

    # losses
    loss1 = nn.MSELoss()(doc_results['tfidf'], tf_true)
    loss2 = torch.relu( 8 + nn.MSELoss()(doc_results['latent'], sum_results['latent']) - nn.MSELoss()(doc_results['latent'][1:], sum_results['latent'][:-1]) )
    loss3 = nn.MSELoss()(doc_results['sent_importance'], y_true)
    loss4 = nn.MSELoss()(doc_results['word_importance'], w_importance)
    final_loss = loss1 + loss2 + 10*( loss3 + loss4 )

    # backprop
    final_loss.backward()
    optim.step()
    optim.zero_grad()

    # update stats
    tfloss, latloss, summloss = tfloss+loss1.item(), latloss+loss2.item(), summloss+loss3.item()

  print(f'losses for the epoch -->  tf-idf:{tfloss}, latentdistance:{latloss}, importantsents:{summloss}')

100%|██████████| 11558/11558 [37:29<00:00,  5.14it/s]


losses for the epoch -->  tf-idf:226.17038132902235, latentdistance:96987.45580530167, importantsents:957.5618014745414


100%|██████████| 11558/11558 [37:34<00:00,  5.13it/s]

losses for the epoch -->  tf-idf:148.08792835334316, latentdistance:94606.98675918579, importantsents:890.7497558239847





# Testing

## Inference

In [None]:
def get_summaries(dataset, printfirst=False):
  outputs, abstracts = [], []

  for idx in range(len(dataset)):
    x, t_art, t_sum = dataset.get_sample(idx)

    results = net(x.unsqueeze(0).to(device))
    pred = results['sent_importance']
    pred = pred.squeeze(0)[:len(t_art)]

    t_pred = []
    if len(t_art)>2:
      _, top2 = pred.topk(k=3, dim=0)
    else:
      top2 = [0,1]

    for i in top2:
      if i >= len(t_art): i = 0
      t_pred.append(t_art[i])

    predicted = '. '.join(t_pred)
    real      = '. '.join(list(t_sum))[2:]

    if printfirst:
      printfirst=False
      print(f"GROUND TRUTH SUMMARY\n  {real}\n{'_'*100}\nEXTRACTED SUMMARY\n  {predicted}\n{'='*100}\n")

    # stats
    outputs.append(predicted)
    abstracts.append(real)
  
  return outputs, abstracts

def compute_accuracy(outputs, abstracts):
  acc = Rouge().get_scores(outputs, abstracts, avg=True)
  print(f"""
Rouge-1: recall {acc['rouge-1']['r']}, precision {acc['rouge-1']['p']}, f1 {acc['rouge-1']['f']}
Rouge-2: recall {acc['rouge-2']['r']}, precision {acc['rouge-2']['p']}, f1 {acc['rouge-2']['f']}
Rouge-L: recall {acc['rouge-l']['r']}, precision {acc['rouge-l']['p']}, f1 {acc['rouge-l']['f']}
        """)



## Compute Scores

In [None]:
net.eval()
traindata = CNN_dailymail(n_articles=4000, offset=20000)
testdata = CNN_dailymail(n_articles=4000, offset=50000)
 
print('TRAINING SET')
outputs, abstracts = get_summaries(traindata)
compute_accuracy(outputs, abstracts)

print('VALIDATION SET')
outputs, abstracts = get_summaries(testdata, True)
compute_accuracy(outputs, abstracts)

TRAINING SET

Rouge-1: recall 0.3825598650190738, precision 0.26851838415184, f1 0.3069110942724683
Rouge-2: recall 0.13458653729294975, precision 0.08674773677889, f1 0.10185534606662996
Rouge-L: recall 0.34906454924142116, precision 0.2451949880551469, f1 0.28012148252430363
        
VALIDATION SET
GROUND TRUTH SUMMARY
  They argue the U.S. should orient its foreign policy to favor pro-equality regimes. t Nations that wo n't commit to that goal do n't deserve American money or weapons , they say. t Authors : Terrorists in many nations oppose gender equality
____________________________________________________________________________________________________
EXTRACTED SUMMARY
  Even terrorists have fears . The logic , for them , is simple . And the prospect of gender equality appears to rank high on their list of worst nightmares 


Rouge-1: recall 0.38579475555652465, precision 0.2683727565670014, f1 0.3085049602217897
Rouge-2: recall 0.1367134627115836, precision 0.0870151193626517, 

### Comparison with Random selected sentences

In [None]:
def get_randomsummaries(dataset, printfirst=False):
  outputs, abstracts = [], []

  for idx in range(len(dataset)):
    x, t_art, t_sum = dataset.get_sample(idx)
    if len(t_art)==0: continue


    t_pred = []
    if len(t_art)>2:
      pred = torch.rand(len(t_art))
      pred = pred.squeeze(0)[:len(t_art)]
      _, top2 = pred.topk(k=3, dim=0)
      top2 = top2.tolist()
    else:
      top2 = [0]

    for i in top2:
      t_pred.append(t_art[i])

    predicted = '. '.join(t_pred)
    real      = '. '.join(list(t_sum))[2:]

    if printfirst:
      printfirst=False
      print(f"GROUND TRUTH SUMMARY\n  {real}\n{'_'*100}\nEXTRACTED SUMMARY\n  {predicted}\n{'='*100}\n")

    # stats
    outputs.append(predicted)
    abstracts.append(real)
  
  return outputs, abstracts

print('TRAINING SET')
outputs, abstracts = get_randomsummaries(traindata)
compute_accuracy(outputs, abstracts)

print('VALIDATION SET')
outputs, abstracts = get_randomsummaries(testdata, True)
compute_accuracy(outputs, abstracts)

TRAINING SET

Rouge-1: recall 0.2918446135875289, precision 0.22238655984579914, f1 0.24408942573128645
Rouge-2: recall 0.0744033160152091, precision 0.05286653587265048, f1 0.05913813314411545
Rouge-L: recall 0.2637194585694278, precision 0.2013533172360759, f1 0.22076545335946063
        
VALIDATION SET
GROUND TRUTH SUMMARY
  They argue the U.S. should orient its foreign policy to favor pro-equality regimes. t Nations that wo n't commit to that goal do n't deserve American money or weapons , they say. t Authors : Terrorists in many nations oppose gender equality
____________________________________________________________________________________________________
EXTRACTED SUMMARY
  Even terrorists have fears . So , any successful strategy against these groups must put women 's rights at the front and center of policy planning . Educated women and girls would fundamentally challenge the power structure of organizations like ISIS 


Rouge-1: recall 0.29633085069116855, precision 0.22520

### Comparison with TF-IDF selected sentences

In [None]:
def get_tfidfsummaries(dataset, printfirst=False):
  outputs, abstracts = [], []

  for idx in range(len(dataset)):
    x, t_art, t_sum = dataset.get_sample(idx)
    _,_,tf,_,_ = dataset[idx]

    pred, prev = [], 0
    for d in ((x == 2).nonzero(as_tuple=True)[0]):
      pred.append(tf[prev:d].sum().view(1))
      prev = d
    pred.append(tf[prev:400].sum().view(1))
    pred = torch.cat(pred, dim=-1)

    #print(sum(  [1 for t in x if t==2]  ))
    #print(pred, len(t_art))

    t_pred = []
    try:
      _, top2 = pred.topk(k=3, dim=0)
    except:
      print('err')
      top2 = [0,1]

    for i in top2:
      try:    t_pred.append(t_art[i])
      except: pass

    predicted = '. '.join(t_pred)
    real      = '. '.join(list(t_sum))[2:]

    if printfirst:
      printfirst=False
      print(f"GROUND TRUTH SUMMARY\n  {real}\n{'_'*100}\nEXTRACTED SUMMARY\n  {predicted}\n{'='*100}\n")

    # stats
    outputs.append(predicted)
    abstracts.append(real)
  
  return outputs, abstracts

print('TRAINING SET')
outputs, abstracts = get_tfidfsummaries(traindata)
compute_accuracy(outputs, abstracts)

print('VALIDATION SET')
outputs, abstracts = get_tfidfsummaries(testdata, True)
compute_accuracy(outputs, abstracts)

TRAINING SET
err
err
err

Rouge-1: recall 0.37931125622137246, precision 0.22912423199380835, f1 0.27754392479024775
Rouge-2: recall 0.11962545575379759, precision 0.06435749945905123, f1 0.080655679990919
Rouge-L: recall 0.34168043309850366, precision 0.20678630036821835, f1 0.25024388981194634
        
VALIDATION SET
GROUND TRUTH SUMMARY
  They argue the U.S. should orient its foreign policy to favor pro-equality regimes. t Nations that wo n't commit to that goal do n't deserve American money or weapons , they say. t Authors : Terrorists in many nations oppose gender equality
____________________________________________________________________________________________________
EXTRACTED SUMMARY
  investment . Does any group or state that refuses to commit to working toward gender equality merit our money , weapons or political capital ? Any entity that refuses to treat at least half of its population as equal to the other can not be expected to protect minorities and promote tolerance 