In [1]:
import os
import cv2
import math
import time
import tarfile
import numbers
import threading
import queue as Queue
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import download_url
#from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split, DataLoader, Dataset
from torchsummary import summary

In [2]:
class CosFace(torch.nn.Module):
    def __init__(self, s=64.0, m=0.40):
        super(CosFace, self).__init__()
        self.s = s
        self.m = m

    def forward(self, logits, labels):
        index = torch.where(labels != -1)[0]
        target_logit = logits[index, labels[index].view(-1)]
        final_target_logit = target_logit - self.m
        logits[index, labels[index].view(-1)] = final_target_logit
        logits = logits * self.s
        return logits

In [3]:
loss = CosFace()

In [39]:
logits = torch.randn(3, 5)
labels = torch.empty(3, dtype=torch.long).random_(5)

In [40]:
logits.dtype, labels.dtype

(torch.float32, torch.int64)

In [41]:
output = loss(logits, labels) #input: logits, target: labels

In [42]:
print(logits.shape, labels.shape, output.shape)

torch.Size([3, 5]) torch.Size([3]) torch.Size([3, 5])


In [43]:
print('input: ', logits, '\n target: ', labels, '\n output: ', output)

input:  tensor([[ 1.0348, -0.0898,  0.1454, -0.3221,  1.9344],
        [-0.2144,  0.1993,  0.0426,  1.5271,  0.1213],
        [-1.3730,  0.6963, -2.3733,  0.1223, -0.5391]]) 
 target:  tensor([2, 0, 3]) 
 output:  tensor([[  66.2279,   -5.7441,    9.3073,  -20.6170,  123.7996],
        [ -13.7212,   12.7547,    2.7239,   97.7355,    7.7600],
        [ -87.8722,   44.5626, -151.8880,    7.8291,  -34.5045]])


In [59]:
index = torch.where(labels != -1)[0]
index

tensor([0, 1, 2])

In [45]:
labels[index].view(-1)

tensor([2, 0, 3])

In [46]:
target_logit = logits[index, labels[index].view(-1)]
target_logit

tensor([ 0.1454, -0.2144,  0.1223])

In [47]:
final_target_logit = target_logit - 1
final_target_logit

tensor([-0.8546, -1.2144, -0.8777])

In [48]:
logits[index, labels[index].view(-1)] = final_target_logit
logits

tensor([[ 1.0348, -0.0898, -0.8546, -0.3221,  1.9344],
        [-1.2144,  0.1993,  0.0426,  1.5271,  0.1213],
        [-1.3730,  0.6963, -2.3733, -0.8777, -0.5391]])

In [49]:
logits = logits * 100
logits

tensor([[ 103.4811,   -8.9752,  -85.4573,  -32.2141,  193.4370],
        [-121.4394,   19.9293,    4.2560,  152.7117,   12.1250],
        [-137.3002,   69.6291, -237.3250,  -87.7671,  -53.9132]])

