In [1]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable
import math

from beartype import beartype
import datetime
import plotly.graph_objects as go
from tqdm import tqdm
import yaml

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from e3nn import o3

from diffusion_edf.data import DemoSeqDataset, DemoSequence, TargetPoseDemo, PointCloud, SE3
from diffusion_edf.gnn_data import FeaturedPoints, merge_featured_points, GraphEdge, flatten_featured_points, set_featured_points_attribute, _featured_points_repr
from diffusion_edf import train_utils
from diffusion_edf import transforms
from diffusion_edf.utils import sample_reference_points
from diffusion_edf.dist import diffuse_isotropic_se3, adjoint_inv_tr_isotropic_se3_score, diffuse_isotropic_se3_batched
from diffusion_edf.point_attentive_score_model import PointAttentiveScoreModel
from diffusion_edf.trainer import DiffusionEdfTrainer

torch.set_printoptions(precision=4, sci_mode=False)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
configs_root_dir = 'configs/pick_point_attn'
train_configs_file = 'train_configs.yaml'
task_configs_file = 'task_configs.yaml'
trainer = DiffusionEdfTrainer(configs_root_dir=configs_root_dir,
                              train_configs_file=train_configs_file,
                              task_configs_file=task_configs_file)
trainer.init(log_name = trainer.get_current_time(postfix="MultiphaseTraining"),
             resume_training = False,)

ScoreModel: Initializing Key Model




ScoreModel: Initializing Query Model
ScoreModel: Initializing Score Head


True

# Main Loop

