In [1]:
import os
base_dir = os.path.abspath('../..')
save_dir = os.path.join(base_dir,'results')
data_dir = os.path.join(base_dir,'data')

In [17]:
import torch
# torch.manual_seed(0)
import wandb
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
from torch.utils.data import TensorDataset,Dataset
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.transforms import Resize

import torchvision
from tqdm import tqdm
import random
import numpy as np
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200


from sklearn import linear_model
from sklearn.preprocessing import StandardScaler

# from neurora.rdm_corr import rdm_correlation_spearman

In [18]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Define model

In [19]:
class TripletLabelModel(nn.Module):
    def __init__(self, encoded_space_dim=64, num_classes=10):
        super().__init__()
        ""
        ### Convolutional section
       ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
    
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)
        ### Linear section
        ## changed 32*4*4 to 32*2*2
        self.encoder_lin = nn.Sequential(
            nn.Linear(32*2*2, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(True),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(True),
            nn.Linear(128, encoded_space_dim)
        )

        ## triplet projection module
        self.decoder_triplet_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 32),
            nn.ReLU(True)
         
        )
        ##labeling module
        self.decoder_labels_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 32),
            nn.ReLU(True),
            nn.Linear(32, 16),
            nn.ReLU(True),
            nn.Linear(16, num_classes),
        )

        ### initialize weights using xavier initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)
    
        
    def forward(self, x, y=None):
        batch_s = x.size(0)
        img_features = self.encoder_cnn(x)
        img_features = self.flatten(img_features)
        
        enc_latent = self.encoder_lin(img_features)

        triplet_latent = self.decoder_triplet_lin(enc_latent)
        label = self.decoder_labels_lin(enc_latent)
        # label = F.softmax(label,dim=1)
        return enc_latent, label

In [20]:
### custom loss computing triplet loss and labeling loss


class CustomLoss(nn.Module):
    def __init__(self, margin=10):
        super(CustomLoss, self).__init__()
        self.margin = margin
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, anchor, positive, negative, label, pred_label):
        cosine_sim = torch.nn.CosineSimilarity(1)
        # distance_positive = torch.tensor(1)-cosine_sim(anchor,positive)
   
        # distance_negative = torch.tensor(1)-cosine_sim(anchor,negative)

        # triplet_loss = torch.maximum(distance_positive - distance_negative + self.margin, torch.tensor(0))
        # triplet_loss = torch.sum(triplet_loss)
        triplet_loss = (nn.TripletMarginWithDistanceLoss( distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y)))
        triplet_loss = triplet_loss(anchor, positive, negative)
        label_loss = F.binary_cross_entropy_with_logits(pred_label.float(), label.float())
        total_loss = triplet_loss + label_loss
        return triplet_loss, label_loss, total_loss

In [21]:
t = TripletLabelModel()
cifar_model_path = '../../data/CIFAR10_NCE_i_1e-05_50.pth'
t.load_state_dict(torch.load(cifar_model_path))



<All keys matched successfully>

### Training functions

In [27]:


