## Imports

In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter

## ArcFace class 

In [170]:
class ArcFace(nn.Module):
    """
    ArcFace Layer for face recognition.
    This layer is used for face recognition tasks. It computes the class logits using the ArcFace algorithm.

    Attributes:
        emb_size (int): The embedding size (the number of features extracted from CNN).
        num_classes (int): The number of classes.
        s (float): The radius of the projected hypersphere.
        m (float): The arc margin in radians.
    """
    def __init__(self, emb_size, num_classes, s=64.0, m=0.50):
        """
        Constructor for ArcFace Layer.

        Parameters:
            emb_size (int): The embedding size (the number of features extracted from CNN).
            num_classes (int): The number of classes.
            s (float): The radius of the projected hypersphere. Default is 64.0.
            m (float): The arc margin in radians. Default is 0.50.
        """
        # inherit from base class
        super(ArcFace, self).__init__()
        # save configuration
        self.num_classes = num_classes
        self.emb_size = emb_size
        self.s = s
        self.m = m
        # formulate weight tensor
        self.weights = Parameter(torch.FloatTensor(num_classes, emb_size))
        # initialize weights
        nn.init.xavier_uniform_(self.weights)

    def forward(self, embedding, gt):
        """
        Computes the forward pass using the Arcface loss and returns the class logits.

        Parameters:
            embedding (torch.Tensor): Extracted embeddings.
            gt (torch.Tensor): Ground truth labels.

        Returns:
            torch.Tensor: Computed class logits through ArcFace Algorithm.
        """
        print(f"batch dim = {embedding.shape[0]}")
        print(f"Embedding size = {self.emb_size}")
        print(f"Shape of embedding {embedding.shape}")
        print(f"No of classes = {self.num_classes}")
        print(f"shape of weights {self.weights.shape}")
        
        fc7 = F.linear(F.normalize(embedding, dim=1), F.normalize(self.weights, dim=1), bias=None)
        print(f"Shape of logits = {fc7.shape}")
        # pick logit at class index
        one_hot = F.one_hot(gt, self.num_classes)
        print(f"One hot encoded labels are {one_hot}")
        print(f"Logits of all classes {fc7}")
        original_target_logit = fc7[one_hot > 0]
        print(f"original target logit = {original_target_logit}")
        # theta
        eps = 1e-10
        # clip logits to prevent zero division when backward
        theta = torch.acos(torch.clamp(original_target_logit, -1.0 + eps, 1.0 - eps))
        # marginal_target_logit
        marginal_target_logit = torch.cos(theta + self.m)
        
        # update fc7
        diff = marginal_target_logit - original_target_logit
        print(f"diff in raw form is {diff}")
        print("diff after unsqueezing is {}".format(torch.unsqueeze(diff, dim=1)))
        print(f"fc7 originally is {fc7}")
        print("Multiplication of one hot encoding with diff: {}".format(torch.mul(one_hot, torch.unsqueeze(diff, dim=1))))
        fc7 = fc7 + torch.mul(one_hot, torch.unsqueeze(diff, dim=1))
        print(f"fc7 after inclusion of margin in the ground truth logits \n: {fc7}")
        # scaling
        fc7 *= self.s
        return fc7
        
    def get_weights(self):
        """
        Returns a deep copy of the weights which serve as class centroids.

        Returns:
            torch.Tensor: Deep copy of the weights.
        """
        return self.weights.clone()


### Create the embedding matrix for a batch dimension of 3

In [171]:
batch_dim = 2
num_classes = 3
emb_size = 5
embedding_matrix = torch.rand(batch_dim, emb_size)
print(embedding_matrix)

tensor([[0.0986, 0.9394, 0.7522, 0.0020, 0.7245],
        [0.8452, 0.2748, 0.1371, 0.2975, 0.5863]])


### Create the corresponding ground truth label vector
##### 3-class problem [0, 1, 2]

In [172]:

ground_truth_vector = torch.tensor([1, 2])
print(F.one_hot(ground_truth_vector))
#ground_truth_vector = torch.tensor([[0, 1, 2]])
print(ground_truth_vector)
print(f"Shape of ground truth vector is {ground_truth_vector.shape}")

tensor([[0, 1, 0],
        [0, 0, 1]])
tensor([1, 2])
Shape of ground truth vector is torch.Size([2])


In [173]:

arcface_loss = ArcFace(emb_size, num_classes, s=64.0, m=0.50)
arcface_loss.forward(embedding_matrix, ground_truth_vector)

batch dim = 2
Embedding size = 5
Shape of embedding torch.Size([2, 5])
No of classes = 3
shape of weights torch.Size([3, 5])
Shape of logits = torch.Size([2, 3])
One hot encoded labels are tensor([[0, 1, 0],
        [0, 0, 1]])
Logits of all classes tensor([[ 0.1924,  0.6971,  0.3102],
        [ 0.2836,  0.5013, -0.3012]], grad_fn=<MmBackward0>)
original target logit = tensor([ 0.6971, -0.3012], grad_fn=<IndexBackward0>)
diff in raw form is tensor([-0.4291, -0.4203], grad_fn=<SubBackward0>)
diff after unsqueezing is tensor([[-0.4291],
        [-0.4203]], grad_fn=<UnsqueezeBackward0>)
fc7 originally is tensor([[ 0.1924,  0.6971,  0.3102],
        [ 0.2836,  0.5013, -0.3012]], grad_fn=<MmBackward0>)
Multiplication of one hot encoding with diff: tensor([[-0.0000, -0.4291, -0.0000],
        [-0.0000, -0.0000, -0.4203]], grad_fn=<MulBackward0>)
fc7 after inclusion of margin in the ground truth logits 
: tensor([[ 0.1924,  0.2680,  0.3102],
        [ 0.2836,  0.5013, -0.7215]], grad_fn=<AddB

tensor([[ 12.3137,  17.1532,  19.8498],
        [ 18.1530,  32.0855, -46.1732]], grad_fn=<MulBackward0>)