In [1]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable

from beartype import beartype
import datetime
import plotly.graph_objects as go
from tqdm import tqdm
import yaml

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from open3d.visualization.tensorboard_plugin import summary
from torch.utils.tensorboard import SummaryWriter

from e3nn import o3

from diffusion_edf.data import DemoSeqDataset, DemoSequence, TargetPoseDemo, PointCloud, SE3
from diffusion_edf.gnn_data import FeaturedPoints, merge_featured_points, GraphEdge, flatten_featured_points, set_featured_points_attribute, _featured_points_repr
from diffusion_edf import train_utils
from diffusion_edf import preprocess
from diffusion_edf import transforms
from diffusion_edf.feature_extractor import UnetFeatureExtractor
from diffusion_edf.tensor_field import TensorField
from diffusion_edf.radial_func import SinusoidalPositionEmbeddings
from diffusion_edf.equivariant_score_model import ScoreModel
from diffusion_edf.utils import sample_reference_points
from diffusion_edf.dist import diffuse_isotropic_se3, adjoint_inv_tr_isotropic_se3_score, diffuse_isotropic_se3_batched

torch.set_printoptions(precision=4, sci_mode=False)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
device = 'cuda:0'
# eval = True
compile = False

model_configs_dir = 'configs/test/model_configs.yaml'
train_configs_dir = 'configs/test/train_configs.yaml'
task_configs_dir = 'configs/test/task_configs.yaml'

with open(model_configs_dir) as file:
    model_configs = yaml.load(file, Loader=yaml.FullLoader)
with open(train_configs_dir) as file:
    train_configs = yaml.load(file, Loader=yaml.FullLoader)
with open(task_configs_dir) as file:
    task_configs = yaml.load(file, Loader=yaml.FullLoader)

train_configs['preprocess_config'].append({
    'name': 'Rescale',
    'kwargs': {'rescale_factor': 1/task_configs['unit_length']}
})

# Load demo

In [3]:
proc_fn = []
for proc in train_configs['preprocess_config']:
    proc_fn.append(
        getattr(preprocess, proc['name'])(**proc['kwargs'])
    )
proc_fn = Compose(proc_fn)


collate_fn = train_utils.get_collate_fn(task=train_configs['task_type'], proc_fn=proc_fn)
trainset = DemoSeqDataset(dataset_dir=train_configs['dataset_dir'], annotation_file=train_configs['annotation_file'], device=device)
train_dataloader = DataLoader(trainset, shuffle=True, collate_fn=collate_fn, batch_size=train_configs['n_batches'])

# Load Model

In [4]:
score_model = ScoreModel(**model_configs, deterministic=False).to(device=device)
if compile:
    raise NotImplementedError
# if eval:
#     score_model = score_model.eval()

optimizer = torch.optim.Adam(list(score_model.parameters()), **train_configs['optimizer_kwargs'])

ScoreModel: Initializing Score Head




ScoreModel: Initializing Key Feature Extractor
ScoreModel: Initializing Query Feature Extractor


In [5]:
resume_log_dir: Optional[str] = None
#resume_log_dir: Optional[str] = 'runs/2023_04_21_00-34-42'
resume_checkpoint_dir: Optional[str] = None
if resume_log_dir is not None:
    if resume_checkpoint_dir is None:
        resume_checkpoint_dir = sorted(os.listdir(os.path.join(resume_log_dir, f'checkpoint')), key= lambda f:int(f.rstrip('.pt')))[-1]
    resume_training = True
    if input(f"Enter 'y' if you want to resume training from checkpoint: {os.path.join(resume_log_dir, f'checkpoint', resume_checkpoint_dir)}") == 'y':
        pass
    else:
        raise ValueError()
else:
    resume_training = False
    resume_log_dir = os.path.join(train_configs['log_dir_root'], f"{datetime.datetime.now().strftime('%Y_%m_%d_%H-%M-%S')}")

writer = SummaryWriter(log_dir=resume_log_dir)
log_dir = writer.log_dir

if not os.path.exists(os.path.join(log_dir, f'checkpoint')):
    os.mkdir(os.path.join(log_dir, f'checkpoint'))

