In [175]:
import torch
import math
from torch.nn import CrossEntropyLoss


def l2_norm(input,axis=-1):
    norm = torch.norm(input,2,axis,True)
    output = torch.div(input, norm)
    return output
class AdaFaceWAct(torch.nn.Module):
    ''' 
    1. Multiply embeddings with W (W phase)
    2. Compute Adaface Activate (like normalized softmax) (Act phase)
    '''
    def __init__(self,
                 embedding_size=512,
                 classnum=70722,
                 m=0.4,
                 h=0.333,
                 s=64.,
                 t_alpha=1.0,
                 ):
        super(AdaFaceWAct, self).__init__()
        self.classnum = classnum
        self.kernel = torch.nn.Parameter(torch.Tensor(embedding_size,classnum))

        # initial kernel
        self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
        self.m = m 
        self.eps = 1e-3
        self.h = h
        self.s = s

        # ema prep
        self.t_alpha = t_alpha
        self.register_buffer('t', torch.zeros(1))
        self.register_buffer('batch_mean', torch.ones(1)*(20))
        self.register_buffer('batch_std', torch.ones(1)*100)

        print('\n\AdaFaceWAct with the following property')
        print('self.m', self.m)
        print('self.h', self.h)
        print('self.s', self.s)
        print('self.t_alpha', self.t_alpha)

    def forward(self, embbedings, norms, label):
        kernel_norm = l2_norm(self.kernel,axis=0)
        cosine = torch.mm(embbedings,kernel_norm)
        cosine = cosine.clamp(-1+self.eps, 1-self.eps) # for stability

        safe_norms = torch.clip(norms, min=0.001, max=100) # for stability
        safe_norms = safe_norms.clone().detach()

        # update batchmean batchstd
        with torch.no_grad():
            mean = safe_norms.mean().detach()
            std = safe_norms.std().detach()
            self.batch_mean = mean * self.t_alpha + (1 - self.t_alpha) * self.batch_mean
            self.batch_std =  std * self.t_alpha + (1 - self.t_alpha) * self.batch_std

        margin_scaler = (safe_norms - self.batch_mean) / (self.batch_std+self.eps) # 66% between -1, 1
        margin_scaler = margin_scaler * self.h # 68% between -0.333 ,0.333 when h:0.333
        margin_scaler = torch.clip(margin_scaler, -1, 1)
        # ex: m=0.5, h:0.333
        # range
        #       (66% range)
        #   -1 -0.333  0.333   1  (margin_scaler)
        # -0.5 -0.166  0.166 0.5  (m * margin_scaler)

        # g_angular
        m_arc = torch.zeros(label.size()[0], cosine.size()[1], device=cosine.device)
        m_arc.scatter_(1, label.reshape(-1, 1), 1.0)
        g_angular = self.m * margin_scaler * -1
        m_arc = m_arc * g_angular
        theta = cosine.acos()
        theta_m = torch.clip(theta + m_arc, min=self.eps, max=math.pi-self.eps)
        cosine = theta_m.cos()

        # g_additive
        m_cos = torch.zeros(label.size()[0], cosine.size()[1], device=cosine.device)
        m_cos.scatter_(1, label.reshape(-1, 1), 1.0)
        g_add = self.m + (self.m * margin_scaler)
        m_cos = m_cos * g_add
        cosine = cosine - m_cos
        # scale
        scaled_cosine_m = cosine * self.s
        return scaled_cosine_m

cross_entropy_loss = CrossEntropyLoss()


In [176]:
adaface_w_act = AdaFaceWAct(embedding_size=512,
                 classnum=10,
                 m=0.4,
                 h=0.333,
                 s=64.,
                 t_alpha=1.0)


\AdaFaceWAct with the following property
self.m 0.4
self.h 0.333
self.s 64.0
self.t_alpha 1.0


In [177]:
# Dummy input/labels
embs = torch.randn(2,512)
norm = torch.norm(embs,2,dim=-1,keepdim=True)
embs = torch.div(embs, norm)
labels = torch.tensor([5,9])
print("embs.shape: ",embs.shape) 
print("norm.shape: ",norm.shape) 
print("labels.shape: ",labels.shape) 

embs.shape:  torch.Size([2, 512])
norm.shape:  torch.Size([2, 1])
labels.shape:  torch.Size([2])


In [178]:
logits = adaface_w_act(embs,norm,labels)
logits.shape

torch.Size([2, 10])

In [181]:
print(labels.shape)
loss_train = cross_entropy_loss(logits, labels)
print(loss_train)

