In [None]:
import torch
import onnxruntime as ort
import numpy as np

class ONNXPolicy:
    def __init__(self, onnx_path: str, device: str = "cpu"):
        """
        Initialize ONNX policy
        
        Args:
            onnx_path: Path to the .onnx policy file
            device: Device to run inference on ("cpu" or "cuda")
        """
        self.device = device

        providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device == "cuda" else ['CPUExecutionProvider']
        self.session = ort.InferenceSession(onnx_path, providers=providers)

        self.input_name = self.session.get_inputs()[0].name
        self.output_name = self.session.get_outputs()[0].name
        
        print(f"Policy loaded from {onnx_path}")
        print(f"Input: {self.input_name}, Output: {self.output_name}")
    
    def __call__(self, observation: np.ndarray) -> np.ndarray:
        """
        Run policy inference
        
        Args:
            observation: Observation array of shape (42,) or (batch, 42)
        
        Returns:
            Action array of shape (10,) or (batch, 10)
        """
        # observation is 2D (batch, features)
        if observation.ndim == 1:
            observation = observation.reshape(1, -1)

        observation = observation.astype(np.float32)
        
        # Run inference
        action = self.session.run([self.output_name], {self.input_name: observation})[0]
        
        if action.shape[0] == 1:
            return action[0]
        
        return action

# Usage Example
if __name__ == "__main__":
    policy = ONNXPolicy("/path/to/policy.onnx", device="cuda")
    obs = np.random.randn(42).astype(np.float32)

    action = policy(obs)
    print(f"Observation shape: {obs.shape}")
    print(f"Action shape: {action.shape}")
    print(f"Action values: {action}")

Policy loaded from /home/sachin/IsaacLab/logs/rsl_rl/rexai_direct/2025-10-04_20-00-30/exported/policy.onnx
Input: obs, Output: actions
Observation shape: (42,)
Action shape: (10,)
Action values: [-1.4317504   0.64814943 -0.4952938   1.8832533   1.7667413  -0.8991956
 -2.835862    3.8421986  -3.004331   -6.2876744 ]


In [None]:
import torch
import torch.nn as nn
import numpy as np
from typing import Union, Tuple

