In [12]:
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

In [13]:
import torch
import torch.nn as nn

In [14]:
# from soft_embedding import SoftEmbedding

n_tokens = 20
initialize_from_vocab = True

In [16]:
import torch
import torch.nn as nn


class SoftEmbedding(nn.Module):
    def __init__(
        self,
        wte: nn.Embedding,
        n_tokens: int = 10,
        random_range: float = 0.5,
        initialize_from_vocab: bool = True,
    ):
        """appends learned embedding to

        Args:
            wte (nn.Embedding): original transformer word embedding
            n_tokens (int, optional): number of tokens for task. Defaults to 10.
            random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5.
            initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.
        """
        super(SoftEmbedding, self).__init__()
        self.wte = wte  # embedding vocab of gpt2
        self.n_tokens = n_tokens
        self.learned_embedding = nn.parameter.Parameter(
            self.initialize_embedding(
                wte, n_tokens, random_range, initialize_from_vocab
            )
        )

        print("self.learned_embedding.shape: ", self.learned_embedding.shape)
    def initialize_embedding(
        self,
        wte: nn.Embedding,
        n_tokens: int = 10,
        random_range: float = 0.5,
        initialize_from_vocab: bool = True,
    ):
        """initializes learned embedding

        Args:
            same as __init__

        Returns:
            torch.float: initialized using original schemes
        """
        if initialize_from_vocab:
            # print("SIZE: ", self.wte.weight[:n_tokens].clone().detach().shape)
            return self.wte.weight[:n_tokens].clone().detach()
        return torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(
            -random_range, random_range
        )

    def forward(self, tokens):
        """run forward pass

        Args:
            tokens (torch.long): input tokens before encoding

        Returns:
            torch.float: encoding of text concatenated with learned task specifc embedding
        """
        print("TOKEN LEN: ", tokens.shape)
        input_embedding = self.wte(tokens[:, self.n_tokens :])
        print("input_embedding: ", input_embedding.shape)
        learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1)
        print("learned_embedding: ", learned_embedding.shape)
        print("RETURN EMBEDDING: ", torch.cat([learned_embedding, input_embedding], 1).shape)
        return torch.cat([learned_embedding, input_embedding], 1)


In [17]:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained('gpt2')

In [18]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [20]:
s_wte = SoftEmbedding(model.get_input_embeddings(), 
                      n_tokens=n_tokens, 
                      initialize_from_vocab=initialize_from_vocab)

self.learned_embedding.shape:  torch.Size([20, 768])


In [21]:
print(s_wte)

SoftEmbedding(
  (wte): Embedding(50257, 768)
)


In [22]:
model.set_input_embeddings(s_wte)

In [10]:
print(model.get_input_embeddings())


SoftEmbedding(
  (wte): Embedding(50257, 768)
)


In [11]:
inputs = tokenizer("May the force be", return_tensors="pt")

# need to pad attention_mask and input_ids to be full seq_len + n_learned_tokens
# even though it does not matter what you pad input_ids with, it's just to make HF happy
inputs['input_ids'] = torch.cat([torch.full((1,n_tokens), 50256), inputs['input_ids']], 1)
inputs['attention_mask'] = torch.cat([torch.full((1,n_tokens), 1), inputs['attention_mask']], 1)

outputs = model(**inputs)

input_embedding:  torch.Size([1, 4, 768])
learned_embedding:  torch.Size([1, 20, 768])
RETURN EMBEDDING:  torch.Size([1, 24, 768])


In [13]:
print(outputs["logits"].shape)
# print(outputs["past_key_values"][0][0].shape)

torch.Size([1, 24, 50257])


In [28]:
outputs.keys()

odict_keys(['logits', 'past_key_values'])