<a href="https://colab.research.google.com/github/rohrl/llm_shenanigans/blob/main/soft_prompts.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# based on https://github.com/kipgparker/soft-prompt-tuning/blob/main/example.ipynb

In [2]:
# !pip install sentencepiece transformers accelerate einops

In [3]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

In [4]:
import torch
import torch.nn as nn

In [5]:
torch.set_default_device('cuda')

In [6]:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained('gpt2', device_map="cuda")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [7]:
model.device

device(type='cuda', index=0)

## Sanity check

In [8]:
sanity_text = "The capital of Australia"
sanity_output = model.generate(input_ids = tokenizer.encode(sanity_text, return_tensors="pt"), max_length=10, num_return_sequences=1)
print("==================\n" + tokenizer.decode(sanity_output[0], skip_special_tokens=True) + "\n==================\n")

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


The capital of Australia, Sydney, is home to



## Soft Embeddings

In [9]:
class SoftEmbedding(nn.Module):
    def __init__(self,
                wte: nn.Embedding,
                n_tokens: int = 10,
                random_range: float = 0.5,
                initialize_from_vocab: bool = True):
        """appends learned embedding to

        Args:
            wte (nn.Embedding): original transformer word embedding
            n_tokens (int, optional): number of tokens for task. Defaults to 10.
            random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5.
            initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.
        """
        super(SoftEmbedding, self).__init__()
        self.wte = wte
        self.n_tokens = n_tokens
        self.learned_embedding = nn.parameter.Parameter(self.initialize_embedding(wte,
                                                                               n_tokens,
                                                                               random_range,
                                                                               initialize_from_vocab))

    def initialize_embedding(self,
                             wte: nn.Embedding,
                             n_tokens: int = 10,
                             random_range: float = 0.5,
                             initialize_from_vocab: bool = True):
        """initializes learned embedding

        Args:
            same as __init__

        Returns:
            torch.float: initialized using original schemes
        """
        if initialize_from_vocab:
            # this takes first n_tokens words from vocab and uses as init of learnt embeddings
            return self.wte.weight[:n_tokens].clone().detach()
        return torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range).to('cuda')

    def forward(self, tokens):
        """run forward pass

        Args:
            tokens (torch.long): input tokens before encoding

        Returns:
            torch.float: encoding of text concatenated with learned task specifc embedding
        """
        # below line means that first n_tokens tokens will be ignored (?)
        input_embedding = self.wte(tokens[:, self.n_tokens:])
        learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1)
        return torch.cat([learned_embedding, input_embedding], 1)

In [10]:
# How many soft prompt tokens do we want to use.
num_soft_prompt_tokens = 5 # 20
initialize_from_vocab = False  # True

In [11]:
model.get_input_embeddings()

Embedding(50257, 768)

In [12]:
tokenizer

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

In [13]:
s_wte = SoftEmbedding(model.get_input_embeddings(),
                      n_tokens = num_soft_prompt_tokens,
                      initialize_from_vocab = initialize_from_vocab)

In [14]:
s_wte

SoftEmbedding(
  (wte): Embedding(50257, 768)
)

In [15]:
model.set_input_embeddings(s_wte)

In [16]:
def prepend_with_soft_prompts_padding(inputs, num_soft_tokens, pad_token_id = tokenizer.unk_token_id, labels = None):
    """
    Need to pad attention_mask and input_ids to be full seq_len + n_learned_tokens,
    even though it does not matter what you pad input_ids with, it's just to make HF happy.
    More exp: the SoftEmbedding implementation ignores first num_soft_prompt_tokens of input tokens so this padding is to insert them at the beginning (and also make consistent with attention_mask length)
    Padding is made of repeated "unk_token" (but it doesn't matter as it's ignored).
    """
    batch_size = inputs['input_ids'].size(0)

    inputs['input_ids'] = torch.cat([torch.full((batch_size, num_soft_tokens), pad_token_id), inputs['input_ids']], 1)
    inputs['attention_mask'] = torch.cat([torch.full((batch_size, num_soft_tokens), 1), inputs['attention_mask']], 1)

    if labels is None:
        return inputs
    else:
        labels = torch.cat([torch.full((batch_size, num_soft_tokens), pad_token_id), labels], 1)
        return inputs, labels


## Inference

In [32]:
inputs = tokenizer("The capital of Australia", return_tensors="pt")


In [33]:
inputs

{'input_ids': tensor([[ 464, 3139,  286, 4505]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1]], device='cuda:0')}

In [34]:
tokenizer.decode(inputs['input_ids'].squeeze(), skip_special_tokens=False)

'The capital of Australia'

