In [None]:
from google.colab import output
output.enable_custom_widget_manager()

In [None]:
#define projection model
import torch
import numpy as np
from torch import nn
def weights_init(m: nn.Module) -> None:
    if isinstance(m, (nn.Linear,)):
        nn.init.kaiming_normal_(m.weight.data)
        m.bias.data.fill_(0.01)
    elif isinstance(m, (nn.BatchNorm1d,)):
        m.weight.data.fill_(1.0)
        m.bias.data.fill_(0)


class NeuralMapper(nn.Module):

    def __init__(self, dim_input, dim_emb=2):
        super().__init__()
        self.linear_1 = nn.Linear(dim_input, dim_input)
        self.bn_1 = nn.BatchNorm1d(dim_input)
        self.linear_2 = nn.Linear(dim_input, dim_input)
        self.bn_2 = nn.BatchNorm1d(dim_input)
        self.linear_3 = nn.Linear(dim_input, dim_input)
        self.bn_3 = nn.BatchNorm1d(dim_input)
        self.linear_4 = nn.Linear(dim_input, dim_emb)
        self.relu = nn.ReLU()

        self.apply(weights_init)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.bn_1(x)
        x = self.linear_2(self.relu(x))
        x = self.bn_2(x)
        x = self.linear_3(self.relu(x))
        x = self.bn_3(x)
        x = self.linear_4(self.relu(x))
        return x

In [None]:
#read HPSv2 embeddings and metric scores
import numpy as np
txt_features0=np.load("drive/My Drive/HPD/HPSv2_all_train_txt.npy")
txt_features1=np.load("drive/My Drive/HPD/HPSv2_all_train_img.npy")
txt_features=torch.Tensor(np.concatenate([txt_features0,txt_features1],axis=1)).to("cuda")[:430000]
scores=torch.Tensor(100*np.sum(txt_features0*txt_features1, axis=1)).to("cuda")
scores=scores.reshape((scores.shape[0],))
print(txt_features.shape)
print(scores.shape)

torch.Size([430000, 2048])
torch.Size([430056])


In [None]:
#load projection model
a1= 0.9987241625785828
a2= 7.954826354980469
b= 1.1770744323730469
model_path="drive/My Drive/HPD/HPSv2_HPD"
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_weights=model_path+".pt"
loaded_model = torch.load(model_weights, weights_only=False)
bs = 1000

In [None]:
#contour map function
from torch import nn
def compute_kde_regression_new(X, Y, z_arr, a1, a2, b, h):
    n, d= X.shape
    a1 = torch.tensor(a1, dtype=torch.float32, device="cuda")
    a2 = torch.tensor(a2, dtype=torch.float32, device="cuda")
    b = torch.tensor(b, dtype=torch.float32, device="cuda")
    X = torch.tensor(X, dtype=torch.float32, device="cuda")
    Y = torch.tensor(Y, dtype=torch.float32, device="cuda")
    dists = torch.cdist(X, Y)
    ra2=a2**2
    ra1=a1**2
    rb=b**2
    kernels=1/(1+ra2*dists**(2*rb))

    w_kde=z_arr @ (kernels)
    kde=torch.sum(kernels, axis=0)
    kde_reg=w_kde/kde
    return (ra1*kde_reg).detach().cpu().numpy()

In [None]:
#project datasets
from torch.utils.data import TensorDataset, Dataset, DataLoader
labels_train=torch.tensor(np.array([1 for i in range(len(txt_features))])).to("cuda")
# labels_train=torch.tensor(np.array([1 for i in range(len(txt_features_train))])).to("cuda")
print(labels_train.shape)
labels_train=labels_train.reshape((labels_train.shape[0],1))
print(labels_train.shape)
points_train_ds = TensorDataset(txt_features, labels_train)
# points_train_ds = TensorDataset(txt_features_train, labels_train)
def get_batch_embeddings(pretrained_model: torch.nn.Module,
                         input_points: Dataset,
                         batch_size: int):
    """
    Yields final embeddings for every batch in dataset
    """
    pretrained_model.eval()
    test_dl = DataLoader(input_points, batch_size=batch_size, shuffle=False)
    for batch_points, batch_labels in test_dl:
        with torch.no_grad():
            embeddings = pretrained_model(batch_points)
            yield embeddings, batch_labels
pretrained_model=loaded_model
input_points_train=points_train_ds
batch_size=bs
pos_list=[]
for embeddings, batch_labels in get_batch_embeddings(pretrained_model,
                                                     input_points_train,
                                                     batch_size):
    x1 = embeddings[:, 0].detach().cpu().numpy()
    x2 = embeddings[:, 1].detach().cpu().numpy()
    pos_list.append(embeddings.detach().cpu().numpy())

projected_pos=np.concatenate(pos_list)

torch.Size([430000])
torch.Size([430000, 1])


