https://github.com/KevinMusgrave/pytorch-metric-learning

In [None]:
!pip install pytorch-metric-learning > /dev/null
!pip install faiss-gpu > /dev/null

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set()

sns.set_style("whitegrid", {'axes.grid' : False})

# from tqdm.notebook import tqdm
from tqdm.auto import tqdm

# cuda

In [None]:
import torch
torch.cuda.is_available()

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Assuming that we are on a CUDA machine, this should print a CUDA device:

print(device)

# Data

normal mnist

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

In [None]:
from pytorch_metric_learning import losses, miners, distances, reducers, testers, samplers

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((64, 64)),
])

trainset = torchvision.datasets.Omniglot(
    root='./omniglot_data',
    background=True, # train
    download=True,
    transform=transform,
)

testset = torchvision.datasets.Omniglot(
    root='./omniglot_data',
    background=False, # test
    download=True,
    transform=transform,
)

In [None]:
targets = []
for _, l in trainset:
    targets.append(l)

len(targets), len(set(targets))

In [None]:
BATCH_SIZE = 128
SAMPLES_PER_CLASS = 16 # drawing 16 classes per batch

In [None]:
sampler = samplers.MPerClassSampler(targets, SAMPLES_PER_CLASS, batch_size=BATCH_SIZE, length_before_new_iter=BATCH_SIZE * 200) # 100 batches per epoch

In [None]:
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=BATCH_SIZE,
    num_workers=2,
    sampler=sampler,
)

testloader = torch.utils.data.DataLoader(
    testset, batch_size=BATCH_SIZE,
    num_workers=2,
    sampler=sampler,
)

In [None]:
def imshow(img):
    npimg = img.numpy()

    npimg = np.transpose(npimg, (1, 2, 0))
    plt.imshow(npimg)

    plt.show()

In [None]:
def imggrid(images, labels, predicted_labels):
    fig, axs = plt.subplots(figsize=(16, 16))
    for i in range(len(images[:16])):
        img = images[i]
        label = labels[i]
        predicted_label = predicted_labels[i]

        img = torch.squeeze(np.transpose(img, (1, 2, 0)))
        ax = plt.subplot(4, 4, i+1)
        ax.imshow(img, cmap='gray')

        color = 'black' if label == predicted_label else 'red'
        ax.set_title(f'True: {label}. Predicted: {predicted_label}', {'color': color})

In [None]:
dataiter = iter(trainloader)
images, labels = dataiter.next()

imggrid(images, labels.numpy(), labels.numpy())

# Network

In [None]:
# 1x64x64

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
class Net(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Net, self).__init__()
        
        self.layer = nn.Sequential(
            # N * channels_img * 64 * 64
            nn.Conv2d(
                channels_img,
                features_d,
                kernel_size=4,
                stride=2,
                padding=1,
            ),
            # features_d * 32 * 32
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d*2, 4, 2, 1), # 16 * 16
            self._block(features_d*2, features_d*4, 4, 2, 1), # 8 * 8
            self._block(features_d*4, features_d*8, 4, 2, 1), # 4 * 4 * features_d*8
            nn.Flatten(),
            nn.Linear(4*4*features_d*8, features_d*8),
        )

    def forward(self, x):
        x = self.layer(x)
        return x

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

# Train

In [None]:
FEATURES_DIM = 8
LR = 1e-3
NUM_EPOCHS = 150

In [None]:
model = Net(1, FEATURES_DIM).to(device)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=LR)

In [None]:
distance = distances.CosineSimilarity()
reducer = reducers.ThresholdReducer(low = 0)
loss_func = losses.TripletMarginLoss(margin = 0.2, distance = distance, reducer = reducer)
mining_func = miners.TripletMarginMiner(margin = 0.2, distance = distance, type_of_triplets = "semihard")

In [None]:
model.train()

for epoch in tqdm(range(NUM_EPOCHS)):
    for batch_idx, (data, labels) in enumerate(trainloader):
        data, labels = data.to(device), labels.to(device)

        embeddings = model(data)
        indices_tuple = mining_func(embeddings, labels)
        loss = loss_func(embeddings, labels, indices_tuple)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # if epoch % 5 == 0:
            # print("Epoch {} Iteration {}: Loss = {}, Number of mined triplets = {}".format(epoch, batch_idx, loss, mining_func.num_triplets))
    print("Epoch {} Loss = {}".format(epoch, loss.item()))

In [None]:
model.eval();

In [None]:
e = model(images.to(device)).to('cpu')

In [None]:
((e[0]- e[1])**2).sum()

In [None]:
((e[0]- e[31])**2).sum()

In [None]:
((e[30]- e[31])**2).sum()