In [11]:
import torch as t
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW
import matplotlib.pyplot as plt
import einops

# Check for GPU availability
device = t.device("cuda" if t.cuda.is_available() else "cpu")

class CustomGPT2Model(GPT2LMHeadModel):
    """
    Custom GPT2 Model that allows setting a custom embedding for a specific token.
    """
    def __init__(self, config):
        super().__init__(config)

    def forward(self, input_ids=None, inputs_embeds=None, custom_token_id_vector=None, **kwargs):
        if inputs_embeds is None:
            inputs_embeds = self.transformer.wte(input_ids)
        
        if custom_token_id_vector is not None and self.keyword_index is not None:
            embedding_matrix = self.transformer.wte.weight
            custom_embed = embedding_matrix[custom_token_id_vector]
            inputs_embeds[0, self.keyword_index] = custom_embed

        return super().forward(inputs_embeds=inputs_embeds, **kwargs)

In [2]:
def tokenize_input(tokenizer, input_text, magic_word):
    """
    Tokenize input text and find the positions of a magic_word.
    """
    tokens = tokenizer.encode(input_text, return_tensors='pt').to(device)
    magic_word_tokens = tokenizer.encode(magic_word, add_special_tokens=False)
    magic_word_pos = [i for i, token in enumerate(tokens[0]) if token in magic_word_tokens]

    if not magic_word_pos:
        raise ValueError(f"Keyword '{magic_word}' not found in input text.")
    return tokens, magic_word_pos[0]

In [None]:
def intialise_random_token_vector(model):
    """
    Returns a random unit-norm vector of length vocab_size
    """
    vocab_size = model.config.vocab_size
    magic_token_vector = t.rand(vocab_size, device=device)
    magic_token_vector /= magic_token_vector.sum()
    magic_token_vector = t.nn.Parameter(magic_token_vector, requires_grad=True)

    return magic_token_vector


In [None]:
def train_token_vector(model, tokens, magic_word_pos, target_token_id, magic_token_vector, lr = 0.01, epochs = 500, l1_lambda = 0.01):
    """
    Perform gradient descent on the magic_token_vector which loss function given by cross-entopy 
    between predicted last token and target_token
    """
    loss_values = []
    optimizer = AdamW([magic_token_vector], lr=0.001)

    for epoch in epochs:
        optimizer.zero_grad()
        outputs = model(tokens, custom_token_id_vector=)


In [10]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = CustomGPT2Model.from_pretrained('gpt2').to(device)
embedding_matrix = model.transformer.wte.weight.data
print(t.norm(embedding_matrix, dim = 1))

tensor([3.0684, 3.0901, 3.5135,  ..., 4.3369, 3.6443, 3.1488], device='cuda:0')