In [139]:
class CosFace(nn.Module):
    def __init__(self, in_features, out_features, s=64.0, m=0.35):
        super(CosFace, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        
        self.kernel = nn.Parameter(torch.FloatTensor(in_features, out_features))
        nn.init.normal_(self.kernel, std=0.01)

    def forward(self, logits, labels):
        logits = F.normalize(logits, p=2.0, dim=1)
        kernel_norm = F.normalize(self.kernel, p=2.0, dim=0)
        cos_theta = torch.mm(logits, kernel_norm)
        cos_theta = cos_theta.clamp(-1, 1)  # for numerical stability
        index = torch.where(labels != -1)[0]
        m_hot = torch.zeros(index.size()[0], cos_theta.size()[1])
        m_hot.scatter_(1, labels[index, None], self.m)
        cos_theta[index] -= m_hot
        ret = cos_theta * self.s
        return ret

In [119]:
torch.linalg.norm(logits, dim=1, ord = 2, keepdim=True)

tensor([[2.7295],
        [2.4312],
        [1.2166],
        [2.8466]])

In [143]:
F.normalize(logits, p=2.0, dim=0)

tensor([[-0.3729, -0.4269,  0.0690,  0.0250,  0.6966],
        [-0.8660, -0.5873,  0.3917, -0.8840, -0.3064],
        [ 0.1669,  0.4848, -0.5942, -0.3404, -0.6479],
        [ 0.2885, -0.4877, -0.6991, -0.3194, -0.0343]])

In [144]:
logits = torch.randn(4, 5)
labels = torch.empty(4, dtype=torch.long).random_(5)

In [133]:
logits.dtype, labels.dtype

(torch.float32, torch.float32)

In [71]:
print(logits.shape, labels.shape)#, output.shape)

torch.Size([4, 5]) torch.Size([4])


In [72]:
print('logits: ', logits, '\n labels: ', labels)#, '\n output: ', output)

input:  tensor([[ 1.1558,  1.0874, -0.1887, -1.0192,  1.0584],
        [ 0.7359, -0.2864,  0.6532,  1.9240,  0.0953],
        [ 1.7354, -1.7874,  0.2236,  1.4018,  0.8048],
        [ 0.5327, -0.6309, -1.7325,  0.9621, -0.9335]]) 
 target:  tensor([2, 0, 4, 4])


In [103]:
in_features = 5
out_features = 5
s = 10
m = 5
        
kernel = nn.Parameter(torch.FloatTensor(in_features, out_features))

In [77]:
kernel.shape

torch.Size([5, 5])

In [83]:
cos_theta = torch.mm(logits, kernel)
cos_theta, cos_theta.shape

(tensor([[-2.1417e+28, -3.4295e+30, -1.3719e+31, -9.5810e+30,  1.3345e+23],
         [ 4.5702e+28,  1.1872e+31,  4.7493e+31,  3.3168e+31, -1.5254e+22],
         [ 2.8820e+28,  4.0639e+30,  1.6257e+31,  1.1353e+31, -8.4016e+22],
         [-1.3693e+28, -3.1489e+31, -1.2597e+32, -8.7972e+31, -9.4421e+22]],
        grad_fn=<MmBackward0>),
 torch.Size([4, 5]))

In [84]:
cos_theta = cos_theta.clamp(-1, 1)  # for numerical stability
cos_theta, cos_theta.shape

(tensor([[-1., -1., -1., -1.,  1.],
         [ 1.,  1.,  1.,  1., -1.],
         [ 1.,  1.,  1.,  1., -1.],
         [-1., -1., -1., -1., -1.]], grad_fn=<ClampBackward1>),
 torch.Size([4, 5]))

In [86]:
index = torch.where(labels != -1)[0]
index, index.shape

(tensor([0, 1, 2, 3]), torch.Size([4]))

In [92]:
index.shape[0], cos_theta.size()[1]

(4, 5)

In [107]:
m_hot = torch.zeros(index.size()[0], cos_theta.size()[1])
m_hot.shape, m_hot

(torch.Size([4, 5]),
 tensor([[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]))

In [108]:
labels[index, None]

tensor([[2],
        [0],
        [4],
        [4]])

In [109]:
m_hot.scatter_(1, labels[index, None], m)

tensor([[0., 0., 5., 0., 0.],
        [5., 0., 0., 0., 0.],
        [0., 0., 0., 0., 5.],
        [0., 0., 0., 0., 5.]])

In [110]:
cos_theta[index] -= m_hot
cos_theta

tensor([[-1., -1., -6., -1.,  1.],
        [-4.,  1.,  1.,  1., -1.],
        [ 1.,  1.,  1.,  1., -6.],
        [-1., -1., -1., -1., -6.]], grad_fn=<IndexPutBackward0>)

In [112]:
ret = cos_theta * s
ret

tensor([[-10., -10., -60., -10.,  10.],
        [-40.,  10.,  10.,  10., -10.],
        [ 10.,  10.,  10.,  10., -60.],
        [-10., -10., -10., -10., -60.]], grad_fn=<MulBackward0>)

In [141]:
loss = CosFace(5, 5)

In [145]:
output = loss(logits, labels) #input: logits, target: labels
output.shape, output

(torch.Size([4, 5]),
 tensor([[-18.5394,  -3.7403, -18.0750, -75.7024,  12.6182],
         [-73.9478,  28.4022, -10.4667, -25.2501,  20.9953],
         [-34.6950, -35.8199, -26.9825, -23.3158, -35.8693],
         [ 48.2260, -74.4909,  11.7007,  20.4957,   7.0936]],
        grad_fn=<MulBackward0>))