In [None]:
import functools
from typing import Sequence

# TODO
import bracex
import wcmatch.fnmatch


class PathExpression:
    __slots__ = ["_expr"]

    def __init__(self, expr: "PathExpressionLike"):
        if isinstance(expr, PathExpression):
            expr = expr._expr
        self._expr = expr
    
    def __repr__(self):
        return f"{PathExpression.__qualname__}({self._expr!r})"

    # TODO
    # @functools.lru_cache(maxsize=1)
    def _is_concrete_single_cache(self, expr: ...):
        match expr:
            case str():
                return not wcmatch.fnmatch.is_magic(
                    expr, 
                    flags=wcmatch.fnmatch.BRACE,
                )
        return False

    # TODO
    # @functools.lru_cache(maxsize=1)
    def _matcher_cache(self, expr: ...):
        return wcmatch.fnmatch.compile(
            expr, 
            flags=wcmatch.fnmatch.BRACE,
            limit=0,
        )

    @property
    def is_concrete_single(self):
        return self._is_concrete_single_cache(self._expr)

    def resolve(self, paths: list[str]):
        return self._matcher_cache(self._expr).filter(paths)

    def match(self, path: str):
        return self._matcher_cache(self._expr).match(path)

    def expand(self):
        return bracex.expand(self._expr)


PathExpressionLike = str | Sequence[str] | PathExpression


In [55]:
PathExpression(["/a"]).is_concrete_single
PathExpression("/a").is_concrete_single

True

In [68]:
%timeit -n 10 PathExpression(["a*", "b", "c"]) #.resolve(["a", "ab", "b"])
%timeit -n 10 PathExpression("").resolve(["a", "ab", "b"])

230 ns ± 124 ns per loop (mean ± std. dev. of 7 runs, 10 loops each)
12.8 μs ± 6.91 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
# TODO
from typing import Callable

import numpy
from robotodo.utils.pose import Pose
from robotodo.engines.isaac._kernel import Kernel


USDPrimRef = Callable[[], list["pxr.Usd.Prim"]]
USDStageRef = Callable[[], "pxr.Usd.Stage"]


# TODO cache
class USDPrimPathRef(USDPrimRef):
    __slots__ = ["_path", "_stage"]

    def __init__(self, path: PathExpressionLike, stage_ref: USDStageRef):
        self._path = PathExpression(path)
        self._stage_ref = stage_ref

    def __call__(self):
        stage = self._stage_ref()
        paths = self._path.resolve(stage.Traverse())
        return [
            stage.GetPrimAtPath(path)
            for path in paths
        ]


class USDPrimHelper:
    __slots__ = ["_prims_ref", "_kernel"]

    def __init__(self, ref: USDPrimRef, kernel: Kernel):
        self._prims_ref = ref
        self._kernel = kernel

    # TODO invalidate!!!!
    @functools.lru_cache
    def _xform_cache(self, stage: "pxr.Usd.Stage"):
        pxr = self._kernel.pxr

        # TODO Usd.TimeCode.Default()
        cache = pxr.UsdGeom.XformCache()
        def _on_changed(notice, sender):
            # TODO
            cache.Clear()
        # TODO NOTE life cycle
        cache._notice_handler = _on_changed
        # TODO
        cache._notice_token = pxr.Tf.Notice.Register(
            pxr.Usd.Notice.ObjectsChanged, 
            cache._notice_handler, 
            stage,
        )

        return cache

    @property
    def prims(self):
        return self._prims_ref()

    @property
    def path(self):
        return [
            prim.GetPath().pathString
            for prim in self.prims
        ]

    # TODO special handling for cameras??
    @property
    def pose(self):
        # TODO 
        # pxr = self._kernel.pxr

        return Pose.from_matrix(
            numpy.stack([
                # NOTE matrices in USD are in col-major hence transpose
                numpy.transpose(
                    # TODO
                    # pxr.UsdGeom.Imageable(prim)
                    # .ComputeLocalToWorldTransform(pxr.Usd.TimeCode.Default())
                    self._xform_cache(prim.GetStage())
                    .GetLocalToWorldTransform(prim)
                    .RemoveScaleShear()
                )
                for prim in self.prims
            ])
            # TODO mv
            # @ self._world_convention_transform
        )

    # TODO FIXME BUG: _world_convention_transform!!!!
    # TODO special handling for cameras??
    @pose.setter
    def pose(self, value: Pose):
        # TODO 
        # pxr = self._kernel.pxr

        pose_parent = Pose.from_matrix(
            numpy.stack([
                numpy.asarray(
                    # TODO
                    # pxr.UsdGeom.Imageable(prim)
                    # .ComputeParentToWorldTransform(pxr.Usd.TimeCode.Default())
                    self._xform_cache(prim.GetStage())
                    .GetParentToWorldTransform(prim)
                    .RemoveScaleShear()
                ).T
                for prim in self.prims
            ])
        )
        self.pose_in_parent = pose_parent.inv() * value
    
    @property
    def pose_in_parent(self):
        def get_local_transform(prim: "pxr.Usd.Prim"):
            transform, _ = self._xform_cache(prim.GetStage()).GetLocalTransformation(prim)
            # NOTE matrices in USD are in col-major hence transpose
            return numpy.transpose(transform.RemoveScaleShear())
        return Pose.from_matrix(
            numpy.stack([
                get_local_transform(prim)
                for prim in self.prims
            ])
            # TODO mv
            # @ self._world_convention_transform
        )
    
    @pose_in_parent.setter
    def pose_in_parent(self, value: Pose):
        pxr = self._kernel.pxr
        omni = self._kernel.omni
        # TODO
        self._kernel.enable_extension("omni.physx")
        self._kernel.import_module("omni.physx.scripts.physicsUtils")
        
        # TODO mv
        # value = Pose.from_matrix(
        #     numpy.linalg.inv(self._world_convention_transform) 
        #     @ value.to_matrix()
        # )

        p_vec3s = pxr.Vt.Vec3fArrayFromBuffer(value.p)
        # NOTE this auto-converts from xyzw to wxyz
        q_quats = pxr.Vt.QuatfArrayFromBuffer(value.q)
        
        with pxr.Sdf.ChangeBlock():
            for prim, p_vec3, q_quat in zip(self.prims, p_vec3s, q_quats):
                xformable = pxr.UsdGeom.Xformable(prim)
                omni.physx.scripts.physicsUtils \
                    .set_or_add_translate_op(xformable, p_vec3)
                omni.physx.scripts.physicsUtils \
                    .set_or_add_orient_op(xformable, q_quat)

    # TODO rm
    # # TODO !!!
    # # TODO cache
    # @property
    # def _world_convention_transform(self):
    #     raise NotImplementedError
    #     # TODO for cams: necesito??
    #     # return Pose(q=[1., -1., -1., 1.]) 



In [None]:
from robotodo.engines.isaac.scene import Scene


class Entity:
    def __init__(self, path: PathExpressionLike, scene: Scene):
        self._usd_prim_helper = USDPrimHelper(
            ref=USDPrimPathRef(path, stage_ref=lambda: scene._usd_stage),
            kernel=scene._kernel,
        )

    @property
    def path(self):
        return self._usd_prim_helper.path

    @property
    def pose(self):
        return self._usd_prim_helper.pose

In [None]:
# scene = Scene.create()
# scene.get("/some/object")

# scene = Scene.create_from_usd("someusdref.usd")