# Start Simulation

In [None]:
# env.clean(clean_all=True)

In [None]:
import random
import numpy as np
import os
import time
from PIL import Image

import torch
torch.__version__
device = torch.device("cuda")

In [None]:
import getpass
user = getpass.getuser()
print(user)

In [None]:
usd_path = f"omniverse://localhost/Users/{user}/UVA/uva_table.usd"

In [None]:
from omni.isaac.kit import SimulationApp    
simulation_app = SimulationApp({"headless": True, "open_usd": usd_path,  "livesync_usd": usd_path}) 

In [None]:
# set log level
import logging
import carb

logging.getLogger("omni.hydra").setLevel(logging.ERROR)
logging.getLogger("omni.isaac.urdf").setLevel(logging.ERROR)
logging.getLogger("omni.physx.plugin").setLevel(logging.ERROR)

logging.getLogger("omni.isaac.synthetic_utils").setLevel(logging.ERROR)
logging.getLogger("omni.isaac.synthetic_utils.syntheticdata").setLevel(logging.ERROR)
logging.getLogger("omni.hydra.scene_delegate.plugin").setLevel(logging.ERROR)


l = carb.logging.LEVEL_ERROR
carb.settings.get_settings().set("/log/level", l)
carb.settings.get_settings().set("/log/fileLogLevel", l)
carb.settings.get_settings().set("/log/outputStreamLevel", l)

# # This logged error is printed as it should
# carb.log_error("ERROR")
# # This warning is printed but should not
# carb.log_warn("WARNING")

# Config

In [None]:
from task.config import DATA_PATH, FEATURE_PATH
task_type = "Table"
side_choice = "Border"
base_asset_id = 0
load_nucleus = True

In [None]:
pause

# Init Env

In [None]:
from uva_env import UvaEnv
env = UvaEnv()

In [None]:
from task.utils import add_scene_default
add_scene_default()

In [None]:
print(list(env.stage.TraverseAll()))

In [None]:
env.clean()
env.world.step(render=True)

# Scene

In [None]:
from task.scene import ArrangeScene
scene = ArrangeScene(task_type, side_choice, base_asset_id = 0, traj_id = 0, load_nucleus = load_nucleus)
env.scene = scene

In [None]:
# add base
scene.add_base_asset()

In [None]:
# add room
scene.add_room()

In [None]:
env.world.step(render=True)

In [None]:
# randomize scene
from layout.randomizer import Randomizer
randomizer = Randomizer(load_nucleus = load_nucleus)

# Reward

In [None]:
from uv.reward import Rewarder

In [None]:
rewarder = Rewarder(env.world)
env.rewarder = rewarder

# Render

In [None]:
from render.helper import RenderHelper
render = RenderHelper(task_type, side_choice)

render.add_task_cameras(camera_type="main")
render.add_task_cameras(camera_type="side")
render.set_cameras()

In [None]:
render.camera_paths

In [None]:
pause

#  Learning

In [None]:
# from learning.network.resnet import ResNetFeatureExtractor

# feature extraction
from learning.utils import extract_image_clip_feature_and_save, obtain_action_from_trainer
from transformers import CLIPProcessor, CLIPModel

feature_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
feature_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# replay buffer
import json
from learning.replay_buffer import ReplayBuffer
from learning.config import *

buffer = ReplayBuffer(max_size=2000)

# trainer
from learning.network.sac import *

policy = Policy()

qf1 = QFunction()
qf2 = QFunction()
target_qf1 = QFunction()
target_qf2 = QFunction()

from learning.sac_trainer import SACTrainer
from learning.ddpg_trainer import ddpg_trainer

trainer = SACTrainer(policy, qf1, qf2, target_qf1, target_qf2, 
     use_automatic_entropy_tuning = False, 
     policy_lr=1e-3, 
     qf_lr=1e-3,
     target_update_period = 5)

# Trajectory

In [None]:
use_network = True
debug = True

In [None]:
total_traj = 0
total_step = 0

