In [9]:
import numpy as np
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
def d_t(x, y):
    return 0.5 * (np.linalg.norm(x - y)) ** 2

In [3]:
# create a function that given an index idx and a torch tensor of shape (batch_size, M, input_dim) creates a probability tensor of shape (batch_size, M) of choosing between M clusters, where probability of picking cluster at index i ~ exp(-||y_idx-y_i||^2)
def create_probability_tensor(idx, cluster_tensor, gamma = 0.1):
    # gamma is the randomness factor
    # shape of idx is (batch_size,)
    batch_size = cluster_tensor.size(0)
    M = cluster_tensor.size(1)

    # Get the cluster center at the given index
    y_idx = cluster_tensor[torch.arange(batch_size), idx]  # (batch_size, input_dim)

    # Compute squared distances
    squared_diffs = (cluster_tensor - y_idx.unsqueeze(1)) ** 2  # (batch_size, M, input_dim)
    squared_distances = squared_diffs.sum(dim=-1)  # (batch_size, M)

    # Compute probabilities
    probabilities = torch.exp(-gamma * squared_distances)

    # Normalize to get a valid probability distribution
    probabilities /= probabilities.sum(dim=-1, keepdim=True)

    return probabilities

def return_realized_distances(idx, cluster_tensor, gamma=0.1):
    probabilities = create_probability_tensor(idx, cluster_tensor, gamma)
    # sample clusters according to probabilities
    sampled_indices = torch.multinomial(probabilities, num_samples=1)
    sampled_clusters = cluster_tensor[torch.arange(cluster_tensor.size(0)), sampled_indices.squeeze()] # shape (batch_size, input_dim)
    # calculate distance between the idx clusters and sampled clusters
    realized_distances = torch.sum((cluster_tensor[torch.arange(cluster_tensor.size(0)), idx] - sampled_clusters) ** 2, dim=-1)
    return realized_distances

In [4]:
# example usage
idx = torch.tensor([0, 1, 2, 3, 4, 5])  # example indices
cluster_tensor = torch.randn(6, 6, 10)  # example cluster tensor (batch_size=3, M=5, input_dim=10)
realized_distances = return_realized_distances(idx, cluster_tensor, gamma = 1e-6)
print(realized_distances)

tensor([29.6419,  0.0000, 36.8542,  0.0000,  0.0000, 34.2116])


In [10]:
N = 100 # number of samples
M = 10 # number of clusters
Batch_size = 16 # number of batches
num_samples_in_batch = 8 # number of samples in each batch
input_dim = 2 # dimensionality of the input space
X = torch.randn(N, input_dim).to(device)  # example input tensor (N, input_dim)
Y = torch.randn(M, input_dim).to(device)  # example cluster tensor (M, input_dim)
X_batches = torch.zeros(Batch_size, num_samples_in_batch, input_dim).to(device)  # (Batch_size, num_samples_in_batch, input_dim)
for i in range(Batch_size):
    batch_indices = torch.randint(0, N, (num_samples_in_batch,)).to(device)
    X_batches[i] = X[batch_indices]
Y_batches = Y.unsqueeze(0).expand(Batch_size, -1, -1).to(device)  # (Batch_size, M, input_dim)

In [17]:
from ADEN import ADEN
model = ADEN(input_dim=input_dim).to(device)
predicted_distances = model(X_batches, Y_batches) # predicted distances (Batch_size, num_samples_in_batch, M)
idx = torch.argmin(predicted_distances, dim=-1)  # indices of the closest clusters (Batch_size, num_samples_in_batch)
idx[0]

tensor([8, 2, 5, 8, 0, 2, 5, 0], device='cuda:0')