In [None]:
import gymnasium as gym
# from EasyEnv import myEasyGym
from Approach_env import SRC_approach as SRC_test
import numpy as np
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.env_checker import check_env
from stable_baselines3 import PPO
from CL_env import CurriculumWrapper
# Create original environment
gym.envs.register(id="Training_ppo_rand_needle", entry_point=SRC_test, max_episode_steps=2000)
env = gym.make("Training_ppo_rand_needle", render_mode="human")

# Wrap the env into curriculum learning mode
env = CurriculumWrapper(env)

import torch
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
import torch.nn as nn

In [None]:
# Check the environment
check_env(env)

In [None]:
env.reset()

In [None]:
# Initialize the model
model = PPO("MlpPolicy", env, verbose=1,tensorboard_log="./First_expert_insert/",)

In [None]:
# pretrained behavior cloning (optional)
import pickle

with open('expert_data.pkl', 'rb') as f:
    expert_data = pickle.load(f)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Policy = model.policy.to(device)
observations, actions = zip(*expert_data)
observations = np.array(observations)
actions = np.array(actions)

observations_tensor = torch.tensor(observations, dtype=torch.float32)
actions_tensor = torch.tensor(actions, dtype=torch.long)

dataset = TensorDataset(observations_tensor, actions_tensor)
data_loader = DataLoader(dataset, batch_size=1024, shuffle=True)

optimizer = Adam(model.policy.parameters(), lr=2e-4)
criterion = nn.CrossEntropyLoss()
batch_idx = 0

for epoch in range(100000): 
    loss_avg = 0
    batch_idx = 0
    if (epoch%10 == 0):
        data_loader = DataLoader(dataset, batch_size=1024, shuffle=True)
        print("Re-shuffle the batch...")
    for batch_obs, batch_actions in data_loader:
        batch_idx += 1
        batch_obs = batch_obs.to(device)
        batch_actions = batch_actions.to(device)
        dist = Policy.get_distribution(batch_obs)
        loss = 0
        action_len = env.action_space.shape[0]
        for i in range(action_len):
            action_logits = dist.distribution[i].logits
            loss += criterion(action_logits, batch_actions[:,i].long())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_avg += loss

        if (batch_idx%200 == 0):
            print(
                "Train Epoch: {} Batch idx: {} \t Loss: {:.6f}\n".format(
                    epoch,
                    batch_idx,
                    loss.item(),
                )
            )
    print(f"Average Loss in {epoch} episode is {loss_avg/batch_idx}\n")


In [None]:
# First-time training
# The agent will gradually increase the difficulty of the task once it satisfies the previous error tolerance
checkpoint_callback = CheckpointCallback(save_freq=10000, save_path='./First_end_effector/Model_temp2', name_prefix='SRC')
model.learn(total_timesteps=int(2000000), progress_bar=True,callback=checkpoint_callback,)
model.save("SRC")

In [None]:
# Continue training
model = PPO("MlpPolicy", env, verbose=1,tensorboard_log="./First_expert_demo/")
checkpoint_callback = CheckpointCallback(save_freq=10000, save_path='./First_expert_demo/Model_temp', name_prefix='SRC')
model_path = "./Expert.zip"
model = PPO.load(model_path)
model.set_env(env=env)
model.learn(total_timesteps=int(1000000), progress_bar=True,callback=checkpoint_callback,reset_num_timesteps=False)
model.save("./First_RL_expert_insert/"+"rl_model_final")

In [None]:
# Predict the action
obs,info = env.reset()
print(obs)
for i in range(10000):
    action, _state = model.predict(obs, deterministic=True)
    obs, reward, terminated,truncated, info = env.step(action)
    env.render()
    if terminated or truncated:
        obs, info = env.reset()

In [None]:
def low_pass_filter(prev_action, new_action, alpha=0.3):
    """
    Apply low pass filter
    alpha: smooth factor
    """
    return alpha * new_action + (1 - alpha) * prev_action

obs, info = env.reset()
prev_action = None

for i in range(10000):
    current_action, _state = model.predict(obs, deterministic=True)

    if prev_action is not None:
        filtered_action = low_pass_filter(prev_action, current_action)
    else:
        filtered_action = current_action

    prev_action = filtered_action
    print(filtered_action)
    
    obs, reward, terminated, truncated, info = env.step(filtered_action)
    print(info)
    env.render()
    
    if terminated or truncated:
        obs, info = env.reset()
        prev_action = None 
