# MemSum with Legal Domain Embeddings

In this notebook we will train MemSum with different word embeddings pretrained on a legal domain.

The following models where trained

1. train ToS;DR dataset with legal embeddings
2. train GovReport dataset with legal embeddings (utilize MemSum's GovReport checkpoint) 
3. fine-tune ToS;DR on 2

In [38]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
!pip install -r requirements.txt -q

In [3]:
import torch
torch.__version__

'1.13.0+cu116'

In [4]:
torch.cuda.empty_cache()

## Download pretrained word embedding & model checkpoints

legal domain word2vec models from https://osf.io/qvg8s/

modified to fit MemSum(vector size set to 200 & \<eod>, \<pad>, \<unk> tokens added)

Download pickle film & model checkpoints from google drive(https://drive.google.com/drive/folders/1za083ah4oPjX14uYH8rwfbtoE-wiFJcP?usp=sharing)

# Training

In [37]:
!nvidia-smi

Tue Dec 13 23:22:57 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:05:00.0 Off |                  N/A |
| 44%   50C    P8    18W / 320W |      3MiB / 10240MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA TITAN RTX    On   | 00000000:09:00.0 Off |                  N/A |
| 41%   42C    P8     8W / 280W |   3444MiB / 24576MiB |      0%      Default |
|       

### 1. ToS;DR + Legal Embedding

In [None]:
!cd src/MemSum_Full; python train.py -training_corpus_file_name ../../tosdata/train_labelled.jsonl -validation_corpus_file_name ../../tosdata/validation.jsonl -model_folder ../../model/just_tos/run1/ -log_folder ../../log/just_tos/run1/ -vocabulary_file_name ../../sigmalaw/SigmaLaw_Vocab_200dim.pkl -pretrained_unigram_embeddings_file_name ../../sigmalaw/SigmaLaw_Embeddings_200dim.pkl -max_seq_len 50 -max_doc_len 300 -num_of_epochs 15 -save_every 100 -n_device 1 -batch_size_per_device 4 -max_extracted_sentences_per_document 13 -moving_average_decay 0.999 -p_stop_thres 0.6

1611it [00:00, 10911.57it/s]
202it [00:00, 12187.13it/s]
model restored!
optimizer restored!
[current_epoch: 2] 
current_batch restored!
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  remaining_mask_np = np.ones_like( doc_mask_np ).astype( np.bool ) | doc_mask_np
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  extraction_mask_np = np.zeros_like( doc_mask_np ).astype( np.bool ) | doc_mask_np
99it [00:50,  2.13it/s][current_batch: 01200] loss: 4.098, learning rate: 0.000100
Starting validation ...
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  remaining_mask_np = np.ones_like( doc_mask ).astype( np.bool ) | doc_mask
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  extraction_mask_np = np.

### 2. GovReport + Legal Embedding

In [35]:
!cd src/MemSum_Full; python train.py -training_corpus_file_name ../../data/train_GOVREPORT.jsonl -validation_corpus_file_name ../../data/val_GOVREPORT.jsonl -model_folder ../../model/gov/run1 -log_folder ../../log/gov/run1 -vocabulary_file_name ../../sigmalaw/SigmaLaw_Vocab_200dim.pkl -pretrained_unigram_embeddings_file_name ../../sigmalaw/SigmaLaw_Embeddings_200dim.pkl -restore_old_checkpoint True -max_seq_len 100 -max_doc_len 500 -num_of_epochs 14 -save_every 1000 -n_device 1 -batch_size_per_device 4 -max_extracted_sentences_per_document 22 -moving_average_decay 0.999 -p_stop_thres 0.6

17517it [00:04, 4338.48it/s]
974it [00:00, 5103.27it/s]
model restored!
optimizer restored!
[current_epoch: 14] 
current_batch restored!


### 3. GovReport + Legal Embeddings + ToS;DR

use model_batch_63000 from 2

In [80]:
!cd src/MemSum_Full; python train.py -training_corpus_file_name ../../tosdata/train_labelled.jsonl -validation_corpus_file_name ../../tosdata/validation.jsonl -model_folder ../../model/gov_tos/run1/ -log_folder ../../log/gov_tos/run1/ -vocabulary_file_name ../../sigmalaw/SigmaLaw_Vocab_200dim.pkl -pretrained_unigram_embeddings_file_name ../../sigmalaw/SigmaLaw_Embeddings_200dim.pkl -max_seq_len 50 -max_doc_len 300 -num_of_epochs 10 -save_every 100 -n_device 1 -batch_size_per_device 4 -max_extracted_sentences_per_document 13 -moving_average_decay 0.999 -p_stop_thres 0.6

1611it [00:00, 7491.95it/s]
202it [00:00, 10440.15it/s]
model restored!
optimizer restored!
[current_epoch: 0] 
current_batch restored!
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  remaining_mask_np = np.ones_like( doc_mask_np ).astype( np.bool ) | doc_mask_np
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  extraction_mask_np = np.zeros_like( doc_mask_np ).astype( np.bool ) | doc_mask_np
99it [00:56,  1.94it/s][current_batch: 00100] loss: 4.040, learning rate: 0.000100
Starting validation ...
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  remaining_mask_np = np.ones_like( doc_mask ).astype( np.bool ) | doc_mask
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  extraction_mask_np = np.z

# Testing trained model on custom dataset

In [45]:
from summarizers import MemSum
from tqdm import tqdm
from rouge_score import rouge_scorer
import json
import numpy as np

In [95]:
rouge_cal = rouge_scorer.RougeScorer(['rouge1','rouge2', 'rougeLsum'], use_stemmer=True)

tosdr_legal = MemSum(  "model/just_tos/run2/model_batch_4300.pt", 
                  "sigmalaw/SigmaLaw_Vocab_200dim.pkl", 
                  gpu = 0 ,  max_doc_len = 300  )

In [68]:
gov_legal = MemSum(  "model/gov/run1/model_batch_63000.pt", 
                  "sigmalaw/SigmaLaw_Vocab_200dim.pkl", 
                  gpu = 0 ,  max_doc_len = 500  )

In [93]:
tosdr_gov_legal = MemSum(  "model/gov_tos/run1/model_batch_1500.pt", 
                  "sigmalaw/SigmaLaw_Vocab_200dim.pkl", 
                  gpu = 0 ,  max_doc_len = 300  )

In [48]:
test_data = [ json.loads(line) for line in open("tosdata/test.jsonl")]

In [74]:
test_gov = [ json.loads(line) for line in open("data/test_GOVREPORT.jsonl")]

In [49]:
def evaluate( model, corpus, p_stop, max_extracted_sentences, rouge_cal ):
    scores = []
    for data in tqdm(corpus):
        gold_summary = data["summary"]
        extracted_summary = model.extract( [data["text"]], p_stop_thres = p_stop, max_extracted_sentences_per_document = max_extracted_sentences )[0]
        
        score = rouge_cal.score( "\n".join( gold_summary ), "\n".join(extracted_summary)  )
        scores.append( [score["rouge1"].fmeasure, score["rouge2"].fmeasure, score["rougeLsum"].fmeasure ] )
    
    return np.asarray(scores).mean(axis = 0)

In [77]:
evaluate( tosdr_legal, test_data, 0.6, 13, rouge_cal )

100%|██████████| 201/201 [00:27<00:00,  7.43it/s]


array([0.4141877 , 0.27054185, 0.40019928])

In [75]:
evaluate( gov_legal, test_gov, 0.6, 22, rouge_cal )

100%|██████████| 973/973 [06:20<00:00,  2.56it/s]


array([0.593514  , 0.28232396, 0.56586892])

In [94]:
evaluate( tosdr_gov_legal, test_data, 0.6, 13, rouge_cal )

100%|██████████| 201/201 [00:26<00:00,  7.50it/s]


array([0.40880941, 0.26623319, 0.39521398])

To cite MemSum, please use the following bibtex:

```
@inproceedings{gu-etal-2022-memsum,
    title = "{M}em{S}um: Extractive Summarization of Long Documents Using Multi-Step Episodic {M}arkov Decision Processes",
    author = "Gu, Nianlong  and
      Ash, Elliott  and
      Hahnloser, Richard",
    booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
    month = may,
    year = "2022",
    address = "Dublin, Ireland",
    publisher = "Association for Computational Linguistics",
    url = "https://aclanthology.org/2022.acl-long.450",
    doi = "10.18653/v1/2022.acl-long.450",
    pages = "6507--6522",
    abstract = "We introduce MemSum (Multi-step Episodic Markov decision process extractive SUMmarizer), a reinforcement-learning-based extractive summarizer enriched at each step with information on the current extraction history. When MemSum iteratively selects sentences into the summary, it considers a broad information set that would intuitively also be used by humans in this task: 1) the text content of the sentence, 2) the global text context of the rest of the document, and 3) the extraction history consisting of the set of sentences that have already been extracted. With a lightweight architecture, MemSum obtains state-of-the-art test-set performance (ROUGE) in summarizing long documents taken from PubMed, arXiv, and GovReport. Ablation studies demonstrate the importance of local, global, and history information. A human evaluation confirms the high quality and low redundancy of the generated summaries, stemming from MemSum{'}s awareness of extraction history.",
}
```