<a href="https://colab.research.google.com/github/rsr2425/word-count-investigation/blob/main/notebooks/5_Custom_Decoder_%2B_Chaining_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
CURRENT_MODEL_ID = "gpt-3.5-turbo"
LLM_JUDGE_MODEL_ID = "gpt-4o"
PROJECT_NAME = "word-count-investigation"

In [None]:
!pip install datasets langchain_openai rouge-score evaluate wandb deepeval bitsandbytes



In [None]:
import os
from google.colab import userdata

os.environ['OPENAI_API_KEY'] = userdata.get('OPENAI_API_KEY')

In [None]:
import evaluate

rouge_score = evaluate.load("rouge")

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

## Dataset

In [None]:
from datasets import load_dataset

# dataset = load_dataset("ccdv/cnn_dailymail", '3.0.0', split="test[:1000]")
# dataset = dataset.rename_column('article', 'text')
# dataset = dataset.rename_column('highlights', 'summary')
# dataset = dataset.remove_columns(['id'])

dataset = load_dataset("billsum", split="ca_test")
# dataset.drop_column('title')

README.md:   0%|          | 0.00/7.27k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/91.8M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/15.8M [00:00<?, ?B/s]

ca_test-00000-of-00001.parquet:   0%|          | 0.00/6.12M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/18949 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3269 [00:00<?, ? examples/s]

Generating ca_test split:   0%|          | 0/1237 [00:00<?, ? examples/s]

AttributeError: 'Dataset' object has no attribute 'drop_column'

In [None]:
dataset

Dataset({
    features: ['text', 'summary', 'title'],
    num_rows: 1237
})

## Metrics

In [None]:
import json

def compute_rouge(record, **kargs):
    return rouge_score.compute(
        predictions=[record['ai_summary']],
        references=[record['summary']]
    )

In [None]:
from json import JSONDecodeError

def generate_questions(text, llm, n):
    messages = [
      ("system", """
        You are a helpful question generating chatbot.  Generate {n} factual questions
        from the text provided by the user. Make sure these questions can be answered
        using the provided text, and that the answers should be yes or no. Make sure there
        are both questions that can be answered with yes and questions that can be answered
        with no. Think through step by step before answering and make sure there are a mix
        of answers to the questions you provide.

        Return the questions as a json containing a list of strings.
        """
      ),
      ("human", f"{text}"),
    ]
    ai_msg = llm.invoke(messages)
    questions = []
    try:
        questions = json.loads(ai_msg.content)['questions']
    except JSONDecodeError as e:
        questions = {'questions': [''] * n}
    return questions

def generate_anwsers(questions, source_text, llm):
    messages = [
      ("system", """
        You are a helpful question answering chatbot.  The user will give you a list of questions and the text off which you
        should answer them. Answer the questions using the provided text. Answer only with "Yes", "No", or "idk". If the
        question cannot be answered using the provided text, answer with "idk". If you are unsure, answer with "idk".
        If the question string is empty, answer with "idk".

        Return the answers as a json containing a list of strings.
        """
      ),
      ("human", f"""
        Please answer the following questions:

          {questions}

        using this text:

          {source_text}
      """),
    ]
    ai_msg = llm.invoke(messages)
    answers = []
    try:
        answers = json.loads(ai_msg.content)['answers']
    except (JSONDecodeError, TypeError) as e:
        answers = ['idk'] * len(questions)
    return answers

