In [1]:
import os
import sys
import time
import json
import traceback
from typing import List, Dict, Tuple, Generator, Set

import torch
from transformers import (
    PreTrainedTokenizer,
    PreTrainedModel,
    AutoModelForCausalLM,
    AutoTokenizer,
)
from IPython.display import clear_output, display, HTML
from ipywidgets import Output

# Enable efficient transfer for Hugging Face models
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = "1"

# Set the device to 'cuda' if available, else 'cpu'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Specify the model name (replace with your preferred model)
model_name = "unsloth/Llama-3.2-1B-Instruct"

# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)


In [2]:
# These are automatically derived from relative frequencey of words in a gpt generated dataset. See slopcalc.ipynb for the code to generate a more complete list.
word_penalties_list = [['a kaleidoscope of', 2000], ['testament', 2000.0150485530914], ['technopolis', 1682.682059035117], ['understandingly', 762.9294022671543], ['paperbound', 659.2264970913199], ['hesitantly', 496.5646879026894], ['piqued', 482.3001178804444], ['delved', 473.4940223966827], ['curveballs', 462.50687039417824], ['bustling', 454.70303449492854], ['marveled', 428.19439049963717], ['inclusivity', 399.28185144068397], ['birdwatcher', 382.93952575702605], ['elara', 382.02399833524635], ['camaraderie', 325.065910926091], ['newfound', 289.3537643476301], ['marveling', 281.4117889244332], ["hiroshi's", 277.20734354116485], ['greentech', 268.92005660614404], ['thoughtfully', 266.9326102346037], ['intently', 251.51633784078055], ['birdwatching', 250.1588231011304], ['amidst', 249.22122673588677], ['cherishing', 247.91009553317267], ['attentively', 246.79285812826976], ['interjected', 235.9610251843368], ['serendipitous', 233.0266906486917], ["marianne's", 232.87022859034334], ["maya's", 230.37564440034032], ['excitedly', 228.84649139211055], ['steepled', 228.70847810531137], ['engrossed', 228.4472196666415], ['fostering', 222.59645412759878], ['brainstormed', 218.63487421031408], ['furrowed', 217.1288191216257], ['nodded', 215.746532494257], ['contemplatively', 213.9394730581084], ['jotted', 212.19052857841066], ["mia's", 209.54359039341304], ['yesteryears', 205.40375361782048], ['conspiratorially', 204.18903883197223], ['poring', 203.12158920489887], ['stumbled', 201.95286430826962], ['strategized', 198.31538808865406], ['hesitated', 194.35575102206053], ['intrigued', 191.32480777377384], ["sarah's", 188.1056806342427], ['lykos', 186.8984432180823], ['adaptability', 185.2743410645729], ['yoing', 184.27349339389698], ['geocaches', 182.61995131301913], ['furrowing', 181.28300434012144], ['quandaries', 178.45513005800902], ['chimed', 177.6317240627814], ['headfirst', 177.55430176035384], ['gruffly', 173.5169670342752], ['skeptically', 173.2196196510284], ['nestled', 170.24886793134038], ['fruitville', 168.93895251122095], ['gastronomical', 168.77834340202367], ['sighed', 167.83599102428747], ['warmly', 166.64524795750117], ['approvingly', 165.14554435242388], ['questioningly', 164.2217764827755], ["timmy's", 162.85237720972512], ['undeterred', 159.81034467083455], ['starlit', 158.81973280586473], ['unearthing', 157.53953848282245], ['grappled', 155.11380760257956], ["yumi's", 153.362396079487], ["seabrook's", 152.65396517679832], ['geocachers', 152.48899331241418], ['animatedly', 150.64516344395025], ['bakersville', 149.48324667712868], ['minji', 148.7787149817242], ['fateful', 147.881376001738], ['sparkled', 145.48284973440963], ['resonated', 144.91492949803347], ['harmoniously', 144.8378436549682], ['fidgeted', 143.88462648395776], ['mwanga', 141.271194443305], ['gleamed', 140.84454272803274], ['embracing', 140.8134127640521], ['pleasantries', 138.9683910665212], ['iostream', 137.02499195670802], ['navigated', 136.8749045617025], ['interconnectedness', 136.6775722710472], ['tanginess', 136.0248012468762], ['mouthwatering', 135.40207079890078], ["amelia's", 135.12735430462644], ['delving', 134.62133115310917], ['mischievously', 134.53400914082113], ['tirelessly', 134.50459651470058], ['transcended', 132.75875026522667], ['sympathetically', 132.28731274201124], ['pondered', 132.24181930810568], ['lingered', 131.6820547398057], ['empathizing', 130.38734974729505], ['niche', 128.82354262722254], ['regaled', 128.21211629309659], ['greenthumb', 127.87603715586023], ['savored', 127.44044593637169], ["amira's", 127.26977675143385], ['meticulously', 125.67264078678225], ['firsthand', 123.40718461639742], ['empathetically', 122.76583436768932], ['unshed', 122.234447281337], ["jenkin's", 122.13793091468249], ['empathy', 120.78510640297107], ['enigmatically', 120.10278606401909], ["marla's", 119.96311948139385], ['bayville', 119.86205591147561], ['adversities', 119.16540991510242], ['eagerly', 118.7736944560049], ['labyrinthine', 117.30247757246565], ['quizzically', 116.99368259297609], ['transcending', 116.98548707285242], ['resilience', 115.03290831887757], ["lily's", 114.79275367950875], ['commiserated', 114.67810631179985], ['savoring', 114.09168105940152], ["amara's", 113.35572623503091], ['somberly', 111.33701848917218], ['cinephile', 110.08462885614495], ['solace', 109.58942409221955], ['twinkled', 108.8491511689895], ['aquascaping', 108.490673606918], ['rippled', 107.44694800406151], ['reveled', 107.0468756808211], ['greenhaven', 106.94331659047654], ['birdwatchers', 106.78153731714639], ['adwoa', 105.93106020380063], ['appreciatively', 105.79259888560438], ['awestruck', 105.76861040590829], ['ecotech', 105.3920442928442], ['navigating', 105.00722107822565], ['lightheartedness', 103.97167324805294], ['disapprovingly', 103.7632507571004], ['exclaimed', 103.50051962551075], ["samir's", 103.32935174824694], ['fishkeeping', 102.85891331779236], ['sparked', 101.88166406410981], ['welled', 101.1238612235024], ['jotting', 101.00457387540926], ['resourcefulness', 100.7218012566337], ['flickered', 99.7346200354375], ['reminisced', 99.40737436442414], ["abernathy's", 97.97170234043274], ['unbeknownst', 96.90280381444087], ['pattered', 96.45923446705675], ['reassuringly', 95.99564132547866], ['miscommunications', 95.92537976660182], ['wafted', 95.46404472730974], ['absentmindedly', 94.91609082523561], ['weightiness', 94.72530273625273], ['allyship', 94.20061224332532], ['perseverance', 93.51317590165482], ['timmy', 93.37738761768523], ['mindfully', 92.98616009115628], ['disheartened', 92.79027445159294], ['leaned', 92.64926718281589], ['birder', 92.44448100845747], ['captivated', 92.43721619730242], ["ravi's", 92.31017082428738], ["abuela's", 91.35242003061113], ['apprehensions', 91.29401570563567];


In [3]:

word_penalties = {}
for el in word_penalties_list[:500]:
    word_penalties[el[0]] = el[1]

In [4]:
# Load your own words/phrases + penalty values from a JSON file
if os.path.exists('over_represented_words.json'):
    with open('over_represented_words.json', 'r') as f:
        over_represented_words = json.load(f)
        # Use the top 500 words
        for el in over_represented_words[:500]:
            word_penalties[el[0]] = el[1]

In [5]:

def precompute_starting_tokens(
    tokenizer: PreTrainedTokenizer, word_penalties: Dict[str, float]
) -> Dict[Tuple[int, ...], Set[int]]:
    """
    Precompute all starting token IDs for each target word variant.

    Args:
        tokenizer (PreTrainedTokenizer): The tokenizer used by the model.
        word_penalties (Dict[str, float]): Dictionary of target words with their respective penalty.

    Returns:
        Dict[Tuple[int, ...], Set[int]]: A mapping from each token sequence (word variant) to a set of starting token IDs.
    """
    starting_tokens_lookup = {}

    for word in word_penalties.keys():
        variants = [
            word.lower(),
            word.capitalize(),
            word.upper(),
            f" {word.lower()}",
            f" {word.capitalize()}",
            f" {word.upper()}",
        ]

        for variant in variants:
            # Encode the full variant
            token_ids = tokenizer.encode(variant, add_special_tokens=False)
            starting_tokens = set()
            if token_ids:
                starting_tokens.add(token_ids[0])
                first_token_decoded = tokenizer.decode(token_ids[0], skip_special_tokens=True)

                # Iterate over all possible prefixes of the first token
                for i in range(len(first_token_decoded) - 1):
                    prefix = first_token_decoded[:-(i + 1)]
                    encoded_prefix = tokenizer.encode(prefix, add_special_tokens=False)
                    if encoded_prefix:
                        starting_tokens.add(encoded_prefix[0])  # Add the first token of the prefix

                starting_tokens_lookup[tuple(token_ids)] = starting_tokens

    return starting_tokens_lookup

# Precompute starting tokens
starting_tokens_lookup = precompute_starting_tokens(tokenizer, word_penalties)

class AdvancedCustomWordSampler:
    """
    A sampler that generates text while downregulating specified words or phrases.
    It uses backtracking and custom adjustments to avoid overrepresented words.
    """
    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        word_penalties: Dict[str, float],
        starting_tokens_lookup: Dict[Tuple[int, ...], Set[int]],
        max_backtrack: int = 5,
        adjustment_strength: float = 1.0,
        device: torch.device = torch.device('cuda'),
        output_every_n_tokens: int = 5,
        slow_debug: bool = False,
        debug_delay: float = 2.0,
    ):
        """
        Initializes the AdvancedCustomWordSampler with the necessary parameters.

        Args:
            model (PreTrainedModel): The language model to use for Generation.
            tokenizer (PreTrainedTokenizer): The tokenizer associated with the model.
            word_penalties (Dict[str, float]): Dictionary of target words with their respective penalty.
            starting_tokens_lookup (Dict[Tuple[int, ...], Set[int]]): Mapping from token sequences to starting token IDs.
            max_backtrack (int): Maximum number of tokens to backtrack.
            adjustment_strength (float): Strength of the downregulation adjustment.
            device (torch.device): Device to run the model on.
            output_every_n_tokens (int): Frequency of updating the inference output display.
            slow_debug (bool): Enables slow debug mode when set to True.
            debug_delay (float): Time in seconds to pause during slow debug steps.
        """
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.word_penalties = word_penalties
        self.starting_tokens_lookup = starting_tokens_lookup
        self.max_backtrack = max_backtrack
        self.adjustment_strength = adjustment_strength
        self.device = device
        self.output_every_n_tokens = output_every_n_tokens
        self.slow_debug = slow_debug
        self.debug_delay = debug_delay

        # Prepare token sequences for downregulation
        self.token_sequences = self._prepare_token_sequences()
        self.max_sequence_length = max(len(seq) for seq in self.token_sequences.keys())

        # Initialize a cache to store logits along with their positions
        self.logit_cache = {}

        # Record of downregulated sequences at specific positions
        self.downregulated_positions = {}  # Key: position, Value: set of sequences

    def _prepare_token_sequences(self) -> Dict[Tuple[int, ...], float]:
        """
        Prepares the token sequences from word_penalties for efficient lookup.

        Returns:
            Dict[Tuple[int, ...], float]: Mapping from token ID sequences to their adjustment factors.
        """
        token_sequences = {}
        for word, penalty in self.word_penalties.items():
            variants = [
                word.lower(),
                word.capitalize(),
                word.upper(),
                f" {word.lower()}",
                f" {word.capitalize()}",
                f" {word.upper()}",
            ]
            for variant in variants:
                token_ids = tuple(self.tokenizer.encode(variant, add_special_tokens=False))
                if token_ids:
                    token_sequences[token_ids] = 1 / penalty  # Inverse of the over-representation penalty
        return token_sequences

    def _adjust_logits(self, logits: torch.FloatTensor, adjustment: float) -> torch.FloatTensor:
        """
        Adjusts the logits by applying the downregulation factor.

        Args:
            logits (torch.FloatTensor): The original logits.
            adjustment (float): The adjustment factor to apply.

        Returns:
            torch.FloatTensor: The adjusted logits.
        """
        log_adjustment = torch.log(torch.tensor(adjustment ** self.adjustment_strength, device=self.device))
        return logits + log_adjustment  # Lowering the logit for disallowed tokens

    @torch.no_grad()
    def generate_stream(
        self,
        prompt: str,
        max_length: int,
        temperature: float = 1.0,
        top_k: int = 50,
        top_p: float = 0.95,
    ) -> Generator[str, None, None]:
        """
        Generates text in a streaming fashion with custom downregulation and backtracking.

        Args:
            prompt (str): The initial text prompt.
            max_length (int): The maximum length of the generated text.
            temperature (float): Sampling temperature.
            top_k (int): Top-k filtering.
            top_p (float): Top-p (nucleus) filtering.

        Yields:
            Generator[str, None, None]: Yields generated text chunks.
        """
        # Encode the prompt
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        generated_sequence = input_ids[0].tolist()
        current_position = len(generated_sequence)  # Tracks the current position in the sequence

        # Initialize display (if using in Jupyter)
        with inference_output:
            inference_output.clear_output(wait=True)
            display(HTML(f"<div style='white-space: pre-wrap;'>{self.tokenizer.decode(generated_sequence)}</div>"))

        past_key_values = None
        output_tokens_counter = 0

        while len(generated_sequence) < max_length:
            current_input_ids = torch.tensor([generated_sequence], device=self.device)

            regenerating = False
            if current_position in self.logit_cache:
                # We backtracked and want to use the cached logits
                next_token_logits = self.logit_cache[current_position]
                past_key_values = None
                regenerating = True
            else:
                if past_key_values is None:
                    outputs = self.model(current_input_ids, use_cache=True)
                else:
                    outputs = self.model(current_input_ids[:, -1:], past_key_values=past_key_values, use_cache=True)

                next_token_logits = outputs.logits[:, -1, :] / temperature
                past_key_values = outputs.past_key_values
                self.logit_cache[current_position] = next_token_logits.clone()

            # Apply top-k and top-p filtering
            filtered_logits = self._filter_logits(next_token_logits, top_k, top_p)

            # Sample the next token
            probs = torch.softmax(filtered_logits, dim=-1)
            next_token_index = torch.multinomial(probs, num_samples=1)
            next_token = next_token_index.item()

            if regenerating:
                with debug_output:
                    debug_output.clear_output(wait=True)
                alt_token = tokenizer.decode(next_token, skip_special_tokens=True)
                debug_info = f"Alternate token: {alt_token}"
                self._display_debug(debug_info)
                if self.slow_debug:
                    time.sleep(self.debug_delay)

            # Append the new token to the sequence
            generated_sequence.append(next_token)
            current_position += 1
            output_tokens_counter += 1

            # Yield the current text chunk
            current_text = self.tokenizer.decode(generated_sequence)
            if output_tokens_counter >= self.output_every_n_tokens:
                output_tokens_counter = 0
                with inference_output:
                    inference_output.clear_output(wait=True)
                    display(HTML(f"<div style='white-space: pre-wrap;'>{current_text}</div>"))
                yield current_text  # Yield the generated text chunk

            # Check for end-of-sequence token
            if next_token == self.tokenizer.eos_token_id:
                break

            # After adding the token, check for disallowed sequences
            matched_sequence, start_pos = self._detect_disallowed_sequence(generated_sequence)

            if matched_sequence:
                # Downregulate the relevant tokens at the start_pos
                adjustment = self.token_sequences[matched_sequence]
                word = self.tokenizer.decode(torch.tensor(matched_sequence))

                # Display debug information
                debug_info = f"Replacing '{word}'"
                self._display_debug(debug_info)

                if self.slow_debug:
                    time.sleep(self.debug_delay)
                    with debug_output:
                        debug_output.clear_output(wait=True)

                # Identify starting tokens to downregulate
                starting_tokens = self.starting_tokens_lookup.get(matched_sequence, set())

                # Find the cached logits corresponding to start_pos
                for token_id in starting_tokens:
                    # Apply downregulation
                    self.logit_cache[start_pos][:, token_id] = self._adjust_logits(
                        self.logit_cache[start_pos][:, token_id], adjustment
                    )

                # Record that this sequence has been downregulated at this position
                if start_pos not in self.downregulated_positions:
                    self.downregulated_positions[start_pos] = set()
                self.downregulated_positions[start_pos].add(matched_sequence)

                # Backtrack: remove tokens from the generated_sequence that are part of the disallowed sequence
                for _ in range(len(matched_sequence)):
                    generated_sequence.pop()
                    current_position -= 1

                # Update the model's past_key_values by re-encoding up to start_pos
                # This is necessary because we've modified the generated_sequence
                new_input_ids = torch.tensor([generated_sequence], device=self.device)
                outputs = self.model(new_input_ids, use_cache=True)
                past_key_values = outputs.past_key_values

                # Clear the logit_cache ahead of start_pos since we've backtracked
                to_del = [key for key in self.logit_cache if key > start_pos]
                for key in to_del:
                    del self.logit_cache[key]

                continue  # Continue to the next iteration

        # Final display of the generated text
        final_text = self.tokenizer.decode(generated_sequence)
        with inference_output:
            inference_output.clear_output(wait=True)
            display(HTML(f"<div style='white-space: pre-wrap;'>{final_text}</div>"))
        yield final_text

        # Clear variables to free up memory
        del outputs, next_token_logits, filtered_logits, past_key_values

    def _filter_logits(self, logits: torch.FloatTensor, top_k: int, top_p: float) -> torch.FloatTensor:
        """
        Applies top-k and top-p (nucleus) filtering to the logits.

        Args:
            logits (torch.FloatTensor): The original logits.
            top_k (int): The number of top tokens to keep.
            top_p (float): The cumulative probability threshold.

        Returns:
            torch.FloatTensor: The filtered logits.
        """
        # Apply top-k
        if top_k > 0:
            top_k = min(top_k, logits.size(-1))
            top_k_logits, _ = torch.topk(logits, top_k)
            min_top_k = top_k_logits[:, -1].unsqueeze(-1)
            logits = torch.where(logits < min_top_k, float('-inf'), logits)

        # Apply top-p
        if top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)

            # Remove tokens with cumulative probability above the threshold
            sorted_indices_to_remove = cumulative_probs > top_p
            # Shift the mask right to keep the first token above the threshold
            sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
            sorted_indices_to_remove[:, 0] = False
            indices_to_remove = sorted_indices_to_remove.scatter(
                dim=1, index=sorted_indices, src=sorted_indices_to_remove
            )
            logits = logits.masked_fill(indices_to_remove, float('-inf'))

        return logits

    def _detect_disallowed_sequence(self, generated_sequence: List[int]) -> Tuple[Tuple[int, ...], int]:
        """
        Detects if the recent tokens in the generated_sequence match any disallowed sequence.

        Args:
            generated_sequence (List[int]): The list of generated token IDs.

        Returns:
            Tuple[Tuple[int, ...], int]: The matched disallowed sequence and its start position.
                                         Returns (None, -1) if no match is found.
        """
        # Start checking from the longest possible sequence to the shortest
        for seq_length in range(self.max_sequence_length, 0, -1):
            if len(generated_sequence) < seq_length:
                continue
            candidate_sequence = tuple(generated_sequence[-seq_length:])
            if candidate_sequence in self.token_sequences:
                start_pos = len(generated_sequence) - seq_length
                return candidate_sequence, start_pos
        return None, -1

    def _display_debug(self, message: str):
        """
        Displays debug information in the debug_output widget.

        Args:
            message (str): The debug message to display.
        """
        with debug_output:
            debug_output.clear_output(wait=True)
            display(HTML(f"<pre>{message}</pre>"))


