In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import Sampler, Dataset, DataLoader

In [10]:
class PositiveNegativePairSampler(Sampler):
    def __init__(self, labels):
        self.labels = labels
        self.label_indices = {label: torch.nonzero(labels == label).squeeze() for label in torch.unique(labels)}

    def __iter__(self):
        positive_pairs = []
        negative_pairs = []

        for label, indices_per_label in self.label_indices.items():
            # Ensure there are at least two instances of the label for a positive pair
            if len(indices_per_label) >= 2:
                # Generate positive pairs
                pairs = torch.tensor(list(torch.combinations(indices_per_label, r=2)))
                positive_pairs.extend(pairs.tolist())

                # Generate negative pairs
                other_labels = torch.unique(self.labels[self.labels != label])
                for _ in range(len(pairs)):
                    random_negative_label = other_labels[torch.randint(len(other_labels), (1,))].item()
                    negative_indices = self.label_indices[random_negative_label]
                    negative_pairs.append((pairs[_][0].item(), negative_indices[torch.randint(len(negative_indices), (1,))].item()))

        all_pairs = positive_pairs + negative_pairs
        return iter(all_pairs)

    def __len__(self):
        max_positive = max(len(indices_per_label) // 2 for indices_per_label in self.label_indices.values() if len(indices_per_label) >= 2)
        max_negative = max(len(indices_per_label) * (len(indices_per_label) - 1) // 2 for indices_per_label in self.label_indices.values() if len(indices_per_label) >= 2)
        return len(self.label_indices) * max(0, max(max_positive, max_negative))


In [11]:
class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], self.labels[index]

In [12]:
def pixel_wise_contrastive_loss(z, pairs, temperature=1.0):
    z_flat = z.view(len(z), -1)

    # Extract representations for positive pairs
    z1 = z_flat[0, pairs[:, 0]]
    z2 = z_flat[0, pairs[:, 1]]

    # Compute cosine similarity between pixel representations
    similarities = F.cosine_similarity(z1, z2, dim=-1)

    # Apply temperature scaling
    similarities_scaled = similarities / temperature

    # Compute log probability for positive pairs
    log_prob_pos = F.log_softmax(similarities_scaled, dim=-1)

    # Negative pairs (pixels from the same image but not in the positive pairs)
    log_prob_neg = torch.log(torch.sum(torch.exp(similarities_scaled), dim=-1))

    # Compute loss
    loss = - (log_prob_pos - log_prob_neg)

    return loss.mean()

In [17]:
# Example usage
labels = torch.randint(6, (100,))  # Assuming 6 different labels
sampler = PositivePairSampler(labels)

# Assuming you have a dataset with pixel representations (z) and corresponding labels
dataset = CustomDataset(torch.randn((100, 256)), labels)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

for batch in dataloader:
    z, labels_batch = batch
    positive_pairs = torch.tensor(list(sampler))
    loss = pixel_wise_contrastive_loss(z, positive_pairs)
    print("Pixel-wise Contrastive Loss:", loss.item())

Pixel-wise Contrastive Loss: -0.02354308031499386
Pixel-wise Contrastive Loss: -0.009365047328174114
Pixel-wise Contrastive Loss: 0.05583999678492546
Pixel-wise Contrastive Loss: -0.029539600014686584
Pixel-wise Contrastive Loss: 0.0017805927200242877
Pixel-wise Contrastive Loss: -0.009369559586048126
Pixel-wise Contrastive Loss: -0.008453655056655407
Pixel-wise Contrastive Loss: -0.009281601756811142
Pixel-wise Contrastive Loss: 0.0022351776715368032
Pixel-wise Contrastive Loss: 0.023846834897994995
Pixel-wise Contrastive Loss: 0.03041231818497181
Pixel-wise Contrastive Loss: -0.02649216540157795
Pixel-wise Contrastive Loss: -0.008021372370421886
Pixel-wise Contrastive Loss: -0.04994263872504234
Pixel-wise Contrastive Loss: -0.03785321116447449
Pixel-wise Contrastive Loss: -0.011963598430156708
Pixel-wise Contrastive Loss: 0.01683914288878441
Pixel-wise Contrastive Loss: 0.0070340619422495365
Pixel-wise Contrastive Loss: 0.009985729120671749
Pixel-wise Contrastive Loss: -0.00825302582