class TrainModels(nn.Module):
    def __init__(self, latent_dims, num_classes):
        super(TrainModels, self).__init__()
        self.triplet_lab_model = TripletLabelModel(latent_dims, 10) ### load cifar model
        cifar_model_path = '../../data/CIFAR10_NCE_i_1e-05_50.pth'
        self.triplet_lab_model.load_state_dict(torch.load(cifar_model_path))
        self.triplet_lab_model.decoder_labels_lin[4] = nn.Linear(16, num_classes)
        self.custom_loss = CustomLoss()
        self.num_classes = num_classes
    
    def forward(self, anchor_im, positive_im, negative_im):
        anchor_latent, anchor_label = self.triplet_lab_model(anchor_im)
        positive_latent, _ = self.triplet_lab_model(positive_im)
        negative_latent, _ = self.triplet_lab_model(negative_im)

        return anchor_latent, positive_latent, negative_latent, anchor_label

    def test_epoch(self, test_data):
    # Set evaluation mode for encoder and decoder
        self.eval()
        with torch.no_grad(): # No need to track the gradients
            # Define the lists to store the outputs for each batch
            test_triplet_loss = []
            test_label_loss = []
            test_total_loss = []
            total = 0
            correct = 0
            for anchor_ims, contrast_ims, labels in test_data:
                # Move tensor to the proper device
                anchor_ims = anchor_ims.to(device)
                contrast_ims = contrast_ims.to(device)
                labels = F.one_hot(labels, num_classes=self.num_classes)
                labels = labels.to(device)
                anchor_latent, positive_latent, negative_latent, pred_label = self.forward(anchor_ims, anchor_ims,contrast_ims) 
                # Append the network output and the original image to the lists
                triplet_loss, label_loss, total_loss = self.custom_loss(anchor_latent,
                                                                positive_latent, 
                                                                negative_latent, 
                                                                labels,
                                                                pred_label)
                total += labels.size(0)
                correct += (torch.argmax(pred_label, dim = 1) == torch.argmax(labels, dim = 1)).sum().item()
                test_triplet_loss.append(triplet_loss.item())
                test_label_loss.append(label_loss.item())
                test_total_loss.append(total_loss.item())
        test_triplet_loss = sum(test_triplet_loss)/len(test_triplet_loss)
        test_label_loss = sum(test_label_loss)/len(test_label_loss)
        test_total_loss = sum(test_total_loss)/len(test_total_loss)
        test_accuracy = correct/total
        return test_triplet_loss, test_label_loss, test_total_loss, test_accuracy

    def test_epoch_calculate_representation_separation(self, test_data):
    # Set evaluation mode for encoder and decoder
        self.eval()
        with torch.no_grad(): # No need to track the gradients
            accuracies = []
            for anchor_ims, contrast_ims, labels in test_data:
                # Move tensor to the proper device
                anchor_ims = anchor_ims.to(device)
                contrast_ims = contrast_ims.to(device)
                # labels = F.one_hot(labels, num_classes=self.num_classes)
                # labels = labels.to(device)
                anchor_latent, _, _, _ = self.forward(anchor_ims, anchor_ims,contrast_ims) 
                # use sklearn to predict labels from anchor_latent
                # calculate accuracy
                # x's are anchor_latent and y's are labels
                # append accuracy to list
                # put anchor_latent and labels on cpu and convert to numpy

          
                anchor_latent = anchor_latent.cpu().numpy()
                ### standard scale the data in anchor_latent before fitting to the model
                anchor_latent = StandardScaler().fit_transform(anchor_latent)
                labels = labels.cpu().numpy()
                
                lm = linear_model.LogisticRegression()
                lm.fit(anchor_latent, labels)
                # convert labels to sklearn format
                accuracies.append(lm.score(anchor_latent, labels))
        accuracy = sum(accuracies)/len(accuracies)
        return accuracy

    def train_epoch(self, train_data, optimizer, train_mode):
        self.train()
        train_triplet_loss = []
        train_label_loss = []
        train_total_loss = []
        correct = 0
        total = 0
        for anchor_ims, contrast_ims, labels in train_data:
            
            anchor_ims = anchor_ims.to(device)
            contrast_ims = contrast_ims.to(device)
            labels = F.one_hot(labels, num_classes=self.num_classes)
            labels = labels.to(device)

            optimizer.zero_grad()
            anchor_latent, positive_latent, negative_latent, pred_label = self.forward(anchor_ims, anchor_ims,contrast_ims) 
           
           
           
            triplet_loss, label_loss, total_loss = self.custom_loss(anchor_latent,
                                                                positive_latent, 
                                                                negative_latent, 
                                                                labels,
                                                                pred_label)
            
            
            if train_mode==0:
                triplet_loss.backward()
            elif train_mode==1:
                label_loss.backward()
            elif train_mode==2:
                total_loss.backward()

            optimizer.step()
            train_triplet_loss.append(triplet_loss.item())
            train_label_loss.append(label_loss.item())
            train_total_loss.append(total_loss.item())
            total += labels.size(0)
            correct += (torch.argmax(pred_label, dim = 1) == torch.argmax(labels, dim = 1)).sum().item()
        train_triplet_loss = sum(train_triplet_loss)/len(train_triplet_loss)
        train_label_loss = sum(train_label_loss)/len(train_label_loss)
        train_total_loss = sum(train_total_loss)/len(train_total_loss)
        train_accuracy = correct/total
        return train_triplet_loss, train_label_loss, train_total_loss, train_accuracy

    def training_loop(self, train_data, test_data,train_mode,
                      epochs, optimizer):
        train_losses = []
        val_losses = []
        train_triplet_losses = []
        val_triplet_losses = []
        train_label_losses = []
        val_label_losses = []
        train_accuracies = []
        val_accuracies = []
        latent_separation_accuracy = 0
        for epoch in tqdm(range(epochs)):
          train_triplet_loss, train_label_loss, train_total_loss, train_accuracy =self.train_epoch(train_data, optimizer, 
                                             train_mode)
          test_triplet_loss, test_label_loss, test_total_loss, test_accuracy = self.test_epoch(test_data)
          separation_accuracy = self.test_epoch_calculate_representation_separation(test_data)
          train_losses.append(train_total_loss)
          val_losses.append(test_total_loss)
          train_triplet_losses.append(train_triplet_loss)
          val_triplet_losses.append(test_triplet_loss)
          train_label_losses.append(train_label_loss)
          val_label_losses.append(test_label_loss)
          train_accuracies.append(train_accuracy)
          val_accuracies.append(test_accuracy)
          wandb.log({"train triplet loss": train_triplet_loss, 
            "train label loss":train_label_loss, 
            "validation triplet loss":test_triplet_loss, 
            "validation label loss":test_label_loss, 
            "total train loss":train_total_loss, 
            "total validation loss":test_total_loss, 
            "train label accuracy":train_accuracy, 
            "validation label accuracy":test_accuracy,
            'latent separation accuracy':separation_accuracy})
        return train_triplet_losses, train_label_losses, val_triplet_losses, val_label_losses ,train_losses, val_losses, train_accuracies, val_accuracies


