In [None]:
import sys
path_to_add = '/home/jin/SRC-gym/gym-env/Hierachical_Learning_v2'
sys.path.append(path_to_add)

import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import train_test_split
import cv2
from cv_bridge import CvBridge
import rospy
from sensor_msgs.msg import Image as RosImage
import gymnasium as gym
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.utils import set_random_seed
from Approach_env import SRC_approach
from torchvision.models import resnet50
import time
import clip

device = "cuda" if torch.cuda.is_available() else "cpu"

# Define the BehaviorCloningModel class as before
class BehaviorCloningModel(nn.Module):
    def __init__(self, feature_extractor):
        super(BehaviorCloningModel, self).__init__()
        self.feature_extractor = feature_extractor
        self.regressor = nn.Sequential(
            nn.BatchNorm1d(2048 + 7),
            nn.Linear(2048 + 7, 256),
            nn.ReLU(),
            nn.Linear(256, 7),
            nn.Tanh()  # Apply Tanh to restrict output range to (-1, 1)
        )

    def forward(self, x, proprioceptive_data):
        with torch.no_grad():
            features = self.feature_extractor(x)
            features = features.view(features.size(0), -1)
        combined_input = torch.cat((features, proprioceptive_data), dim=1)
        return self.regressor(combined_input)

def load_imgnet_pretrained_model():
    # Normally, load the imgnet pretrained model
    model = resnet50(pretrained=True)
    model.fc = nn.Identity() 
    # Assume we adjust the model to fit imgnet architecture specifics here
    return model

# Load the trained imgnet model
def load_imgnet_model():
    model = load_imgnet_pretrained_model()  # Ensure this function returns the correctly configured model
    model.eval()
    return model

imgnet_model = load_imgnet_model()

# Function to load the trained model
def load_trained_model(model_path, imgnet_model):
    model = BehaviorCloningModel(imgnet_model).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model

# Function to predict action using the loaded model
def predict_action(model, image_np, proprio_data):
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image = transform(image_np).unsqueeze(0).to(device)
    proprioceptive_tensor = torch.tensor(proprio_data, dtype=torch.float32).unsqueeze(0).to(device)
    
    with torch.no_grad():
        predicted_action = model(image, proprioceptive_tensor)
    return predicted_action.cpu().numpy()

# Initialize ROS node
current_images = {}
image_received = {}
bridge = CvBridge()
view_name = 'front'      # front, back
task_name = 'Approach'  # Approach, Place, Insert, Regrasp, Pullout
algor_name = 'imgnet'     # imgnet only for this script

def image_callback(msg, camera_id):
    """Callback to process and save images from different cameras."""
    global current_images, image_received
    try:
        current_images[camera_id] = bridge.imgmsg_to_cv2(msg, "bgr8")
        image_received[camera_id] = True
    except Exception as e:
        rospy.logerr(f"Failed to convert image from {camera_id}: {e}")

if view_name == 'front':
    camera_topics = {
        'front': '/ambf/env/cameras/cameraL/ImageData'
    }

elif view_name == 'back':
    camera_topics = {
        'back': '/ambf/env/cameras/normal_camera/ImageData'
    }

for cam_id, topic in camera_topics.items():
    rospy.Subscriber(topic, RosImage, image_callback, callback_args=(cam_id))
    image_received[cam_id] = False

def wait_for_images():
    """Wait for all cameras to have received an image."""
    rate = rospy.Rate(100)
    while not all(image_received.values()) and not rospy.is_shutdown():
        rate.sleep()
    for key in image_received:
        image_received[key] = False

# Example usage
model_path = f'/home/jin/SRC-gym/gym-env/Hierachical_Learning_v2/SRC_img_data/{task_name}/{algor_name}/{view_name}/model_final.pth'
model = load_trained_model(model_path, imgnet_model)

In [None]:
seed = 60
set_random_seed(seed)

max_episode_steps=500
trans_step = 0.05e-2  # Trans unit in m
angle_step = np.deg2rad(2)
jaw_step = 0.05
threshold = [0.3,np.deg2rad(10)]   # Trans unit in cm

step_size = np.array([trans_step,trans_step,trans_step,angle_step,angle_step,angle_step,jaw_step],dtype=np.float32) 
####################

threshold_expert = [0.1,np.deg2rad(5)] 
gym.envs.register(id="TD3_HER_sparse", entry_point=SRC_approach, max_episode_steps=max_episode_steps)
env = gym.make("TD3_HER_sparse", render_mode=None,reward_type = "dense",seed = seed, threshold = threshold_expert,max_episode_step=max_episode_steps,step_size=step_size)

In [None]:
env.reset()

In [None]:
num_episodes = 3
for episode in range(num_episodes):
    obs,_ = env.reset()
    time.sleep(0.5)
    for timestep in range(max_episode_steps):
        wait_for_images()  # Wait until an image is received
        proprio_data = obs["observation"][0:7]
        action = predict_action(model, current_images[view_name], proprio_data).squeeze()
        next_obs, reward, done, _, info = env.step(action)
        time.sleep(0.01)
        obs = next_obs
        if done:
            print(timestep)
            break