In [1]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"

In [2]:
from typing import List, Tuple, Optional, Union, Iterable
import datetime

import plotly.graph_objects as go
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from e3nn import o3
from open3d.visualization.tensorboard_plugin import summary
from torch.utils.tensorboard import SummaryWriter

from diffusion_edf.embedding import NodeEmbeddingNetwork
from diffusion_edf.data import SE3, PointCloud, TargetPoseDemo, DemoSequence, DemoSeqDataset, load_demos, save_demos
from diffusion_edf.preprocess import Rescale, NormalizeColor, Downsample, PointJitter, ColorJitter
from diffusion_edf.wigner import TransformFeatureQuaternion
from diffusion_edf.score_model import ScoreModel
from diffusion_edf import transforms, pc_utils
from diffusion_edf.loss import SE3DenoisingDiffusion
from diffusion_edf.utils import sample_reference_points
from diffusion_edf.dist import diffuse_isotropic_se3, adjoint_inv_tr_isotropic_se3_score, diffuse_isotropic_se3_batched



torch.set_printoptions(precision=4, sci_mode=False)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
unit_len = 0.01
scene_voxel_size = 0.01
grasp_voxel_size = 0.01

scene_voxel_size = scene_voxel_size / unit_len
grasp_voxel_size = grasp_voxel_size / unit_len


rescale_fn = Rescale(rescale_factor=1/unit_len)
recover_scale_fn = Rescale(rescale_factor=unit_len)

scene_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=scene_voxel_size, coord_reduction="average")])
scene_unproc_fn = Compose([recover_scale_fn,])
grasp_proc_fn = Compose([rescale_fn,
                         Downsample(voxel_size=grasp_voxel_size, coord_reduction="average")])
grasp_unproc_fn = Compose([recover_scale_fn,])

In [4]:
import math

device = 'cuda:0'
compile = False

irreps_input = o3.Irreps('3x0e')
irreps_node_embedding = o3.Irreps('32x0e+16x1e+8x2e') #o3.Irreps('128x0e+64x1e+32x2e')
irreps_sh = o3.Irreps('1x0e+1x1e+1x2e')
# fc_neurons = [128, 64, 64]
# num_heads = 4
# alpha_drop = 0.2
# proj_drop = 0.0
# drop_path_rate = 0.0
# irreps_mlp_mid = 2
# n_scales = 4
# pool_ratio = 0.5
lin_mult = 10.

In [5]:
score_model = ScoreModel(irreps_input = irreps_input,
                         irreps_emb_init = irreps_node_embedding,
                         irreps_sh = irreps_sh,
                         fc_neurons_init = [32, 16, 16],
                         num_heads = 4,
                         n_scales = 4,
                         pool_ratio = 0.25,
                         dim_mult = [1, 1, 2, 2],
                         n_layers = 2,
                         gnn_radius = 3.0,
                         cutoff_radius = 5.0,
                         weight_feature_dim = 64,
                         query_downsample_ratio = 0.01,
                         device=device,
                         lin_mult=lin_mult,
                         deterministic = False,
                         compile_head = compile)

score_model = score_model.to(device)
optimizer = torch.optim.Adam(list(score_model.parameters()), lr=3e-4, betas=(0.9, 0.98), eps=1e-09, weight_decay=1e-4, amsgrad=True)



# Load demo

In [6]:
resume_log_dir: Optional[str] = None
#resume_log_dir: Optional[str] = 'runs/2023_04_21_00-34-42'
resume_checkpoint_dir: Optional[str] = None
if resume_log_dir is not None:
    if resume_checkpoint_dir is None:
        resume_checkpoint_dir = sorted(os.listdir(os.path.join(resume_log_dir, f'checkpoint')), key= lambda f:int(f.rstrip('.pt')))[-1]
    resume_training = True
    if input(f"Enter 'y' if you want to resume training from checkpoint: {os.path.join(resume_log_dir, f'checkpoint', resume_checkpoint_dir)}") == 'y':
        pass
    else:
        raise ValueError()
else:
    resume_training = False
    resume_log_dir = os.path.join('runs', f"{datetime.datetime.now().strftime('%Y_%m_%d_%H-%M-%S')}")

writer = SummaryWriter(log_dir=resume_log_dir)
log_dir = writer.log_dir

