![NVIDIA Logo](images/nvidia.png)

# P-tuning Simplified

Before beginning PEFT for the PubMedQA task, beginning with the PEFT technique p-tuning, we will, in this notebook, cover the main concepts behind p-tuning and construct a simplified p-tuning mechanism to help develop your intuition about p-tuning.

---

## Learning Objectives

By the time you complete this notebook you will understand:
- How the PEFT technique p-tuning works.
- How a relatively small deep learning model can be trained to create embeddings that support specific LLM behavior.
- Why p-tuning is considered an efficient fine-tuning method.

---

## P-tuning Presentation

In [None]:
from llm_utils.slides import load_p_tuning_slides
load_p_tuning_slides()

---

## P-tuning Simplified

For the remainder of the notebook we will create a simplified p-tuning mechanism to assist your intuition about how p-tuning works.

## Imports

In [None]:
import numpy as np

---

## P-tuning

P-tuning is a PEFT technique whereby we train a relatively small multi-layer perceptron **(MLP)** neural network, called a **prompt encoder (PE)**, to create **virtual tokens** (also referred to as virtual prompts, soft prompts, and virtual embeddings) which can be added to LLM input to improve the likelihood that the LLM supplies the kind of generation we would like.

Part of what makes p-tuning an "efficient" fine-tuning method is that the weights of the LLM we are working with are frozen during p-tuning, and only the prompt encoder (PE) weights, which typically consists of only a few million parameters, are trained.

Thus, while it is common to say things like "p-tune an LLM" or "use a p-tuned LLM" in fact the LLM is not updated or trained at all during the p-tuning process, only the small PE.

---

## With vs. Without P-tuning

If we envision a typical prompt to response process with an LLM looking like:

```python
prompt -> (tokenizer) -> tokens -> (embedding_layer) -> embeddings -> (LLM) -> output
```

Then with p-tuning it would be:

```python
prompt -> (tokenizer) -> tokens -> (embedding_layer) -> embeddings -> (MLP PE) -> embeddings+virtual_tokens -> (LLM) -> output
```

---

## The Efficiency of P-tuning

P-tuning can be performed with a relatively small amount of data, only hundreds to a few thousand samples, and depending on the task can result in performance as well as models that went through full supervised fine-tuning for the task.

As an additional benefit, p-tuning is well suited to fine-tuning a single LLM for multiple tasks. Either we can perform a single p-tuning with data that represents multiple tasks we would like the LLM to perform on, or, we can undergo multiple rounds of p-tuning, which results in multiple task-specific prompt encoders. Then during inference, depending on the task we would like to perform, we can simply pass our prompt through the relevant prompt encoder on its way to the LLM, with the possibility of many such prompt encoders existing in front of the LLM for use in a variety of tasks.

---

## P-tuning Training

We'll look in more depth at the p-tuning process below, but in summary it's rather straight forward deep learning training process. During p-tuning we:
1. Provide training data with prompt/desired-response pairs.
2. Include (at first randomly generated) virtual embeddings supplied by the PE to the user-supplied prompt embeddings.
3. Compare the LLM's output to our desired-respone label and calculate a loss.
4. Back propogate through the LLM back to the PE and w/o updating the LLM weights, do update the PE weights
5. Repeat until our training objectives are met

---

## Simulating P-tuning

In order to improve your intuition about how p-tuning works, we will be simulating it in a simplified way in this notebook.

In our simulation of p-tuning we will define the following components:
- A small matrix of fixed weights to represent a simplified version of an LLM.
- A set of **virtual tokens** represented as embeddings.
- A **prompt encoder** (a small matrix or vector) that will be updated to produce better **virtual tokens**.

---

## Embedding Dimension

When a Large Language Model (LLM) converts tokens into embeddings, the size of these embeddings is always a fixed number, determined by the architecture of the LLM. In the spirit of this notebook (taking a *simplified* look at p-tuning) we set the embedding dimension here to a small value.

In [None]:
embedding_dim = 4

---

## Number of Virtual Tokens

