<a href="https://colab.research.google.com/github/rohrl/llm_shenanigans/blob/main/soft_prompts.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [181]:
# based on https://github.com/kipgparker/soft-prompt-tuning/blob/main/example.ipynb

In [182]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

In [183]:
import torch
import torch.nn as nn

In [184]:
class SoftEmbedding(nn.Module):
    def __init__(self,
                wte: nn.Embedding,
                n_tokens: int = 10,
                random_range: float = 0.5,
                initialize_from_vocab: bool = True):
        """appends learned embedding to

        Args:
            wte (nn.Embedding): original transformer word embedding
            n_tokens (int, optional): number of tokens for task. Defaults to 10.
            random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5.
            initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.
        """
        super(SoftEmbedding, self).__init__()
        self.wte = wte
        self.n_tokens = n_tokens
        self.learned_embedding = nn.parameter.Parameter(self.initialize_embedding(wte,
                                                                               n_tokens,
                                                                               random_range,
                                                                               initialize_from_vocab))

    def initialize_embedding(self,
                             wte: nn.Embedding,
                             n_tokens: int = 10,
                             random_range: float = 0.5,
                             initialize_from_vocab: bool = True):
        """initializes learned embedding

        Args:
            same as __init__

        Returns:
            torch.float: initialized using original schemes
        """
        if initialize_from_vocab:
            # this takes first n_tokens words from vocab and uses as init of learnt embeddings
            return self.wte.weight[:n_tokens].clone().detach()
        return torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range)

    def forward(self, tokens):
        """run forward pass

        Args:
            tokens (torch.long): input tokens before encoding

        Returns:
            torch.float: encoding of text concatenated with learned task specifc embedding
        """
        # below line means that first n_tokens tokens will be ignored (?)
        input_embedding = self.wte(tokens[:, self.n_tokens:])
        learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1)
        return torch.cat([learned_embedding, input_embedding], 1)

In [185]:
n_tokens = 3 # 20
initialize_from_vocab = False  # True

In [186]:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained('gpt2')

In [187]:
model.get_input_embeddings()

Embedding(50257, 768)

In [188]:
tokenizer

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

In [189]:
s_wte = SoftEmbedding(model.get_input_embeddings(),
                      n_tokens=n_tokens,
                      initialize_from_vocab=initialize_from_vocab)

In [190]:
s_wte

SoftEmbedding(
  (wte): Embedding(50257, 768)
)

In [191]:
model.set_input_embeddings(s_wte)

In [348]:
inputs = tokenizer("May the force be", return_tensors="pt")


In [349]:
inputs

{'input_ids': tensor([[6747,  262, 2700,  307]]), 'attention_mask': tensor([[1, 1, 1, 1]])}

In [350]:
tokenizer.decode(inputs.input_ids.squeeze(), skip_special_tokens=False)

'May the force be'

In [351]:

# need to pad attention_mask and input_ids to be full seq_len + n_learned_tokens
# even though it does not matter what you pad input_ids with, it's just to make HF happy
# more exp: the SoftEmbedding implementation ignores first n_tokens of input tokens so this padding is to insert them at the beginning (and also make consistent with attention_mask length)
# Padding is made of repeated "unk_token" (but it doesn't matter as it's ignored)
inputs['input_ids'] = torch.cat([torch.full((1,n_tokens), tokenizer.unk_token_id), inputs['input_ids']], 1)
inputs['attention_mask'] = torch.cat([torch.full((1,n_tokens), 1), inputs['attention_mask']], 1)


In [352]:
print(inputs)
print(tokenizer.decode(inputs.input_ids.squeeze(), skip_special_tokens=False))

{'input_ids': tensor([[50256, 50256, 50256,  6747,   262,  2700,   307]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}
<|endoftext|><|endoftext|><|endoftext|>May the force be


In [353]:
# outputs = model(**inputs)

new_out_tokens = 10
curr_inputs = inputs

new_token_id = 0
outputs = torch.cat([inputs.input_ids, torch.full((1, new_out_tokens), 0) ], 1)

for i in range(new_out_tokens):

  # outputs = model.generate(**inputs, max_length = curr_inputs.input_ids.size(1) + 1)
  raw_outputs = model(**curr_inputs)
  # print(raw_outputs.logits.shape)

  # new_token_id = outputs.squeeze()[-1]
  new_token_id = raw_outputs.logits[:,-1,:].argmax(axis=-1).item()
  outputs[:, (-new_out_tokens+i)] = new_token_id
  # print(outputs)

  # add the new token to inputs and repeat
  curr_inputs['input_ids'] = torch.cat([curr_inputs['input_ids'], torch.full((1, 1), new_token_id)], 1)
  curr_inputs['attention_mask'] = torch.cat([curr_inputs['attention_mask'], torch.full((1,1), 1)], 1)



In [354]:
# print(outputs.logits.shape)
print(outputs)

predicted_token_ids = outputs.squeeze()

tensor([[50256, 50256, 50256,  6747,   262,  2700,   307,   351,   345,    13,
           198,   198,   464,   691,  1517,   326,   338]])


In [355]:
text = tokenizer.decode(predicted_token_ids, skip_special_tokens=False) #[0]

# Print the decoded text
print(f"|{text}|")

|<|endoftext|><|endoftext|><|endoftext|>May the force be with you.

The only thing that's|