def compute_factual_consistency(record, llm, n):
    # TODO figure out why n isn't always respected
    questions = generate_questions(record['text'], llm, n)
    gt_answers = generate_anwsers(questions, record['text'], llm)
    # assert len(gt_answers) == n
    human_summary_answers = generate_anwsers(questions, record['summary'], llm)
    # assert len(human_summary_answers) == n
    ai_summary_answers = generate_anwsers(questions, record['ai_summary'], llm)
    # assert len(ai_summary_answers) == n

    if all(x == 'idk' for x in human_summary_answers):
        hfc = 0
    else:
        hfc = sum([1 if x == y else 0 for x, y in zip(human_summary_answers, gt_answers)]) / float(len(questions))
    if all(x == 'idk' for x in ai_summary_answers):
        afc = 0
    else:
        afc = sum([1 if x == y else 0 for x, y in zip(ai_summary_answers, gt_answers)]) / float(len(questions))

    return {
        'gt_answers': gt_answers,
        'human_summary_answers': human_summary_answers,
        'ai_summary_answers': ai_summary_answers,
        'human_factual_consistency': hfc,
        'ai_factual_consistency': afc,
    }

In [None]:
import enum

class Metric(enum.Enum):
    ROUGE = "ROUGE"
    FACTUAL_CONSISTENCY = "Factual Consistency"

    def __str__(self):
        return self.value

metric_fn_mapping = {
    Metric.ROUGE: compute_rouge,
    Metric.FACTUAL_CONSISTENCY: compute_factual_consistency,
}

## Helper Functions

In [None]:
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_openai import ChatOpenAI
from typing import Any, Dict

class WordCountControlRunnable(Runnable):
    def __init__(
        self,
        llm: ChatOpenAI,
        word_count_target: int = 25,
        tolerance: int = 10,
        revision_attempts: int = 5,
    ):
        self.llm = llm
        self.word_count_target = word_count_target
        self.tolerance = tolerance
        self.revision_attempts = revision_attempts

    def invoke(
        self,
        input: Any,
        config: RunnableConfig = None,
        **kwargs: Any,
    ) -> Any:
        # Extract the raw text from input
        sample_text = input if isinstance(input, str) else input.get("sample_text")

        # Prepare the initial messages
        # TODO is it bad I'm asking this exact thing twice essentially?
        messages = [
            ("system", "You are a helpful summary chatbot. Summarize the content provided by the user."),
            ("human", sample_text),
        ]

        attempt = 0
        ai_summary = None

        # Iterative refinement loop
        while attempt < self.revision_attempts and (
            ai_summary is None or self._count_words(ai_summary) > self.word_count_target + self.tolerance
        ):
            attempt += 1
            ai_msg = self.llm.invoke(messages)
            ai_summary = ai_msg.content
            messages.append(("ai", ai_summary))
            messages.append(("human", "Shorten this."))

        return {"final_summary": ai_summary, "attempts": attempt}

    def _count_words(self, text: str) -> int:
        """Utility function to count words in a given text."""
        return len(text.split())

In [None]:
from langchain_core.output_parsers import StrOutputParser

TOLERANCE = 10
REVISION_ATTEMPTS = 5

def count_words(text):
    return len(text.split())

def summarize(record, llm, word_count_target=None):
    messages = [
        ("system", """
        You are a helpful summary chatbot.  Summarize the content provided by the user.
        """),
        ("human", f"{record['text']}"),
    ]
    chain = llm | StrOutputParser()
    unpolished_summary = chain.invoke(messages)
    messages = [
        ("system", """
        You are a helpful chatbot.  Take the content provided by the user and polish it so that it is a complete thought.
        """),
        ("human", f"{unpolished_summary}"),
    ]
    ai_summary = chain.invoke(messages)

    return {
        'text_word_count': count_words(record['text']),
        'summary_word_count': count_words(record['summary']),
        'ai_summary': ai_summary,
        'ai_summary_word_count': count_words(ai_summary),
    }

In [None]:
def process_dataset(dataset, llm, n, metrics, word_count_target=None):
    print("Processing Dataset!")
    print("Now summarizing data...")
    processed_dataset = dataset.map(
        summarize,
        fn_kwargs={
            'llm': llm,
            'word_count_target': word_count_target
        }
    )
    # llm_judge = ChatOpenAI(model_name=LLM_JUDGE_MODEL_ID, temperature=temperature)
    for metric in metrics:
        print(f"Now calculating {str(metric)}...")
        processed_dataset = processed_dataset.map(
            metric_fn_mapping[metric],
            fn_kwargs={
                'llm': llm,
                # 'llm': llm_judge,
                'n': n
            }
        )
    print("Done!")
    return processed_dataset

