# SageMaker JumpStart - invoke text generation endpoint

This notebook demonstrates how to attach a predictor to an existing endpoint name and invoke the endpoint with example payloads.

In [1]:
from sagemaker.predictor import retrieve_default

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/sagemaker-user/.config/sagemaker/config.yaml


Retrieve a predictor from your deployed endpoint name.

In [2]:
endpoint_name = "jumpstart-dft-meta-textgeneration-l-20240428-043847"
predictor = retrieve_default(endpoint_name)

Now query your endpoint with example payloads.

This model supports the following payload parameters. You may specify any subset of these parameters when invoking an endpoint.

* **do_sample:** If True, activates logits sampling. If specified, it must be boolean.
* **max_new_tokens:** Maximum number of generated tokens. If specified, it must be a positive integer.
* **repetition_penalty:** A penalty for repetitive generated text. 1.0 means no penalty.
* **return_full_text:** If True, input text will be part of the output generated text. If specified, it must be boolean. The default value for it is False.
* **stop**: If specified, it must a list of strings. Text generation stops if any one of the specified strings is generated.
* **seed**: Random sampling seed.
* **temperature:** Controls the randomness in the output. Higher temperature results in output sequence with low-probability words and lower temperature results in output sequence with high-probability words. If `temperature` -> 0, it results in greedy decoding. If specified, it must be a positive float.
* **top_k:** In each step of text generation, sample from only the `top_k` most likely words. If specified, it must be a positive integer.
* **top_p:** In each step of text generation, sample from the smallest possible set of words with cumulative probability `top_p`. If specified, it must be a float between 0 and 1.
* **truncate:** Truncate inputs tokens to the given size.
* **typical_p:** Typical decoding mass, according to [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666).
* **best_of:** Generate best_of sequences and return the one if the highest token logprobs.
* **watermark:** Whether to perform watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226).
* **details:** Return generation details, to include output token logprobs and IDs.
* **decoder_input_details:** Return decoder input token logprobs and IDs.
* **top_n_tokens:** Return the N most likely tokens at each step.

In [3]:
import random
import argparse
import os
import time
import re
import json
import multiprocessing
import numpy as np
import torch
from statistics import mean

import datasets
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset

In [4]:
dataset = load_dataset("gsm8k", "main", ignore_verifications=True)



In [5]:
data_train = dataset['train']
data_test = dataset['test']

In [6]:
def data_reader(data=None):

    questions = []
    answers = []
    
    for datum in data:
        questions.append(datum["question"].strip())
        answers.append(datum["answer"].split("#### ")[-1]) 
    
    q_len_list = []
    for q in questions:
        q_len_list.append(len(q.split(" ")))
    q_len_mean = mean(q_len_list)
    
    print("dataset : {}".format('gsm8k'))
    print("data size : {}".format(len(answers)))
    print("average num of words for each sample : {}".format(q_len_mean))
    
    return questions, answers

In [7]:
def create_prompt(direct_answer_trigger_for_fewshot=None, cot_flag=None):
    x, z, y = [], [], []
    
    direct_answer_trigger_for_fewshot = "The final answer is " if direct_answer_trigger_for_fewshot is None else direct_answer_trigger_for_fewshot
    # example sentences ...    
    if True:
        
        x.append("There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?")
        z.append("There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6.")
        y.append("6")

        x.append("If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?")
        z.append("There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.")
        y.append("5")        

        x.append("Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?")
        z.append("Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39.")
        y.append("39")        

        x.append("Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?")
        z.append("Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8.")
        y.append("8")        

        x.append("Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?")
        z.append("Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9.")
        y.append("9")        

        x.append("There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?")
        z.append("There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29.")
        y.append("29")        

        x.append("Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?")
        z.append("Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls.")
        y.append("33")        

        x.append("Olivia has $23. She bought five bagels for $3 each. How much money does she have left?")
        z.append("Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8.")
        y.append("8")
    
    else:
        raise ValueError("dataset is not properly defined ...")
        
    # randomize order of the examples ...
    index_list = list(range(len(x)))
    random.shuffle(index_list)

    # Concatenate demonstration examples ...
    demo_text = ""
    for i in index_list:
