# Educationnal notebook : loading and inference on SemanticKITTI
## Loading config files

In [None]:
!HYDRA_FULL_ERROR=1
%load_ext autoreload
%autoreload 2

import sys, os
sys.path.append("../")

import warnings
warnings.filterwarnings("ignore")

import hydra
import pytorch_lightning as pl
import numpy as np
import torch
from torch_geometric.data import Batch

#Loading config and making small changes for visualisation
hydra.initialize(config_path="../configs")
cfg = hydra.compose(config_name="defaults.yaml", overrides=["hydra.searchpath=[file://../HelixNet/configs]", "+data=semantic-kitti", "+experiment=viz/ours_sk"])

cfg.data.data_dir = os.path.join("../", cfg.data.data_dir)
cfg.model.load_weights = os.path.join("../", cfg.model.load_weights)
if "helixnet" in cfg.data._target_:
    cfg.data._target_ = f"HelixNet.{cfg.data._target_}"


pl.seed_everything(cfg.seed)

import helix4d

## Loading the dataset

In [None]:
datamodule = hydra.utils.instantiate(cfg.data)
datamodule.setup("fit")

## Loading the model

In [None]:
model = hydra.utils.instantiate(
        cfg.model,
        _recursive_=False,
    )
model.load_state_dict(torch.load(cfg.model.load_weights, map_location=f"cuda:{0}")['state_dict'])

if torch.cuda.is_available():
    model = model.cuda()

## Inference on random samples

In [None]:
for _ in range(3):
    item=datamodule.val_dataset[np.random.randint(len(datamodule.val_dataset))]

    with torch.no_grad():
        item.point_pred_y = model(Batch.from_data_list([item]).to(model.device), 1, 0)[0].argmax(-1).detach().cpu()
    
    datamodule.show_2d(item, color="y;pred_y;voxel")
    print(f"Seq {item.seqid} frame {item.scanid}\t Accuracy = {100*(item.point_y == item.point_pred_y)[item.point_y!=0].float().mean():.2f}\n\n")