In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torchfly.modules.transformers import (GPT2SimpleLM,
                                           UnifiedGPT2SmallConfig,
                                           UnifiedGPT2MediumConfig,
                                           UnifiedGPT2LargeConfig,
                                           UnifiedGPT2DistillConfig,
                                           UnifiedGPT2XLConfig)
                                           
from torchfly.text.tokenizers import UnifiedBPETokenizer
from torchfly.utils import get_pretrained_states
from transformers import GPT2LMHeadModel, RobertaTokenizer, GPT2Tokenizer, RobertaModel

from torchfly.text.decode import top_filtering

In [4]:
states = get_pretrained_states("unified-gpt2-large")

File exists: /home/wuqy1203/.cache/torchfly/models/unified-gpt2-large.pth


In [3]:
model = GPT2SimpleLM(UnifiedGPT2LargeConfig)

In [6]:
model.load_state_dict(states, strict=False)

_IncompatibleKeys(missing_keys=['lm_head.decoder.weight'], unexpected_keys=['transformer.h.0.attn.bias', 'transformer.h.1.attn.bias', 'transformer.h.2.attn.bias', 'transformer.h.3.attn.bias', 'transformer.h.4.attn.bias', 'transformer.h.5.attn.bias', 'transformer.h.6.attn.bias', 'transformer.h.7.attn.bias', 'transformer.h.8.attn.bias', 'transformer.h.9.attn.bias', 'transformer.h.10.attn.bias', 'transformer.h.11.attn.bias', 'transformer.h.12.attn.bias', 'transformer.h.13.attn.bias', 'transformer.h.14.attn.bias', 'transformer.h.15.attn.bias', 'transformer.h.16.attn.bias', 'transformer.h.17.attn.bias', 'transformer.h.18.attn.bias', 'transformer.h.19.attn.bias', 'transformer.h.20.attn.bias', 'transformer.h.21.attn.bias', 'transformer.h.22.attn.bias', 'transformer.h.23.attn.bias', 'transformer.h.24.attn.bias', 'transformer.h.25.attn.bias', 'transformer.h.26.attn.bias', 'transformer.h.27.attn.bias', 'transformer.h.28.attn.bias', 'transformer.h.29.attn.bias', 'transformer.h.30.attn.bias', 'tra

In [9]:
model.lm_head.decoder.weight

Parameter containing:
tensor([[-0.0149, -0.0209,  0.0021,  ...,  0.0336, -0.0005, -0.0090],
        [ 0.0055, -0.0438,  0.0013,  ...,  0.0671,  0.0329, -0.0399],
        [ 0.0585,  0.0603,  0.0302,  ..., -0.1041, -0.0566, -0.0330],
        ...,
        [-0.0209,  0.0443,  0.0327,  ..., -0.0081, -0.0022,  0.0123],
        [ 0.0028,  0.0243,  0.0061,  ..., -0.0177, -0.0040,  0.0213],
        [-0.0049,  0.0257,  0.0111,  ..., -0.0429,  0.0170,  0.0346]],
       requires_grad=True)

In [None]:
model.transformer.e

In [17]:
torch.save(model.state_dict(), "roberta-large.pth")
torch.save(model.half().state_dict(), "roberta-large-fp16.pth")

# Modify the tokenizer

In [2]:
roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
roberta_tokenizer.all_special_tokens

['<s>', '</s>', '<pad>', '<unk>', '<mask>']

In [3]:
gpt_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# Modify the model

In [4]:
def random_interpolate(x, vocab_size):
    return x[np.random.randint(vocab_size, size=20), :].mean(0) + \
            torch.randn(x.shape[-1]) * 0.01

### Small Model

In [5]:
# load the model
gpt2_small = GPT2LMHeadModel.from_pretrained("gpt2")

In [6]:
num_special_tokens = 8

# copy the original embedding
new_embedding = nn.Embedding(gpt2_small.config.vocab_size + num_special_tokens, gpt2_small.config.n_embd)
new_embedding.weight.data[:gpt2_small.config.vocab_size, :] = gpt2_small.transformer.wte.weight.data

In [7]:
# for the first three, use random interpolate
for i in range(num_special_tokens):
    new_embedding.weight.data[gpt2_small.config.vocab_size+i, :] = \
        random_interpolate(new_embedding.weight.data, gpt2_small.config.vocab_size)

In [8]:
# set tied
gpt2_small.transformer.wte = new_embedding
gpt2_small.lm_head.weight = gpt2_small.transformer.wte.weight

In [9]:
# get special tokens
special_tokens = set(roberta_tokenizer.encoder.keys()).difference(set(gpt_tokenizer.encoder.keys()))

In [10]:
# redefine the mapping
r2g_mapping = []
special_start_count = 50257

for r_idx in range(50265):
    r_token = roberta_tokenizer.decoder[r_idx]
    
    if r_token in special_tokens:
        r2g_mapping.append(special_start_count)
        special_start_count += 1
    else:
        g_idx = gpt_tokenizer.encoder[r_token]
        r2g_mapping.append(g_idx)
        
# r2g = {v:k for k,v in enumerate(r2g_mapping)}
# r2g_mapping = [r2g[i] for i in range(50265)]

In [11]:

# rearrange the weight
gpt2_small_states = gpt2_small.state_dict()
gpt2_small_states = {k: v for k, v in gpt2_small_states.items() if '.attn.bias' not in k}
# gpt2_small_states['transformer.wte.weight'] = gpt2_small_states['transformer.wte.weight'][r2g_mapping]

### Medium

In [12]:
# get special tokens
special_tokens = set(roberta_tokenizer.encoder.keys()).difference(set(gpt_tokenizer.encoder.keys()))

In [13]:
# load the model
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2-medium")

num_special_tokens = 8

# copy the original embedding
new_embedding = nn.Embedding(gpt2_model.config.vocab_size + num_special_tokens, gpt2_model.config.n_embd)
new_embedding.weight.data[:gpt2_model.config.vocab_size, :] = gpt2_model.transformer.wte.weight.data

# for the first three, use random interpolate
for i in range(num_special_tokens):
    new_embedding.weight.data[gpt2_model.config.vocab_size+i, :] = \
        random_interpolate(new_embedding.weight.data, gpt2_model.config.vocab_size)

# set tied
gpt2_model.transformer.wte = new_embedding
gpt2_model.lm_head.weight = gpt2_model.transformer.wte.weight

# redefine the mapping
r2g_mapping = []
special_start_count = 50257

for r_idx in range(50265):
    r_token = roberta_tokenizer.decoder[r_idx]
    
    if r_token in special_tokens:
        r2g_mapping.append(special_start_count)
        special_start_count += 1
    else:
        g_idx = gpt_tokenizer.encoder[r_token]
        r2g_mapping.append(g_idx)

# rearrange the weight
gpt2_model_states = gpt2_model.state_dict()
gpt2_model_states = {k: v for k, v in gpt2_model_states.items() if '.attn.bias' not in k}
gpt2_model_states['transformer.wte.weight'] = gpt2_model_states['transformer.wte.weight'][r2g_mapping]

In [9]:
torch.save(gpt2_model.state_dict(), "unified_gpt2_medium.pth")

In [14]:
torch.save(gpt2_model.half().state_dict(), "unified_gpt2_medium_fp16.pth")

### Large

In [10]:
# load the model
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2-large")

num_special_tokens = 8

# copy the original embedding
new_embedding = nn.Embedding(gpt2_model.config.vocab_size + num_special_tokens, gpt2_model.config.n_embd)
new_embedding.weight.data[:gpt2_model.config.vocab_size, :] = gpt2_model.transformer.wte.weight.data

# for the first three, use random interpolate
for i in range(num_special_tokens):
    new_embedding.weight.data[gpt2_model.config.vocab_size+i, :] = \
        random_interpolate(new_embedding.weight.data, gpt2_model.config.vocab_size)

# set tied
gpt2_model.transformer.wte = new_embedding
gpt2_model.lm_head.weight = gpt2_model.transformer.wte.weight

# redefine the mapping
r2g_mapping = []
special_start_count = 50257

for r_idx in range(50265):
    r_token = roberta_tokenizer.decoder[r_idx]
    
    if r_token in special_tokens:
        r2g_mapping.append(special_start_count)
        special_start_count += 1
    else:
        g_idx = gpt_tokenizer.encoder[r_token]
        r2g_mapping.append(g_idx)
        
# r2g = {v:k for k,v in enumerate(r2g_mapping)}
# r2g_mapping = [r2g[i] for i in range(50265)]

# rearrange the weight
gpt2_model_states = gpt2_model.state_dict()
gpt2_model_states = {k: v for k, v in gpt2_model_states.items() if '.attn.bias' not in k}
gpt2_model_states['transformer.wte.weight'] = gpt2_model_states['transformer.wte.weight'][r2g_mapping]

In [11]:
torch.save(gpt2_model.state_dict(), "unified_gpt2_large.pth")

In [11]:
torch.save(gpt2_model.half().state_dict(), "unified_gpt2_large_fp16.pth")

### Extra Large

In [7]:
# load the model
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2-xl")

num_special_tokens = 8

# copy the original embedding
new_embedding = nn.Embedding(gpt2_model.config.vocab_size + num_special_tokens, gpt2_model.config.n_embd)
new_embedding.weight.data[:gpt2_model.config.vocab_size, :] = gpt2_model.transformer.wte.weight.data

# for the first three, use random interpolate
for i in range(num_special_tokens):
    new_embedding.weight.data[gpt2_model.config.vocab_size+i, :] = \
        random_interpolate(new_embedding.weight.data, gpt2_model.config.vocab_size)

# set tied
gpt2_model.transformer.wte = new_embedding
gpt2_model.lm_head.weight = gpt2_model.transformer.wte.weight

# redefine the mapping
r2g_mapping = []
special_start_count = 50257

for r_idx in range(50265):
    r_token = roberta_tokenizer.decoder[r_idx]
    
    if r_token in special_tokens:
        r2g_mapping.append(special_start_count)
        special_start_count += 1
    else:
        g_idx = gpt_tokenizer.encoder[r_token]
        r2g_mapping.append(g_idx)
        
# r2g = {v:k for k,v in enumerate(r2g_mapping)}
# r2g_mapping = [r2g[i] for i in range(50265)]

# rearrange the weight
gpt2_model_states = gpt2_model.state_dict()
gpt2_model_states = {k: v for k, v in gpt2_model_states.items() if '.attn.bias' not in k}
gpt2_model_states['transformer.wte.weight'] = gpt2_model_states['transformer.wte.weight'][r2g_mapping]

In [8]:
torch.save(gpt2_model.state_dict(), "unified_gpt2_xl.pth")

In [9]:
torch.save(gpt2_model.half().state_dict(), "unified_gpt2_xl_fp16.pth")

### Distill GPT-2

In [14]:
# load the model
gpt2_model = GPT2LMHeadModel.from_pretrained("distilgpt2")

num_special_tokens = 8

# copy the original embedding
new_embedding = nn.Embedding(gpt2_model.config.vocab_size + num_special_tokens, gpt2_model.config.n_embd)
new_embedding.weight.data[:gpt2_model.config.vocab_size, :] = gpt2_model.transformer.wte.weight.data

# for the first three, use random interpolate
for i in range(num_special_tokens):
    new_embedding.weight.data[gpt2_model.config.vocab_size+i, :] = \
        random_interpolate(new_embedding.weight.data, gpt2_model.config.vocab_size)

# set tied
gpt2_model.transformer.wte = new_embedding
gpt2_model.lm_head.weight = gpt2_model.transformer.wte.weight

# redefine the mapping
r2g_mapping = []
special_start_count = 50257

for r_idx in range(50265):
    r_token = roberta_tokenizer.decoder[r_idx]
    
    if r_token in special_tokens:
        r2g_mapping.append(special_start_count)
        special_start_count += 1
    else:
        g_idx = gpt_tokenizer.encoder[r_token]
        r2g_mapping.append(g_idx)
        
# r2g = {v:k for k,v in enumerate(r2g_mapping)}
# r2g_mapping = [r2g[i] for i in range(50265)]

# rearrange the weight
gpt2_model_states = gpt2_model.state_dict()
gpt2_model_states = {k: v for k, v in gpt2_model_states.items() if '.attn.bias' not in k}
gpt2_model_states['transformer.wte.weight'] = gpt2_model_states['transformer.wte.weight'][r2g_mapping]

100%|██████████| 574/574 [00:00<00:00, 378364.06B/s]
100%|██████████| 352833716/352833716 [00:30<00:00, 11435213.11B/s]


In [15]:
torch.save(gpt2_model.state_dict(), "unified_gpt2_distill.pth")

### some random test

In [17]:
r2g_mapping[40]

481

In [18]:
gpt_tokenizer.decoder[40]

'I'

In [19]:
roberta_tokenizer.decoder[100]

'I'

In [20]:
gpt_tokenizer.encode("I like cats.")

[40, 588, 11875, 13]

In [21]:
# save the final result
# torch.save(new_gpt2_small.state_dict(), "unified_gpt2_small.pth")

# Test Purpose

In [None]:
new_gpt2 = GPT2SimpleLM(UnifiedGPT2MediumConfig)
new_gpt2.load_state_dict(gpt2_model_states, strict=False)

device = torch.device("cuda")
model = new_gpt2
model = model.to(device)

# ask more about news
prompt = roberta_tokenizer.encode("There is a cat. ")
prompt = torch.LongTensor(prompt).to(device)

prev_word = prompt.unsqueeze(0)
past = None
sentence = []

model.eval()

for timestep in range(256):
    
    with torch.no_grad():
        logits, past = model(prev_word, past=past)
    logits = logits[:, -1, :]
    logits = top_filtering(logits, top_p=0.9)
    probs = F.softmax(logits, dim=-1)
    
    prev_word = torch.multinomial(probs, num_samples=1)
    #prev_word = torch.argmax(logits, -1).unsqueeze(1)
    sentence.append(prev_word.item())
    
roberta_tokenizer.decode(sentence)

In [None]:
roberta_tokenizer.encode("I like cats.")

In [None]:
[r2g_mapping[item] for item in gpt_tokenizer.encode("I like cats.")]

In [None]:
new_gpt2_small = GPT2SimpleLM(GPT2SmallConfig)
new_gpt2_small.load_state_dict(gpt2_small.state_dict(), strict=False)

device = torch.device("cuda")
model = new_gpt2_small
model = model.to(device)

# ask more about news
prompt = gpt_tokenizer.encode("There is a cat")
prompt = torch.LongTensor(prompt).to(device)

prev_word = prompt.unsqueeze(0)
past = None
sentence = []

model.eval()

for timestep in range(256):
    
    logits, past = model(prev_word, past=past)
    logits = logits[:, -1, :]
    
    #logits = top_filtering(logits, top_p=0.9)
    #probs = F.softmax(logits, dim=-1)
    
    #prev_word = torch.multinomial(probs, num_samples=1)
    prev_word = torch.argmax(logits, -1).unsqueeze(1)
    sentence.append(prev_word.item())
    
gpt_tokenizer.decode(sentence)