# Welcome!
This is a reference implementation of Plug-and-Blend (https://github.com/xxbidiao/plug-and-blend , which itself is based on https://arxiv.org/abs/2104.04039), using the LogitsProcessor framework new in Huggingface Transformers. Feel free to check them out if you are unclear of anything in this notebook.

# Set things up
Here we will download necessary model to set up the modifier network.

In [1]:
# Downloading the GeDi modifier model.
!wget https://storage.googleapis.com/sfr-gedi-data/gedi_topic.zip
import zipfile
with zipfile.ZipFile('gedi_topic.zip', 'r') as zip_ref:
    zip_ref.extractall('./')

gedi_path = "gedi_topic/"

!pip install transformers
!pip install torch

--2022-02-28 20:36:48--  https://storage.googleapis.com/sfr-gedi-data/gedi_topic.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 108.177.125.128, 142.250.157.128, 142.251.8.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|108.177.125.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1318630072 (1.2G) [application/zip]
Saving to: ‘gedi_topic.zip’


2022-02-28 20:37:19 (41.4 MB/s) - ‘gedi_topic.zip’ saved [1318630072/1318630072]

Collecting transformers
  Downloading transformers-4.16.2-py3-none-any.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 1.3 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 47.0 MB/s 
Collecting tokenizers!=0.11.3,>=0.10.1
  Downloading tokenizers-0.11.6-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.5 MB

Now let's set the Logits Processor up.

In [3]:
import transformers
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, LogitsProcessorList

# Set CUDA device to cuda if gpu is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

gedi_location = gedi_path

class PlugAndBlendLogitsProcessor(transformers.LogitsProcessor):

    gedi_model = GPT2LMHeadModel.from_pretrained(gedi_location).to(device)
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    # default omega from original GeDi work, higher disc_weight means more aggressive topic steering.
    # can be overridden when calling generate_one_sentence(), see that function.
    # default value (1x) is 30.
    omega = 30

    def __init__(self, topic: str, weight: float):
        super().__init__()
        self.topic = topic
        self.weight = weight
        self.encoded_topic = PlugAndBlendLogitsProcessor.tokenizer.encode(topic)[0]

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        #print("Applying topic: %s, weight: %s" % (self.encoded_topic, self.weight))
        # print("test %s" % scores[:, 100])
        # scores[:, 100] += 1
        # print("after %s" % scores[:, 100])
        modifiers = self.get_gedi_modifiers(input_ids = input_ids)

        # Make them appear on the same device
        modifiers = modifiers.to(scores.device)

        scores += modifiers * self.weight * PlugAndBlendLogitsProcessor.omega

        return scores

    def get_gedi_modifiers(self, input_ids):

        # Setting up some constants
        code_0 = "negative"
        code_1 = "positive"
        nt_id = PlugAndBlendLogitsProcessor.tokenizer.encode(code_0)[0]
        pt_id = PlugAndBlendLogitsProcessor.tokenizer.encode(code_1)[0]

        # define class weights for cross entropy loss: give weight 0 to [50256], the padding (eot) token.
        crossentropy_loss_weight = [1] * 50257
        crossentropy_loss_weight[50256] = 0 # do not calculate loss on eos token
        crossentropy_loss_weight = torch.tensor(crossentropy_loss_weight).float().to(device)

        # Creating prefixes.
        seq_pt = (torch.ones(input_ids.shape[0]) * pt_id).type_as(input_ids).view(-1, 1)
        seq_nt = (torch.ones(input_ids.shape[0]) * nt_id).type_as(input_ids).view(-1, 1)
        encoded_topic_torch = (torch.ones(input_ids.shape[0]) * self.encoded_topic).type_as(input_ids).view(-1, 1)

        # Assemble input_ids.
        seq_pt_new = torch.cat((seq_pt, encoded_topic_torch, input_ids), dim=1)[:, :]
        seq_nt_new = torch.cat((seq_nt, encoded_topic_torch, input_ids), dim=1)[:, :]

        def prepare_inputs_for_generation(input_ids, **kwargs):
            return {"input_ids": input_ids.to(device)}

        seq_batched = torch.cat([seq_pt_new,seq_nt_new], dim=0)

        model_inputs = prepare_inputs_for_generation(input_ids=seq_batched)

        gedi_outputs = PlugAndBlendLogitsProcessor.gedi_model(**model_inputs)

        # Let's calculate modifier on the whole sentence:
        # This is modifier on all tokens multiplied.
        # Here, we calculate the baseline (sentence without generated token) modifier, for normalization.

        shift_logits = gedi_outputs["logits"][..., :-1, :].contiguous().to(device)
        shift_labels = seq_batched[..., 1:].contiguous().to(device)

        # By using Cross Entropy on previous tokens,
        # This effectively picked probabilities of previous tokens in the sequence.
        loss_fct = torch.nn.CrossEntropyLoss(reduction="none",
                                             weight=crossentropy_loss_weight,
                                             )

        # Cross entropy loss originally gives -p(x), so...
        logits_r = -1 * loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1),
        )
        logits_r = logits_r.view(seq_batched.shape[0], -1)

        seq_len = logits_r.shape[1]

        logits_r = torch.sum(logits_r, 1)

        # Now, finally add the baseline into the actual final (generated token) logits.
        gedi_logits = torch.log_softmax(gedi_outputs["logits"][:, -1, :], -1)
        gedi_logits += logits_r.unsqueeze(1)

        # Normalize modifier logits by sequence length and reshape it for output
        gedi_logits_split = torch.split(gedi_logits / seq_len,
                                        input_ids.shape[0])

        logits = torch.stack(gedi_logits_split, 2)

        logp_related_softmax = torch.log_softmax(logits, dim=-1)

        # Once normalized, we only care about the "positive" dimension (0).
        final_modifier = logp_related_softmax[...,0]

        return final_modifier

