<a href="https://colab.research.google.com/github/rsr2425/word-count-investigation/blob/main/notebooks/3_Custom_Decoding_Step.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets langchain_openai rouge-score evaluate wandb deepeval bitsandbytes



In [None]:
import torch
torch.cuda.is_available()

True

In [None]:
import os

import torch
import torch.nn as nn
import bitsandbytes as bnb
import inspect

from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig

In [None]:
from datasets import load_dataset

dataset = load_dataset("billsum", split="ca_test")
# dataset.drop_column('title')
# dataset = load_dataset("ccdv/cnn_dailymail", '3.0.0', split={
#     'train': 'train[:1000]',
#     'validation': 'validation[:1000]',
#     'test': 'test[:1000]'
# })
# dataset = dataset.rename_column('article', 'text')
# dataset = dataset.rename_column('highlights', 'summary')
# dataset = dataset.remove_columns(['id'])

# dataset = dataset.map(lambda x: {'target_word_count': WORD_COUNT_TARGET})

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16,
)

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "meta-llama/Llama-3.1-8B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map='auto',
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

def custom_word_count(input_ids):
    # Decode input_ids to text and count words
    text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    return len(text.split())

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:  61%|######    | 3.03G/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

In [None]:
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps

In [None]:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# ^^^ avoids this error?
# "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
# Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation."

In [None]:
import math

from transformers import LogitsProcessor

class GracefulWordCountLogitsProcessor( ):
    def __init__(self, tokenizer, target_word_count, word_count_fn, buffer_window=5, completion_boost=5.0):
        super().__init__()
        self.tokenizer = tokenizer
        self.eos_token_id = tokenizer.eos_token_id
        self.target_word_count = target_word_count
        self.word_count_fn = word_count_fn
        self.buffer_window = buffer_window
        self.completion_boost = completion_boost

    def __call__(self, input_ids, scores):
        current_word_count = self.word_count_fn(input_ids)
        # print(f"Current Word Count: {current_word_count}")
        # print(f"Input ids: {input_ids}")
        # print(f"Input ids shape: {input_ids.shape}")

        if self.target_word_count - self.buffer_window <= current_word_count < self.target_word_count:
            punctuation_tokens = [".", "!", "?"]
            punctuation_ids = [
                tokenizer.convert_tokens_to_ids(tok) for tok in punctuation_tokens if tok in tokenizer.vocab
            ]

            for token_id in punctuation_ids:
                scores[:, token_id] += self.completion_boost

            # TODO: do I need this?
            # scores[:, self.eos_token_id] += self.completion_boost / 2

        # Prevent overshooting: strongly favor EOS if the count exceeds the target
        if current_word_count >= self.target_word_count:
            scores[:, :] = -float("inf")  # Set all probabilities to zero
            scores[:, self.eos_token_id] = 0.0  # Make EOS the only valid option
        return scores

In [None]:
dataset

Dataset({
    features: ['text', 'summary', 'title'],
    num_rows: 1237
})

In [None]:
target_word_count = 50  # Desired word count
word_count_processor = GracefulWordCountLogitsProcessor(tokenizer, target_word_count, custom_word_count)

input_text = "Explain photosynthesis:"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)

output = model.generate(
    input_ids,
    max_length=100,  # Prevent overly long generations
    logits_processor=[word_count_processor],  # Use custom logit processor
    do_sample=False
)

# Decode and print the output
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_text)
print(f"Word Count: {custom_word_count(output)}")

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Explain photosynthesis: the process by which plants, algae, and some bacteria convert light energy from the sun into chemical energy stored in glucose.
Photosynthesis is a vital process that sustains life on Earth. It's the backbone of the food chain, providing energy and organic compounds for nearly all living organisms
Word Count: 50


In [None]:
from transformers import Pipeline, LogitsProcessorList, PreTrainedModel, PreTrainedTokenizer, TextGenerationPipeline
from typing import Optional, Any, Dict, Tuple, List
from transformers import LogitsProcessorList

def custom_word_count_fn(input_ids):
    # quick hack. Need to make sure I don't get the early conversation words caught up in the calc.
    # this would be horrifying to use in a prod environment
    decoded_text = tokenizer.decode(
        input_ids[0],
    ).split('<|start_header_id|>assistant<|end_header_id|>')[-1]
    return len(decoded_text.split())

# Instantiate your custom LogitsProcessor
target_word_count = 25
logits_processor = LogitsProcessorList([
    GracefulWordCountLogitsProcessor(
        tokenizer=tokenizer,
        target_word_count=target_word_count,
        word_count_fn=custom_word_count_fn
    )
])

pipe = TextGenerationPipeline(
    task="text-generation",
    model=model,
    tokenizer=tokenizer,
    logits_processor=logits_processor,
    # not actually sure why I need this? Otherwise, w/ custom LP the sequence ends way too early
    max_new_tokens=10000,
)

Device set to use cuda:0


In [None]:
from langchain_core.output_parsers import StrOutputParser

messages = [
    ("system", """
    You are a helpful summary chatbot.  Summarize the content provided by the user. Make sure it's a complete sentence.
    """),
    ("human", f"{dataset['text'][0]}"),
]

    {"role": "system", "content": """
    You are a helpful summary chatbot.  Summarize the content provided by the user. Make sure it's a complete sentence.
    """},
    {"role": "user", "content": f"{dataset['text'][0]}"},
]

# chain = hf | StrOutputParser()
# chain.invoke(messages).content

In [None]:
final_output = pipe(chat)
print(f"Word Count: {len(final_output[0]['generated_text'][-1]['content'].split())}")
final_output[0]['generated_text'][-1]

Word Count: 25


{'role': 'assistant',
 'content': 'The California State Legislature has enacted a law to amend Section 215.1 of the Revenue and Taxation Code to provide a tax exemption for buildings'}