In [5]:
import base64
from typing import List, Optional
from pathlib import Path
from tqdm import tqdm
import numpy as np
import weaviate
import weaviate.classes as wvc
import weaviate.classes.config as wc
from weaviate.classes.init import AdditionalConfig, Timeout
from weaviate.classes.query import Filter
from weaviate.connect import ConnectionParams
from weaviate.collections import Collection
from weaviate.util import generate_uuid5

from utils.descriptions import get_latest_descriptions
from utils.weaviate import create_collection

In [2]:
descriptions_dict = (
    get_latest_descriptions()
)  # https://huggingface.co/datasets/tiange/Cap3D/resolve/48903d63859fe3d3f17942bf6d5383eb05dd1775/Cap3D_automated_Objaverse_full.csv?download=true
len(descriptions_dict)

2024-05-01 15:47:29,497 - INFO - Get latest descriptions process started for file: 'Cap3D_automated_Objaverse_full.csv'
2024-05-01 15:47:29,498 - INFO - Get latest descriptions process started for file: 'Cap3D_automated_Objaverse_full.csv'
2024-05-01 15:47:29,498 - INFO - Calculate file hash process started for file: 'Cap3D_automated_Objaverse_full.csv'
2024-05-01 15:47:29,498 - DEBUG - Reading file bytes for checksum (file_path: Cap3D_automated_Objaverse_full.csv)
2024-05-01 15:47:29,532 - DEBUG - Calculating file hash (checksum) from file bytes
2024-05-01 15:47:29,582 - DEBUG - Successfully calculated file hash (checksum) from file bytes
2024-05-01 15:47:29,585 - INFO - Calculate file hash process successfully completed for file: 'Cap3D_automated_Objaverse_full.csv' (file hash: 'cb1f6a41b1b85fa104c7ac18b9cef2e6c3bab45869e1eea53a027fa513fd221b')
2024-05-01 15:47:29,585 - INFO - Requesting pointer file from url: 'https://huggingface.co/datasets/tiange/Cap3D/raw/main/Cap3D_automated_Obj

1002422

In [3]:
HTTP_HOST = "localhost"
HTTP_PORT = 8080
HTTP_SECURE = False

GRPC_HOST = "localhost"
GRPC_PORT = 50051
GRPC_SECURE = False

INTIALISATION_TIMEOUT_S = 2
QUERY_TIMEOUT_S = 45
INSERT_TIMEOUT_S = 120

COLLECTION_NAME = "Cap3DMM"
DATA_UPLOAD_COLLECTION_NAME = "UploadCap3DMM"

BATCH_SIZE = 50
PATH_TO_EXAMPLE_OBJECTS = "/home/yunusskeete/Documents/data/3D/Cap3D/local-split/unzips/compressed_imgs_perobj_00.zip/Cap3D_Objaverse_renderimgs"
path_to_example_objects = Path(PATH_TO_EXAMPLE_OBJECTS)

