In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import math

In [None]:
class ArcFace(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
        """
        
        Implementacion de ArcFace
        
        Args:
            in_features: Tamaño del embedding vector (input)
            out_features: Números de clases
            s: Factor de escalado
            m: margen añadido entre clases en el espacio angular
            easy_margin: Usar solamente si la version base se vuelve inestable
        """

        super(ArcFace, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight) # inicializa la clase para los weights

        # Calcula cos(m) y sen(m) para temas de eficencia
        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m) # umbral para el margen de decision
        self.mm = math.sin(math.pi - m) * m # margen de penalización

    def forward(self, input, label):
        """
        
        Forward prop para ArcFace
        
        Args:
            input: Inputs las dimensiones del embedding tensor [batch_size, in_features]
            label: Labels con la dimension del [batch_size]
        Returns:
            Output: Logits (resultado tras pasar por una funcion de activacion) with shape [batch_size, out_features] to pass it to the CrossEntropyLoss (Ya me cansé de escribir en español)
        """ 

        # Normalize both inputs features and weight matrix
        cosine = F.linear(F.normalize(input), F.normalize(self.weight)) # Cosine similarity between features and weights
        sine = torch.sqrt(1.0 - torch.clamp(cosine**2, 0, 1)) # sin(θ) from cos(θ)

        # Compute cos(θ + m) using trigonometric identity
        phi = cosine + self.cos_m - sine * self.sin_m

        # Decide whether to apply margin based on thresholding (remember just use it, if the model becomes unstable)
        if self.easy_margin:
            # Use cosine if it is positive, else keep original
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            # Use original phi only if above threshold, else apply modified margin
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        
        # One Hot to enconde labels
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, label.view(-1, 1), 1.0)

        # Apply arc margin only to the correct class logits
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

In [None]:
from torchvision import models

class ResNetArcModel(nn.Module):
    def __init__(self, num_classes, backbone="resnet50", embedding_size=512):
        """
        
        Wraps a ResNet backbone and replaces final layer with embedding + ArcFace.
        
        Args:
            num_classes: Number of output classes.
            backbone: choose resnet version 'resnet18', 'resnet50'
            embedding_size: Output dimension of the embedding before classification.
        """

        super(ResNetArcModel, self).__init__()

        # Load a PRETRAINED ResNet and remove the original classifier
        resnet = getattr(models, backbone)(pretrainable=True)
        in_features = resnet.fc.in_features
        resnet.fc = nn.Identity() # Remove final classification layer

        self.backbone = resnet

        # Project backbone output to lower-dim embedding
        self.embedding = nn.Linear(in_features, embedding_size)

        # ArcFace classification head
        self.arcface = ArcFace(embedding_size, num_classes)

    def forward(self, x, labels=None):
        """
        Forward pass through the full model.

        Args:
            x: Input image tensor [B, C, H, W]
            labels: Target labels [B], required for training
        Returns:
            If labels are provided: ArcFace logits (used in training).
            If not: Raw embeddings (used in inference).
        """

        x = self.backbone(x)       # extract features from ResNet

        x = self.embedding(x)      # project to embedding space

        if labels is not None:
            logits = self.arcface(x, labels)  # compute logits with arc margin
            return logits
        
        return x  # inference mode: return embeddings

EXAMPLE CODE TO USE IT

In [None]:
# model = (num_classes=100)  cargar modelo base para entrenar
model = model.cuda()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for images, labels in train_loader:
    images = images.cuda()
    labels = labels.cuda()

    # Forward pass
    logits = model(images, labels)
    loss = criterion(logits, labels)

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

TO SAVE THE MODEL

In [None]:
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'epoch': current_epoch,
}, 'arcface_model.pth')

TO LOAD THE MODEL

In [None]:
model = ResNetWithArcFace(num_classes=100, embedding_size=512)
model.load_state_dict(torch.load('arcface_model_weights.pth'))
model = model.cuda()
model.eval()

TO SAVE EMBEDDINGS

In [None]:
import cv2
import numpy as np
import torch
from torchvision import transforms
from torch.nn.functional import normalize
from face_recognition import face_locations
from model import ResNetWithArcFace  # Your ArcFace model class

# 1. Load trained model
model = ResNetWithArcFace(num_classes=100, embedding_size=512)
model.load_state_dict(torch.load('arcface_model_weights.pth'))
model = model.cuda()
model.eval()

# 2. Define transform (match training time!)
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((112, 112)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# 3. Function to extract and compute embedding
def extract_embedding_from_image(image_path):
    image = cv2.imread(image_path)
    rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Detect face
    face_coords = face_locations(rgb)
    if not face_coords:
        raise ValueError("No face found in image.")
    
    top, right, bottom, left = face_coords[0]
    face = image[top:bottom, left:right]

    # Transform and get embedding
    face_tensor = transform(face).unsqueeze(0).cuda()
    with torch.no_grad():
        embedding = model(face_tensor)
        embedding = normalize(embedding, dim=1)  # cosine normalization
    return embedding.squeeze(0).cpu().numpy()  # 512-dim vector


EXAMPLE

In [None]:
alice_embedding = extract_embedding_from_image('alice.jpg')
np.save('alice_embedding.npy', alice_embedding)

PRACTICAL USE IN REAL TIME

In [None]:
import cv2
import torch
import numpy as np
from torchvision import transforms
from torch.nn.functional import normalize
from face_recognition import face_locations  # uses dlib internally
from model import ResNetWithArcFace  # your model class

# Load your trained model
model = ResNetWithArcFace(num_classes=100, embedding_size=512)
model.load_state_dict(torch.load('arcface_model_weights.pth'))
model = model.cuda()
model.eval()

# Transformation for input images (resize, tensor, normalize)
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((112, 112)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# Load reference embeddings (enrolled people)
reference_db = {
    "Alice": np.load("alice_embedding.npy"),
    "Bob": np.load("bob_embedding.npy")
    # Add more enrolled users
}

# Function to compute embedding from face crop
def get_embedding(face_img):
    face_tensor = transform(face_img).unsqueeze(0).cuda()
    with torch.no_grad():
        emb = model(face_tensor)  # returns [1, 512]
        emb = normalize(emb, dim=1)
    return emb.squeeze(0).cpu().numpy()  # [512]

# Compare to reference embeddings
def recognize_face(embedding, threshold=0.5):
    best_match = None
    best_score = -1
    for name, ref_emb in reference_db.items():
        score = np.dot(embedding, ref_emb)  # cosine similarity
        if score > best_score:
            best_score = score
            best_match = name
    if best_score >= threshold:
        return best_match, best_score
    else:
        return "Unknown", best_score

# OpenCV capture
cap = cv2.VideoCapture(0)
print("[INFO] Starting camera...")

while True:
    ret, frame = cap.read()
    if not ret:
        break

    rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    faces = face_locations(rgb_frame)  # returns (top, right, bottom, left)

    for (top, right, bottom, left) in faces:
        face_img = frame[top:bottom, left:right]  # crop
        if face_img.size == 0:
            continue

        emb = get_embedding(face_img)
        name, score = recognize_face(emb)

        # Draw bounding box and name
        cv2.rectangle(frame, (left, top), (right, bottom), (0, 255, 0), 2)
        cv2.putText(frame, f'{name} ({score:.2f})', (left, top - 10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)

    cv2.imshow("ArcFace Recognition", frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()
