In [3]:
import pybullet as p
import time
import pybullet_data
import numpy as np
import pandas as pd
from abc import ABC, abstractmethod

if p.isConnected():
    p.disconnect()

## SETUP SIMULATION
p.connect(p.GUI)
p.setAdditionalSearchPath(pybullet_data.getDataPath())
p.resetSimulation()
p.setGravity(0, 0, -10)
p.setRealTimeSimulation(0)

plane_id = p.loadURDF("plane.urdf")

# =============================================================================
# ABSTRACT BASE CLASSES
# =============================================================================

class GraspableObject(ABC):
    @abstractmethod
    def get_position(self):
        """Return the current position [x, y, z] of the object."""
        pass

    @abstractmethod
    def get_orientation(self):
        """Return the current orientation (quaternion) of the object."""
        pass

    @abstractmethod
    def reset(self):
        """Reset the object to its initial position and orientation."""
        pass

class Gripper(ABC):
    @abstractmethod
    def close_gripper(self):
        """Close the gripper fingers."""
        pass

    @abstractmethod
    def open_gripper(self):
        """Open the gripper fingers."""
        pass

    @abstractmethod
    def move_to_pose(self, position, orientation):
        """Move the gripper to a specific pose (position + orientation)."""
        pass
    
    @abstractmethod
    def get_pose(self):
        """Return the current pose of the gripper as (position, orientation)."""
        pass

# =============================================================================
# GRASP SAMPLER - Generates candidate grasp poses
# =============================================================================