In [3]:
for epoch in range(trainer.epoch, trainer.max_epochs+1):
    for demo_batch in trainer.trainloader:
        B = len(demo_batch)
        assert B == 1, "Batch training is not supported yet."

        trainer.optimizer.zero_grad(set_to_none=True)

        scene_input, grasp_input, T_target = train_utils.flatten_batch(demo_batch=demo_batch) # T_target: (Nbatch, Ngrasps, 7)
        T_target = T_target.squeeze(0) # (B=1, N_poses=1, 7) -> (1,7) 

        ########################################## Augmentation #########################################
        if trainer.t_augment is not None:
            x_ref, _ = train_utils.transform_and_sample_reference_points(T_target=T_target,
                                                                         scene_points=scene_input,
                                                                         grasp_points=grasp_input,
                                                                         contact_radius=trainer.contact_radius,
                                                                         n_samples_x_ref=1)
            T_target, _, __, ___, ____ = train_utils.diffuse_T_target(T_target=T_target, 
                                                                      x_ref=x_ref, 
                                                                      time=torch.tensor([trainer.t_augment], device=T_target.device), 
                                                                      lin_mult=trainer.score_model.lin_mult,
                                                                      ang_mult=trainer.score_model.ang_mult)
        ##################################################################################################

        # time_schedule = trainer.diffusion_schedules[torch.randint(low=0, high=trainer.n_schedules, size=(1,)).item()]
        # time = train_utils.random_time(min_time=time_schedule[1], max_time=time_schedule[0], device=T_target.device) # Shape: (1,)
        time_in = torch.empty(0, device=trainer.device)
        T = torch.empty(0,7, device=trainer.device)
        gt_ang_score, gt_lin_score = torch.empty(0,3, device=trainer.device), torch.empty(0,3, device=trainer.device)
        gt_ang_score_ref, gt_lin_score_ref = torch.empty(0,3, device=trainer.device), torch.empty(0,3, device=trainer.device)
        
        for time_schedule in trainer.diffusion_schedules:
            time_ = train_utils.random_time(min_time=time_schedule[1], max_time=time_schedule[0], device=T_target.device) # Shape: (1,)
            
            x_ref_, n_neighbors_ = train_utils.transform_and_sample_reference_points(T_target=T_target,
                                                                                scene_points=scene_input,
                                                                                grasp_points=grasp_input,
                                                                                contact_radius=trainer.contact_radius,
                                                                                n_samples_x_ref=trainer.n_samples_x_ref)
            T_, delta_T_, time_in_, gt_score_, gt_score_ref_ = train_utils.diffuse_T_target(T_target=T_target, 
                                                                                    x_ref=x_ref_, 
                                                                                    time=time_, 
                                                                                    lin_mult=trainer.score_model.lin_mult,
                                                                                    ang_mult=trainer.score_model.ang_mult)
            
            
            (gt_ang_score_, gt_lin_score_), (gt_ang_score_ref_, gt_lin_score_ref_) = gt_score_, gt_score_ref_
            T = torch.cat([T, T_], dim=0)
            time_in = torch.cat([time_in, time_in_], dim=0)
            gt_ang_score = torch.cat([gt_ang_score, gt_ang_score_], dim=0)
            gt_lin_score = torch.cat([gt_lin_score, gt_lin_score_], dim=0)
            gt_ang_score_ref = torch.cat([gt_ang_score_ref, gt_ang_score_ref_], dim=0)
            gt_lin_score_ref = torch.cat([gt_lin_score_ref, gt_lin_score_ref_], dim=0)

        loss, fp_info, tensor_info, statistics = trainer.score_model.get_train_loss(Ts=T, time=time_in, key_pcd=scene_input, query_pcd=grasp_input,
                                                                                    target_ang_score=gt_ang_score, target_lin_score=gt_lin_score)
        scene_out: FeaturedPoints = fp_info['key_fp']
        grasp_out: FeaturedPoints = fp_info['query_fp']

        loss.backward()
        trainer.optimizer.step()

        with torch.no_grad():
            for tag, scalar_value in statistics.items():
                trainer.logger.add_scalar(tag=tag, scalar_value=scalar_value, global_step=trainer.steps)
        trainer.steps += 1

    if trainer.epoch % trainer.n_epochs_per_checkpoint == 0:
        with torch.no_grad():
            scene_pcd = PointCloud(points=scene_input.x, colors=scene_input.f)
            grasp_pcd = PointCloud(points=grasp_input.x, colors=grasp_input.f)
            target_pose_pcd = PointCloud.merge(
                scene_pcd,
                grasp_pcd.transformed(SE3(T_target), squeeze=True),
            )
            diffused_pose_pcd = PointCloud.merge(
                scene_pcd,
                grasp_pcd.transformed(SE3(T))[0],
            )
            scene_attn_pcd = PointCloud(points=scene_out.x.detach().cpu(), 
                                        colors=scene_out.w.detach().cpu(),
                                        cmap='magma')
            grasp_attn_pcd = PointCloud(points=grasp_out.x.detach().cpu(), 
                                        colors=grasp_out.w.detach().cpu(),
                                        cmap='magma')
        
            query_weight, query_points, query_point_batch = grasp_out.w.detach(), grasp_out.x.detach(), grasp_out.b.detach(), 
            batch_vis_idx = (query_point_batch == 0).nonzero().squeeze(-1)
            query_weight, query_points = query_weight[batch_vis_idx], query_points[batch_vis_idx]

            N_repeat = 500
            query_points_colors = torch.tensor([0.01, 1., 1.], device=query_weight.device, dtype=query_weight.dtype).expand(N_repeat, 1, 3) * query_weight[None, :, None]
            r_query_ball = 0.5

            ball = torch.randn(N_repeat,1,3, device=query_points.device, dtype=query_points.dtype)
            ball = ball/ball.norm(dim=-1, keepdim=True) * r_query_ball
            query_points = (query_points + ball).reshape(-1,3)
            query_points_colors = query_points_colors.reshape(-1,3)

        trainer.logger.add_3d(
            tag = "Scene Attention",
            data = {
                "vertex_positions": scene_attn_pcd.points.cpu(),
                "vertex_colors": scene_attn_pcd.colors.cpu(),  # (N, 3)
            },
            step=trainer.epoch//trainer.n_epochs_per_checkpoint,
        )

        trainer.logger.add_3d(
            tag = "Grasp Attention",
            data = {
                # "vertex_positions": query_points.repeat(max(int(1000//len(query_points)),1),1).cpu(),      # There is a bug with too small number of points so repeat
                # "vertex_colors": query_points_colors.repeat(max(int(1000//len(query_points)),1),1).cpu(),  # (N, 3)
                "vertex_positions": query_points.cpu(),      # There is a bug with too small number of points so repeat
                "vertex_colors": query_points_colors.cpu(),  # (N, 3)
            },
            step=trainer.epoch//trainer.n_epochs_per_checkpoint,
        )

        trainer.logger.add_3d(
            tag = "Target Pose",
            data = {
                "vertex_positions": target_pose_pcd.points.cpu(),
                "vertex_colors": target_pose_pcd.colors.cpu(),  # (N, 3)
            },
            step=trainer.epoch//trainer.n_epochs_per_checkpoint,
        )

        trainer.logger.add_3d(
            tag = "Diffused Pose",
            data = {
                "vertex_positions": diffused_pose_pcd.points.cpu(),
                "vertex_colors": diffused_pose_pcd.colors.cpu(),  # (N, 3)
            },
            step=trainer.epoch//trainer.n_epochs_per_checkpoint,
            #description=f"Diffuse time: {time_in[0].item()} || eps: {eps.item()} || std: {std.item()}",
        )

        trainer.logger.add_3d(
            tag = "Grasp",
            data = {
                "vertex_positions": grasp_pcd.points.cpu(),
                "vertex_colors": grasp_pcd.colors.cpu(),  # (N, 3)
            },
            step=trainer.epoch//trainer.n_epochs_per_checkpoint,
        )

        torch.save({'epoch': trainer.epoch,
                    'steps': trainer.steps,
                    'score_model_state_dict': trainer.score_model.state_dict(),
                    'optimizer_state_dict': trainer.optimizer.state_dict(),
                    }, os.path.join(trainer.log_dir, f'checkpoint/{trainer.epoch}.pt'))
        
        print(f"(Epoch: {trainer.epoch}) Successfully saved logs to: {trainer.log_dir}")
    trainer.epoch += 1

(Epoch: 0) Successfully saved logs to: runs/2023_05_31_00-31-39_MultiphaseTraining
(Epoch: 20) Successfully saved logs to: runs/2023_05_31_00-31-39_MultiphaseTraining
(Epoch: 40) Successfully saved logs to: runs/2023_05_31_00-31-39_MultiphaseTraining
(Epoch: 60) Successfully saved logs to: runs/2023_05_31_00-31-39_MultiphaseTraining
(Epoch: 80) Successfully saved logs to: runs/2023_05_31_00-31-39_MultiphaseTraining
(Epoch: 100) Successfully saved logs to: runs/2023_05_31_00-31-39_MultiphaseTraining
(Epoch: 120) Successfully saved logs to: runs/2023_05_31_00-31-39_MultiphaseTraining
(Epoch: 140) Successfully saved logs to: runs/2023_05_31_00-31-39_MultiphaseTraining
(Epoch: 160) Successfully saved logs to: runs/2023_05_31_00-31-39_MultiphaseTraining
(Epoch: 180) Successfully saved logs to: runs/2023_05_31_00-31-39_MultiphaseTraining
(Epoch: 200) Successfully saved logs to: runs/2023_05_31_00-31-39_MultiphaseTraining