During p-tuning, we specify the number of virtual tokens to train, typically ranging from 10 to 50. In the spirit of this notebook we choose here a small value 2.

In [None]:
num_virtual_tokens = 2

---

## The LLM

As you are aware LLMs commonly have billions of parameters consisting of many weight matrices. For this example notebook we define the LLM to be a single, small, weight matrix. Because this single matrix represents the entire LLM, it will both receive the embedding inputs and produce the LLM's output, thus we shape it as `embedding_dim` by `llm_output_dim`.

Worth mentioning that like in legitimate p-tuning, the `llm_weights` will be fixed throughout the process of p-tuning.

In [None]:
llm_output_dim = 3 

llm_weights = np.random.randn(embedding_dim, llm_output_dim)

---

## Initialize the Prompt Encoder

The prompt encoder is typically a several million paramter MLP (multi-layer perceptron) neural network. Here we initialize it as a small matrix.

In [None]:
prompt_encoder = np.random.randn(num_virtual_tokens, embedding_dim)

---

## Initialize the Virtual Tokens

Each virtual token will be the same rank as the model's embedding dimension. Before p-tuning, virtual tokens are initialized to random values.

In [None]:
virtual_tokens = np.random.randn(num_virtual_tokens, embedding_dim)

---

## Simulate Input

Here we simulate the hard prompt input to the LLM as a 1 x embedding dimension vector. Real inputs will be matrices of size number of embeddings x embedding dimension.

In [None]:
input_vector = np.random.randn(1, embedding_dim)

---

## Simulating a Forward Pass Through the LLM

During p-tuning, the prompt encoder updates the virtual tokens, which are then fed into the LLM (represented here by the fixed weight matrix) along with the hard input tokens. We'll simulate a simple forward pass and loss calculation.

In [None]:
def forward_pass(input_vector, virtual_tokens, llm_weights):
    # Concatenate virtual tokens with the input vector
    combined_input = np.concatenate((virtual_tokens, input_vector), axis=0)
    
    # Pass through the LLM
    llm_output = np.dot(combined_input, llm_weights)
    return llm_output

In [None]:
# Forward pass with initial virtual tokens
initial_output = forward_pass(input_vector, virtual_tokens, llm_weights)

---

## Calculate the Loss

Here we assume a simple loss function (mean squared error) against a target output.

In [None]:
target = np.random.randn(1, llm_output_dim)  # Target output
loss = np.mean((initial_output - target)**2)

---

## Calculate the Gradient

In p-tuning, the gradient for the Prompt Encoder is calculated by backpropagating the loss through the entire LLM to determine how changes in the prompt encoder's parameters affect the final output, even though only the parameters of the prompt encoder are updated.

In [None]:
# Dummy gradient: actual gradient for the Prompt Encoder is calculated by backpropagating the loss through the entire LLM
prompt_encoder_gradient = np.random.randn(num_virtual_tokens, embedding_dim)

---

## Update Prompt Encoder Weights

While the LLM weights remain frozen, we of course update the weights of the prompt encoder.

In [None]:
learning_rate = 0.01
prompt_encoder -= learning_rate * prompt_encoder_gradient

---

## Update Virtual Tokens

After the prompt encoder has been updated, it performs a transformation on the original virtual tokens to produce updated virtual tokens.

In [None]:
# Simulate the transformation of virtual tokens using matrix multiplication
# with the updated prompt_encoder
original_virtual_tokens = virtual_tokens
virtual_tokens = original_virtual_tokens * prompt_encoder

---

## Repeat Forward Pass

The updated virtual tokens are now used in the next forward pass as prompt encoder training continues until some specified training objective (number of epochs, validation loss, etc.) has been reached.

In [None]:
# Forward pass with updated virtual tokens
updated_output = forward_pass(input_vector, virtual_tokens, llm_weights)

---

## Post-Training

The final virtual tokens that were the result of the p-tuning process are now used in during post-training inference.

```python
prompt -> (tokenizer) -> tokens -> (embedding_layer) -> embeddings -> (MLP PE) -> embeddings+virtual_tokens -> (LLM) -> output
```