class GraspSampler:
    """
    Generates candidate grasp poses around an object using sphere-based sampling.
    
    The object is at the center of a conceptual sphere. Grasp poses are sampled
    by picking points on the sphere and orienting the gripper to point inward.
    """
    
    def __init__(self, object_position, radius=0.15, height_offset=0.0):
        """
        Args:
            object_position: [x, y, z] center position of the object
            radius: Distance from object center to gripper position
            height_offset: Additional height offset for approach
        """
        self.object_position = np.array(object_position)
        self.radius = radius
        self.height_offset = height_offset
    
    def sample_poses(self, num_samples, noise_std=0.005, include_difficult=True):
        """
        Generate grasp pose candidates using spherical sampling.
        Includes both easy (likely success) and difficult (likely failure) poses
        for a balanced dataset.
        
        Args:
            num_samples: Number of grasp poses to generate
            noise_std: Standard deviation of Gaussian noise to add
            include_difficult: If True, includes ~40% difficult poses likely to fail
            
        Returns:
            List of poses, each as dict with 'position' [x,y,z] and 'orientation' [roll,pitch,yaw]
        """
        poses = []
        
        # Split samples: 60% good poses, 40% difficult poses (for balanced dataset)
        num_good = int(num_samples * 0.6) if include_difficult else num_samples
        num_difficult = num_samples - num_good
        
        # Generate GOOD poses (likely to succeed)
        for _ in range(num_good):
            theta = np.random.uniform(0, 2 * np.pi)
            phi = np.random.beta(2, 3) * (np.pi / 4)  # 0-45 degrees from vertical
            
            x = self.radius * np.sin(phi) * np.cos(theta)
            y = self.radius * np.sin(phi) * np.sin(theta)
            z = self.radius * np.cos(phi) + self.height_offset
            z = np.clip(z, 0.03, 0.10)
            
            x += np.random.normal(0, noise_std)
            y += np.random.normal(0, noise_std)
            z += np.random.normal(0, noise_std * 0.5)
            
            position = self.object_position + np.array([x, y, z])
            
            direction = self.object_position - position
            direction = direction / np.linalg.norm(direction)
            
            yaw = np.arctan2(direction[1], direction[0])
            horizontal_dist = np.sqrt(direction[0]**2 + direction[1]**2)
            pitch = np.arctan2(-direction[2], horizontal_dist)
            roll = 0
            
            yaw += np.random.normal(0, noise_std * 2)
            pitch += np.random.normal(0, noise_std * 0.5)
            
            poses.append({
                'position': position.tolist(),
                'orientation': [roll, pitch, yaw]
            })
        
        # Generate DIFFICULT poses (likely to fail) - for balanced dataset
        for _ in range(num_difficult):
            # Randomly choose a type of difficult grasp
            difficulty_type = np.random.choice(['too_far', 'too_high', 'bad_angle', 'offset'])
            
            theta = np.random.uniform(0, 2 * np.pi)
            
            if difficulty_type == 'too_far':
                # Gripper too far from object
                phi = np.random.uniform(np.pi/4, np.pi/2.5)  # More horizontal
                radius_mult = np.random.uniform(1.5, 2.5)
                x = self.radius * radius_mult * np.sin(phi) * np.cos(theta)
                y = self.radius * radius_mult * np.sin(phi) * np.sin(theta)
                z = self.radius * np.cos(phi) + self.height_offset
                
            elif difficulty_type == 'too_high':
                # Gripper too high above object
                phi = np.random.uniform(0, np.pi/6)
                x = self.radius * np.sin(phi) * np.cos(theta)
                y = self.radius * np.sin(phi) * np.sin(theta)
                z = np.random.uniform(0.12, 0.20)  # Way too high
                
            elif difficulty_type == 'bad_angle':
                # Bad approach angle
                phi = np.random.uniform(np.pi/3, np.pi/2)  # Very horizontal
                x = self.radius * np.sin(phi) * np.cos(theta)
                y = self.radius * np.sin(phi) * np.sin(theta)
                z = self.radius * np.cos(phi) + self.height_offset
                z = np.clip(z, 0.02, 0.08)
                
            else:  # offset
                # Large XY offset (gripper not centered on object)
                phi = np.random.uniform(0, np.pi/4)
                x = self.radius * np.sin(phi) * np.cos(theta) + np.random.uniform(-0.06, 0.06)
                y = self.radius * np.sin(phi) * np.sin(theta) + np.random.uniform(-0.06, 0.06)
                z = self.radius * np.cos(phi) + self.height_offset
            
            position = self.object_position + np.array([x, y, z])
            
            # For difficult poses, sometimes don't point at object
            if np.random.random() < 0.3:
                # Random orientation (not pointing at object)
                yaw = np.random.uniform(0, 2 * np.pi)
                pitch = np.random.uniform(-0.5, 0.5)
            else:
                direction = self.object_position - position
                if np.linalg.norm(direction) > 0.001:
                    direction = direction / np.linalg.norm(direction)
                    yaw = np.arctan2(direction[1], direction[0])
                    horizontal_dist = np.sqrt(direction[0]**2 + direction[1]**2)
                    pitch = np.arctan2(-direction[2], horizontal_dist)
                else:
                    yaw, pitch = 0, 0
            
            roll = 0
            
            poses.append({
                'position': position.tolist(),
                'orientation': [roll, pitch, yaw]
            })
        
        # Shuffle to mix good and difficult poses
        np.random.shuffle(poses)
        
        return poses
    

    # A simpler sampling method for small scale tests
    def sample_simple_poses(self, num_samples, z_range=(0.05, 0.15), yaw_range=(0, 2*np.pi)):
        """
        Simpler sampling: vary z-height and yaw angle above the object.
        Good for top-down grasping scenarios.
        
        Args:
            num_samples: Number of poses to generate
            z_range: (min, max) height above object
            yaw_range: (min, max) yaw rotation range
            
        Returns:
            List of poses with 'position' and 'orientation'
        """
        poses = []
        obj_x, obj_y, obj_z = self.object_position
        
        for _ in range(num_samples):
            z_offset = np.random.uniform(*z_range)
            yaw = np.random.uniform(*yaw_range)
            
            # Small x,y variation
            x_noise = np.random.normal(0, 0.02)
            y_noise = np.random.normal(0, 0.02)
            
            poses.append({
                'position': [obj_x + x_noise, obj_y + y_noise, obj_z + z_offset],
                'orientation': [0, 0, yaw]  # roll, pitch, yaw
            })
        
        return poses

