# Long Form Text Summarization using JumpStart Foundation Model  

In [2]:
%%capture 

# !pip install cohere_sagemaker
!pip install pandas
!pip install tqdm

In [20]:
from sagemaker import get_execution_role
# from cohere_sagemaker import CohereError
# from cohere_sagemaker import Client
from sagemaker import ModelPackage
import cohere_sagemaker
from tqdm import tqdm
import pandas as pd
import numpy as np
import sagemaker
import logging
import boto3
import time
import re
import json

In [29]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

## Part - I: Hosting the foundation model for real-time inference 

#### I. Imports 

#### Setup Logging 

In [3]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

##### Log versions of dependencies 

In [4]:
logger.info(f'[Using SageMaker version: {sagemaker.__version__}]')
logger.info(f'[Using Boto3 version: {boto3.__version__}]')

[Using SageMaker version: 2.132.0]
[Using Boto3 version: 1.26.69]


#### II. Setup essentials 

Mapping for Model Packages (initially only us-east-1 and eu-west-1 is supported)

In [None]:
model_package_map = {
    'us-east-1': 'arn:aws:sagemaker:us-east-1:865070037744:model-package/cohere-gpt-medium-v1-4-825b877abfd53d7ca65fd7b4b262c421',
    'eu-west-1': 'arn:aws:sagemaker:eu-west-1:985815980388:model-package/cohere-gpt-medium-v1-4-825b877abfd53d7ca65fd7b4b262c421'
}

In [None]:
region = boto3.Session().region_name
logger.info(f'Region = {region}')

In [None]:
if region not in model_package_map.keys():
    raise Exception(f'Unsupported region = {region}')

In [None]:
MODEL_PACKAGE_ARN = model_package_map[region]
logger.info(f'Model package ARN = {MODEL_PACKAGE_ARN}')

In [None]:
ROLE = get_execution_role()
session = sagemaker.Session()
logger.info(f'Role = {ROLE}')

In [None]:
timestamp = int(time.time())
MODEL_NAME = f'cohere-medium-{timestamp}'

#### III. Create a SageMaker endpoint for real-time inference 

In [None]:
model = ModelPackage(role=ROLE, 
                     model_package_arn=MODEL_PACKAGE_ARN, 
                     sagemaker_session=session, 
                     name=MODEL_NAME)
model.__dict__

In [None]:
NUM_INSTANCES = 1
INSTANCE_TYPE = 'ml.g5.xlarge'

In [None]:
%%time

model.deploy(NUM_INSTANCES, 
             INSTANCE_TYPE, 
             endpoint_name=MODEL_NAME)

## Part 2: Long-form abstractive text summarization of legal judgement docs

#### I. Read, parse and chunk docs 

In [3]:
doc_name = '5.txt'

In [7]:
cleaned_lines = []
with open(f'./docs/{doc_name}', encoding='iso-8859-1') as doc:
    for line in doc.readlines():
        line = line.strip()
        line = re.sub(' +', ' ', line)
        line = line.replace('\n', '')
        line = line.replace('\t', '')
        line = line.replace('  ', ' ')
        if len(line) > 0:
            cleaned_lines.append(line)

In [8]:
doc = ' '.join(cleaned_lines)
doc = doc.split()
len(doc)

4168

In [9]:
chunk_size = 768
chunks = [' '.join(doc[i:i+chunk_size]) for i in range(0, len(doc), chunk_size)]

In [10]:
len(chunks)

6

In [13]:
chunks[0]

'Civil Appeal No. 8 of 1951. Appeal from the judgment and decree dated 12th October, 1944, of the High Court of Judicature at Allahabad (Allsop and Malik JJ.)in First Appeal No. 374 of 1941 arising out of a Decree dated 31st July, 1941, of the Court of the Civil Judge, Moradabad, in Original Suit No. 9 of 1941. Bakshi Tek Chand (section K. Kapoor, with him) for the appel lant. Achhru Ram (Jwala Prasad, with him) for the respondent. February 22. The judgment of the Court was deliv ered by BoSE J. This is a litigation between two branches of a family whose common ancestor was one Megh Raj Singh The family tree is as follows: Megh Raj Singh Jawahar Singh Madan Singh Shankar Lal(d 1884) Brijlal (d. 1889 or (1890) Daughter: Met. Mohan Dei (d. Oct 1929) Kishan Lal Mahabir Prasad Husband: Narain Das (d. 21 5 1940) (d. 1921) Shri Kishan Das Mst. Deoki Jugal Kishore Amar Nath (d.march 1929) (d. 1894) Plff. 1 Plff.2. Dhiyan Singh Jai Bhagwan Singh Deft. 1 Deft. 2 Ghas Ram Onkar Prasad The disput