---------------------

In [None]:
# traj config
for traj_id in range(1):
    print("total_traj: ", total_traj)
    env.scene.traj_id = traj_id
    
    # base
    if total_traj % 5 == 0:
        asset_count = len(env.scene.base_asset_file_paths)
        base_asset_id = random.choice([_ for _ in range(asset_count)])
        env.scene.base_asset_id = base_asset_id    
        env.clean(clean_base = True, clean_all = False)
        scene.add_base_asset()
        
        # randomize scene
        randomizer.randomize_house(rand = True)
        
    image_folder = os.path.join(DATA_PATH, task_type, side_choice, str(traj_id))
    
    env.world.step(render = True)
    env.world.step(render = True)

    # get images
    env.world.render()
    images = render.get_images()
    render.save_rgb(images[0]['rgb'], image_folder, "0")
    render.save_rgb(images[1]['rgb'], image_folder, "0_side")
    
    ## extract feature
    
#     extract_image_feature_and_save(images[0]['rgb'][:,:,:3], 
#         feature_extractor, os.path.join(image_folder, str(0) + ".pt"))
    if use_network:
        extract_image_clip_feature_and_save(f"{image_folder}/{0}.png", feature_model, feature_processor, 
            f"{image_folder}/{0}.pt",)
    

    # trajectory
    for step in range(5):
        total_step += 1
        
        # sample an object
        env.add_scene_obj(mode = "random")
        
        # TODO: get action from sampling
        if not use_network or total_traj < 10 or np.random.rand() < 0.2:
            x, y = np.tanh(np.random.randn()), np.tanh(np.random.randn())
        else:
            image_feature_file = f"{image_folder}/{step}.pt"
            
            object_info = env.scene.objects[-1]
            object_type = object_info["type"]
            obj_name = object_info["name"][:-4]
            object_feature_file = os.path.join(FEATURE_PATH, object_type, obj_name + ".pt")
            x, y = obtain_action_from_trainer(image_feature_file, object_feature_file, trainer, 
                                              scaler=np.exp(- total_traj / 100))
            object_info["use_network"] = True
        
        # load the object into the scene
        env.put_last_object((x, y)) 
        env.world.step(render=True)
        
        # register the object to the world for physics update
        env.register_last_object()
        env.world.step(render=True)

        # get images
        env.world.render()
        images = render.get_images()
        render.save_rgb(images[0]['rgb'], image_folder, str(step + 1))
        render.save_rgb(images[1]['rgb'], image_folder, f"{step + 1}_side")

        ## calculate reward
        env.calculate_last_reward(simulation_step = 30)
        
        ## extract feature
#         extract_image_feature_and_save(images[0]['rgb'][:,:,:3], 
#             feature_extractor, os.path.join(image_folder, str(step + 1) + ".pt"))
        if use_network:
            extract_image_clip_feature_and_save(f"{image_folder}/{step + 1}.png", feature_model, feature_processor, 
                f"{image_folder}/{step + 1}.pt",)
    

        ## reset
        env.world.reset()
        env.world.step(render=True)
        
        ## trainer nework
        if use_network and total_step % UPDATE_TRAINER_STEPS == 0 and total_traj > 5:
            batch = buffer.sample_batch(batch_size = BATCH_SIZE)
            trainer.update(batch)
            
            if debug and total_step % 10 == 0:
                rewards = batch['rewards'].to(device)
                terminals = batch['terminals'].to(device)
                obs = batch['observations'].to(device)
                actions = batch['actions'].to(device)
                next_obs = batch['next_observations'].to(device)
                obj_features = batch['object_features'].to(device)
                
                dist = trainer.policy(obs, obj_features)
                pred = trainer.qf1(obs, obj_features, actions)
                print(f"debug {total_traj}/{total_step}", #"\n dist: ", dist.mean.flatten().tolist(), dist.stddev.flatten().tolist(),
                      "\n pred:", pred.flatten().tolist(),
                      "\n rewards: ", rewards.flatten().tolist())


    # Record
    record = env.scene.get_scene_data()
    env.scene.save_scene_data()
    # print("record: ", record)
    
    # Add record to buffer
    if use_network:
        buffer.add_scene_sample(record)

    # Reset (env clean)
    env.clean(clean_all = False)
    env.step(render = True)
    total_traj += 1

