diff --git a/vllm/multimodal/hasher.py b/vllm/multimodal/hasher.py index da019d40a6fe..0fb1363ce471 100644 --- a/vllm/multimodal/hasher.py +++ b/vllm/multimodal/hasher.py @@ -20,22 +20,22 @@ class MultiModalHasher: @classmethod - def serialize_item(cls, obj: object) -> Union[bytes, memoryview]: + def serialize_item(cls, obj: object) -> Iterable[Union[bytes, memoryview]]: # Simple cases - if isinstance(obj, str): - return obj.encode("utf-8") if isinstance(obj, (bytes, memoryview)): - return obj + return (obj, ) + if isinstance(obj, str): + return (obj.encode("utf-8"), ) if isinstance(obj, (int, float)): - return np.array(obj).tobytes() + return (np.array(obj).tobytes(), ) if isinstance(obj, Image.Image): exif = obj.getexif() if Image.ExifTags.Base.ImageID in exif and isinstance( exif[Image.ExifTags.Base.ImageID], uuid.UUID): # If the image has exif ImageID tag, use that - return exif[Image.ExifTags.Base.ImageID].bytes - return cls.item_to_bytes( + return (exif[Image.ExifTags.Base.ImageID].bytes, ) + return cls.iter_item_to_bytes( "image", np.asarray(convert_image_mode(obj, "RGBA"))) if isinstance(obj, torch.Tensor): tensor_obj: torch.Tensor = obj.cpu() @@ -49,43 +49,34 @@ def serialize_item(cls, obj: object) -> Union[bytes, memoryview]: tensor_obj = tensor_obj.view( (tensor_obj.numel(), )).view(torch.uint8) - return cls.item_to_bytes( + return cls.iter_item_to_bytes( "tensor", { "original_dtype": str(tensor_dtype), "original_shape": tuple(tensor_shape), "data": tensor_obj.numpy(), }) - - return cls.item_to_bytes("tensor", tensor_obj.numpy()) + return cls.iter_item_to_bytes("tensor", tensor_obj.numpy()) if isinstance(obj, np.ndarray): # If the array is non-contiguous, we need to copy it first - arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes() - return cls.item_to_bytes("ndarray", { + arr_data = obj.view( + np.uint8).data if obj.flags.c_contiguous else obj.tobytes() + return cls.iter_item_to_bytes("ndarray", { "dtype": obj.dtype.str, "shape": obj.shape, "data": arr_data, }) - logger.warning( "No serialization method found for %s. " "Falling back to pickle.", type(obj)) - return pickle.dumps(obj) - - @classmethod - def item_to_bytes( - cls, - key: str, - obj: object, - ) -> bytes: - return b''.join(kb + vb for kb, vb in cls.iter_item_to_bytes(key, obj)) + return (pickle.dumps(obj), ) @classmethod def iter_item_to_bytes( cls, key: str, obj: object, - ) -> Iterable[tuple[bytes, Union[bytes, memoryview]]]: + ) -> Iterable[Union[bytes, memoryview]]: # Recursive cases if isinstance(obj, (list, tuple)): for i, elem in enumerate(obj): @@ -94,17 +85,15 @@ def iter_item_to_bytes( for k, v in obj.items(): yield from cls.iter_item_to_bytes(f"{key}.{k}", v) else: - key_bytes = key.encode("utf-8") - value_bytes = cls.serialize_item(obj) - yield key_bytes, value_bytes + yield key.encode("utf-8") + yield from cls.serialize_item(obj) @classmethod def hash_kwargs(cls, **kwargs: object) -> str: hasher = blake3() for k, v in kwargs.items(): - for k_bytes, v_bytes in cls.iter_item_to_bytes(k, v): - hasher.update(k_bytes) - hasher.update(v_bytes) + for bytes_ in cls.iter_item_to_bytes(k, v): + hasher.update(bytes_) return hasher.hexdigest()