Skip to content

Commit

Permalink
feat: add new GGUFValueType.OBJ virtual type
Browse files Browse the repository at this point in the history
The content of the OBJ type is actually a list of all key names of the object. This change includes several improvements and additions to the codebase:

* GGUFWriter:
  * Added `def add_kv(self, key: str, val: Any) -> None` method: Automatically determines the appropriate value type based on val.
  * Added `def add_dict(self, key: str, val: dict) -> None` method: add object(dict) key-value
* constants:
  * Revised `GGUFValueType.get_type(val)`: Added support for numpy's integers and floating point numbers, and appropriately selected the number of digits according to the size of the integer.
* gguf_reader
  * Added `ReaderField.get()` method: get the value of this ReaderField
* Unit tests have been added to cover these changes.

Related Issues: ggerganov#4868, ggerganov#2872
  • Loading branch information
snowyu committed Jan 26, 2024
1 parent 8d674b2 commit 4ca6cd6
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 6 deletions.
44 changes: 43 additions & 1 deletion gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
from enum import Enum, IntEnum, auto
from typing import Any
import numpy as np

#
# constants
Expand Down Expand Up @@ -510,19 +511,60 @@ class GGUFValueType(IntEnum):
UINT64 = 10
INT64 = 11
FLOAT64 = 12
OBJ = 13

@staticmethod
def get_type(val: Any) -> GGUFValueType:
if isinstance(val, (str, bytes, bytearray)):
return GGUFValueType.STRING
elif isinstance(val, list):
return GGUFValueType.ARRAY
elif isinstance(val, np.float32):
return GGUFValueType.FLOAT32
elif isinstance(val, np.float64):
return GGUFValueType.FLOAT64
elif isinstance(val, float):
return GGUFValueType.FLOAT32
elif isinstance(val, bool):
return GGUFValueType.BOOL
elif isinstance(val, int):
elif isinstance(val, np.uint8):
return GGUFValueType.UINT8
elif isinstance(val, np.uint16):
return GGUFValueType.UINT16
elif isinstance(val, np.uint32):
return GGUFValueType.UINT32
elif isinstance(val, np.uint64):
return GGUFValueType.UINT64
elif isinstance(val, np.int8):
return GGUFValueType.INT8
elif isinstance(val, np.int16):
return GGUFValueType.INT16
elif isinstance(val, np.int32):
return GGUFValueType.INT32
elif isinstance(val, np.int64):
return GGUFValueType.INT64
elif isinstance(val, int):
if val >=0 and val <= np.iinfo(np.uint8).max:
return GGUFValueType.UINT8
elif val >=0 and val <= np.iinfo(np.uint16).max:
return GGUFValueType.UINT16
elif val >=0 and val <= np.iinfo(np.uint32).max:
return GGUFValueType.UINT32
elif val >=0 and val <= np.iinfo(np.uint64).max:
return GGUFValueType.UINT64
elif val >=np.iinfo(np.int8).min and val <= np.iinfo(np.int8).max:
return GGUFValueType.INT8
elif val >=np.iinfo(np.int16).min and val <= np.iinfo(np.int16).max:
return GGUFValueType.INT16
elif val >=np.iinfo(np.int32).min and val <= np.iinfo(np.int32).max:
return GGUFValueType.INT32
elif val >=np.iinfo(np.int64).min and val <= np.iinfo(np.int64).max:
return GGUFValueType.INT64
else:
print("The integer exceed limit:", val)
sys.exit()
elif isinstance(val, dict):
return GGUFValueType.OBJ
# TODO: need help with 64-bit types in Python
else:
print("Unknown type:", type(val))
Expand Down
18 changes: 17 additions & 1 deletion gguf-py/gguf/gguf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ class ReaderField(NamedTuple):

types: list[GGUFValueType] = []

def get(self):
result = None
itype = self.types[0]
if itype == GGUFValueType.ARRAY or itype == GGUFValueType.OBJ:
itype = self.types[-1]
if itype == GGUFValueType.STRING:
result = [str(bytes(self.parts[idx]), encoding="utf-8") for idx in self.data]
else:
result = [pv for idx in self.data for pv in self.parts[idx].tolist()]
elif itype == GGUFValueType.STRING:
result = str(bytes(self.parts[-1]), encoding="utf-8")
else:
result = self.parts[-1].tolist()[0]

return result


