# Defining imports:

In [73]:
# !pip install torch-summary

In [74]:
# Standard library imports
import os
import time
from pathlib import Path

# Third-party library imports
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torchvision.utils as vutils
import torchvision.datasets as datasets
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import torchvision.io as io
import torchvision.models as models
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torch.nn import TripletMarginLoss
import torch.nn as nn
from torchsummary import summary
from tqdm.notebook import tqdm
from tqdm import tqdm
from typing import Tuple, List, Dict, Union, Any
from pathlib import Path

# Internal imports from PyTorch
from torchvision.models import resnet50
from torchvision.utils import save_image
from torchvision.io import read_image

In [75]:
plt.rcParams["savefig.bbox"] = 'tight'

In [76]:
# Check if CUDA (GPU support) is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

Using device: cuda


# Model:

**Defining custom model: Triplet Neural Network (with dataset class):**

In [77]:
# Utilize PyTorch's data loading utilities
class TripletDataset(torch.utils.data.Dataset):
    
    # Initialization
    def __init__(self, dataset, transform=None):
        
        # Initializes dataset and transformations
        self.dataset = dataset
        self.transform = transform
        
        # Extract label from dataset
        self.labels = [item[1] for item in dataset.imgs]
        
        # Create dictionary where keys are labels and values are lists of indices corresponding to each label.
        self.label_to_indices = {label: np.where(np.array(self.labels) == label)[0]
                                 for label in set(self.labels)}

        
    # Defines how individual items are retrieved from the dataset given an index (Called when dataset is indexed like dataset[index])
    def __getitem__(self, index):
        
        # Extracts the image path and label of the anchor image at the given index from the dataset.
        # label1 will be label of anchor/positive 
        img1, label1 = self.dataset.imgs[index]
        
        # Initialize positive index with the anchor index
        positive_index = index
        
        # For positive index: randomly selects another index from the indices of images with the same label as the anchor image 
        while positive_index == index:
            positive_index = np.random.choice(self.label_to_indices[label1])
            
        # For negative label: Randomly selects a label that is different from the label of the anchor image 
        negative_label = np.random.choice(list(set(self.labels) - set([label1])))
        negative_index = np.random.choice(self.label_to_indices[negative_label])
        
        # Load images corresponding to the anchor, positive, and negative indices and convert images to RGB format
        img2 = self.dataset.imgs[positive_index][0]
        img3 = self.dataset.imgs[negative_index][0]
        img1 = Image.open(img1).convert("RGB")
        img2 = Image.open(img2).convert("RGB")
        img3 = Image.open(img3).convert("RGB")
        
        # If transformation is not None, apply the transformation
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            img3 = self.transform(img3)
        
        # Return images
        return label1, img1, img2, img3

    # Return the length of the dataset
    def __len__(self):
        return len(self.dataset)

In [78]:
# # Orignal

# class TripletNetwork(nn.Module):
#     def __init__(self, embedding_size=128):
#         super(TripletNetwork, self).__init__()
#         self.backbone = resnet50(pretrained=True)
#         self.backbone.fc = nn.Sequential(torch.nn.Linear(self.backbone.fc.in_features, 256),torch.nn.ReLU(),torch.nn.Dropout(0.7),torch.nn.Linear(256, 1),torch.nn.Sigmoid())  # Correct usage without dim=1

#     def forward(self, x):
#         embedding = self.backbone(x)
#         return embedding

In [79]:
# # Adjusted for Embedding Outputs

# # In this adjusted version, you can call model(input, return_embedding=True) to get the embedding directly.

# class TripletNetwork(nn.Module):
#     def __init__(self, embedding_size=128):
#         super(TripletNetwork, self).__init__()
        
#         self.backbone = resnet50(pretrained=True)
        
#         # Adjusting to include an embedding layer that outputs before the sigmoid
#         self.embedding_layer = nn.Linear(self.backbone.fc.in_features, embedding_size)
        
#         self.fc = nn.Sequential(nn.ReLU(), nn.Dropout(0.7), nn.Linear(embedding_size, 1), nn.Sigmoid())

#     def forward(self, x, return_embedding=False):
#         x = self.backbone(x)
#         embedding = self.embedding_layer(x)
#         if return_embedding:
#             return embedding
#         x = self.fc(embedding)
#         return x

In [80]:
# # Adjusted for Embedding Outputs, fixing size issue

# from torchvision.models import resnet50
# import torch.nn as nn

# class TripletNetwork(nn.Module):
#     def __init__(self, embedding_size=128):
#         super(TripletNetwork, self).__init__()
        
#         # Load a pre-trained resnet50 model
#         self.backbone = resnet50(pretrained=True)
        
#         # Remove the last fully connected layer (classifier) to use the backbone as a feature extractor
#         # This modifies the model to output features directly after the global average pooling layer
#         # which should have a feature size of 2048
#         self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
        
#         # Adjusting to include an embedding layer that outputs before the sigmoid
#         # This layer now expects a [batch_size, 2048] input tensor
#         self.embedding_layer = nn.Linear(2048, embedding_size)  # ResNet-50's pooling features are 2048-dimensional
        
#         # Final classification layer
#         self.fc = nn.Sequential(nn.ReLU(), nn.Dropout(0.7), nn.Linear(embedding_size, 1), nn.Sigmoid())

#     def forward(self, x, return_embedding=False):
#         # Pass input through the backbone to get features
#         x = self.backbone(x)
        
#         # Flatten the output for the embedding layer
#         x = x.view(x.size(0), -1)
        
#         # Get the embedding
#         embedding = self.embedding_layer(x)
        
#         if return_embedding:
#             return embedding
        
#         # Pass embedding through the final classification layer
#         x = self.fc(embedding)
#         return x

In [81]:
class Identity(nn.Module):
    def forward(self, x):
        return x

In [82]:
class TripletNetwork(nn.Module):
    def __init__(self, embedding_size=64):
        super(TripletNetwork, self).__init__()
        
        # Load a pre-trained resnet50 model
        self.backbone = resnet50(pretrained=True)
        
        # Replace the fully connected layer with an Identity module
        self.backbone.fc = Identity()
        
        # Embedding layer
        self.embedding_layer = nn.Linear(2048, embedding_size)  # Use 2048 as the in_features to match ResNet-50
        
        # Classification layers
        self.fc = nn.Sequential(nn.ReLU(), nn.Dropout(0.7), nn.Linear(embedding_size, 1), nn.Sigmoid())

    def forward(self, x, return_embedding=False):
        # Extract features using the backbone
        x = self.backbone(x)  # This will now give a [batch_size, 2048] tensor directly
        
        # Get the embedding
        embedding = self.embedding_layer(x)
        
        if return_embedding:
            return embedding
        
        # Pass embedding through the classification layer
        x = self.fc(embedding)
        return x

# Loading images into Data Loaders:

**Loading images into test data loader:**

In [83]:
class CustomImageFolder(ImageFolder):
    @staticmethod
    def find_classes(directory: Union[str, Path]) -> Tuple[List[str], Dict[str, int]]:
        """
        Finds the class folders in a dataset structured in a directory by overriding the sorting order.

        Parameters:
            directory (Union[str, Path]): Root directory path.

        Returns:
            Tuple[List[str], Dict[str, int]]: (classes, class_to_idx) where classes are a list of 
                                              the class names and class_to_idx is a dictionary mapping 
                                              class name to class index.
        """
        # Correct the syntax error by adding an extra set of parentheses around the generator expression
        classes = sorted((entry.name for entry in os.scandir(directory) if entry.is_dir()), reverse=True)
        if not classes:
            raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

In [84]:
# Set train and valid directory paths
train_directory = 'data/Training_Augmented'
valid_directory = 'data/Validation_Augmented'
test_directory = 'data/Test_Augmented'

# Batch size
bs = 32

# Number of classes
num_classes = 2

#define a standard transform to tensor
transform = transforms.Compose([transforms.ToTensor()])

# Define the index mapping for folders to labels
folder_to_label = {'no_melanoma': 0, 'melanoma': 1}

# Load Data from folders
data = {
    'train': CustomImageFolder(root=train_directory, transform=transform),
    'valid': CustomImageFolder(root=valid_directory, transform=transform),
    'test': CustomImageFolder(root=test_directory, transform=transform)
}


# Modify the class_to_idx attribute to reflect your custom mapping
# data['test'].class_to_idx = folder_to_label

In [85]:
print(data['test'])

Dataset CustomImageFolder
    Number of datapoints: 1512
    Root location: data/Test_Augmented
    StandardTransform
Transform: Compose(
               ToTensor()
           )


In [86]:
print(data['test'].class_to_idx)

{'no_melanoma': 0, 'melanoma': 1}


**Creating sample:**

In [87]:
# Printing items (tuple of file path and labels)
for item in data['test'].imgs:
    print(item)

