In [None]:
from transformers.models.bloom import BloomForCausalLM, BloomTokenizerFast
import os 
from pathlib import Path
import shutil
from tqdm import tqdm
import pandas as pd
from pruner.process import *

In [None]:
model_name_or_path = "D:/dddd/bloom_560m"  # "bigscience/bloomz-560m"
tokenizer_old = BloomTokenizerFast.from_pretrained(model_name_or_path,
                                                   use_fast=True)
# tokenizer_old.pad_token = tokenizer_old.eos_token
model_old = BloomForCausalLM.from_pretrained(model_name_or_path)


In [None]:
model_old.transformer.word_embeddings.weight.shape

In [None]:
small_dataset = make_train_dataset(tokenizer=tokenizer_old,
                                   data_path="D:\\数据\\WuDaoCorpus2.0_base_200G",
                                   data_file_number=3,
                                   data_proc_num=1, 
                                   use_streaming=False)

In [None]:
batch_size = 5000
target_num_shards = small_dataset.num_rows // batch_size + 1


In [None]:
def shard2pandas(index:int):


    shards_datasets = small_dataset.shard(target_num_shards, index=index)
    ss = shards_datasets.to_pandas().pipe(
        lambda x:x[['input_ids']]
    ).pipe(
        lambda x: x.explode(['input_ids'])
    ).pipe(
        lambda x: x.assign(**{
            'value':1
        })
    ).pipe(
        lambda x: x.groupby(['input_ids']).agg(
            value = ('value', 'sum')
        ).reset_index(drop=False)
    )
    return ss

index_dist = pd.concat([shard2pandas(index=i) for i in tqdm(range(target_num_shards))]).pipe(
    lambda x: x.groupby(['input_ids']).agg(
        value = ('value', 'sum')
    ).reset_index(drop=False)
)
index_dist.shape


In [None]:
index_dist

In [None]:
model_old.num_parameters()/1e8

In [None]:
map_index_df = index_dist.pipe(
    lambda x: x.assign(**{
        'new_id':x.index + min(x['input_ids'])
    })
)
map_index_df

In [None]:
map_index_df

In [None]:
save_ids = map_index_df['input_ids'].tolist()

vocab_old = tokenizer_old.get_vocab()
vocab_old = {k: v for k, v in tqdm(
    vocab_old.items()) if v in save_ids}
vocab_old


In [None]:
oi2ni = {row['input_ids']:row['new_id'] for (_, row) in map_index_df[['input_ids', 'new_id']].iterrows()}
vocab_new = {k:oi2ni.get(v) for k,v in tqdm(vocab_old.items())}
vocab_new 

In [None]:
oi2ni ={int(k):int(v) for k,v in oi2ni.items()}

In [None]:
import json
with open("map_index.json", mode='w', encoding='utf-8') as fout:
    fout.write(json.dumps(oi2ni,default=str,ensure_ascii=False))

In [None]:
import torch
model_old.transformer.word_embeddings(torch.LongTensor([0]))

In [None]:
weight_old = model_old.transformer.word_embeddings.weight.data.clone()

weight_old#.shape

In [None]:
mask_weight = range(weight_old.shape[0])

save_id_list_ = map_index_df['input_ids'].tolist()
min_save_id_list_ = min(save_id_list_)

mask_weight = [True if i < min_save_id_list_ or i in save_id_list_ else False for i in tqdm(mask_weight)]


In [None]:
weight_new = weight_old[mask_weight].data
# weight_new.requires_grad_ = True
weight_new.shape

In [31]:
oi2ni
# 1872, 976

