Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix numpy array incompatability #658

Merged

Conversation

riven314
Copy link
Contributor

@riven314 riven314 commented Aug 16, 2023

Problem:
Current implementation raises error when a collection contains fields of type np.ndarray

How to Reproduce:

  1. set up a mongodb server container in docker-compose.yml:
version: "3.8"
services:
  mongo:
    image: mongo:4.4.6
    restart: always
    environment:
      MONGO_INITDB_ROOT_USERNAME: root
      MONGO_INITDB_ROOT_PASSWORD: password
    ports:
      - 27017:27017
  1. run the mongodb container: docker-compose up -d
  2. prepare a script to insert a document with numpy array field on client side:
import asyncio
from io import BytesIO
from typing import Any, Callable, Iterator

import motor.motor_asyncio
import numpy as np
from beanie import Document, init_beanie
from bson import Binary
from bson.binary import USER_DEFINED_SUBTYPE


class ValidationError(Exception):
    pass


class NumpyField(np.ndarray):
    @classmethod
    def __get_validators__(cls) -> Iterator[Callable]:
        yield cls.validate

    @classmethod
    def validate(cls, value: Any) -> Any:
        if isinstance(value, Binary):
            try:
                np_stream = BytesIO(value)
                return np.load(np_stream)
            except Exception as e:
                raise ValidationError(
                    "Error in loading NumPy array from BSON: %s" % str(e)
                )
        return value

    @classmethod
    def __bson__(cls, value: np.ndarray) -> Binary:
        if isinstance(value, np.ndarray):
            try:
                np_stream = BytesIO()
                np.save(np_stream, value, allow_pickle=False)
                np_stream.seek(0)
                bin_data = np_stream.getbuffer().tobytes()
                return Binary(bin_data, USER_DEFINED_SUBTYPE)
            except Exception as e:
                raise ValidationError(
                    "Error in converting NumPy array to BSON: %s" % str(e)
                )
        raise ValidationError("Value must be a NumPy array")


def numpy_to_bson(value: np.ndarray) -> Binary:
    if isinstance(value, np.ndarray):
        try:
            np_stream = BytesIO()
            np.save(np_stream, value, allow_pickle=False)
            np_stream.seek(0)
            bin_data = np_stream.getbuffer().tobytes()
            return Binary(bin_data, USER_DEFINED_SUBTYPE)
        except Exception as e:
            raise ValidationError(
                "Error in converting NumPy array to BSON: %s" % str(e)
            )
    raise ValidationError("Value must be a NumPy array")


class User(Document):
    data: NumpyField

    class Settings:
        name = "users"
        bson_encoders = {np.ndarray: numpy_to_bson}


async def init_db_client():
    client = motor.motor_asyncio.AsyncIOMotorClient(
        "mongodb://root:password@localhost:27017"
    )
    database = client["test"]
    await init_beanie(database=database, document_models=[User])


async def main():
    await init_db_client()
    NP_DATA = np.random.random((10, 10))
    await User.insert_many([User(data=NP_DATA)])

asyncio.run(main())

@roman-right
Copy link
Owner

Hi! Thank you for the PR. I'll check it out this week after the bug-fixing session

@riven314
Copy link
Contributor Author

riven314 commented Sep 9, 2023

Hi @roman-right
Any update on this fix? Let me know if there is any thing I can help with this PR

@roman-right
Copy link
Owner

Hi @riven314 ,

Sorry for the delay. The bug-fixing session finished yesterday.

Regarding the PR. It is not obvious, what is happening there. Can you use custom encoder for you case? https://beanie-odm.dev/tutorial/defining-a-document/#encoders

@riven314
Copy link
Contributor Author

riven314 commented Sep 14, 2023

@roman-right
I am using customer encoder in my above example for the Document:

class User(Document):
    data: NumpyField

    class Settings:
        name = "users"
        bson_encoders = {np.ndarray: numpy_to_bson}

and the error still happens with customer encoder.

when I trace the source code, I found that it got into error before the data is passed to the customer encoder, so this PR aims to resolve this error

@roman-right
Copy link
Owner

I see. There are overridden magic methods in numpy arrays; that's why it cannot be simply compared. However, you detect this through indirect signs. It will work, but it won't be obvious. I'll think about how I can make this a bit more apparent.

@roman-right roman-right merged commit 19e5ade into roman-right:main Sep 14, 2023
22 checks passed
@roman-right
Copy link
Owner

I've updated the check a bit.
Thank you for your PR. Merged.

@riven314
Copy link
Contributor Author

thx a lot for the update!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants