In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import onnxruntime as ort
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfigurationChannel
import collections
import os

# ======================================================================================
# --- ‚öôÔ∏è 1. CONFIGURATION: PLEASE EDIT THESE VALUES ---
# ======================================================================================

# --- Environment and Model Paths ---
ENV_PATH = "../Env/DroneFlightv1.exe" # or "builds/YourEnv.x86_64"

# --- PASTE YOUR ABSOLUTE PATH HERE using r"..." ---
# This tells Python to treat backslashes as normal characters, fixing the SyntaxError.
MODEL_PATH = r"C:\Users\Fede\Desktop\MasterThesis\PPO_Unity\results\PPO\Drone\Drone-4499957.onnx"

# --- Data Collection and Training Settings ---
STEPS_TO_COLLECT = 5000
REPLAY_BUFFER_CAPACITY = 10000
LEARNING_RATE = 3e-4
BATCH_SIZE = 64
TRAINING_STEPS = 2000

# --- ONNX Model Input/Output Names (pre-filled from your inspection output) ---
OBS_INPUT_NAMES = ["obs_0", "obs_1", "obs_2", "obs_3"]
ACTION_OUTPUT_NAME = "deterministic_continuous_actions"
RECURRENT_IN_NAME = "recurrent_in"
RECURRENT_OUT_NAME = "recurrent_out"

# ======================================================================================
# --- Path Validation ---
# ======================================================================================

# Check if the model file exists at the absolute path
if not os.path.exists(MODEL_PATH):
    print(f"\n‚ö†Ô∏è WARNING: Model file not found at the specified path!")
    print(f"Path given: {MODEL_PATH}")
else:
    print(f"‚úÖ Model path is valid: {MODEL_PATH}")

‚úÖ Model path is valid: C:\Users\Fede\Desktop\MasterThesis\PPO_Unity\results\PPO\Drone\Drone-4499957.onnx


In [2]:
# ======================================================================================
# --- üîé 2. INSPECT THE ONNX MODEL ---
# ======================================================================================

# This cell is for verification. You've already provided the output, so the names in Cell 1 are correct.

if MODEL_PATH and os.path.exists(MODEL_PATH):
    try:
        print(f"Inspecting model: {MODEL_PATH}\n")
        session = ort.InferenceSession(MODEL_PATH)

        print("--- Model Inputs ---")
        for i, inp in enumerate(session.get_inputs()):
            print(f"Input {i}: Name: {inp.name}, Shape: {inp.shape}, Type: {inp.type}")

        print("\n--- Model Outputs ---")
        for i, out in enumerate(session.get_outputs()):
            print(f"Output {i}: Name: {out.name}, Shape: {out.shape}, Type: {out.type}")

    except Exception as e:
        print(f"‚ùå Error loading or inspecting the ONNX model: {e}")
else:
    print("Skipping inspection because model path is not valid.")

Inspecting model: C:\Users\Fede\Desktop\MasterThesis\PPO_Unity\results\PPO\Drone\Drone-4499957.onnx

--- Model Inputs ---
Input 0: Name: obs_0, Shape: ['batch', 3, 84, 84], Type: tensor(float)
Input 1: Name: obs_1, Shape: ['batch', 68], Type: tensor(float)
Input 2: Name: obs_2, Shape: ['batch', 4, 84, 84], Type: tensor(float)
Input 3: Name: obs_3, Shape: ['batch', 64], Type: tensor(float)
Input 4: Name: recurrent_in, Shape: ['batch', 1, 256], Type: tensor(float)

--- Model Outputs ---
Output 0: Name: version_number, Shape: [1], Type: tensor(float)
Output 1: Name: memory_size, Shape: [1], Type: tensor(float)
Output 2: Name: continuous_actions, Shape: ['batch', 4], Type: tensor(float)
Output 3: Name: continuous_action_output_shape, Shape: [1], Type: tensor(float)
Output 4: Name: deterministic_continuous_actions, Shape: ['Divdeterministic_continuous_actions_dim_0', 4], Type: tensor(float)
Output 5: Name: recurrent_out, Shape: ['Transposerecurrent_out_dim_0', 1, 256], Type: tensor(float)


In [9]:
# ======================================================================================
# --- üì¶ 3. REPLAY BUFFER AND DATA GATHERING LOGIC (REVISED FOR ACTION HANDLING) ---
# ======================================================================================