# =============================================================================
# GRASP SIMULATOR - Orchestrates grasp execution and data collection
# =============================================================================

class GraspSimulator:
    """
    Orchestrates grasp attempts and collects success/failure data.
    
    This class:
    1. Moves the gripper to a candidate pose
    2. Executes a grasp attempt (approach, close, lift)
    3. Evaluates success (object held after lifting)
    4. Records results for dataset creation
    """
    
    def __init__(self, gripper, graspable_object, lift_height=0.3, hold_time=3.0):
        """
        Args:
            gripper: A Gripper instance
            graspable_object: A GraspableObject instance
            lift_height: Height to lift object to (meters)
            hold_time: Time to hold object to verify grasp (seconds)
        """
        self.gripper = gripper
        self.object = graspable_object
        self.lift_height = lift_height
        self.hold_time = hold_time
        self.results = []  # Store (pose_features, success_label)
    
    def execute_grasp(self, pose, steps_per_phase=100, visualize=True):
        """
        Execute a single grasp attempt at the given pose.
        
        Args:
            pose: Dict with 'position' [x,y,z] and 'orientation' [roll,pitch,yaw]
            steps_per_phase: Simulation steps per movement phase
            visualize: Whether to add delays for visualization
            
        Returns:
            bool: True if grasp was successful, False otherwise
        """
        position = pose['position']
        orientation = pose['orientation']
        
        # Record initial object position
        initial_obj_pos = self.object.get_position()
        
        # Step 1: Open gripper and move to approach pose (above target)
        self.gripper.open_gripper()
        approach_pos = [position[0], position[1], position[2] + 0.1]
        self.gripper.move_to_pose(approach_pos, orientation)
        self._run_simulation(steps_per_phase, visualize)
        
        # Step 2: Move down to grasp pose
        self.gripper.move_to_pose(position, orientation)
        self._run_simulation(steps_per_phase, visualize)
        
        # Step 3: Close gripper - give extra time to fully close
        self.gripper.close_gripper()
        self._run_simulation(steps_per_phase * 3, visualize)  # 3x more time to close
        
        # Step 4: Lift
        lift_pos = [position[0], position[1], self.lift_height]
        self.gripper.move_to_pose(lift_pos, orientation)
        self._run_simulation(steps_per_phase * 2, visualize)
        
        # Step 5: Hold and check success
        success = self._check_grasp_success(self.hold_time, visualize)
        
        # Record the result
        pose_features = position + orientation  # [x, y, z, roll, pitch, yaw]
        self.results.append({
            'x': position[0],
            'y': position[1],
            'z': position[2],
            'roll': orientation[0],
            'pitch': orientation[1],
            'yaw': orientation[2],
            'success': int(success)
        })
        
        return success
    
    def _run_simulation(self, steps, visualize=True):
        """Run simulation for given number of steps."""
        for _ in range(steps):
            p.stepSimulation()
            if visualize:
                time.sleep(1./240.)
    
    def _check_grasp_success(self, hold_time, visualize=True):
        """
        Check if object is still held after lifting.
        
        Success criteria: Object z-position remains above a threshold
        (hasn't fallen back to ground) for the hold duration.
        """
        check_steps = int(hold_time * 240)  # Convert seconds to simulation steps
        ground_threshold = 0.1  # Object should be above this height
        
        for _ in range(check_steps):
            p.stepSimulation()
            if visualize:
                time.sleep(1./240.)
            
            # Get current object position
            obj_pos, _ = p.getBasePositionAndOrientation(self.object.body_id)
            
            # If object fell below threshold, grasp failed
            if obj_pos[2] < ground_threshold:
                return False
        
        return True
    
    def reset_scene(self):
        """Reset object to initial position for next grasp attempt."""
        self.object.reset()
        self.gripper.open_gripper()
        self._run_simulation(50, visualize=False)
    
    def get_dataset(self):
        """Return collected results as a pandas DataFrame."""
        return pd.DataFrame(self.results)
    
    def run_data_collection(self, poses, reset_between=True, visualize=True):
        """
        Run multiple grasp attempts and collect data.
        
        Args:
            poses: List of pose dicts from GraspSampler
            reset_between: Whether to reset scene between attempts
            visualize: Whether to show simulation in real-time
            
        Returns:
            pandas DataFrame with pose features and success labels
        """
        for i, pose in enumerate(poses):
            print(f"Grasp attempt {i+1}/{len(poses)}")
            success = self.execute_grasp(pose, visualize=visualize)
            print(f"  Result: {'SUCCESS' if success else 'FAILURE'}")
            
            if reset_between:
                self.reset_scene()
        
        return self.get_dataset()

