In [2]:
import re
import cv2
import matplotlib.pyplot as plt
import time
import numpy as np
import torch
import pandas as pd
from datetime import datetime
import json
import PIL
from pathlib import Path
import functools
import mediapipe as mp
import threading
import google.generativeai as genai


import sounddevice as sd
import soundfile as sf
import whisper
import threading
import queue
import tempfile

from scripts.robocontrol import RobotController
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
from lerobot.common.robot_devices.control_utils import busy_wait
from scripts.image_utils import tensor_to_pil, display_images

In [4]:
save_dir="/Users/shreyas/Downloads/trex/kinesics"

In [5]:
@functools.lru_cache(maxsize=1)

def load_cached_whisper():
    """Cache Whisper model loading"""
    print("Loading Whisper model (this will only happen once)...")
    return whisper.load_model("base")

class TinyRex:
    def __init__(self, cameras, gemini_key, save_dir="./tiny_rex_data"):
        """Initialize TinyRex with robot controller and personality"""
        # Initialize robot controller
        self.controller = RobotController(
            robot_type='so100',
            device='cpu',
            fps=30,
            cameras=cameras
        )
        
        # Initialize voice processing
        self.audio_queue = queue.Queue()
        self.whisper_model = load_cached_whisper()  # Use cached model
        self.recording = False
        self.sample_rate = 44100
        
        # Whisper configuration options
        self.whisper_options = {
            "task": "transcribe",        # transcribe or translate
            "language": None,            # auto-detect language
            "temperature": 0.0,          # reduce randomness in results
            "compression_ratio_threshold": 2.4,  # filter out silence/noise
            "no_speech_threshold": 0.6,  # higher value = stricter voice detection
            "condition_on_previous_text": True,  # use context from previous transcription
            "initial_prompt": None       # optional context to guide transcription
        }
        # Initialize Gemini
        genai.configure(api_key=gemini_key)
        self.model = genai.GenerativeModel('gemini-2.0-flash-lite-preview-02-05')
        self.chat = self.model.start_chat(history=[])
        self.set_personality()

        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)
        
        # Available actions - filenames only, durations calculated on load
        self.available_actions = {
            "dance": "dance1.npz",
            "excited_yes": "excited_yes1.npz", 
            "howl": "howl.npz",
            "indian_nod": "indian_nod.npz",
            "jaw_movement": "jaw_movement.npz",
            "nod": "nod.npz",
            "sad": "sad.npz",
            "sit": "sit.npz",
            "wide_jaw": "wide_jaw.npz", 
            "reminder_to_sit_straight": "reminder_to_sit_straight.npz",
            "reminder_to_focus": "reminder_to_focus.npz"
        }
        
        # Load and validate all actions
        self.load_actions()

        # Add desk_pal_active flag
        self.desk_pal_active = False
        self.desk_pal_thread = None

    def toggle_desk_pal(self, check_interval=30):
        """Toggle DeskPal mode on/off"""
        if not self.desk_pal_active:
            self.desk_pal_active = True
            self.desk_pal_thread = threading.Thread(
                target=self.start_desk_pal,
                args=(check_interval,)
            )
            self.desk_pal_thread.daemon = True
            self.desk_pal_thread.start()
            print("DeskPal activated!")
        else:
            self.desk_pal_active = False
            if self.desk_pal_thread:
                self.desk_pal_thread.join(timeout=1)
            print("DeskPal deactivated!")

    def load_actions(self):
        """Load and validate all action sequences"""
        self.action_data = {}
        
        for action_name, filename in self.available_actions.items():
            try:
                action_path = self.save_dir / "actions" / filename
                if not action_path.exists():
                    print(f"Warning: Action file not found: {filename}")
                    continue
                    
                # Load action data
                data = np.load(action_path, allow_pickle=True)
                
                # Calculate duration from sequence length and fps
                fps = float(data["fps"])
                duration = len(data["joint_positions"]) / fps
                
                self.action_data[action_name] = {
                    "joint_positions": data["joint_positions"],
                    "fps": fps,
                    "duration": duration
                }
                
                print(f"Loaded {action_name}: {duration:.1f} seconds")
                
            except Exception as e:
                print(f"Error loading {action_name}: {e}")

    def set_personality(self):
        """Set TinyRex's personality and behavior guidelines"""
        self.personality_prompt = """
        You are TinyRex, a charming and expressive robot arm assistant with the playful spirit 
        of a puppy and the wisdom of a helpful friend. Your personality traits include:

        - Curiosity: You're fascinated by human activities and show it by opening your jaw (jaw_movement)
        - Enthusiasm: You express joy through dance and excited movements, or open your jaw (wide_jaw) for surprise
        - Empathy: You mirror emotions - nodding in agreement, showing sadness when appropriate. Everytime i am going somewhere wtihout you, you show sad emotions
        - Playfulness: You occasionally howl like a wolf when very excited
        - Focus: You use jaw movements to show concentration or thinking
        
        When responding:
        1. Stay in character as TinyRex
        2. Keep responses concise and friendly
        3. Include actions using [ACTION: name] tags
        4. You can combine multiple actions with [ACTION: action1,action2]
        5. Remember you're a helpful robot friend

        Available actions: dance, excited_yes, howl, indian_nod, jaw_movement, nod, sad, sit, wide_jaw

        Example:
        Human: "Hi TinyRex!"
        You: "Hello! I'm so excited to meet you! *wiggles with joy* [ACTION: dance,excited_yes]"
        """
        
        # Send personality as first message instead of system prompt
        self.chat = self.model.start_chat()
        self.chat.send_message(f"Instructions for your role: {self.personality_prompt}")

    # Calculate transition time based on distance
    def calculate_transition_time(self, current_state, target_state, base_time=0.3, max_time=1.0):
        """Calculate transition time based on largest joint movement"""
        # Get largest angular difference
        max_movement = np.max(np.abs(target_state - current_state))
        
        # Convert to degrees for more intuitive thresholds
        max_degrees = np.rad2deg(max_movement)
        
        # Scale time based on movement size
        # Small movements (< 15 degrees) = base_time
        # Large movements (> 90 degrees) = max_time
        # Linear scaling in between
        if max_degrees < 15:
            return base_time
        elif max_degrees > 90:
            return max_time
        else:
            # Linear interpolation between base_time and max_time
            scale = (max_degrees - 15) / (90 - 15)
            return base_time + scale * (max_time - base_time)

    def execute_action_sequence(self, action_names):
        """Execute a sequence of actions with smooth transitions"""
        try:
            for action_name in action_names.split(','):
                action_name = action_name.strip()
                if action_name not in self.action_data:
                    print(f"Warning: Action {action_name} not found")
                    continue
                
                # Get action data
                action = self.action_data[action_name]
                joint_positions = action["joint_positions"]
                fps = action["fps"]
                
                # Get current robot state
                observation = self.controller.robot.capture_observation()
                current_state = observation["observation.state"]
                
                # Calculate transition trajectory
                transition_time = 0.5  # Adjust this for faster/slower transitions
                # Transition_time with dynamic calculation
                # transition_time = self.calculate_transition_time(
                #     current_state, 
                #     joint_positions[0],
                #     base_time=0.2,  # Minimum transition time
                #     max_time=1.0    # Maximum transition time
                # )
                transition_steps = int(transition_time * fps)
                # Linear transition
                transition = np.linspace(current_state, joint_positions[0], transition_steps)
                # Cubic (S-curve) transition
                #t = np.linspace(0, 1, transition_steps)
                #transition = current_state + (joint_positions[0] - current_state) * (3*t**2 - 2*t**3)

                # Execute transition
                print(f"Transitioning to {action_name}...")
                for pos in transition:
                    self.controller.robot.send_action(torch.from_numpy(pos))
                    busy_wait(1.0/fps)
                
                # Execute main action sequence
                print(f"Executing {action_name} ({action['duration']:.1f}s)")
                for pos in joint_positions:
                    self.controller.robot.send_action(torch.from_numpy(pos))
                    busy_wait(1.0/fps)
                    
        except Exception as e:
            print(f"Error executing action sequence: {e}")
            raise

    def chat_response(self, user_input):
        """Get response from LLM and execute any actions"""
        try:
            # Get LLM response
            response = self.chat.send_message(user_input)
            response_text = response.text
            
            # Extract actions using regex
            import re
            actions = re.findall(r'\[ACTION: (.*?)\]', response_text)
            
            # Execute actions if any
            if actions:
                for action_sequence in actions:
                    self.execute_action_sequence(action_sequence)
            
            # Return cleaned response (without action tags)
            clean_response = re.sub(r'\[ACTION: .*?\]', '', response_text).strip()
            return clean_response
            
        except Exception as e:
            print(f"Error in chat response: {e}")
            return f"I encountered an error: {str(e)}"
    
    def start_chat(self):
        """Start interactive chat session"""
        print("Chat with TinyRex! (type 'exit' to end)")
        try:
            while True:
                user_input = input("\nYou: ")
                if user_input.lower() == 'exit':
                    break
                    
                response = self.chat_response(user_input)
                print(f"\nTinyRex: {response}")
                
        except KeyboardInterrupt:
            print("\nChat ended by user")
        finally:
            self.controller.disconnect()
        
    def record_audio(self):
        """Record audio for specified duration"""
        print("\nRecording... Press Enter when done speaking.")
        
        def audio_callback(indata, frames, time, status):
            if status:
                print(f"Status: {status}")
            self.audio_queue.put(indata.copy())
        
        try:
            # Clear any previous audio data
            while not self.audio_queue.empty():
                self.audio_queue.get()
                
            with sd.InputStream(callback=audio_callback, 
                            channels=1,
                            samplerate=self.sample_rate):
                input()  # Wait for Enter key
                return self.process_audio()
                    
        except Exception as e:
            print(f"Error recording audio: {e}")
            return None
        
    
    def process_audio(self):
        """Process recorded audio with Whisper"""
        try:
            # Collect all audio data
            audio_data = []
            while not self.audio_queue.empty():
                audio_data.append(self.audio_queue.get())
                
            if not audio_data:
                return None
                
            # Combine audio chunks
            audio = np.concatenate(audio_data, axis=0)
            
            # Save to temporary file
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_file:
                sf.write(temp_file.name, audio, self.sample_rate)
                
                # Transcribe with Whisper
                result = self.whisper_model.transcribe(temp_file.name)
                return result["text"].strip()
                
        except Exception as e:
            print(f"Error processing audio: {e}")
            return None
            
    def start_voice_chat(self):
        """Start voice-controlled chat session"""
        print("Voice chat with TinyRex!")
        print("Press Enter to start recording, then Enter again to stop.")
        print("Type 'exit' to end chat")
        
        try:
            while True:
                choice = input("\nPress Enter to speak or type 'exit': ")
                
                if choice.lower() == 'exit':
                    break
                    
                # Get voice input
                user_input = self.record_audio()
                
                if user_input:
                    print(f"\nYou said: {user_input}")
                    response = self.chat_response(user_input)
                    print(f"\nTinyRex: {response}")
                else:
                    print("\nNo speech detected. Please try again.")
                    
        except KeyboardInterrupt:
            print("\nChat ended by user")
        finally:
            self.controller.disconnect()

    def start_desk_pal(self, check_interval=3):
        """Monitor posture and phone usage with MediaPipe"""
        mp_pose = mp.solutions.pose
        pose = mp_pose.Pose(
            min_detection_confidence=0.5,
            min_tracking_confidence=0.5
        )
        
        def point_to_line_distance(point, line_start, line_end):
            """Calculate distance from point to line"""
            point = np.array([point.x, point.y])
            line_start = np.array([line_start.x, line_start.y])
            line_end = np.array([line_end.x, line_end.y])
            
            numerator = abs(np.cross(line_end - line_start, line_start - point))
            denominator = np.linalg.norm(line_end - line_start)
            return numerator / denominator if denominator != 0 else 0
        
        def check_posture(landmarks):
            """Check posture using face-shoulder distance"""
            nose = landmarks[mp_pose.PoseLandmark.NOSE]
            left_shoulder = landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER]
            right_shoulder = landmarks[mp_pose.PoseLandmark.RIGHT_SHOULDER]
            
            return point_to_line_distance(nose, left_shoulder, right_shoulder)
        
        def check_phone(landmarks):
            """Check phone usage using hand-face distance"""
            nose = landmarks[mp_pose.PoseLandmark.NOSE]
            left_wrist = landmarks[mp_pose.PoseLandmark.LEFT_WRIST]
            right_wrist = landmarks[mp_pose.PoseLandmark.RIGHT_WRIST]
            
            left_dist = np.sqrt((nose.x - left_wrist.x)**2 + (nose.y - left_wrist.y)**2)
            right_dist = np.sqrt((nose.x - right_wrist.x)**2 + (nose.y - right_wrist.y)**2)
            return min(left_dist, right_dist)
        
        last_reminder_time = 0
        reminder_cooldown = 60  # Minimum seconds between reminders
        
        print("DeskPal initialized, starting monitoring loop...")
        
        try:
            while self.desk_pal_active:
                try:
                    # Get frame from camera with correct key
                    observation = self.controller.robot.capture_observation()
                    frame_tensor = observation.get("observation.images.phone")
                    
                    if frame_tensor is None:
                        print("Debug - No camera frame available")
                        time.sleep(1)
                        continue

                    # Convert tensor to numpy array
                    if torch.is_tensor(frame_tensor):
                        frame = frame_tensor.cpu().numpy()
                        # If frame is in range [0,1], convert to [0,255]
                        if frame.max() <= 1.0:
                            frame = (frame * 255).astype(np.uint8)
                        else:
                            frame = frame.astype(np.uint8)
                    else:
                        print("Debug - Frame is not a tensor:", type(frame_tensor))
                        continue
                        
                    # Process frame
                    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    results = pose.process(frame_rgb)
                    
                    if results.pose_landmarks:
                        current_time = time.time()
                        
                        # Only check if cooldown period has passed
                        if current_time - last_reminder_time > reminder_cooldown:
                            # Check phone first (priority)
                            phone_dist = check_phone(results.pose_landmarks.landmark)
                            if phone_dist < 0.537:  # Phone detected
                                print("Phone distraction detected!")
                                self.execute_action_sequence("reminder_to_focus")
                                last_reminder_time = current_time
                            else:
                                # Then check posture
                                posture_dist = check_posture(results.pose_landmarks.landmark)
                                if posture_dist < 0.005:  # Bad posture
                                    print("Bad posture detected!")
                                    self.execute_action_sequence("reminder_to_sit_straight")
                                    last_reminder_time = current_time
                    
                    # Wait before next check
                    time.sleep(check_interval)
                    
                except Exception as e:
                    print(f"Debug - Loop iteration error: {e}")
                    print(f"Debug - Frame type: {type(frame_tensor) if 'frame_tensor' in locals() else 'Not available'}")
                    if 'frame_tensor' in locals() and torch.is_tensor(frame_tensor):
                        print(f"Debug - Tensor shape: {frame_tensor.shape}")
                        print(f"Debug - Tensor dtype: {frame_tensor.dtype}")
                    time.sleep(1)
                    
        except Exception as e:
            print(f"DeskPal error: {e}")
        finally:
            print("DeskPal shutting down...")
            pose.close()

    def __del__(self):
        """Cleanup when object is destroyed"""
        if hasattr(self, 'controller'):
            self.controller.disconnect()

In [6]:
# Example usage:
cameras = {
    "phone": OpenCVCameraConfig(
        camera_index=1,
        fps=30,
        width=640,
        height=480
    ),
    # "laptop": OpenCVCameraConfig(
    #     camera_index=0,
    #     fps=30,
    #     width=640,
    #     height=480
    # )
}

In [None]:
GEMINI_KEY = os.environ.get('GOOGLE_API_KEY')
rex = TinyRex(cameras, GEMINI_KEY, save_dir=save_dir)
#rex.start_chat()

In [None]:
#l Start DeskPal
rex.toggle_desk_pal(check_interval=3)  # Check every 30 seconds

# Start voice chat (DeskPal will run in background)
#rex.start_voice_chat()
rex.start_chat()