In [1]:
# !ml cuda/11.8.0

import os
## For setup
import torch
from configs.get_config import get_config_from_yml
from GINN.shape_boundary_helper import ShapeBoundaryHelper
from GINN.helpers.mp_manager import MPManager
from GINN.helpers.timer_helper import TimerHelper
from GINN.morse.scc_surfacenet_manager import SCCSurfaceNetManager
from GINN.problem_sampler import ProblemSampler
from GINN.visualize.plotter_3d import Plotter3d
from train.train_utils.autoclip import AutoClip
from utils import get_model, get_stateless_net_with_partials

from neural_clbf.controllers.simple_neural_cbf_controller import SimpleNeuralCBFController
from neural_clbf.systems.simple3d import Simple3DRobot 

## For extracting and plotting a mesh
import k3d
from notebooks.notebook_utils import get_mesh_for_latent

## For running a training loop
import einops
from tqdm import trange
from models.model_utils import tensor_product_xz
from train.losses import closest_shape_diversity_loss, eikonal_loss, envelope_loss, interface_loss, normal_loss_euclidean, obstacle_interior_loss, strain_curvature_loss
from train.train_utils.latent_sampler import sample_new_z
from utils import set_all_seeds

(CVXPY) Dec 06 11:08:14 AM: Encountered unexpected exception importing solver DIFFCP:
ImportError('diffcp >= 1.0.15 is required')


  warn("Could not import F16 module; is AeroBench installed?")
  _torch_pytree._register_pytree_node(
  warn("Could not import HW module; is ROS installed?")


In [2]:
set_all_seeds(5)
## Set the device
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)

## Read the config
yml_path = 'configs/horse.yml'
config = get_config_from_yml(yml_path)
config['device'] = device
print(config['dataset_dir'])

print("DEVICE", device)
print("MODEL:", config['model'])
print("ACTIVATION", config.get('activation', None))

## Create the model and stateless functions and load a checkpoint
model = get_model(config)
netp = get_stateless_net_with_partials(model, use_x_and_z_arg=config['use_x_and_z_arg'])
# model.load_state_dict(torch.load('_quickstart/trained_model_3d.pt', map_location=device))

## Create different helpers for ...
## ... the problem definition
p_sampler = ProblemSampler(config)
## ... multiprocessing to create plots on a non-blocking thread
mp_manager = MPManager(config)
## ... recording timings
timer_helper = TimerHelper(config, lock=mp_manager.get_lock())
mp_manager.set_timer_helper(timer_helper)  ## weak circular reference
## ... plotting
plotter = Plotter3d(config)
## ... connectedness computation
scc_manager = SCCSurfaceNetManager(config, netp, mp_manager, plotter, timer_helper, p_sampler, device)
## ... sampling from the shape boundary
shapeb_helper = ShapeBoundaryHelper(config, netp, mp_manager, plotter, timer_helper, p_sampler.sample_from_interface()[0], device)
## ... clipping the gradients
auto_clip = AutoClip(config)

/scratch/rhm4nj/cral/cral-ginn/depth/tmp_outs
DEVICE cuda
MODEL: cond_siren
ACTIVATION None


In [3]:
torch.set_default_device('cpu')

controller_period = 0.05
simulation_dt = 0.01
nominal_params = {}
scenarios = [
    nominal_params,
]

# Define the dynamics model
dynamics_model = Simple3DRobot(
    nominal_params,
    dt=simulation_dt,
    controller_dt=controller_period,
    scenarios=scenarios,
)

# Initialize the controller
cbf_controller = SimpleNeuralCBFController(
    dynamics_model,
    scenarios,
    model,
    cbf_lambda=1.0,
    cbf_relaxation_penalty=1
)

torch.set_default_device(device)

{}


In [4]:
z = torch.tensor([-0.1])
mesh_checkpoint = get_mesh_for_latent(netp.f_, netp.params_, z, config['bounds'], mc_resolution=32, device=device, chunks=1)

fig = k3d.plot()
fig += k3d.mesh(*mesh_checkpoint, color=0xff0000, side='double')
fig.display()
fig.camera_auto_fit = False
fig.camera = [0.8042741481976844,
            -1.040350835893895,
            0.7038650223301532,
            0.08252720725551285,
            -0.08146462547370059,
            -0.1973267630672968,
            -0.3986658507677483,
            0.39231188503442904,
            0.8289492893370278]

Output()

In [5]:
from torch.utils.tensorboard import SummaryWriter
import datetime

