In [None]:
import os
os.chdir('../')


from copy import deepcopy
from pathlib import Path
from types import SimpleNamespace
import pickle as pkl

import numpy as np
import matplotlib.pyplot as plt
import torch
import trimesh
from tqdm.notebook import tqdm
from scipy.spatial.transform import Rotation
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

from popup.models.baseline_nn import create_nn_model, create_and_query_nn_model
from popup.utils.exp import init_experiment
from popup.core.evaluator import Evaluator


%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
arguments = SimpleNamespace(
    scenario=Path("./scenarios/gb_nn_baseline.toml"), 
    exp_name="nn_baseline",
    project_config=Path("./configs/smplh.toml"),
    experiment_prefix="nn_baseline", resume_checkpoint=None, workers=None, 
    batch_size=None, lr=None, no_wandb=True
)
config = init_experiment(arguments, True)
config.eval_temporal = False
exp_folder = deepcopy(config.exp_folder)

In [None]:
objname2classid = config.objname2classid

classid2objname = {v: k for k, v in objname2classid.items()}

### NN Classsifier

In [None]:
kdtree, labels, test_queries, test_labels, test_t_stamps = \
    create_nn_model(config, "classifier", human_features="verts", backend="faiss_gpu")

In [None]:
NEIGHBORS = 1

dataset_to_pred_labels = {}
for dataset in config.datasets:
    _, pred_neighbors = kdtree.query(test_queries[dataset], k=NEIGHBORS)

    for K in range(1, NEIGHBORS + 1):
        _pred_labels = labels[pred_neighbors[:, :K]]
        if K > 1:
            pred_labels = np.zeros(len(_pred_labels), dtype=np.uint8)
            for i, pred_label in enumerate(_pred_labels):
                class_scores = np.bincount(pred_label)
                pred_labels[i] = class_scores.argmax()
        else:
            pred_labels = _pred_labels.reshape(-1)

        if K == 1:
            dataset_to_pred_labels[dataset] = pred_labels

        print(f"{dataset} K={K} | ACC: {100 * np.mean(test_labels[dataset] == pred_labels):02f}")
    
# PER CLASS PREDICTIONS
for class_id in range(len(classid2objname)):
    all_pred, all_gt = [], []
    for dataset in config.datasets:
        mask = test_labels[dataset] == class_id
        
        all_pred.append(dataset_to_pred_labels[dataset][mask])
        all_gt.append(test_labels[dataset][mask])
    
    all_pred = np.concatenate(all_pred, axis=0)
    all_gt = np.concatenate(all_gt, axis=0)
    
    print(f"{classid2objname[class_id]:20s} ACC {100 * np.mean(all_pred == all_gt):02f}")

# CONFUSION MATRIX
font = {'weight': 'bold', 'size': 15}
plt.rc('font', **font)

all_pred, all_gt = [], []
for dataset in config.datasets: 
    all_pred.append(dataset_to_pred_labels[dataset])
    all_gt.append(test_labels[dataset])
all_pred = np.concatenate(all_pred, axis=0)
all_gt = np.concatenate(all_gt, axis=0)

cmtrx = confusion_matrix(all_gt, all_pred)
fig = plt.figure(figsize=(30, 30))
axes_labels = [classid2objname[class_id] for class_id in sorted(classid2objname.keys())]
ConfusionMatrixDisplay.from_predictions(
    all_gt, all_pred, xticks_rotation="vertical", display_labels=axes_labels, ax=plt.gca()
)

### NN pose general

In [None]:
kdtree, labels, test_queries, test_labels, test_t_stamps = \
    create_nn_model(config, "pose_general",  human_features="verts", backend="faiss_gpu")

In [None]:
canonical_meshes_path = dict()
for dataset in config.datasets:
    if dataset == "grab":
        dataset_path = config.grab_path
    elif dataset == "behave":
        dataset_path = config.behave_path
    
    dataset_objects = list((dataset_path / "object_keypoints").glob("*.npz"))
    dataset_objects = [object_name.stem for object_name in dataset_objects]

    for class_id, object_name in classid2objname.items():
        if object_name in dataset_objects:
            canonical_meshes_path[class_id] = str(dataset_path / "object_meshes" / f"{object_name}.ply")

In [None]:
NEIGHBORS = 1

dataset_to_pred_labels = {}
for dataset in config.datasets:
    if dataset == "grab":
        dataset_path = config.grab_path
    elif dataset == "behave":
        dataset_path = config.behave_path
        
    _, pred_neighbors = kdtree.query(test_queries[dataset], k=NEIGHBORS)
    dataset_to_pred_labels[dataset] = {}
    for K in range(1, NEIGHBORS + 1):
        target_dir = exp_folder / "pose_general" / f"{K}/visualization/0/{dataset}"
        _pred_labels = labels[pred_neighbors[:, :K]]
        
        dataset_to_pred_labels[dataset][K] = []
        for sample_id in tqdm(range(len(_pred_labels))):            
            sample_t_stamp = dataset_path / str(test_t_stamps[dataset][sample_id])

            # prediction is flattened 3x3 pca_axes and center location
            _pred_label = _pred_labels[sample_id]

            pred_class = _pred_label[:, 0]
            if K > 1:
                class_scores = np.bincount(pred_class.astype(np.int8))
                pred_class = class_scores.argmax()
            else:
                pred_class = pred_class[0].astype(np.int8)
            dataset_to_pred_labels[dataset][K].append(pred_class)
            pred_rot = _pred_label[:, 1:10].mean(axis=0).reshape(3, 3)
            pred_center = _pred_label[:, 10:].mean(axis=0)

            # load mesh
            predicted_mesh = trimesh.load(canonical_meshes_path[pred_class], process=False)

            # load preprocessing params
            with (sample_t_stamp / "preprocess_transform.pkl").open("rb") as fp:
                preprocess_transform = pkl.load(fp)
            preprocess_params = (
                np.array(preprocess_transform["translation"], dtype=np.float32), 
                preprocess_transform["scale"]
            )
            scale = preprocess_params[1]

            # construct rotation
            R = Rotation.from_matrix(pred_rot.reshape(3, 3))

            # save the resulting mesh
            sbj, obj_act, t_stamp = str(test_t_stamps[dataset][sample_id]).split("/")
            posed_mesh_path = target_dir / sbj / obj_act / "posed_mesh" / f"{t_stamp}.obj"
            posed_mesh_path.parent.mkdir(parents=True, exist_ok=True)
            predicted_mesh.vertices = R.apply(scale * predicted_mesh.vertices) + pred_center
            _ = predicted_mesh.export(str(posed_mesh_path))
        dataset_to_pred_labels[dataset][K] = np.array(dataset_to_pred_labels[dataset][K])