# Tests

def test_generation(prompt = None, topics = None):
    if prompt is None:
      prompt = "Once upon a time,"
    
    if topics is None:
      # default topics
      topics = {"Science":1,"Nature":1}

    

    #print(transformers.__version__)


    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    # Set up the base language model.
    # As this is plug-and-blend, you can change this to any model that uses the GPT2 tokenizer (i.e. has the same input_ids => actual sentence mapping).
    # We are using GPT-2 here just as an example.
    model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)

    # Default prompt
    input_ids = tokenizer.encode(prompt, return_tensors='pt')

    #input_ids = torch.cat([input_ids,input_ids,input_ids],dim=0)

    lp_raw_list = []
    for item in topics:
      lp_raw_list.append(PlugAndBlendLogitsProcessor(topic=item, weight=topics[item]))
    #lp_raw_list = [PlugAndBlendLogitsProcessor(topic="Science", weight=1), PlugAndBlendLogitsProcessor(topic="Nature", weight=1)]

    lp_list = LogitsProcessorList(lp_raw_list)

    greedy_output = model.generate(
        input_ids,
        max_length=50,
        logits_processor=lp_list,
        no_repeat_ngram_size=2,
    )
    print("Output:\n" + 100 * '-')
    print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))

    # greedy_output = model.generate(
    #     input_ids,
    #     max_length=50,
    #     logits_processor=lp_list,
    # )
    # print("Output:\n" + 100 * '-')
    # print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))


Some weights of the model checkpoint at gedi_topic/ were not used when initializing GPT2LMHeadModel: ['logit_scale']
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Generate things (Demo)

This demo showcases generation using GPT-2 as base model. Refer to the content of this function to see how you can use a different model (as long as its tokenizer is `GPT2Tokenizer.from_pretrained("gpt2")` . 

Change test_prompt for prompt; change topics dictionary for topics you want to include in the generated sentence. 1 (all weights added up) gives standard control strength, and in our experiments 2 to 4 gives stronger steering.

In [7]:
test_topics = {"Business":0.5, "Science":0.5}
test_prompt = "Once upon a time,"

test_generation(prompt=test_prompt, topics=test_topics)

Output:
----------------------------------------------------------------------------------------------------
Once upon a time, the world was a place of great beauty and great danger. The world of the gods was the place where the great gods were born, and where they were to live.

The world that was created was not the same
