In [1]:
import socket
import lz4.frame
import json
import torch
import pythreejs as p3js
from IPython.display import display
from collections.abc import Callable
# from omni.isaac.lab.utils.math import quat_inv, quat_mul


class Rokoko_Glove():
    def __init__(self, pos_sensitivity: float = 0.4, rot_sensitivity: float = 0.8, device="cuda:0"):
        self.device = device
        self.fingertip_poses = torch.zeros((12, 7), device=self.device)
        self._additional_callbacks = dict()
        # Define the IP address and port to listen on
        UDP_IP = "0.0.0.0"  # Listen on all available network interfaces
        UDP_PORT = 14043     # Make sure this matches the port used in Rokoko Studio Live

        # Create a UDP socket
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 8192)
        self.sock.bind((UDP_IP, UDP_PORT))

        print(f"Listening for UDP packets on {UDP_IP}:{UDP_PORT}")

        self.left_hand_joint_names = [
            'leftHand', 'leftThumbProximal', 'leftThumbMedial', 'leftThumbDistal', 'leftThumbTip',
            'leftIndexProximal', 'leftIndexMedial', 'leftIndexDistal', 'leftIndexTip',
            'leftMiddleProximal', 'leftMiddleMedial', 'leftMiddleDistal', 'leftMiddleTip',
            'leftRingProximal', 'leftRingMedial', 'leftRingDistal', 'leftRingTip',
            'leftLittleProximal', 'leftLittleMedial', 'leftLittleDistal', 'leftLittleTip']

        self.right_hand_joint_names = [
            'rightHand', 'rightThumbProximal', 'rightThumbMedial', 'rightThumbDistal', 'rightThumbTip',
            'rightIndexProximal', 'rightIndexMedial', 'rightIndexDistal', 'rightIndexTip',
            'rightMiddleProximal', 'rightMiddleMedial', 'rightMiddleDistal', 'rightMiddleTip',
            'rightRingProximal', 'rightRingMedial', 'rightRingDistal', 'rightRingTip',
            'rightLittleProximal', 'rightLittleMedial', 'rightLittleDistal', 'rightLittleTip']

        self.joint_names = self.left_hand_joint_names + self.right_hand_joint_names
        self.joint_spheres = []
        self.joint_axes = []
        self.scene = self.create_scene()

    def create_scene(self):
        # Create the 3D scene
        camera = p3js.PerspectiveCamera(position=[0, 2, 5], up=[0, 1, 0], children=[
            p3js.DirectionalLight(color='white', position=[3, 5, 1], intensity=0.5)
        ])
        scene = p3js.Scene(children=[
            camera,
            p3js.AmbientLight(intensity=0.5)
        ])
        for _ in range(len(self.joint_names)):
            sphere = p3js.Mesh(
                p3js.SphereGeometry(radius=0.05),
                p3js.MeshLambertMaterial(color='blue')
            )
            self.joint_spheres.append(sphere)
            scene.add(sphere)
            axes = self.create_axes()
            self.joint_axes.append(axes)
            scene.add(axes)
        controller = p3js.OrbitControls(controlling=camera)
        renderer = p3js.Renderer(scene=scene, camera=camera, controls=[controller], width=800, height=600)
        display(renderer)
        return scene

    def create_axes(self):
        # Create a coordinate system with three lines representing X, Y, and Z axes
        axes = p3js.Group()
        for color, rotation in zip(['red', 'green', 'blue'], [[0.707, 0, 0, 0.707], [0, 0, 0, 1], [0, 0.707, 0, 0.707]]):
            axis = p3js.Mesh(
                p3js.CylinderGeometry(radiusTop=0.01, radiusBottom=0.01, height=0.2),
                p3js.MeshLambertMaterial(color=color)
            )
            axis.position = [0, 0.1, 0]
            axis.quaternion = rotation
            axes.add(axis)
        return axes

    def reset(self):
        pass

    def advance(self):
        data, addr = self.sock.recvfrom(8192)  # Buffer size is 1024 bytes
        decompressed_data = lz4.frame.decompress(data)

        received_json = json.loads(decompressed_data)
        # Initialize arrays to store the positions and orientations
        hand_positions = torch.zeros((42, 3), device=self.device)
        hand_orientations = torch.zeros((42, 4), device=self.device)

        # Iterate through the JSON data to extract hand joint positions and orientations
        for joint_name in self.joint_names:
            joint_data = received_json["scene"]["actors"][0]["body"][joint_name]
            joint_position = torch.tensor(list(joint_data["position"].values()))
            joint_rotation = torch.tensor(list(joint_data["rotation"].values()))
            idx = self.joint_names.index(joint_name)
            hand_positions[idx] = joint_position
            hand_orientations[idx] = joint_rotation
        hand_positions *= 5
        # Calculate joint angles
        joint_angles = torch.zeros(41, device=self.device)  # There are 42 joints, but 41 angles between them
        # for i in range(1, 42):  # Skip the root joint (index 0)
        #     parent_idx = i - 1  # Assuming the parent is the previous joint in the list
        #     local_rotation = quat_mul(quat_inv(hand_orientations[parent_idx].unsqueeze(0)), hand_orientations[i].unsqueeze(0))
        #     joint_angles[i - 1] = quaternion_to_angle(local_rotation)
        
        self.update_scene(hand_positions, hand_orientations)
        return None

    def update_scene(self, hand_positions, hand_orientations):
        # Update the position of the spheres and axes in the 3D scene
        for i, (sphere, axes) in enumerate(zip(self.joint_spheres, self.joint_axes)):
            pos = hand_positions[i].cpu().numpy()
            sphere.position = list(pos)
            quat = hand_orientations[i].cpu().numpy()
            axes.position = list(pos)
            axes.quaternion = list(quat)

    def add_callback(self, key: str, func: Callable):
        if key not in ["L", "R"]:
            raise ValueError(f"Only left (L) and right (R) buttons supported. Provided: {key}.")
        self._additional_callbacks[key] = func

    def normalize_wrt_middle_proximal(self, hand_positions, is_left=True):
        middle_proximal_idx = self.left_hand_joint_names.index('leftMiddleProximal')
        if not is_left:
            middle_proximal_idx = self.right_hand_joint_names.index('rightMiddleProximal')

        wrist_position = hand_positions[0]
        middle_proximal_position = hand_positions[middle_proximal_idx]
        bone_length = torch.linalg.norm(wrist_position - middle_proximal_position)
        normalized_hand_positions = (middle_proximal_position - hand_positions) / bone_length
        return normalized_hand_positions


# Example usage
glove = Rokoko_Glove()
while True:
    glove.advance()


Listening for UDP packets on 0.0.0.0:14043


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.5, position=(3.0, 5.0,…

KeyboardInterrupt: 