In [None]:
import sys
import os


def import_modules(libpath):
    path2add = os.path.normpath(os.path.abspath(os.path.join(os.path.dirname(
        libpath), os.path.pardir)))
    print(f'Adding path: {path2add}')
    if (not (path2add in sys.path)):
        sys.path.append(path2add)

In [None]:
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.feature_extraction import get_graph_node_names
from torch import nn
import torch

import_modules('../src/models/components')
from datamodules.components.frame_dataset import H2OFrameDataset
from models.components.unified_fcn import UnifiedFCNModule

In [None]:
# Test the datahandler/dataloader implementation
dataset = H2OFrameDataset(
    '../data/h2o/', '../data/h2o/label_split/pose_train.txt')
data = dataset[2]
print(data.keys())

In [None]:
# Utility functions to create grid data

def conf_func(dist, alpha, dth):
    dist = np.sqrt(np.sum((dist)**2, axis=-1))
    mask = (dist < dth)
    conf = np.exp(alpha*(1 - dist/dth))
    conf = mask * conf
    mean_conf = np.mean(conf, axis=-1)
    return mean_conf


def corner_confidences(cp_pred_np: np.ndarray, obj_pose: np.ndarray, l_hand: np.ndarray, r_hand: np.ndarray, alpha: float = 2.0, dth=[75, 75, 7.5]):
    cp_gt = np.stack([obj_pose, l_hand, r_hand])
    cp_gt = cp_gt.reshape(cp_gt.shape[:-1] + (-1, 3))
    cp_pred_np = cp_pred_np.reshape(cp_pred_np.shape[:-1] + (-1, 3))
    dist = cp_gt - cp_pred_np[:, :, ]
    c_uv = conf_func(dist[..., :2], alpha, dth[0])
    z_mask = (dist[..., -1] < dth[-1])
    c_z = np.mean(z_mask * np.abs(dist[..., -1]), axis=-1)
    conf = 0.5*(c_uv + c_z)

    return conf

In [None]:
# Test the working of Unified FCN
ufcn = UnifiedFCNModule('convnext_tiny', 21, 9, 12)
net = ufcn.net
train_nodes, _ = get_graph_node_names(net)

net = create_feature_extractor(
    net, return_nodes={'features.7.2.block.4': 'feat_out'})
out = net(torch.rand(1, 3, 416, 416))
x = out['feat_out']
out_channels = 5 * 3 * (3 * ufcn.num_cpts + 1 +
                        ufcn.obj_classes + ufcn.verb_classes)
lin = nn.Linear(x.shape[-1], out_channels)
x = lin(x)
# x = x.permute(0, 3, 1, 2)

bsize, _, h, w = x.size()
x_reshaped = x.contiguous().view(bsize, -1, 3, 3 * ufcn.num_cpts +
                                 1 + ufcn.obj_classes + ufcn.verb_classes)
# print(x.shape)

# vector indices (at position 2): 0 -> object, 1 -> l_hand, 2 -> r_hand
cp_pred = torch.sigmoid(x_reshaped[:, :, :, 0:3 * ufcn.num_cpts])
conf_pred = x_reshaped[:, :, :, 3 * ufcn.num_cpts].contiguous()
obj_pred = torch.sigmoid(
    x_reshaped[:, :, 0, 3 * ufcn.num_cpts+1: 3 * ufcn.num_cpts+1+ufcn.obj_classes])
l_verb_pred = torch.sigmoid(x_reshaped[:, :, 1, 3 * ufcn.num_cpts+1 +
                            ufcn.obj_classes: 3 * ufcn.num_cpts+1+ufcn.obj_classes+ufcn.verb_classes])
r_verb_pred = torch.sigmoid(x_reshaped[:, :, 2, 3 * ufcn.num_cpts+1 +
                            ufcn.obj_classes: 3 * ufcn.num_cpts+1+ufcn.obj_classes+ufcn.verb_classes])

print(cp_pred.shape, conf_pred.shape, obj_pred.shape,
      l_verb_pred.shape, r_verb_pred.shape)

In [None]:
# Confidence computation
l_hand, r_hand, obj_label, obj_pose, verb = data['l_hand'], data[
    'r_hand'], data['obj_label'], data['obj_pose'], data['verb']
conf = corner_confidences(cp_pred.data.cpu().numpy(), obj_pose, l_hand, r_hand)
# print(conf.shape)

noho_scale = 0.1
ho_scale = 5
conf_mask = np.ones_like(conf)*noho_scale
print(conf_mask.shape)

In [None]:
# Control points computation
