# ! Important
This notebook contains an unfinished version of the sampler that contains bugs and vram leaks. You should use the code in antislop_generate.py (see example_generate.ipynb also) if you want to actually use the sampler for anything.

I'm leaving it here for now because it's a cool way to visualise the backtracking.

In [1]:
import os
import sys
import time
import json
import traceback
from typing import List, Dict, Tuple, Generator, Set

import torch
from transformers import (
    PreTrainedTokenizer,
    PreTrainedModel,
    AutoModelForCausalLM,
    AutoTokenizer,
)
from IPython.display import clear_output, display, HTML
from ipywidgets import Output
import numpy as np

# Enable efficient transfer for Hugging Face models
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = "1"

# Set the device to 'cuda' if available, else 'cpu'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Specify the model name (replace with your preferred model)
model_name = "unsloth/Llama-3.2-1B-Instruct"

# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)


In [2]:
# These are automatically derived from relative frequency of words in a gpt generated dataset. See slopcalc.ipynb for the code to generate a more complete list.
slop_phrase_prob_adjustments = [['kaleidoscope', 0.5], ['symphony', 0.5], ['testament to', 0.5], ['moth to a flame', 0.5], ['canvas', 0.5], ['eyes glinted', 0.5], ['camaraderie', 0.5], ['humble abode', 0.5], ['cold and calculating', 0.5], ['eyes never leaving', 0.5], ['tapestry', 0.5], ['barely above a whisper', 0.5], ['body and soul', 0.5], ['orchestra', 0.5], ['depths', 0.5], ['a dance of', 0.5], ['chuckles darkly', 0.5], ['maybe, just maybe', 0.5], ['maybe that was enough', 0.5], ['with a mixture of', 0.5], ['air was filled with anticipation', 0.5], ['cacophony', 0.5], ['bore silent witness to', 0.5], ['eyes sparkling with mischief', 0.5], ['was only just beginning', 0.5], ['practiced ease', 0.5], ['ready for the challenges', 0.5], ['only just getting started', 0.5], ['once upon a time', 0.5], ['nestled deep within', 0.5], ['ethereal beauty', 0.5], ['life would never be the same again.', 0.5], ["it's important to remember", 0.5], ['for what seemed like an eternity', 0.5], ['feel a sense of pride and accomplishment', 0.5], ['little did he know', 0.5], ['ball is in your court', 0.5], ['game is on', 0.5], ['choice is yours', 0.5], ['feels like an electric shock', 0.5], ['threatens to consume', 0.5], ['meticulous', 0.5], ['meticulously', 0.5], ['navigating', 0.5], ['complexities', 0.5], ['realm', 0.5], ['understanding', 0.5], ['dive into', 0.5], ['shall', 0.5], ['tailored', 0.5]];
slop_phrase_prob_adjustments = dict(slop_phrase_prob_adjustments)


In [3]:
if os.path.exists('slop_phrase_prob_adjustments.json'):
    with open('slop_phrase_prob_adjustments.json', 'r') as f:
        slop_phrase_prob_adjustments = dict(json.load(f)[:500])

In [4]:

def precompute_starting_tokens(
    tokenizer: PreTrainedTokenizer, slop_phrase_prob_adjustments: Dict[str, float]
) -> Dict[Tuple[int, ...], Set[int]]:
    """
    Precompute all starting token IDs for each target word variant.

    Args:
        tokenizer (PreTrainedTokenizer): The tokenizer used by the model.
        word_penalties (Dict[str, float]): Dictionary of target words with their respective penalty.

    Returns:
        Dict[Tuple[int, ...], Set[int]]: A mapping from each token sequence (word variant) to a set of starting token IDs.
    """
    starting_tokens_lookup = {}

    for word in slop_phrase_prob_adjustments.keys():
        variants = [
            word.lower(),
            word.capitalize(),
            word.upper(),
            f" {word.lower()}",
            f" {word.capitalize()}",
            f" {word.upper()}",
        ]

        for variant in variants:
            # Encode the full variant
            token_ids = tokenizer.encode(variant, add_special_tokens=False)
            starting_tokens = set()
            if token_ids:
                starting_tokens.add(token_ids[0])
                first_token_decoded = tokenizer.decode(token_ids[0], skip_special_tokens=True)

                # Iterate over all possible prefixes of the first token
                for i in range(len(first_token_decoded) - 1):
                    prefix = first_token_decoded[:-(i + 1)]
                    encoded_prefix = tokenizer.encode(prefix, add_special_tokens=False)
                    if encoded_prefix:
                        starting_tokens.add(encoded_prefix[0])  # Add the first token of the prefix

                starting_tokens_lookup[tuple(token_ids)] = starting_tokens

    return starting_tokens_lookup

