In [52]:
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    LogitsWarper,
    LogitsProcessor,
    LogitsProcessorList,
)
from cog import BasePredictor, Input
from typing import Dict
import torch

CACHE_DIR = "./src/models"

In [53]:
# setup
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small", cache_dir=CACHE_DIR)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small", cache_dir=CACHE_DIR)


In [54]:
prompt = "Hello, my name is "


In [55]:
# token for "Greg"
tokenizer("Greg", return_tensors="pt").input_ids  # tensor([[11859,     1]])


tensor([[11859,     1]])

In [56]:
inputs = tokenizer(prompt, return_tensors="pt").input_ids
print("inputs: ", inputs)
outputs = model.generate(inputs)
print("outputs: ", outputs)
tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

inputs:  tensor([[8774,    6,   82,  564,   19,    3,    1]])
outputs:  tensor([[  0,   3,  23, 265,   3,   7, 265,  76,  15,  40,   3,   7,   9, 967,
           1]])


'iam samuel sailor'

In [66]:
class BiasLogitsWarper(LogitsWarper):
    """Applies a bias to the logits of specific tokens before softmax.

    This class can be used with the `LogitsProcessorList` in Hugging Face's Transformers
    library to alter the logits produced by a model before softmax is applied during
    text generation.

    The class is not dependent on the `input_ids` as it applies bias to specific token ids
    regardless of the context or sequence of tokens currently being processed.

    Attributes
    ----------
    bias : Dict[int, float]
        A dictionary mapping from token ids to bias values. The bias is added to the
        logits for the corresponding token id. If the bias is -100 or 100, the logit for
        the token id is set to negative or positive infinity, respectively, to essentially
        ban or guarantee the token.

    Methods
    -------
    __call__(input_ids: torch.LongTensor, scores: torch.FloatTensor)
        Applies the bias to the logits. This method is called during the generation process.

    warp_logits(logits: torch.Tensor) -> torch.Tensor
        The method that actually applies the bias to the logits.

    Example
    -------
    bias_dict = {11859: 8}  # We're using 8 here to heavily bias towards "Greg" (token 11859)
    bias_warper = BiasLogitsWarper(bias_dict)
    logits_processor_list = LogitsProcessorList([bias_warper])
    outputs = model.generate(inputs, logits_processor=logits_processor_list)
    """

    def __init__(self, bias: Dict[int, float]):
        """
        Parameters
        ----------
        bias : Dict[int, float]
            A dictionary mapping from token ids to bias values.
        """
        self.bias = bias

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        """The method called during the generation process.

        Parameters
        ----------
        input_ids : torch.LongTensor
            The input ids for the current generation step. Not used in this class.
        scores : torch.FloatTensor
            The logits for the current generation step.

        Returns
        -------
        torch.Tensor
            The modified logits.
        """
        return self.warp_logits(scores)

    def warp_logits(self, logits: torch.Tensor) -> torch.Tensor:
        """Applies the bias to the logits.

        Parameters
        ----------
        logits : torch.Tensor
            The logits for the current generation step.

        Returns
        -------
        torch.Tensor
            The modified logits.
        """
        for token_id, bias_value in self.bias.items():
            if abs(bias_value) == 100:
                # Set logit to extremely high or low value
                new_logit = float("inf") if bias_value > 0 else -float("inf")
                logits[..., token_id] = new_logit
            else:
                # Add bias to logit
                logits[..., token_id] += bias_value

            new_logit = logits[..., token_id].item()
            # print(f"New logit for token {token_id}: {new_logit}")

        # The '...' (ellipsis) is used here to index into any number of dimensions,
        # For example, if logits is a 3D tensor with shape (batch_size, sequence_length, vocab_size),
        # logits[..., token_id] would be equivalent to logits[:, :, token_id].
        return logits


In [67]:
def map_tokens_to_bias(tokenizer, bias_dict_str: Dict[str, float]) -> Dict[int, float]:
    """
    Maps a dictionary with strings as keys to a dictionary with corresponding token IDs as keys.

    Parameters
    ----------
    tokenizer
        The tokenizer to use for translating the strings into token IDs.
    bias_dict_str : Dict[str, float]
        A dictionary mapping from strings to bias values.

    Returns
    -------
    Dict[int, float]
        The resulting dictionary mapping from token IDs to bias values.

    Examples
    --------
    >>> bias_dict_str = {"Greg": 8, "Sam": -10}
    >>> bias_dict = map_tokens_to_bias(tokenizer, bias_dict_str) # {11859: 8, 3084: -10}
    """
    return {
        tokenizer.encode(token, add_special_tokens=False)[0]: bias for token, bias in bias_dict_str.items()
    }