ext = ""
if config['lambda_obst'] > 0:
    ext = "obst_" + str(config['lambda_obst']) + "_"
log_dir = "all_runs/runs_cubehole/" + ext + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
writer = SummaryWriter(log_dir)

opt = torch.optim.Adam(model.parameters(), lr=config['lr'])
# opt.load_state_dict(torch.load('_quickstart/opt_3d.pt', map_location=device))
z = sample_new_z(config, is_init=True)

print(f'Initial z: {z}')
p_surface = None
cur_plot_epoch = 0
plots_dict_at_last_epoch = None
log_history_dict = {}

config['max_epochs'] = 100
prev_epoch = 0

Initial z: tensor([[-0.1000]], device='cuda:0')


In [8]:
config['max_epochs'] = 10
save_interval = 1000

for epoch in (pbar := trange(config['max_epochs'], leave=True, position=0, colour="yellow")):
    epoch += prev_epoch

    mp_manager.update_epoch(epoch)
    opt.zero_grad()
    
    # plotter.reset_output(p_sampler.recalc_output(netp.f_, netp.params_, z), epoch=epoch)
    # plotter.plot_shape(p_sampler.constr_pts_dict)
    
    loss_scc = torch.tensor(0.0)
    if config['lambda_scc'] > 0:
        success, res_tup = scc_manager.get_scc_pts_to_penalize(z, epoch)
        if success:
            p_penalize, p_penalties = res_tup
            print(f'penalize DCs with {len(p_penalize)} points')
            y_saddles_opt = model(p_penalize.data, p_penalize.z_in(z)).squeeze(1)
            loss_scc = config['lambda_scc'] *  (y_saddles_opt * p_penalties.data).mean()

    ## Design region loss                
    loss_env = torch.tensor(0.0)
    if config['lambda_env'] > 0:
        ys_env = model(*tensor_product_xz(p_sampler.sample_from_envelope(), z)).squeeze(1)
        loss_env = config['lambda_env'] * envelope_loss(ys_env)

    ## Interface loss
    loss_if = torch.tensor(0.0)
    if config['lambda_bc'] > 0:
        ys_BC = model(*tensor_product_xz(p_sampler.sample_from_interface()[0], z)).squeeze(1)
        loss_if = config['lambda_bc'] * interface_loss(ys_BC)
        
    ## Interface normal loss
    loss_if_normal = torch.tensor(0.0)
    if config['lambda_normal'] > 0:
        pts_normal, target_normal = p_sampler.sample_from_interface()
        ys_normal = netp.vf_x(*tensor_product_xz(pts_normal, z)).squeeze(1)
        loss_if_normal = config['lambda_normal'] * normal_loss_euclidean(ys_normal, torch.cat([target_normal for _ in range(config['batch_size'])]))

    ## Obstacle loss (for debugging purposes, it's not considered part of the envelope) TODO: do we leave it like this?
    loss_obst = torch.tensor(0.0)
    if config['lambda_obst'] > 0:
        ys_obst = model(*tensor_product_xz(p_sampler.sample_from_obstacles(), z))
        loss_obst = config['lambda_obst'] * obstacle_interior_loss(ys_obst)

    ## Sample points from the domain if necessary TODO: I think diversity doesnt need domain points anymore? TODO: can we move this up so that all the losses come after each other?
    if config['lambda_eikonal'] > 0 or config['lambda_div'] > 0 or config['lambda_descent']:
        xs_domain = p_sampler.sample_from_domain()

    ## Eikonal loss    
    loss_eikonal = torch.tensor(0.0)
    if config['lambda_eikonal'] > 0:
        y_x_eikonal = netp.vf_x(*tensor_product_xz(xs_domain, z))
        loss_eikonal = config['lambda_eikonal'] * eikonal_loss(y_x_eikonal)

    ## Sample points from the 0-levelset if necessary TODO: can we move this up?
    if config['lambda_div'] > 0 or config['lambda_curv'] > 0:
        if p_surface is None or epoch % config['recompute_surface_pts_every_n_epochs'] == 0:
            p_surface, weights_surf_pts = shapeb_helper.get_surface_pts(z)
    
    ## Curvature loss
    loss_curv = torch.tensor(0.0)
    if config['lambda_curv'] > 0:
        if p_surface is None:
            print('No surface points found - skipping curvature loss')
        else:
            y_x_surf = netp.vf_x(p_surface.data, p_surface.z_in(z)).squeeze(1)
            y_xx_surf = netp.vf_xx(p_surface.data, p_surface.z_in(z)).squeeze(1)
            loss_curv = config['lambda_curv'] * strain_curvature_loss(y_x_surf, y_xx_surf, clip_max_value=config['strain_curvature_clip_max'],
                                                                            weights=weights_surf_pts)
    ## Diversity loss
    loss_div = torch.tensor(0.0)
    if config['lambda_div'] > 0 and config['batch_size'] > 1:
        if p_surface is None:
            print('No surface points found - skipping diversity loss')
        else:
            y_div = model(*tensor_product_xz(p_surface.data, z)).squeeze(1)  # [(bz k)] whereas k is n_surface_points; evaluate model at all surface points for each shape
            loss_div = config['lambda_div'] * closest_shape_diversity_loss(einops.rearrange(y_div, '(bz k)-> bz k', bz=config['batch_size']), 
                                                                                weights=weights_surf_pts)
            if torch.isnan(loss_div) or torch.isinf(loss_div):
                print(f'NaN or Inf loss_div: {loss_div}')
                loss_div = torch.tensor(0.0) if torch.isnan(loss_div) or torch.isinf(loss_div) else loss_div 
    
    ## Descent loss
    loss_descent = torch.tensor(0.0)
    if config['lambda_descent'] > 0:
        cbf_controller.set_V_nn(model)

        xt, zt = tensor_product_xz(xs_domain, z)
        x = torch.cat([xt, zt], dim=-1)
        losses_list = cbf_controller.descent_loss(x)
        for _, l in losses_list:
            if not l.isnan():
                loss_descent += l

        loss_descent = config['lambda_descent']  * 1/config['lambda_descent'] * loss_descent

        print("DESCENT LOSS:", loss_descent)

    loss = loss_env + loss_if + loss_if_normal + loss_obst + loss_eikonal + loss_scc + loss_curv + loss_div + loss_descent
    # print(f'loss_env: {loss_env}')
    # print(f'loss_if: {loss_if}')
    # print(f'loss_if_normal: {loss_if_normal}')
    # print(f'loss_obst: {loss_obst}')
    # print(f'loss_eikonal: {loss_eikonal}')
    # print(f'loss_scc: {loss_scc}')
    # print(f'loss_curv: {loss_curv}')
    # print(f'loss_div: {loss_div}')
    
    losses = {
        "loss_env": loss_env,
        "loss_if": loss_if,
        "loss_if_normal": loss_if_normal,
        "loss_obst": loss_obst,
        "loss_eikonal": loss_eikonal,
        "loss_scc": loss_scc,
        "loss_curv": loss_curv,
        "loss_div": loss_div,
        "loss_descent": loss_descent,   
        "loss": loss,
    }

    ## Gradients with clipping
    loss.backward()
    grad_norm = auto_clip.grad_norm(model.parameters())
    if auto_clip.grad_clip_enabled:
        auto_clip.update_gradient_norm_history(grad_norm)
        torch.nn.utils.clip_grad_norm_(model.parameters(), auto_clip.get_clip_value())
        
    ## Update the parameters
    opt.step()

    for lname, l in losses.items():
        writer.add_scalar(f'Loss/{lname}', l, epoch)

    for name, param in model.named_parameters():
        writer.add_histogram(f'Weights/{name}', param, epoch)
        if param.grad is not None:
            writer.add_histogram(f'Gradients/{name}', param.grad, epoch)

    if (epoch % save_interval == 0 and epoch > 0) or (epoch - prev_epoch - 1) == config['max_epochs']:
        save_data = {
            # 'f_': netp.f_,              # Assuming this is a model or callable
            'params_': netp.params_,    # Model parameters or some state dict
            'z': z,                      # Latent tensor
            'bounds': config['bounds']
        }
        save_path_pts = os.path.join(log_dir, "pts")
        if not os.path.exists(save_path_pts): os.makedirs(save_path_pts)
        save_path = os.path.join(save_path_pts, f"nept_{epoch}.pt")
        torch.save(save_data, save_path)
    
    # Look at debugging plots
    # For this you have to enable plots in the config; note: this will slow down the training
    if mp_manager.are_plots_available_for_epoch(epoch):
        plots_dict_at_last_epoch = mp_manager.pop_plots_dict(epoch)

