In [None]:
import os
os.environ['MUJOCO_GL'] = 'egl'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

if 'notebooks' not in os.listdir(os.getcwd()):
    os.chdir('../') #changing directories so that output/gsplat_full etc. exists

from collision.utils import DummyCam, ImageDemoDataset, generate_camera, put_pose_into_mujoco, update_reconstruction_dict, get_normalized_function
from utils.mujoco_utils import compute_camera_extrinsic_matrix
from scene.cameras import Camera_Pose
from collision.chain_utils import build_chain_relation_map
from collision.network import SingleNetwork, HyperNetwork
from contextlib import redirect_stdout
from video_api import initialize_gaussians

import cv2
from gaussian_renderer import render
import sys 
import torch 
from PIL import Image
import numpy as np
import mujoco
import matplotlib.pyplot as plt
import torch.nn.functional as F
from tqdm import tqdm, trange
from transformers import CLIPProcessor, CLIPModel
from IPython.display import display, clear_output
from torchvision.transforms import transforms


from pathlib import Path
from itertools import cycle

In [5]:
# load mujoco
model_xml_path = Path("collision_scene/universal_robots_ur5e_scene2/scene.xml")
model = mujoco.MjModel.from_xml_path(model_xml_path.as_posix())
data = mujoco.MjData(model)

mujoco.mj_resetData(model, data)


def sample_collision_pose():
    pose = np.random.uniform(model.jnt_range[:, 0], model.jnt_range[:, 1])
    put_pose_into_mujoco(model, data, pose)
    return pose

In [6]:
# set camera
dummy_cams = [
    DummyCam(0, -45.0, 2.5, lookat=[0,  0, 0]),
]
cams = [generate_camera(dummy_cam) for dummy_cam in dummy_cams]

renderer = mujoco.Renderer(model, 480, 480)
renderer.update_scene(data, camera=cams[0])

In [None]:
mujoco.mj_resetData(model, data)

while data.ncon != 10:
    pose = sample_collision_pose()

renderer.update_scene(data, camera=cams[0])
pixels = renderer.render()
image = Image.fromarray(pixels)
image

In [57]:
relation_map, chain = build_chain_relation_map(model_xml_path.as_posix())
sdf_model = HyperNetwork(chain.n_joints, relation_map)
state_dict = torch.load('output/universal_robots_ur5e_robotiq/sdf_net.ckpt', weights_only=True)
sdf_model.load_state_dict(state_dict)
for parameters in sdf_model.parameters():
    parameters.requires_grad_(False)
sdf_model.cuda()
del state_dict

In [58]:
sdf_we_model = HyperNetwork(chain.n_joints, relation_map)
state_dict = torch.load('output/universal_robots_ur5e_robotiq/sdf_net_wo_eik.ckpt', weights_only=True)
sdf_we_model.load_state_dict(state_dict)
for parameters in sdf_model.parameters():
    parameters.requires_grad_(False)
sdf_we_model.cuda()
del state_dict

In [59]:
def get_p(sdf_m, joint_angles):
    sdf, s = sdf_m(joint_angles[None])
    return torch.sigmoid(sdf * s)


In [None]:
get_p(sdf_model, joint_angles), get_p(sdf_we_model, joint_angles)

In [None]:
sdf_model(joint_angles[None]), sdf_we_model(joint_angles[None])

In [62]:
import time

In [None]:
joint_angles = torch.tensor(pose, dtype=torch.float32).cuda()
action_t = torch.nn.Parameter(joint_angles, requires_grad=True,)
optimize = torch.optim.Adam([action_t], lr=0.01)
first = time.time()
for i in range(1000):    
    sdf, s = sdf_model(action_t[None])
    
    if sdf < -0.100:
        break
    else:
        optimize.zero_grad()
        sdf.backward()
        optimize.step()
secend = time.time()
print(f"SDF:{sdf.item():.4f} T:{secend - first:.4f}")

In [None]:
sdf, s = sdf_model(action_t[None])
sdf.backward()
action_t.grad

In [None]:
data.qpos = action_t.detach().cpu().numpy()
mujoco.mj_step(model, data)
mujoco.mj_collision(model, data)

print(data.ncon)
renderer.update_scene(data, camera=cams[0])
pixels = renderer.render()
image = Image.fromarray(pixels)
image

In [None]:
joint_angles = torch.tensor(pose, dtype=torch.float32).cuda()
action_t = torch.nn.Parameter(joint_angles, requires_grad=True,)
optimize = torch.optim.Adam([action_t], lr=0.01)
first = time.time()
tbar = trange(1000)
for i in tbar:    
    p = get_p(sdf_we_model, action_t)
    
    if p < 0.5:
        break
    else:
        optimize.zero_grad()
        p.backward()
        optimize.step()
    tbar.set_postfix({
        "p": format(p.item(), '.4f'),
    })
secend = time.time()
print(f"p:{p.item():.4f} T:{secend - first:.4f}")

In [None]:
get_p(sdf_model, action_t).item(), get_p(sdf_we_model, action_t).item()

In [None]:
action_t.grad

In [74]:
temp.backward()

In [None]:
data.qpos = action_t.detach().cpu().numpy()
mujoco.mj_step(model, data)
mujoco.mj_collision(model, data)

print(data.ncon)
renderer.update_scene(data, camera=cams[0])
pixels = renderer.render()
image = Image.fromarray(pixels)
image