In [None]:
#sampling points to estimate contour
import random
def kernel_sample_and_get_remaining_ids(N, n):
    # Ensure n is not larger than N
    if n > N:
        raise ValueError("n cannot be larger than N")

    # Sample n unique IDs from the range 0 to N-1
    sampled_ids = random.sample(range(N), n)

    # Convert sampled_ids to a set for faster look-up
    sampled_set = set(sampled_ids)

    # Get the remaining IDs
    remaining_ids = [i for i in range(N) if i not in sampled_set]

    return sampled_ids, remaining_ids
k_ids, _=kernel_sample_and_get_remaining_ids(len(txt_features),50000)

In [None]:
# compute multi-resolution contour
import numpy as np
import plotly.graph_objects as go
from IPython.display import display

# Define grid resolutions to precompute
grid_sizes = [32, 64, 100, 128, 256, 512, 1024, 2048]

# Store precomputed KDE values
precomputed_zi = {}

x_min, x_max = projected_pos.min(0), projected_pos.max(0)

c=1
scale=1.8
# Precompute KDE for different resolutions
for grid in grid_sizes:
    xi = np.linspace(x_min[0], x_max[0], grid)
    yi = np.linspace(x_min[1], x_max[1], grid)

    z_list = []
    for i in range(len(xi)):
        Y1 = [np.array([xi[i], yi[j]]).reshape((1, 2)) for j in range(len(yi))]
        Y1_array = np.concatenate(Y1, axis=0)
        kde_y = compute_kde_regression_new(projected_pos[k_ids], Y1_array, scores[k_ids], a1, c*a2, b, 0.9)
        z_list.append(kde_y.reshape((len(yi), 1)))

    zi = np.concatenate(z_list, axis=1)
    precomputed_zi[grid] = (xi, yi, zi)
    c=c*scale

print("Precomputation complete!")

Precomputation complete!


In [None]:
# compute hierarchical overlay
import numpy as np
import plotly.graph_objects as go

# Initial grid resolution
init_grid = 32
xi, yi, zi = precomputed_zi[init_grid]
x1_min, x1_max = min(xi), max(xi)
x2_min, x2_max = min(yi), max(yi)
# Dictionary holding precomputed tiles at different resolutions
tile_dict = {grid: precomputed_zi[grid] for grid in [32, 64, 128, 256, 512, 1024, 2048]}

# Assign each point to hierarchical tiles
tile_points = {}  # Stores sampled points per tile resolution
tile_scores = {}  # Stores corresponding scores

resolutions = [32, 64, 128, 256, 512, 1024, 2048]
for res in resolutions:
    step = (x1_max - x1_min) / res  # Tile size
    assigned_points = []
    assigned_scores = []

    for idx, (x, y) in enumerate(projected_pos):
        tile_x, tile_y = int((x - x1_min) / step), int((y - x2_min) / step)
        if np.random.rand() < 0.0001*(res):  # Random sampling per tile
            assigned_points.append((x, y))
            assigned_scores.append(scores.cpu().numpy()[idx])

    tile_points[res] = np.array(assigned_points)
    tile_scores[res] = np.array(assigned_scores)

# Initial marker size and color scaling
cmin, cmax = scores.cpu().numpy().min(), scores.cpu().numpy().max()
initial_marker_size = 6

In [None]:
# Define global min and max for contour color scaling
# global_min = zi.min()
# global_max = zi.max()
import time

# Dictionary to store zoom response times
zoom_response_times = []
global_min = cmin+2
global_max = cmax-2
fig = go.FigureWidget(
    data=[
        go.Contour(z=zi, x=xi, y=yi, colorscale="RdYlBu_r", ncontours=80, line=dict(width=0), zmin=global_min, zmax=global_max),  # Contour plot
        go.Scatter(x=[], y=[], mode='markers',  # Sample points overlay (initially empty)
                   marker=dict(
                       color=[],  # Dynamically updated
                       colorscale="RdYlBu_r",
                      #  cmin=cmin, cmax=cmax,
                       cmin=global_min, cmax=global_max,
                       opacity=0.7,
                       size=initial_marker_size,
                       sizemode="area",
                       line=dict(width=1, color="gray"),
                      #  colorbar=dict(title="Scores")
                       showscale=False
                   ), name="Sample Points")
    ]
)



# Set figure layout
# fig.update_layout(autosize=False, width=1000, height=1000,xaxis=dict(scaleanchor="y"),  # Ensures equal scaling for X and Y
# yaxis=dict(scaleanchor="x"))
fig.update_layout(autosize=False, width=900, height=900)
# Dictionary to cache previous zoom tiles
cached_tiles = {}


