<a href="https://colab.research.google.com/github/IntelligenceX-ai/ix-language-bertsum-extractive-summarization/blob/feature-first_implementation/Model/bertsum_extractive_summarization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Extractive Summarization with BERT**

Reference: https://skimai.com/tutorial-how-to-fine-tune-bert-for-summarization/

- Extension of BERT for text summarization:
  - **BERTSUM**: Liu et al., 2019 https://arxiv.org/abs/1908.08345
- Lite versions of BERT (faster & smaller) fine-tuned:
  - **DistilBERT** (Sanh et al., 2019)
  - **MobileBERT** (Sun et al., 2019)

**Text Summarization:**
1. **Abstractive summarization**: rewriting key points
  - could be more useful when writing essays
2. **Extractive summarization**: copies the most important sentences directly from the document
  - *binary classification problem at the sentence level*
  - could be more useful when doing research and need quick summary of a document

**BERT Summarizer:**
- Part 1: BERT encoder
- Part 2: Summarization classifier

**Part 1: BERT encoder**
- BERTSUM
- pretrained BERT-base encoder from the masked language modeling task (Devlin et al., 2018)
- assign each sentence a label $y_i \in {0, 1}$ indicating whether the sentence should be included in the final summary
- Thus, we need to add a token [CLS] before each sentence
- After we run a forward pass through the encoder, the last hidden layer of these [CLS] tokens will be used as the representations for our sentences
- => vector representation of each sentence
- **"learns the interaction among tokens in our document"**

**Part 2: Summarization Classifier**
- Simple feed-forward layer 
- => score for each sentence
- Transformer model with 3 layers showed best result => showed that inter-sentence interactions through self-attention mechanism is important in selecting the most important sentences
- **"learn the interactions among sentences"**

To make summarization lighter & faster to be deployed on low-resource devices, the following code was modified from the BERTSUM source code:
- BERT encoder was replaced with DistilBERT & MobileBERT
- summary layers are kept the same

**Findings from the performance of BERT-base, DistilBERT, MobileBERT:**
- DistilBERT - 40% smaller than BERT-base
- training loss: DistilBERT = BERT-base < MobileBERT (performed slightly worse)
- MobileBERT significantly smaller & faster

**Conclusion:**

DistilBERT retains BERT-base’s performance in extractive summarization while being 45% smaller. MobileBERT is 4x smaller and 2.7x faster than BERT-base yet retains 94% of its performance.

# **Setup**

In [2]:
!git clone https://github.com/chriskhanhtran/bert-extractive-summarization.git
%cd bert-extractive-summarization
!pip install -r requirements.txt