# =============================================================================
# CONCRETE GRIPPER IMPLEMENTATIONS
# =============================================================================

class PR2Gripper(Gripper):
    """PR2 parallel-jaw gripper implementation."""
    
    def __init__(self, position, orientation):
        """
        Args:
            position: [x, y, z] initial position
            orientation: quaternion [x, y, z, w] initial orientation
        """
        self.initial_position = list(position)
        self.initial_orientation = list(orientation)
        self.current_position = list(position)
        self.current_orientation = [0, 0, 0]  # Euler angles [roll, pitch, yaw]
        
        # Load the gripper URDF
        self.body_id = p.loadURDF("pr2_gripper.urdf", position[0], position[1], position[2], 
                                   orientation[0], orientation[1], orientation[2], orientation[3])
        
        # Initialize gripper joints (open position)
        self.joint_positions = [0.550569, 0.0, 0.549657, 0.0]
        for i, pos in enumerate(self.joint_positions):
            p.resetJointState(self.body_id, i, pos)
        
        # Create constraint to control gripper pose
        self.constraint_id = p.createConstraint(
            parentBodyUniqueId=self.body_id,
            parentLinkIndex=-1,
            childBodyUniqueId=-1,
            childLinkIndex=-1,
            jointType=p.JOINT_FIXED,
            jointAxis=[0, 0, 0],
            parentFramePosition=[0.2, 0, 0],
            childFramePosition=position
        )
        
        # Add friction to gripper finger links for better grip
        for link_idx in range(p.getNumJoints(self.body_id)):
            p.changeDynamics(self.body_id, link_idx, 
                             lateralFriction=3.0,
                             spinningFriction=1.0,
                             rollingFriction=1.0)
    
    def close_gripper(self):
        """Close the gripper fingers to grasp."""
        for joint in [0, 2]:
            p.setJointMotorControl2(
                self.body_id, 
                joint, 
                p.POSITION_CONTROL, 
                targetPosition=0.0,  # Close fully
                maxVelocity=2.0,  # Faster closing speed
                force=50  # Strong grip force
            )

    def open_gripper(self):
        """Open the gripper fingers."""
        for joint in [0, 2]:
            p.setJointMotorControl2(
                self.body_id, 
                joint, 
                p.POSITION_CONTROL, 
                targetPosition=0.55, 
                maxVelocity=1, 
                force=10
            )

    def move_to_pose(self, position, orientation):
        """
        Move the gripper to a specific pose.
        
        Args:
            position: [x, y, z] target position
            orientation: [roll, pitch, yaw] Euler angles
        """
        self.current_position = list(position)
        self.current_orientation = list(orientation)
        
        p.configureDebugVisualizer(p.COV_ENABLE_RENDERING, 0)
        p.changeConstraint(
            self.constraint_id,
            jointChildPivot=position,
            jointChildFrameOrientation=p.getQuaternionFromEuler(orientation),
            maxForce=50
        )
        p.configureDebugVisualizer(p.COV_ENABLE_RENDERING, 1)
    
    def get_pose(self):
        """Return current pose as (position, orientation)."""
        return self.current_position, self.current_orientation


