<a href="https://colab.research.google.com/github/pramodith/llm_exploration/blob/colab/dynamic_prompt_token_dropping.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [24]:
!pip install transformers
!pip install langchain
!pip install datasets
!pip install huggingface_hub



In [39]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List
import torch
from pprint import pprint
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from datasets import load_dataset
from langchain.callbacks import get_openai_callback

In [26]:
model_name = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

In [27]:
def to_tokens_and_probs(model: AutoModelForCausalLM, tokenizer: AutoTokenizer, input_texts: List[str]):
    """
    This function takes a list of input texts and returns a list of tuples (token, prob) for each token in the input text.
    Reference: https://discuss.huggingface.co/t/announcement-generation-get-probabilities-for-generated-output/30075/17

    Args:
        model (AutoModelForCausalLM): _description_
        tokenizer (AutoTokenizer): _description_
        input_texts (List[str]): _description_

    Returns:
        _type_: _description_
    """
    input_ids = tokenizer(input_texts, padding=True, return_tensors="pt").input_ids
    outputs = model(input_ids)
    probs = torch.softmax(outputs.logits, dim=-1).detach()

    # collect the probability of the generated token -- probability at index 0 corresponds to the token at index 1
    probs = probs[:, :-1, :]
    input_ids = input_ids[:, 1:]
    gen_probs = torch.gather(probs, 2, input_ids[:, :, None]).squeeze(-1)

    batch = []
    for input_sentence, input_probs in zip(input_ids, gen_probs):
        text_sequence = []
        for token_pos, (token, p) in enumerate(zip(input_sentence, input_probs)):
            if token not in tokenizer.all_special_ids:
                text_sequence.append((tokenizer.decode(token), p.item(), token_pos, token.item()))
        batch.append(text_sequence)
    return batch

In [28]:
sample_prompt = ["The capital of France is Paris."]
token_probs = to_tokens_and_probs(model, tokenizer, sample_prompt)


In [29]:
token_probs

[[('The', 0.08746562898159027, 0, 133),
  (' capital', 8.743484795559198e-05, 1, 812),
  (' of', 0.14376769959926605, 2, 9),
  (' France', 0.005475882440805435, 3, 1470),
  (' is', 0.24484379589557648, 4, 16),
  (' Paris', 0.0067796302028000355, 5, 2201),
  ('.', 0.33608853816986084, 6, 4)]]

Set your key to openai's API's using colab's secrets features.

In [66]:
from google.colab import userdata
OPENAI_API_KEY = userdata.get('OPENAI_API_KEY')
GPT_35_TURBO_COST_PER_INP_TOKEN = 0.0010/1000

In [34]:
chat = ChatOpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)

In [35]:
from pydantic import BaseModel
class ReadingComprehensionPrompt(BaseModel):
    system_message: str = SystemMessagePromptTemplate.from_template(f"You are a very smart student in a reading comprehension class."
        "Your teacher is giving you a reading comprehension test. You are given a passage and a question."
        "You must answer the question based on the passage.")
    human_message: str = HumanMessagePromptTemplate.from_template("Passage: {passage}\nQuestion: {question}\nAnswer: ")

rc_prompt = ReadingComprehensionPrompt()

In [51]:
def get_llm_response(prompt: BaseModel, chat: ChatOpenAI, passage: str, question: str):
    chat_message = ChatPromptTemplate.from_messages([prompt.system_message, prompt.human_message])
    completed_prompt = chat_message.format_prompt(passage=passage, question=question).to_messages()
    with get_openai_callback() as cb:
      response = chat(completed_prompt)
    return response, cb

In [52]:
get_llm_response(rc_prompt, chat, "The capital of France is Paris.", "What is the capital of Australia?")

(AIMessage(content='The passage does not provide any information about the capital of Australia.'),
 Tokens Used: 86
 	Prompt Tokens: 73
 	Completion Tokens: 13
 Successful Requests: 1
 Total Cost (USD): $0.0001355)