In [None]:
import wandb

def log_dataset_to_wandb(dataset, project_name, run_name, split_name="dataset_split"):
    wandb.init(
        project=project_name,
        name=run_name,
        settings=wandb.Settings(_service_wait=300),
    )

    data_table = wandb.Table(columns=dataset.column_names)

    # Add rows from the dataset
    for row in dataset:
        data_table.add_data(*[row[col] for col in dataset.column_names])

    # Log the table to WandB
    wandb.log({split_name: data_table})

    wandb.finish()

In [None]:
def gen_run_name():
    pass

In [None]:
import bitsandbytes as bnb
import torch

from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "meta-llama/Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map='auto',
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
from transformers import LogitsProcessor

class GracefulWordCountLogitsProcessor(LogitsProcessor):
    def __init__(self, tokenizer, target_word_count, word_count_fn, buffer_window=5, completion_boost=5.0):
        self.eos_token_id = tokenizer.eos_token_id
        self.target_word_count = target_word_count
        self.word_count_fn = word_count_fn
        self.buffer_window = buffer_window
        self.completion_boost = completion_boost

    def __call__(self, input_ids, scores):
        # Calculate current word count using your custom function
        current_word_count = self.word_count_fn(input_ids)

        # If within the buffer window, increase EOS token probability
        if self.target_word_count - self.buffer_window <= current_word_count < self.target_word_count:
            # Identify likely sentence-ending tokens (e.g., '.', '!', '?')
            punctuation_tokens = [".", "!", "?"]
            punctuation_ids = [
                tokenizer.convert_tokens_to_ids(tok) for tok in punctuation_tokens if tok in tokenizer.vocab
            ]

            # Boost the logits for punctuation tokens
            for token_id in punctuation_ids:
                scores[:, token_id] += self.completion_boost  # Slightly increase completion token probabilities

            # # Optionally, give a small boost to EOS token without enforcing it
            # scores[:, self.eos_token_id] += self.completion_boost / 2

        # Prevent overshooting: strongly favor EOS if the count exceeds the target
        if current_word_count >= self.target_word_count:
            scores[:, :] = -float("inf")  # Set all probabilities to zero
            scores[:, self.eos_token_id] = 0.0  # Make EOS the only valid option

        return scores

In [None]:
from transformers import Pipeline
from torch import Tensor

# class CustomTextGenerationPipeline(Pipeline):
#     def __init__(self, model, tokenizer, logits_processor=None, **generate_kwargs):
#         super().__init__(model, tokenizer)
#         self.model = model
#         self.tokenizer = tokenizer
#         self.logits_processor = logits_processor
#         self.generate_kwargs = generate_kwargs

#     def _sanitize_parameters(self, **kwargs):
#         preprocess_kwargs = {}
#         if "maybe_arg" in kwargs:
#             preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
#         return preprocess_kwargs, {}, {}

#     def preprocess(self, inputs, maybe_arg=2):
#         # If inputs is a string, tokenize it first
#         if isinstance(inputs, str):
#             inputs = self.tokenizer(inputs, return_tensors="pt")
#         model_input = Tensor(inputs["input_ids"]).to(self.model.device)
#         return {"model_inputs": model_input}

#     def _forward(self, model_inputs):
#         outputs = self.model.generate(
#                 model_inputs['model_inputs'],
#                 logits_processor=self.logits_processor
#             )
#         return outputs

#     def postprocess(self, model_outputs):
#         model_outputs = model_outputs.squeeze(0)
#         # print(model_outputs)
#         # print(model_outputs.shape)
#         best_class = {'translation_text': model_outputs[0]}
#         return best_class