if not os.path.exists(os.path.join(log_dir, f'checkpoint')):
    os.mkdir(os.path.join(log_dir, f'checkpoint'))

In [7]:
trainset = DemoSeqDataset(dataset_dir="demo/test_demo", annotation_file="data.yaml", device=device)
train_dataloader = DataLoader(trainset, shuffle=True, collate_fn=lambda x:x)

In [8]:
if resume_training:
    checkpoint = torch.load(os.path.join(log_dir, f'checkpoint', resume_checkpoint_dir))
    score_model.load_state_dict(checkpoint['score_model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    steps = checkpoint['steps']
    print(f"resume training from checkpoint: {os.path.join(log_dir, f'checkpoint', resume_checkpoint_dir)}")
    epoch = epoch + 1
else:
    epoch = 0
    steps = 0

In [9]:
max_epochs = 200
N_samples = 10
n_epochs_per_checkpoint = 20

n_samples_x_ref = 10
visualize_raw_pcd = False

In [10]:
for epoch in range(epoch, max_epochs+1):
    for train_batch in train_dataloader:
        assert len(train_batch) == 1, "Batch training is not supported yet."

        optimizer.zero_grad(set_to_none=True)


        demo_seq: DemoSequence = train_batch[0]
        demo: TargetPoseDemo = demo_seq[1]
        scene_raw: PointCloud = demo.scene_pc
        grasp_raw: PointCloud = demo.grasp_pc
        target_poses_raw: SE3 = demo.target_poses
        scene_proc: PointCloud = scene_proc_fn(scene_raw).to(device)
        grasp_proc: PointCloud = grasp_proc_fn(grasp_raw).to(device)
        target_poses: SE3 = rescale_fn(target_poses_raw).to(device)
        T_target: torch.Tensor = target_poses.poses


        min_time = 1e-3
        max_time = 1.0
        time_in = (min_time/max_time + torch.rand(1, dtype=T_target.dtype, device=T_target.device) * (1-min_time/max_time))*max_time
        eps = time_in / 2
        std = torch.sqrt(time_in) * lin_mult
        x_ref, n_neighbors = sample_reference_points(PointCloud.transform_pcd(scene_proc, target_poses.inv())[0].points, grasp_proc.points, r=3, n_samples=n_samples_x_ref)
        # T, delta_T, (gt_ang_score, gt_lin_score), (gt_ang_score_ref, gt_lin_score_ref) = diffuse_isotropic_se3(T0 = T_target, eps=eps, std=std, x_ref=x_ref, double_precision=True)
        T, delta_T, (gt_ang_score, gt_lin_score), (gt_ang_score_ref, gt_lin_score_ref) = diffuse_isotropic_se3_batched(T0 = T_target, eps=eps, std=std, x_ref=x_ref, double_precision=True)
        T, delta_T, (gt_ang_score, gt_lin_score), (gt_ang_score_ref, gt_lin_score_ref) = T.squeeze(-2), delta_T.squeeze(-2), (gt_ang_score.squeeze(-2), gt_lin_score.squeeze(-2)), (gt_ang_score_ref.squeeze(-2), gt_lin_score_ref.squeeze(-2))


        key_feature = scene_proc.colors
        key_coord = scene_proc.points
        key_batch = torch.zeros(len(key_coord), device=device, dtype=torch.long)
        query_feature = grasp_proc.colors
        query_coord = grasp_proc.points
        query_batch = torch.zeros(len(query_coord), device=device, dtype=torch.long)
        (ang_score, lin_score), query, query_info, key_info = score_model(T=T,
                                                                          key_feature=key_feature, key_coord=key_coord, key_batch=key_batch,
                                                                          query_feature=query_feature, query_coord=query_coord, query_batch=query_batch,
                                                                          time=time_in, info_mode='NO_GRAD' if True else 'NONE')

        with torch.no_grad():
            ang_score_ref, lin_score_ref = adjoint_inv_tr_isotropic_se3_score(x_ref=-x_ref, ang_score=ang_score, lin_score=lin_score)
        target_ang_score = gt_ang_score * torch.sqrt(time_in)
        target_lin_score = gt_lin_score * torch.sqrt(time_in)
        target_ang_score_ref = gt_ang_score_ref * torch.sqrt(time_in)
        target_lin_score_ref = gt_lin_score_ref * torch.sqrt(time_in)


        ang_score_diff = target_ang_score - ang_score
        lin_score_diff = target_lin_score - lin_score
        # ang_loss = torch.norm(ang_score_diff, dim=-1).mean(dim=-1)
        # lin_loss = torch.norm(lin_score_diff * lin_mult, dim=-1).mean(dim=-1)
        ang_loss = torch.sum(torch.square(ang_score_diff), dim=-1).mean(dim=-1)
        lin_loss = torch.sum(torch.square(lin_score_diff * lin_mult), dim=-1).mean(dim=-1)
        loss = ang_loss + lin_loss

        loss.backward()
        optimizer.step()


        with torch.no_grad():
            writer.add_scalar(tag="Loss/train", scalar_value=loss.item(), global_step=steps)
            writer.add_scalar(tag="Loss/angular", scalar_value=ang_loss.item(), global_step=steps)
            writer.add_scalar(tag="Loss/linear", scalar_value=lin_loss.item(), global_step=steps)

            target_norm_ang, target_norm_lin = torch.norm(target_ang_score.detach(), dim=-1), torch.norm(target_lin_score.detach(), dim=-1) # Shape: (Nbatch, ), (Nbatch, )
            score_norm_ang, score_norm_lin = torch.norm(ang_score.detach(), dim=-1), torch.norm(lin_score.detach(), dim=-1)         # Shape: (Nbatch, ), (Nbatch, )
            writer.add_scalar(tag="norm/target_ang", scalar_value=target_norm_ang.mean(dim=-1).item(), global_step=steps)
            writer.add_scalar(tag="norm/target_lin", scalar_value=target_norm_lin.mean(dim=-1).item(), global_step=steps)
            writer.add_scalar(tag="norm/inferred_ang", scalar_value=score_norm_ang.mean(dim=-1).item(), global_step=steps)
            writer.add_scalar(tag="norm/inferred_lin", scalar_value=score_norm_lin.mean(dim=-1).item(), global_step=steps)

            target_norm_ang_ref, target_norm_lin_ref = torch.norm(target_ang_score_ref.detach(), dim=-1), torch.norm(target_lin_score_ref.detach(), dim=-1) # Shape: (Nbatch, ), (Nbatch, )
            score_norm_ang_ref, score_norm_lin_ref = torch.norm(ang_score_ref.detach(), dim=-1), torch.norm(lin_score_ref.detach(), dim=-1)         # Shape: (Nbatch, ), (Nbatch, )
            writer.add_scalar(tag="norm_ref/target_ang", scalar_value=target_norm_ang_ref.mean(dim=-1).item(), global_step=steps)
            writer.add_scalar(tag="norm_ref/target_lin", scalar_value=target_norm_lin_ref.mean(dim=-1).item(), global_step=steps)
            writer.add_scalar(tag="norm_ref/inferred_ang", scalar_value=score_norm_ang_ref.mean(dim=-1).item(), global_step=steps)
            writer.add_scalar(tag="norm_ref/inferred_lin", scalar_value=score_norm_lin_ref.mean(dim=-1).item(), global_step=steps)

            dp_align_ang = torch.einsum('...i,...i->...', ang_score.detach(), target_ang_score.detach()) # Shape: (Nbatch, )
            dp_align_lin = torch.einsum('...i,...i->...', lin_score.detach(), target_lin_score.detach()) # Shape: (Nbatch, )
            dp_align_ang_normalized = dp_align_ang / target_norm_ang / score_norm_ang # Shape: (Nbatch, )
            dp_align_lin_normalized = dp_align_lin / target_norm_lin / score_norm_lin # Shape: (Nbatch, )
            writer.add_scalar(tag="alignment/unnormalized/ang", scalar_value=dp_align_ang.mean(dim=-1).item(), global_step=steps)
            writer.add_scalar(tag="alignment/unnormalized/lin", scalar_value=dp_align_lin.mean(dim=-1).item(), global_step=steps)
            writer.add_scalar(tag="alignment/normalized/ang", scalar_value=dp_align_ang_normalized.mean(dim=-1).item(), global_step=steps)
            writer.add_scalar(tag="alignment/normalized/lin", scalar_value=dp_align_lin_normalized.mean(dim=-1).item(), global_step=steps)

        steps += 1

    if epoch % n_epochs_per_checkpoint == 0:
        torch.save({'epoch': epoch,
                    'steps': steps,
                    'score_model_state_dict': score_model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    }, os.path.join(log_dir, f'checkpoint/{epoch}.pt'))

        with torch.no_grad():
            if visualize_raw_pcd:
                target_pose_pcd = PointCloud.merge(scene_raw, grasp_raw.transformed(target_poses_raw)[0])
                diffused_pose_pcd = PointCloud.merge(scene_raw, grasp_raw.transformed( recover_scale_fn(SE3(T)) )[0])
                grasp_pcd = grasp_raw
            else:
                target_pose_pcd = PointCloud.merge(scene_proc, grasp_proc.transformed(target_poses)[0])
                diffused_pose_pcd = PointCloud.merge(scene_proc, grasp_proc.transformed( SE3(T) )[0])
                grasp_pcd = grasp_proc
            
            ##### Query Summary #####
            query_weight, _, query_points, query_point_batch = query
            batch_vis_idx = (query_point_batch == 0).nonzero().squeeze(-1)
            query_weight, query_points = query_weight[batch_vis_idx], query_points[batch_vis_idx]
            if visualize_raw_pcd:
                query_points = query_points * unit_len

            N_repeat = 500
            query_points_colors = torch.tensor([0.01, 1., 1.], device=query_weight.device, dtype=query_weight.dtype).expand(N_repeat, 1, 3) * query_weight[None, :, None]
            r_query_ball = 0.5
            if visualize_raw_pcd:
                r_query_ball *= unit_len

            ball = torch.randn(N_repeat,1,3, device=query_points.device, dtype=query_points.dtype)
            ball = ball/ball.norm(dim=-1, keepdim=True) * r_query_ball
            query_points = (query_points + ball).reshape(-1,3)
            query_points_colors = query_points_colors.reshape(-1,3)
        
        writer.add_3d(
            tag = "Target Pose",
            data = {
                "vertex_positions": target_pose_pcd.points.cpu(),
                "vertex_colors": target_pose_pcd.colors.cpu(),  # (N, 3)
            },
            step=epoch//n_epochs_per_checkpoint,
        )

        writer.add_3d(
            tag = "Diffused Pose",
            data = {
                "vertex_positions": diffused_pose_pcd.points.cpu(),
                "vertex_colors": diffused_pose_pcd.colors.cpu(),  # (N, 3)
            },
            step=epoch//n_epochs_per_checkpoint,
            description=f"Diffuse time: {time_in.item()} || eps: {eps.item()} || std: {std.item()}",
        )

        writer.add_3d(
            tag = "Grasp",
            data = {
                "vertex_positions": grasp_pcd.points.cpu(),
                "vertex_colors": grasp_pcd.colors.cpu(),  # (N, 3)
            },
            step=epoch//n_epochs_per_checkpoint,
        )

        writer.add_3d(
            tag = "Query points",
            data = {
                # "vertex_positions": query_points.repeat(max(int(1000//len(query_points)),1),1).cpu(),      # There is a bug with too small number of points so repeat
                # "vertex_colors": query_points_colors.repeat(max(int(1000//len(query_points)),1),1).cpu(),  # (N, 3)
                "vertex_positions": query_points.cpu(),      # There is a bug with too small number of points so repeat
                "vertex_colors": query_points_colors.cpu(),  # (N, 3)
            },
            step=epoch//n_epochs_per_checkpoint,
        )

        
        print(f"(Epoch: {epoch}) Successfully saved logs to: {log_dir}")

(Epoch: 0) Successfully saved logs to: runs/2023_04_24_20-03-12
(Epoch: 20) Successfully saved logs to: runs/2023_04_24_20-03-12
(Epoch: 40) Successfully saved logs to: runs/2023_04_24_20-03-12
(Epoch: 60) Successfully saved logs to: runs/2023_04_24_20-03-12
(Epoch: 80) Successfully saved logs to: runs/2023_04_24_20-03-12
(Epoch: 100) Successfully saved logs to: runs/2023_04_24_20-03-12
(Epoch: 120) Successfully saved logs to: runs/2023_04_24_20-03-12