In [None]:
# _exp_folder = config.exp_folder
config.grab["gen_subjects"] = ["s9", "s10"]
config.grab["gen_objects"] = config.grab["val_objects"]
config.grab["gen_actions"] = config.grab["val_actions"]
config.behave["gen_objects"] = config.behave["val_objects"]
config.behave["gen_split_file"] = config.behave["val_split_file"]
config.undo_preprocessing_eval = True

for dataset in config.datasets:
    gt_classes = test_labels[dataset][:, 0].astype(np.int32) 
    for K in range(1, NEIGHBORS + 1):
        pred = dataset_to_pred_labels[dataset][K]

        print(f"{dataset} K={K} | ACC: {100 * np.mean(gt_classes == pred):02f}")

for K in range(1, NEIGHBORS + 1):
    print(40*"=")
    print(f"K={K}")
    print(40*"=")
    config.exp_folder = exp_folder / "pose_general" / f"{K}"
    evaluator = Evaluator(torch.device("cuda:0"), config)
    evaluator.evaluate()

### NN pose class-specific

In [None]:
pred_neighbors, train_labels, test_labels, test_t_stamps = create_and_query_nn_model(
    config, "pose_class_specific",  human_features="verts", n_neighbors=3, backend="faiss_gpu"
)

In [None]:
canonical_meshes_path = dict()
for dataset in config.datasets:
    if dataset == "grab":
        dataset_path = config.grab_path
    elif dataset == "behave":
        dataset_path = config.behave_path
    
    dataset_objects = list((dataset_path / "object_keypoints").glob("*.npz"))
    dataset_objects = [object_name.stem for object_name in dataset_objects]

    for class_id, object_name in classid2objname.items():
        if object_name in dataset_objects:
            canonical_meshes_path[class_id] = str(dataset_path / "object_meshes" / f"{object_name}.ply")

In [None]:
import pickle as pkl
NEIGHBORS = 1

for K in range(1, NEIGHBORS + 1):
    for class_id in tqdm(classid2objname.keys()):
        for dataset in config.datasets:
            if dataset == "grab":
                dataset_path = config.grab_path
            elif dataset == "behave":
                dataset_path = config.behave_path
            target_dir = exp_folder / "pose_class_specific" / f"{K}/visualization/0/{dataset}"    

            if not dataset in pred_neighbors[class_id]:
                continue

            _pred_labels = train_labels[class_id][pred_neighbors[class_id][dataset][:, :K]]
            _test_t_stamps = test_t_stamps[dataset][class_id]


            for sample_id in tqdm(range(len(_pred_labels)), leave=False):
                _pred_label = np.mean(_pred_labels[sample_id], axis=0)
                sample_t_stamp = dataset_path / str(_test_t_stamps[sample_id])

                # prediction is flattened 3x3 pca_axes and center location
                pred_rot = _pred_label[:9].reshape(3, 3)
                pred_center = _pred_label[9:]

                # load gt_mesh
                class_name = _test_t_stamps[sample_id].split("/")[1].split("_")[0]
                predicted_mesh = trimesh.load(canonical_meshes_path[objname2classid[class_name]], process=False)

                # load preprocessing params
                with (sample_t_stamp / "preprocess_transform.pkl").open("rb") as fp:
                    preprocess_transform = pkl.load(fp)
                preprocess_params = (
                    np.array(preprocess_transform["translation"], dtype=np.float32), 
                    preprocess_transform["scale"]
                )
                scale = preprocess_params[1]

                # construct rotation
                R = Rotation.from_matrix(pred_rot)

                # save the resulting mesh
                sbj, obj_act, t_stamp = str(_test_t_stamps[sample_id]).split("/")
                posed_mesh_path = target_dir / sbj / obj_act / "posed_mesh" / f"{t_stamp}.obj"
                posed_mesh_path.parent.mkdir(parents=True, exist_ok=True)
                predicted_mesh.vertices = R.apply(scale * predicted_mesh.vertices) + pred_center
                _ = predicted_mesh.export(str(posed_mesh_path))

In [None]:
# _exp_folder = config.exp_folder
config.grab["gen_subjects"] = ["s9", "s10"]
config.grab["gen_objects"] = config.grab["val_objects"]
config.grab["gen_actions"] = config.grab["val_actions"]
config.behave["gen_objects"] = config.behave["val_objects"]
config.behave["gen_split_file"] = config.behave["val_split_file"]
config.undo_preprocessing_eval = True

for K in range(1, NEIGHBORS + 1):
    print(40*"=")
    print(f"K={K}")
    print(40*"=")
    config.exp_folder = exp_folder / "pose_class_specific" / f"{K}"
    evaluator = Evaluator(torch.device("cuda:0"), config)
    evaluator.evaluate()