# Debug

In [None]:
env.clean(clean_base = True, clean_all = False)

In [None]:
# get images
env.world.render()
env.world.render()
images = render.get_images()

Image.fromarray(images[0]['rgb'], "RGBA").resize((300, 300))

##  Debug Training

In [None]:
import json
from learning.replay_buffer import ReplayBuffer

buffer = ReplayBuffer(max_size=1000)

for i in range(1):
    replay = json.load(open(f"{DATA_PATH}/{task_type}/{side_choice}/{i}/scene.json"))
    buffer.add_scene_sample(replay)

In [None]:
len(buffer.dataset)

In [None]:
# trainer
from learning.network.sac import *

policy = Policy()
target_policy = Policy()

qf1 = QFunction()
qf2 = QFunction()

target_qf1 = QFunction()
target_qf1.eval()

target_qf2 = QFunction()
target_qf2.eval()

from learning.sac_trainer import SACTrainer
from learning.ddpg_trainer import DDPGTrainer

trainer = SACTrainer(policy, qf1, qf2, target_qf1, target_qf2, 
     use_automatic_entropy_tuning = True, 
     policy_lr=1e-3, 
     qf_lr=1e-3,
     target_update_period = 5)

In [None]:
trainer = DDPGTrainer(qf1, target_qf1, policy, target_policy,
     policy_learning_rate=1e-3,
     qf_learning_rate=1e-3
    )

In [None]:
self = trainer

In [None]:
for _ in range(1):
    batch = buffer.sample_batch(batch_size=4)

    rewards = batch['rewards'].to(device)
    terminals = batch['terminals'].to(device)
    obs = batch['observations'].to(device)
    actions = batch['actions'].to(device)
    next_obs = batch['next_observations'].to(device)

    obj_features = batch['object_features'].to(device)

    # print("rewards", rewards)

    self.update(batch)
    
    if _ % 10 == 0:
        dist = self.policy(obs, obj_features)
        pred = self.qf1(obs, obj_features, actions)
        print(_, "\n dist   : ", [round(x, 3) for x in dist.mean.flatten().tolist()], 
                  [round(x, 3) for x in dist.stddev.flatten().tolist()],
              "\n pred   :", [round(x, 3) for x in pred.flatten().tolist()],
              "\n rewards: ", [round(x, 3) for x in rewards.flatten().tolist()])
    
#     # debug
#     pred = self.qf1(obs, obj_features, actions)
#     loss = self.qf_criterion(pred, rewards)
    
#     self.qf1_optimizer.zero_grad()
#     loss.backward()
#     self.qf1_optimizer.step()
    
#     if _ % 20 == 0:
#         print(_, "\n debug loss: ", loss.item(), "\n pred:", 
#               pred.flatten().tolist(), "\n rewards: ", rewards.flatten().tolist())

In [None]:
dist = self.policy(obs, obj_features)

In [None]:
dist = self.policy(obs.to(device), obj_features.to(device))

In [None]:
dist.normal_mean, dist.stddev

In [None]:
dist.sample(scaler = 0.2)

In [None]:
self = dist

In [None]:
scaler = 0.1

In [None]:
self.normal_mean + scaler * self.normal_std * MultivariateDiagonalNormal(
    torch.zeros(self.normal_mean.size()).to(self.device),
    torch.ones(self.normal_std.size()).to(self.device)
).sample()

# distribution

In [None]:
from learning.distributions import MultivariateDiagonalNormal

In [None]:
dist = JustNormal(dist.mean, dist.stddev)

In [None]:
dist.rsample_and_logprob()