prev_epoch += config['max_epochs']

  0%|[33m          [0m| 0/10 [00:00<?, ?it/s]INFO:cp_helper:=== Recomputing the graph ===
INFO:cp_helper:(0) SUCCESS - PC Candidates: 3375
INFO:cp_helper:(0) SUCCESS
INFO:cp_helper:lr: 0.0001
Descending to CPs: 6.50e+00: 100%|██████████| 1000/1000 [00:07<00:00, 136.28it/s]
INFO:cp_helper:(0) POST grad norm - PC Candidates: 3590
INFO:cp_helper:(1) SUCCESS
INFO:cp_helper:y_x_mag: torch.Size([3470])tensor([4.3344e+00, 6.0642e+00, 2.4277e+00,  ..., 1.9214e-04, 1.4557e-04,
        9.4593e-05])
INFO:cp_helper:nof points with small enough gradient: 391
INFO:cp_helper:(2) SUCCESS
INFO:cp_helper:(2) POST clustering - PC Clusters: 205
INFO:cp_helper:(3) SUCCESS
INFO:cp_helper:(3) saddle points False
INFO:cp_helper:(4) SUCCESS
Connecting CPs: -1.80e+00: 100%|███████████| 1000/1000 [00:02<00:00, 469.60it/s]
INFO:cp_helper:(5) SUCCESS
INFO:cp_helper:(6) SUCCESS
INFO:cp_helper:(6) SUCCESS 205
INFO:SCCSurfaceNetManager:penalize DCs with 82 points inside envelope and 0 points outside


penalize DCs with 82 points


Flow to surface points: 8.27e-03: 100%|██████████| 200/200 [00:00<00:00, 294.39it/s]


tensor([[-0.1800, -0.0423,  0.2121, -0.1000],
        [-0.0826, -0.0420, -0.0744, -0.1000],
        [ 0.0845, -0.1438,  0.0669, -0.1000],
        ...,
        [-0.1628,  0.2629,  0.2822, -0.1000],
        [-0.0452,  0.0126, -0.0980, -0.1000],
        [-0.0188, -0.1085, -0.1433, -0.1000]], device='cuda:0')
DESCENT LOSS: tensor(3.7645e-06, device='cuda:0', grad_fn=<MulBackward0>)


 10%|[33m█         [0m| 1/10 [00:19<02:53, 19.24s/it]INFO:cp_helper:=== Recomputing the graph ===
INFO:cp_helper:(0) SUCCESS - PC Candidates: 3375
INFO:cp_helper:(0) SUCCESS
INFO:cp_helper:lr: 0.0001
Descending to CPs: 5.52e+00: 100%|██████████| 1000/1000 [00:07<00:00, 133.56it/s]
INFO:cp_helper:(0) POST grad norm - PC Candidates: 3580
INFO:cp_helper:(1) SUCCESS
INFO:cp_helper:y_x_mag: torch.Size([3398])tensor([1.9140e-01, 1.2501e-05, 1.2435e-05,  ..., 4.6501e-05, 1.6437e-04,
        5.4020e-05])
INFO:cp_helper:nof points with small enough gradient: 467
INFO:cp_helper:(2) SUCCESS
INFO:cp_helper:(2) POST clustering - PC Clusters: 225
INFO:cp_helper:(3) SUCCESS
INFO:cp_helper:(3) saddle points False
INFO:cp_helper:(4) SUCCESS
Connecting CPs: -1.63e+00: 100%|███████████| 1000/1000 [00:02<00:00, 479.87it/s]
INFO:cp_helper:(5) SUCCESS
INFO:cp_helper:(6) SUCCESS
INFO:cp_helper:(6) SUCCESS 225
INFO:SCCSurfaceNetManager:penalize DCs with 90 points inside envelope and 0 points outside


penalize DCs with 90 points
tensor([[-0.0938, -0.0421,  0.0369, -0.1000],
        [ 0.2368, -0.0316, -0.1213, -0.1000],
        [-0.1823,  0.3047,  0.2139, -0.1000],
        ...,
        [-0.2867,  0.3141,  0.1608, -0.1000],
        [ 0.0877,  0.0134, -0.2591, -0.1000],
        [-0.2553, -0.1855,  0.1374, -0.1000]], device='cuda:0')
DESCENT LOSS: tensor(-3.1003e-06, device='cuda:0', grad_fn=<MulBackward0>)


 20%|[33m██        [0m| 2/10 [00:37<02:28, 18.62s/it]INFO:cp_helper:=== Recomputing the graph ===
INFO:cp_helper:(0) SUCCESS - PC Candidates: 3375
INFO:cp_helper:(0) SUCCESS
INFO:cp_helper:lr: 0.0001
Descending to CPs: 3.70e+00: 100%|██████████| 1000/1000 [00:07<00:00, 134.94it/s]
INFO:cp_helper:(0) POST grad norm - PC Candidates: 3600
INFO:cp_helper:(1) SUCCESS
INFO:cp_helper:y_x_mag: torch.Size([3469])tensor([5.1704e-02, 6.5391e-06, 9.5375e-06,  ..., 5.2849e-04, 9.7036e-05,
        2.8932e-05])
INFO:cp_helper:nof points with small enough gradient: 535
INFO:cp_helper:(2) SUCCESS
INFO:cp_helper:(2) POST clustering - PC Clusters: 275
INFO:cp_helper:(3) SUCCESS
INFO:cp_helper:(3) saddle points False
INFO:cp_helper:(4) SUCCESS
Connecting CPs: -1.48e+00: 100%|███████████| 1000/1000 [00:02<00:00, 474.26it/s]
INFO:cp_helper:(5) SUCCESS
INFO:cp_helper:(6) SUCCESS
INFO:cp_helper:(6) SUCCESS 275
INFO:SCCSurfaceNetManager:penalize DCs with 103 points inside envelope and 0 points outside


penalize DCs with 103 points
tensor([[ 0.2361,  0.1688, -0.1996, -0.1000],
        [ 0.3660,  0.3747, -0.3154, -0.1000],
        [-0.2645, -0.1085,  0.2321, -0.1000],
        ...,
        [-0.0718,  0.0224, -0.0295, -0.1000],
        [ 0.2727,  0.0396, -0.2763, -0.1000],
        [-0.2807,  0.0867,  0.1668, -0.1000]], device='cuda:0')
DESCENT LOSS: tensor(-1.4646e-06, device='cuda:0', grad_fn=<MulBackward0>)


 30%|[33m███       [0m| 3/10 [00:56<02:11, 18.85s/it]INFO:cp_helper:=== Recomputing the graph ===
INFO:cp_helper:(0) SUCCESS - PC Candidates: 3375
INFO:cp_helper:(0) SUCCESS
INFO:cp_helper:lr: 0.0001
Descending to CPs: 2.06e+00: 100%|██████████| 1000/1000 [00:07<00:00, 135.72it/s]
INFO:cp_helper:(0) POST grad norm - PC Candidates: 3650
INFO:cp_helper:(1) SUCCESS
INFO:cp_helper:y_x_mag: torch.Size([3570])tensor([1.8468e+00, 1.8469e+00, 5.3152e-06,  ..., 9.3532e-06, 2.5149e-05,
        4.3784e-06])
INFO:cp_helper:nof points with small enough gradient: 666
INFO:cp_helper:(2) SUCCESS
INFO:cp_helper:(2) POST clustering - PC Clusters: 342
INFO:cp_helper:(3) SUCCESS
INFO:cp_helper:(3) saddle points False
INFO:cp_helper:(4) SUCCESS
Connecting CPs: -1.15e+00: 100%|███████████| 1000/1000 [00:02<00:00, 450.42it/s]
INFO:cp_helper:(5) SUCCESS
INFO:cp_helper:(6) SUCCESS
INFO:cp_helper:(6) SUCCESS 342
INFO:SCCSurfaceNetManager:penalize DCs with 154 points inside envelope and 0 points outside


penalize DCs with 154 points
tensor([[-0.3990, -0.2115, -0.0862, -0.1000],
        [-0.2964, -0.1472,  0.2021, -0.1000],
        [-0.3197, -0.2629, -0.0316, -0.1000],
        ...,
        [-0.0211, -0.0081, -0.1198, -0.1000],
        [-0.2041, -0.3822,  0.0887, -0.1000],
        [-0.0659, -0.0422, -0.0243, -0.1000]], device='cuda:0')
DESCENT LOSS: tensor(-2.2605e-06, device='cuda:0', grad_fn=<MulBackward0>)


 40%|[33m████      [0m| 4/10 [01:17<01:58, 19.82s/it]INFO:cp_helper:=== Recomputing the graph ===
INFO:cp_helper:(0) SUCCESS - PC Candidates: 3375
INFO:cp_helper:(0) SUCCESS
INFO:cp_helper:lr: 0.0001
Descending to CPs: 1.53e+00: 100%|██████████| 1000/1000 [00:07<00:00, 133.38it/s]
INFO:cp_helper:(0) POST grad norm - PC Candidates: 3717
INFO:cp_helper:(1) SUCCESS
INFO:cp_helper:y_x_mag: torch.Size([3640])tensor([3.0424e+00, 3.0424e+00, 1.0837e+00,  ..., 5.4153e-03, 4.2818e-04,
        1.0761e-04])
INFO:cp_helper:nof points with small enough gradient: 702
INFO:cp_helper:(2) SUCCESS
INFO:cp_helper:(2) POST clustering - PC Clusters: 375
INFO:cp_helper:(3) SUCCESS
INFO:cp_helper:(3) saddle points False
INFO:cp_helper:(4) SUCCESS
Connecting CPs: -1.02e+00: 100%|███████████| 1000/1000 [00:02<00:00, 467.41it/s]
INFO:cp_helper:(5) SUCCESS
INFO:cp_helper:(6) SUCCESS
INFO:cp_helper:(6) SUCCESS 375
INFO:SCCSurfaceNetManager:penalize DCs with 167 points inside envelope and 0 points outside


penalize DCs with 167 points
tensor([[-0.2692, -0.3633,  0.1414, -0.1000],
        [ 0.2995,  0.2584, -0.2786, -0.1000],
        [ 0.3725,  0.0862, -0.3183, -0.1000],
        ...,
        [-0.1258, -0.3076,  0.1249, -0.1000],
        [ 0.0359,  0.0200,  0.0298, -0.1000],
        [ 0.1977,  0.1351,  0.0850, -0.1000]], device='cuda:0')
DESCENT LOSS: tensor(-1.1562e-05, device='cuda:0', grad_fn=<MulBackward0>)


 50%|[33m█████     [0m| 5/10 [01:40<01:44, 20.85s/it]INFO:cp_helper:=== Recomputing the graph ===
INFO:cp_helper:(0) SUCCESS - PC Candidates: 3375
INFO:cp_helper:(0) SUCCESS
INFO:cp_helper:lr: 0.0001
Descending to CPs: 1.46e+00: 100%|██████████| 1000/1000 [00:07<00:00, 132.63it/s]
INFO:cp_helper:(0) POST grad norm - PC Candidates: 3750
INFO:cp_helper:(1) SUCCESS
INFO:cp_helper:y_x_mag: torch.Size([3699])tensor([2.1269e+00, 2.1269e+00, 1.7167e-01,  ..., 3.3685e-04, 2.4223e+00,
        2.2806e-04])
INFO:cp_helper:nof points with small enough gradient: 803
INFO:cp_helper:(2) SUCCESS
INFO:cp_helper:(2) POST clustering - PC Clusters: 411
INFO:cp_helper:(3) SUCCESS
INFO:cp_helper:(3) saddle points False
INFO:cp_helper:(4) SUCCESS
Connecting CPs: -9.39e-01: 100%|███████████| 1000/1000 [00:02<00:00, 458.23it/s]
INFO:cp_helper:(5) SUCCESS
INFO:cp_helper:(6) SUCCESS
INFO:cp_helper:(6) SUCCESS 411
INFO:SCCSurfaceNetManager:penalize DCs with 192 points inside envelope and 0 points outside


penalize DCs with 192 points
tensor([[-0.2550,  0.2280,  0.2682, -0.1000],
        [-0.1763,  0.3864,  0.2870, -0.1000],
        [ 0.1412,  0.0101,  0.0045, -0.1000],
        ...,
        [ 0.1724, -0.0108, -0.1729, -0.1000],
        [-0.2098,  0.2057,  0.3580, -0.1000],
        [-0.3713, -0.2794, -0.0097, -0.1000]], device='cuda:0')
DESCENT LOSS: tensor(-1.4255e-05, device='cuda:0', grad_fn=<MulBackward0>)


 60%|[33m██████    [0m| 6/10 [02:06<01:30, 22.53s/it]INFO:cp_helper:=== Recomputing the graph ===
INFO:cp_helper:(0) SUCCESS - PC Candidates: 3375
INFO:cp_helper:(0) SUCCESS
INFO:cp_helper:lr: 0.0001
Descending to CPs: 1.32e+00: 100%|██████████| 1000/1000 [00:07<00:00, 136.30it/s]
INFO:cp_helper:(0) POST grad norm - PC Candidates: 3786
INFO:cp_helper:(1) SUCCESS
INFO:cp_helper:y_x_mag: torch.Size([3721])tensor([6.5794e-01, 6.5793e-01, 8.8313e-06,  ..., 9.3932e-01, 7.4849e-04,
        1.2527e-05])
INFO:cp_helper:nof points with small enough gradient: 779
INFO:cp_helper:(2) SUCCESS
INFO:cp_helper:(2) POST clustering - PC Clusters: 407
INFO:cp_helper:(3) SUCCESS
INFO:cp_helper:(3) saddle points False
INFO:cp_helper:(4) SUCCESS
Connecting CPs: -9.24e-01: 100%|███████████| 1000/1000 [00:02<00:00, 472.76it/s]
INFO:cp_helper:(5) SUCCESS
INFO:cp_helper:(6) SUCCESS
INFO:cp_helper:(6) SUCCESS 407
INFO:SCCSurfaceNetManager:penalize DCs with 192 points inside envelope and 0 points outside


penalize DCs with 192 points
tensor([[-0.3214, -0.0917,  0.1207, -0.1000],
        [ 0.0851, -0.1335, -0.1298, -0.1000],
        [ 0.1715, -0.0065, -0.3455, -0.1000],
        ...,
        [ 0.0274, -0.1327, -0.1144, -0.1000],
        [-0.3972, -0.2434, -0.1794, -0.1000],
        [ 0.1967, -0.0247, -0.0411, -0.1000]], device='cuda:0')
DESCENT LOSS: tensor(-1.0773e-05, device='cuda:0', grad_fn=<MulBackward0>)


 70%|[33m███████   [0m| 7/10 [02:31<01:10, 23.36s/it]INFO:cp_helper:=== Recomputing the graph ===
INFO:cp_helper:(0) SUCCESS - PC Candidates: 3375
INFO:cp_helper:(0) SUCCESS
INFO:cp_helper:lr: 0.0001
Descending to CPs: 9.56e-01: 100%|██████████| 1000/1000 [00:07<00:00, 135.76it/s]
INFO:cp_helper:(0) POST grad norm - PC Candidates: 3782
INFO:cp_helper:(1) SUCCESS
INFO:cp_helper:y_x_mag: torch.Size([3694])tensor([4.8761e-03, 5.7010e-06, 8.3396e-06,  ..., 3.0956e-05, 1.0771e+00,
        1.1418e-01])
INFO:cp_helper:nof points with small enough gradient: 781
INFO:cp_helper:(2) SUCCESS
INFO:cp_helper:(2) POST clustering - PC Clusters: 435
INFO:cp_helper:(3) SUCCESS
INFO:cp_helper:(3) saddle points False
INFO:cp_helper:(4) SUCCESS
Connecting CPs: -8.61e-01: 100%|███████████| 1000/1000 [00:02<00:00, 447.19it/s]
INFO:cp_helper:(5) SUCCESS
INFO:cp_helper:(6) SUCCESS
INFO:cp_helper:(6) SUCCESS 435
INFO:SCCSurfaceNetManager:penalize DCs with 211 points inside envelope and 0 points outside


penalize DCs with 211 points
tensor([[ 0.2018, -0.1377, -0.2396, -0.1000],
        [-0.2176,  0.3402,  0.2027, -0.1000],
        [ 0.0308, -0.0800, -0.2235, -0.1000],
        ...,
        [-0.1928, -0.2117,  0.2463, -0.1000],
        [ 0.0824, -0.0631, -0.2243, -0.1000],
        [ 0.0997,  0.4116,  0.1844, -0.1000]], device='cuda:0')
DESCENT LOSS: tensor(-1.2046e-05, device='cuda:0', grad_fn=<MulBackward0>)


 80%|[33m████████  [0m| 8/10 [03:00<00:50, 25.04s/it]INFO:cp_helper:=== Recomputing the graph ===
INFO:cp_helper:(0) SUCCESS - PC Candidates: 3375
INFO:cp_helper:(0) SUCCESS
INFO:cp_helper:lr: 0.0001
Descending to CPs: 8.44e-01: 100%|██████████| 1000/1000 [00:07<00:00, 136.79it/s]
INFO:cp_helper:(0) POST grad norm - PC Candidates: 3810
INFO:cp_helper:(1) SUCCESS
INFO:cp_helper:y_x_mag: torch.Size([3742])tensor([8.3005e-02, 3.0590e-05, 1.1101e-05,  ..., 2.4386e-05, 2.0343e-03,
        2.7160e-05])
INFO:cp_helper:nof points with small enough gradient: 891
INFO:cp_helper:(2) SUCCESS
INFO:cp_helper:(2) POST clustering - PC Clusters: 486
INFO:cp_helper:(3) SUCCESS
INFO:cp_helper:(3) saddle points False
INFO:cp_helper:(4) SUCCESS
Connecting CPs: -7.97e-01: 100%|███████████| 1000/1000 [00:02<00:00, 469.05it/s]
INFO:cp_helper:(5) SUCCESS
INFO:cp_helper:(6) SUCCESS
INFO:cp_helper:(6) SUCCESS 486
INFO:SCCSurfaceNetManager:penalize DCs with 272 points inside envelope and 0 points outside


penalize DCs with 272 points
tensor([[ 0.0772,  0.0005, -0.2746, -0.1000],
        [ 0.1788, -0.0697, -0.1461, -0.1000],
        [ 0.0397,  0.3766,  0.1680, -0.1000],
        ...,
        [-0.3854, -0.3296, -0.0759, -0.1000],
        [-0.3865, -0.1221,  0.1741, -0.1000],
        [ 0.3757,  0.0193, -0.3220, -0.1000]], device='cuda:0')
DESCENT LOSS: tensor(-7.6556e-06, device='cuda:0', grad_fn=<MulBackward0>)


 90%|[33m█████████ [0m| 9/10 [03:37<00:28, 28.88s/it]INFO:cp_helper:=== Recomputing the graph ===
INFO:cp_helper:(0) SUCCESS - PC Candidates: 3375
INFO:cp_helper:(0) SUCCESS
INFO:cp_helper:lr: 0.0001
Descending to CPs: 7.94e-01: 100%|██████████| 1000/1000 [00:07<00:00, 134.74it/s]
INFO:cp_helper:(0) POST grad norm - PC Candidates: 3861
INFO:cp_helper:(1) SUCCESS
INFO:cp_helper:y_x_mag: torch.Size([3801])tensor([6.2503e-06, 1.0059e-05, 2.0162e-05,  ..., 1.1509e-04, 1.7866e-04,
        1.5283e+00])
INFO:cp_helper:nof points with small enough gradient: 879
INFO:cp_helper:(2) SUCCESS
INFO:cp_helper:(2) POST clustering - PC Clusters: 476
INFO:cp_helper:(3) SUCCESS
INFO:cp_helper:(3) saddle points False
INFO:cp_helper:(4) SUCCESS
Connecting CPs: -7.44e-01: 100%|███████████| 1000/1000 [00:02<00:00, 469.11it/s]
INFO:cp_helper:(5) SUCCESS
INFO:cp_helper:(6) SUCCESS
INFO:cp_helper:(6) SUCCESS 476
INFO:SCCSurfaceNetManager:penalize DCs with 243 points inside envelope and 0 points outside


penalize DCs with 243 points
tensor([[-0.1407, -0.3553,  0.1402, -0.1000],
        [ 0.0747, -0.0052, -0.2174, -0.1000],
        [-0.0471, -0.1261, -0.0977, -0.1000],
        ...,
        [-0.3099, -0.0811, -0.0364, -0.1000],
        [ 0.1347, -0.0086, -0.1080, -0.1000],
        [ 0.1088,  0.0757, -0.2055, -0.1000]], device='cuda:0')
DESCENT LOSS: tensor(-1.0817e-05, device='cuda:0', grad_fn=<MulBackward0>)


100%|[33m██████████[0m| 10/10 [04:09<00:00, 24.95s/it]


In [None]:
## NOTE: 8GB is not enough CUDA memory to perform marching cubes after training. Maybe we release some tensors? Alt: I would love to understand what these tensors
## TODO: smaller update?

z = torch.tensor([-0.1])
mesh_update = get_mesh_for_latent(netp.f_, netp.params_, z, config['bounds'], mc_resolution=128, device=device, chunks=1)

fig = k3d.plot()
# fig += k3d.mesh(*mesh_checkpoint, color=0xff0000, side='double', opacity=0.5, name='Original shape')
fig += k3d.mesh(*mesh_update, color=0x00ff00, side='double', opacity=0.5, name='Updated shape')
fig.display()
fig.camera_auto_fit = False
fig.camera = [0.8042741481976844,
            -1.040350835893895,
            0.7038650223301532,
            0.08252720725551285,
            -0.08146462547370059,
            -0.1973267630672968,
            -0.3986658507677483,
            0.39231188503442904,
            0.8289492893370278]