In [28]:
set_A_ims = np.load(os.path.join(data_dir, 'set_A.npy'))
set_B_ims = np.load(os.path.join(data_dir, 'set_B.npy'))
set_C_ims = np.load(os.path.join(data_dir, 'set_C.npy'))
set_A_labs = np.load(os.path.join(data_dir, 'set_A_labs.npy'))
set_B_labs = np.load(os.path.join(data_dir, 'set_B_labs.npy'))
set_C_labs = np.load(os.path.join(data_dir, 'set_C_labs.npy'))


In [29]:
dset_size = 30

In [30]:
150/600

0.25

In [31]:
# 40:60


In [32]:

set_A_sub_ims =[]
set_B_sub_ims =[]
set_C_sub_ims =[]

set_A_sub_labs =[]
set_B_sub_labs =[]
set_C_sub_labs =[]


for i in range (4):
    sub_main = set_A_ims[i*600:(i*600)+600]
    labels_main = set_A_labs[i*600:(i*600)+600]
    np.random.seed(711)
    np.random.shuffle(sub_main)
    np.random.seed(711)
    np.random.shuffle(labels_main)

    set_A_sub_ims.append(sub_main[:30])
    set_B_sub_ims.append(sub_main[:15])
    set_B_sub_ims.append(sub_main[30:45])
    set_C_sub_ims.append(sub_main[35:65])

    set_A_sub_labs.append(labels_main[:30])
    set_B_sub_labs.append(labels_main[:15])
    set_B_sub_labs.append(labels_main[30:45])
    set_C_sub_labs.append(labels_main[35:65])


    




##flatten set_A_sub_ims into an array of shape 120,64,64,3
set_A_sub_ims = np.concatenate(set_A_sub_ims)
set_B_sub_ims = np.concatenate(set_B_sub_ims)
set_C_sub_ims = np.concatenate(set_C_sub_ims)

set_A_sub_labs = np.concatenate(set_A_sub_labs)
set_B_sub_labs = np.concatenate(set_B_sub_labs)
set_C_sub_labs = np.concatenate(set_C_sub_labs)


A-B: 50% \
A-C: 0% \
B-C: 33.33%

In [33]:

###initialize weights and bias tracking
def wandb_init(epochs, lr, train_mode, batch_size, model_number,data_set):
  wandb.init(project="ConceptualAlignment2023", entity="psych-711",settings=wandb.Settings(start_method="thread"))
  wandb.config = {
    "learning_rate": lr,
    "epochs": epochs,
    "batch_size": batch_size, 
    # "label_ratio":label_ratio, 
    "model_number": model_number,
    "dataset": data_set,
    "train_mode":train_mode,
  }
  train_mode_dict = {0:'triplet', 1:'label', 2:'label_and_triplet'}
  wandb.run.name = f'{data_set}_{train_mode_dict[train_mode]}_{model_number}'
  wandb.run.save()
     

In [34]:

