In [1]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable

from beartype import beartype
import datetime
import plotly.graph_objects as go
from tqdm import tqdm
import yaml

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from e3nn import o3

from diffusion_edf.data import DemoSeqDataset, DemoSequence, TargetPoseDemo, PointCloud, SE3
from diffusion_edf.gnn_data import FeaturedPoints, merge_featured_points, GraphEdge, flatten_featured_points, set_featured_points_attribute, _featured_points_repr
from diffusion_edf import train_utils
from diffusion_edf import preprocess
from diffusion_edf import transforms
from diffusion_edf.multiscale_score_model import ScoreModel
from diffusion_edf.utils import sample_reference_points
from diffusion_edf.dist import diffuse_isotropic_se3, adjoint_inv_tr_isotropic_se3_score, diffuse_isotropic_se3_batched

torch.set_printoptions(precision=4, sci_mode=False)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
device = 'cuda:0'
# eval = True
compile = False

model_configs_dir = 'configs/test/multiscale_score_model_configs.yaml'
train_configs_dir = 'configs/test/train_configs.yaml'
task_configs_dir = 'configs/test/task_configs.yaml'

with open(model_configs_dir) as file:
    model_configs = yaml.load(file, Loader=yaml.FullLoader)
with open(train_configs_dir) as file:
    train_configs = yaml.load(file, Loader=yaml.FullLoader)
with open(task_configs_dir) as file:
    task_configs = yaml.load(file, Loader=yaml.FullLoader)

train_configs['preprocess_config'].append({
    'name': 'Rescale',
    'kwargs': {'rescale_factor': 1/task_configs['unit_length']}
})

# Load demo

In [3]:
proc_fn = []
for proc in train_configs['preprocess_config']:
    proc_fn.append(
        getattr(preprocess, proc['name'])(**proc['kwargs'])
    )
proc_fn = Compose(proc_fn)


collate_fn = train_utils.get_collate_fn(task=train_configs['task_type'], proc_fn=proc_fn)
trainset = DemoSeqDataset(dataset_dir=train_configs['dataset_dir'], annotation_file=train_configs['annotation_file'], device=device)
train_dataloader = DataLoader(trainset, shuffle=True, collate_fn=collate_fn, batch_size=train_configs['n_batches'])

# Load Model

In [4]:
score_model = ScoreModel(**model_configs, deterministic=False).to(device=device)
if compile:
    raise NotImplementedError
# if eval:
#     score_model = score_model.eval()

optimizer = torch.optim.Adam(list(score_model.parameters()), **train_configs['optimizer_kwargs'])

ScoreModel: Initializing Key Feature Extractor




ScoreModel: Initializing Query Model
ScoreModel: Initializing Score Head


