# ArcFace Implementation on MNIST Dataset 

Paper : [https://arxiv.org/abs/1801.07698](https://arxiv.org/abs/1801.07698)

Training pipeline for a simple arcface cnn model on mnist dataset. As I was implementing first time I wanted to try first on mnist dataset just for practice. You can use the same for shopee competition. PyTorch framework is used for training. At last I had also plotted the image embeddings. 

kernel which helped me to implement arcface :
[https://www.kaggle.com/slawekbiel/arcface-explained](https://www.kaggle.com/slawekbiel/arcface-explained)

## Imports

In [None]:
import torch 
import torch.nn.functional as F

from torch import nn, optim 
from torch.utils.data import DataLoader
from torchvision import transforms as T, datasets

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import plotly.express as px

from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Transforms and dataset

In [None]:
transform = T.Compose([
    T.ToTensor(),
    T.Normalize((0.5,), (0.5,))
])

In [None]:
trainset = datasets.MNIST('../input/mnist-dataset-pytorch', train = True, transform = transform)
testset = datasets.MNIST('../input/mnist-dataset-pytorch', train = False, transform = transform)

## ArcFace CNN Model 

![](https://www.weak-learner.com/assets/img/blog/personal/arcface_archi.png)

In [None]:
class ArcFace(nn.Module):
    
    def __init__(self,in_features,out_features,margin = 0.7 ,scale = 64):
        super().__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.scale = scale
        self.margin = margin 
        
        self.weights = nn.Parameter(torch.FloatTensor(out_features,in_features))
        nn.init.xavier_normal_(self.weights)
        
    def forward(self,features,targets):
        cos_theta = F.linear(features,F.normalize(self.weights),bias=None)
        cos_theta = cos_theta.clip(-1+1e-7, 1-1e-7)
        
        arc_cos = torch.acos(cos_theta)
        M = F.one_hot(targets, num_classes = self.out_features) * self.margin
        arc_cos = arc_cos + M
        
        cos_theta_2 = torch.cos(arc_cos)
        logits = cos_theta_2 * self.scale
        return logits
    
    
class MNIST_Model(nn.Module):
    
    def __init__(self):
        super(MNIST_Model, self).__init__()

        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50,3)
        self.arc_face = ArcFace(in_features = 3, out_features = 10)
        
    def forward(self,features,targets = None):
        
        x = F.relu(F.max_pool2d(self.conv1(features), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        _,c,h,w = x.shape
        x = x.view(-1, c*h*w)
        x = F.relu(self.fc1(x))
        x = F.normalize(self.fc2(x))
        
        if targets is not None:
            logits = self.arc_face(x,targets)
            return logits
        return x

In [None]:
model = MNIST_Model()
model.to(device)

## Training

In [None]:
class TrainModel():
    
    def __init__(self,criterion = None,optimizer = None,schedular = None,device = None):
        self.criterion = criterion
        self.optimizer = optimizer
        self.schedular = schedular
        self.device = device
        
    def accuracy(self,logits,labels):
        ps = torch.argmax(logits,dim = 1).detach().cpu().numpy()
        acc = accuracy_score(ps,labels.detach().cpu().numpy())
        return acc

    def get_dataloader(self,trainset,validset):
        trainloader = DataLoader(trainset,batch_size = 64, num_workers = 4, pin_memory = True)
        validloader = DataLoader(validset,batch_size = 64, num_workers = 4, pin_memory = True)
        return trainloader, validloader
        
    def train_batch_loop(self,model,trainloader,i):
        
        epoch_loss = 0.0
        epoch_acc = 0.0
        pbar_train = tqdm(trainloader, desc = "Epoch" + " [TRAIN] " + str(i+1))
        
        for t,data in enumerate(pbar_train):
            
            images,labels = data
            images = images.to(device)
            labels = labels.to(device)
            
            logits = model(images,labels)
            loss = self.criterion(logits,labels)
            
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            
            epoch_loss += loss.item()
            epoch_acc += self.accuracy(logits,labels)
            
            pbar_train.set_postfix({'loss' : '%.6f' %float(epoch_loss/(t+1)), 'acc' : '%.6f' %float(epoch_acc/(t+1))})
            
        return epoch_loss / len(trainloader), epoch_acc / len(trainloader)
            
    
    def valid_batch_loop(self,model,validloader,i):
        
        epoch_loss = 0.0
        epoch_acc = 0.0
        pbar_valid = tqdm(validloader, desc = "Epoch" + " [VALID] " + str(i+1))
        
        for v,data in enumerate(pbar_valid):
            
            images,labels = data
            images = images.to(device)
            labels = labels.to(device)
            
            logits = model(images,labels)
            loss = self.criterion(logits,labels)
            
            epoch_loss += loss.item()
            epoch_acc += self.accuracy(logits,labels)
            
            pbar_valid.set_postfix({'loss' : '%.6f' %float(epoch_loss/(v+1)), 'acc' : '%.6f' %float(epoch_acc/(v+1))})
            
        return epoch_loss / len(validloader), epoch_acc / len(validloader)
            
    
    def run(self,model,trainset,validset,epochs):
    
        trainloader,validloader = self.get_dataloader(trainset,validset)
        
        for i in range(epochs):
            
            model.train()
            avg_train_loss, avg_train_acc = self.train_batch_loop(model,trainloader,i)
            
            model.eval()
            avg_valid_loss, avg_valid_acc = self.valid_batch_loop(model,validloader,i)
            
        return model 

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.0001)


model = TrainModel(criterion, optimizer, device).run(model, trainset, testset, 20)

## Image Embeddings

In [None]:
emb = []
y = []

testloader = DataLoader(testset,batch_size = 64)
with torch.no_grad():
    for images,labels in tqdm(testloader):
        
        images = images.to(device)
        embeddings = model(images)
        
        emb += [embeddings.detach().cpu()]
        y += [labels]
        
    embs = torch.cat(emb).cpu().numpy()
    y = torch.cat(y).cpu().numpy()

In [None]:
tsne_df = pd.DataFrame(
    np.column_stack((embs, y)),
    columns = ["x","y","z","targets"]
)

fig = px.scatter_3d(tsne_df, x='x', y='y', z='z',
              color='targets')
fig.show()