In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import numpy as np
from torch.utils.data import Sampler

class TripletSampler(Sampler):
    def __init__(self, labels, num_samples):
        self.labels = labels
        self.num_samples = num_samples

    def __iter__(self):
        # Create a dictionary storing for each label the indices of its corresponding samples
        label_to_indices = {label: np.where(self.labels == label)[0] for label in set(self.labels)}

        # Create a list to store the triplets
        triplets = []

        # For each sample, create a triplet
        for idx in range(len(self.labels)):
            pos_indices = label_to_indices[self.labels[idx]].tolist()
            pos_indices.remove(idx)  # remove the sample itself from positive examples
            if len(pos_indices) == 0:  # if no positive examples available, skip this sample
                continue
            neg_indices = np.where(self.labels != self.labels[idx])[0]
            pos = np.random.choice(pos_indices)
            neg = np.random.choice(neg_indices)
            triplets.append([idx, pos, neg])

        return iter(np.random.choice(triplets, self.num_samples))

    def __len__(self):
        return self.num_samples

class EmbeddingNet(nn.Module):
    def __init__(self):
        super(EmbeddingNet, self).__init__()
        self.convnet = nn.Sequential(nn.Conv2d(1, 32, 5), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2),
                                     nn.Conv2d(32, 64, 5), nn.PReLU(),
                                     nn.MaxPool2d(2, stride=2))

        self.fc = nn.Sequential(nn.Linear(64 * 4 * 4, 256),
                                nn.PReLU(),
                                nn.Linear(256, 256),
                                nn.PReLU(),
                                nn.Linear(256, 2)
                                )

    def forward(self, x):
        output = self.convnet(x)
        output = output.view(output.size()[0], -1)
        output = self.fc(output)
        return output

class TripletNet(nn.Module):
    def __init__(self, embedding_net):
        super(TripletNet, self).__init__()
        self.embedding_net = embedding_net

    def forward(self, x1, x2, x3):
        output1 = self.embedding_net(x1)
        output2 = self.embedding_net(x2)
        output3 = self.embedding_net(x3)
        return output1, output2, output3

    def get_embedding(self, x):
        return self.embedding_net(x)

def triplet_loss(anchor, positive, negative, size_average=True):
    margin = 1.0
    distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
    distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
    losses = F.relu(distance_positive - distance_negative + margin)
    return losses.mean() if size_average else losses.sum()

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

labels = #  
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
sampler = TripletSampler(labels, num_samples=10000)
dataloader = DataLoader(trainset, batch_size=64, sampler=sampler)

# Assume 28x28x1 input images
embedding_net = EmbeddingNet()
model = TripletNet(embedding_net)

# Define optimizer
optimizer = Adam(model.parameters(), lr=0.001)

epochs = 100
# Train the model
for epoch in range(epochs):
    for x in dataloader
        optimizer.zero_grad()
        anchor, positive, negative = x  # Implement this function to get your triplets
        anchor_out, positive_out, negative_out = model(anchor, positive, negative)
        loss = triplet_loss(anchor_out, positive_out, negative_out)
        loss.backward()
        optimizer.step()