#### II. Short-form Abstractive Text Summarization

In [23]:
ENDPOINT_NAME = 'huggingface-text2text-flan-t5-xl-1681986641'
CONTENT_TYPE = 'application/json'
# ENDPOINT_NAME = MODEL_NAME

In [16]:
# client = Client(endpoint_name=ENDPOINT_NAME)
client = boto3.client('sagemaker-runtime')

In [17]:
summaries_by_chunks = []

In [32]:
%%time

MAX_LENGTH = 256
NUM_RETURN_SEQUENCES = 1
TOP_K = 0
TOP_P = 0.7
DO_SAMPLE = True 
TEMPERATURE = 0.2


for chunk in tqdm(chunks):
    prompt = f'{chunk}\nSummarize the above context.'
    payload = {'text_inputs': prompt, 
           'max_length': MAX_LENGTH, 
           'temperature': TEMPERATURE,
           'num_return_sequences': NUM_RETURN_SEQUENCES,
           'top_k': TOP_K,
           'top_p': TOP_P,
           'do_sample': DO_SAMPLE}
    payload = json.dumps(payload).encode('utf-8')

    response = client.invoke_endpoint(EndpointName=ENDPOINT_NAME, 
                                  ContentType=CONTENT_TYPE, 
                                  Body=payload)
    
    model_predictions = json.loads(response['Body'].read())
    generated_text = model_predictions['generated_texts'][0]
    summaries_by_chunks.append(generated_text)


100%|██████████| 6/6 [00:10<00:00,  1.76s/it]

CPU times: user 36.2 ms, sys: 5.8 ms, total: 42 ms
Wall time: 10.5 s





In [33]:
cleaned_summaries = []
STOP_SEQ = '. '

In [34]:
def clean_summary(summary):
    valid_sents = []
    sents = summary.split(STOP_SEQ)
    last_sent = sents[-1]
    if not last_sent.endswith('.'):
        sents = sents[0:-2]
    return ' '.join(sents)

In [35]:
for summary in summaries_by_chunks:
    summary = summary.replace('\n', '')
    summary = summary.replace('  ', ' ')
    summary = summary.replace('\'', '')
    summary = summary.strip()
    cleaned_summary = clean_summary(summary)
    if not cleaned_summary.endswith('.'):
        cleaned_summary = cleaned_summary + '.'
    if len(cleaned_summary) >= 64:  # atleast 64 chars
        cleaned_summaries.append(cleaned_summary)

In [36]:
logger.info(f'Total number of short summaries generated = {len(cleaned_summaries)}')

Total number of short summaries generated = 6


In [37]:
cleaned_summaries

