# Long Form Text Summarization using JumpStart Foundation Model  

In [1]:
%%capture 

!pip install cohere_sagemaker
!pip install pandas
!pip install tqdm

## Part - I: Hosting the foundation model for real-time inference 

#### I. Imports 

In [None]:
from sagemaker import get_execution_role
from cohere_sagemaker import CohereError
from cohere_sagemaker import Client
from sagemaker import ModelPackage
import cohere_sagemaker
from tqdm import tqdm
import pandas as pd
import numpy as np
import sagemaker
import logging
import boto3
import time
import re

#### Setup Logging 

In [None]:
logger = logging.getLogger('sagemaker')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

##### Log versions of dependencies 

#### II. Setup essentials 

Mapping for Model Packages (initially only us-east-1 and eu-west-1 is supported)

In [None]:
model_package_map = {
    'us-east-1': 'arn:aws:sagemaker:us-east-1:865070037744:model-package/cohere-gpt-medium-v1-4-825b877abfd53d7ca65fd7b4b262c421',
    'eu-west-1': 'arn:aws:sagemaker:eu-west-1:985815980388:model-package/cohere-gpt-medium-v1-4-825b877abfd53d7ca65fd7b4b262c421'
}

In [None]:
region = boto3.Session().region_name
logger.info(f'Region = {region}')

In [None]:
if region not in model_package_map.keys():
    raise Exception(f'Unsupported region = {region}')

In [None]:
MODEL_PACKAGE_ARN = model_package_map[region]
logger.info(f'Model package ARN = {MODEL_PACKAGE_ARN}')

In [None]:
ROLE = get_execution_role()
session = sagemaker.Session()
logger.info(f'Role = {ROLE}')

In [None]:
timestamp = int(time.time())
MODEL_NAME = f'cohere-medium-{timestamp}'

#### III. Create a SageMaker endpoint for real-time inference 

In [None]:
model = ModelPackage(role=ROLE, 
                     model_package_arn=MODEL_PACKAGE_ARN, 
                     sagemaker_session=session, 
                     name=MODEL_NAME)
model.__dict__

In [None]:
NUM_INSTANCES = 1
INSTANCE_TYPE = 'ml.g5.xlarge'

In [None]:
%%time

model.deploy(NUM_INSTANCES, 
             INSTANCE_TYPE, 
             endpoint_name=MODEL_NAME)

## Part 2: Long-form abstractive text summarization of legal judgement docs

#### I. Read, parse and chunk docs 

In [96]:
doc_name = '5.txt'

In [97]:
cleaned_lines = []
with open(f'./docs/{doc_name}', encoding='iso-8859-1') as doc:
    for line in doc.readlines():
        line = line.strip()
        line = re.sub(' +', ' ', line)
        line = line.replace('\n', '')
        line = line.replace('\t', '')
        line = line.replace('  ', ' ')
        if len(line) > 0:
            cleaned_lines.append(line)

In [None]:
doc = ' '.join(cleaned_lines)
doc = doc.split()
len(doc)

In [99]:
chunk_size = 768
chunks = [' '.join(doc[i:i+chunk_size]) for i in range(0, len(doc), chunk_size)]

In [None]:
len(chunks)

#### II. Short-form Abstractive Text Summarization

In [101]:
ENDPOINT_NAME = 'cohere-medium-1679931302'
# ENDPOINT_NAME = MODEL_NAME

In [102]:
client = Client(endpoint_name=ENDPOINT_NAME)

In [103]:
summaries_by_chunks = []

In [None]:
%%time


for chunk in tqdm(chunks):
    prompt = f'Context = {chunk}\nSummarize the above context.'
    response = client.generate(prompt=prompt, 
                           max_tokens=256, 
                           temperature=0.2, 
                           return_likelihoods='GENERATION')
    generated_text = response.generations[0].text
    summaries_by_chunks.append(generated_text)

In [105]:
cleaned_summaries = []
STOP_SEQ = '. '

In [106]:
def clean_summary(summary):
    valid_sents = []
    sents = summary.split(STOP_SEQ)
    last_sent = sents[-1]
    if not last_sent.endswith('.'):
        sents = sents[0:-2]
    return ' '.join(sents)

In [107]:
for summary in summaries_by_chunks:
    summary = summary.replace('\n', '')
    summary = summary.replace('  ', ' ')
    summary = summary.replace('\'', '')
    summary = summary.strip()
    cleaned_summary = clean_summary(summary)
    if not cleaned_summary.endswith('.'):
        cleaned_summary = cleaned_summary + '.'
    if len(cleaned_summary) >= 64:  # atleast 64 chars
        cleaned_summaries.append(cleaned_summary)

In [None]:
logger.info(f'Total number of short summaries generated = {len(cleaned_summaries)}')

#### III. Question Generation

In [109]:
questions_map = {}
total_questions_generated = 0

In [110]:
detect_words = ['why', 'how', 'what', 'who', 'where', 'is', 'when', 'which', 'whose', 'are', 'do', 'does', 'can', 'could', 'should', 'will', 'have', 'has']

def is_a_question(question):
    first_word = question.split()[0]
    if first_word.lower() in detect_words:
        return True
    return False

In [None]:
%%time

for summary in tqdm(cleaned_summaries):
        prompt = f"""EXTRACT QUESTIONS
        Context: 
        {summary}
        Questions:
        """
        try:
            response = client.generate(prompt=prompt, 
                                   max_tokens=512, 
                                   temperature=0, 
                                   return_likelihoods='GENERATION')
            generated_text = response.generations[0].text
            questions = generated_text.split('\n')
            cleaned_questions = set()
            for question in questions:
                if len(question) > 5:
                    question = re.sub(r'\d+\.', '', question)
                    question = question.replace('Q:', '')
                    question = question.strip()
                    if is_a_question(question) is True:
                        cleaned_questions.add(question)
            total_questions_generated += len(cleaned_questions)
            questions_map[summary] = cleaned_questions
        except Exception:
            pass

In [None]:
logger.info(f'Total questions generated = {total_questions_generated}')

#### 4. Abstractive Question & Answering

In [None]:
%%time

qa_pairs = []

for context, questions in tqdm(questions_map.items()):
    for question in questions:
        prompt = f"""Context = {context}
        Question = {question}
        Answer = 
        """
        try:
            response = client.generate(prompt=prompt, 
                               max_tokens=128, 
                               temperature=0, 
                               return_likelihoods='GENERATION')

            generated_text = response.generations[0].text
            answer = generated_text.strip()
            qa_pairs.append((doc_name, context, question, answer))
        except Exception:
            pass

#### 5. Combine short summaries into a long form summary

In [114]:
long_form_summary = []
for short_summary in cleaned_summaries:
    short_summary = short_summary.replace('\'', '')
    long_form_summary.append(short_summary)
long_form_summary = '\n\n'.join(long_form_summary)

#### 6. Write long form summary and QA pairs to disk

In [115]:
with open(f'./summaries/summary_5.txt', 'w') as out:
    out.write(long_form_summary)

#### 7. Write QA pairs to disk

In [None]:
df = pd.DataFrame(qa_pairs, columns=['doc_name', 'short_summary', 'question', 'answer'])
df.head()

In [117]:
df.to_csv('./qa_pairs/qa_pairs_5.csv', index=False)