class PandaGripper(Gripper):
    """Panda-style gripper (PR2 with different parameters) for second gripper requirement."""
    
    def __init__(self, position, orientation):
        self.initial_position = list(position)
        self.initial_orientation = list(orientation)
        self.current_position = list(position)
        self.current_orientation = [0, 0, 0]
        
        # Load PR2 gripper but with different scaling (simulates different gripper)
        self.body_id = p.loadURDF("pr2_gripper.urdf", 
                                   basePosition=position,
                                   baseOrientation=orientation,
                                   globalScaling=0.9)  # Slightly smaller
        
        # Different initial joint positions (narrower opening)
        self.joint_positions = [0.45, 0.0, 0.45, 0.0]
        for i, pos in enumerate(self.joint_positions):
            p.resetJointState(self.body_id, i, pos)
        
        self.constraint_id = p.createConstraint(
            parentBodyUniqueId=self.body_id,
            parentLinkIndex=-1,
            childBodyUniqueId=-1,
            childLinkIndex=-1,
            jointType=p.JOINT_FIXED,
            jointAxis=[0, 0, 0],
            parentFramePosition=[0.18, 0, 0],  # Slightly different offset
            childFramePosition=position
        )
        
        # Add friction
        for link_idx in range(p.getNumJoints(self.body_id)):
            p.changeDynamics(self.body_id, link_idx, 
                             lateralFriction=3.0,
                             spinningFriction=1.0,
                             rollingFriction=1.0)
    
    def close_gripper(self):
        for joint in [0, 2]:
            p.setJointMotorControl2(
                self.body_id, joint, p.POSITION_CONTROL, 
                targetPosition=0.0,
                maxVelocity=2.5,  # Faster
                force=60  # Stronger
            )

    def open_gripper(self):
        for joint in [0, 2]:
            p.setJointMotorControl2(
                self.body_id, joint, p.POSITION_CONTROL, 
                targetPosition=0.45,  # Opens less wide
                maxVelocity=1, 
                force=10
            )

    def move_to_pose(self, position, orientation):
        self.current_position = list(position)
        self.current_orientation = list(orientation)
        p.configureDebugVisualizer(p.COV_ENABLE_RENDERING, 0)
        p.changeConstraint(
            self.constraint_id,
            jointChildPivot=position,
            jointChildFrameOrientation=p.getQuaternionFromEuler(orientation),
            maxForce=50
        )
        p.configureDebugVisualizer(p.COV_ENABLE_RENDERING, 1)
    
    def get_pose(self):
        return self.current_position, self.current_orientation

# =============================================================================
# CONCRETE OBJECT IMPLEMENTATIONS
# =============================================================================

class Cube(GraspableObject):
    """Small cube object for grasping."""
    
    def __init__(self, position, orientation):
        """
        Args:
            position: [x, y, z] initial position
            orientation: quaternion initial orientation
        """
        self.initial_position = list(position)
        self.initial_orientation = list(orientation)
        self.body_id = p.loadURDF("cube_small.urdf", position, orientation, globalScaling=1.0)

    def get_position(self):
        """Get current position from simulation."""
        pos, _ = p.getBasePositionAndOrientation(self.body_id)
        return list(pos)

    def get_orientation(self):
        """Get current orientation from simulation."""
        _, orn = p.getBasePositionAndOrientation(self.body_id)
        return list(orn)

    def reset(self):
        """Reset object to initial position."""
        p.resetBasePositionAndOrientation(
            self.body_id, 
            self.initial_position, 
            self.initial_orientation
        )


