## Take a PLY and a group of `negative control` PLY files. Find the location where the difference between them is the most drastic

In [2]:
import pandas as pd
from sklearn.neighbors import KDTree
from precise.dataset.precise_dataset import get_data_from_ply

def compute_distance(src_coords, source_feats, target_coord, target_feats, distance_fnc, top_k=100):
    target_kdtree      = KDTree(target_coord)
    _, closest_tgtid   = target_kdtree.query(src_coords, k=1)
    tgt_remapped_feats = target_feats[closest_tgtid.flatten()]
    scores             = distance_fnc(source_feats, tgt_remapped_feats) # N
    return torch.topk(scores, k=top_k)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
src_ply   = "pdbs/rhoA/aligned/5ez6_aligned_surface.ply"
tgt_plys  = ["pdbs/rhoA/aligned/6kx2_aligned_surface.ply", "pdbs/rhoA/aligned/1kmq_aligned_surface.ply"]
src_data  = get_data_from_ply(src_ply)
tgt_datas = [get_data_from_ply(tgt_ply) for tgt_ply in tgt_plys]

In [13]:
import torch 
def dist(x, y):
    return torch.norm(x - y, dim=1) * (-x[:, -1])

locs1     = compute_distance(src_data.pos.numpy(), src_data.z, tgt_datas[1].pos.numpy(), tgt_datas[0].z, dist, 
                             top_k=150)

In [14]:
colors                = (torch.zeros_like(src_data.pos).int() 
                         + torch.tensor([175, 175, 175]))
colors[locs1.indices] = torch.tensor([255, 0, 0])

In [15]:
from precise.utils.visualization import write_ply_with_colors
from pathlib import Path

outloc = Path("distinct_5ez6_vs_1kmq.ply")
write_ply_with_colors(outloc, src_data.pos.numpy(), 
                      src_data.faces.numpy(),
                      colors.numpy())