In [4]:
with weaviate.WeaviateClient(
    connection_params=ConnectionParams.from_params(
        http_host=HTTP_HOST,
        http_port=HTTP_PORT,
        http_secure=HTTP_SECURE,
        grpc_host=GRPC_HOST,
        grpc_port=GRPC_PORT,
        grpc_secure=GRPC_SECURE,
    ),
    additional_config=AdditionalConfig(
        timeout=Timeout(
            init=INTIALISATION_TIMEOUT_S, query=QUERY_TIMEOUT_S, insert=INSERT_TIMEOUT_S
        ),  # Values in seconds
    ),
) as client:
    assert (
        client.is_live()
    ), "Weaviate client is not live"  # This will raise an exception if the client is not live
    print("Client connection established")

    try:
        client.collections.delete(COLLECTION_NAME)
        client.collections.delete(DATA_UPLOAD_COLLECTION_NAME)

        cap3d: Collection = create_collection(
            client=client,
            collection_name=COLLECTION_NAME,
            configure_upload_collection=False,
        )
        cap3d_upload: Collection = create_collection(
            client=client,
            collection_name=DATA_UPLOAD_COLLECTION_NAME,
            configure_upload_collection=True,
        )

        with cap3d.batch.dynamic() as cap3d_batch:
            with cap3d_upload.batch.dynamic() as cap3d_upload_batch:
                for object_idx, object_folder in tqdm(
                    enumerate(path_to_example_objects.iterdir())
                ):
                    image_uuids_per_object: List[str] = []

                    object_description: str = descriptions_dict.get(
                        object_folder.name, ""
                    )  # TODO: Add log if object could not be found and add to tracking list

                    for object_image_idx, object_image_file in enumerate(
                        file
                        for file in object_folder.iterdir()
                        if file.suffix == ".png" and "_" not in file.name
                    ):
                        # Convert image to base64
                        with object_image_file.open("rb") as file:
                            image_b64: str = base64.b64encode(file.read()).decode(
                                "utf-8"
                            )

                        # Build the image object payload
                        image_obj: Dict[str, str] = {
                            "image": image_b64,
                            "description": object_description,
                            "datasetUID": f"{object_image_file.parent.name}_{object_image_file.name}",  # E.g. "c5517f31ede34ad0a0da1f38753f9588_00005.png"
                        }

                        object_uuid: str = generate_uuid5(image_obj["datasetUID"])

                        image_uuids_per_object.append(object_uuid)

                        # Add object to batch queue
                        cap3d_upload_batch.add_object(
                            properties=image_obj,
                            uuid=object_uuid,
                            # references=reference_obj  # You can add references here
                        )  # Batcher automatically sends batches

                    # Check for failed inserts
                    if len(cap3d_upload.batch.failed_objects) > 0:
                        print(
                            f"Failed to import {len(cap3d_upload.batch.failed_objects)} objects"
                        )
                        for failed in cap3d_upload.batch.failed_objects:
                            print(
                                f"e.g. Failed to import image object with error: {failed.message}"
                            )

                    data_objects: List[
                        weaviate.collections.classes.internal.ObjectSingleReturn
                    ] = [
                        collection.query.fetch_object_by_id(uuid, include_vector=True)
                        for uuid in object_uuids
                    ]
                    average_vector: np.array = np.mean(
                        np.array(
                            [
                                data_object.vector["default"]
                                for data_object in data_objects
                            ]
                        ),
                        axis=0,
                    )

                    # Build the object payload
                    obj: Dict[str, str] = {
                        "description": object_description,
                        "datasetUID": f"{object_image_file.parent.name}_{object_image_file.name}",  # E.g. "c5517f31ede34ad0a0da1f38753f9588_00005.png"
                    }
                    # Add object to batch queue
                    cap3d_batch.add_object(
                        properties=obj,
                        uuid=object_uuid,
                        # references=reference_obj  # You can add references here
                    )  # Batcher automatically sends batches

                    # On success, delete image objects from cap3d_upload collection
                    # Check for failed inserts
                    if len(cap3d.batch.failed_objects) > 0:
                        print(
                            f"Failed to import {len(cap3d.batch.failed_objects)} objects"
                        )
                        for failed in cap3d.batch.failed_objects:
                            print(
                                f"e.g. Failed to import object with error: {failed.message}"
                            )

                    cap3d_upload.data.delete_many(
                        where=Filter.by_id().contains_any(image_uuids_per_object)
                    )

    except Exception as e:
        print(f"client operation failed: {e}")

    print("Closing client connection")
    pass
    # The connection is closed automatically when the context manager exits

2024-05-01 15:48:33,438 - DEBUG - load_ssl_context verify=True cert=None trust_env=True http2=False
2024-05-01 15:48:33,439 - DEBUG - load_verify_locations cafile='/home/yunusskeete/miniconda3/envs/weaviate/lib/python3.11/site-packages/certifi/cacert.pem'
2024-05-01 15:48:33,443 - DEBUG - connect_tcp.started host='localhost' port=8080 local_address=None timeout=45 socket_options=None
2024-05-01 15:48:33,444 - DEBUG - connect_tcp.complete return_value=<httpcore._backends.sync.SyncStream object at 0x71b7e8ad2910>
2024-05-01 15:48:33,444 - DEBUG - send_request_headers.started request=<Request [b'GET']>
2024-05-01 15:48:33,445 - DEBUG - send_request_headers.complete
2024-05-01 15:48:33,445 - DEBUG - send_request_body.started request=<Request [b'GET']>
2024-05-01 15:48:33,446 - DEBUG - send_request_body.complete
2024-05-01 15:48:33,446 - DEBUG - receive_response_headers.started request=<Request [b'GET']>
2024-05-01 15:48:33,446 - DEBUG - receive_response_headers.complete return_value=(b'HTT

Client connection established


0it [00:00, ?it/s]2024-05-01 15:48:33,760 - DEBUG - send_request_headers.started request=<Request [b'GET']>
2024-05-01 15:48:33,761 - DEBUG - send_request_headers.started request=<Request [b'GET']>
2024-05-01 15:48:33,761 - DEBUG - send_request_headers.complete
2024-05-01 15:48:33,762 - DEBUG - send_request_headers.complete
2024-05-01 15:48:33,762 - DEBUG - send_request_body.started request=<Request [b'GET']>
2024-05-01 15:48:33,763 - DEBUG - send_request_body.started request=<Request [b'GET']>
2024-05-01 15:48:33,765 - DEBUG - send_request_body.complete
2024-05-01 15:48:33,766 - DEBUG - send_request_body.complete
2024-05-01 15:48:33,766 - DEBUG - receive_response_headers.started request=<Request [b'GET']>
2024-05-01 15:48:33,766 - DEBUG - receive_response_headers.started request=<Request [b'GET']>
2024-05-01 15:48:33,767 - DEBUG - receive_response_headers.complete return_value=(b'HTTP/1.1', 200, b'OK', [(b'Access-Control-Allow-Headers', b'Content-Type, Authorization, Batch, X-Openai-A

client operation failed: name 'base64' is not defined
Closing client connection