from transformers import Pipeline, LogitsProcessorList, PreTrainedModel, PreTrainedTokenizer
from typing import Optional, Any, Dict, Tuple, List

class CustomLogitsProcessorPipeline(Pipeline):
    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        logits_processor: Optional[LogitsProcessorList] = None,
        **kwargs
    ):
        """
        Initializes the custom pipeline.

        Args:
            model (PreTrainedModel): The pre-trained model to use.
            tokenizer (PreTrainedTokenizer): The tokenizer associated with the model.
            logits_processor (LogitsProcessorList, optional): A custom list of logits processors to apply.
            **kwargs: Additional arguments passed to the parent Pipeline class.
        """
        super().__init__(model=model, tokenizer=tokenizer, **kwargs)
        self.custom_logits_processor = logits_processor or LogitsProcessorList()

    def _sanitize_parameters(
        self,
        *args,
        **kwargs
    ) -> Tuple[Dict, Dict, Dict]:
        """
        Sanitizes the parameters passed to the pipeline methods.

        Args:
            *args: Positional arguments for the pipeline call.
            **kwargs: Keyword arguments for the pipeline call.

        Returns:
            Tuple[Dict, Dict, Dict]: Tuple of dictionaries for preprocess, forward, and postprocess.
        """
        preprocess_params = {}
        forward_params = {}
        postprocess_params = {}

        if "max_length" in kwargs:
            forward_params["max_length"] = kwargs["max_length"]
        if "top_k" in kwargs:
            forward_params["top_k"] = kwargs["top_k"]
        if "temperature" in kwargs:
            forward_params["temperature"] = kwargs["temperature"]

        preprocess_params.update(kwargs.get("preprocess_kwargs", {}))
        postprocess_params.update(kwargs.get("postprocess_kwargs", {}))

        return preprocess_params, forward_params, postprocess_params

    def preprocess(self, input_text: str, **kwargs) -> Dict[str, Any]:
        """
        Prepares inputs for the forward method.

        Args:
            inputs (Any): The raw input to preprocess.
            **kwargs: Additional preprocessing arguments.

        Returns:
            Dict[str, Any]: Preprocessed inputs in dictionary format.
        """
        input_ids = self.tokenizer(
            input_text,
            return_tensors="pt",
        ).input_ids
        input_ids = input_ids.to(self.model.device)
        tokenized_inputs = {"input_ids": input_ids}
        return tokenized_inputs

    def _forward(self, inputs: dict, return_tensors: bool = False, **generate_kwargs) -> Any:
        """
        Forward pass through the pipeline with custom logits processing.

        Args:
            inputs (dict): The inputs to the model.
            return_tensors (bool, optional): Whether to return tensors.
            **generate_kwargs: Additional arguments passed to the generate method.

        Returns:
            Any: The output of the pipeline with custom logits processing applied.
        """
        generate_kwargs["logits_processor"] = self.custom_logits_processor
        return self.model.generate(
            inputs["input_ids"],
            **generate_kwargs,
        )

    def postprocess(self, model_outputs: Any, **kwargs) -> List[str]:
        """
        Custom postprocess method to handle the model outputs.

        Args:
            model_outputs (Any): The raw outputs from the model.
            **kwargs: Additional arguments for postprocessing.

        Returns:
            List[str]: Postprocessed outputs.
        """
        generated_text = self.tokenizer.batch_decode(
            model_outputs[0],
            skip_special_tokens=True,
        )
        return generated_text

# Example: Define the logits processor list
from transformers import LogitsProcessorList

def custom_word_count_fn(input_ids):
    decoded_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return len(decoded_text.split())

# Instantiate your custom LogitsProcessor
target_word_count = 50
logits_processor = LogitsProcessorList([
    GracefulWordCountLogitsProcessor(
        tokenizer=tokenizer,
        target_word_count=target_word_count,
        word_count_fn=custom_word_count_fn
    )
])

# Initialize the custom pipeline
pipe = CustomLogitsProcessorPipeline(
    task="text-generation",
    model=model,
    tokenizer=tokenizer,
    logits_processor=logits_processor,
    # max_length=100  # Example generate kwargs
)


Device set to use cuda:0


In [None]:
from langchain_openai import ChatOpenAI
from langchain.llms import HuggingFacePipeline

def run_experiment(model_name, temperature, dataset, number_of_questions, metrics, word_count_target=None, subset_size=None, log_to_wandb=None):
  llm = HuggingFacePipeline(pipeline=pipe)
  if subset_size is not None:
    dataset = dataset.select(range(subset_size))
  results_subset = process_dataset(dataset, llm, number_of_questions, metrics, word_count_target=word_count_target)
  if log_to_wandb is not None and log_to_wandb:
    log_dataset_to_wandb(results_subset, PROJECT_NAME, f"{RUN_PREFIX}{model_name}")
  return results_subset

# Experiments

In [None]:
# Parameters across runs
SUBSET_SIZE = 1 # if set to None, entire dataset will be processed
TEMPERATURE = 0.7
NUMBER_OF_QUESTIONS = 10
LOG_TO_WANDB = False

metrics = [
    Metric.ROUGE,
    # Metric.FACTUAL_CONSISTENCY,
]

## Run: Baseline

In [None]:
# # Run Parameters
# WORD_COUNT_TARGET = None
# RUN_PREFIX=f"baseline_"

# results = run_experiment(
#     CURRENT_MODEL_ID,
#     TEMPERATURE,
#     dataset,
#     NUMBER_OF_QUESTIONS,
#     metrics,
#     word_count_target=WORD_COUNT_TARGET,
#     subset_size=SUBSET_SIZE,
#     log_to_wandb=LOG_TO_WANDB,
# )
# df = results.to_pandas()
# df.select_dtypes(include='number').mean()

## Run: Generate with Target Word Count(25)

In [None]:
# Run Parameters
WORD_COUNT_TARGET = 25
RUN_PREFIX=f"word_cnt_target_{WORD_COUNT_TARGET}_"

results = run_experiment(
    CURRENT_MODEL_ID,
    TEMPERATURE,
    dataset,
    NUMBER_OF_QUESTIONS,
    metrics,
    word_count_target=WORD_COUNT_TARGET,
    subset_size=SUBSET_SIZE,
    log_to_wandb=LOG_TO_WANDB,
)
df = results.to_pandas()
df.select_dtypes(include='number').mean()

Processing Dataset!
Now summarizing data...


Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


TypeError: string indices must be integers

## Run: Generate with Target Word Count(50)

In [None]:
WORD_COUNT_TARGET = 50
RUN_PREFIX=f"word_cnt_target_{WORD_COUNT_TARGET}_"

results = run_experiment(
    CURRENT_MODEL_ID,
    TEMPERATURE,
    dataset,
    NUMBER_OF_QUESTIONS,
    metrics,
    word_count_target=WORD_COUNT_TARGET,
    subset_size=SUBSET_SIZE,
    log_to_wandb=LOG_TO_WANDB,
)
df = results.to_pandas()
df.select_dtypes(include='number').mean()

## Run: Generate with Target Word Count(150)

In [None]:
WORD_COUNT_TARGET = 150
RUN_PREFIX=f"word_cnt_target_{WORD_COUNT_TARGET}_"

results = run_experiment(
    CURRENT_MODEL_ID,
    TEMPERATURE,
    dataset,
    NUMBER_OF_QUESTIONS,
    metrics,
    word_count_target=WORD_COUNT_TARGET,
    subset_size=SUBSET_SIZE,
    log_to_wandb=LOG_TO_WANDB,
)
df = results.to_pandas()
df.select_dtypes(include='number').mean()