In [None]:
import torch
import torch.nn as nn
import numpy as np
import time                                           
from env import InvertedPendulumSerialEnv
from agent import ActorNet, QNet, SACTrainer, ReplayBuffer, weight_to_tensor, ActorB, ActorSU, CriticB, CriticSU

In [None]:
# Device configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- Initialize Swing-up Actor ---
actor_su = ActorSU(obs_dim=6, act_dim=1, max_action=1.75).to(device)
actor_su.load_pretrained_weights()  # Load pretrained weights
actor_su.eval()

# --- Initialize Balance Actor ---
actor_b = ActorB(obs_dim=6, act_dim=1, max_action=5.0).to(device)
# actor_b.load_pretrained_weights()
actor_b.load_model('sac_model_b.pth')  # Load model file
actor_b.eval()

In [None]:
# --- Environment ---
env_b = InvertedPendulumSerialEnv(port='COM12', baudrate=921600, mode='balance')

# Initial reset
obs, _ = env_b.reset()

In [None]:
# Run inference loop (max 1e6 steps or break if done)
for _ in range(1_000_000):
    # 1) Create observation tensor on device
    obs_tensor = torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)

    # 2) Select action (always use torch.no_grad)
    if torch.isnan(obs_tensor).any() or torch.isinf(obs_tensor).any():
        action_np = np.zeros(1, dtype=np.float32)
    else:
        alpha_val = obs_tensor[0, 2].item()  # assuming obs[2] is alpha
        if abs(alpha_val) >= np.deg2rad(155):
            # Use balance policy
            with torch.no_grad():
                action_tensor, _, _ = actor_b(obs_tensor, deterministic=True)
        else:
            # Use swing-up policy
            with torch.no_grad():
                action_tensor, _, _ = actor_su(obs_tensor, deterministic=True)
                # action_tensor = torch.tensor(0.0)
        action_np = action_tensor.cpu().numpy().flatten()

    # 3) Step the environment
    next_obs, reward, terminated, truncated, _ = env_b.step(action_np)
    done = terminated or truncated

    # 4) Update obs for next loop
    obs = next_obs

    # 5) Stop if environment signals done
    if done:
        print("Environment returned done. Breaking out.")
        obs, _ = env_b.reset()
        break

    # 6) (Optional) give ESP32 time to process, e.g., 10 ms
    time.sleep(0.01)
