# Contrastive Loss

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np

# Example dataset
class ChemicalPairsDataset(Dataset):
    def __init__(self, fingerprint_pairs, labels):
        """
        fingerprint_pairs: List of (fp1, fp2) pairs (e.g., numpy arrays)
        labels: List of 1 (similar) or 0 (dissimilar) for each pair
        """
        self.fingerprint_pairs = fingerprint_pairs
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        fp1, fp2 = self.fingerprint_pairs[idx]
        label = self.labels[idx]
        return torch.tensor(fp1, dtype=torch.float32), torch.tensor(fp2, dtype=torch.float32), label


# Define the model
class ContrastiveNetwork(nn.Module):
    def __init__(self, input_dim, embedding_dim):
        super(ContrastiveNetwork, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, embedding_dim)
        )

    def forward(self, x):
        return self.encoder(x)

# Contrastive loss function
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, embedding1, embedding2, label):
        distance = torch.nn.functional.pairwise_distance(embedding1, embedding2)
        loss = torch.mean(
            label * torch.pow(distance, 2) +  # Pull similar pairs
            (1 - label) * torch.pow(torch.clamp(self.margin - distance, min=0.0), 2)  # Push dissimilar pairs
        )
        return loss

# Example usage
# Generate synthetic fingerprints for demonstration
np.random.seed(42)
fingerprints = [np.random.rand(1024) for _ in range(100)]  # Replace with actual fingerprints
pairs = [(fingerprints[i], fingerprints[i + 1]) for i in range(0, 99, 2)]
labels = [1 if i % 2 == 0 else 0 for i in range(len(pairs))]  # Dummy labels

# Create DataLoader
dataset = ChemicalPairsDataset(pairs, labels)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Initialize model and optimizer
input_dim = 1024
embedding_dim = 128
model = ContrastiveNetwork(input_dim, embedding_dim)
criterion = ContrastiveLoss(margin=1.0)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(10):
    total_loss = 0.0
    for fp1, fp2, label in dataloader:
        optimizer.zero_grad()
        emb1 = model(fp1)
        emb2 = model(fp2)
        loss = criterion(emb1, emb2, label.float())
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(dataloader)}")

# Extract embeddings for downstream tasks
with torch.no_grad():
    test_fp = torch.tensor(fingerprints[0], dtype=torch.float32).unsqueeze(0)
    embedding = model(test_fp)
    print(f"Embedding: {embedding}")

Epoch 1, Loss: 0.3436333356159074
Epoch 2, Loss: 0.2594407738319465
Epoch 3, Loss: 0.10254170213426862
Epoch 4, Loss: 0.09385412532304015
Epoch 5, Loss: 0.04112052909372973
Epoch 6, Loss: 0.0378460675378197
Epoch 7, Loss: 0.04739623440296522
Epoch 8, Loss: 0.01095719076693058
Epoch 9, Loss: 0.010216679224478347
Epoch 10, Loss: 0.006144822067913732
Embedding: tensor([[-0.1480,  0.1684,  0.0615, -0.0494, -0.1677, -0.0519, -0.0237, -0.0431,
          0.2677, -0.1457, -0.2233,  0.1102,  0.1390, -0.2684,  0.2947, -0.1332,
          0.0673, -0.1382,  0.3326, -0.1083, -0.0734, -0.0244,  0.2734, -0.0089,
         -0.1972, -0.1327, -0.0504,  0.3559,  0.0930, -0.3089, -0.0335,  0.2512,
         -0.1602, -0.2937,  0.1157,  0.2944,  0.2510, -0.1548,  0.4781, -0.2169,
         -0.0094,  0.4833,  0.2468, -0.0541,  0.1635, -0.0592, -0.3351,  0.0754,
         -0.1574,  0.0169, -0.0236, -0.3096,  0.3379, -0.0285,  0.1350,  0.1275,
          0.0137,  0.2121,  0.2314,  0.1047,  0.2154,  0.2253,  0.4446, 

In [4]:
pairs[0]

(array([0.37454012, 0.95071431, 0.73199394, ..., 0.29734901, 0.9243962 ,
        0.97105825]),
 array([0.94426649, 0.47421422, 0.86204265, ..., 0.7228143 , 0.06766836,
        0.7078351 ]))