class ReaderTensor(NamedTuple):
name: str
Expand Down Expand Up @@ -165,7 +181,7 @@ def _get_field_parts(
val = self._get(offs, nptype)
return int(val.nbytes), [val], [0], types
# Handle arrays.
if gtype == GGUFValueType.ARRAY:
if gtype == GGUFValueType.ARRAY or gtype == GGUFValueType.OBJ:
raw_itype = self._get(offs, np.uint32)
offs += int(raw_itype.nbytes)
alen = self._get(offs, np.uint64)
Expand Down
40 changes: 40 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,44 @@ def add_array(self, key: str, val: Sequence[Any]) -> None:
self.add_key(key)
self.add_val(val, GGUFValueType.ARRAY)

def add_kv(self, key: str, val: Any) -> None:
vtype=GGUFValueType.get_type(val)
if vtype == GGUFValueType.OBJ:
self.add_dict(key, val)
elif vtype == GGUFValueType.ARRAY:
self.add_array(key, val)
elif vtype == GGUFValueType.STRING:
self.add_string(key, val)
elif vtype == GGUFValueType.BOOL:
self.add_bool(key, val)
elif vtype == GGUFValueType.INT64:
self.add_int64(key, val)
elif vtype == GGUFValueType.FLOAT64:
self.add_float64(key, val)
elif vtype == GGUFValueType.INT32:
self.add_int32(key, val)
elif vtype == GGUFValueType.FLOAT32:
self.add_float32(key, val)
elif vtype == GGUFValueType.UINT64:
self.add_uint64(key, val)
elif vtype == GGUFValueType.UINT32:
self.add_uint32(key, val)
elif vtype == GGUFValueType.UINT16:
self.add_uint16(key, val)
elif vtype == GGUFValueType.UINT8:
self.add_uint8(key, val)
else:
raise ValueError(f"Unsupported type: {type(val)}")

def add_dict(self, key: str, val: dict) -> None:
if not isinstance(val, dict):
raise ValueError("Value must be a dict type")

self.add_key(key)
self.add_val(val, GGUFValueType.OBJ)
for k, v in val.items():
self.add_kv(key + "." + k, v)

def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None:
if vtype is None:
vtype = GGUFValueType.get_type(val)
Expand All @@ -181,6 +219,8 @@ def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool
self.kv_data += self._pack("Q", len(val))
for item in val:
self.add_val(item, add_vtype=False)
elif vtype == GGUFValueType.OBJ and isinstance(val, dict) and val:
self.add_val(list(val.keys()), GGUFValueType.ARRAY, False)
else:
raise ValueError("Invalid GGUF metadata value type or value")

Expand Down
53 changes: 49 additions & 4 deletions gguf-py/tests/test_gguf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,52 @@
import gguf # noqa: F401
import sys
from pathlib import Path
import numpy as np
import unittest

# TODO: add tests
# Necessary to load the local gguf package
sys.path.insert(0, str(Path(__file__).parent.parent))

from gguf import GGUFWriter, GGUFReader

def test_write_gguf() -> None:
pass
class TestGGUFReaderWriter(unittest.TestCase):

def test_rw(self) -> None:
# Example usage with a file
gguf_writer = GGUFWriter("test_writer.gguf", "llama")

# gguf_writer.add_architecture()
gguf_writer.add_block_count(12)
gguf_writer.add_uint32("answer", 42) # Write a 32-bit integer
gguf_writer.add_float32("answer_in_float", 42.0) # Write a 32-bit float
gguf_writer.add_dict("dict1", {"key1": 2, "key2": "hi", "obj": {"k": 1}})
gguf_writer.add_custom_alignment(64)

tensor1 = np.ones((32,), dtype=np.float32) * 100.0
tensor2 = np.ones((64,), dtype=np.float32) * 101.0
tensor3 = np.ones((96,), dtype=np.float32) * 102.0

gguf_writer.add_tensor("tensor1", tensor1)
gguf_writer.add_tensor("tensor2", tensor2)
gguf_writer.add_tensor("tensor3", tensor3)

gguf_writer.write_header_to_file()
gguf_writer.write_kv_data_to_file()
gguf_writer.write_tensors_to_file()

gguf_writer.close()

gguf_reader = GGUFReader("test_writer.gguf")
self.assertEqual(gguf_reader.alignment, 64)
v = gguf_reader.get_field("dict1")
self.assertIsNotNone(v)
self.assertListEqual(v.get(), ['key1', 'key2', 'obj'])
v = gguf_reader.get_field("dict1.key1")
self.assertEqual(v.get(), 2)
v = gguf_reader.get_field("dict1.key2")
self.assertEqual(v.get(), "hi")
v = gguf_reader.get_field("dict1.obj")
self.assertListEqual(v.get(), ['k'])


if __name__ == '__main__':
unittest.main()

0 comments on commit 4ca6cd6

Please sign in to comment.