In [74]:
bias_dict_str = {"Greg": 8, "Sam": -10}  # The user wants to increase the likelihood of "Greg"
bias_dict = map_tokens_to_bias(tokenizer, bias_dict_str)

bias_warper = BiasLogitsWarper(bias_dict)
logits_processor_list = LogitsProcessorList([bias_warper])

outputs = model.generate(inputs, logits_processor=logits_processor_list)

print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])

Greg Gregg


In [61]:
bias_dict = {11859: 8}  # We're using 8 here to heavily bias towards "Greg"
bias_warper = BiasLogitsWarper(bias_dict)
logits_processor_list = LogitsProcessorList([bias_warper])

# Using our custom LogitsProcessorList with the bias_warper in the call to generate
outputs = model.generate(inputs, logits_processor=logits_processor_list)

print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])

New logit for token 11859: 5.309817790985107
New logit for token 11859: 5.580060005187988
New logit for token 11859: 2.9869871139526367
New logit for token 11859: 2.290940284729004
Greg Gregg


In [59]:
-2.6901822090148926 + 8.0


5.309817790985107

___

In [103]:
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    LogitsWarper,
    LogitsProcessorList,
)
from cog import BasePredictor, Input
from typing import Dict
import torch
import json5
import os

CACHE_DIR = "./src/models"


class BiasLogitsWarper(LogitsWarper):
    """Applies a bias to the logits of specific tokens before softmax.

    This class can be used with the `LogitsProcessorList` in Hugging Face's Transformers
    library to alter the logits produced by a model before softmax is applied during
    text generation.

    The class is not dependent on the `input_ids` as it applies bias to specific token ids
    regardless of the context or sequence of tokens currently being processed.

    Attributes
    ----------
    bias : Dict[int, float]
        A dictionary mapping from token ids to bias values. The bias is added to the
        logits for the corresponding token id. If the bias is -100 or 100, the logit for
        the token id is set to negative or positive infinity, respectively, to essentially
        ban or guarantee the token.

    Methods
    -------
    __call__(input_ids: torch.LongTensor, scores: torch.FloatTensor)
        Applies the bias to the logits. This method is called during the generation process.

    warp_logits(logits: torch.Tensor) -> torch.Tensor
        The method that actually applies the bias to the logits.

    Example
    -------
    bias_dict = {11859: 8}  # We're using 8 here to heavily bias towards "Greg" (token 11859)
    bias_warper = BiasLogitsWarper(bias_dict)
    logits_processor_list = LogitsProcessorList([bias_warper])
    outputs = model.generate(inputs, logits_processor=logits_processor_list)
    """

    def __init__(self, bias: Dict[int, float]):
        """
        Parameters
        ----------
        bias : Dict[int, float]
            A dictionary mapping from token ids to bias values.
        """

        self.bias = bias

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        """The method called during the generation process.

        Parameters
        ----------
        input_ids : torch.LongTensor
            The input ids for the current generation step. Not used in this class.
        scores : torch.FloatTensor
            The logits for the current generation step.

        Returns
        -------
        torch.Tensor
            The modified logits.
        """

        # input_ids not used because biases are applied to scores (logits) over the entire vocab
        # So, we don't really need input_ids, it's just a formality outlined by the LogitsWarper ABC
        return self.warp_logits(scores)

    def warp_logits(self, logits: torch.Tensor) -> torch.Tensor:
        """Applies the bias to the logits.

        Parameters
        ----------
        logits : torch.Tensor
            The logits for the current generation step.

        Returns
        -------
        torch.Tensor
            The modified logits.
        """

        for token_id, bias_value in self.bias.items():
            if abs(bias_value) == 100:
                # Set logit to extremely high or low value
                new_logit = float("inf") if bias_value > 0 else -float("inf")
                logits[..., token_id] = new_logit
            else:
                # Add bias to logit
                logits[..., token_id] += bias_value
            new_logit = logits[..., token_id].item()
            # print(f"New logit for token {token_id}: {new_logit}")
            
        # The '...' (ellipsis) is used here to index into any number of dimensions,
        # For example, if logits is a 3D tensor with shape (batch_size, sequence_length, vocab_size),
        # logits[..., token_id] would be equivalent to logits[:, :, token_id].
        return logits


