In [1]:
import os
os.environ["PYTORCH_JIT_USE_NNC_NOT_NVFUSER"] = "1"
from typing import List, Tuple, Optional, Union, Iterable
import math

from beartype import beartype
import datetime
import plotly.graph_objects as go
from tqdm import tqdm
import yaml

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
from e3nn import o3

from diffusion_edf.data import DemoSeqDataset, DemoSequence, TargetPoseDemo, PointCloud, SE3
from diffusion_edf.gnn_data import FeaturedPoints, merge_featured_points, GraphEdge, flatten_featured_points, set_featured_points_attribute, _featured_points_repr
from diffusion_edf import train_utils
from diffusion_edf import transforms
from diffusion_edf.utils import sample_reference_points
from diffusion_edf.dist import diffuse_isotropic_se3, adjoint_inv_tr_isotropic_se3_score, diffuse_isotropic_se3_batched
from diffusion_edf.point_attentive_score_model import PointAttentiveScoreModel
from diffusion_edf.trainer import DiffusionEdfTrainer

torch.set_printoptions(precision=4, sci_mode=False)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
configs_root_dir = 'configs/pick_highres'
# configs_root_dir = 'configs/pick_lowres'
train_configs_file = 'train_configs.yaml'
task_configs_file = 'task_configs.yaml'
trainer = DiffusionEdfTrainer(configs_root_dir=configs_root_dir,
                              train_configs_file=train_configs_file,
                              task_configs_file=task_configs_file)
init_epoch = trainer.init(
    log_name = trainer.get_current_time(postfix="Stable_Highres"),
    resume_training = False,
)

ScoreModel: Initializing Key Feature Extractor




ScoreModel: Initializing Query Model
ScoreModel: Initializing Score Head


# Main Loop

In [3]:
for epoch in range(init_epoch, trainer.max_epochs+1):
    for n, demo_batch in enumerate(trainer.trainloader):
        B = len(demo_batch)
        assert B == 1, "Batch training is not supported yet."

        scene_input, grasp_input, T_target = train_utils.flatten_batch(demo_batch=demo_batch) # T_target: (Nbatch, Ngrasps, 7)
        T_target = T_target.squeeze(0) # (B=1, N_poses=1, 7) -> (1,7) 

        save_checkpoint = (epoch % trainer.n_epochs_per_checkpoint == 0) and n == len(trainer.trainloader)-1
        trainer.train_once(
            T_target=T_target,
            scene_input=scene_input,
            grasp_input=grasp_input,
            epoch=epoch,
            save_checkpoint = save_checkpoint,
            checkpoint_count = epoch // trainer.n_epochs_per_checkpoint
        )

(Epoch: 0) Successfully saved logs to: runs/2023_06_01_18-07-22_Stable_Highres
(Epoch: 20) Successfully saved logs to: runs/2023_06_01_18-07-22_Stable_Highres
(Epoch: 40) Successfully saved logs to: runs/2023_06_01_18-07-22_Stable_Highres
(Epoch: 60) Successfully saved logs to: runs/2023_06_01_18-07-22_Stable_Highres
(Epoch: 80) Successfully saved logs to: runs/2023_06_01_18-07-22_Stable_Highres
(Epoch: 100) Successfully saved logs to: runs/2023_06_01_18-07-22_Stable_Highres
(Epoch: 120) Successfully saved logs to: runs/2023_06_01_18-07-22_Stable_Highres
(Epoch: 140) Successfully saved logs to: runs/2023_06_01_18-07-22_Stable_Highres
(Epoch: 160) Successfully saved logs to: runs/2023_06_01_18-07-22_Stable_Highres
(Epoch: 180) Successfully saved logs to: runs/2023_06_01_18-07-22_Stable_Highres
(Epoch: 200) Successfully saved logs to: runs/2023_06_01_18-07-22_Stable_Highres
