Skip to content

Commit

Permalink
feat: add new GGUFValueType.OBJ virtual type
Browse files Browse the repository at this point in the history
The content of the OBJ type is actually a list of all key names of the object.

* GGUFWriter:
  * add `def add_kv(self, key: str, val: Any) -> None`:  This will be added based on the val type
  * add `def add_dict(self, key: str, val: dict) -> None`: add object(dict) value
* constants:
  * `GGUFValueType.get_type`: Added support for Numpy's integers and floating-point numbers, and selected the appropriate number of digits based on the size of the integer.
* gguf_reader:
  * add `ReaderField.get`: to return the value of the field
* Unit test added.

Related Issues: ggerganov#4868, ggerganov#2872
  • Loading branch information
snowyu committed Jan 26, 2024
1 parent c25052b commit d1bfb58
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 4 deletions.
32 changes: 32 additions & 0 deletions gguf-py/tests/test_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import sys
from pathlib import Path
import numpy as np
import unittest

# Necessary to load the local gguf package
sys.path.insert(0, str(Path(__file__).parent.parent))

from gguf.constants import GGUFValueType

class TestGGUFValueType(unittest.TestCase):

def test_get_type(self):
self.assertEqual(GGUFValueType.get_type("test"), GGUFValueType.STRING)
self.assertEqual(GGUFValueType.get_type([1, 2, 3]), GGUFValueType.ARRAY)
self.assertEqual(GGUFValueType.get_type(1.0), GGUFValueType.FLOAT32)
self.assertEqual(GGUFValueType.get_type(True), GGUFValueType.BOOL)
self.assertEqual(GGUFValueType.get_type(b"test"), GGUFValueType.STRING)
self.assertEqual(GGUFValueType.get_type(np.uint8(1)), GGUFValueType.UINT8)
self.assertEqual(GGUFValueType.get_type(np.uint16(1)), GGUFValueType.UINT16)
self.assertEqual(GGUFValueType.get_type(np.uint32(1)), GGUFValueType.UINT32)
self.assertEqual(GGUFValueType.get_type(np.uint64(1)), GGUFValueType.UINT64)
self.assertEqual(GGUFValueType.get_type(np.int8(-1)), GGUFValueType.INT8)
self.assertEqual(GGUFValueType.get_type(np.int16(-1)), GGUFValueType.INT16)
self.assertEqual(GGUFValueType.get_type(np.int32(-1)), GGUFValueType.INT32)
self.assertEqual(GGUFValueType.get_type(np.int64(-1)), GGUFValueType.INT64)
self.assertEqual(GGUFValueType.get_type(np.float32(1.0)), GGUFValueType.FLOAT32)
self.assertEqual(GGUFValueType.get_type(np.float64(1.0)), GGUFValueType.FLOAT64)
self.assertEqual(GGUFValueType.get_type({"k": 12}), GGUFValueType.OBJ)

if __name__ == '__main__':
unittest.main()
53 changes: 49 additions & 4 deletions gguf-py/tests/test_gguf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,52 @@
import gguf # noqa: F401
import sys
from pathlib import Path
import numpy as np
import unittest

# TODO: add tests
# Necessary to load the local gguf package
sys.path.insert(0, str(Path(__file__).parent.parent))

from gguf import GGUFWriter, GGUFReader

def test_write_gguf() -> None:
pass
class TestGGUFReaderWriter(unittest.TestCase):

def test_rw(self) -> None:
# Example usage with a file
gguf_writer = GGUFWriter("test_writer.gguf", "llama")

# gguf_writer.add_architecture()
gguf_writer.add_block_count(12)
gguf_writer.add_uint32("answer", 42) # Write a 32-bit integer
gguf_writer.add_float32("answer_in_float", 42.0) # Write a 32-bit float
gguf_writer.add_dict("dict1", {"key1": 2, "key2": "hi", "obj": {"k": 1}})
gguf_writer.add_custom_alignment(64)

tensor1 = np.ones((32,), dtype=np.float32) * 100.0
tensor2 = np.ones((64,), dtype=np.float32) * 101.0
tensor3 = np.ones((96,), dtype=np.float32) * 102.0

gguf_writer.add_tensor("tensor1", tensor1)
gguf_writer.add_tensor("tensor2", tensor2)
gguf_writer.add_tensor("tensor3", tensor3)

gguf_writer.write_header_to_file()
gguf_writer.write_kv_data_to_file()
gguf_writer.write_tensors_to_file()

gguf_writer.close()

gguf_reader = GGUFReader("test_writer.gguf")
self.assertEqual(gguf_reader.alignment, 64)
v = gguf_reader.get_field("dict1")
self.assertIsNotNone(v)
self.assertListEqual(v.get(), ['key1', 'key2', 'obj'])
v = gguf_reader.get_field("dict1.key1")
self.assertEqual(v.get(), 2)
v = gguf_reader.get_field("dict1.key2")
self.assertEqual(v.get(), "hi")
v = gguf_reader.get_field("dict1.obj")
self.assertListEqual(v.get(), ['k'])


if __name__ == '__main__':
unittest.main()

0 comments on commit d1bfb58

Please sign in to comment.