def update_contour(trace, layout):
    """Updates contour plot and sample point overlay dynamically."""
    start_time = time.time()  # Start timing
    x_min_zoom, x_max_zoom = trace["xaxis.range"]
    y_min_zoom, y_max_zoom = trace["yaxis.range"]
    fig.update_layout(autosize=False, width=900, height=900)
    zoom_range = x_max_zoom - x_min_zoom

    # Determine the best resolution based on zoom level
    if zoom_range > (x_max[0] - x_min[0]) * 0.5:
        grid = 32
        marker_size = 6
    elif zoom_range > (x_max[0] - x_min[0]) * 0.25:
        grid = 64
        marker_size = 8
    elif zoom_range > (x_max[0] - x_min[0]) * 0.1:
        grid = 128
        marker_size = 10
    elif zoom_range > (x_max[0] - x_min[0]) * 0.05:
        grid = 512
        marker_size = 10
    elif zoom_range > (x_max[0] - x_min[0]) * 0.025:
        grid = 1024
        marker_size = 10
    else:
        grid = 2048  # Highest resolution
        marker_size = 12

    # Retrieve the precomputed contour grid
    xi_new, yi_new, zi_new = tile_dict[grid]

    # Select only the tiles that overlap with the zoomed-in region
    mask_x = (xi_new >= x_min_zoom) & (xi_new <= x_max_zoom)
    mask_y = (yi_new >= y_min_zoom) & (yi_new <= y_max_zoom)

    # Extract zoomed-in data
    xi_zoomed = xi_new[mask_x]
    yi_zoomed = yi_new[mask_y]
    zi_zoomed = zi_new[np.ix_(mask_y, mask_x)]

    # Cache the tile if not already stored
    cache_key = (grid, x_min_zoom, x_max_zoom, y_min_zoom, y_max_zoom)
    if cache_key not in cached_tiles:
        cached_tiles[cache_key] = (xi_zoomed, yi_zoomed, zi_zoomed)

    # Retrieve points for the selected grid resolution
    sample_points = tile_points[grid]
    sample_scores = tile_scores[grid]

    margin_x = (xi_zoomed.max() - xi_zoomed.min()) * 0.02  # 2% of the range
    margin_y = (yi_zoomed.max() - yi_zoomed.min()) * 0.02
    # Filter points that fall within the zoomed-in region
    # mask_points = (sample_points[:, 0] >= x_min_zoom) & (sample_points[:, 0] <= x_max_zoom) & \
    #               (sample_points[:, 1] >= y_min_zoom) & (sample_points[:, 1] <= y_max_zoom)
    mask_points = (sample_points[:, 0] > xi_zoomed.min()+margin_x) & (sample_points[:, 0] < xi_zoomed.max()-margin_x) & \
                  (sample_points[:, 1] > yi_zoomed.min()+margin_y) & (sample_points[:, 1] < yi_zoomed.max()-margin_y)
    visible_points = sample_points[mask_points]
    visible_scores = sample_scores[mask_points]

    # Update the figure
    with fig.batch_update():
        fig.data[0].x = xi_zoomed
        fig.data[0].y = yi_zoomed
        fig.data[0].z = zi_zoomed
        fig.data[0].zmin = global_min
        fig.data[0].zmax = global_max

        # Update scatter points
        fig.data[1].x = visible_points[:, 0] if len(visible_points) > 0 else []
        fig.data[1].y = visible_points[:, 1] if len(visible_points) > 0 else []
        fig.data[1].marker.color = visible_scores if len(visible_scores) > 0 else []
        fig.data[1].marker.size = marker_size

    end_time = time.time()  # End timing
    response_time = end_time - start_time
    zoom_response_times.append(response_time)
    print(f"Zoom response time: {response_time:.4f} seconds")

# Attach event listener for zoom changes
fig.layout.on_change(update_contour, "xaxis.range")

# Display the figure
display(fig)

FigureWidget({
    'data': [{'colorscale': [[0.0, 'rgb(49,54,149)'], [0.1, 'rgb(69,117,180)'],
                             [0.2, 'rgb(116,173,209)'], [0.3, 'rgb(171,217,233)'],
                             [0.4, 'rgb(224,243,248)'], [0.5, 'rgb(255,255,191)'],
                             [0.6, 'rgb(254,224,144)'], [0.7, 'rgb(253,174,97)'],
                             [0.8, 'rgb(244,109,67)'], [0.9, 'rgb(215,48,39)'],
                             [1.0, 'rgb(165,0,38)']],
              'line': {'width': 0},
              'ncontours': 80,
              'type': 'contour',
              'uid': '3d2e2162-4c1d-478a-beb5-eeba5133582a',
              'x': array([-14.647935  , -13.730427  , -12.81292   , -11.895412  , -10.977904  ,
                          -10.060396  ,  -9.142889  ,  -8.225382  ,  -7.3078737 ,  -6.3903666 ,
                           -5.4728584 ,  -4.5553503 ,  -3.6378431 ,  -2.720336  ,  -1.8028278 ,
                           -0.8853197 ,   0.03218746,   0.94969463,   1.86

Zoom response time: 0.0063 seconds
Zoom response time: 0.0030 seconds
Zoom response time: 0.0138 seconds
Zoom response time: 0.0170 seconds
Zoom response time: 0.0098 seconds
Zoom response time: 0.0069 seconds


In [None]:
np.mean(zoom_response_times)

np.float64(0.00944213072458903)