['Brijlal and Mst Mohan Dei were rival claimants to a part of a property Brijlal claimed to have the whole of the property and Mst Mohan Dei had only a right of maintenance They pressed their claims and persuaded an arbitrator that they had an immediate right to part of the estate Mst Mohan Dei resisted this claim and contended that she was entitled to separate and exclusive possession, and in any event, that she was entitled in absolute right to a part of the property On the facts which now emerge it is evident that Brijlal had no right and that his hopes of one day succeeding as reversioner were remote Mohan Dei had a son Shri Kishan Das who was the next presumptive reversioner and as the boy was a good deal younger than Brijlal, Brijlal s chances were slim Actually, the boy survived Brijlal by nearly forty years.',
 'A judicial committee of the Court of Appeal in a land dispute between Kishan Lal and his son, concluded that the estoppel against Kishan Lal, which had arose under the 

#### III. Question Generation

In [38]:
questions_map = {}
total_questions_generated = 0

In [39]:
detect_words = ['why', 'how', 'what', 'who', 'where', 'is', 'when', 'which', 'whose', 'are', 'do', 'does', 'can', 'could', 'should', 'will', 'have', 'has']

def is_a_question(question):
    first_word = question.split()[0]
    if first_word.lower() in detect_words:
        return True
    return False

In [40]:
%%time

for summary in tqdm(cleaned_summaries):
        prompt = f"""EXTRACT QUESTIONS
        Context: 
        {summary}
        Questions:
        """
        payload = {'text_inputs': prompt, 
               'max_length': 512, 
               'temperature': 0,
               'num_return_sequences': 1,
               'top_k': TOP_K,
               'top_p': TOP_P,
               'do_sample': DO_SAMPLE}
        payload = json.dumps(payload).encode('utf-8')
    
        try:            
            response = client.invoke_endpoint(EndpointName=ENDPOINT_NAME, 
                                  ContentType=CONTENT_TYPE, 
                                  Body=payload)
    
            model_predictions = json.loads(response['Body'].read())
            generated_text = model_predictions['generated_texts'][0]
            
            
            questions = generated_text.split('\n')
            cleaned_questions = set()
            for question in questions:
                if len(question) > 5:
                    question = re.sub(r'\d+\.', '', question)
                    question = question.replace('Q:', '')
                    question = question.strip()
                    if is_a_question(question) is True:
                        cleaned_questions.add(question)
            total_questions_generated += len(cleaned_questions)
            questions_map[summary] = cleaned_questions
        except Exception:
            pass

100%|██████████| 6/6 [00:00<00:00, 19.71it/s]

CPU times: user 33.8 ms, sys: 103 µs, total: 33.9 ms
Wall time: 308 ms





In [41]:
logger.info(f'Total questions generated = {total_questions_generated}')

Total questions generated = 0


#### 4. Abstractive Question & Answering

In [113]:
%%time

qa_pairs = []

for context, questions in tqdm(questions_map.items()):
    for question in questions:
        prompt = f"""Context = {context}
        Question = {question}
        Answer = 
        """
        try:
            response = client.generate(prompt=prompt, 
                               max_tokens=128, 
                               temperature=0, 
                               return_likelihoods='GENERATION')

            generated_text = response.generations[0].text
            answer = generated_text.strip()
            qa_pairs.append((doc_name, context, question, answer))
        except Exception:
            pass

100%|██████████| 6/6 [01:36<00:00, 16.12s/it]

CPU times: user 78.5 ms, sys: 10.1 ms, total: 88.7 ms
Wall time: 1min 36s





#### 5. Combine short summaries into a long form summary

In [114]:
long_form_summary = []
for short_summary in cleaned_summaries:
    short_summary = short_summary.replace('\'', '')
    long_form_summary.append(short_summary)
long_form_summary = '\n\n'.join(long_form_summary)

#### 6. Write long form summary and QA pairs to disk

In [115]:
with open(f'./summaries/summary_5.txt', 'w') as out:
    out.write(long_form_summary)

#### 7. Write QA pairs to disk

In [116]:
df = pd.DataFrame(qa_pairs, columns=['doc_name', 'short_summary', 'question', 'answer'])
df.head()

Unnamed: 0,doc_name,short_summary,question,answer
0,5.txt,The context is a civil appeal filed by Bakshi ...,What is the position of the plaintiffs?,The plaintiffs are entitled to the reversion.
1,5.txt,The context is a civil appeal filed by Bakshi ...,What is the effect of the first question?,The first question is about the nature of the ...
2,5.txt,The context is a civil appeal filed by Bakshi ...,What is the effect of the award?,The effect of the award is that the plaintiffs...
3,5.txt,The context is a civil appeal filed by Bakshi ...,What is the nature of the award?,
4,5.txt,The context is a civil appeal filed by Bakshi ...,What is the position of the defendants?,The defendants say that it gave Mst. Mohan Dei...


In [117]:
df.to_csv('./qa_pairs/qa_pairs_5.csv', index=False)