In [8]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import datasets
from facenet_pytorch import InceptionResnetV1
from PIL import Image
import os

# Siamese network for create embeddings
class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.facenet = InceptionResnetV1(pretrained='vggface2')
        self.fc = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )

    def forward(self, input1):
        output1 = self.facenet(input1)
        output1 = self.fc(output1)
        return output1

# Pretrained model loading
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SiameseNetwork().to(device)
model.eval()  # Переключение модели в режим оценки

# Verification function
def verify_user(img1_path, database_folder, model, transform, threshold=.01):
    img1 = Image.open(img1_path)
    img1 = transform(img1).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output1 = model(img1)

        for root, _, files in os.walk(database_folder):
            for file in files:
                img2_path = os.path.join(root, file)
                img2 = Image.open(img2_path)
                img2 = transform(img2).unsqueeze(0).to(device)
                output2 = model(img2)
                distance = (output1 - output2).pow(2).sum().item()
                is_match = distance < threshold
                print(f'Comparing {img1_path} with {img2_path}: Distance = {distance:.4f}, Match = {is_match}')

# Path to images
comparison_image_path = 'path_to_the_image_for_comparison'
database_path = 'path_to_your_user_base'

# Images transformation
transformation = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Verication test
verify_user(comparison_image_path, database_path, model, transformation)


Comparing d:/datasets/tfw6/test/S132/131_2_2_2_140_47_3.png with d:/datasets/tfw6/test/users\131_1_1_1_450_3_m.png: Distance = 0.0023, Match = True
Comparing d:/datasets/tfw6/test/S132/131_2_2_2_140_47_3.png with d:/datasets/tfw6/test/users\132_1_1_1_1_3_m.png: Distance = 0.0201, Match = False
Comparing d:/datasets/tfw6/test/S132/131_2_2_2_140_47_3.png with d:/datasets/tfw6/test/users\133_1_1_1_1_3.png: Distance = 0.0208, Match = False
