# Helix4D's Inference on HelixNet

Inferences on HelixNet with the default configuration of Helix4D. Model trained with full rotations of the sensor, and inference on 1/5 turn. For visualization, all point clouds are downsampled on a regular grid of size 25cm.

In [None]:
!HYDRA_FULL_ERROR=1

%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings("ignore")
from plotly.offline import init_notebook_mode
init_notebook_mode(connected = True)

import sys, os
sys.path.append("../")

import hydra
import pytorch_lightning as pl
import logging

pl.utilities.distributed.log.setLevel(logging.ERROR)

hydra.initialize(config_path="../configs")
cfg = hydra.compose(config_name="defaults.yaml", overrides=["hydra.searchpath=[file://../HelixNet/configs]", "+experiment=viz/ours_helixnet"])

cfg.data.data_dir = os.path.join("../", cfg.data.data_dir)
cfg.model.load_weights = os.path.join("../", cfg.model.load_weights)
if "helixnet" in cfg.data._target_:
    cfg.data._target_ = f"HelixNet.{cfg.data._target_}"

pl.seed_everything(cfg.seed)

import numpy as np
import torch
from torch_geometric.data import Batch

In [None]:
datamodule = hydra.utils.instantiate(cfg.data)
datamodule.setup()

tag = "test"

In [None]:
model = hydra.utils.instantiate(
        cfg.model,
        _recursive_=False,
    ).cuda()

model.load_state_dict(torch.load(cfg.model.load_weights, map_location=f"cuda:{0}")['state_dict'])

if torch.cuda.is_available():
    model = model.cuda()

In [None]:
items = []
len_dataset = len(getattr(datamodule, f"{tag}_dataset"))

while len(items) < 2:
    item = getattr(datamodule, f"{tag}_dataset")[np.random.randint(len_dataset)]
    with torch.no_grad():
        out = model(Batch.from_data_list([item]).to(model.device), 1, 0)[0].detach().cpu()
    item.point_pred_y = out.argmax(-1)
    
    items.append(item.detach().cpu())
    torch.cuda.empty_cache()

model = model.cpu()
item = item.cpu()
del item, model
torch.cuda.empty_cache()

In [None]:
for item in items:
    datamodule.show_3d(item, "pred_y;y;slice;voxel;time")