# Precompute starting tokens
starting_tokens_lookup = precompute_starting_tokens(tokenizer, slop_phrase_prob_adjustments)

class AdvancedCustomWordSampler:
    """
    A sampler that generates text while downregulating specified words or phrases.
    It uses backtracking and custom adjustments to avoid overrepresented words.
    """
    def __init__(
        self,
        model: PreTrainedModel,
        tokenizer: PreTrainedTokenizer,
        slop_phrase_prob_adjustments: Dict[str, float],
        starting_tokens_lookup: Dict[Tuple[int, ...], Set[int]],
        adjustment_strength: float = 1.0,
        device: torch.device = torch.device('cuda'),
        output_every_n_tokens: int = 5,
        slow_debug: bool = False,
        debug_delay: float = 2.0,
    ):
        """
        Initializes the AdvancedCustomWordSampler with the necessary parameters.

        Args:
            model (PreTrainedModel): The language model to use for Generation.
            tokenizer (PreTrainedTokenizer): The tokenizer associated with the model.
            slop_phrase_prob_adjustments (Dict[str, float]): Dictionary of target words with their respective probability adjustment factor.
            starting_tokens_lookup (Dict[Tuple[int, ...], Set[int]]): Mapping from token sequences to starting token IDs.
            adjustment_strength (float): Strength of the downregulation adjustment.
            device (torch.device): Device to run the model on.
            output_every_n_tokens (int): Frequency of updating the inference output display.
            slow_debug (bool): Enables slow debug mode when set to True.
            debug_delay (float): Time in seconds to pause during slow debug steps.
        """
        self.model = model.to(device)
        self.tokenizer = tokenizer
        self.slop_phrase_prob_adjustments = slop_phrase_prob_adjustments
        self.starting_tokens_lookup = starting_tokens_lookup
        self.adjustment_strength = adjustment_strength
        self.device = device
        self.output_every_n_tokens = output_every_n_tokens
        self.slow_debug = slow_debug
        self.debug_delay = debug_delay

        # Prepare token sequences for downregulation
        self.token_sequences = self._prepare_token_sequences()
        self.max_sequence_length = max(len(seq) for seq in self.token_sequences.keys())

        # Initialize a cache to store logits along with their positions
        self.logit_cache = {}

        # Record of downregulated sequences at specific positions
        self.downregulated_positions = {}  # Key: position, Value: set of sequences

    def _prepare_token_sequences(self) -> Dict[Tuple[int, ...], float]:
        """
        Prepares the token sequences from slop_phrase_prob_adjustments for efficient lookup.

        Returns:
            Dict[Tuple[int, ...], float]: Mapping from token ID sequences to their adjustment factors.
        """
        token_sequences = {}
        for word, prob_adjustment_factor in self.slop_phrase_prob_adjustments.items():
            variants = [
                word.lower(),
                word.capitalize(),
                word.upper(),
                f" {word.lower()}",
                f" {word.capitalize()}",
                f" {word.upper()}",
            ]
            for variant in variants:
                token_ids = tuple(self.tokenizer.encode(variant, add_special_tokens=False))
                if token_ids:
                    token_sequences[token_ids] = prob_adjustment_factor
        return token_sequences

    def _adjust_logits(self, logits: torch.FloatTensor, adjustment: float) -> torch.FloatTensor:
        """
        Adjusts the logits by applying the downregulation factor.

        Args:
            logits (torch.FloatTensor): The original logits.
            adjustment (float): The adjustment factor to apply.

        Returns:
            torch.FloatTensor: The adjusted logits.
        """
        log_adjustment = torch.log(torch.tensor(adjustment ** self.adjustment_strength, device=self.device))
        return logits + log_adjustment  # Lowering the logit for disallowed tokens

    @torch.no_grad()
    def generate_stream(
        self,
        prompt: str,
        max_length: int,
        temperature: float = 1.0,
        top_k: int = 50,
        top_p: float = 0.95,
    ) -> Generator[str, None, None]:
        """
        Generates text in a streaming fashion with custom downregulation and backtracking.

        Args:
            prompt (str): The initial text prompt.
            max_length (int): The maximum length of the generated text.
            temperature (float): Sampling temperature.
            top_k (int): Top-k filtering.
            top_p (float): Top-p (nucleus) filtering.

        Yields:
            Generator[str, None, None]: Yields generated text chunks.
        """
        # Encode the prompt
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        generated_sequence = input_ids[0].tolist()
        current_position = len(generated_sequence)  # Tracks the current position in the sequence

        # Initialize display (if using in Jupyter)
        with inference_output:
            inference_output.clear_output(wait=True)
            display(HTML(f"<div style='white-space: pre-wrap;'>{self.tokenizer.decode(generated_sequence)}</div>"))

        past_key_values = None
        output_tokens_counter = 0

        while len(generated_sequence) < max_length:
            # Ensure the generated token ids are tokenised how the tokeniser would normally do it,
            # We do this ecause our lookup only works when the text is tokenised the normal way, 
            # and the model may have arrived at the slop phrase by a different route.
            # !! actually -- this might break things (?), as cached values are indexed by position. need to rethink this.
            #generated_sequence = self.tokenizer.encode(self.tokenizer.decode(generated_sequence, skip_special_tokens=False), add_special_tokens=False)

            current_input_ids = torch.tensor([generated_sequence], device=self.device)

            regenerating = False
            if current_position in self.logit_cache:
                # We backtracked and want to use the cached logits
                next_token_logits = self.logit_cache[current_position]
                past_key_values = None
                regenerating = True
            else:
                if past_key_values is None:
                    outputs = self.model(current_input_ids, use_cache=True)
                else:
                    outputs = self.model(current_input_ids[:, -1:], past_key_values=past_key_values, use_cache=True)

                next_token_logits = outputs.logits[:, -1, :] / temperature
                past_key_values = outputs.past_key_values
                self.logit_cache[current_position] = next_token_logits.clone()

            # Apply top-k and top-p filtering
            filtered_logits = self._filter_logits(next_token_logits, top_k, top_p)

            # Sample the next token
            probs = torch.softmax(filtered_logits, dim=-1)
            next_token_index = torch.multinomial(probs, num_samples=1)
            next_token = next_token_index.item()

            if regenerating:
                with debug_output:
                    debug_output.clear_output(wait=True)
                alt_token = tokenizer.decode(next_token, skip_special_tokens=True)
                debug_info = f"Alternate token: {alt_token}"
                self._display_debug(debug_info)
                if self.slow_debug:
                    time.sleep(self.debug_delay)

            # Append the new token to the sequence
            generated_sequence.append(next_token)
            current_position += 1
            output_tokens_counter += 1

            # Yield the current text chunk
            current_text = self.tokenizer.decode(generated_sequence)
            if output_tokens_counter >= self.output_every_n_tokens:
                output_tokens_counter = 0
                with inference_output:
                    inference_output.clear_output(wait=True)
                    display(HTML(f"<div style='white-space: pre-wrap;'>{current_text}</div>"))
                yield current_text  # Yield the generated text chunk

            # Check for end-of-sequence token
            if next_token == self.tokenizer.eos_token_id:
                break

            # After adding the token, check for disallowed sequences
            matched_sequence, start_pos = self._detect_disallowed_sequence(generated_sequence)

            if matched_sequence:
                # Downregulate the relevant tokens at the start_pos
                adjustment = self.token_sequences[matched_sequence]
                word = self.tokenizer.decode(torch.tensor(matched_sequence))

                # Display debug information
                debug_info = f"Replacing '{word}'"
                self._display_debug(debug_info)
                with inference_output:
                    inference_output.clear_output(wait=True)
                    display(HTML(f"<div style='white-space: pre-wrap;'>{current_text}</div>"))

                if self.slow_debug:
                    time.sleep(self.debug_delay)
                    with debug_output:
                        debug_output.clear_output(wait=True)

                # Identify starting tokens to downregulate
                starting_tokens = self.starting_tokens_lookup.get(matched_sequence, set())

                for token_id in starting_tokens:
                    self.logit_cache[start_pos][:, token_id] *= adjustment

                # Record that this sequence has been downregulated at this position
                if start_pos not in self.downregulated_positions:
                    self.downregulated_positions[start_pos] = set()
                self.downregulated_positions[start_pos].add(matched_sequence)

                # Check if the starting token would still be selected after downregulation
                slop_phrase_starting_token = generated_sequence[start_pos]
                if torch.argmax(self.logit_cache[start_pos]).item() == slop_phrase_starting_token:
                    if self.slow_debug:
                        debug_info = f"[INFO] Slop phrase '{self.tokenizer.decode(matched_sequence)}' prob was downregulated {round(1/adjustment, 2)}x but still selected."
                        self._display_debug(debug_info)
                        time.sleep(self.debug_delay)
                    continue

                # Backtrack: remove tokens from the generated_sequence that are part of the disallowed sequence
                for _ in range(len(matched_sequence)):
                    generated_sequence.pop()
                    current_position -= 1

                # Update the model's past_key_values by re-encoding up to start_pos
                # This is necessary because we've modified the generated_sequence
                new_input_ids = torch.tensor([generated_sequence], device=self.device)
                outputs = self.model(new_input_ids, use_cache=True)
                past_key_values = outputs.past_key_values

                # Clear the logit_cache ahead of start_pos since we've backtracked
                to_del = [key for key in self.logit_cache if key > start_pos]
                for key in to_del:
                    del self.logit_cache[key]

                continue  # Continue to the next iteration

        # Final display of the generated text
        final_text = self.tokenizer.decode(generated_sequence)
        with inference_output:
            inference_output.clear_output(wait=True)
            display(HTML(f"<div style='white-space: pre-wrap;'>{final_text}</div>"))
        yield final_text

        # Clear variables to free up memory
        del outputs, next_token_logits, filtered_logits, past_key_values

    def _filter_logits(self, logits: torch.FloatTensor, top_k: int, top_p: float) -> torch.FloatTensor:
        """
        Applies top-k and top-p (nucleus) filtering to the logits.

        Args:
            logits (torch.FloatTensor): The original logits.
            top_k (int): The number of top tokens to keep.
            top_p (float): The cumulative probability threshold.

        Returns:
            torch.FloatTensor: The filtered logits.
        """
        # Apply top-k
        if top_k > 0:
            top_k = min(top_k, logits.size(-1))
            top_k_logits, _ = torch.topk(logits, top_k)
            min_top_k = top_k_logits[:, -1].unsqueeze(-1)
            logits = torch.where(logits < min_top_k, float('-inf'), logits)

        # Apply top-p
        if top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)

            # Remove tokens with cumulative probability above the threshold
            sorted_indices_to_remove = cumulative_probs > top_p
            # Shift the mask right to keep the first token above the threshold
            sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
            sorted_indices_to_remove[:, 0] = False
            indices_to_remove = sorted_indices_to_remove.scatter(
                dim=1, index=sorted_indices, src=sorted_indices_to_remove
            )
            logits = logits.masked_fill(indices_to_remove, float('-inf'))

        return logits

    def _detect_disallowed_sequence(self, generated_sequence: List[int]) -> Tuple[Tuple[int, ...], int]:
        """
        Detects if the recent tokens in the generated_sequence match any disallowed sequence.

        Args:
            generated_sequence (List[int]): The list of generated token IDs.

        Returns:
            Tuple[Tuple[int, ...], int]: The matched disallowed sequence and its start position.
                                         Returns (None, -1) if no match is found.
        """        
        # Start checking from the longest possible sequence to the shortest
        for seq_length in range(self.max_sequence_length, 0, -1):
            if len(generated_sequence) < seq_length:
                continue
            candidate_sequence = tuple(generated_sequence[-seq_length:])
            if candidate_sequence in self.token_sequences:
                start_pos = len(generated_sequence) - seq_length
                return candidate_sequence, start_pos
        return None, -1

    def _display_debug(self, message: str):
        """
        Displays debug information in the debug_output widget.

        Args:
            message (str): The debug message to display.
        """
        with debug_output:
            debug_output.clear_output(wait=True)
            display(HTML(f"<pre>{message}</pre>"))


