<a href="https://colab.research.google.com/github/ykato27/Metric-Learning/blob/main/PyTroch_Metric_Learning_TripletMarginLoss_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pytorch-metric-learning
!pip install faiss-gpu

Collecting pytorch-metric-learning
  Downloading pytorch_metric_learning-0.9.99-py3-none-any.whl (105 kB)
[?25l[K     |███▏                            | 10 kB 20.5 MB/s eta 0:00:01[K     |██████▎                         | 20 kB 9.0 MB/s eta 0:00:01[K     |█████████▍                      | 30 kB 6.5 MB/s eta 0:00:01[K     |████████████▌                   | 40 kB 3.4 MB/s eta 0:00:01[K     |███████████████▋                | 51 kB 3.9 MB/s eta 0:00:01[K     |██████████████████▊             | 61 kB 4.3 MB/s eta 0:00:01[K     |█████████████████████▉          | 71 kB 4.2 MB/s eta 0:00:01[K     |█████████████████████████       | 81 kB 4.3 MB/s eta 0:00:01[K     |████████████████████████████    | 92 kB 4.8 MB/s eta 0:00:01[K     |███████████████████████████████▏| 102 kB 4.2 MB/s eta 0:00:01[K     |████████████████████████████████| 105 kB 4.2 MB/s 
Installing collected packages: pytorch-metric-learning
Successfully installed pytorch-metric-learning-0.9.99
Collecting faiss

In [2]:
from pytorch_metric_learning import losses, miners, distances, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### 
from torchvision import datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np

### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### 
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### 
def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        embeddings = model(data)
        indices_tuple = mining_func(embeddings, labels)
        loss = loss_func(embeddings, labels, indices_tuple)
        loss.backward()
        optimizer.step()
        if batch_idx % 20 == 0:
            print("Epoch {} Iteration {}: Loss = {}, Number of mined triplets = {}".format(epoch, batch_idx, loss, mining_func.num_triplets))

### convenient function from pytorch-metric-learning ###
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester()
    return tester.get_all_embeddings(dataset, model)

### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###
def test(train_set, test_set, model, accuracy_calculator):
    train_embeddings, train_labels = get_all_embeddings(train_set, model)
    test_embeddings, test_labels = get_all_embeddings(test_set, model)
    train_labels = train_labels.squeeze(1)
    test_labels = test_labels.squeeze(1)
    print("Computing accuracy")
    accuracies = accuracy_calculator.get_accuracy(test_embeddings, 
                                                train_embeddings,
                                                test_labels,
                                                train_labels,
                                                False)
    print("Test set accuracy (Precision@1) = {}".format(accuracies["precision_at_1"]))

device = torch.device("cuda")

transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

batch_size = 256

dataset1 = datasets.MNIST('.', train=True, download=True, transform=transform)
dataset2 = datasets.MNIST('.', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1, batch_size=256, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size=256)

model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 1


### pytorch-metric-learning stuff ###
distance = distances.CosineSimilarity()
reducer = reducers.ThresholdReducer(low = 0)
loss_func = losses.TripletMarginLoss(margin = 0.2, distance = distance, reducer = reducer)
mining_func = miners.TripletMarginMiner(margin = 0.2, distance = distance, type_of_triplets = "semihard")
accuracy_calculator = AccuracyCalculator(include = ("precision_at_1",), k = 1)
### pytorch-metric-learning stuff ###


for epoch in range(1, num_epochs+1):
    train(model, loss_func, mining_func, device, train_loader, optimizer, epoch)
    test(dataset1, dataset2, model, accuracy_calculator)



Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Epoch 1 Iteration 0: Loss = 0.10558018088340759, Number of mined triplets = 749398
Epoch 1 Iteration 20: Loss = 0.09165391325950623, Number of mined triplets = 94051
Epoch 1 Iteration 40: Loss = 0.08924613893032074, Number of mined triplets = 73120
Epoch 1 Iteration 60: Loss = 0.08542117476463318, Number of mined triplets = 41243
Epoch 1 Iteration 80: Loss = 0.083809994161129, Number of mined triplets = 47270
Epoch 1 Iteration 100: Loss = 0.08540888875722885, Number of mined triplets = 35164
Epoch 1 Iteration 120: Loss = 0.08405602723360062, Number of mined triplets = 27674
Epoch 1 Iteration 140: Loss = 0.08477538824081421, Number of mined triplets = 37812
Epoch 1 Iteration 160: Loss = 0.08108657598495483, Number of mined triplets = 31249
Epoch 1 Iteration 180: Loss = 0.08339288830757141, Number of mined triplets = 29493
Epoch 1 Iteration 200: Loss = 0.07720369100570679, Number of mined triplets = 16588
Epoch 1 Iteration 220: Loss = 0.0861109122633934, Number of mined triplets = 31505


  cpuset_checked))
100%|██████████| 1875/1875 [00:18<00:00, 104.01it/s]
100%|██████████| 313/313 [00:04<00:00, 76.30it/s] 


Computing accuracy
Test set accuracy (Precision@1) = 0.9814
