In [1]:
import h5py
import torch
import open_clip
from torch import Tensor

In [2]:
model, _, _ = open_clip.create_model_and_transforms(
    "ViT-B-16",  # e.g., ViT-B-16
    pretrained="laion2b_s34b_b88k",  # e.g., laion2b_s34b_b88k
    precision="fp16",
)
model.eval()
model = model.to("cuda")
tokenizer = open_clip.get_tokenizer("ViT-B-16")

negatives = ["object", "things", "stuff", "texture"]
with torch.no_grad():
    tok_phrases = torch.cat([tokenizer(phrase) for phrase in negatives]).to(
        "cuda"
    )
    neg_embeds = model.encode_text(tok_phrases)
neg_embeds /= neg_embeds.norm(dim=-1, keepdim=True)

In [3]:
tokenizer = open_clip.get_tokenizer("ViT-B-16")

In [4]:
query = "book"
positives = [query]
with torch.no_grad():
    tok_phrases = torch.cat(
        [tokenizer(phrase) for phrase in positives]
    ).to("cuda")
    pos_embeds = model.encode_text(tok_phrases)
pos_embeds /= pos_embeds.norm(dim=-1, keepdim=True)
# use query to dot product with the point cloud -> centroids
scales_list = torch.linspace(0.0, 1.5, 30)

In [5]:
def get_relevancy(
    embed: torch.Tensor,
    positive_id: int,
    pos_embeds: Tensor,
    neg_embeds: Tensor,
    positive_words_length: int,
) -> torch.Tensor:
    phrases_embeds = torch.cat([pos_embeds, neg_embeds], dim=0)
    p = phrases_embeds.to(embed.dtype)  # phrases x 512
    output = torch.mm(embed, p.T)  # rays x phrases
    positive_vals = output[..., positive_id : positive_id + 1]  # noqa E501
    negative_vals = output[..., positive_words_length:]  # rays x N_phrase
    repeated_pos = positive_vals.repeat(
        1, 4
    )  # rays x N_phrase

    sims = torch.stack((repeated_pos, negative_vals), dim=-1)  # rays x N-phrase x 2
    softmax = torch.softmax(10 * sims, dim=-1)  # rays x n-phrase x 2
    best_id = softmax[..., 0].argmin(dim=1)  # rays x 2
    return torch.gather(
        softmax,
        1,
        best_id[..., None, None].expand(
            best_id.shape[0], 4, 2
        ),
    )[:, 0, :]

In [6]:
def load_h5_file(load_config: str) -> dict:
    hdf5_file = h5py.File(load_config, "r")
    # batch_idx = 5
    points = hdf5_file["points"]["points"][:]
    origins = hdf5_file["origins"]["origins"][:]
    directions = hdf5_file["directions"]["directions"][:]

    clip_embeddings_per_scale = []

    clips_group = hdf5_file["clip"]
    for i in range(30):
        clip_embeddings_per_scale.append(clips_group[f"scale_{i}"][:])

    rgb = hdf5_file["rgb"]["rgb"][:]
    hdf5_file.close()
    h5_dict = {
        "points": points,
        "origins": origins,
        "directions": directions,
        "clip_embeddings_per_scale": clip_embeddings_per_scale,
        "rgb": rgb,
    }
    return h5_dict

In [7]:
h5_dict = load_h5_file("/workspace/chat-with-nerf-dev/chat-with-nerf/data/scene0025_00/embeddings.h5")

In [8]:
n_phrases = len(positives)
best_scale_for_phrases = [None for _ in range(n_phrases)]
probability_per_scale_per_phrase = [
    None for _ in range(n_phrases)
]
for i, scale in enumerate(scales_list):
    clip_output = torch.from_numpy(
        h5_dict["clip_embeddings_per_scale"][i]
    ).to("cuda")
    for i in range(n_phrases):
        probs = get_relevancy(
            embed=clip_output,
            positive_id=i,
            pos_embeds=pos_embeds,
            neg_embeds=neg_embeds,
            positive_words_length=1,
        )
        pos_prob = probs[..., 0:1]
        if (
            best_scale_for_phrases[i] is None
            or pos_prob.max() > probability_per_scale_per_phrase[i].max()  # type: ignore
        ):
            best_scale_for_phrases[i] = scale
            probability_per_scale_per_phrase[i] = pos_prob

possibility_array = probability_per_scale_per_phrase[0].detach().cpu().numpy()  # type: ignore # noqa: E501

In [9]:
possibility_array

array([[0.5068755 ],
       [0.42661226],
       [0.50202906],
       ...,
       [0.4464897 ],
       [0.5149638 ],
       [0.4339206 ]], dtype=float32)

In [10]:
len(possibility_array[possibility_array > 0.5])

213623

In [11]:
len(possibility_array[possibility_array < 0.5])

294280

In [14]:
pos_prob

tensor([[0.5245],
        [0.5252],
        [0.5385],
        ...,
        [0.5331],
        [0.5226],
        [0.5367]], device='cuda:0')