In [5]:

# Setup separate output widgets for prompt, inference, and debug
prompt_output = Output()
inference_output = Output()
debug_output = Output()

# Display the output widgets
display(HTML("<h2>Prompt</h2>"))
display(prompt_output)
display(HTML("<h2>Inference Output</h2>"))
display(inference_output)
display(HTML("<h2>Debug Information</h2>"))
display(debug_output)

# Enable slow debug mode
SLOW_DEBUG = True

# Initialize the sampler
sampler = AdvancedCustomWordSampler(
    model=model,
    tokenizer=tokenizer,
    slop_phrase_prob_adjustments=slop_phrase_prob_adjustments,
    starting_tokens_lookup=starting_tokens_lookup,
    adjustment_strength=1.0,
    device=device,
    output_every_n_tokens=5,
    slow_debug=SLOW_DEBUG,          # Enable slow debug
    debug_delay=1.5                # Set delay to 1.5 seconds per debug step
)

# Define the prompt
prompt = "Write a story about Elara, the weaver of tapestries in future Technopolis. In the bustling city, a group of "
prompt = "Once upon a time in a land far away, there was a"

# Display the prompt
with prompt_output:
    prompt_output.clear_output(wait=True)
    display(HTML(f"<div style='white-space: pre-wrap;'>{prompt}</div>"))

# Start generating
try:
    for text_chunk in sampler.generate_stream(prompt, max_length=800, temperature=1.0, top_k=50, top_p=0.95):
        pass  # The text_chunk is already being displayed via the inference_output widget
    print("\nGeneration complete.")
except Exception as e:
    with debug_output:
        debug_output.clear_output(wait=True)
        debug_output.append_stdout(f"\n\nAn error occurred: {str(e)}\n")
        traceback.print_exc(file=sys.stdout)


Output()

Output()

Output()

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class (https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)



Generation complete.