#         if cot_flag:
        demo_text += "\nQuestion: " + x[i] + "\n ### Answer: " + z[i] + "\n####" + \
                     direct_answer_trigger_for_fewshot + " " + y[i] + ".\n\n"
#         else:
#             demo_text += "Q: " + x[i] + "\nA: " + \
#                          args.direct_answer_trigger_for_fewshot + " " + y[i] + ".\n\n"
    
    return demo_text

In [8]:
import json
def get_rag_questions(filename):
    with open(filename, 'r') as file:
        ret = []
        for line in file:
            # print(line)
            json_obect = json.loads(line)
            ret.append(json_obect)
    return ret

In [9]:
rag_data = get_rag_questions('gsm8k_rag.jsonl')

In [10]:
def answer_question(question, index=0, rag=False, rag_data=rag_data):
    if rag:
        prompt = "Look at the examples below where math questions have been solved by reasoning step by step. Notice that the final answer for each question is after ####\n"
        prompt += rag_data[index]['prompt']
        prompt += '\nSimilar to how these questions have been solved, let\'s reason step by step and solve the following question and generate an answer. Make sure that the final numeric answers is after ### '
        prompt += "\n{\"question\": "
        prompt += question
        prompt += ", \"answer:\" "
    else:
        prompt = "Look at the examples below where math questions have been solved by reasoning step by step\n"
        prompt += create_prompt()
        prompt += '\nSimilar to how these questions have been solved, let\'s reason step by step and solve the following question and generate an answer. Make sure that the answer starts with ### Answer: '
        prompt += "\nQuestion: " + question + "\n### Answer: "

    # print(prompt)
    payload = {
        "inputs": "<s>[INST] " + prompt + " [/INST] ",
        "parameters": {
            "max_new_tokens": 512,
            "top_p": 0.9,
            "temperature": 0.6
        }
    }
    try:
        response = predictor.predict(payload)
    except:
        return "0"
    try:
        answer = response[0]['generated_text']
    except:
        print("error")
        answer = "0"
    return answer

# print(answer_question(context="Hi", question="What is 2+2 ?"))

def extract_last_number(s):
    # Find all sequences of digits in the string
    if s is None:
        return 0
    steps = s.split('\n')
    for s in reversed(steps):
        numbers = re.findall(r'\d+', s)
        if numbers:
            return numbers[-1]
    # Return the last number found, or None if no numbers are present
    return 0

In [77]:
# print(rag_data[0]['prompt'])

In [11]:
def evaluate(dataset, k=100):
    questions, answers = data_reader(dataset)
    model_answers_no_verify = []
    correct_answers_cnt = 0
    for i, question in enumerate(questions[0:k]):
        answer = float(answers[i])
        model_long_answer = answer_question(question)
        model_short_answer = float(extract_last_number(model_long_answer))
        model_answers_no_verify.append({
            'correct_answer': answer,
            'model_short_answer': model_short_answer,
            'model_long_answer': model_long_answer
        })
        correct_answers_cnt += 1 if model_short_answer == answer else 0
        # print(model_long_answer)
        print(model_short_answer, answer)
        # if i % 10 == 9:
        #     print(model_long_answer)
        if i % 100 == 99:
            break
    return model_answers_no_verify, correct_answers_cnt

In [12]:
import json
def write_results_to_file(filename, data):
    # Open a file and write each dictionary as a JSON string
    with open(filename, 'w') as file:
        for item in data:
            json_string = json.dumps(item)
            # print(item)
            # print('\n')
            file.write(json_string + '\n')

In [17]:
def self_verify_1(question, rag=False):
    prev_answer = None

    for i in range(5):
        model_long_answer = answer_question(question, rag=True)
        model_short_answer = float(extract_last_number(model_long_answer))

        if prev_answer == model_short_answer:
            return True, model_long_answer, model_short_answer, i
        prev_answer = model_short_answer

    return False, model_long_answer, model_short_answer, 5