In [5]:
resume_log_dir: Optional[str] = None
#resume_log_dir: Optional[str] = 'runs/2023_04_21_00-34-42'
resume_checkpoint_dir: Optional[str] = None
if resume_log_dir is not None:
    if resume_checkpoint_dir is None:
        resume_checkpoint_dir = sorted(os.listdir(os.path.join(resume_log_dir, f'checkpoint')), key= lambda f:int(f.rstrip('.pt')))[-1]
    resume_training = True
    if input(f"Enter 'y' if you want to resume training from checkpoint: {os.path.join(resume_log_dir, f'checkpoint', resume_checkpoint_dir)}") == 'y':
        pass
    else:
        raise ValueError()
    
    checkpoint = torch.load(os.path.join(resume_log_dir, f'checkpoint', resume_checkpoint_dir))
    score_model.load_state_dict(checkpoint['score_model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    steps = checkpoint['steps']
    print(f"resume training from checkpoint: {os.path.join(resume_log_dir, f'checkpoint', resume_checkpoint_dir)}")
    epoch = epoch + 1

    log_dir = resume_log_dir
else:
    resume_training = False
    log_dir = os.path.join(train_configs['log_dir_root'], f"{datetime.datetime.now().strftime('%Y_%m_%d_%H-%M-%S')}")
    epoch = 0
    steps = 0

In [6]:
writer = train_utils.LazyInitSummaryWriter(log_dir=log_dir, 
                                           resume=resume_training,
                                           config_files=[model_configs_dir, 
                                                         train_configs_dir, 
                                                         task_configs_dir])

# Main Loop

In [7]:
max_epochs = train_configs['max_epochs']
n_epochs_per_checkpoint = train_configs['n_epochs_per_checkpoint']
n_samples_x_ref = train_configs['n_samples_x_ref']
contact_radius = train_configs['contact_radius']

In [8]:
for epoch in range(epoch, max_epochs+1):
    for demo_batch in train_dataloader:
        B = len(demo_batch)
        assert B == 1, "Batch training is not supported yet."

        optimizer.zero_grad(set_to_none=True)

        scene_input, grasp_input, T_target = train_utils.flatten_batch(demo_batch=demo_batch) # T_target: (Nbatch, Ngrasps, 7)
        T_target = T_target.squeeze(0) # (B=1, N_poses=1, 7) -> (1,7) 

        min_time = torch.tensor([1e-3], dtype=T_target.dtype, device=T_target.device)
        max_time = torch.tensor([1.0], dtype=T_target.dtype, device=T_target.device)
        time_in = (min_time/max_time + torch.rand(1, dtype=T_target.dtype, device=T_target.device) * (1-min_time/max_time))*max_time   # Shape: (1,)
        #time_in = torch.exp(torch.rand_like(max_time) * (torch.log(max_time)-torch.log(min_time)) + torch.log(min_time))              # Shape: (1,)

        eps = time_in / 2   # Shape: (1,)
        std = torch.sqrt(time_in) * score_model.lin_mult   # Shape: (1,)
        x_ref, n_neighbors = sample_reference_points(
            src_points = PointCloud(points=scene_input.x, colors=scene_input.f).transformed(
                                    SE3(T_target).inv(), squeeze=True
                                    ).points, 
            dst_points = grasp_input.x, 
            r=contact_radius, 
            n_samples=n_samples_x_ref
        )
        # T, delta_T, (gt_ang_score, gt_lin_score), (gt_ang_score_ref, gt_lin_score_ref) = diffuse_isotropic_se3(T0 = T_target, eps=eps, std=std, x_ref=x_ref, double_precision=True)
        T, delta_T, (gt_ang_score, gt_lin_score), (gt_ang_score_ref, gt_lin_score_ref) = diffuse_isotropic_se3_batched(T0 = T_target, eps=eps, std=std, x_ref=x_ref, double_precision=True)
        T, delta_T, (gt_ang_score, gt_lin_score), (gt_ang_score_ref, gt_lin_score_ref) = T.squeeze(-2), delta_T.squeeze(-2), (gt_ang_score.squeeze(-2), gt_lin_score.squeeze(-2)), (gt_ang_score_ref.squeeze(-2), gt_lin_score_ref.squeeze(-2))
        # T: (nT, 7) || delta_T: (nT, 7) || gt_*_score_*: (nT, 3) ||
        # Note that nT = n_samples_x_ref * nT_target  ||   nT_target = 1


        time_in = time_in.repeat(len(T))

        loss, fp_info, tensor_info, statistics = score_model.get_train_loss(Ts=T, time=time_in, key_pcd=scene_input, query_pcd=grasp_input,
                                                                            target_ang_score=gt_ang_score, target_lin_score=gt_lin_score)
        # scene_out: FeaturedPoints = fp_info['key_fp']
        grasp_out: FeaturedPoints = fp_info['query_fp']

        loss.backward()
        optimizer.step()

        with torch.no_grad():
            for tag, scalar_value in statistics.items():
                writer.add_scalar(tag=tag, scalar_value=scalar_value, global_step=steps)
        steps += 1

    if epoch % n_epochs_per_checkpoint == 0:
        with torch.no_grad():
            scene_pcd = PointCloud(points=scene_input.x, colors=scene_input.f)
            grasp_pcd = PointCloud(points=grasp_input.x, colors=grasp_input.f)
            target_pose_pcd = PointCloud.merge(
                scene_pcd,
                grasp_pcd.transformed(SE3(T_target), squeeze=True),
            )
            diffused_pose_pcd = PointCloud.merge(
                scene_pcd,
                grasp_pcd.transformed(SE3(T))[0],
            )
            # scene_attn_pcd = PointCloud(points=scene_out.x.detach().cpu(), 
            #                             colors=scene_out.w.detach().cpu(),
            #                             cmap='magma')
            grasp_attn_pcd = PointCloud(points=grasp_out.x.detach().cpu(), 
                                        colors=grasp_out.w.detach().cpu(),
                                        cmap='magma')
        
            query_weight, query_points, query_point_batch = grasp_out.w.detach(), grasp_out.x.detach(), grasp_out.b.detach(), 
            batch_vis_idx = (query_point_batch == 0).nonzero().squeeze(-1)
            query_weight, query_points = query_weight[batch_vis_idx], query_points[batch_vis_idx]

            N_repeat = 500
            query_points_colors = torch.tensor([0.01, 1., 1.], device=query_weight.device, dtype=query_weight.dtype).expand(N_repeat, 1, 3) * query_weight[None, :, None]
            r_query_ball = 0.5

            ball = torch.randn(N_repeat,1,3, device=query_points.device, dtype=query_points.dtype)
            ball = ball/ball.norm(dim=-1, keepdim=True) * r_query_ball
            query_points = (query_points + ball).reshape(-1,3)
            query_points_colors = query_points_colors.reshape(-1,3)

        # writer.add_3d(
        #     tag = "Scene Attention",
        #     data = {
        #         "vertex_positions": scene_attn_pcd.points.cpu(),
        #         "vertex_colors": scene_attn_pcd.colors.cpu(),  # (N, 3)
        #     },
        #     step=epoch//n_epochs_per_checkpoint,
        # )

        writer.add_3d(
            tag = "Grasp Attention",
            data = {
                # "vertex_positions": query_points.repeat(max(int(1000//len(query_points)),1),1).cpu(),      # There is a bug with too small number of points so repeat
                # "vertex_colors": query_points_colors.repeat(max(int(1000//len(query_points)),1),1).cpu(),  # (N, 3)
                "vertex_positions": query_points.cpu(),      # There is a bug with too small number of points so repeat
                "vertex_colors": query_points_colors.cpu(),  # (N, 3)
            },
            step=epoch//n_epochs_per_checkpoint,
        )

        writer.add_3d(
            tag = "Target Pose",
            data = {
                "vertex_positions": target_pose_pcd.points.cpu(),
                "vertex_colors": target_pose_pcd.colors.cpu(),  # (N, 3)
            },
            step=epoch//n_epochs_per_checkpoint,
        )

        writer.add_3d(
            tag = "Diffused Pose",
            data = {
                "vertex_positions": diffused_pose_pcd.points.cpu(),
                "vertex_colors": diffused_pose_pcd.colors.cpu(),  # (N, 3)
            },
            step=epoch//n_epochs_per_checkpoint,
            description=f"Diffuse time: {time_in[0].item()} || eps: {eps.item()} || std: {std.item()}",
        )

        writer.add_3d(
            tag = "Grasp",
            data = {
                "vertex_positions": grasp_pcd.points.cpu(),
                "vertex_colors": grasp_pcd.colors.cpu(),  # (N, 3)
            },
            step=epoch//n_epochs_per_checkpoint,
        )

        torch.save({'epoch': epoch,
                    'steps': steps,
                    'score_model_state_dict': score_model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    }, os.path.join(log_dir, f'checkpoint/{epoch}.pt'))
        
        print(f"(Epoch: {epoch}) Successfully saved logs to: {log_dir}")

(Epoch: 0) Successfully saved logs to: runs/2023_05_28_00-44-00
(Epoch: 20) Successfully saved logs to: runs/2023_05_28_00-44-00
(Epoch: 40) Successfully saved logs to: runs/2023_05_28_00-44-00
(Epoch: 60) Successfully saved logs to: runs/2023_05_28_00-44-00
(Epoch: 80) Successfully saved logs to: runs/2023_05_28_00-44-00
(Epoch: 100) Successfully saved logs to: runs/2023_05_28_00-44-00
(Epoch: 120) Successfully saved logs to: runs/2023_05_28_00-44-00
(Epoch: 140) Successfully saved logs to: runs/2023_05_28_00-44-00
(Epoch: 160) Successfully saved logs to: runs/2023_05_28_00-44-00