In [61]:
def sample_squad_dataset(num_samples: int = 100):
    dataset = load_dataset("squad", split="validation")
    return dataset.shuffle().select(range(num_samples))

In [62]:
squad_dataset = sample_squad_dataset()

In [63]:
squad_dataset[0]

{'id': '56e76abf37bdd419002c3f75',
 'title': 'Teacher',
 'context': "Teachers face several occupational hazards in their line of work, including occupational stress, which can negatively impact teachers' mental and physical health, productivity, and students' performance. Stress can be caused by organizational change, relationships with students, fellow teachers, and administrative personnel, working environment, expectations to substitute, long hours with a heavy workload, and inspections. Teachers are also at high risk for occupational burnout.",
 'question': "What can hurt a teacher's mental and physical health?",
 'answers': {'text': ['occupational stress',
   'occupational stress',
   'occupational stress'],
  'answer_start': [76, 76, 76]}}

In [74]:
import numpy as np
def get_token_dropped_text(doc: str, tokenizer:AutoTokenizer, dropout_percent=0.1) -> str:
    token_probs = to_tokens_and_probs(model, tokenizer, [doc])
    tokens = [token[-1] for token in token_probs[0]]
    num_tokens_to_drop = int(len(tokens)*dropout_percent)
    top_10_percent_tokens = sorted(token_probs[0], key=lambda x: x[1], reverse=True)[:len(tokens) // 10]
    dropped_tokens = [t[0] for t in top_10_percent_tokens]
    tokens_after_deletion = np.delete(tokens, [token[2] for token in top_10_percent_tokens])
    # Remove the top 10% of tokens
    dropped_token_text = tokenizer.decode(tokens_after_deletion)
    return dropped_token_text, dropped_tokens


In [None]:
total_dropped_cost = 0
total_cost = 0
dropped_tokens = []
for i in range(10):
  dropped_context, dropped_context_text = get_token_dropped_text(squad_dataset[i]["context"], tokenizer)
  dropped_question, dropped_question_text = get_token_dropped_text(squad_dataset[i]["question"], tokenizer)
  answer, cb = get_llm_response(rc_prompt, chat, passage = squad_dataset[i]["context"], question = squad_dataset[i]["question"])
  dropped_answer, cb_d = get_llm_response(rc_prompt, chat, passage=dropped_context, question=dropped_question)
  dropped_cost = cb_d.prompt_tokens*GPT_35_TURBO_COST_PER_INP_TOKEN
  cost = cb.prompt_tokens*GPT_35_TURBO_COST_PER_INP_TOKEN
  total_dropped_cost += dropped_cost
  total_cost += cost
  print(f"Actual answer is {squad_dataset[i]['answers']}")
  print(f"Answer for dropped text is {dropped_answer}, cost is {total_dropped_cost}")
  print(f"Answer for original text is {answer}, cost is {total_cost}")
  squad_dataset[i]["predicted_answer"] = answer
  squad_dataset[i]["predicted_answer_for_dropped"] = dropped_answer
  dropped_tokens.extend(dropped_context_text)
  dropped_tokens.extend(dropped_question_text)



Actual answer is {'text': ['occupational stress', 'occupational stress', 'occupational stress'], 'answer_start': [76, 76, 76]}
Answer for dropped text is content="Occupational stress can hurt a teacher's mental and physical health.", cost is 0.00014099999999999998
Answer for original text is content="Occupational stress can hurt a teacher's mental and physical health.", cost is 0.00015
Actual answer is {'text': ['Baden-Württemberg', 'Baden-Württemberg', 'Baden-Württemberg', 'Baden-Württemberg'], 'answer_start': [323, 323, 323, 323]}
Answer for dropped text is content='Lake Constance separates the German state Bavaria from the Austrian state Vorarlberg.', cost is 0.00036399999999999996
Answer for original text is content='Lake Constance separates the German state Bavaria from the German state of Baden-Württemberg.', cost is 0.00038599999999999995
Actual answer is {'text': ['bans', 'bans on foreign popular culture, control of the internet and unauthorised satellite dishes', 'bans on fore