In [6]:
if resume_training:
    checkpoint = torch.load(os.path.join(log_dir, f'checkpoint', resume_checkpoint_dir))
    score_model.load_state_dict(checkpoint['score_model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    steps = checkpoint['steps']
    print(f"resume training from checkpoint: {os.path.join(log_dir, f'checkpoint', resume_checkpoint_dir)}")
    epoch = epoch + 1
else:
    epoch = 0
    steps = 0

# Main Loop

In [7]:
max_epochs = train_configs['max_epochs']
n_epochs_per_checkpoint = train_configs['n_epochs_per_checkpoint']
n_samples_x_ref = train_configs['n_samples_x_ref']
contact_radius = train_configs['contact_radius']

In [8]:
for epoch in range(epoch, max_epochs+1):
    for demo_batch in train_dataloader:
        B = len(demo_batch)
        assert B == 1, "Batch training is not supported yet."

        optimizer.zero_grad(set_to_none=True)

        scene_input, grasp_input, T_target = train_utils.flatten_batch(demo_batch=demo_batch) # T_target: (Nbatch, Ngrasps, 7)
        T_target = T_target.squeeze(0) # (B=1, N_poses=1, 7) -> (1,7) 

        min_time = torch.tensor([1e-3], dtype=T_target.dtype, device=T_target.device)
        max_time = torch.tensor([1.0], dtype=T_target.dtype, device=T_target.device)
        time_in = (min_time/max_time + torch.rand(1, dtype=T_target.dtype, device=T_target.device) * (1-min_time/max_time))*max_time   # Shape: (1,)
        #time_in = torch.exp(torch.rand_like(max_time) * (torch.log(max_time)-torch.log(min_time)) + torch.log(min_time))              # Shape: (1,)

        eps = time_in / 2   # Shape: (1,)
        std = torch.sqrt(time_in) * score_model.lin_mult   # Shape: (1,)
        x_ref, n_neighbors = sample_reference_points(
            src_points = PointCloud(points=scene_input.x, colors=scene_input.f).transformed(
                                    SE3(T_target).inv(), squeeze=True
                                    ).points, 
            dst_points = grasp_input.x, 
            r=contact_radius, 
            n_samples=n_samples_x_ref
        )
        # T, delta_T, (gt_ang_score, gt_lin_score), (gt_ang_score_ref, gt_lin_score_ref) = diffuse_isotropic_se3(T0 = T_target, eps=eps, std=std, x_ref=x_ref, double_precision=True)
        T, delta_T, (gt_ang_score, gt_lin_score), (gt_ang_score_ref, gt_lin_score_ref) = diffuse_isotropic_se3_batched(T0 = T_target, eps=eps, std=std, x_ref=x_ref, double_precision=True)
        T, delta_T, (gt_ang_score, gt_lin_score), (gt_ang_score_ref, gt_lin_score_ref) = T.squeeze(-2), delta_T.squeeze(-2), (gt_ang_score.squeeze(-2), gt_lin_score.squeeze(-2)), (gt_ang_score_ref.squeeze(-2), gt_lin_score_ref.squeeze(-2))
        # T: (nT, 7) || delta_T: (nT, 7) || gt_*_score_*: (nT, 3) ||
        # Note that nT = n_samples_x_ref * nT_target  ||   nT_target = 1


        time_in_ = time_in.repeat(len(T))
        (ang_score, lin_score), (scene_out, grasp_out) = score_model(Ts=T, time=time_in_,
                                                                     key_pcd=scene_input, 
                                                                     query_pcd=grasp_input, 
                                                                     extract_features = True,
                                                                     debug = True)

        with torch.no_grad():
            ang_score_ref, lin_score_ref = adjoint_inv_tr_isotropic_se3_score(x_ref=-x_ref, ang_score=ang_score, lin_score=lin_score)
        target_ang_score = gt_ang_score * torch.sqrt(time_in)
        target_lin_score = gt_lin_score * torch.sqrt(time_in)
        target_ang_score_ref = gt_ang_score_ref * torch.sqrt(time_in)
        target_lin_score_ref = gt_lin_score_ref * torch.sqrt(time_in)


        ang_score_diff = target_ang_score - ang_score
        lin_score_diff = target_lin_score - lin_score
        # ang_loss = torch.norm(ang_score_diff, dim=-1).mean(dim=-1)
        # lin_loss = torch.norm(lin_score_diff * lin_mult, dim=-1).mean(dim=-1)
        ang_loss = torch.sum(torch.square(ang_score_diff), dim=-1).mean(dim=-1)
        lin_loss = torch.sum(torch.square(lin_score_diff * score_model.lin_mult), dim=-1).mean(dim=-1)
        loss = ang_loss + lin_loss

        loss.backward()
        optimizer.step()


        with torch.no_grad():
            writer.add_scalar(tag="Loss/train", scalar_value=loss.item(), global_step=steps)
            writer.add_scalar(tag="Loss/angular", scalar_value=ang_loss.item(), global_step=steps)
            writer.add_scalar(tag="Loss/linear", scalar_value=lin_loss.item(), global_step=steps)

            target_norm_ang, target_norm_lin = torch.norm(target_ang_score.detach(), dim=-1), torch.norm(target_lin_score.detach(), dim=-1) # Shape: (Nbatch, ), (Nbatch, )
            score_norm_ang, score_norm_lin = torch.norm(ang_score.detach(), dim=-1), torch.norm(lin_score.detach(), dim=-1)         # Shape: (Nbatch, ), (Nbatch, )
            writer.add_scalar(tag="norm/target_ang", scalar_value=target_norm_ang.mean(dim=-1).item(), global_step=steps)
            writer.add_scalar(tag="norm/target_lin", scalar_value=target_norm_lin.mean(dim=-1).item(), global_step=steps)
            writer.add_scalar(tag="norm/inferred_ang", scalar_value=score_norm_ang.mean(dim=-1).item(), global_step=steps)
            writer.add_scalar(tag="norm/inferred_lin", scalar_value=score_norm_lin.mean(dim=-1).item(), global_step=steps)

            target_norm_ang_ref, target_norm_lin_ref = torch.norm(target_ang_score_ref.detach(), dim=-1), torch.norm(target_lin_score_ref.detach(), dim=-1) # Shape: (Nbatch, ), (Nbatch, )
            score_norm_ang_ref, score_norm_lin_ref = torch.norm(ang_score_ref.detach(), dim=-1), torch.norm(lin_score_ref.detach(), dim=-1)         # Shape: (Nbatch, ), (Nbatch, )
            writer.add_scalar(tag="norm_ref/target_ang", scalar_value=target_norm_ang_ref.mean(dim=-1).item(), global_step=steps)
            writer.add_scalar(tag="norm_ref/target_lin", scalar_value=target_norm_lin_ref.mean(dim=-1).item(), global_step=steps)
            writer.add_scalar(tag="norm_ref/inferred_ang", scalar_value=score_norm_ang_ref.mean(dim=-1).item(), global_step=steps)
            writer.add_scalar(tag="norm_ref/inferred_lin", scalar_value=score_norm_lin_ref.mean(dim=-1).item(), global_step=steps)

            dp_align_ang = torch.einsum('...i,...i->...', ang_score.detach(), target_ang_score.detach()) # Shape: (Nbatch, )
            dp_align_lin = torch.einsum('...i,...i->...', lin_score.detach(), target_lin_score.detach()) # Shape: (Nbatch, )
            dp_align_ang_normalized = dp_align_ang / target_norm_ang / score_norm_ang # Shape: (Nbatch, )
            dp_align_lin_normalized = dp_align_lin / target_norm_lin / score_norm_lin # Shape: (Nbatch, )
            writer.add_scalar(tag="alignment/unnormalized/ang", scalar_value=dp_align_ang.mean(dim=-1).item(), global_step=steps)
            writer.add_scalar(tag="alignment/unnormalized/lin", scalar_value=dp_align_lin.mean(dim=-1).item(), global_step=steps)
            writer.add_scalar(tag="alignment/normalized/ang", scalar_value=dp_align_ang_normalized.mean(dim=-1).item(), global_step=steps)
            writer.add_scalar(tag="alignment/normalized/lin", scalar_value=dp_align_lin_normalized.mean(dim=-1).item(), global_step=steps)

        steps += 1

    if epoch % n_epochs_per_checkpoint == 0:
        torch.save({'epoch': epoch,
                    'steps': steps,
                    'score_model_state_dict': score_model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    }, os.path.join(log_dir, f'checkpoint/{epoch}.pt'))

        with torch.no_grad():
            scene_pcd = PointCloud(points=scene_input.x, colors=scene_input.f)
            grasp_pcd = PointCloud(points=grasp_input.x, colors=grasp_input.f)
            target_pose_pcd = PointCloud.merge(
                scene_pcd,
                grasp_pcd.transformed(SE3(T_target), squeeze=True),
            )
            diffused_pose_pcd = PointCloud.merge(
                scene_pcd,
                grasp_pcd.transformed(SE3(T))[0],
            )
            
            ##### Query Summary #####
            query_weight, query_points, query_point_batch = grasp_out.w.detach(), grasp_out.x.detach(), grasp_out.b.detach(), 
            batch_vis_idx = (query_point_batch == 0).nonzero().squeeze(-1)
            query_weight, query_points = query_weight[batch_vis_idx], query_points[batch_vis_idx]

            N_repeat = 500
            query_points_colors = torch.tensor([0.01, 1., 1.], device=query_weight.device, dtype=query_weight.dtype).expand(N_repeat, 1, 3) * query_weight[None, :, None]
            r_query_ball = 0.5

            ball = torch.randn(N_repeat,1,3, device=query_points.device, dtype=query_points.dtype)
            ball = ball/ball.norm(dim=-1, keepdim=True) * r_query_ball
            query_points = (query_points + ball).reshape(-1,3)
            query_points_colors = query_points_colors.reshape(-1,3)

            #########################
            scene_attn_pcd = PointCloud(points=scene_out.x.detach().cpu(), 
                                        colors=scene_out.w.detach().cpu(),
                                        cmap='magma')
        
        writer.add_3d(
            tag = "Target Pose",
            data = {
                "vertex_positions": target_pose_pcd.points.cpu(),
                "vertex_colors": target_pose_pcd.colors.cpu(),  # (N, 3)
            },
            step=epoch//n_epochs_per_checkpoint,
        )

        writer.add_3d(
            tag = "Diffused Pose",
            data = {
                "vertex_positions": diffused_pose_pcd.points.cpu(),
                "vertex_colors": diffused_pose_pcd.colors.cpu(),  # (N, 3)
            },
            step=epoch//n_epochs_per_checkpoint,
            description=f"Diffuse time: {time_in.item()} || eps: {eps.item()} || std: {std.item()}",
        )

        writer.add_3d(
            tag = "Grasp",
            data = {
                "vertex_positions": grasp_pcd.points.cpu(),
                "vertex_colors": grasp_pcd.colors.cpu(),  # (N, 3)
            },
            step=epoch//n_epochs_per_checkpoint,
        )

        writer.add_3d(
            tag = "Query points",
            data = {
                # "vertex_positions": query_points.repeat(max(int(1000//len(query_points)),1),1).cpu(),      # There is a bug with too small number of points so repeat
                # "vertex_colors": query_points_colors.repeat(max(int(1000//len(query_points)),1),1).cpu(),  # (N, 3)
                "vertex_positions": query_points.cpu(),      # There is a bug with too small number of points so repeat
                "vertex_colors": query_points_colors.cpu(),  # (N, 3)
            },
            step=epoch//n_epochs_per_checkpoint,
        )

        writer.add_3d(
            tag = "Scene Attention",
            data = {
                "vertex_positions": scene_attn_pcd.points.cpu(),
                "vertex_colors": scene_attn_pcd.colors.cpu(),  # (N, 3)
            },
            step=epoch//n_epochs_per_checkpoint,
        )

        
        print(f"(Epoch: {epoch}) Successfully saved logs to: {log_dir}")

(Epoch: 0) Successfully saved logs to: runs/2023_05_25_13-47-46
(Epoch: 20) Successfully saved logs to: runs/2023_05_25_13-47-46
(Epoch: 40) Successfully saved logs to: runs/2023_05_25_13-47-46
(Epoch: 60) Successfully saved logs to: runs/2023_05_25_13-47-46
(Epoch: 80) Successfully saved logs to: runs/2023_05_25_13-47-46
(Epoch: 100) Successfully saved logs to: runs/2023_05_25_13-47-46
(Epoch: 120) Successfully saved logs to: runs/2023_05_25_13-47-46
(Epoch: 140) Successfully saved logs to: runs/2023_05_25_13-47-46
(Epoch: 160) Successfully saved logs to: runs/2023_05_25_13-47-46
(Epoch: 180) Successfully saved logs to: runs/2023_05_25_13-47-46


In [None]:
score_model.score_head.key_tensor_field.time_emb_dim