In [35]:
inputs = prepend_with_soft_prompts_padding(inputs, num_soft_prompt_tokens)

print(inputs)
print(tokenizer.decode(inputs['input_ids'].squeeze(), skip_special_tokens=False))

{'input_ids': tensor([[50256, 50256, 50256, 50256, 50256,   464,  3139,   286,  4505]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}
<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>The capital of Australia


In [36]:
# outputs = model(**inputs)

new_out_tokens = 10
curr_inputs = inputs

new_token_id = 0
outputs = torch.cat([inputs['input_ids'], torch.full((1, new_out_tokens), 0) ], 1)


model.eval()

with torch.no_grad():
  for i in range(new_out_tokens):

    # outputs = model.generate(**inputs, max_length = curr_inputs['input_ids'].size(1) + 1)
    raw_outputs = model(**curr_inputs)
    # print(raw_outputs.logits.shape)

    # new_token_id = outputs.squeeze()[-1]
    new_token_id = raw_outputs.logits[:,-1,:].argmax(axis=-1).item()
    outputs[:, (-new_out_tokens+i)] = new_token_id
    # print(outputs)

    # add the new token to inputs and repeat
    curr_inputs['input_ids'] = torch.cat([curr_inputs['input_ids'], torch.full((1, 1), new_token_id)], 1)
    curr_inputs['attention_mask'] = torch.cat([curr_inputs['attention_mask'], torch.full((1,1), 1)], 1)



In [37]:
# print(outputs.logits.shape)
print(outputs)

predicted_token_ids = outputs.squeeze()

tensor([[50256, 50256, 50256, 50256, 50256,   464,  3139,   286,  4505,   318,
         33452,    13, 33452,   318, 33452,    13, 33452,   318, 33452]],
       device='cuda:0')


In [38]:
text = tokenizer.decode(predicted_token_ids, skip_special_tokens=False) #[0]

# Print the decoded text
print(f"|{text}|")

|<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>The capital of Australia is Canberra. Canberra is Canberra. Canberra is Canberra|


## Training

In [24]:
import torch
import torch.nn as nn
import torch.optim as optim
import copy

# this token is ignored in loss - used to mask remainder of output
ignored_token_id = tokenizer.unk_token_id

target = "The capital of Australia is Canberra."

target_tokens = tokenizer(target, return_tensors="pt")
target_len = target_tokens['input_ids'].size(1)

print(target_tokens['input_ids'])

# create the batch by repeating tokens, then in the loop mask endings

target_tokens['input_ids'] = target_tokens['input_ids'].repeat(target_len - 1, 1)
target_tokens['attention_mask'] = target_tokens['attention_mask'].repeat(target_len - 1, 1)

# labels will be the next token, so clone and left-shift 1 hop
labels = target_tokens['input_ids'].clone()
labels = labels.roll(-1, dims=-1)

# add masks
for i in range(target_len - 1):
  # pad right of i
  #labels.append(target_tokens['input_ids'][i, i+1].item())
  target_tokens['input_ids'][i, i+1:] = torch.full((1, target_len - i - 1), ignored_token_id)
  labels[i, i+1:] = torch.full((1, target_len - i - 1), ignored_token_id)
  target_tokens['attention_mask'][i, i+1:] = torch.full((1, target_len - i - 1), 0)

# last token will never be fed as input so trim all tensors
target_tokens['input_ids'] = target_tokens['input_ids'][:, :-1]
target_tokens['attention_mask'] = target_tokens['attention_mask'][:, :-1]
labels = labels[:, :-1]

# Finally, pad inputs with soft prompts
target_tokens, labels = prepend_with_soft_prompts_padding(target_tokens, num_soft_prompt_tokens, labels = labels)


print(target_tokens['input_ids'])
print(target_tokens['attention_mask'])
print(labels)


tensor([[  464,  3139,   286,  4505,   318, 33452,    13]], device='cuda:0')
tensor([[50256, 50256, 50256, 50256, 50256,   464, 50256, 50256, 50256, 50256,
         50256],
        [50256, 50256, 50256, 50256, 50256,   464,  3139, 50256, 50256, 50256,
         50256],
        [50256, 50256, 50256, 50256, 50256,   464,  3139,   286, 50256, 50256,
         50256],
        [50256, 50256, 50256, 50256, 50256,   464,  3139,   286,  4505, 50256,
         50256],
        [50256, 50256, 50256, 50256, 50256,   464,  3139,   286,  4505,   318,
         50256],
        [50256, 50256, 50256, 50256, 50256,   464,  3139,   286,  4505,   318,
         33452]], device='cuda:0')
tensor([[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')
tensor([[50256, 50256, 50256, 50256, 50256,  3139, 50

In [25]:
# freeze entire model, then unfreeze soft embeddings
model.requires_grad_(False)
s_wte.requires_grad_(True)

SoftEmbedding(
  (wte): Embedding(50257, 768)
)

In [26]:
def compute_loss(criterion, logits, labels):

    logits_flat = logits.view(-1, logits.size(-1))
    # print(logits_flat.shape)

    labels_flat = labels.flatten()
    # print(labels_flat.shape)

    loss = criterion(logits_flat, labels_flat)

    loss_per_batch = loss.mean()

    return loss_per_batch

In [27]:
model.train()

criterion = nn.CrossEntropyLoss(ignore_index = ignored_token_id, reduction='none')
optimizer = optim.SGD(model.parameters(), lr=0.01) # Try Adam

best_loss = 1e9
best_soft_prompts = None

# Train the model
num_epochs = 200

for epoch in range(num_epochs):
    # Forward pass
    outputs = model(input_ids = target_tokens['input_ids'], attention_mask = target_tokens['attention_mask'])

    # print(outputs.logits.shape) #, outputs.logits)
    # print(labels.shape, labels)

    loss = compute_loss(criterion, outputs.logits, labels)

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Save best
    if loss.item() < best_loss:
        best_loss = loss.item()
        best_soft_prompts = copy.deepcopy(s_wte) # .clone() fails :(
        print('--- NEW BEST: Epoch: {}, Loss: {:.4f}'.format(epoch+1, best_loss))


    # Print the loss
    if epoch % 10 == 0:
        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))


--- NEW BEST: Epoch: 1, Loss: 1.9366
Epoch [1/200], Loss: 1.9366
Epoch [11/200], Loss: 9.0823
Epoch [21/200], Loss: 6.3038
--- NEW BEST: Epoch: 29, Loss: 1.8657
--- NEW BEST: Epoch: 31, Loss: 1.7750
Epoch [31/200], Loss: 1.7750
--- NEW BEST: Epoch: 32, Loss: 0.9168
--- NEW BEST: Epoch: 35, Loss: 0.7546
Epoch [41/200], Loss: 2.6753
Epoch [51/200], Loss: 1.4778
Epoch [61/200], Loss: 1.3136
--- NEW BEST: Epoch: 64, Loss: 0.7347
Epoch [71/200], Loss: 0.9173
--- NEW BEST: Epoch: 74, Loss: 0.5162
Epoch [81/200], Loss: 0.7370
--- NEW BEST: Epoch: 91, Loss: 0.5002
Epoch [91/200], Loss: 0.5002
--- NEW BEST: Epoch: 92, Loss: 0.3778
--- NEW BEST: Epoch: 93, Loss: 0.3711
--- NEW BEST: Epoch: 96, Loss: 0.3488
Epoch [101/200], Loss: 0.5596
--- NEW BEST: Epoch: 103, Loss: 0.3432
Epoch [111/200], Loss: 0.3456
--- NEW BEST: Epoch: 117, Loss: 0.3326
Epoch [121/200], Loss: 0.5032
--- NEW BEST: Epoch: 124, Loss: 0.2376
--- NEW BEST: Epoch: 125, Loss: 0.2256
Epoch [131/200], Loss: 0.5356
--- NEW BEST: Epoc

In [28]:
best_soft_prompts

SoftEmbedding(
  (wte): Embedding(50257, 768)
)

In [29]:
# Set the best soft prompts on the model.
model.set_input_embeddings(best_soft_prompts)

In [30]:
print("===============================================================================================")
print(">>>>> Now go back to Inference section and see what you get with trained soft prompts =] <<<<<<")
print("===============================================================================================")

>>>>> Now go back to Inference section and see what you get with trained soft prompts =] <<<<<<


In [31]:
# USING HF LIBRARY
# from transformers import TrainingArguments, Trainer

# training_args = TrainingArguments(
#     output_dir="./model_checkpoints",  # Output directory for checkpoints
#     num_train_epochs=3,  # Total number of training epochs
#     per_device_train_batch_size=16,  # Batch size per device
#     per_device_eval_batch_size=16,  # Batch size for evaluation
#     warmup_steps=500,  # Number of warmup steps
#     logging_steps=100,  # Number of steps between logging
#     save_steps=1000,  # Number of steps between saving checkpoints
#     evaluation_strategy="steps",  # Evaluation strategy
#     eval_steps=1000,  # Number of steps between evaluations
# )

# trainer = Trainer(
#     model=model,  # The model to train
#     args=training_args,  # Training arguments
#     train_dataset=train_dataset,  # Training dataset
#     eval_dataset=eval_dataset,  # Evaluation dataset
# )

# trainer.train()