In [6]:

# Setup separate output widgets for prompt, inference, and debug
prompt_output = Output()
inference_output = Output()
debug_output = Output()

# Display the output widgets
display(HTML("<h2>Prompt</h2>"))
display(prompt_output)
display(HTML("<h2>Inference Output</h2>"))
display(inference_output)
display(HTML("<h2>Debug Information</h2>"))
display(debug_output)

# Enable slow debug mode
SLOW_DEBUG = True

# Initialize the sampler
sampler = AdvancedCustomWordSampler(
    model=model,
    tokenizer=tokenizer,
    word_penalties=word_penalties,
    starting_tokens_lookup=starting_tokens_lookup,
    max_backtrack=5,
    adjustment_strength=1.0,
    device=device,
    output_every_n_tokens=5,
    slow_debug=SLOW_DEBUG,          # Enable slow debug
    debug_delay=1.5                # Set delay to 1.5 seconds per debug step
)

# Define the prompt
prompt = "Write a story about Elara, the weaver of tapestries in future Technopolis. In the bustling city, a group of "

# Display the prompt
with prompt_output:
    prompt_output.clear_output(wait=True)
    display(HTML(f"<div style='white-space: pre-wrap;'>{prompt}</div>"))

# Start generating
try:
    for text_chunk in sampler.generate_stream(prompt, max_length=300, temperature=1.0, top_k=50, top_p=0.95):
        pass  # The text_chunk is already being displayed via the inference_output widget
    print("\nGeneration complete.")
except Exception as e:
    with debug_output:
        debug_output.clear_output(wait=True)
        debug_output.append_stdout(f"\n\nAn error occurred: {str(e)}\n")
        traceback.print_exc(file=sys.stdout)


Output()

Output()

Output()

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)



Generation complete.