class Sphere(GraspableObject):
    """Sphere object for grasping (uses PyBullet's built-in sphere)."""
    
    def __init__(self, position, orientation):
        self.initial_position = list(position)
        self.initial_orientation = list(orientation)
        # Use PyBullet's built-in sphere
        self.body_id = p.loadURDF("sphere_small.urdf", position, orientation, globalScaling=1.0)
        
        # Add friction to prevent slipping
        p.changeDynamics(self.body_id, -1, 
                         lateralFriction=2.0,
                         spinningFriction=0.5,
                         rollingFriction=0.5,
                         mass=0.2)

    def get_position(self):
        pos, _ = p.getBasePositionAndOrientation(self.body_id)
        return list(pos)

    def get_orientation(self):
        _, orn = p.getBasePositionAndOrientation(self.body_id)
        return list(orn)

    def reset(self):
        p.resetBasePositionAndOrientation(
            self.body_id, 
            self.initial_position, 
            self.initial_orientation
        )

class Cylinder(GraspableObject):
    """Cylinder object for grasping (uses custom URDF, scaled to 0.5x)."""
    
    def __init__(self, position, orientation):
        self.initial_position = list(position)
        self.initial_orientation = list(orientation)
        # Load cylinder from CodeFiles with 0.5 scaling
        self.body_id = p.loadURDF("CodeFiles/cylinder.urdf", position, orientation, globalScaling=0.5)
        
        # Add HIGH friction to prevent slipping when grasped
        p.changeDynamics(self.body_id, -1, 
                         lateralFriction=3.0,
                         spinningFriction=1.0,
                         rollingFriction=1.0,
                         mass=0.3)

    def get_position(self):
        pos, _ = p.getBasePositionAndOrientation(self.body_id)
        return list(pos)

    def get_orientation(self):
        _, orn = p.getBasePositionAndOrientation(self.body_id)
        return list(orn)

    def reset(self):
        p.resetBasePositionAndOrientation(
            self.body_id, 
            self.initial_position, 
            self.initial_orientation
        )


In [4]:
# =============================================================================
# FULL PIPELINE: Dataset Collection + Classifier Training + Testing
# =============================================================================

import pybullet as p
import pybullet_data
import time
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report

# =============================================================================
# DATA COLLECTION FUNCTION
# =============================================================================

def collect_dataset(gripper_class, gripper_name, object_class, object_name, 
                    object_pos, sampler_params, num_samples=35, visualize=False):
    """
    Collect grasp data for one gripper-object combination.
    """
    # Setup simulation
    if p.isConnected():
        p.disconnect()
    
    p.connect(p.GUI if visualize else p.DIRECT)
    p.setAdditionalSearchPath(pybullet_data.getDataPath())
    p.resetSimulation()
    p.setGravity(0, 0, -10)
    p.setRealTimeSimulation(0)
    p.loadURDF("plane.urdf")
    
    # Create object and gripper
    obj = object_class(object_pos, p.getQuaternionFromEuler([0, 0, 0]))
    gripper = gripper_class([object_pos[0], object_pos[1], 0.4], [0, 0, 0, 1])
    
    # Create sampler and simulator
    sampler = GraspSampler(object_position=object_pos, **sampler_params)
    simulator = GraspSimulator(gripper, obj, lift_height=0.25, hold_time=1.0)
    
    # Generate poses and collect data
    poses = sampler.sample_poses(num_samples, noise_std=0.005)
    
    print(f"\nCollecting {num_samples} samples for {gripper_name} + {object_name}...")
    dataset = simulator.run_data_collection(poses, reset_between=True, visualize=visualize)
    
    # Add gripper and object labels
    dataset['gripper'] = gripper_name
    dataset['object'] = object_name
    
    p.disconnect()
    
    success_count = dataset['success'].sum()
    print(f"  → {success_count}/{num_samples} successful ({100*success_count/num_samples:.1f}%)")
    
    return dataset

