In [1]:
from dataset_classes.vgg_face2_classifier import VGGFace2ClassifierDataset
from torchvision.transforms import v2
from torch.utils.data import DataLoader
import torch

transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.uint8, scale=True),
    v2.Resize((224, 224)),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])

def reverse_transform(image: torch.Tensor):
    fn = v2.Compose([
        v2.Normalize(mean=(-0.5 / 0.5, -0.5 / 0.5, -0.5 / 0.5), std=(1 / 0.5, 1 / 0.5, 1 / 0.5)),
    ])
    
    return fn(image)

dataset = VGGFace2ClassifierDataset(
    './recognition_dataset/images/test',
    './recognition_dataset/labels/test',
    transform
)
test_loader = DataLoader(dataset, batch_size=32, shuffle=False)

In [2]:
from loss.triplet_loss import TripletLoss
from facenet_pytorch import InceptionResnetV1
from utils.system import get_available_device

device = get_available_device()
model = InceptionResnetV1(pretrained='vggface2', num_classes=2, device=device).eval()

model_loss = TripletLoss()



  from .autonotebook import tqdm as notebook_tqdm


GPU is not available, using CPU instead
Using device: cpu


In [3]:
from tqdm import tqdm
from utils.classifier_model import aggregate_metrics, evaluate_recognition_batch

test_loop = tqdm(test_loader, leave=False)

running_val_loss = []
batch_results = []

model.eval()
with torch.no_grad():
    for anchor, positive, negative in test_loop:
        
        anchor = anchor.to(device)
        positive = positive.to(device)
        negative = negative.to(device)

        # Прямой проход + расчет ошибки модели
        anchor_embeddings = model(anchor)
        positive_embeddings = model(positive)
        negative_embeddings = model(negative)
        
        loss = model_loss(anchor_embeddings, positive_embeddings, negative_embeddings)
        
        running_val_loss.append(loss.item())
        mean_val_loss = sum(running_val_loss)/len(running_val_loss)
        
        batch_metrics = evaluate_recognition_batch(anchor_embeddings, positive_embeddings, negative_embeddings)
        batch_results.append(batch_metrics)
        
            
aggregated_metrics = aggregate_metrics(batch_results)

                                                  

KeyboardInterrupt: 

In [None]:
print('loss: ', mean_val_loss)
print('metrics:\n', aggregated_metrics)