In [None]:
import gymnasium as gym
from stable_baselines3.common.evaluation import evaluate_policy
from High_level_env import SRC_high_level
import numpy as np
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.env_checker import check_env
from RL_algo.PPO import PPO
from stable_baselines3.common.utils import set_random_seed
import time
import torch
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F
import torch.nn as nn
from sklearn.model_selection import train_test_split
import numpy as np
import pickle
from torch.utils.tensorboard import SummaryWriter

seed = 10
set_random_seed(seed)

episode_steps = 50

gym.envs.register(id="high_level", entry_point=SRC_high_level, max_episode_steps=episode_steps)
env = gym.make("high_level")

In [None]:
env.reset()

In [None]:
env.reset()
for _ in range(20000):
    # Random action
    action = env.action_space.sample()
    # print(action)
    print(env.low_env.task)
    obs, reward, terminated, truncated, info = env.step(action)
    if terminated or truncated:
        time.sleep(1.0)
        obs, info = env.reset()

In [None]:
check_env(env)

In [None]:
with open('/home/jin/6d_pose_collection_scripts/Trajectory_complete.pkl', 'rb') as f:
    expert_data = pickle.load(f)

# Data Preprocessing
observations = np.concatenate([traj.obs[:-1] for traj in expert_data], axis=0)
actions =   np.concatenate([traj.acts for traj in expert_data], axis=0)

change_indices = np.where(actions[:-1] != actions[1:])[0] + 1

changed_observations = []
changed_actions = []

for index in change_indices:
    end_index = min(index + 50, len(observations))
    changed_observations.extend(observations[index:end_index])
    changed_actions.extend([actions[index]] * (end_index - index))

changed_observations = np.array(changed_observations)
changed_actions = np.array(changed_actions)

print("Changed Observations: ", changed_observations)
print("Changed Actions: ", changed_actions)

observations = changed_observations
actions = changed_actions


In [None]:
# First training
model = PPO("MlpPolicy", env,gamma=0.7, verbose=1,tensorboard_log="./High_level")

In [None]:
def train(model, device, train_loader, optimizer, criterion, epoch, writer):
    model.train()
    train_loss = 0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        dist = model.get_distribution(data)
        logits = dist.distribution.logits
        loss = criterion(logits, target[:, 0])
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)
    print(f'Epoch: {epoch}, Average Loss: {train_loss:.4f}')
    writer.add_scalar("Loss/train", train_loss, epoch)

def test(model, device, test_loader, criterion, epoch, writer):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            dist = model.get_distribution(data)
            logits = dist.distribution.logits
            loss = criterion(logits, target[:, 0])
            test_loss += loss.item()
            # 计算准确率
            pred = logits.argmax(dim=1, keepdim=True)  # 获取最大概率的索引作为预测结果
            correct += pred.eq(target[:, 0].view_as(pred)).sum().item()
            total += target.size(0)
    
    accuracy = 100. * correct / total 
    test_loss /= len(test_loader)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{total} ({accuracy:.2f}%)\n')
    writer.add_scalar("Loss/test", test_loss, epoch)
    writer.add_scalar("Accuracy/test", accuracy, epoch)

def pretrain(model, device, train_loader, test_loader, optimizer, criterion, epochs, writer):
    for epoch in range(epochs):
        train(model, device, train_loader, optimizer, criterion, epoch, writer)
        test(model, device, test_loader, criterion, epoch, writer)

In [None]:
# 设置TensorBoard
writer = SummaryWriter(log_dir="./High_level/HLC/BC_log")

# 数据加载和准备
train_observations, test_observations, train_actions, test_actions = train_test_split(
    observations,
    actions,
    test_size=0.2,
    random_state=42
)

# 创建数据加载器
train_loader = DataLoader(TensorDataset(torch.tensor(train_observations, dtype=torch.float32), torch.tensor(train_actions, dtype=torch.long)), batch_size=1024, shuffle=True)
test_loader = DataLoader(TensorDataset(torch.tensor(test_observations, dtype=torch.float32), torch.tensor(test_actions, dtype=torch.long)), batch_size=1024, shuffle=False)

# 模型和优化器配置
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Model_Policy = model.policy.to(device)
optimizer = Adam(Model_Policy.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# 启动预训练
pretrain(Model_Policy, device, train_loader, test_loader, optimizer, criterion, 10000, writer)

In [None]:
model.learn(total_timesteps=int(100000))

In [None]:
model.save('./High_level/HLC/Pretrain_final')

In [None]:
model.predict(train_observations[0,:])
model.predict(obs)
print(obs)
print(train_observations[0,:])

In [None]:
model = PPO.load('./High_level/HLC/Pretrain_final')

In [None]:
# model = RecurrentPPO("MlpLstmPolicy", env, verbose=1, tensorboard_log="./High_level")
checkpoint_callback = CheckpointCallback(save_freq=10000, save_path='./High_level/HLC', name_prefix='HLC')
model.learn(total_timesteps=int(1000000), progress_bar=True,callback=checkpoint_callback,)

In [None]:
model.save('./High_level/HLC/model_final')

In [None]:
# Predict the action with model
obs,info = env.reset()
print(obs)
for i in range(10000):
    action, _state = model.predict(obs, deterministic=True)
    print(action)
    obs, reward, terminated,truncated, info = env.step(action)
    # relative_obs = obs["observation"][14:-1]
    # print(f"vector:{relative_obs}, norm:{np.linalg.norm(relative_obs)}")
    # print(info)
    # time.sleep(0.1)
    env.render()
    if terminated or truncated:
        obs, info = env.reset()

In [None]:
np.linalg.norm(env.low_env.obs["observation"][14:21])