In [None]:

import os
import numpy as np
import torch
import mediapy as media
import gymnasium as gym
import mujoco
from torch.utils.tensorboard import SummaryWriter
from mujoco_playground import registry
from mujoco_playground._src.wrapper_torch import RSLRLBraxWrapper
from agent import Agent  # Assuming you have your PPO agent here


In [None]:

# Set environment variables for MuJoCo
os.environ.setdefault("MUJOCO_GL", "egl")
os.environ["XLA_FLAGS"] = " --xla_gpu_triton_gemm_any=True"

# Load the environment and its config
env_name = 'PandaPickCubeOrientation'
env = registry.load(env_name)
env_cfg = registry.get_default_config(env_name)

print("Loaded MuJoCo Playground environment:", env_name)
print("Env config:", env_cfg)


In [None]:

# Setup PPO agent and environment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
envs = make_playground_vec_env(num_envs=2048, seed=42)
eval_envs = make_playground_vec_env(num_envs=4, seed=43)

agent = Agent(envs).to(device)

# Rollout and evaluate the agent's performance on the environment
def rollout(agent, envs, num_steps=200):
    agent.eval()
    frames = []
    next_obs, info = envs.reset(seed=42)
    next_obs = next_obs.to(device)
    for step in range(num_steps):
        with torch.no_grad():
            action = agent.get_action(next_obs, deterministic=True)
        next_obs, reward, done, trunc, infos = envs.step(action)
        frames.append(envs.render())
        if done.any():
            break
    return frames

# Run the rollout and save video
frames = rollout(agent, envs)
media.write_video('rollout_video.mp4', frames, fps=30)
