In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from omegaconf import DictConfig

import transformers
from transformers import RobertaTokenizer, GPT2Model, AutoTokenizer

from torchfly.nn.transformers import GPT2LMHeadModel
from torchfly.common.download import get_pretrained_weights
from torchfly.text.decode import TransformerDecoder

In [2]:
class GPT2MediumConfig:
    initializer_range= 0.02
    layer_norm_epsilon= 1e-05
    n_ctx= 1024
    n_embd= 1024
    n_head= 16
    n_layer= 24
    n_positions= 1024
    vocab_size= 50265
    embd_pdrop= 0.0
    resid_pdrop= 0.0
    attn_pdrop= 0.0
    output_attentions= False
    output_hidden_states= False
    output_past= True
    pad_token_id= 1
    name= "roberta-tokenized-gpt2-medium"

In [3]:
def random_interpolate(x, vocab_size):
    return x[np.random.randint(vocab_size, size=20), :].mean(0)  + \
            torch.randn(x.shape[-1]) * 0.001

In [4]:
new_model = GPT2LMHeadModel(GPT2MediumConfig)

In [5]:
num_special_tokens = 8
original_vocab_size = 50257
model_name = "gpt2-medium"

In [6]:
roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
gpt_tokenizer = AutoTokenizer.from_pretrained("gpt2")
special_tokens = set(roberta_tokenizer.encoder.keys()).difference(set(gpt_tokenizer.encoder.keys()))

In [7]:
gpt2_model = transformers.GPT2LMHeadModel.from_pretrained(model_name)

# copy the original embedding
new_embedding = nn.Embedding(gpt2_model.config.vocab_size + num_special_tokens, gpt2_model.config.n_embd)
new_embedding.weight.data[:gpt2_model.config.vocab_size, :] = gpt2_model.transformer.wte.weight.data

# for the first three, use random interpolate
for i in range(num_special_tokens):
    new_embedding.weight.data[gpt2_model.config.vocab_size+i, :] = \
        random_interpolate(new_embedding.weight.data, gpt2_model.config.vocab_size)

# redefine the mapping
r2g_mapping = []
special_start_count = original_vocab_size

for r_idx in range(original_vocab_size + num_special_tokens):
    r_token = roberta_tokenizer.decoder[r_idx]
    
    if r_token in special_tokens:
        r2g_mapping.append(special_start_count)
        special_start_count += 1
    else:
        g_idx = gpt_tokenizer.encoder[r_token]
        r2g_mapping.append(g_idx)

# rearrange the weight
# gpt2_model_states = gpt2_model.state_dict()
# gpt2_model_states = {k: v for k, v in gpt2_model_states.items() if '.attn.bias' not in k}
# gpt2_model_states['transformer.wte.weight'].data = gpt2_model_states['transformer.wte.weight'][r2g_mapping]
# gpt2_model.transformer.wte.weight.data = new_embedding.weight.data

In [8]:
gpt2_model.transformer.wte = nn.Embedding.from_pretrained(new_embedding.weight[r2g_mapping].detach().data)
gpt2_model.lm_head.weight = gpt2_model.transformer.wte.weight

In [9]:
new_model.load_state_dict(gpt2_model.state_dict(), strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['transformer.h.0.attn.bias', 'transformer.h.1.attn.bias', 'transformer.h.2.attn.bias', 'transformer.h.3.attn.bias', 'transformer.h.4.attn.bias', 'transformer.h.5.attn.bias', 'transformer.h.6.attn.bias', 'transformer.h.7.attn.bias', 'transformer.h.8.attn.bias', 'transformer.h.9.attn.bias', 'transformer.h.10.attn.bias', 'transformer.h.11.attn.bias', 'transformer.h.12.attn.bias', 'transformer.h.13.attn.bias', 'transformer.h.14.attn.bias', 'transformer.h.15.attn.bias', 'transformer.h.16.attn.bias', 'transformer.h.17.attn.bias', 'transformer.h.18.attn.bias', 'transformer.h.19.attn.bias', 'transformer.h.20.attn.bias', 'transformer.h.21.attn.bias', 'transformer.h.22.attn.bias', 'transformer.h.23.attn.bias'])

In [10]:
torch.save(new_model.state_dict(), f"{GPT2MediumConfig.name}.pth")

## Test

In [11]:
model = new_model
# model = transformers.GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = roberta_tokenizer

In [12]:
decoder_helper = TransformerDecoder(DictConfig({}))

In [13]:
decoder_helper.register_generator(model)
decoder_helper.register_tokenizer(tokenizer)

In [14]:
sentence = tokenizer.encode("A dog has an", add_special_tokens=False)
input_ids = torch.LongTensor([sentence])

In [15]:
results = decoder_helper.generate(input_ids=input_ids, num_beams=1, do_sample=False)

In [16]:
tokenizer.decode((results["tokens"][0][0]).tolist())

' emotional bond with its owner. It is a bond that is stronger than any other bond. It is'

In [17]:
print(results["tokens"][0][0])

tensor([3722, 2175,   19,   63, 1945,    4,   85,   16,   10, 2175,   14,   16,
        3651,   87,  143,   97, 2175,    4,   85,   16])


In [18]:
input_ids

tensor([[ 250, 2335,   34,   41]])