class ActorCriticPolicy(nn.Module):
    """Actor-Critic policy for Torch"""
    
    def __init__(self, num_obs: int = 42, num_actions: int = 10):
        super().__init__()
        
        # Actor network
        self.actor = nn.Sequential(
            nn.Linear(num_obs, 400),
            nn.ELU(alpha=1.0),
            nn.Linear(400, 200),
            nn.ELU(alpha=1.0),
            nn.Linear(200, 100),
            nn.ELU(alpha=1.0),
            nn.Linear(100, num_actions)
        )
        
        # Critic network
        self.critic = nn.Sequential(
            nn.Linear(num_obs, 400),
            nn.ELU(alpha=1.0),
            nn.Linear(400, 200),
            nn.ELU(alpha=1.0),
            nn.Linear(200, 100),
            nn.ELU(alpha=1.0),
            nn.Linear(100, 1)
        )
    
    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        """Forward pass - returns actions"""
        return self.actor(obs)
    
    def act(self, obs: torch.Tensor, deterministic: bool = True) -> torch.Tensor:
        """Get action from observation"""
        action = self.actor(obs)
        return torch.tanh(action)  # Apply tanh for bounded actions
    
    def evaluate(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get both action and value estimate"""
        action = self.actor(obs)
        value = self.critic(obs)
        return action, value


class PolicyWrapper:    
    def __init__(self, checkpoint_path: str, num_obs: int = 42, num_actions: int = 10, 
                 device: str = "cpu", action_scale: float = 1.0):
        """
        Initialize policy from checkpoint
        
        Args:
            checkpoint_path: Path to .pt checkpoint file
            num_obs: Number of observation features
            num_actions: Number of actions
            device: Device to run on ("cpu" or "cuda")
            action_scale: Scaling factor for actions
        """
        self.device = torch.device(device)
        self.action_scale = action_scale
        self.num_obs = num_obs
        self.num_actions = num_actions
        self.is_torchscript = False
        
        self.model = self.load_checkpoint(checkpoint_path)
        self.model.eval()
        
        print(f"Policy loaded from {checkpoint_path}")
        print(f"Model on device: {self.device}")
        print(f"Num observations: {num_obs}, Num actions: {num_actions}")
        print(f"TorchScript model: {self.is_torchscript}")
    
    def load_checkpoint(self, checkpoint_path: str):
        """Load model weights from checkpoint - handles both TorchScript and state_dict"""
        try:
            model = torch.jit.load(checkpoint_path, map_location=self.device)
            self.is_torchscript = True
            print("Loaded as TorchScript model")
            return model
        except Exception as e:
            print(f"Not a TorchScript model, trying as state_dict: {e}")
            checkpoint = torch.load(checkpoint_path, map_location=self.device)
            model = ActorCriticPolicy(self.num_obs, self.num_actions).to(self.device)

            if isinstance(checkpoint, dict):
                if 'model_state_dict' in checkpoint:
                    model.load_state_dict(checkpoint['model_state_dict'])
                elif 'actor_critic' in checkpoint:
                    model.load_state_dict(checkpoint['actor_critic'])
                elif 'state_dict' in checkpoint:
                    model.load_state_dict(checkpoint['state_dict'])
                else:
                    model.load_state_dict(checkpoint)
            else:
                raise ValueError("Unsupported checkpoint format")
            
            print("Loaded as state_dict")
            return model
    
    def __call__(self, observation: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
        """
        Get action from observation
        Args:
            observation: Observation of shape (num_obs,) or (batch, num_obs)
        Returns:
            Action array of shape (10,) or (batch, 10)
        """
        # Convert to tensor if needed
        if isinstance(observation, np.ndarray):
            obs_tensor = torch.from_numpy(observation).float()
        else:
            obs_tensor = observation.float()
        
        single_obs = obs_tensor.ndim == 1
        if single_obs:
            obs_tensor = obs_tensor.unsqueeze(0)
        
        if obs_tensor.shape[1] != self.num_obs:
            obs_tensor = obs_tensor[:, :self.num_obs]
        
        # Move to device
        obs_tensor = obs_tensor.to(self.device)
        
        # Get action
        with torch.no_grad():
            if self.is_torchscript:
                action = self.model(obs_tensor)
            else:
                action = self.model.act(obs_tensor, deterministic=True)
            
            action = action * self.action_scale
        
        action = action.cpu().numpy()

        if single_obs:
            return action[0]
        
        return action
    
    def get_value(self, observation: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
        """Get value estimate for observation (only for non-TorchScript models)"""
        if self.is_torchscript:
            raise NotImplementedError("Value estimation not available for TorchScript models. "
                                     "TorchScript model only exports the actor network.")
        
        if isinstance(observation, np.ndarray):
            obs_tensor = torch.from_numpy(observation).float()
        else:
            obs_tensor = observation.float()
        
        single_obs = obs_tensor.ndim == 1
        if single_obs:
            obs_tensor = obs_tensor.unsqueeze(0)
        
        if obs_tensor.shape[1] != self.num_obs:
            obs_tensor = obs_tensor[:, :self.num_obs]
        
        obs_tensor = obs_tensor.to(self.device)
        
        with torch.no_grad():
            value = self.model.critic(obs_tensor)
        
        value = value.cpu().numpy()
        
        if single_obs:
            return value[0, 0]
        
        return value.squeeze(-1)


# Usage Example
if __name__ == "__main__":
    policy = PolicyWrapper(
        checkpoint_path="/pat/to/policy.pt",
        num_obs=42,
        num_actions=10,
        device="cuda",
        action_scale=1.0
    )

    obs = np.random.randn(42).astype(np.float32)
    action = policy(obs)
    
    print(f"\nSingle Observation Test:")
    print(f"Observation shape: {obs.shape}")
    print(f"Action shape: {action.shape}")
    print(f"Action values: {action}")
    print(f"Action range: [{action.min():.3f}, {action.max():.3f}]")
    
    batch_obs = np.random.randn(4, 42).astype(np.float32)
    batch_actions = policy(batch_obs)
    
    print(f"\nBatch Observation Test:")
    print(f"Batch observation shape: {batch_obs.shape}")
    print(f"Batch action shape: {batch_actions.shape}")

    # Print model info
    print(f"\nModel type: {'TorchScript' if policy.is_torchscript else 'PyTorch'}")


Loaded as TorchScript model
Policy loaded from /home/sachin/IsaacLab/logs/rsl_rl/rexai_direct/2025-10-04_20-00-30/exported/policy.pt
Model on device: cuda
Num observations: 42, Num actions: 10
TorchScript model: True

Single Observation Test:
Observation shape: (42,)
Action shape: (10,)
Action values: [ -2.6916864    1.2083182   -0.17800936   2.9344132    3.8472185
  -1.2177778   -1.5107269    1.573474    -7.824546   -12.670531  ]
Action range: [-12.671, 3.847]

Batch Observation Test:
Batch observation shape: (4, 42)
Batch action shape: (4, 10)

Model type: TorchScript
