Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama3 8B support, tiktoken tokenizer #158

Merged
merged 2 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
torch._inductor.config.triton.cudagraphs = True
torch._dynamo.config.cache_size_limit = 100000

from sentencepiece import SentencePieceProcessor
from tokenizer import get_tokenizer

from model import Transformer

Expand Down Expand Up @@ -217,7 +217,7 @@ def main(
assert checkpoint_path.is_file(), checkpoint_path

tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
assert tokenizer_path.is_file(), str(tokenizer_path)

device = 'cuda'
precision = torch.bfloat16
Expand All @@ -231,7 +231,7 @@ def main(

model.eval()

tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)

torch.manual_seed(1234)

Expand Down
9 changes: 4 additions & 5 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ def device_sync(device):
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from sentencepiece import SentencePieceProcessor

from model import Transformer

from tokenizer import get_tokenizer

def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
Expand Down Expand Up @@ -269,7 +267,7 @@ def main(
assert checkpoint_path.is_file(), checkpoint_path

tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
assert tokenizer_path.is_file(), str(tokenizer_path)

global print
from tp import maybe_init_dist
Expand Down Expand Up @@ -297,7 +295,8 @@ def main(
device_sync(device=device) # MKG
print(f"Time to load model: {time.time() - t0:.02f} seconds")

tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)

encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
prompt_length = encoded.size(0)

Expand Down
2 changes: 1 addition & 1 deletion mixtral-moe/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def main(
assert checkpoint_path.is_file(), checkpoint_path

tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
assert tokenizer_path.is_file(), str(tokenizer_path)

global print
rank = maybe_init_dist()
Expand Down
1 change: 1 addition & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def from_name(cls, name: str):
"Mistral-7B": dict(n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
"stories15M": dict(n_layer=6, n_head=6, dim=288),
"stories110M": dict(n_layer=12, n_head=12, dim=768),
"Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256),
}

class KVCache(nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from sentencepiece import SentencePieceProcessor
from tokenizer import get_tokenizer

try:
from GPTQ import GenericGPTQRunner, InputRecorder
Expand Down Expand Up @@ -578,8 +578,8 @@ def quantize(
quant_handler = WeightOnlyInt4GPTQQuantHandler(model, groupsize)

tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path))
assert tokenizer_path.is_file(), str(tokenizer_path)
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)

quantized_state_dict = quant_handler.create_quantized_state_dict(
tokenizer,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
torch
sentencepiece
tiktoken
133 changes: 86 additions & 47 deletions scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
import json
import re
import shutil
import sys
from pathlib import Path
from typing import Optional
Expand All @@ -27,33 +28,62 @@ def convert_hf_checkpoint(
if model_name is None:
model_name = checkpoint_dir.name

# Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files
# need to be copied into model.pth.
# Llama 3 70B can't be easily merged into one model.pth file, though, since names of the
# weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not
# currently supported.
# Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken
is_llama3 = "Llama-3" in model_name
if is_llama3:
# Check if we have multiple original/consolidated.NN.pth files and report error
# if we do for Llama 3.
original_dir = checkpoint_dir / "original"
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
bin_files = [bin for bin in original_dir.iterdir() if pattern.match(bin.name)]
if len(bin_files) > 1:
raise ValueError(
f"Multiple consolidated.NN.pth files found in {original_dir}. "
"Merging them into one model.pth file is not supported for Llama 3.")


config = ModelArgs.from_name(model_name)
print(f"Model config {config.__dict__}")

# Load the json file containing weight mapping
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"

assert model_map_json.is_file()

with open(model_map_json) as json_map:
bin_index = json.load(json_map)

weight_map = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
if not is_llama3:
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"

assert model_map_json.is_file()

with open(model_map_json) as json_map:
bin_index = json.load(json_map)

weight_map = {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight",
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
'model.layers.{}.self_attn.rotary_emb.inv_freq': None,
'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight',
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
}
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
else:
# There is no separate pytorch_model.bin.index.json file for llama3.
# Instead, we will just use all original/consolidated.NN.pth files.
# so, we use model.safetensors.index.json
weight_map = None
original_dir = checkpoint_dir / "original"
pattern = re.compile(r"^consolidated\.\d{2}\.pth$")
bin_files = {bin for bin in original_dir.iterdir() if pattern.match(bin.name)}


def permute(w, n_head):
dim = config.dim
Expand All @@ -68,32 +98,41 @@ def permute(w, n_head):
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
merged_result.update(state_dict)
final_result = {}
for key, value in merged_result.items():
if "layers" in key:
abstract_key = re.sub(r'(\d+)', '{}', key)
layer_num = re.search(r'\d+', key).group(0)
new_key = weight_map[abstract_key]
if new_key is None:
continue
new_key = new_key.format(layer_num)
else:
new_key = weight_map[key]

final_result[new_key] = value

for key in tuple(final_result.keys()):
if "wq" in key:
q = final_result[key]
k = final_result[key.replace("wq", "wk")]
v = final_result[key.replace("wq", "wv")]
q = permute(q, config.n_head)
k = permute(k, config.n_local_heads)
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
if weight_map is not None:
for key, value in merged_result.items():
if "layers" in key:
abstract_key = re.sub(r'(\d+)', '{}', key)
layer_num = re.search(r'\d+', key).group(0)
new_key = weight_map[abstract_key]
if new_key is None:
continue
new_key = new_key.format(layer_num)
else:
new_key = weight_map[key]

final_result[new_key] = value

for key in tuple(final_result.keys()):
if "wq" in key:
q = final_result[key]
k = final_result[key.replace("wq", "wk")]
v = final_result[key.replace("wq", "wv")]
q = permute(q, config.n_head)
k = permute(k, config.n_local_heads)
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
else:
final_result = merged_result
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
torch.save(final_result, checkpoint_dir / "model.pth")
if is_llama3:
original_dir = checkpoint_dir / "original"
tokenizer_model = original_dir / "tokenizer.model"
tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model"
print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}")
shutil.copy(tokenizer_model, tokenizer_model_tiktoken)

if __name__ == '__main__':
import argparse
Expand Down
111 changes: 111 additions & 0 deletions tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import os
import sentencepiece as spm
import tiktoken
from tiktoken.load import load_tiktoken_bpe
from pathlib import Path
from typing import Dict

class TokenizerInterface:
def __init__(self, model_path):
self.model_path = model_path

def encode(self, text):
raise NotImplementedError("This method should be overridden by subclasses.")

def decode(self, tokens):
raise NotImplementedError("This method should be overridden by subclasses.")

def bos_id(self):
raise NotImplementedError("This method should be overridden by subclasses.")

def eos_id(self):
raise NotImplementedError("This method should be overridden by subclasses.")

class SentencePieceWrapper(TokenizerInterface):
def __init__(self, model_path):
super().__init__(model_path)
self.processor = spm.SentencePieceProcessor(str(model_path))

def encode(self, text):
return self.processor.EncodeAsIds(text)

def decode(self, tokens):
return self.processor.DecodeIds(tokens)

def bos_id(self):
return self.processor.bos_id()

def eos_id(self):
return self.processor.eos_id()

class TiktokenWrapper(TokenizerInterface):
"""
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
"""

special_tokens: Dict[str, int]

num_reserved_special_tokens = 256

pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501

def __init__(self, model_path):
super().__init__(model_path)
assert os.path.isfile(model_path), str(model_path)
mergeable_ranks = load_tiktoken_bpe(str(model_path))
num_base_tokens = len(mergeable_ranks)
special_tokens = [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|reserved_special_token_0|>",
"<|reserved_special_token_1|>",
"<|reserved_special_token_2|>",
"<|reserved_special_token_3|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|reserved_special_token_4|>",
"<|eot_id|>", # end of turn
] + [
f"<|reserved_special_token_{i}|>"
for i in range(5, self.num_reserved_special_tokens - 5)
]
self.special_tokens = {
token: num_base_tokens + i for i, token in enumerate(special_tokens)
}
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=self.pat_str,
mergeable_ranks=mergeable_ranks,
special_tokens=self.special_tokens,
)
# BOS / EOS token IDs
self._bos_id: int = self.special_tokens["<|begin_of_text|>"]
self._eos_id: int = self.special_tokens["<|end_of_text|>"]

def encode(self, text):
return self.model.encode(text)

def decode(self, tokens):
return self.model.decode(tokens)

def bos_id(self):
return self._bos_id

def eos_id(self):
return self._eos_id

def get_tokenizer(tokenizer_model_path, model_name):
"""
Factory function to get the appropriate tokenizer based on the model name.

Args:
- tokenizer_model_path (str): The file path to the tokenizer model.
- model_name (str): The name of the model, used to determine the tokenizer type.

Returns:
- TokenizerInterface: An instance of a tokenizer.
"""
if "Llama-3" in str(model_name):
return TiktokenWrapper(tokenizer_model_path)
else:
return SentencePieceWrapper(tokenizer_model_path)