Cloning into 'bert-extractive-summarization'...
remote: Enumerating objects: 239, done.[K
remote: Total 239 (delta 0), reused 0 (delta 0), pack-reused 239[K
Receiving objects: 100% (239/239), 321.37 KiB | 18.90 MiB/s, done.
Resolving deltas: 100% (123/123), done.
/content/bert-extractive-summarization
[31mERROR: torch-1.1.0-cp36-cp36m-linux_x86_64.whl is not a supported wheel on this platform.[0m


In [3]:
# If error in the above setup, run the following:
!pip install torch
!pip install transformers
!pip install boto3
!pip install newspaper3k

Collecting transformers
  Downloading transformers-4.11.3-py3-none-any.whl (2.9 MB)
[K     |████████████████████████████████| 2.9 MB 28.1 MB/s 
Collecting huggingface-hub>=0.0.17
  Downloading huggingface_hub-0.0.19-py3-none-any.whl (56 kB)
[K     |████████████████████████████████| 56 kB 4.7 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.46-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 59.8 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 49.4 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 31.6 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  A

### Download Checkpoints

In [4]:
# Loading checkpoints of DistilBERT, MobileBERT fine-tuned on CNN/DailyMail datasets

!wget -O "checkpoints/bertbase_ext.pt" "https://www.googleapis.com/drive/v3/files/1t27zkFMUnuqRcsqf2fh8F1RwaqFoMw5e?alt=media&key=AIzaSyCmo6sAQ37OK8DK4wnT94PoLx5lx-7VTDE"
!wget -O "checkpoints/distilbert_ext.pt" "https://www.googleapis.com/drive/v3/files/1WxU7cHECfYaU32oTM0JByTRGS5f6SYEF?alt=media&key=AIzaSyCmo6sAQ37OK8DK4wnT94PoLx5lx-7VTDE"
!wget -O "checkpoints/mobilebert_ext.pt" "https://www.googleapis.com/drive/v3/files/1umMOXoueo38zID_AKFSIOGxG9XjS5hDC?alt=media&key=AIzaSyCmo6sAQ37OK8DK4wnT94PoLx5lx-7VTDE"

--2021-10-21 20:19:17--  https://www.googleapis.com/drive/v3/files/1t27zkFMUnuqRcsqf2fh8F1RwaqFoMw5e?alt=media&key=AIzaSyCmo6sAQ37OK8DK4wnT94PoLx5lx-7VTDE
Resolving www.googleapis.com (www.googleapis.com)... 142.251.45.106, 142.250.188.202, 172.217.13.74, ...
Connecting to www.googleapis.com (www.googleapis.com)|142.251.45.106|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 497468883 (474M) [application/octet-stream]
Saving to: ‘checkpoints/bertbase_ext.pt’


2021-10-21 20:19:24 (78.3 MB/s) - ‘checkpoints/bertbase_ext.pt’ saved [497468883/497468883]

--2021-10-21 20:19:24--  https://www.googleapis.com/drive/v3/files/1WxU7cHECfYaU32oTM0JByTRGS5f6SYEF?alt=media&key=AIzaSyCmo6sAQ37OK8DK4wnT94PoLx5lx-7VTDE
Resolving www.googleapis.com (www.googleapis.com)... 142.251.45.10, 142.250.81.202, 172.217.9.202, ...
Connecting to www.googleapis.com (www.googleapis.com)|142.251.45.10|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 324966717 (310M) 

# **Usage**

**Import libraries**

In [5]:
import torch
from models.model_builder import ExtSummarizer # !pip install transformers
from ext_sum import summarize

import textwrap
import nltk
nltk.download('punkt') # for tokenizer
from newspaper import Article # need to pip install newspaper3k

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


**Input data & preprocessing**

In [7]:
# Crawl URL with `newspaper3k`
url = "https://www.cnn.com/2020/05/29/tech/facebook-violence-trump/index.html" #@param {type: 'string'}
#url = "https://www.cnn.com/2020/05/22/business/hertz-bankruptcy/index.html"
article = Article(url)
article.download()
article.parse()

# to print in pretty way, limiting width to 80
wrapper = textwrap.TextWrapper(width=80) # way to define width for printing
print(wrapper.fill(article.text))

# Save input text into `raw_data/input.txt`
with open('raw_data/input.txt', 'w') as f:
    f.write(article.text)

(CNN) Over and over again in 2018, during an apology tour that took him from the
halls of the US Congress to an appearance before the European Parliament, Mark
Zuckerberg said Facebook had failed to "take a broad enough view of our
responsibilities."  But two years later, Zuckerberg and Facebook are still
struggling with their responsibilities and how to handle one of their most
famous users: President Donald Trump.  Despite Zuckerberg having previously
indicated any post that "incites violence" would be a line in the sand — even if
it came from a politician — Facebook remained silent for hours Friday after
Trump was accused of glorifying violence in posts that appeared on its
platforms.  At 12:53am ET on Friday morning, as cable news networks carried
images of fires and destructive protests in Minneapolis, the President tweeted :
"These THUGS are dishonoring the memory of George Floyd, and I won't let that
happen. Just spoke to Governor Tim Walz and told him that the Military is with


**Load Summarization Model**

In [9]:
model_type = 'mobilebert' #@param ['bertbase', 'distilbert', 'mobilebert']
checkpoint = torch.load(f'checkpoints/{model_type}_ext.pt', map_location='cpu')
model = ExtSummarizer(checkpoint=checkpoint, bert_type=model_type, device="cpu")

In [10]:
# Print model's state_dict

#print("Model's state_dict:")
#for param_tensor in model.state_dict():
#    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

**Run Summarization**

In [12]:
%%time
input_fp = 'raw_data/input.txt' # includes input
result_fp = 'results/summary.txt' # the summary result will be saved here

"""
Takes inputs:
- input file path
- result file path
- summarization model
- max_length of result summary
"""
summary = summarize(input_fp, result_fp, model, max_length=1)

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

CPU times: user 810 ms, sys: 47.4 ms, total: 858 ms
Wall time: 1.21 s


  scores = scores.masked_fill(mask.byte(), -1e18)


In [13]:
# Print summary
print(summary)

(CNN) Over and over again in 2018, during an apology tour that took him from the halls of the US Congress to an appearance before the European Parliament, Mark Zuckerberg said Facebook had failed to "take a broad enough view of our responsibilities."


In [14]:
# Print summary in a pretty way
print(wrapper.fill(summary))

(CNN) Over and over again in 2018, during an apology tour that took him from the
halls of the US Congress to an appearance before the European Parliament, Mark
Zuckerberg said Facebook had failed to "take a broad enough view of our
responsibilities."


# **Saving & Loading Model**

**1. Saving state_dict() - recommended method for saving models**

In [15]:
# Saving state_dict
PATH = 'state-dict-mobilebert-summarization-model.pt'
torch.save(model.state_dict(), PATH)

In [16]:
# Define model to load saved model
loaded_model = ExtSummarizer(checkpoint=checkpoint, bert_type=model_type, device="cpu")

In [17]:
# Loading model
loaded_model.load_state_dict(torch.load(PATH))
loaded_model.eval() # model.eval() must be called to set dropout and batch \
                      # normalization layers to evaluation mode before running \
                      # inference. Failing to do this will yield inconsistent \
                      # inference results.

ExtSummarizer(
  (bert): Bert(
    (model): MobileBertModel(
      (embeddings): MobileBertEmbeddings(
        (word_embeddings): Embedding(30522, 128, padding_idx=0)
        (position_embeddings): Embedding(512, 512)
        (token_type_embeddings): Embedding(2, 512)
        (embedding_transformation): Linear(in_features=384, out_features=512, bias=True)
        (LayerNorm): NoNorm()
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): MobileBertEncoder(
        (layer): ModuleList(
          (0): MobileBertLayer(
            (attention): MobileBertAttention(
              (self): MobileBertSelfAttention(
                (query): Linear(in_features=128, out_features=128, bias=True)
                (key): Linear(in_features=128, out_features=128, bias=True)
                (value): Linear(in_features=512, out_features=128, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): MobileBertSelfOutput(
              

**2. Saving entire model - Python's pickle module**

In [18]:
# Saving entire model
PATH = 'mobilebert-summarization-model.pt'
torch.save(model, PATH)

In [19]:
# Define model to load saved model
loaded_model = ExtSummarizer(checkpoint=checkpoint, bert_type=model_type, device="cpu")

In [21]:
# Load model
loaded_model = torch.load(PATH) # make sure model class is defined ahead**
loaded_model.eval() # model.eval() must be called to set dropout and batch \
                      # normalization layers to evaluation mode before running \
                      # inference. Failing to do this will yield inconsistent \
                      # inference results.

ExtSummarizer(
  (bert): Bert(
    (model): MobileBertModel(
      (embeddings): MobileBertEmbeddings(
        (word_embeddings): Embedding(30522, 128, padding_idx=0)
        (position_embeddings): Embedding(512, 512)
        (token_type_embeddings): Embedding(2, 512)
        (embedding_transformation): Linear(in_features=384, out_features=512, bias=True)
        (LayerNorm): NoNorm()
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (encoder): MobileBertEncoder(
        (layer): ModuleList(
          (0): MobileBertLayer(
            (attention): MobileBertAttention(
              (self): MobileBertSelfAttention(
                (query): Linear(in_features=128, out_features=128, bias=True)
                (key): Linear(in_features=128, out_features=128, bias=True)
                (value): Linear(in_features=512, out_features=128, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): MobileBertSelfOutput(
              

**Run summarization again with loaded model**

In [23]:
%%time
"""
Takes inputs:
- input file path
- result file path
- summarization model
- max_length of result summary
"""
summary = summarize(input_fp, result_fp, loaded_model, max_length=1)

CPU times: user 566 ms, sys: 12.1 ms, total: 578 ms
Wall time: 703 ms


  scores = scores.masked_fill(mask.byte(), -1e18)


In [24]:
# Print summary
print(summary)

(CNN) Over and over again in 2018, during an apology tour that took him from the halls of the US Congress to an appearance before the European Parliament, Mark Zuckerberg said Facebook had failed to "take a broad enough view of our responsibilities."


In [25]:
# Print summary in a pretty way
print(wrapper.fill(summary))

(CNN) Over and over again in 2018, during an apology tour that took him from the
halls of the US Congress to an appearance before the European Parliament, Mark
Zuckerberg said Facebook had failed to "take a broad enough view of our
responsibilities."