# =============================================================================
# RUN FULL DATA COLLECTION
# =============================================================================

print("="*60)
print("STARTING FULL DATASET COLLECTION")
print("="*60)

# Configuration for each combination
configs = [
    # (GripperClass, gripper_name, ObjectClass, object_name, object_pos, sampler_params)
    (PR2Gripper, "PR2", Cube, "Cube", [0.6, 0.3, 0.025], {'radius': 0.08, 'height_offset': 0.03}),
    (PR2Gripper, "PR2", Cylinder, "Cylinder", [0.6, 0.3, 0.05], {'radius': 0.06, 'height_offset': 0.02}),
    (PandaGripper, "Panda", Cube, "Cube", [0.6, 0.3, 0.025], {'radius': 0.08, 'height_offset': 0.03}),
    (PandaGripper, "Panda", Cylinder, "Cylinder", [0.6, 0.3, 0.05], {'radius': 0.06, 'height_offset': 0.02}),
]

# Collect data for all combinations
all_datasets = []
samples_per_combo = 35  # 35 x 4 = 140 total

for gripper_cls, g_name, obj_cls, o_name, o_pos, s_params in configs:
    df = collect_dataset(gripper_cls, g_name, obj_cls, o_name, o_pos, s_params, 
                         num_samples=samples_per_combo, visualize=False)
    all_datasets.append(df)

# Combine all datasets
full_dataset = pd.concat(all_datasets, ignore_index=True)

print("\n" + "="*60)
print("DATASET COLLECTION COMPLETE")
print("="*60)
print(f"\nTotal samples: {len(full_dataset)}")
print(f"Successful grasps: {full_dataset['success'].sum()}")
print(f"Failed grasps: {len(full_dataset) - full_dataset['success'].sum()}")
print(f"\nDataset shape: {full_dataset.shape}")
print("\nSamples per combination:")
print(full_dataset.groupby(['gripper', 'object'])['success'].agg(['count', 'sum', 'mean']))

# =============================================================================
# CLASSIFIER TRAINING
# =============================================================================

print("\n" + "="*60)
print("TRAINING CLASSIFIER")
print("="*60)

# Prepare features and labels
feature_cols = ['x', 'y', 'z', 'roll', 'pitch', 'yaw']
X = full_dataset[feature_cols].values
y = full_dataset['success'].values

# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42, stratify=y)

print(f"\nTraining set: {len(X_train)} samples")
print(f"Test set: {len(X_test)} samples")

# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Try multiple classifiers
classifiers = {
    'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42),
    'SVM (RBF)': SVC(kernel='rbf', C=1.0, gamma='scale', random_state=42),
}

best_clf = None
best_acc = 0
best_name = ""

for name, clf in classifiers.items():
    clf.fit(X_train_scaled, y_train)
    y_pred = clf.predict(X_test_scaled)
    acc = accuracy_score(y_test, y_pred)
    print(f"\n{name}: Validation Accuracy = {acc*100:.1f}%")
    
    if acc > best_acc:
        best_acc = acc
        best_clf = clf
        best_name = name

print(f"\n→ Best classifier: {best_name} ({best_acc*100:.1f}% accuracy)")

# =============================================================================
# TESTING PHASE
# =============================================================================

print("\n" + "="*60)
print("TESTING PHASE: 10 new grasps per combination")
print("="*60)

