In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import Sampler, Dataset, DataLoader

In [6]:
class PositiveNegativePairSampler(Sampler):
    def __init__(self, labels):
        self.labels = labels
        self.label_indices = {label: torch.nonzero(labels == label).squeeze() for label in torch.unique(labels)}
        print(self.label_indices)

    def __iter__(self):
        positive_pairs = []
        negative_pairs = []

        for label, indices_per_label in self.label_indices.items():
            # Ensure there are at least two instances of the label for a positive pair
            if len(indices_per_label) >= 2:
                # Generate positive pairs
                pairs = torch.tensor(list(torch.combinations(indices_per_label, r=2)))
                positive_pairs.extend(pairs.tolist())

                # Generate negative pairs
                other_labels = torch.unique(self.labels[self.labels != label])
                for _ in range(len(pairs)):
                    random_negative_label = other_labels[torch.randint(len(other_labels), (1,))].item()
                    negative_indices = self.label_indices[random_negative_label]
                    negative_pairs.append((pairs[_][0].item(), negative_indices[torch.randint(len(negative_indices), (1,))].item()))

        all_pairs = positive_pairs + negative_pairs
        return iter(all_pairs)

    def __len__(self):
        max_positive = max(len(indices_per_label) // 2 for indices_per_label in self.label_indices.values() if len(indices_per_label) >= 2)
        max_negative = max(len(indices_per_label) * (len(indices_per_label) - 1) // 2 for indices_per_label in self.label_indices.values() if len(indices_per_label) >= 2)
        return len(self.label_indices) * max(0, max(max_positive, max_negative))


In [7]:
class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], self.labels[index]

In [8]:
def pixel_wise_contrastive_loss(z, pairs, temperature=1.0):
    z_flat = z.view(len(z), -1)

    # Extract representations for positive pairs
    z1 = z_flat[0, pairs[:, 0]]
    z2 = z_flat[0, pairs[:, 1]]

    # Compute cosine similarity between pixel representations
    similarities = F.cosine_similarity(z1, z2, dim=-1)

    # Apply temperature scaling
    similarities_scaled = similarities / temperature

    # Compute log probability for positive pairs
    log_prob_pos = F.log_softmax(similarities_scaled, dim=-1)

    # Negative pairs (pixels from the same image but not in the positive pairs)
    log_prob_neg = torch.log(torch.sum(torch.exp(similarities_scaled), dim=-1))

    # Compute loss
    loss = - (log_prob_pos - log_prob_neg)

    return loss.mean()

In [10]:
# Example usage
labels = torch.randint(6, (100,))  # Assuming 6 different labels
print("Labels:", labels)
sampler = PositiveNegativePairSampler(labels)

# Assuming you have a dataset with pixel representations (z) and corresponding labels
dataset = CustomDataset(torch.randn((100, 256)), labels)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

for batch in dataloader:
    z, labels_batch = batch
    positive_pairs = torch.tensor(list(sampler))
    loss = pixel_wise_contrastive_loss(z, positive_pairs)
    print("Pixel-wise Contrastive Loss:", loss.item())

Labels: tensor([3, 4, 5, 5, 4, 4, 2, 5, 3, 0, 1, 2, 5, 0, 3, 5, 5, 5, 0, 5, 0, 0, 5, 3,
        1, 2, 5, 2, 5, 1, 5, 4, 5, 5, 2, 4, 5, 4, 5, 0, 3, 5, 2, 0, 1, 1, 1, 3,
        3, 2, 2, 0, 3, 4, 2, 0, 4, 1, 0, 0, 1, 4, 1, 1, 5, 5, 3, 5, 3, 3, 3, 0,
        3, 0, 2, 5, 2, 1, 3, 4, 2, 5, 4, 5, 5, 4, 1, 5, 4, 4, 4, 1, 5, 2, 4, 0,
        1, 0, 0, 3])
{tensor(0): tensor([ 9, 13, 18, 20, 21, 39, 43, 51, 55, 58, 59, 71, 73, 95, 97, 98]), tensor(1): tensor([10, 24, 29, 44, 45, 46, 57, 60, 62, 63, 77, 86, 91, 96]), tensor(2): tensor([ 6, 11, 25, 27, 34, 42, 49, 50, 54, 74, 76, 80, 93]), tensor(3): tensor([ 0,  8, 14, 23, 40, 47, 48, 52, 66, 68, 69, 70, 72, 78, 99]), tensor(4): tensor([ 1,  4,  5, 31, 35, 37, 53, 56, 61, 79, 82, 85, 88, 89, 90, 94]), tensor(5): tensor([ 2,  3,  7, 12, 15, 16, 17, 19, 22, 26, 28, 30, 32, 33, 36, 38, 41, 64,
        65, 67, 75, 81, 83, 84, 87, 92])}


TypeError: only integer tensors of a single element can be converted to an index