def main_code(save_dir, num_models, epochs, num_classes, batch_size,
             lr, latent_dims):
  if os.path.isdir(save_dir):
    pass
  else:
    os.mkdir(save_dir)


  # test_intervals = [(540, 600), (1140, 1200), (1740, 1800), (2340, 2400)]
  test_intervals = [(25, 30), (55, 60), (85, 90), (115, 120)]
  # initialize an empty list to hold the indices
  val_indices = []

  # loop through the intervals and append the indices to the list
  for start, stop in test_intervals:
      val_indices.extend(list(range(start, stop)))

  # train_indices = (np.setdiff1d(np.arange(2400),np.array(val_indices)))
  train_indices = (np.setdiff1d(np.arange(120),np.array(val_indices)))

  # np.random.seed(56)
  # contrast_indices  = np.concatenate((np.random.choice(np.arange(start=600, stop=2400), 600, replace=False),
  #               np.random.choice(np.concatenate((np.arange(start=0, stop=600), np.arange(start=1200, stop=2400))), 600, replace=False),
  #               np.random.choice(np.concatenate((np.arange(start=0, stop=1200), np.arange(start=1800, stop=2400))), 600, replace=False),
  #               np.random.choice(np.arange(start=1800, stop=2400), 600, replace=False)))
  contrast_indices  = np.concatenate((np.random.choice(np.arange(start=30, stop=120), 30, replace=False),
                np.random.choice(np.concatenate((np.arange(start=0, stop=30), np.arange(start=60, stop=120))), 30, replace=False),
                np.random.choice(np.concatenate((np.arange(start=0, stop=60), np.arange(start=90, stop=120))), 30, replace=False),
                np.random.choice(np.arange(start=0, stop=90), 30, replace=False)))

  for data_set in ['set_A','set_B','set_C']:
    for train_mode in tqdm(range(3)):
     # torch.manual_seed(0)
      for model in range(num_models):
        wandb_init(epochs, lr, train_mode, batch_size, model,data_set)

        # if data_set=='set_A':
        #   train_data = TensorDataset(Resize(32)(torch.tensor(set_A_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_A_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
        #                              torch.tensor(set_A_labs).to(torch.int64))
        # elif data_set=='set_B':
        #   train_data = TensorDataset(Resize(32)(torch.tensor(set_B_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_B_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
        #                              torch.tensor(set_B_labs).to(torch.int64))
        # elif data_set=='set_C':
        #   train_data = TensorDataset(Resize(32)(torch.tensor(set_C_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_C_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
        #                              torch.tensor(set_C_labs).to(torch.int64))

        if data_set=='set_A':
          train_data = TensorDataset(Resize(32)(torch.tensor(set_A_sub_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_A_sub_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
                                     torch.tensor(set_A_sub_labs).to(torch.int64))
        elif data_set=='set_B':
          train_data = TensorDataset(Resize(32)(torch.tensor(set_B_sub_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_B_sub_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
                                     torch.tensor(set_B_sub_labs).to(torch.int64))
        elif data_set=='set_C':
          train_data = TensorDataset(Resize(32)(torch.tensor(set_C_sub_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_C_sub_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
                                     torch.tensor(set_C_sub_labs).to(torch.int64))
          
        val_data = torch.utils.data.Subset(train_data, val_indices)
        train_data = torch.utils.data.Subset(train_data, train_indices)
       

        train_data = torch.utils.data.DataLoader(train_data, 
                                                batch_size=batch_size,
                                              shuffle=True)
        val_data = torch.utils.data.DataLoader(val_data, 
                                                batch_size=batch_size,
                                              shuffle=True)
        
     

        train_obj = TrainModels(latent_dims, num_classes).to(device) # GPU
        optimizer = torch.optim.Adam(train_obj.parameters(), lr=lr, weight_decay=1e-05)
        train_triplet_losses, train_label_losses, \
          val_triplet_losses, val_label_losses, \
            train_losses, val_losses, train_accuracies, val_accuracies= train_obj.training_loop(train_data = train_data,
                                                            test_data = val_data,
                                                            epochs = epochs,
                                                            optimizer = optimizer, 
                                                            train_mode = train_mode)




        print('validation triplet loss:',val_triplet_losses,'validation total loss:',val_losses,'validation accuracy:',val_accuracies)
        # wandb.log({"train_img_loss": train_img_loss, 
        #           "train_label_loss":train_label_loss, 
        #           "val_img_loss":val_img_loss, 
        #           "val_label_loss":val_label_loss, 
        #           "train_losses":train_losses, 
        #           "val_losses":val_losses, 
        #           "train_accuracy":train_accuracy, 
        #           "val_accuracy":val_accuracy})
        train_mode_dict = {0:'triplet', 1:'label',2:'label_and_triplet' }
        torch.save(train_obj.triplet_lab_model.state_dict(), os.path.join(save_dir,f'{data_set}_{train_mode_dict[train_mode]}_{model}.pth'))
        
        



In [35]:
wandb.finish()

num_classes = 4 # Number of unique class labels in the dataset
latent_dims = 64
epochs = 300
lr = 0.005
num_models = 1
batch_size = 256
save_dir = save_dir
main_code(save_dir, num_models, epochs, num_classes, batch_size,
             lr, latent_dims)
wandb.finish()

  0%|          | 0/3 [00:00<?, ?it/s]

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666875653900206, max=1.0)…

100%|██████████| 300/300 [00:06<00:00, 45.84it/s]
 33%|███▎      | 1/3 [00:10<00:21, 10.68s/it]

validation triplet loss: [0.6863194108009338, 0.7443161010742188, 0.8009245991706848, 0.7991935014724731, 0.7619277834892273, 0.6999549865722656, 0.6141796112060547, 0.550216019153595, 0.48668718338012695, 0.4324496388435364, 0.39540576934814453, 0.36474737524986267, 0.31269606947898865, 0.2687753736972809, 0.23903194069862366, 0.21765242516994476, 0.19601447880268097, 0.1802210658788681, 0.17262884974479675, 0.16479651629924774, 0.15351535379886627, 0.14423082768917084, 0.139659583568573, 0.1307387351989746, 0.12366408109664917, 0.11973211914300919, 0.11286409199237823, 0.10424084961414337, 0.09789971262216568, 0.09331382811069489, 0.090701162815094, 0.08917014300823212, 0.08837150782346725, 0.08392061293125153, 0.08462650328874588, 0.0861627385020256, 0.08798081427812576, 0.08968912810087204, 0.09120944887399673, 0.09260817617177963, 0.09376010298728943, 0.09473182260990143, 0.09560619294643402, 0.09643100947141647, 0.09725721925497055, 0.09817173331975937, 0.09917701780796051, 0.100

0,1
latent separation accuracy,██████▅▅▅███▁███████████████████████████
total train loss,█▃▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▃▁▁▂▁▃▄▄▂▂▁▂▂▂▂▃▃▃▅▆▆▆▅
total validation loss,█▅▃▂▁▁▁▂▂▂▂▂▂▂▂▂▂▁▁▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂
train label accuracy,▅▁▂▃▃▂▃▃▄▃▃▄▄▄▂▃▅▅█▆▆▇▅▅▆▅▇▇▅▄▅▆▅▅▅▄▃▄▄▄
train label loss,▁▄▃▃▃▃▃▃▃▄▄▅▄▅▅▅▆▄▁▂▂▂▄▅▅▃▃▂▃▃▂▃▄▅▄▆███▆
train triplet loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation label accuracy,▅▅▁▁▃▅▅▃▅▆▆▅▃▅▃▃▅█████▃▆▁▃▃▆▆▃▃▃▃▃▁▅▅▅▅▃
validation label loss,▅▁▄▅▄▄▄▄▄▄▄▅▅▅▅▄▃▂▁▂▂▄▆▅▄▃▄▄▃▃▄▄▅▅▆█▇▇▇▆
validation triplet loss,█▆▃▂▁▁▁▂▂▂▂▂▂▁▁▂▂▂▁▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂

0,1
latent separation accuracy,1.0
total train loss,0.74456
total validation loss,0.89352
train label accuracy,0.26
train label loss,0.74456
train triplet loss,0.0
validation label accuracy,0.2
validation label loss,0.76803
validation triplet loss,0.12549


100%|██████████| 300/300 [00:05<00:00, 57.14it/s]
 67%|██████▋   | 2/3 [00:24<00:12, 12.35s/it]

validation triplet loss: [0.7960683703422546, 0.7744645476341248, 0.7710610628128052, 0.7754284739494324, 0.7944244742393494, 0.8154653906822205, 0.8320245146751404, 0.8466028571128845, 0.8515462875366211, 0.8392183184623718, 0.8445754051208496, 0.8067044615745544, 0.7654900550842285, 0.7485978007316589, 0.6725867390632629, 0.6060080528259277, 0.612507164478302, 0.6231316924095154, 0.6411065459251404, 0.6418043971061707, 0.6141659617424011, 0.5742014646530151, 0.5387457013130188, 0.5102128982543945, 0.49110427498817444, 0.47888800501823425, 0.4699229896068573, 0.46321240067481995, 0.4582049548625946, 0.4542270600795746, 0.450664758682251, 0.4477056562900543, 0.4449353814125061, 0.4423947334289551, 0.44031858444213867, 0.43868574500083923, 0.43748483061790466, 0.43679556250572205, 0.43629246950149536, 0.4357615113258362, 0.4352361261844635, 0.4345637857913971, 0.43380728363990784, 0.43307915329933167, 0.4323634207248688, 0.43170738220214844, 0.431183785200119, 0.4308241009712219, 0.4305

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,█▇▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total validation loss,▄█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train label accuracy,▁▄▇█████████████████████████████████████
train label loss,█▆▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train triplet loss,▇█▅▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation label accuracy,▁▁▇▆████████████████████████████████████
validation label loss,▄█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation triplet loss,▇█▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
latent separation accuracy,1.0
total train loss,0.40613
total validation loss,0.52808
train label accuracy,1.0
train label loss,0.0
train triplet loss,0.40613
validation label accuracy,0.95
validation label loss,0.10827
validation triplet loss,0.4198


100%|██████████| 300/300 [00:06<00:00, 47.41it/s]
100%|██████████| 3/3 [00:38<00:00, 12.84s/it]


validation triplet loss: [0.6529640555381775, 0.6671941876411438, 0.7182641625404358, 0.731415867805481, 0.7464177012443542, 0.760253369808197, 0.7508258819580078, 0.7437281012535095, 0.7482518553733826, 0.7316343188285828, 0.6754360795021057, 0.5625578761100769, 0.45095378160476685, 0.3391636908054352, 0.24992278218269348, 0.18862663209438324, 0.13574470579624176, 0.09639106690883636, 0.09756048023700714, 0.10287284851074219, 0.08984506875276566, 0.07306090742349625, 0.0625690445303917, 0.07002898305654526, 0.08110565692186356, 0.07455454021692276, 0.0647224560379982, 0.05397266149520874, 0.044199422001838684, 0.03803696110844612, 0.032941784709692, 0.02873416244983673, 0.026792610064148903, 0.025040218606591225, 0.02375943772494793, 0.02269221656024456, 0.02164366841316223, 0.020128443837165833, 0.018858028575778008, 0.017721889540553093, 0.016619166359305382, 0.015707416459918022, 0.015247980132699013, 0.014835000038146973, 0.014401319436728954, 0.014109528623521328, 0.0116396397352

  0%|          | 0/3 [00:00<?, ?it/s]

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,█▅▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total validation loss,▆█▃▂▂▁▁▁▁▁▁▁▁▁▂▁▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train label accuracy,▁▄▇█████████████████████████████████████
train label loss,█▆▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train triplet loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation label accuracy,▂▁▅▇▇▇▇▇▇█▇▇▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
validation label loss,▅█▄▂▂▂▁▁▁▁▁▁▁▁▃▁▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation triplet loss,██▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
latent separation accuracy,1.0
total train loss,0.0
total validation loss,0.07495
train label accuracy,1.0
train label loss,0.0
train triplet loss,0.0
validation label accuracy,0.95
validation label loss,0.04125
validation triplet loss,0.03371


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668504942208527, max=1.0…

100%|██████████| 300/300 [00:06<00:00, 47.11it/s]
  lambda data: self._console_raw_callback("stderr", data),


validation triplet loss: [0.684748649597168, 0.6086589694023132, 0.6142129302024841, 0.6361970901489258, 0.6346265077590942, 0.6169469952583313, 0.6127795577049255, 0.5802468657493591, 0.5542241930961609, 0.5163036584854126, 0.472921758890152, 0.4259536862373352, 0.3859459459781647, 0.357585608959198, 0.34205397963523865, 0.32246044278144836, 0.3013746440410614, 0.288934588432312, 0.2794848382472992, 0.2813207805156708, 0.2852117717266083, 0.28095993399620056, 0.28175702691078186, 0.2826051414012909, 0.2821059823036194, 0.28391847014427185, 0.27631449699401855, 0.26772069931030273, 0.26132073998451233, 0.2558583915233612, 0.25218912959098816, 0.2514455318450928, 0.250297874212265, 0.24932371079921722, 0.24879947304725647, 0.24836821854114532, 0.24800460040569305, 0.24692454934120178, 0.24767647683620453, 0.2455265074968338, 0.24466359615325928, 0.24909289181232452, 0.2529143989086151, 0.25551578402519226, 0.25534358620643616, 0.25432756543159485, 0.2476930171251297, 0.24171052873134613

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,▇▅▄▆▇█▇▇█▆▆▅▅▃▁▂▄▃▃▂▂▃▄▄▆▇████▆▆▆▆▄▃▅▇▇▇
total validation loss,██▄▄▃▃▃▃▃▃▂▂▂▂▂▃▃▂▁▁▁▃▂▂▂▂▃▃▃▃▃▃▃▃▃▃▃▂▂▂
train label accuracy,▆▇▇▆▆▆▆▆▃▆▆▅▆▅▁▄▄▄▄▄▇▄▂▁▆▅▆▆▆▆▅▃▂▄█▄▅▅▅▆
train label loss,▁▃▅▆▇████▇▆▆▆▄▃▄▅▅▅▄▄▅▅▅▇█████▇▇▇▇▆▅▆▇▇▇
train triplet loss,█▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation label accuracy,▇█▇▇▇█▇▇▇▇▇▅█▇█▇▇▇█▇▇▇▄▁▄▇▇▅▇▄█▇▅▄▅▅▇▅▅▅
validation label loss,▂▅▅▆▆▇▇▇█▇▆▅▄▃▂▅▄▃▁▁▃▄▂▃▄▄▅▅▆▆▇▅▆▆▆▆▆▅▄▃
validation triplet loss,█▇▄▃▃▃▂▂▂▂▂▂▂▂▂▃▃▂▁▁▁▂▂▂▂▂▂▂▂▂▃▂▂▂▃▂▂▂▂▂

0,1
latent separation accuracy,1.0
total train loss,0.73071
total validation loss,0.89543
train label accuracy,0.25
train label loss,0.73071
train triplet loss,0.0
validation label accuracy,0.2
validation label loss,0.72221
validation triplet loss,0.17322


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668725572526456, max=1.0…

100%|██████████| 300/300 [00:05<00:00, 55.69it/s]
  lambda data: self._console_raw_callback("stderr", data),


validation triplet loss: [0.8203703761100769, 0.7507730722427368, 0.7528653144836426, 0.7605627179145813, 0.748887836933136, 0.745153546333313, 0.7679628729820251, 0.7537299990653992, 0.7212414145469666, 0.7650923132896423, 0.7430674433708191, 0.7356300354003906, 0.7206581234931946, 0.6830807328224182, 0.6447067260742188, 0.6178328394889832, 0.5945550799369812, 0.5783262252807617, 0.5642932653427124, 0.5524737238883972, 0.548245370388031, 0.5425359606742859, 0.5344192385673523, 0.5275249481201172, 0.5229005813598633, 0.51861971616745, 0.5137683749198914, 0.5085318684577942, 0.5034374594688416, 0.498715877532959, 0.49484559893608093, 0.4915456771850586, 0.48870736360549927, 0.4865911602973938, 0.485175222158432, 0.48451921343803406, 0.48403388261795044, 0.48384523391723633, 0.48392221331596375, 0.48430272936820984, 0.4845491945743561, 0.4844597280025482, 0.4840044677257538, 0.4832061231136322, 0.48209354281425476, 0.4807193875312805, 0.4793531894683838, 0.4779515266418457, 0.47672587633

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,█▅▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total validation loss,█▇▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train label accuracy,▁▇██████████████████████████████████████
train label loss,█▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train triplet loss,▁▅▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
validation label accuracy,▁▃▆████████████████████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
validation label loss,█▆▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation triplet loss,██▅▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
latent separation accuracy,1.0
total train loss,0.48309
total validation loss,0.49985
train label accuracy,1.0
train label loss,0.0
train triplet loss,0.48309
validation label accuracy,0.95
validation label loss,0.05474
validation triplet loss,0.44511


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668480075895786, max=1.0…

100%|██████████| 300/300 [00:06<00:00, 46.36it/s]
  lambda data: self._console_raw_callback("stderr", data),
100%|██████████| 3/3 [00:41<00:00, 13.88s/it]


validation triplet loss: [0.7163375616073608, 0.7155418395996094, 0.717268168926239, 0.6877397894859314, 0.6773072481155396, 0.665557324886322, 0.6049054265022278, 0.5830373167991638, 0.5900408625602722, 0.5938760638237, 0.5966400504112244, 0.5866459608078003, 0.5334599018096924, 0.47422319650650024, 0.404109925031662, 0.3126247823238373, 0.2577734887599945, 0.22873620688915253, 0.2374824732542038, 0.24769015610218048, 0.25351324677467346, 0.24947679042816162, 0.2531905770301819, 0.2607647180557251, 0.24932144582271576, 0.22204045951366425, 0.19891981780529022, 0.1776742935180664, 0.1420324146747589, 0.12342579662799835, 0.11009359359741211, 0.09697508066892624, 0.08807722479104996, 0.08507118374109268, 0.07909972220659256, 0.07401665300130844, 0.06898420304059982, 0.06346767395734787, 0.05905212089419365, 0.05649843439459801, 0.05681758001446724, 0.058299969881772995, 0.06083018332719803, 0.05715053901076317, 0.053660690784454346, 0.05146346241235733, 0.04752609506249428, 0.0461118519

  0%|          | 0/3 [00:00<?, ?it/s]

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,█▆▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total validation loss,█▇▂▁▁▁▂▂▂▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
train label accuracy,▁▃▇█████████████████████████████████████
train label loss,█▆▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train triplet loss,█▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation label accuracy,▂▁▆███▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇██
validation label loss,▆▅▂▁▃▄▅▆▅▇████████████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
validation triplet loss,█▇▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
latent separation accuracy,1.0
total train loss,0.0
total validation loss,0.69848
train label accuracy,1.0
train label loss,0.0
train triplet loss,0.0
validation label accuracy,0.85
validation label loss,0.65372
validation triplet loss,0.04476


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668398507560293, max=1.0…

100%|██████████| 300/300 [00:06<00:00, 46.90it/s]
  lambda data: self._console_raw_callback("stderr", data),


validation triplet loss: [0.5521707534790039, 0.39689046144485474, 0.3344293534755707, 0.29128164052963257, 0.33400851488113403, 0.36135396361351013, 0.3918306827545166, 0.41050729155540466, 0.43030744791030884, 0.4535525441169739, 0.4388657510280609, 0.4184727370738983, 0.37934380769729614, 0.3518800735473633, 0.31744351983070374, 0.2780458629131317, 0.24228592216968536, 0.22208920121192932, 0.20546482503414154, 0.1918599158525467, 0.1802792102098465, 0.17060087621212006, 0.16149449348449707, 0.15760014951229095, 0.1534852534532547, 0.14911124110221863, 0.14876142144203186, 0.14749346673488617, 0.13666729629039764, 0.13072586059570312, 0.1262650489807129, 0.1229868158698082, 0.12063316255807877, 0.11903542280197144, 0.11825302988290787, 0.11760518699884415, 0.11701243370771408, 0.11648907512426376, 0.1127203106880188, 0.11382788419723511, 0.11742427200078964, 0.11362522840499878, 0.12459228187799454, 0.1328897774219513, 0.13820065557956696, 0.14113914966583252, 0.14280973374843597, 0.

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,▇▂▃▃▃▃▃▄▄▃▃▄▅▇█▇▆▄▃▃▃▃▃▄▄▃▃▃▄▄▄▇▆▂▁▂▅▅▅▅
total validation loss,▆█▆▄▃▃▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▃▂▂▂▂▃▂▃▂▁▁▁▂▂▂▃
train label accuracy,▅▃▄▅▃▃▅▅▃▄▅▃▄▆▆▃▃▃▃▃▃▃▄▄▄▃▄▅▄▂▃▁▅▇▆▅▇█▆█
train label loss,▂▂▃▃▃▃▃▄▄▃▃▄▅▇█▇▆▄▃▃▃▃▃▄▄▃▃▃▄▄▄▇▆▂▁▂▅▅▅▅
train triplet loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation label accuracy,▄▄██▆▆▄▄▄▆██▆▆▄▄▁▄▄▆▆▆▆▄▄▄▃▄▄▁▄▄▄▃▃▄▆▄█▄
validation label loss,▁▃▃▅▆▇▇█▇▇▅▂▃▄▅▅▅▄▅▅▄▄▄▅▅▃▃▃▃▃▅▆▄▂▁▁▃▄▄▄
validation triplet loss,▆█▆▃▃▃▃▂▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▂▂▂

0,1
latent separation accuracy,1.0
total train loss,0.71054
total validation loss,0.81147
train label accuracy,0.33
train label loss,0.71054
train triplet loss,0.0
validation label accuracy,0.2
validation label loss,0.71246
validation triplet loss,0.099


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666849033596615, max=1.0)…

100%|██████████| 300/300 [00:05<00:00, 55.77it/s]
  lambda data: self._console_raw_callback("stderr", data),


validation triplet loss: [0.6911994218826294, 0.7093791365623474, 0.7188628911972046, 0.7621520161628723, 0.8047378659248352, 0.8175811767578125, 0.8054048418998718, 0.7805483937263489, 0.7308078408241272, 0.678548276424408, 0.6232680678367615, 0.5949092507362366, 0.5672026872634888, 0.5422172546386719, 0.5129274129867554, 0.4805111587047577, 0.46587324142456055, 0.4545253813266754, 0.444928377866745, 0.4364512860774994, 0.4303024709224701, 0.42738398909568787, 0.42585378885269165, 0.4247901141643524, 0.4229370057582855, 0.4212643802165985, 0.41916584968566895, 0.4169701635837555, 0.4144483506679535, 0.4122041165828705, 0.41043004393577576, 0.4089180529117584, 0.40731629729270935, 0.4059363901615143, 0.4047562777996063, 0.4040282666683197, 0.40367060899734497, 0.40344902873039246, 0.403303861618042, 0.4032896161079407, 0.4033872187137604, 0.4036906361579895, 0.4040639102458954, 0.40447521209716797, 0.4048921763896942, 0.40532001852989197, 0.4057399332523346, 0.40611791610717773, 0.4065

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,█▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total validation loss,▇█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train label accuracy,▁▆██████████████████████████████████████
train label loss,█▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train triplet loss,▁▅▇███████████████████████████▇▇▇▇▇▇▇▇▇▇
validation label accuracy,▁▂▇█████████████████████████████████████
validation label loss,▆█▂▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation triplet loss,██▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
latent separation accuracy,1.0
total train loss,0.38306
total validation loss,0.44254
train label accuracy,1.0
train label loss,0.0
train triplet loss,0.38306
validation label accuracy,0.95
validation label loss,0.03391
validation triplet loss,0.40863


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668500006198884, max=1.0…

100%|██████████| 300/300 [00:05<00:00, 53.29it/s]
  lambda data: self._console_raw_callback("stderr", data),
100%|██████████| 3/3 [00:42<00:00, 14.14s/it]

validation triplet loss: [0.5582502484321594, 0.47157904505729675, 0.4538884162902832, 0.4453936517238617, 0.446988970041275, 0.4235537648200989, 0.4044879972934723, 0.40253233909606934, 0.4473557472229004, 0.47154757380485535, 0.4861145615577698, 0.49492621421813965, 0.5103545188903809, 0.533054530620575, 0.5019141435623169, 0.4676712453365326, 0.4148188531398773, 0.3440546691417694, 0.3125035762786865, 0.2982857823371887, 0.26579615473747253, 0.2236388772726059, 0.173209547996521, 0.1281779408454895, 0.10198425501585007, 0.08558379858732224, 0.06799783557653427, 0.056807439774274826, 0.047674477100372314, 0.04102534055709839, 0.0364123210310936, 0.033108364790678024, 0.03365987166762352, 0.034381818026304245, 0.03562166914343834, 0.03647340461611748, 0.037025269120931625, 0.03702318295836449, 0.04084473475813866, 0.043173421174287796, 0.04443296417593956, 0.044872500002384186, 0.04020344465970993, 0.035594258457422256, 0.030746012926101685, 0.02761121653020382, 0.02535751461982727, 0




0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,█▅▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total validation loss,█▇█▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train label accuracy,▁▄██████████████████████████████████████
train label loss,█▆▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train triplet loss,█▂▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation label accuracy,▁▂▄▆████████████████████████████████████
validation label loss,█▆█▅▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation triplet loss,█▇█▃▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
latent separation accuracy,1.0
total train loss,0.0
total validation loss,2e-05
train label accuracy,1.0
train label loss,0.0
train triplet loss,0.0
validation label accuracy,1.0
validation label loss,2e-05
validation triplet loss,0.0