{5: 5,
 6: 6,
 7: 7,
 8: 8,
 9: 9,
 10: 10,
 11: 11,
 12: 12,
 13: 13,
 14: 14,
 15: 15,
 16: 16,
 17: 17,
 18: 18,
 19: 19,
 20: 20,
 21: 21,
 22: 22,
 23: 23,
 24: 24,
 25: 25,
 26: 26,
 27: 27,
 28: 28,
 29: 29,
 32: 30,
 35: 31,
 62: 32,
 64: 33,
 65: 34,
 66: 35,
 67: 36,
 68: 37,
 69: 38,
 70: 39,
 71: 40,
 72: 41,
 73: 42,
 74: 43,
 75: 44,
 76: 45,
 77: 46,
 78: 47,
 79: 48,
 80: 49,
 81: 50,
 82: 51,
 83: 52,
 84: 53,
 85: 54,
 86: 55,
 87: 56,
 88: 57,
 89: 58,
 90: 59,
 91: 60,
 92: 61,
 93: 62,
 94: 63,
 95: 64,
 96: 65,
 97: 66,
 98: 67,
 99: 68,
 100: 69,
 101: 70,
 102: 71,
 103: 72,
 104: 73,
 105: 74,
 106: 75,
 107: 76,
 108: 77,
 109: 78,
 110: 79,
 111: 80,
 112: 81,
 113: 82,
 114: 83,
 115: 84,
 116: 85,
 117: 86,
 118: 87,
 119: 88,
 120: 89,
 121: 90,
 122: 91,
 123: 92,
 124: 93,
 125: 94,
 126: 95,
 127: 96,
 128: 97,
 129: 98,
 133: 99,
 137: 100,
 138: 101,
 140: 102,
 144: 103,
 148: 104,
 149: 105,
 150: 106,
 153: 107,
 155: 108,
 159: 109,
 160: 110,
 16

In [32]:
weight_old[1917, :], weight_new[1000, :]

(tensor([ 0.0088,  0.0072, -0.0185,  ..., -0.0422,  0.0167, -0.0020]),
 tensor([ 0.0088,  0.0072, -0.0185,  ..., -0.0422,  0.0167, -0.0020]))

In [33]:
lm_head_weight_old = model_old.lm_head.weight.data
lm_head_weight_new = lm_head_weight_old[mask_weight, :]
lm_head_weight_new.shape

torch.Size([52922, 1024])

In [35]:
from torch import nn

In [None]:
nn.Linear(3, 6, bias=False).weight.shape

In [38]:
lm_head_new = nn.Linear(in_features=lm_head_weight_new.shape[1], out_features=lm_head_weight_new.shape[0], bias=False)
lm_head_new.weight.data = lm_head_weight_new
lm_head_new.weight.data.shape

torch.Size([52922, 1024])

In [39]:
new_embedding = nn.Embedding.from_pretrained(weight_new)
# new_embedding.requires_grad_ = True
new_embedding.weight

model_new = model_old
model_new.transformer.word_embeddings = new_embedding
model_new.transformer.word_embeddings

# lm_head

model_new.lm_head = lm_head_new



In [40]:
model_new.config.vocab_size = weight_new.shape[0]
model_new.config.vocab_size

52922

In [41]:
# model_new.config
model_new.save_pretrained("test_model")

In [None]:
model_test = BloomForCausalLM.from_pretrained("test_model")

In [None]:
model_test.lm_head.weight[1000, :]#.shape

In [None]:
model_old.lm_head.weight[1000,:]#.shape

In [None]:
test_ = torch.load("test_model/pytorch_model.bin")

In [None]:
test_.keys()

In [None]:
test_['transformer.word_embeddings.weight'].shape#,test_['transformer.word_embeddings.bias'].shape

In [None]:
test_['lm_head.weight'].shape

In [None]:
tokenizer_old.vocab_size

In [None]:
model_old.transformer.word_embeddings

In [None]:
tokenizer_old.special_tokens_map

In [None]:
tokenizer_old.all_special_ids

In [None]:
tokenizer_old.all_special_tokens

In [None]:
tokenizer_old.save_pretrained("test_tokenizer")


with open(f"{'test_model'}/tokenizer.json") as fin:
    vocab_token_old = json.loads(fin.read())
import json

vocab_token_old['model']['vocab']

In [None]:
min_vocab_start = map_index_df['input_ids'].min()
vocab_new_p1 = vocab_token_old['model']['vocab']
vocab_new_p1 = {k:v for k,v in vocab_new_p1.items() if v <= min_vocab_start}
vocab_new_p1

In [None]:
vocab_new.update(vocab_new_p1)
vocab_new

In [None]:
vocab_token_old['model']['vocab'] = vocab_new
vocab_token_old['model']['vocab']

In [None]:
with open(file="test_tokenizer/tokenizer.json", mode='w', encoding='utf-8') as fout:
    fout.write(json.dumps(vocab_token_old,default=str,ensure_ascii=False))

In [None]:
vocab_new['anda']

In [None]:
json.dumps(vocab_new, default=str)

In [None]:
BloomTokenizerFast.from_pretrained("test_model")