# Triplet Loss
Metric learning by MNSIT dataset (30min.)

In [None]:
%%shell
git clone -b beta https://github.com/tky823/DNN-based_source_separation.git

In [None]:
%cd "/content/DNN-based_source_separation/egs/tutorials/triplet-loss"

In [None]:
import sys
sys.path.append("/content/DNN-based_source_separation/src")

In [None]:
import random

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
plt.rcParams['font.size'] = 18

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torchvision
from torchvision import transforms

In [None]:
from criterion.metric_learn import TripletLoss

## Generate dataset

In [None]:
class TripletMNIST:
    def __init__(self, root="./", train=True, transform=None, download=True, num_samples=None):
        self.original_dataset = torchvision.datasets.MNIST(root=root, train=train, transform=transform, download=download)
        self.n_class = len(self.original_dataset.classes)

        self.class_list = list(range(self.n_class))
        self.target_list = []
        for class_idx in self.class_list:
            self.target_list.append(torch.where(self.original_dataset.targets==class_idx)[0].tolist())

        if num_samples is None:
            self.num_samples = len(self.original_dataset)
        else:
            self.num_samples = num_samples
    
    def __getitem__(self, idx):
        random.shuffle(self.class_list)
        positive_class = self.class_list[0]
        negative_class = self.class_list[1]

        anchor_idx, positive_idx = random.sample(self.target_list[positive_class], 2)
        negative_idx = random.choice(self.target_list[negative_class])

        (anchor, _), (positive, _), (negative, _) = self.original_dataset[anchor_idx], self.original_dataset[positive_idx], self.original_dataset[negative_idx]
        return anchor, positive, negative
    
    def __len__(self):
        return self.num_samples

In [None]:
random.seed(111)
torch.manual_seed(111)
num_samples = 500000
batch_size = 64

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [None]:
train_dataset = TripletMNIST(root="./", train=True, transform=transform, num_samples=num_samples)
test_dataset = torchvision.datasets.MNIST(root="./", train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)

In [None]:
for anchor, positive, negative in train_loader:
    fig, axes = plt.subplots(1, 3, figsize=(9, 3))
    axes[0].imshow(anchor[0, 0])
    axes[1].imshow(positive[0, 0])
    axes[2].imshow(negative[0, 0])
    break

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dropout=0.3):
        super().__init__()
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride)
        self.prelu = nn.PReLU()
        self.pool2d = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, input):
        x = self.conv2d(input)
        x = self.prelu(x)
        x = self.pool2d(x)
        output = self.dropout(x)

        return output

class BasicModel(nn.Module):
    def __init__(self, embed_dim=2, dropout=0.3):
        super().__init__()

        net = []
        net.append(ConvBlock(1, 32, 5))
        net.append(ConvBlock(32, 64, 5))

        fc_net = []
        fc_net.append(nn.Linear(64*4*4, 512))
        fc_net.append(nn.PReLU())
        fc_net.append(nn.Linear(512, embed_dim))

        self.net = nn.Sequential(*net)
        self.fc_net = nn.Sequential(*fc_net)
        
    def forward(self, input):
        x = self.net(input)
        x = x.view(-1, 64*4*4)
        output = self.fc_net(x)
        
        return output

In [None]:
model = BasicModel()

In [None]:
print(model)

## Training
Embed image (28x28) in 2D

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = TripletLoss()

In [None]:
model.train()

train_loss = []
for idx, (input_anchor, input_positive, input_negative) in enumerate(train_loader):
    optimizer.zero_grad()

    output_anchor = model(input_anchor)
    output_positive = model(input_positive)
    output_negative = model(input_negative)
    
    loss = criterion(output_anchor, output_positive, output_negative)
    loss.backward()
    optimizer.step()

    if (idx + 1) % 100 == 0:
        print("{}/{} Loss: {:.5f}".format(idx + 1, len(train_loader), loss.item()))
    
    train_loss.append(loss.item())

In [None]:
train_loss = np.array(train_loss)
average_loss = 0

for i in range(100):
    average_loss = average_loss + train_loss[i: -100 + i]

average_loss = average_loss / 100

In [None]:
plt.figure(figsize=(12, 8))
plt.plot(train_loss, color='deepskyblue')
plt.plot(range(100, len(train_loss)), average_loss, color='black')
plt.show()

## Test

In [None]:
model.eval()

x = []
labels = []

with torch.no_grad():
    for input, target in test_loader:
        output = model(input)
        x.append(output.squeeze(dim=0).numpy())
        label = target.squeeze(dim=0).item()
        labels.append(label)

x = np.array(x)
labels = np.array(labels)

In [None]:
plt.figure(figsize=(12, 8))

for class_idx, label in enumerate(test_dataset.classes):
    x_class = x[labels == class_idx]
    plt.scatter(x_class[:, 0], x_class[:, 1], label=label)

plt.legend()
plt.show()