def self_verify_2(question, rag=False):
    prev_verification_answer = None
    answers_dict = {}
    for i in range(4):
        model_long_answer = answer_question(question, rag=True)
        model_short_answer = float(extract_last_number(model_long_answer))
        answers_dict_ = answers_dict.get(model_short_answer, [])
        if answers_dict_ == []:
            answers_dict_.append(1)
            answers_dict_.append(model_long_answer)
        else:
            answers_dict_[0] = answers_dict_[0] + 1
        answers_dict[model_short_answer] = answers_dict_
    answers_dict = {k: v for k, v in sorted(answers_dict.items(), key=lambda x:x[1][0])}
    for k, v in answers_dict.items():
        model_short_answer = k
        ans_cnt = v[0]
        model_long_answer = v[1]
    verified = (ans_cnt >= 2)
    return verified, model_long_answer, model_short_answer, ans_cnt

def evaluate_with_self_verify_(dataset, k=100, rag=False, filename=None, strategy=1):
    questions, answers = data_reader(data_test)
    model_answers_verify = []
    correct_answers_cnt, verification_confidence = 0, 0
    for i, question in enumerate(questions[0:k]):
        answer = float(answers[i])
        if strategy == 1:
            verified,  model_long_answer, model_short_answer, verify_step = self_verify_1(question, rag)
        elif strategy == 2:
            verified, model_long_answer, model_short_answer, ans_cnt = self_verify(question, rag)
        dct = {
            'number': i,
            'verified': verified,
            'correct_answer': answer,
            'model_short_answer': model_short_answer,
            'model_long_answer': model_long_answer
        }
        if strategy == 1:
            dct['verify_step'] = verify_step
        elif strategy == 2:
             dct['ans_cnt'] = ans_cnt
        model_answers_verify.append(dct)
        
            
        correct_answers_cnt += 1 if model_short_answer == answer else 0
        # verification_confidence += 1 if verified else 0
        print(answer, model_short_answer, verified)
        # filename = 'base_llama_gsm_8k_verify_100samples.json'
        with open(filename, 'a') as file:
            json_string = json.dumps(model_answers_verify[i])
            file.write(json_string + '\n')
    return model_answers_verify, correct_answers_cnt, verification_confidence

In [18]:
model_answers_verify, correct_answer_count, verification_confidence = evaluate_with_self_verify_(data_test, k=1000, rag=False, filename='llama_verify_1_rag.json')

dataset : gsm8k
data size : 1319
average num of words for each sample : 46.91357088703563
18.0 32.0 True
3.0 3.0 True
70000.0 0.0 True
540.0 540.0 False
20.0 10.0 False
64.0 3.0 False
260.0 128.0 True
160.0 67.0 True
45.0 140.0 False
460.0 115.0 False
366.0 366.0 True
694.0 803.0 False
13.0 5.0 True
18.0 5.0 True
60.0 45.0 False
125.0 125.0 True
230.0 460.0 False
57500.0 57500.0 True
7.0 84.0 True
6.0 3.0 True
15.0 91.0 False
14.0 48.0 False
7.0 7.0 True
8.0 10.0 False
26.0 26.0 True
2.0 5.0 True
243.0 233.0 True
16.0 0.0 False
25.0 40.0 False
104.0 5.0 True
109.0 110.0 True
80.0 96.0 True
35.0 35.0 True
70.0 140.0 True
23.0 38.0 True
9.0 9.0 True
75.0 0.0 True
2.0 2.0 False
10.0 120.0 False
18.0 2.0 False
8.0 16.0 False
200.0 200.0 True
26.0 26.0 True
48.0 83.0 False
20.0 0.0 True
104.0 56.0 False
163.0 193.0 False
800.0 400.0 False
8.0 6.0 True
30.0 39.0 False
294.0 80.0 True
5.0 2.0 False
15.0 15.0 False
40.0 380.0 False
40.0 33.0 True
14.0 14.0 True
3.0 3.0 True
83.0 3755.0 False
5