class ReplayBuffer:
    """A simple FIFO experience replay buffer."""
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity)

    def add(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        states, actions, rewards, next_states, dones = zip(*[self.buffer[i] for i in indices])
        flat_states = np.array([np.concatenate([obs.flatten() for obs in s]) for s in states])
        flat_next_states = np.array([np.concatenate([obs.flatten() for obs in ns]) for ns in next_states])
        return flat_states, np.array(actions), np.array(rewards, dtype=np.float32), flat_next_states, np.array(dones, dtype=np.uint8)

    def __len__(self):
        return len(self.buffer)

def gather_experience_recurrent(env, onnx_session, behavior_name, num_steps, spec):
    """Gathers experience from a recurrent, multi-input model."""
    print("\n--- Starting Experience Gathering (Recurrent Mode) ---")
    replay_buffer = ReplayBuffer(REPLAY_BUFFER_CAPACITY)
    
    obs_names = OBS_INPUT_NAMES
    recurrent_in = RECURRENT_IN_NAME
    recurrent_out = RECURRENT_OUT_NAME
    action_out_name = ACTION_OUTPUT_NAME
    
    agent_memory = collections.defaultdict(lambda: np.zeros((1, 256), dtype=np.float32))
    agent_transition_data = {}

    collected_steps = 0
    while collected_steps < num_steps:
        decision_steps, terminal_steps = env.get_steps(behavior_name)
        
        for agent_id in terminal_steps:
            if agent_id in agent_transition_data:
                prev_obs, prev_action = agent_transition_data.pop(agent_id)
                replay_buffer.add(prev_obs, prev_action, terminal_steps[agent_id].reward, terminal_steps[agent_id].obs, True)
                collected_steps += 1
            agent_memory.pop(agent_id, None)

        if len(decision_steps) > 0:
            input_feed = {}
            for i, name in enumerate(obs_names):
                input_feed[name] = decision_steps.obs[i]
            
            agent_ids = decision_steps.agent_id
            
            memory_array_2d = np.vstack([agent_memory[id] for id in agent_ids])
            memory_array_3d = memory_array_2d[:, np.newaxis, :]
            input_feed[recurrent_in] = memory_array_3d
            
            action_output, recurrent_output = onnx_session.run([action_out_name, recurrent_out], input_feed)
            
            for i, agent_id in enumerate(agent_ids):
                agent_memory[agent_id] = recurrent_output[i, :, :]

            for agent_id in decision_steps:
                if agent_id in agent_transition_data:
                    prev_obs, prev_action = agent_transition_data.pop(agent_id)
                    replay_buffer.add(prev_obs, prev_action, decision_steps[agent_id].reward, decision_steps[agent_id].obs, False)
                    collected_steps += 1
                    if collected_steps % 250 == 0:
                        print(f"Collected {collected_steps}/{num_steps} steps... Buffer size: {len(replay_buffer)}")

            for i, agent_id in enumerate(agent_ids):
                agent_transition_data[agent_id] = ([obs[i] for obs in decision_steps.obs], action_output[i])

            action_dict = {agent_id: action_output[i] for i, agent_id in enumerate(agent_ids)}
            
            # ======================= THIS IS THE FIX =========================
            # Pass the raw numpy array of actions directly. 
            # The set_actions method is smart enough to handle this.
            # env.set_actions(behavior_name, action_output)
            # =================================================================

        env.step(action_dict)
    
    print(f"\n‚úÖ Finished gathering experience. Replay buffer has {len(replay_buffer)} entries.")
    return replay_buffer

In [10]:
# ======================================================================================
# --- üß† 4. CUSTOM PYTORCH MODEL AND TRAINING LOOP (REVISED) ---
# ======================================================================================

class CustomActor(nn.Module):
    """A simple MLP policy network for behavioral cloning."""
    def __init__(self, state_dim, action_dim):
        super(CustomActor, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim, 256), # Increased size for complex state
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim)
        )

    def forward(self, state):
        return self.network(state)

def train_custom_model(replay_buffer, state_dim, action_dim):
    """Initializes and trains the custom model using behavioral cloning."""
    print("\n--- Starting Custom Model Training ---")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    custom_model = CustomActor(state_dim, action_dim).to(device)
    optimizer = optim.Adam(custom_model.parameters(), lr=LEARNING_RATE)
    loss_fn = nn.MSELoss()

    for step in range(TRAINING_STEPS):
        # The replay buffer now returns flattened states automatically
        states, actions, _, _, _ = replay_buffer.sample(BATCH_SIZE)

        states_t = torch.tensor(states, dtype=torch.float32).to(device)
        expert_actions_t = torch.tensor(actions, dtype=torch.float32).to(device)

        predicted_actions_t = custom_model(states_t)
        loss = loss_fn(predicted_actions_t, expert_actions_t)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (step + 1) % 100 == 0:
            print(f"Training Step: {step + 1}/{TRAINING_STEPS} | MSE Loss: {loss.item():.6f}")

    print("\n‚úÖ Custom model training complete.")
    return custom_model

In [11]:
# ======================================================================================
# --- ‚ñ∂Ô∏è 5. EXECUTE THE WORKFLOW (REVISED) ---
# ======================================================================================

# 1. Launch Environment and Load ONNX Model
if MODEL_PATH:
    engine_channel = EngineConfigurationChannel()
    env = UnityEnvironment(file_name=ENV_PATH, side_channels=[engine_channel], worker_id=np.random.randint(0, 100))
    engine_channel.set_configuration_parameters(time_scale=20.0)
    env.reset()
    
    behavior_name = list(env.behavior_specs.keys())[0]
    spec = env.behavior_specs[behavior_name]
    
    onnx_session = ort.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider'])
    
    # 2. Gather expert experience
    experience_buffer = gather_experience_recurrent(env, onnx_session, behavior_name, STEPS_TO_COLLECT, spec)
    
    env.close()

    # 3. Train the custom model
    if experience_buffer and len(experience_buffer) > BATCH_SIZE:
        # Get state and action dimensions by sampling the buffer
        sample_flat_state, sample_action, _, _, _ = experience_buffer.sample(1)
        state_dim = sample_flat_state.shape[1]
        action_dim = sample_action.shape[1]
        
        print(f"\nDetected Flattened State Dimension: {state_dim}")
        print(f"Detected Action Dimension: {action_dim}")

        trained_model = train_custom_model(experience_buffer, state_dim, action_dim)
        
        # Save the trained student model
        torch.save(trained_model.state_dict(), "my_custom_drone_model.pth")
        print("\nCustom bootstrapped model saved to 'my_custom_drone_model.pth'")
    else:
        print("\nSkipping training due to an error or insufficient data.")
else:
    print("Execution skipped because ONNX model path was not found. Please check your config in Cell 1.")


--- Starting Experience Gathering (Recurrent Mode) ---


TypeError: UnityEnvironment.step() takes 1 positional argument but 2 were given