def test_classifier(gripper_class, gripper_name, object_class, object_name, 
                    object_pos, sampler_params, classifier, scaler, n_tests=10):
    """Test classifier predictions on new grasps."""
    
    if p.isConnected():
        p.disconnect()
    
    p.connect(p.DIRECT)
    p.setAdditionalSearchPath(pybullet_data.getDataPath())
    p.resetSimulation()
    p.setGravity(0, 0, -10)
    p.setRealTimeSimulation(0)
    p.loadURDF("plane.urdf")
    
    obj = object_class(object_pos, p.getQuaternionFromEuler([0, 0, 0]))
    gripper = gripper_class([object_pos[0], object_pos[1], 0.4], [0, 0, 0, 1])
    
    sampler = GraspSampler(object_position=object_pos, **sampler_params)
    simulator = GraspSimulator(gripper, obj, lift_height=0.25, hold_time=1.0)
    
    # Generate NEW test poses
    test_poses = sampler.sample_poses(n_tests, noise_std=0.005)
    
    correct_predictions = 0
    results = []
    
    for i, pose in enumerate(test_poses):
        # Get features
        features = np.array([[pose['position'][0], pose['position'][1], pose['position'][2],
                              pose['orientation'][0], pose['orientation'][1], pose['orientation'][2]]])
        features_scaled = scaler.transform(features)
        
        # Predict
        predicted = classifier.predict(features_scaled)[0]
        
        # Execute grasp to get actual result
        actual = simulator.execute_grasp(pose, visualize=False)
        simulator.reset_scene()
        
        if predicted == actual:
            correct_predictions += 1
        
        results.append({
            'predicted': predicted,
            'actual': int(actual),
            'correct': predicted == actual
        })
    
    p.disconnect()
    
    accuracy = correct_predictions / n_tests
    return accuracy, results

# Run tests for all combinations
test_results = []

for gripper_cls, g_name, obj_cls, o_name, o_pos, s_params in configs:
    print(f"\nTesting {g_name} + {o_name}...")
    acc, results = test_classifier(gripper_cls, g_name, obj_cls, o_name, o_pos, s_params,
                                   best_clf, scaler, n_tests=10)
    test_results.append({
        'Gripper': g_name,
        'Object': o_name,
        'Prediction Accuracy': f"{acc*100:.1f}%",
        'Correct': sum(r['correct'] for r in results),
        'Total': len(results)
    })
    print(f"  → Prediction accuracy: {acc*100:.1f}% ({sum(r['correct'] for r in results)}/10)")

# Final summary
print("\n" + "="*60)
print("FINAL TEST RESULTS")
print("="*60)
test_df = pd.DataFrame(test_results)
print(test_df.to_string(index=False))

overall_correct = sum(r['Correct'] for r in test_results)
overall_total = sum(r['Total'] for r in test_results)
print(f"\nOverall prediction accuracy: {overall_correct}/{overall_total} = {100*overall_correct/overall_total:.1f}%")



STARTING FULL DATASET COLLECTION

Collecting 35 samples for PR2 + Cube...
Grasp attempt 1/35
  Result: SUCCESS
Grasp attempt 2/35
  Result: FAILURE
Grasp attempt 3/35
  Result: SUCCESS
Grasp attempt 4/35
  Result: SUCCESS
Grasp attempt 5/35
  Result: FAILURE
Grasp attempt 6/35
  Result: SUCCESS
Grasp attempt 7/35
  Result: SUCCESS
Grasp attempt 8/35
  Result: SUCCESS
Grasp attempt 9/35
  Result: SUCCESS
Grasp attempt 10/35
  Result: FAILURE
Grasp attempt 11/35
  Result: FAILURE
Grasp attempt 12/35
  Result: SUCCESS
Grasp attempt 13/35
  Result: SUCCESS
Grasp attempt 14/35
  Result: SUCCESS
Grasp attempt 15/35
  Result: SUCCESS
Grasp attempt 16/35
  Result: SUCCESS
Grasp attempt 17/35
  Result: FAILURE
Grasp attempt 18/35
  Result: FAILURE
Grasp attempt 19/35
  Result: FAILURE
Grasp attempt 20/35
  Result: SUCCESS
Grasp attempt 21/35
  Result: SUCCESS
Grasp attempt 22/35
  Result: FAILURE
Grasp attempt 23/35
  Result: SUCCESS
Grasp attempt 24/35
  Result: FAILURE
Grasp attempt 25/35
  R