('data/Test_Augmented\\melanoma\\ISIC_0034529_v1.jpg', 1)
('data/Test_Augmented\\melanoma\\ISIC_0034548_v1.jpg', 1)
('data/Test_Augmented\\melanoma\\ISIC_0034572_v1.jpg', 1)
('data/Test_Augmented\\melanoma\\ISIC_0034573_v1.jpg', 1)
('data/Test_Augmented\\melanoma\\ISIC_0034584_v1.jpg', 1)
('data/Test_Augmented\\melanoma\\ISIC_0034595_v1.jpg', 1)
('data/Test_Augmented\\melanoma\\ISIC_0034605_v1.jpg', 1)
('data/Test_Augmented\\melanoma\\ISIC_0034628_v1.jpg', 1)
('data/Test_Augmented\\melanoma\\ISIC_0034630_v1.jpg', 1)
('data/Test_Augmented\\melanoma\\ISIC_0034633_v1.jpg', 1)
('data/Test_Augmented\\melanoma\\ISIC_0034638_v1.jpg', 1)
('data/Test_Augmented\\melanoma\\ISIC_0034644_v1.jpg', 1)
('data/Test_Augmented\\melanoma\\ISIC_0034650_v1.jpg', 1)
('data/Test_Augmented\\melanoma\\ISIC_0034655_v1.jpg', 1)
('data/Test_Augmented\\melanoma\\ISIC_0034657_v1.jpg', 1)
('data/Test_Augmented\\melanoma\\ISIC_0034687_v1.jpg', 1)
('data/Test_Augmented\\melanoma\\ISIC_0034713_v1.jpg', 1)
('data/Test_Au

In [88]:
# Size of Data, to be used for calculating Average Loss and Accuracy
test_data_size = len(data['test'])

# Print length of each dataset
print(test_data_size)

1512


In [89]:
train_dataset = TripletDataset(dataset=data['train'], transform=transform)
train_data = DataLoader(train_dataset, batch_size=bs, shuffle=True)

valid_dataset = TripletDataset(dataset=data['valid'], transform=transform)
valid_data = DataLoader(valid_dataset, batch_size=bs, shuffle=True)

test_dataset = TripletDataset(dataset=data['test'], transform=transform)
test_data = DataLoader(test_dataset, batch_size=bs, shuffle=False)

In [90]:
print(test_data)

<torch.utils.data.dataloader.DataLoader object at 0x0000014BAD94C880>


In [117]:
for label, anchor, positive, negative in test_data:
    print("Label:", label)
    print(len(label))
    print('Anchor:', anchor)
    print('Positive:', positive)
    print('Negative:', negative)

Label: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1])
32
Anchor: tensor([[[[1.0000, 1.0000, 1.0000,  ..., 0.9922, 0.9961, 0.9882],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 0.9804, 0.9765, 0.9961],
          ...,
          [1.0000, 1.0000, 0.9804,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [0.9804, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 0.9882, 0.9922, 0.9843],
          [1.0000, 1.0000, 1.0000,  ..., 0.9961, 0.9961, 0.9961],
          [1.0000, 1.0000, 1.0000,  ..., 0.9686, 0.9647, 0.9843],
          ...,
          [0.8471, 0.8235, 0.7804,  ..., 1.0000, 1.0000, 1.0000],
          [0.8392, 0.8431, 0.8196,  ..., 1.0000, 1.0000, 1.0000],
          [0.8118, 0.8431, 0.8196,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.

Label: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1])
32
Anchor: tensor([[[[0.9608, 0.9608, 0.9725,  ..., 1.0000, 1.0000, 1.0000],
          [0.9686, 0.9686, 0.9765,  ..., 1.0000, 1.0000, 1.0000],
          [0.9765, 0.9843, 0.9882,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[0.9843, 0.9843, 0.9843,  ..., 1.

Label: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1])
32
Anchor: tensor([[[[1.0000, 1.0000, 1.0000,  ..., 0.9255, 0.9294, 0.9294],
          [0.9882, 1.0000, 1.0000,  ..., 1.0000, 0.9882, 0.9647],
          [0.9569, 0.9804, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [0.9882, 0.9804, 0.9529,  ..., 0.8824, 0.8863, 0.8941],
          [1.0000, 1.0000, 0.9608,  ..., 0.8745, 0.8667, 0.8588],
          [1.0000, 1.0000, 0.9686,  ..., 0.8706, 0.8510, 0.8314]],

         [[0.6510, 0.6353, 0.6039,  ..., 0.3373, 0.3608, 0.3725],
          [0.6235, 0.6275, 0.6118,  ..., 0.3765, 0.3686, 0.3569],
          [0.6000, 0.6118, 0.6118,  ..., 0.3725, 0.3608, 0.3333],
          ...,
          [0.4000, 0.3647, 0.2980,  ..., 0.0000, 0.0000, 0.0000],
          [0.3216, 0.2784, 0.2000,  ..., 0.0078, 0.0000, 0.0000],
          [0.2745, 0.2235, 0.1373,  ..., 0.0275, 0.0039, 0.0000]],

         [[0.9529, 0.9569, 0.9608,  ..., 0.

Label: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1])
32
Anchor: tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [0.9804, 0.9922, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[0.9922, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [0.9725, 0.9804, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 0.9882]],

         [[0.9020, 0.9176, 0.9255,  ..., 1.

Label: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1])
32
Anchor: tensor([[[[0.9843, 0.9843, 0.9882,  ..., 0.9843, 0.9765, 0.9725],
          [0.9843, 0.9843, 0.9882,  ..., 0.9804, 0.9804, 0.9804],
          [0.9843, 0.9882, 0.9882,  ..., 0.9922, 0.9922, 1.0000],
          ...,
          [0.9490, 0.9686, 0.9961,  ..., 0.9843, 0.9765, 0.9765],
          [0.9490, 0.9608, 0.9843,  ..., 0.9608, 0.9490, 0.9490],
          [0.9686, 0.9804, 0.9843,  ..., 0.9451, 0.9412, 0.9333]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 0.9843, 0.9922, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 0.9725, 0.9843, 0.9922],
          ...,
          [0.9412, 0.9608, 0.9922,  ..., 1.0000, 1.0000, 1.0000],
          [0.9451, 0.9686, 0.9961,  ..., 1.0000, 1.0000, 1.0000],
          [0.9725, 0.9882, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.

Label: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[0.9922, 0.9961, 0.9961,  ..., 1.0000, 1.0000, 1.0000],
          [0.9922, 0.9882, 0.9882,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 0.9843, 0.9843,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 0.9216,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 0.8980,  ..., 0.9882, 1.0000, 1.0000],
          [0.9725, 1.0000, 0.9451,  ..., 0.9569, 0.9255, 0.9137]],

         [[0.9961, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 0.9961, 0.9922,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 0.9843, 0.9843,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [0.8353, 0.7373, 0.5608,  ..., 0.9412, 0.9529, 0.9490],
          [0.8667, 0.7686, 0.5725,  ..., 0.8941, 0.9216, 0.9333],
          [0.8431, 0.7804, 0.6392,  ..., 0.8706, 0.8392, 0.8196]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[0.9961, 1.0000, 1.0000,  ..., 0.9922, 0.9725, 0.9922],
          [0.9843, 1.0000, 1.0000,  ..., 1.0000, 0.9961, 0.9922],
          [0.9882, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 0.9882],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 0.9804, 0.9686, 0.9843],
          [1.0000, 1.0000, 1.0000,  ..., 0.9725, 0.9569, 0.9647],
          [0.9843, 1.0000, 1.0000,  ..., 0.9882, 0.9922, 1.0000]],

         [[0.5882, 0.5922, 0.6275,  ..., 0.5490, 0.5294, 0.5412],
          [0.5569, 0.5725, 0.6118,  ..., 0.5608, 0.5451, 0.5412],
          [0.5490, 0.5608, 0.5765,  ..., 0.5804, 0.5647, 0.5373],
          ...,
          [0.5373, 0.5529, 0.5686,  ..., 0.5961, 0.5804, 0.5922],
          [0.5176, 0.5412, 0.5647,  ..., 0.5882, 0.5686, 0.5725],
          [0.5059, 0.5294, 0.5569,  ..., 0.6039, 0.6039, 0.6157]],

         [[0.8980, 0.9020, 0.9569,  ..., 0.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[0.7373, 0.7451, 0.7725,  ..., 0.7490, 0.7490, 0.7843],
          [0.7569, 0.7529, 0.7725,  ..., 0.7490, 0.7647, 0.7961],
          [0.7490, 0.7373, 0.7412,  ..., 0.7725, 0.7882, 0.8157],
          ...,
          [0.5804, 0.6235, 0.5922,  ..., 0.7608, 0.6784, 0.6235],
          [0.6118, 0.6118, 0.5569,  ..., 0.7098, 0.6510, 0.6196],
          [0.6431, 0.5961, 0.5333,  ..., 0.6275, 0.6157, 0.6157]],

         [[0.4157, 0.4275, 0.4510,  ..., 0.3490, 0.3373, 0.3608],
          [0.4314, 0.4275, 0.4510,  ..., 0.3569, 0.3569, 0.3804],
          [0.4078, 0.4000, 0.4039,  ..., 0.3882, 0.3961, 0.4157],
          ...,
          [0.1882, 0.2314, 0.2000,  ..., 0.4627, 0.3843, 0.3255],
          [0.2196, 0.2196, 0.1647,  ..., 0.4235, 0.3647, 0.3255],
          [0.2510, 0.2039, 0.1490,  ..., 0.3412, 0.3294, 0.3333]],

         [[0.8471, 0.8392, 0.8510,  ..., 0.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 1.0000, 1.0000,  ..., 0.9725, 0.9765, 1.0000],
          [1.0000, 0.9961, 0.9843,  ..., 0.9804, 0.9882, 1.0000],
          [1.0000, 0.9843, 0.9686,  ..., 1.0000, 1.0000, 0.9843],
          ...,
          [0.9490, 0.9725, 0.9882,  ..., 1.0000, 0.9882, 0.9451],
          [0.9686, 0.9765, 0.9843,  ..., 1.0000, 1.0000, 0.9843],
          [0.9765, 0.9804, 0.9882,  ..., 1.0000, 1.0000, 1.0000]],

         [[0.4980, 0.4980, 0.5059,  ..., 0.5725, 0.5765, 0.6118],
          [0.4706, 0.4627, 0.4745,  ..., 0.5804, 0.5765, 0.5882],
          [0.4784, 0.4667, 0.4706,  ..., 0.6000, 0.5882, 0.5725],
          ...,
          [0.3882, 0.4118, 0.4275,  ..., 0.4902, 0.4549, 0.4118],
          [0.4078, 0.4157, 0.4314,  ..., 0.5294, 0.5020, 0.4471],
          [0.4118, 0.4275, 0.4431,  ..., 0.5686, 0.5451, 0.4824]],

         [[0.5333, 0.5373, 0.5451,  ..., 0.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 1.0000, 0.9922,  ..., 0.9569, 0.9216, 0.9490],
          [1.0000, 1.0000, 0.9843,  ..., 0.9804, 0.9373, 0.9647],
          [1.0000, 1.0000, 0.9765,  ..., 0.9961, 0.9608, 0.9843],
          ...,
          [0.9451, 0.9608, 0.9882,  ..., 0.9569, 0.9725, 0.9882],
          [0.9569, 0.9922, 1.0000,  ..., 0.9961, 1.0000, 0.9804],
          [0.9608, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[0.5686, 0.5922, 0.5451,  ..., 0.6471, 0.6471, 0.7020],
          [0.5804, 0.5961, 0.5373,  ..., 0.6588, 0.6549, 0.7059],
          [0.5961, 0.6039, 0.5294,  ..., 0.6627, 0.6588, 0.7020],
          ...,
          [0.6275, 0.6549, 0.6824,  ..., 0.3843, 0.3804, 0.3804],
          [0.6353, 0.6706, 0.7137,  ..., 0.4235, 0.4039, 0.3686],
          [0.6314, 0.6745, 0.7294,  ..., 0.4941, 0.4667, 0.3882]],

         [[0.9608, 0.9765, 0.9216,  ..., 0.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[0.9804, 0.9961, 1.0000,  ..., 0.9804, 0.9843, 0.9922],
          [1.0000, 1.0000, 1.0000,  ..., 0.9804, 0.9882, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 0.9804, 0.9961, 1.0000],
          ...,
          [0.9922, 1.0000, 1.0000,  ..., 1.0000, 0.9961, 0.9765],
          [0.9686, 0.9961, 1.0000,  ..., 1.0000, 1.0000, 0.9961],
          [0.9529, 0.9843, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[0.4706, 0.4863, 0.4941,  ..., 0.5804, 0.5765, 0.5843],
          [0.4980, 0.4980, 0.4941,  ..., 0.5843, 0.5882, 0.6000],
          [0.5137, 0.5059, 0.4980,  ..., 0.5843, 0.6000, 0.6235],
          ...,
          [0.5765, 0.6078, 0.6353,  ..., 0.6824, 0.6588, 0.6392],
          [0.5529, 0.5843, 0.6157,  ..., 0.6627, 0.6510, 0.6431],
          [0.5373, 0.5725, 0.6039,  ..., 0.6706, 0.6706, 0.6706]],

         [[0.7216, 0.7373, 0.7451,  ..., 0.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 0.9961],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 0.9961],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 0.9961, 0.9882],
          ...,
          [0.9020, 0.9490, 0.9961,  ..., 0.9961, 0.9843, 0.9765],
          [0.9216, 0.9725, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [0.9882, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[0.6157, 0.6235, 0.6314,  ..., 0.8275, 0.8078, 0.7882],
          [0.6275, 0.6275, 0.6275,  ..., 0.8196, 0.8039, 0.7882],
          [0.6510, 0.6510, 0.6431,  ..., 0.8157, 0.8078, 0.7922],
          ...,
          [0.5843, 0.6314, 0.6667,  ..., 0.9961, 1.0000, 1.0000],
          [0.6000, 0.6510, 0.7020,  ..., 0.9686, 0.9608, 0.9686],
          [0.6627, 0.6941, 0.7098,  ..., 0.9490, 0.9569, 0.9569]],

         [[0.8196, 0.8157, 0.8196,  ..., 1.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 1.0000, 0.9961,  ..., 1.0000, 1.0000, 1.0000],
          [0.9804, 0.9765, 0.9882,  ..., 1.0000, 1.0000, 1.0000],
          [0.9804, 0.9804, 0.9922,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 0.9961, 0.9882],
          [1.0000, 1.0000, 1.0000,  ..., 0.9961, 0.9961, 0.9882],
          [1.0000, 1.0000, 1.0000,  ..., 0.9961, 0.9961, 0.9961]],

         [[0.8902, 0.8863, 0.8941,  ..., 0.9843, 0.9804, 0.9765],
          [0.8588, 0.8627, 0.8745,  ..., 0.9843, 0.9804, 0.9765],
          [0.8588, 0.8588, 0.8784,  ..., 0.9843, 0.9765, 0.9765],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 0.8667, 0.8745, 0.8745],
          [1.0000, 1.0000, 1.0000,  ..., 0.8627, 0.8824, 0.8745],
          [1.0000, 1.0000, 1.0000,  ..., 0.8627, 0.8824, 0.8824]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[0.9686, 0.9765, 0.9961,  ..., 1.0000, 1.0000, 1.0000],
          [0.9490, 0.9412, 0.9569,  ..., 1.0000, 1.0000, 1.0000],
          [0.9569, 0.9373, 0.9412,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [0.9961, 0.9961, 0.9961,  ..., 1.0000, 1.0000, 1.0000],
          [0.9961, 0.9961, 0.9961,  ..., 1.0000, 1.0000, 1.0000],
          [0.9961, 0.9961, 0.9961,  ..., 1.0000, 1.0000, 1.0000]],

         [[0.9686, 0.9765, 0.9961,  ..., 1.0000, 1.0000, 1.0000],
          [0.9490, 0.9412, 0.9569,  ..., 1.0000, 1.0000, 1.0000],
          [0.9569, 0.9373, 0.9412,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [0.9961, 0.9961, 0.9961,  ..., 1.0000, 1.0000, 1.0000],
          [0.9961, 0.9961, 0.9961,  ..., 1.0000, 1.0000, 1.0000],
          [0.9961, 0.9961, 0.9961,  ..., 1.0000, 1.0000, 1.0000]],

         [[0.9686, 0.9765, 0.9961,  ..., 1.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[0.9569, 0.9843, 1.0000,  ..., 1.0000, 0.9843, 0.9451],
          [0.9608, 0.9725, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [0.9647, 0.9569, 0.9725,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 0.9961,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 0.9922,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 0.9882,  ..., 1.0000, 1.0000, 1.0000]],

         [[0.2275, 0.2549, 0.2824,  ..., 0.3490, 0.3333, 0.3020],
          [0.2431, 0.2510, 0.2706,  ..., 0.3490, 0.3608, 0.3608],
          [0.2588, 0.2431, 0.2549,  ..., 0.3373, 0.3373, 0.3569],
          ...,
          [0.0784, 0.0627, 0.0431,  ..., 0.0039, 0.0078, 0.0078],
          [0.0824, 0.0667, 0.0431,  ..., 0.0000, 0.0000, 0.0000],
          [0.0824, 0.0667, 0.0392,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.6549, 0.6902, 0.7255,  ..., 0.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [0.9961, 0.9961, 0.9961,  ..., 1.0000, 1.0000, 1.0000],
          [0.9961, 0.9961, 0.9922,  ..., 1.0000, 1.0000, 1.0000],
          [0.9961, 0.9961, 0.9922,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 0.9922,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 0.9961,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 0.9961,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[0.6118, 0.6000, 0.5725,  ..., 0.8000, 0.7647, 0.7490],
          [0.5804, 0.5765, 0.5725,  ..., 0.8078, 0.8000, 0.8078],
          [0.5608, 0.5686, 0.5882,  ..., 0.8235, 0.8353, 0.8667],
          ...,
          [0.6392, 0.6353, 0.6510,  ..., 0.8627, 0.8824, 0.9098],
          [0.6314, 0.6196, 0.6235,  ..., 0.8392, 0.8549, 0.8863],
          [0.6627, 0.6353, 0.6118,  ..., 0.8275, 0.8471, 0.8745]],

         [[0.0627, 0.0588, 0.0549,  ..., 0.1843, 0.1647, 0.1490],
          [0.0353, 0.0431, 0.0588,  ..., 0.1961, 0.2000, 0.2157],
          [0.0235, 0.0392, 0.0745,  ..., 0.2235, 0.2392, 0.2784],
          ...,
          [0.0000, 0.0039, 0.0392,  ..., 0.2549, 0.2706, 0.2941],
          [0.0000, 0.0000, 0.0235,  ..., 0.2314, 0.2431, 0.2706],
          [0.0118, 0.0039, 0.0118,  ..., 0.2196, 0.2353, 0.2588]],

         [[0.6431, 0.6353, 0.6235,  ..., 0.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 0.9922, 0.9608,  ..., 0.9882, 1.0000, 0.9725],
          [1.0000, 1.0000, 0.9647,  ..., 0.9725, 1.0000, 0.9725],
          [1.0000, 1.0000, 0.9725,  ..., 0.9490, 0.9725, 0.9961],
          ...,
          [0.9569, 0.9804, 1.0000,  ..., 1.0000, 1.0000, 0.9608],
          [0.9490, 0.9686, 0.9882,  ..., 1.0000, 1.0000, 0.9843],
          [0.9255, 0.9451, 0.9686,  ..., 1.0000, 1.0000, 1.0000]],

         [[0.0392, 0.0549, 0.1020,  ..., 0.3176, 0.3569, 0.2941],
          [0.0627, 0.0745, 0.1216,  ..., 0.2980, 0.3333, 0.3059],
          [0.1020, 0.1098, 0.1490,  ..., 0.2784, 0.3098, 0.3333],
          ...,
          [0.1490, 0.1804, 0.2078,  ..., 0.2275, 0.2078, 0.1765],
          [0.1569, 0.1725, 0.2000,  ..., 0.1961, 0.2039, 0.1804],
          [0.1412, 0.1608, 0.1882,  ..., 0.1686, 0.1961, 0.1922]],

         [[0.2510, 0.2706, 0.3294,  ..., 0.

Negative: tensor([[[[0.9333, 0.9647, 0.9843,  ..., 0.9059, 0.8902, 0.8863],
          [0.9804, 1.0000, 1.0000,  ..., 0.9098, 0.8980, 0.9059],
          [1.0000, 1.0000, 1.0000,  ..., 0.9137, 0.9059, 0.9098],
          ...,
          [1.0000, 0.9882, 0.9882,  ..., 0.9922, 0.9961, 0.9882],
          [1.0000, 0.9804, 0.9765,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 0.9765, 0.9725,  ..., 1.0000, 1.0000, 1.0000]],

         [[0.2627, 0.2431, 0.1765,  ..., 0.0000, 0.0000, 0.0000],
          [0.3529, 0.3333, 0.2667,  ..., 0.0000, 0.0000, 0.0000],
          [0.4627, 0.4392, 0.3804,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0588, 0.0235, 0.0039,  ..., 0.0431, 0.0510, 0.0431],
          [0.0627, 0.0275, 0.0039,  ..., 0.0549, 0.0824, 0.0902],
          [0.0588, 0.0196, 0.0000,  ..., 0.0706, 0.1137, 0.1294]],

         [[0.7725, 0.7608, 0.7020,  ..., 0.4471, 0.4235, 0.4196],
          [0.8392, 0.8235, 0.7725,  ..., 0.4627, 0.4471, 0.4510],
          [0.9216, 0.9059, 0.858

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 0.9804, 0.9725,  ..., 0.8588, 1.0000, 1.0000],
          [0.9882, 0.9647, 0.9608,  ..., 0.7255, 1.0000, 1.0000],
          [0.9490, 0.9373, 0.9608,  ..., 0.6706, 0.8196, 1.0000],
          ...,
          [0.9961, 1.0000, 0.9961,  ..., 0.5255, 0.3529, 0.4588],
          [1.0000, 1.0000, 1.0000,  ..., 0.5882, 0.3412, 0.3020],
          [1.0000, 1.0000, 1.0000,  ..., 0.6196, 0.3490, 0.2275]],

         [[0.1725, 0.1294, 0.1216,  ..., 0.2000, 0.5255, 0.3922],
          [0.1373, 0.1137, 0.1176,  ..., 0.0549, 0.3294, 0.4431],
          [0.0980, 0.0980, 0.1176,  ..., 0.0000, 0.0863, 0.4588],
          ...,
          [0.3490, 0.3529, 0.3647,  ..., 0.0000, 0.0000, 0.0510],
          [0.3765, 0.3843, 0.3843,  ..., 0.0000, 0.0000, 0.0510],
          [0.3961, 0.3961, 0.3922,  ..., 0.0000, 0.0000, 0.0627]],

         [[0.6157, 0.5804, 0.5804,  ..., 0.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 1.0000, 1.0000,  ..., 0.9569, 0.9804, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 0.9765, 0.9608, 0.9647],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 0.9804, 0.9647],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 0.9804, 1.0000, 1.0000],
          [1.0000, 1.0000, 0.9725,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 0.9804,  ..., 1.0000, 1.0000, 1.0000]],

         [[0.6392, 0.6588, 0.6863,  ..., 0.6588, 0.6980, 0.7412],
          [0.6353, 0.6510, 0.6784,  ..., 0.6745, 0.6784, 0.6902],
          [0.6314, 0.6431, 0.6667,  ..., 0.7098, 0.6980, 0.6902],
          ...,
          [0.5725, 0.5333, 0.5176,  ..., 0.4588, 0.4980, 0.5137],
          [0.5804, 0.5529, 0.5216,  ..., 0.4863, 0.4980, 0.4902],
          [0.6235, 0.5961, 0.5451,  ..., 0.4941, 0.5020, 0.4863]],

         [[1.0000, 1.0000, 1.0000,  ..., 0.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 0.9765, 0.9451,  ..., 0.9882, 0.9804, 0.9804],
          [1.0000, 1.0000, 0.9922,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 0.9882,  ..., 0.9882, 1.0000, 1.0000],
          [1.0000, 1.0000, 0.9804,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 0.9882,  ..., 1.0000, 1.0000, 1.0000]],

         [[0.3020, 0.2588, 0.2157,  ..., 0.3882, 0.3804, 0.3804],
          [0.3255, 0.2941, 0.2627,  ..., 0.4196, 0.4157, 0.4157],
          [0.3216, 0.3098, 0.2941,  ..., 0.4471, 0.4471, 0.4510],
          ...,
          [0.2157, 0.2039, 0.1882,  ..., 0.5255, 0.5373, 0.5451],
          [0.2118, 0.1961, 0.1804,  ..., 0.5412, 0.5451, 0.5529],
          [0.2235, 0.2039, 0.1882,  ..., 0.5569, 0.5569, 0.5608]],

         [[0.9569, 0.9176, 0.8784,  ..., 0.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[0.9922, 1.0000, 1.0000,  ..., 0.9961, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 0.9922, 0.9961, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 0.9961, 0.9961, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 0.9569, 0.9765, 0.9922],
          [1.0000, 1.0000, 0.9843,  ..., 0.9569, 0.9569, 0.9569],
          [1.0000, 1.0000, 0.9686,  ..., 1.0000, 1.0000, 0.9961]],

         [[0.7373, 0.7569, 0.7843,  ..., 0.7255, 0.7137, 0.7176],
          [0.7451, 0.7647, 0.7843,  ..., 0.7255, 0.7137, 0.7137],
          [0.7490, 0.7569, 0.7765,  ..., 0.7490, 0.7294, 0.7294],
          ...,
          [0.7333, 0.7255, 0.7216,  ..., 0.7843, 0.8118, 0.8235],
          [0.7451, 0.7294, 0.7059,  ..., 0.7843, 0.7882, 0.8000],
          [0.7647, 0.7333, 0.6941,  ..., 0.8314, 0.8353, 0.8392]],

         [[0.9020, 0.9255, 0.9608,  ..., 0.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[0.7216, 0.7216, 0.5490,  ..., 0.7412, 0.7176, 0.7059],
          [0.7686, 0.7451, 0.5725,  ..., 0.7333, 0.7294, 0.7255],
          [0.7882, 0.7490, 0.5725,  ..., 0.7490, 0.7490, 0.7373],
          ...,
          [0.5882, 0.5922, 0.6275,  ..., 0.6863, 0.6980, 0.7059],
          [0.6471, 0.6510, 0.6667,  ..., 0.6627, 0.6549, 0.6549],
          [0.6980, 0.7059, 0.7059,  ..., 0.6627, 0.6431, 0.6353]],

         [[0.5725, 0.5647, 0.3765,  ..., 0.5569, 0.5451, 0.5373],
          [0.6157, 0.5882, 0.3961,  ..., 0.5490, 0.5569, 0.5569],
          [0.6353, 0.5882, 0.4039,  ..., 0.5647, 0.5647, 0.5647],
          ...,
          [0.3843, 0.3882, 0.4314,  ..., 0.5216, 0.5333, 0.5412],
          [0.4353, 0.4471, 0.4706,  ..., 0.4980, 0.4902, 0.4902],
          [0.4824, 0.4980, 0.5059,  ..., 0.4980, 0.4784, 0.4706]],

         [[1.0000, 1.0000, 0.8275,  ..., 0.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[0.9843, 0.9843, 0.9882,  ..., 0.9882, 0.9725, 0.9725],
          [0.9922, 0.9922, 0.9922,  ..., 1.0000, 1.0000, 0.9922],
          [1.0000, 1.0000, 0.9922,  ..., 1.0000, 1.0000, 0.9961],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 0.9843],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 0.9804, 1.0000, 1.0000]],

         [[0.2980, 0.2980, 0.3098,  ..., 0.3451, 0.3216, 0.3098],
          [0.3059, 0.3059, 0.3137,  ..., 0.3490, 0.3216, 0.3059],
          [0.3216, 0.3216, 0.3255,  ..., 0.2980, 0.2745, 0.2549],
          ...,
          [0.3725, 0.3725, 0.3647,  ..., 0.4314, 0.4078, 0.3882],
          [0.3686, 0.3804, 0.3804,  ..., 0.4157, 0.4196, 0.4157],
          [0.3569, 0.3804, 0.3843,  ..., 0.4000, 0.4196, 0.4392]],

         [[0.9059, 0.9059, 0.9137,  ..., 0.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 0.9882, 0.9843,  ..., 1.0000, 1.0000, 1.0000],
          [0.9529, 0.9647, 0.9333,  ..., 1.0000, 1.0000, 1.0000],
          [0.9176, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [0.5490, 0.5098, 0.5333,  ..., 1.0000, 1.0000, 1.0000],
          [0.4588, 0.4863, 0.4824,  ..., 1.0000, 1.0000, 1.0000],
          [0.4235, 0.5412, 0.5647,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 1.0000, 1.0000,  ..., 0.9686, 0.9412, 0.9059],
          [1.0000, 1.0000, 1.0000,  ..., 0.9333, 0.8706, 0.8314],
          [1.0000, 1.0000, 0.9922,  ..., 0.9647, 0.9294, 0.9176],
          ...,
          [0.9922, 0.9961, 0.9647,  ..., 1.0000, 1.0000, 0.9804],
          [0.9569, 0.9843, 0.9804,  ..., 1.0000, 1.0000, 0.9608],
          [0.9529, 1.0000, 1.0000,  ..., 1.0000, 0.9961, 0.9176]],

         [[0.5294, 0.5451, 0.5608,  ..., 0.7843, 0.8275, 0.8353],
          [0.5333, 0.5490, 0.5647,  ..., 0.6196, 0.6235, 0.6196],
          [0.5137, 0.5255, 0.5373,  ..., 0.4824, 0.4980, 0.5098],
          ...,
          [0.6588, 0.6510, 0.5843,  ..., 0.9569, 0.9098, 0.8510],
          [0.6157, 0.6275, 0.5922,  ..., 0.9765, 0.9490, 0.8706],
          [0.6118, 0.6471, 0.6392,  ..., 0.9843, 0.9373, 0.8510]],

         [[0.9569, 0.9608, 0.9451,  ..., 1.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[0.9804, 0.9922, 1.0000,  ..., 0.7216, 0.7059, 0.7020],
          [0.9804, 0.9922, 1.0000,  ..., 0.7255, 0.7098, 0.7059],
          [0.9843, 1.0000, 1.0000,  ..., 0.7294, 0.7176, 0.7176],
          ...,
          [0.9961, 0.9961, 0.9961,  ..., 0.9922, 0.9961, 1.0000],
          [0.9961, 0.9961, 0.9961,  ..., 0.9922, 0.9922, 0.9922],
          [0.9961, 0.9961, 0.9961,  ..., 0.9882, 0.9922, 0.9922]],

         [[0.7843, 0.7765, 0.7608,  ..., 0.0039, 0.0000, 0.0000],
          [0.7765, 0.7725, 0.7647,  ..., 0.0039, 0.0000, 0.0000],
          [0.7647, 0.7608, 0.7608,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0078, 0.0078, 0.0039,  ..., 0.5020, 0.5059, 0.5098],
          [0.0078, 0.0078, 0.0039,  ..., 0.5020, 0.5098, 0.5098],
          [0.0078, 0.0078, 0.0039,  ..., 0.5059, 0.5098, 0.5098]],

         [[0.9294, 0.9294, 0.9216,  ..., 0.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 1.0000, 1.0000,  ..., 0.9725, 0.9843, 0.9804],
          [1.0000, 1.0000, 1.0000,  ..., 0.9765, 0.9882, 0.9843],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 0.9843,  ..., 1.0000, 1.0000, 1.0000],
          [0.9961, 0.9608, 0.9490,  ..., 0.9882, 1.0000, 1.0000]],

         [[0.2157, 0.2431, 0.2745,  ..., 0.7765, 0.7922, 0.7882],
          [0.2353, 0.2471, 0.2706,  ..., 0.7804, 0.7882, 0.7843],
          [0.2706, 0.2627, 0.2627,  ..., 0.8000, 0.8039, 0.7922],
          ...,
          [0.4039, 0.3922, 0.3569,  ..., 0.6510, 0.6510, 0.6392],
          [0.3922, 0.3569, 0.3294,  ..., 0.6392, 0.6431, 0.6392],
          [0.3569, 0.3098, 0.2941,  ..., 0.6196, 0.6275, 0.6275]],

         [[0.5804, 0.6039, 0.6314,  ..., 0.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[0.9922, 0.9490, 0.9490,  ..., 1.0000, 1.0000, 0.9412],
          [1.0000, 0.9804, 0.9765,  ..., 1.0000, 1.0000, 0.9569],
          [1.0000, 0.9725, 0.9725,  ..., 0.9843, 1.0000, 0.9451],
          ...,
          [0.9882, 0.9922, 1.0000,  ..., 0.9843, 0.9412, 0.9294],
          [0.9804, 0.9765, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [0.9922, 0.9686, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[0.9294, 0.8784, 0.8706,  ..., 0.6275, 0.6235, 0.5020],
          [0.9569, 0.9020, 0.8902,  ..., 0.6000, 0.6039, 0.4980],
          [0.9373, 0.8824, 0.8745,  ..., 0.5490, 0.5569, 0.4667],
          ...,
          [0.7490, 0.7451, 0.7569,  ..., 0.7255, 0.6902, 0.6784],
          [0.7451, 0.7412, 0.7804,  ..., 0.7333, 0.7255, 0.7176],
          [0.7647, 0.7333, 0.7608,  ..., 0.7490, 0.7686, 0.7608]],

         [[0.9804, 0.9333, 0.9176,  ..., 0.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 0.9961, 0.9843,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 0.9882,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 0.9961,  ..., 1.0000, 1.0000, 0.9804],
          [1.0000, 1.0000, 0.9882,  ..., 1.0000, 1.0000, 0.9725],
          [1.0000, 1.0000, 0.9961,  ..., 1.0000, 1.0000, 0.9647]],

         [[0.6471, 0.6392, 0.6275,  ..., 0.7490, 0.7608, 0.7765],
          [0.6471, 0.6431, 0.6314,  ..., 0.7451, 0.7686, 0.7804],
          [0.6353, 0.6353, 0.6314,  ..., 0.7255, 0.7569, 0.7843],
          ...,
          [0.6314, 0.6118, 0.5804,  ..., 0.5451, 0.5412, 0.5176],
          [0.6196, 0.6118, 0.5804,  ..., 0.5333, 0.5216, 0.4902],
          [0.6235, 0.6196, 0.5961,  ..., 0.5294, 0.5059, 0.4706]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 0.9961, 1.0000],
          [0.9686, 0.9804, 0.9882,  ..., 1.0000, 0.9765, 0.9647],
          ...,
          [0.9804, 0.9882, 0.9922,  ..., 0.8863, 0.8510, 0.8235],
          [0.9765, 0.9804, 0.9882,  ..., 0.8588, 0.8275, 0.8039],
          [0.9608, 0.9647, 0.9765,  ..., 0.8549, 0.8314, 0.8118]],

         [[0.7059, 0.7137, 0.7176,  ..., 0.8039, 0.8118, 0.8275],
          [0.7255, 0.7412, 0.7451,  ..., 0.8235, 0.8118, 0.8157],
          [0.6980, 0.7098, 0.7137,  ..., 0.8588, 0.8275, 0.8157],
          ...,
          [0.7137, 0.7216, 0.7255,  ..., 0.5843, 0.5451, 0.5176],
          [0.7137, 0.7176, 0.7216,  ..., 0.5529, 0.5216, 0.4980],
          [0.6980, 0.7020, 0.7059,  ..., 0.5490, 0.5333, 0.5137]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[0.3765, 0.3725, 0.3804,  ..., 0.7843, 0.7804, 0.7725],
          [0.3725, 0.3725, 0.3765,  ..., 0.7608, 0.7529, 0.7451],
          [0.4196, 0.4118, 0.4078,  ..., 0.7412, 0.7294, 0.7216],
          ...,
          [0.7020, 0.6980, 0.6941,  ..., 0.7725, 0.7725, 0.7647],
          [0.6863, 0.6745, 0.6706,  ..., 0.7451, 0.7333, 0.7137],
          [0.6706, 0.6549, 0.6392,  ..., 0.7255, 0.7255, 0.7098]],

         [[0.4902, 0.4863, 0.4941,  ..., 0.8667, 0.8588, 0.8510],
          [0.4863, 0.4824, 0.4863,  ..., 0.8431, 0.8353, 0.8235],
          [0.5255, 0.5137, 0.5098,  ..., 0.8235, 0.8118, 0.8039],
          ...,
          [0.7725, 0.7686, 0.7647,  ..., 0.7765, 0.7765, 0.7686],
          [0.7647, 0.7529, 0.7451,  ..., 0.7608, 0.7490, 0.7294],
          [0.7490, 0.7333, 0.7176,  ..., 0.7412, 0.7412, 0.7255]],

         [[0.8980, 0.8941, 0.9020,  ..., 1.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 0.9922,  ..., 0.9804, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [0.9922, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 0.9765]],

         [[0.9725, 0.9843, 0.9922,  ..., 0.7725, 0.7294, 0.7843],
          [0.9765, 0.9804, 0.9922,  ..., 0.7569, 0.7569, 0.7725],
          [0.9843, 0.9843, 0.9922,  ..., 0.7020, 0.7294, 0.6941],
          ...,
          [0.9608, 0.9647, 0.9765,  ..., 0.6706, 0.6745, 0.6784],
          [0.9529, 0.9608, 0.9725,  ..., 0.6941, 0.6706, 0.6392],
          [0.9765, 0.9843, 0.9882,  ..., 0.7098, 0.6549, 0.5882]],

         [[1.0000, 1.0000, 0.9922,  ..., 1.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[0.9294, 1.0000, 1.0000,  ..., 1.0000, 0.9882, 0.9412],
          [0.9373, 0.9490, 1.0000,  ..., 1.0000, 0.9843, 0.9843],
          [0.9608, 0.8980, 1.0000,  ..., 1.0000, 0.9843, 1.0000],
          ...,
          [1.0000, 0.9882, 0.9843,  ..., 1.0000, 0.9961, 0.9843],
          [1.0000, 1.0000, 0.9922,  ..., 1.0000, 1.0000, 1.0000],
          [0.9608, 0.9765, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[0.1529, 0.4039, 0.6039,  ..., 0.9686, 1.0000, 1.0000],
          [0.0706, 0.2549, 0.5686,  ..., 0.9804, 0.9922, 1.0000],
          [0.0157, 0.1255, 0.4941,  ..., 0.9882, 0.9647, 0.9882],
          ...,
          [0.8314, 0.8471, 0.8863,  ..., 0.9843, 1.0000, 0.9922],
          [0.8353, 0.8471, 0.8902,  ..., 0.9843, 1.0000, 1.0000],
          [0.7765, 0.8157, 0.8980,  ..., 0.9765, 0.9961, 1.0000]],

         [[0.4549, 0.7176, 0.9686,  ..., 1.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 0.9922, 0.9843,  ..., 0.9608, 0.9020, 0.8745],
          [1.0000, 1.0000, 1.0000,  ..., 0.9529, 0.9059, 0.8824],
          [1.0000, 1.0000, 1.0000,  ..., 0.9412, 0.9059, 0.8980],
          ...,
          [0.9451, 0.9373, 0.9176,  ..., 0.8510, 0.8353, 0.8235],
          [0.9765, 0.9255, 0.8549,  ..., 0.8431, 0.8118, 0.7882],
          [0.9765, 0.8824, 0.7725,  ..., 0.8314, 0.7922, 0.7569]],

         [[1.0000, 0.9882, 0.9765,  ..., 0.8784, 0.8275, 0.7961],
          [1.0000, 1.0000, 0.9922,  ..., 0.8706, 0.8314, 0.8039],
          [0.9961, 0.9961, 0.9961,  ..., 0.8588, 0.8196, 0.8118],
          ...,
          [0.7765, 0.7569, 0.7333,  ..., 0.7059, 0.6824, 0.6667],
          [0.7961, 0.7373, 0.6549,  ..., 0.6980, 0.6549, 0.6196],
          [0.7843, 0.6902, 0.5686,  ..., 0.6863, 0.6353, 0.5882]],

         [[0.9922, 0.9804, 0.9804,  ..., 1.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[0.6275, 0.6157, 0.6392,  ..., 0.8000, 0.8392, 0.8627],
          [0.6314, 0.6196, 0.6510,  ..., 0.7922, 0.8314, 0.8471],
          [0.6706, 0.6588, 0.6745,  ..., 0.7569, 0.7804, 0.7882],
          ...,
          [0.8745, 0.8745, 0.8824,  ..., 0.7804, 0.7882, 0.7961],
          [0.8431, 0.8510, 0.8706,  ..., 0.7843, 0.7882, 0.7922],
          [0.8471, 0.8549, 0.8784,  ..., 0.7765, 0.7765, 0.7765]],

         [[0.4157, 0.4157, 0.4471,  ..., 0.7176, 0.7569, 0.7765],
          [0.4275, 0.4235, 0.4588,  ..., 0.7098, 0.7490, 0.7647],
          [0.4824, 0.4706, 0.4980,  ..., 0.6745, 0.6980, 0.7059],
          ...,
          [0.7647, 0.7647, 0.7843,  ..., 0.6706, 0.6706, 0.6745],
          [0.7294, 0.7373, 0.7686,  ..., 0.6706, 0.6667, 0.6706],
          [0.7333, 0.7412, 0.7765,  ..., 0.6627, 0.6549, 0.6431]],

         [[0.9529, 0.9294, 0.9412,  ..., 1.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 0.9882, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 0.9765]],

         [[0.7412, 0.7373, 0.7294,  ..., 0.9922, 0.9882, 0.9804],
          [0.7333, 0.7373, 0.7490,  ..., 0.9882, 0.9843, 0.9843],
          [0.7216, 0.7373, 0.7490,  ..., 0.9882, 0.9843, 0.9843],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 0.7725, 0.7843, 0.7882],
          [1.0000, 1.0000, 1.0000,  ..., 0.8000, 0.7804, 0.7490],
          [1.0000, 1.0000, 1.0000,  ..., 0.8196, 0.7765, 0.7216]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.

Positive: tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.000

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0])
8
Anchor: tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 0.9333, 0.9176],
          [0.9922, 0.9882, 0.9922,  ..., 1.0000, 0.9804, 0.9882],
          [0.9529, 0.9569, 0.9647,  ..., 1.0000, 0.9608, 0.9765],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 0.9451, 0.9843, 1.0000],
          [1.0000, 0.9922, 0.9569,  ..., 0.9647, 0.9843, 1.0000],
          [0.9765, 0.9569, 0.9137,  ..., 0.9765, 0.9765, 1.0000]],

         [[0.1647, 0.1569, 0.1451,  ..., 0.2392, 0.1922, 0.1961],
          [0.1294, 0.1216, 0.1098,  ..., 0.2588, 0.2431, 0.2706],
          [0.0902, 0.0941, 0.0902,  ..., 0.2510, 0.2392, 0.2784],
          ...,
          [0.4314, 0.4078, 0.3922,  ..., 0.2627, 0.3216, 0.4196],
          [0.4471, 0.4275, 0.3843,  ..., 0.2824, 0.3216, 0.4078],
          [0.4549, 0.4275, 0.3647,  ..., 0.2941, 0.3137, 0.3765]],

         [[0.5882, 0.5725, 0.5608,  ..., 0.6627, 0.6549, 0.6745],
          [0.5529, 0.5373, 0.5255,  ..., 0.6902, 0.7137, 0

In [92]:
print(valid_data)

<torch.utils.data.dataloader.DataLoader object at 0x0000014BAD9AD6F0>


In [93]:
for label, anchor, positive, negative in valid_data:
    print("Label:", label)
    print(len(label))
    print('Anchor:', anchor)
    print('Positive:', positive)
    print('Negative:', negative)

Label: tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 1, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 1.0000, 1.0000,  ..., 0.9961, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 0.9725,  ..., 1.0000, 1.0000, 1.0000],
          [0.9804, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [0.9412, 1.0000, 0.9922,  ..., 0.9961, 0.9804, 0.9725]],

         [[0.5412, 0.5490, 0.5529,  ..., 0.7255, 0.7412, 0.7647],
          [0.5490, 0.5529, 0.5490,  ..., 0.7294, 0.7373, 0.7451],
          [0.5490, 0.5373, 0.5294,  ..., 0.7529, 0.7451, 0.7412],
          ...,
          [0.6196, 0.6157, 0.5765,  ..., 0.7216, 0.7176, 0.7216],
          [0.6039, 0.6392, 0.6157,  ..., 0.7176, 0.7059, 0.6980],
          [0.5725, 0.6275, 0.6078,  ..., 0.6784, 0.6627, 0.6471]],

         [[1.0000, 1.0000, 1.0000,  ..., 0.

Label: tensor([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[0.8157, 0.8275, 0.8196,  ..., 0.8902, 0.8980, 0.9020],
          [0.8039, 0.8000, 0.7804,  ..., 0.8980, 0.9020, 0.9098],
          [0.8392, 0.8196, 0.7843,  ..., 0.9020, 0.9020, 0.9059],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 0.9922, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 0.9961, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 0.9961, 1.0000, 1.0000]],

         [[0.7569, 0.7686, 0.7608,  ..., 0.9098, 0.9059, 0.9098],
          [0.7255, 0.7216, 0.6941,  ..., 0.9176, 0.9216, 0.9176],
          [0.7176, 0.6980, 0.6588,  ..., 0.9216, 0.9216, 0.9255],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 0.9725, 0.9843, 0.9922],
          [1.0000, 1.0000, 1.0000,  ..., 0.9804, 0.9922, 0.9922],
          [1.0000, 1.0000, 1.0000,  ..., 0.9804, 0.9922, 0.9922]],

         [[1.0000, 1.0000, 1.0000,  ..., 0.

Label: tensor([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 1])
32
Anchor: tensor([[[[0.8549, 0.8078, 0.8118,  ..., 1.0000, 1.0000, 1.0000],
          [0.8980, 0.8706, 0.8431,  ..., 0.9882, 1.0000, 1.0000],
          [0.9059, 0.9098, 0.8863,  ..., 0.9176, 0.9255, 0.9255],
          ...,
          [1.0000, 1.0000, 0.9647,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 0.9804, 0.9608,  ..., 0.9608, 0.9882, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 0.9255, 0.9451, 1.0000]],

         [[0.3333, 0.2784, 0.2471,  ..., 0.5490, 0.5647, 0.5569],
          [0.3961, 0.3490, 0.2980,  ..., 0.4902, 0.5216, 0.5333],
          [0.4275, 0.4157, 0.3608,  ..., 0.4588, 0.4980, 0.5255],
          ...,
          [0.5255, 0.4824, 0.4549,  ..., 0.6627, 0.6706, 0.6667],
          [0.4745, 0.4510, 0.4431,  ..., 0.5882, 0.6000, 0.6235],
          [0.4824, 0.4824, 0.4902,  ..., 0.5451, 0.5490, 0.6039]],

         [[0.6902, 0.6471, 0.6314,  ..., 0.

Label: tensor([0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 1, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[1.0000, 1.0000, 1.0000,  ..., 0.9765, 0.9882, 0.9961],
          [1.0000, 0.9922, 0.9686,  ..., 0.9804, 0.9843, 0.9922],
          [0.9686, 0.9529, 0.9490,  ..., 0.9882, 0.9843, 0.9843],
          ...,
          [1.0000, 1.0000, 0.9961,  ..., 0.9961, 0.9961, 0.9922],
          [1.0000, 1.0000, 0.9961,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 0.9922,  ..., 1.0000, 1.0000, 1.0000]],

         [[0.5882, 0.6039, 0.6078,  ..., 0.8588, 0.8667, 0.8745],
          [0.6000, 0.5804, 0.5686,  ..., 0.8627, 0.8627, 0.8706],
          [0.5608, 0.5451, 0.5490,  ..., 0.8667, 0.8627, 0.8627],
          ...,
          [0.7098, 0.6980, 0.6902,  ..., 0.5765, 0.5686, 0.5647],
          [0.7059, 0.6941, 0.6902,  ..., 0.5804, 0.5765, 0.5725],
          [0.7020, 0.6902, 0.6863,  ..., 0.5961, 0.5961, 0.5922]],

         [[0.9961, 1.0000, 1.0000,  ..., 0.

Label: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
32
Anchor: tensor([[[[0.4078, 0.4039, 0.3961,  ..., 0.5176, 0.5490, 0.5647],
          [0.4000, 0.4000, 0.3961,  ..., 0.5059, 0.5255, 0.5333],
          [0.4196, 0.4196, 0.4157,  ..., 0.5020, 0.5137, 0.5255],
          ...,
          [0.3804, 0.3843, 0.3843,  ..., 0.4471, 0.4353, 0.4314],
          [0.3647, 0.3686, 0.3725,  ..., 0.4627, 0.4549, 0.4275],
          [0.3569, 0.3608, 0.3647,  ..., 0.4667, 0.4549, 0.4196]],

         [[0.2471, 0.2431, 0.2431,  ..., 0.3647, 0.3804, 0.3882],
          [0.2392, 0.2392, 0.2431,  ..., 0.3529, 0.3569, 0.3647],
          [0.2588, 0.2588, 0.2627,  ..., 0.3569, 0.3529, 0.3569],
          ...,
          [0.1647, 0.1686, 0.1686,  ..., 0.2431, 0.2353, 0.2196],
          [0.1490, 0.1529, 0.1569,  ..., 0.2588, 0.2431, 0.2157],
          [0.1412, 0.1451, 0.1490,  ..., 0.2627, 0.2431, 0.2078]],

         [[0.4667, 0.4627, 0.4667,  ..., 0.

# Initializing the TripletNetwork model:

In [94]:
import torch
import torchvision.models as models

# Instantiate the model
triplet_network = TripletNetwork(embedding_size=64).to(device)

for param in triplet_network.parameters():
    param.requires_grad = False

In [95]:
summary(triplet_network, input_size=(bs, 9, 408))

Layer (type:depth-idx)                   Param #
├─ResNet: 1-1                            --
|    └─Conv2d: 2-1                       (9,408)
|    └─BatchNorm2d: 2-2                  (128)
|    └─ReLU: 2-3                         --
|    └─MaxPool2d: 2-4                    --
|    └─Sequential: 2-5                   --
|    |    └─Bottleneck: 3-1              (75,008)
|    |    └─Bottleneck: 3-2              (70,400)
|    |    └─Bottleneck: 3-3              (70,400)
|    └─Sequential: 2-6                   --
|    |    └─Bottleneck: 3-4              (379,392)
|    |    └─Bottleneck: 3-5              (280,064)
|    |    └─Bottleneck: 3-6              (280,064)
|    |    └─Bottleneck: 3-7              (280,064)
|    └─Sequential: 2-7                   --
|    |    └─Bottleneck: 3-8              (1,512,448)
|    |    └─Bottleneck: 3-9              (1,117,184)
|    |    └─Bottleneck: 3-10             (1,117,184)
|    |    └─Bottleneck: 3-11             (1,117,184)
|    |    └─Bottleneck: 3

Layer (type:depth-idx)                   Param #
├─ResNet: 1-1                            --
|    └─Conv2d: 2-1                       (9,408)
|    └─BatchNorm2d: 2-2                  (128)
|    └─ReLU: 2-3                         --
|    └─MaxPool2d: 2-4                    --
|    └─Sequential: 2-5                   --
|    |    └─Bottleneck: 3-1              (75,008)
|    |    └─Bottleneck: 3-2              (70,400)
|    |    └─Bottleneck: 3-3              (70,400)
|    └─Sequential: 2-6                   --
|    |    └─Bottleneck: 3-4              (379,392)
|    |    └─Bottleneck: 3-5              (280,064)
|    |    └─Bottleneck: 3-6              (280,064)
|    |    └─Bottleneck: 3-7              (280,064)
|    └─Sequential: 2-7                   --
|    |    └─Bottleneck: 3-8              (1,512,448)
|    |    └─Bottleneck: 3-9              (1,117,184)
|    |    └─Bottleneck: 3-10             (1,117,184)
|    |    └─Bottleneck: 3-11             (1,117,184)
|    |    └─Bottleneck: 3

In [96]:
# Unfreeze the parameters of the last residual block and the embedding layer
for param in triplet_network.backbone.layer4.parameters():
    param.requires_grad = True

In [97]:
summary(triplet_network, input_size=(bs, 9, 408))

Layer (type:depth-idx)                   Param #
├─ResNet: 1-1                            --
|    └─Conv2d: 2-1                       (9,408)
|    └─BatchNorm2d: 2-2                  (128)
|    └─ReLU: 2-3                         --
|    └─MaxPool2d: 2-4                    --
|    └─Sequential: 2-5                   --
|    |    └─Bottleneck: 3-1              (75,008)
|    |    └─Bottleneck: 3-2              (70,400)
|    |    └─Bottleneck: 3-3              (70,400)
|    └─Sequential: 2-6                   --
|    |    └─Bottleneck: 3-4              (379,392)
|    |    └─Bottleneck: 3-5              (280,064)
|    |    └─Bottleneck: 3-6              (280,064)
|    |    └─Bottleneck: 3-7              (280,064)
|    └─Sequential: 2-7                   --
|    |    └─Bottleneck: 3-8              (1,512,448)
|    |    └─Bottleneck: 3-9              (1,117,184)
|    |    └─Bottleneck: 3-10             (1,117,184)
|    |    └─Bottleneck: 3-11             (1,117,184)
|    |    └─Bottleneck: 3

Layer (type:depth-idx)                   Param #
├─ResNet: 1-1                            --
|    └─Conv2d: 2-1                       (9,408)
|    └─BatchNorm2d: 2-2                  (128)
|    └─ReLU: 2-3                         --
|    └─MaxPool2d: 2-4                    --
|    └─Sequential: 2-5                   --
|    |    └─Bottleneck: 3-1              (75,008)
|    |    └─Bottleneck: 3-2              (70,400)
|    |    └─Bottleneck: 3-3              (70,400)
|    └─Sequential: 2-6                   --
|    |    └─Bottleneck: 3-4              (379,392)
|    |    └─Bottleneck: 3-5              (280,064)
|    |    └─Bottleneck: 3-6              (280,064)
|    |    └─Bottleneck: 3-7              (280,064)
|    └─Sequential: 2-7                   --
|    |    └─Bottleneck: 3-8              (1,512,448)
|    |    └─Bottleneck: 3-9              (1,117,184)
|    |    └─Bottleneck: 3-10             (1,117,184)
|    |    └─Bottleneck: 3-11             (1,117,184)
|    |    └─Bottleneck: 3

In [98]:
# Assuming you replaced the fully connected layer with an embedding layer
for param in triplet_network.backbone.fc.parameters():
    param.requires_grad = True

# Training & Validating TripletNetwork model:

In [99]:
# # epochs = 30
# epochs = 1
# patience = 3  # For early stopping
# best_valid_loss = float('inf')
# early_stopping_counter = 0

# # Define the path where you want to save the model
# model_path = "model_weights.pth"

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# triplet_network.to(device)
# optimizer = Adam(triplet_network.parameters(), lr=0.001,weight_decay=1e-2)
# loss_func = TripletMarginLoss(margin=1.0)
# scheduler = StepLR(optimizer, step_size=1, gamma=0.99)

# history = []

# for epoch in tqdm(range(epochs)):
#     epoch_start = time.time()
#     train_loss, valid_loss = 0.0, 0.0
#     correct_train, correct_valid = 0, 0
#     total_train, total_valid = 0, 0
    
#     triplet_network.train()
#     for labels, anchor, positive, negative in train_data:
#         anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
        
#         optimizer.zero_grad()
#         anchor_embed = triplet_network(anchor)
#         positive_embed = triplet_network(positive)
#         negative_embed = triplet_network(negative)
        
#         loss = loss_func(anchor_embed, positive_embed, negative_embed)
#         loss.backward()
#         optimizer.step()
        
#         train_loss += loss.item()
        
#         # Calculate "accuracy"
#         with torch.no_grad():
#             dist_positive = (anchor_embed - positive_embed).pow(2).sum(1)  # Euclidean distance
#             dist_negative = (anchor_embed - negative_embed).pow(2).sum(1)
#             correct_train += (dist_positive < dist_negative).sum().item()
#             total_train += anchor.size(0)
    
#     scheduler.step()
#     triplet_network.eval()
#     with torch.no_grad():
#         for labels, anchor, positive, negative in valid_data:
#             anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
#             anchor_embed = triplet_network(anchor)
#             positive_embed = triplet_network(positive)
#             negative_embed = triplet_network(negative)
            
#             loss = loss_func(anchor_embed, positive_embed, negative_embed)
#             valid_loss += loss.item()
            
#             # Calculate "accuracy"
#             dist_positive = (anchor_embed - positive_embed).pow(2).sum(1)
#             dist_negative = (anchor_embed - negative_embed).pow(2).sum(1)
#             correct_valid += (dist_positive < dist_negative).sum().item()
#             total_valid += anchor.size(0)
    
#     avg_train_loss = train_loss / len(train_data)
#     avg_valid_loss = valid_loss / len(valid_data)
#     train_accuracy = correct_train / total_train
#     valid_accuracy = correct_valid / total_valid
    
#     history.append([avg_train_loss, avg_valid_loss, train_accuracy, valid_accuracy])
    
#     if avg_valid_loss < best_valid_loss:
#         best_valid_loss = avg_valid_loss
#         early_stopping_counter = 0
#         # Save the model's weights
#         torch.save(triplet_network.state_dict(), model_path)
        
#     else:
#         early_stopping_counter += 1
#         if early_stopping_counter > patience:
#             print("Early stopping triggered.")
#             break

#     epoch_end = time.time()
#     print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.4f}, Valid Loss: {avg_valid_loss:.4f}, Valid Acc: {valid_accuracy:.4f}, Time: {epoch_end - epoch_start:.2f}s")

# Computing Reference Embeddings:

In [125]:
# Assuming you've already defined 'valid_directory' and 'transform'
# valid_dataset_reference = datasets.ImageFolder(root=valid_directory, transform=transform)
# valid_data_reference = DataLoader(valid_dataset_reference, batch_size=bs, shuffle=False)

valid_dataset_reference = CustomImageFolder(root=valid_directory, transform=transform)
valid_data_reference = DataLoader(valid_dataset_reference, batch_size=bs, shuffle=False)

In [126]:
# def compute_reference_embeddings(model, dataloader, device):
#     model.eval()
#     embeddings = []
#     labels = []

#     with torch.no_grad():
#         for images, label in dataloader:
#             images = images.to(device)
#             output = model(images)  # Ensure this gets the embedding before the sigmoid
#             embeddings.append(output)
#             labels.extend(label.tolist())

#     # Average the embeddings for each class
#     class_embeddings = {}
#     unique_labels = set(labels)
#     for lbl in unique_labels:
#         class_indices = [i for i, x in enumerate(labels) if x == lbl]
#         class_embeddings[lbl] = torch.mean(torch.stack([embeddings[i] for i in class_indices]), dim=0)

#     return class_embeddings

In [127]:
def compute_reference_embeddings(model, dataloader, device):
    model.eval()
    embeddings = []
    labels = []

    with torch.no_grad():
        for images, label in dataloader:
            images = images.to(device)
            output = model(images)  # Ensure this gets the embedding before the sigmoid
            # Adjust here to append each embedding in the batch individually
            embeddings.extend(output.cpu().detach().numpy())
            labels.extend(label.cpu().numpy())  # Assuming labels are not already on CPU and not converted to list

    # Convert embeddings back to tensor after collecting them
    embeddings = torch.tensor(embeddings, dtype=torch.float32, device=device)

    # Average the embeddings for each class
    class_embeddings = {}
    unique_labels = set(labels)
    for lbl in unique_labels:
        class_indices = [i for i, x in enumerate(labels) if x == lbl]
        class_embeddings[lbl] = torch.mean(embeddings[class_indices], dim=0)

    return class_embeddings

In [128]:
# Initiliaze the Triplet Network 
loaded_model = TripletNetwork(embedding_size=64) 

# Load the model from model_path
#loaded_model.load_state_dict(torch.load("model_weights_3_epoch_Triplet.pth"))
loaded_model.load_state_dict(torch.load("model_weights.pth"))

# Set it in eval mode
#loaded_model.eval()

# Move model to device
loaded_model = loaded_model.to(device)  

In [129]:
reference_embeddings = compute_reference_embeddings(loaded_model, valid_data_reference, device)

# Save the reference embeddings for later use
torch.save(reference_embeddings, 'reference_embeddings.pt')

# Predict Single Class:

In [130]:
def predict_class(model, image, reference_embeddings, device):
    model.eval()  # Switch model to evaluation mode
    with torch.no_grad():
        # Make sure the image has a batch dimension
        if image.dim() == 3:
            image = image.unsqueeze(0)  # Add batch dimension if not present
        image = image.to(device)
        
        print("Input shape to model:", image.shape)
        
        # Get the embedding of the uploaded image
        image_embedding = model(image)

    # Initialize the closest class and smallest distance
    closest_class = None
    smallest_distance = float('inf')

    # Compare the uploaded image's embedding to each reference embedding
    for class_name, ref_embedding in reference_embeddings.items():
        distance = (image_embedding - ref_embedding.to(device)).pow(2).sum(1).item()
        if distance < smallest_distance:
            smallest_distance = distance
            closest_class = class_name

    return closest_class

In [131]:
def load_image(image_path, transform, device):
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0).to(device)  # Add batch dimension and send to device
    return image

# Define the transform
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Example image path
image_path = "data/Test_Augmented/melanoma/ISIC_0035914_v1.jpg"

# Load the image
image = load_image(image_path, transform, device)

# Load embeddings
reference_embeddings = torch.load('reference_embeddings.pt')

# Now you can directly use the loaded image for prediction
predicted_class = predict_class(loaded_model, image, reference_embeddings, device)

Input shape to model: torch.Size([1, 3, 224, 224])


In [132]:
print(predicted_class)

1


# Loading & conducting inferencing:

In [None]:
def extract_unique_value(tensor):
    
    # If all elements in the tensor evaluate to True, return 1
    if tensor.all():
        return 1
    
    # If none of the elements evaluate to True, return 0
    elif not tensor.any():
        return 0
    
    # Else, return None (contains mix of 0 and 1)
    else:
        return None

In [None]:
def test_model(model, test_loader, device):
    """
    Test the trained model on the test dataset and compute performance metrics.
    
    Args:
        model: The trained PyTorch model to evaluate.
        test_loader: DataLoader for the test dataset.
        device: The device to use for inference (default: cpu).
    Returns:
        Performance metrics such as accuracy, precision, recall, etc.
    """
    
    # Initialize predicted_label
    predicted_label = None

    # Don't include gradient in testing
    with torch.no_grad():
        
        
        for label, anchor, positive, negative in test_loader:
            
            # Send anchor, positive and negative to device
            anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
            
            # Obtain ground_truth from label tensor
            ground_truth = extract_unique_value(label)
            
            # Generate embeddings for anchor, positive and negative through the model
            anchor_embed = model(anchor)
            positive_embed = model(positive)
            negative_embed = model(negative)
            
            # Calculate squared Euclidean distance between the anchor and positive/negative embeddings
            dist_positive = (anchor_embed - positive_embed).pow(2).sum(1)
            dist_negative = (anchor_embed - negative_embed).pow(2).sum(1)
            
            # If distance between anchor & positive embedding is smaller than that of anchor & negative embedding, consider it as correct prediction
            if (dist_positive < dist_negative).sum().item():
                
                print("Correct prediction made by model \n")
                
                if ground_truth == 1:
                    
                    # Set predicted_label same as ground_truth
                    predicted_label = ground_truth 
                    
                    print("Ground truth for the image: Melanoma")
                    print(f"Model's prediction for the image: Melanoma \n")
                    
                else:
                    
                    # Set predicted_label same as ground_truth
                    predicted_label = ground_truth 
                    
                    print("Ground truth for this particular image was: Not Melanoma")
                    print(f"Model's prediction for the image: Not Melanoma \n")
                    
            else:
                
                print("Incorrect prediction made by model \n")
                
                
                if ground_truth == 1:
                    
                    # Set predicted_label opposite as ground_truth
                    predicted_label = 0 
                    
                    print("Ground truth for the image: Melanoma")
                    print(f"Model's prediction for the image: Not Melanoma \n")
                    
                else:
                    
                    # Set predicted_label opposite as ground_truth
                    predicted_label = 1 
                    
                    print("Ground truth for the image: Not Melanoma")
                    print(f"Model's prediction for the image: Melanoma \n")
    
    return ground_truth, predicted_label

In [None]:
ground_truth, predicted_label = test_model(loaded_model, test_data, device)

In [None]:
print(ground_truth)
print(predicted_label)

# Appendix:

Normal test_model:

In [None]:
# def test_model(model, test_loader, device):
#     """
#     Test the trained model on the test dataset and compute performance metrics.
    
#     Args:
#         model: The trained PyTorch model to evaluate.
#         test_loader: DataLoader for the test dataset.
#         device: The device to use for inference (default: cpu).
#     Returns:
#         Performance metrics such as accuracy, precision, recall, etc.
#     """
    
#     # Set the model to evaluation mode
#     model.eval()  
    
#     # Track number of correct predictions
#     correct = 0
    
#     # Track total number of samples processed
#     total = 0
    
#     # Don't include gradient in testing
#     with torch.no_grad():
#         for anchor, positive, negative in test_loader:
            
#             # Send anchor, positive and negative to device
#             anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)
            
#             # Generate embeddings for anchor, positive and negative through the model
#             anchor_embed = model(anchor)
#             positive_embed = model(positive)
#             negative_embed = model(negative)
            
#             # Calculate squared Euclidean distance between the anchor and positive/negative embeddings
#             dist_positive = (anchor_embed - positive_embed).pow(2).sum(1)
#             dist_negative = (anchor_embed - negative_embed).pow(2).sum(1)
            
#             # If distance between anchor & positive embedding is smaller than that of anchor & negative embedding, consider it as correct prediction
#             correct += (dist_positive < dist_negative).sum().item()
            
#             # Update total count of samples
#             total += anchor.size(0)
    
#     # Return accuracy 
#     accuracy = correct / total
    
#     return accuracy