class Predictor(BasePredictor):
    def setup(self):
        """Load the model into memory to make running multiple predictions efficient"""

        os.makedirs(CACHE_DIR, exist_ok=True)
        self.model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small", cache_dir=CACHE_DIR)
        self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small", cache_dir=CACHE_DIR)

    def map_tokens_to_bias(self, bias_dict_str: Dict[str, float]) -> Dict[int, float]:
        """
        Maps a dictionary with strings as keys to a dictionary with corresponding token IDs as keys.

        Parameters
        ----------
        tokenizer
            The tokenizer to use for translating the strings into token IDs.
        bias_dict_str : Dict[str, float]
            A dictionary mapping from strings to bias values.

        Returns
        -------
        Dict[int, float]
            The resulting dictionary mapping from token IDs to bias values.

        Examples
        --------
        >>> bias_dict_str = {"Greg": 8, "Sam": -10}
        >>> bias_dict = map_tokens_to_bias(bias_dict_str) # {11859: 8, 3084: -10}
        """

        bias_dict = {}
        for token, bias in bias_dict_str.items():
            token_ids = self.tokenizer.encode(token, add_special_tokens=False)
            if len(token_ids) > 1:
                raise ValueError(f"The string '{token}' corresponds to more than one token in the tokenizer.")
            bias_dict[token_ids[0]] = bias
        return bias_dict

    def predict(
        self,
        prompt: str = Input(description="Prompt for language model"),
        bias_dict_str: str = Input(description="Dictionary mapping from strings to bias values", default=None),
        max_output_len: int = Input(description="Maximum length of output", default=64, ge=1, le=512),
    ) -> str:
        """Run a single prediction on the model"""

        inputs = self.tokenizer(prompt, return_tensors="pt").input_ids
        logits_processor_list = None
        if bias_dict_str is not None:
            word_to_bias_dict = json5.loads(bias_dict_str)
            tokenid_to_bias_dict = self.map_tokens_to_bias(word_to_bias_dict)
            bias_warper = BiasLogitsWarper(tokenid_to_bias_dict)
            logits_processor_list = LogitsProcessorList([bias_warper])
        outputs = self.model.generate(
            inputs,
            max_new_tokens=max_output_len,
            logits_processor=logits_processor_list,
        )
        return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]


p = Predictor()
p.setup()
p.predict("Hello, my name is", "{'Greg': 8, 'Sam': 10}")

{'Greg': 8, 'Sam': 10}
<class 'str'>


'Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam Sam

In [92]:
!pip install json4

[31mERROR: Could not find a version that satisfies the requirement json4 (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for json4[0m[31m
[0m

In [97]:
import json5


def str_to_dict(s):
    s = s.replace("'", '"').strip()
    return json5.loads(s)


bias_dict_str = "{'Greg': 8, 'Sam': -10}"
bias_dict = str_to_dict(bias_dict_str)


bias_dict_str = "{'Greg': 8, 'Sam': -10}"

# {"Greg": 8, "Sam": -10}  -> {"Greg": 8, "Sam": -10}
# "{'Greg': 8, 'Sam': -10}" -> {"Greg": 8, "Sam": -10}
# """{'Greg': 8, 'Sam': -10}""" -> {"Greg": 8, "Sam": -10}
# """{"Greg": 8, "Sam": -10}""" -> {"Greg": 8, "Sam": -10}
# """ {"Greg":  8, "Sam" : -10 }""" -> {"Greg": 8, "Sam": -10}
# '{"Greg": 8, "Sam": -10} ' -> {"Greg": 8, "Sam": -10}

# Test cases
test_cases = [
    '{"Greg": 8, "Sam": -10}',
    "{'Greg': 8, 'Sam': -10}",
    """{'Greg': 8, 'Sam': -10}""",
    """{"Greg": 8, "Sam": -10}""",
    """ {"Greg":  8, "Sam" : -10 }""",
    '{"Greg": 8, "Sam": -10} ',
]
expected_output = {"Greg": 8, "Sam": -10}

for i, case in enumerate(test_cases):
    result = str_to_dict(case)
    assert result == expected_output, f"Test case {i+1} failed: expected {expected_output}, got {result}"

# print("All test cases passed.")

All test cases passed.


In [None]:
# TODO: Make bias_dict_str an str
# then make it json loads
# good error handling
# pretty docs
