In [21]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, LogitsWarper, LogitsProcessor, LogitsProcessorList
from cog import BasePredictor, Input
from typing import Dict
import torch

CACHE_DIR = "./src/models"

In [5]:
model_name = ""

model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    cache_dir=CACHE_DIR,
)
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    cache_dir=CACHE_DIR,
)


Downloading (…)lve/main/config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/308M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

In [13]:
prompt = "Hello, my name is "


In [23]:


class MyLogitsWarper(LogitsWarper):
    def __init__(self, bais: Dict):
        # if not isinstance(bais, float) or not (bais > 0):
        #     raise ValueError(f"`bais` has to be a strictly positive float, but is {bais}")
        self.bais = bais

    def __call__(self, input_ids: torch.Tensor, logits: torch.Tensor) -> torch.FloatTensor:
        logits = logits + self.bais
        return logits


In [45]:
inputs.squeeze()[0]

tensor(8774)

In [50]:
ids_to_bais = {i_id.item(): 0 for i_id in inputs.squeeze()}
ids_to_bais

{8774: 0, 6: 0, 82: 0, 564: 0, 19: 0, 3: 0, 1: 0}

In [29]:
inputs = tokenizer(prompt, return_tensors="pt").input_ids # input ids
display(inputs)

biases = [i for i, i_id in enumerate(inputs.squeeze())]

logits_processor_list = LogitsProcessorList([
            MyLogitsWarper()])

# display(logits_processor_list)
outputs = model.generate(inputs, logits_processor=logits_processor_list) # logits
display(outputs)
# tokenizer.batch_decode(outputs)

tensor([[8774,    6,   82,  564,   19,    3,    1]])



TypeError: object of type 'MyLogitsWarper' has no len()

In [63]:
class BiasLogitsWarper(LogitsWarper):
    def __init__(self, bias):
        super().__init__()
        self.bias = {int(k): v for k, v in bias.items()}

    def __call__(self, input_ids, scores):
        print(f"Token IDs: {input_ids}")
        print(f"Logits before bias: {scores}")
        for token_id, bias_value in self.bias.items():
            scores[:, token_id] += bias_value
        print(f"Logits after bias: {scores}")
        return scores

class Predictor(BasePredictor):
    def setup(self):
        self.model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small", cache_dir=CACHE_DIR)
        self.tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small", cache_dir=CACHE_DIR)

    def predict(
        self,
        prompt: str = Input(description="Prompt for language model"),
        bias: dict = Input(description="Token bias map"),
    ) -> str:
        inputs = self.tokenizer(prompt, return_tensors="pt")
        input_ids = inputs.input_ids
        attention_mask = inputs.attention_mask
        bias_warper = BiasLogitsWarper(bias)
        logits_processor_list = LogitsProcessorList([bias_warper])

        # Specify return_dict_in_generate=True to obtain logits
        outputs = self.model.generate(input_ids=input_ids, 
                                      attention_mask=attention_mask, 
                                      logits_processor=logits_processor_list,
                                      max_length=150, 
                                      return_dict_in_generate=True)

        return self.tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)[0]


# Create a predictor instance
predictor = Predictor()

# Set up the predictor (loads the model into memory)
predictor.setup()

# Set a prompt
prompt = "Translate this text to French:"

# Optionally set a bias
bias = {"32127": -100}  # This should be an integer, but I'm using string here as per your code

# Call the predict method
output = predictor.predict(prompt, bias)

print(output)


Token IDs: tensor([[0]])
Logits before bias: tensor([[-35.6063,   2.8283,  -3.1114,  ..., -35.5091, -35.5547, -35.4446]])
Logits after bias: tensor([[ -35.6063,    2.8283,   -3.1114,  ...,  -35.5091,  -35.5547,
         -135.4446]])
Token IDs: tensor([[  0, 622]])
Logits before bias: tensor([[-24.0870,   3.7214,  -1.0533,  ..., -24.0596, -24.1545, -23.8459]])
Logits after bias: tensor([[ -24.0870,    3.7214,   -1.0533,  ...,  -24.0596,  -24.1545,
         -123.8459]])
Token IDs: tensor([[  0, 622,   3]])
Logits before bias: tensor([[-20.8465,   1.6117,   7.6392,  ..., -20.7901, -20.8941, -20.5416]])
Logits after bias: tensor([[ -20.8465,    1.6117,    7.6392,  ...,  -20.7901,  -20.8941,
         -120.5416]])
Token IDs: tensor([[    0,   622,     3, 23881]])
Logits before bias: tensor([[-32.3311,   2.2911,  -2.6924,  ..., -32.2394, -32.3277, -32.1008]])
Logits after bias: tensor([[ -32.3311,    2.2911,   -2.6924,  ...,  -32.2394,  -32.3277,
         -132.1008]])
Token IDs: tensor([[    