torch.Size([2])
tensor(25.2385, grad_fn=<NllLossBackward>)


## Seperate W and Act

In [188]:
from typing import Callable

def l2_norm(input,axis=-1):
    norm = torch.norm(input,2,axis,True)
    output = torch.div(input, norm)
    return output
    
class AdaFC(torch.nn.Module):
    ''' 
    1. Multiply embeddings with W (FC phase)
    2. Compute Adaface Activate (like normalized softmax) (Act phase)
    '''
    def __init__(self,
                 margin_loss: Callable,
                 embedding_size=512,
                 classnum=70722,
                 ):
        super(AdaFC, self).__init__()
        self.classnum = classnum
        self.kernel = torch.nn.Parameter(torch.Tensor(embedding_size,classnum))
        self.dist_cross_entropy = CrossEntropyLoss()

        # margin_loss
        if isinstance(margin_loss, Callable):
            self.margin_softmax = margin_loss
        else:
            raise

        # initial kernel
        self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
        self.eps = 1e-3

    def forward(self, embbedings, norms, label):
        kernel_norm = l2_norm(self.kernel,axis=0)
        logits = torch.mm(embbedings,kernel_norm)
        logits = logits.clamp(-1+self.eps, 1-self.eps) # for stability

        logits = self.margin_softmax(logits,norms, labels)
        loss = self.dist_cross_entropy(logits, labels)
        return loss

class AdaAct(torch.nn.Module):
    ''' 
    1. Multiply embeddings with W (FC phase)
    2. Compute Adaface Activate (like normalized softmax) (Act phase)
    '''
    def __init__(self,
                 m=0.4,
                 h=0.333,
                 s=64.,
                 t_alpha=1.0,
                 ):
        super(AdaAct, self).__init__()
        self.m = m 
        self.eps = 1e-3
        self.h = h
        self.s = s
        self.theta = math.cos(math.pi - m)
        
        self.easy_margin = False
        # ema prep
        self.t_alpha = t_alpha
        self.register_buffer('batch_mean_z', torch.ones(1)*(20))
        self.register_buffer('batch_std_z', torch.ones(1)*100)

        print('\n\AdaFaceWAct with the following property')
        print('self.m', self.m)
        print('self.h', self.h)
        print('self.s', self.s)
        print('self.t_alpha', self.t_alpha)

    def forward(self, logits:torch.Tensor, norms:torch.Tensor, labels:torch.Tensor):
        logits = logits.clamp(-1+self.eps, 1-self.eps) # for stability
        
        safe_norms = torch.clip(norms, min=0.001, max=100) # for stability
        safe_norms = safe_norms.clone().detach()

        # update batchmean batchstd
        with torch.no_grad():
            mean_z = safe_norms.mean().detach()
            std_z = safe_norms.std().detach()
            self.batch_mean_z = mean_z * self.t_alpha + (1 - self.t_alpha) * self.batch_mean_z
            self.batch_std_z =  std_z * self.t_alpha + (1 - self.t_alpha) * self.batch_std_z


        z = (safe_norms - self.batch_mean_z) / (self.batch_std_z+self.eps)
        z = z * self.h 
        z = torch.clip(z, -1, 1)

        # g_angular shape(2,1)
        g_angular = - self.m * z 
        g_angular = g_angular.reshape(-1)
        
        index = torch.where(labels != -1)[0]
        target_logits = logits[index, labels[index].view(-1)]

        theta = target_logits.acos()
        theta_m = torch.clip(theta + g_angular, min=self.eps, max=math.pi-self.eps)######
        target_logits_angular = theta_m.cos()

        # g_additive sahpe(2,1)
        g_add = self.m + (self.m * z)
        g_add = g_add.reshape(-1)
        target_logits_add = target_logits_angular - g_add
        # this is not easy_marin in arcface
        gap_ = 1 - self.m*z - self.m - (self.m*z).cos()
        gap_ = gap_.reshape(-1)

        final_target_logits = torch.where(theta + g_angular > 0, target_logits_add, target_logits+gap_)

        logits[index, labels[index].view(-1)] = final_target_logits
        logits = logits * self.s
        return logits


In [189]:
ada_act = AdaAct(m=0.4, h=0.333, s=64., t_alpha=1.0)
ada_fc = AdaFC(margin_loss=ada_act, embedding_size=512, classnum=10)


\AdaFaceWAct with the following property
self.m 0.4
self.h 0.333
self.s 64.0
self.t_alpha 1.0


In [190]:
loss = ada_fc(embs,norm,labels)
loss

tensor(24.5665, grad_fn=<NllLossBackward>)