In [1]:
import os
base_dir = os.path.abspath('../..')
save_dir = os.path.join(base_dir,'results')
data_dir = os.path.join(base_dir,'data')

In [2]:
import torch
# torch.manual_seed(0)
import wandb
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
from torch.utils.data import TensorDataset,Dataset
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.transforms import Resize

import torchvision
from tqdm import tqdm
import random
import numpy as np
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200


from sklearn import linear_model
from sklearn.preprocessing import StandardScaler

# from neurora.rdm_corr import rdm_correlation_spearman

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Define model

In [4]:
class TripletLabelModel(nn.Module):
    def __init__(self, encoded_space_dim=64, num_classes=10):
        super().__init__()
        ""
        ### Convolutional section
       ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
    
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)
        ### Linear section
        ## changed 32*4*4 to 32*2*2
        self.encoder_lin = nn.Sequential(
            nn.Linear(32*2*2, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(True),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(True),
            nn.Linear(128, encoded_space_dim)
        )

        ## triplet projection module
        self.decoder_triplet_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 32),
            nn.ReLU(True)
         
        )
        ##labeling module
        self.decoder_labels_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 32),
            nn.ReLU(True),
            nn.Linear(32, 16),
            nn.ReLU(True),
            nn.Linear(16, num_classes),
        )

        ### initialize weights using xavier initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)
    
        
    def forward(self, x, y=None):
        batch_s = x.size(0)
        img_features = self.encoder_cnn(x)
        img_features = self.flatten(img_features)
        
        enc_latent = self.encoder_lin(img_features)

        triplet_latent = self.decoder_triplet_lin(enc_latent)
        label = self.decoder_labels_lin(enc_latent)
        # label = F.softmax(label,dim=1)
        return enc_latent, label

In [5]:
### custom loss computing triplet loss and labeling loss


class CustomLoss(nn.Module):
    def __init__(self, margin=10):
        super(CustomLoss, self).__init__()
        self.margin = margin
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, anchor, positive, negative, label, pred_label):
        cosine_sim = torch.nn.CosineSimilarity(1)
        # distance_positive = torch.tensor(1)-cosine_sim(anchor,positive)
   
        # distance_negative = torch.tensor(1)-cosine_sim(anchor,negative)

        # triplet_loss = torch.maximum(distance_positive - distance_negative + self.margin, torch.tensor(0))
        # triplet_loss = torch.sum(triplet_loss)
        triplet_loss = (nn.TripletMarginWithDistanceLoss( distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y)))
        triplet_loss = triplet_loss(anchor, positive, negative)
        label_loss = F.binary_cross_entropy_with_logits(pred_label.float(), label.float())
        total_loss = triplet_loss + label_loss
        return triplet_loss, label_loss, total_loss

In [6]:
t = TripletLabelModel()
cifar_model_path = '../../data/CIFAR10_NCE_i_1e-05_50.pth'
t.load_state_dict(torch.load(cifar_model_path))



<All keys matched successfully>

### Training functions

In [7]:


class TrainModels(nn.Module):
    def __init__(self, latent_dims, num_classes):
        super(TrainModels, self).__init__()
        self.triplet_lab_model = TripletLabelModel(latent_dims, 10) ### load cifar model
        cifar_model_path = '../../data/CIFAR10_NCE_i_1e-05_50.pth'
        self.triplet_lab_model.load_state_dict(torch.load(cifar_model_path))
        self.triplet_lab_model.decoder_labels_lin[4] = nn.Linear(16, num_classes)
        self.custom_loss = CustomLoss()
        self.num_classes = num_classes
    
    def forward(self, anchor_im, positive_im, negative_im):
        anchor_latent, anchor_label = self.triplet_lab_model(anchor_im)
        positive_latent, _ = self.triplet_lab_model(positive_im)
        negative_latent, _ = self.triplet_lab_model(negative_im)

        return anchor_latent, positive_latent, negative_latent, anchor_label

    def test_epoch(self, test_data):
    # Set evaluation mode for encoder and decoder
        self.eval()
        with torch.no_grad(): # No need to track the gradients
            # Define the lists to store the outputs for each batch
            test_triplet_loss = []
            test_label_loss = []
            test_total_loss = []
            total = 0
            correct = 0
            for anchor_ims, contrast_ims, labels in test_data:
                # Move tensor to the proper device
                anchor_ims = anchor_ims.to(device)
                contrast_ims = contrast_ims.to(device)
                labels = F.one_hot(labels, num_classes=self.num_classes)
                labels = labels.to(device)
                anchor_latent, positive_latent, negative_latent, pred_label = self.forward(anchor_ims, anchor_ims,contrast_ims) 
                # Append the network output and the original image to the lists
                triplet_loss, label_loss, total_loss = self.custom_loss(anchor_latent,
                                                                positive_latent, 
                                                                negative_latent, 
                                                                labels,
                                                                pred_label)
                total += labels.size(0)
                correct += (torch.argmax(pred_label, dim = 1) == torch.argmax(labels, dim = 1)).sum().item()
                test_triplet_loss.append(triplet_loss.item())
                test_label_loss.append(label_loss.item())
                test_total_loss.append(total_loss.item())
        test_triplet_loss = sum(test_triplet_loss)/len(test_triplet_loss)
        test_label_loss = sum(test_label_loss)/len(test_label_loss)
        test_total_loss = sum(test_total_loss)/len(test_total_loss)
        test_accuracy = correct/total
        return test_triplet_loss, test_label_loss, test_total_loss, test_accuracy

    def test_epoch_calculate_representation_separation(self, test_data):
    # Set evaluation mode for encoder and decoder
        self.eval()
        with torch.no_grad(): # No need to track the gradients
            accuracies = []
            for anchor_ims, contrast_ims, labels in test_data:
                # Move tensor to the proper device
                anchor_ims = anchor_ims.to(device)
                contrast_ims = contrast_ims.to(device)
                # labels = F.one_hot(labels, num_classes=self.num_classes)
                # labels = labels.to(device)
                anchor_latent, _, _, _ = self.forward(anchor_ims, anchor_ims,contrast_ims) 
                # use sklearn to predict labels from anchor_latent
                # calculate accuracy
                # x's are anchor_latent and y's are labels
                # append accuracy to list
                # put anchor_latent and labels on cpu and convert to numpy

          
                anchor_latent = anchor_latent.cpu().numpy()
                ### standard scale the data in anchor_latent before fitting to the model
                anchor_latent = StandardScaler().fit_transform(anchor_latent)
                labels = labels.cpu().numpy()
                
                lm = linear_model.LogisticRegression()
                lm.fit(anchor_latent, labels)
                # convert labels to sklearn format
                accuracies.append(lm.score(anchor_latent, labels))
        accuracy = sum(accuracies)/len(accuracies)
        return accuracy

    def train_epoch(self, train_data, optimizer, train_mode):
        self.train()
        train_triplet_loss = []
        train_label_loss = []
        train_total_loss = []
        correct = 0
        total = 0
        for anchor_ims, contrast_ims, labels in train_data:
            
            anchor_ims = anchor_ims.to(device)
            contrast_ims = contrast_ims.to(device)
            labels = F.one_hot(labels, num_classes=self.num_classes)
            labels = labels.to(device)

            optimizer.zero_grad()
            anchor_latent, positive_latent, negative_latent, pred_label = self.forward(anchor_ims, anchor_ims,contrast_ims) 
           
           
           
            triplet_loss, label_loss, total_loss = self.custom_loss(anchor_latent,
                                                                positive_latent, 
                                                                negative_latent, 
                                                                labels,
                                                                pred_label)
            
            
            if train_mode==0:
                triplet_loss.backward()
            elif train_mode==1:
                label_loss.backward()
            elif train_mode==2:
                total_loss.backward()

            optimizer.step()
            train_triplet_loss.append(triplet_loss.item())
            train_label_loss.append(label_loss.item())
            train_total_loss.append(total_loss.item())
            total += labels.size(0)
            correct += (torch.argmax(pred_label, dim = 1) == torch.argmax(labels, dim = 1)).sum().item()
        train_triplet_loss = sum(train_triplet_loss)/len(train_triplet_loss)
        train_label_loss = sum(train_label_loss)/len(train_label_loss)
        train_total_loss = sum(train_total_loss)/len(train_total_loss)
        train_accuracy = correct/total
        return train_triplet_loss, train_label_loss, train_total_loss, train_accuracy

    def training_loop(self, train_data, test_data,train_mode,
                      epochs, optimizer):
        train_losses = []
        val_losses = []
        train_triplet_losses = []
        val_triplet_losses = []
        train_label_losses = []
        val_label_losses = []
        train_accuracies = []
        val_accuracies = []
        latent_separation_accuracy = 0
        for epoch in tqdm(range(epochs)):
          train_triplet_loss, train_label_loss, train_total_loss, train_accuracy =self.train_epoch(train_data, optimizer, 
                                             train_mode)
          test_triplet_loss, test_label_loss, test_total_loss, test_accuracy = self.test_epoch(test_data)
          separation_accuracy = self.test_epoch_calculate_representation_separation(test_data)
          train_losses.append(train_total_loss)
          val_losses.append(test_total_loss)
          train_triplet_losses.append(train_triplet_loss)
          val_triplet_losses.append(test_triplet_loss)
          train_label_losses.append(train_label_loss)
          val_label_losses.append(test_label_loss)
          train_accuracies.append(train_accuracy)
          val_accuracies.append(test_accuracy)
          wandb.log({"train triplet loss": train_triplet_loss, 
            "train label loss":train_label_loss, 
            "validation triplet loss":test_triplet_loss, 
            "validation label loss":test_label_loss, 
            "total train loss":train_total_loss, 
            "total validation loss":test_total_loss, 
            "train label accuracy":train_accuracy, 
            "validation label accuracy":test_accuracy,
            'latent separation accuracy':separation_accuracy})
        return train_triplet_losses, train_label_losses, val_triplet_losses, val_label_losses ,train_losses, val_losses, train_accuracies, val_accuracies


In [8]:
set_A_ims = np.load(os.path.join(data_dir, 'set_A.npy'))
set_B_ims = np.load(os.path.join(data_dir, 'set_B.npy'))
set_C_ims = np.load(os.path.join(data_dir, 'set_C.npy'))
set_A_labs = np.load(os.path.join(data_dir, 'set_A_labs.npy'))
set_B_labs = np.load(os.path.join(data_dir, 'set_B_labs.npy'))
set_C_labs = np.load(os.path.join(data_dir, 'set_C_labs.npy'))


In [9]:
# 40:60


In [10]:

set_A_sub_ims =[]
set_B_sub_ims =[]
set_C_sub_ims =[]

set_A_sub_labs =[]
set_B_sub_labs =[]
set_C_sub_labs =[]


for i in range (4):
    sub_main = set_A_ims[i*600:(i*600)+600]
    labels_main = set_A_labs[i*600:(i*600)+600]
    np.random.seed(711)
    np.random.shuffle(sub_main)
    np.random.seed(711)
    np.random.shuffle(labels_main)

    set_A_sub_ims.append(sub_main[:30])
    set_B_sub_ims.append(sub_main[:15])
    set_B_sub_ims.append(sub_main[30:45])
    set_C_sub_ims.append(sub_main[35:65])

    set_A_sub_labs.append(labels_main[:30])
    set_B_sub_labs.append(labels_main[:15])
    set_B_sub_labs.append(labels_main[30:45])
    set_C_sub_labs.append(labels_main[35:65])


    




##flatten set_A_sub_ims into an array of shape 120,64,64,3
set_A_sub_ims = np.concatenate(set_A_sub_ims)
set_B_sub_ims = np.concatenate(set_B_sub_ims)
set_C_sub_ims = np.concatenate(set_C_sub_ims)

set_A_sub_labs = np.concatenate(set_A_sub_labs)
set_B_sub_labs = np.concatenate(set_B_sub_labs)
set_C_sub_labs = np.concatenate(set_C_sub_labs)


A-B: 50% \
A-C: 0% \
B-C: 33.33%

In [11]:

###initialize weights and bias tracking
def wandb_init(epochs, lr, train_mode, batch_size, model_number,data_set):
  wandb.init(project="ConceptualAlignment2023", entity="psych-711",settings=wandb.Settings(start_method="thread"))
  wandb.config = {
    "learning_rate": lr,
    "epochs": epochs,
    "batch_size": batch_size, 
    # "label_ratio":label_ratio, 
    "model_number": model_number,
    "dataset": data_set,
    "train_mode":train_mode,
  }
  train_mode_dict = {0:'triplet', 1:'label', 2:'label_and_triplet'}
  wandb.run.name = f'{data_set}_{train_mode_dict[train_mode]}_{model_number}'
  wandb.run.save()
     

In [12]:

def main_code(save_dir, num_models, epochs, num_classes, batch_size,
             lr, latent_dims):
  if os.path.isdir(save_dir):
    pass
  else:
    os.mkdir(save_dir)


  # test_intervals = [(540, 600), (1140, 1200), (1740, 1800), (2340, 2400)]
  test_intervals = [(25, 30), (55, 60), (85, 90), (115, 120)]
  # initialize an empty list to hold the indices
  val_indices = []

  # loop through the intervals and append the indices to the list
  for start, stop in test_intervals:
      val_indices.extend(list(range(start, stop)))

  # train_indices = (np.setdiff1d(np.arange(2400),np.array(val_indices)))
  train_indices = (np.setdiff1d(np.arange(120),np.array(val_indices)))

  # np.random.seed(56)
  # contrast_indices  = np.concatenate((np.random.choice(np.arange(start=600, stop=2400), 600, replace=False),
  #               np.random.choice(np.concatenate((np.arange(start=0, stop=600), np.arange(start=1200, stop=2400))), 600, replace=False),
  #               np.random.choice(np.concatenate((np.arange(start=0, stop=1200), np.arange(start=1800, stop=2400))), 600, replace=False),
  #               np.random.choice(np.arange(start=1800, stop=2400), 600, replace=False)))
  contrast_indices  = np.concatenate((np.random.choice(np.arange(start=30, stop=120), 30, replace=False),
                np.random.choice(np.concatenate((np.arange(start=0, stop=30), np.arange(start=60, stop=120))), 30, replace=False),
                np.random.choice(np.concatenate((np.arange(start=0, stop=60), np.arange(start=90, stop=120))), 30, replace=False),
                np.random.choice(np.arange(start=0, stop=90), 30, replace=False)))

  for data_set in ['set_A','set_A2','set_B','set_C']:
    for train_mode in tqdm(range(3)):
     # torch.manual_seed(0)
      for model in range(num_models):
        wandb_init(epochs, lr, train_mode, batch_size, model,data_set)

        # if data_set=='set_A':
        #   train_data = TensorDataset(Resize(32)(torch.tensor(set_A_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_A_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
        #                              torch.tensor(set_A_labs).to(torch.int64))
        # elif data_set=='set_B':
        #   train_data = TensorDataset(Resize(32)(torch.tensor(set_B_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_B_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
        #                              torch.tensor(set_B_labs).to(torch.int64))
        # elif data_set=='set_C':
        #   train_data = TensorDataset(Resize(32)(torch.tensor(set_C_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_C_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
        #                              torch.tensor(set_C_labs).to(torch.int64))
        if data_set=='set_A2':
          train_data = TensorDataset(Resize(32)(torch.tensor(set_A_sub_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_A_sub_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
                                     torch.tensor(set_A_sub_labs).to(torch.int64))
        if data_set=='set_A':
          train_data = TensorDataset(Resize(32)(torch.tensor(set_A_sub_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_A_sub_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
                                     torch.tensor(set_A_sub_labs).to(torch.int64))
        elif data_set=='set_B':
          train_data = TensorDataset(Resize(32)(torch.tensor(set_B_sub_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_B_sub_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
                                     torch.tensor(set_B_sub_labs).to(torch.int64))
        elif data_set=='set_C':
          train_data = TensorDataset(Resize(32)(torch.tensor(set_C_sub_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_C_sub_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
                                     torch.tensor(set_C_sub_labs).to(torch.int64))
          
        val_data = torch.utils.data.Subset(train_data, val_indices)
        train_data = torch.utils.data.Subset(train_data, train_indices)
       

        train_data = torch.utils.data.DataLoader(train_data, 
                                                batch_size=batch_size,
                                              shuffle=True)
        val_data = torch.utils.data.DataLoader(val_data, 
                                                batch_size=batch_size,
                                              shuffle=True)
        
     

        train_obj = TrainModels(latent_dims, num_classes).to(device) # GPU
        optimizer = torch.optim.Adam(train_obj.parameters(), lr=lr, weight_decay=1e-05)
        train_triplet_losses, train_label_losses, \
          val_triplet_losses, val_label_losses, \
            train_losses, val_losses, train_accuracies, val_accuracies= train_obj.training_loop(train_data = train_data,
                                                            test_data = val_data,
                                                            epochs = epochs,
                                                            optimizer = optimizer, 
                                                            train_mode = train_mode)




        print('validation triplet loss:',val_triplet_losses,'validation total loss:',val_losses,'validation accuracy:',val_accuracies)
        # wandb.log({"train_img_loss": train_img_loss, 
        #           "train_label_loss":train_label_loss, 
        #           "val_img_loss":val_img_loss, 
        #           "val_label_loss":val_label_loss, 
        #           "train_losses":train_losses, 
        #           "val_losses":val_losses, 
        #           "train_accuracy":train_accuracy, 
        #           "val_accuracy":val_accuracy})
        train_mode_dict = {0:'triplet', 1:'label',2:'label_and_triplet' }
        torch.save(train_obj.triplet_lab_model.state_dict(), os.path.join(save_dir,f'{data_set}_{train_mode_dict[train_mode]}_{model}.pth'))
        
        



In [14]:
wandb.finish()

num_classes = 4 # Number of unique class labels in the dataset
latent_dims = 64
epochs = 1000
lr = 0.005
num_models = 1
batch_size = 256
save_dir = save_dir
main_code(save_dir, num_models, epochs, num_classes, batch_size,
             lr, latent_dims)
wandb.finish()

  0%|          | 0/3 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:02<00:00, 46.53it/s]
 33%|███▎      | 1/3 [00:05<00:10,  5.20s/it]

validation triplet loss: [0.6818105578422546, 0.6945065855979919, 0.7384428381919861, 0.7682855725288391, 0.7543907761573792, 0.7349690794944763, 0.7209817171096802, 0.7034502029418945, 0.6808484196662903, 0.6503985524177551, 0.6285585761070251, 0.6080136299133301, 0.5786609649658203, 0.5489876866340637, 0.5118468403816223, 0.4572831094264984, 0.4120670258998871, 0.3707275986671448, 0.31754836440086365, 0.28708702325820923, 0.2624299228191376, 0.23523736000061035, 0.21608467400074005, 0.20106564462184906, 0.1883573681116104, 0.19404922425746918, 0.19796526432037354, 0.1985437124967575, 0.20289981365203857, 0.1997726708650589, 0.19463478028774261, 0.19073167443275452, 0.18645299971103668, 0.1826697438955307, 0.18208113312721252, 0.17807899415493011, 0.17479704320430756, 0.17248325049877167, 0.17956936359405518, 0.16302835941314697, 0.16052086651325226, 0.15680108964443207, 0.1515960544347763, 0.14651727676391602, 0.14259541034698486, 0.14486202597618103, 0.14502394199371338, 0.140799537

VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total validation loss,▇███▇▆▅▄▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train label accuracy,▅▇▇█▄▄▇▅▅▅▆▆▆▆▆▆▆▆▇▆▅▄▅▅▆▆▆▇▄▄▄▄▇▄▄▃▁▁▂▂
train label loss,▁▅▇██▇▅▅▄▃▂▂▃▂▂▂▂▃▃▂▂▃▃▃▃▄▅▅▅▅▅▅▆▆▆▆▇▇▇█
train triplet loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation label accuracy,▃▄▄▄▇▇█▇▇▇▇▇▇▇▇▇▆▆▆▃▃▃▁▂▃▇▇▇▇▇▇▇▆▆▆▆▆▆▄▆
validation label loss,▂▃▇██▇▆▅▄▃▃▃▃▃▂▂▂▂▃▃▃▂▂▂▂▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁
validation triplet loss,▇███▇▆▅▄▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁

0,1
latent separation accuracy,1.0
total train loss,0.70823
total validation loss,0.84947
train label accuracy,0.18
train label loss,0.70817
train triplet loss,6e-05
validation label accuracy,0.3
validation label loss,0.689
validation triplet loss,0.16047


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668746604894598, max=1.0…

100%|██████████| 100/100 [00:01<00:00, 53.55it/s]
 67%|██████▋   | 2/3 [00:14<00:07,  7.71s/it]

validation triplet loss: [0.7857500910758972, 0.8053858876228333, 0.8035473823547363, 0.7810969352722168, 0.7714853286743164, 0.7807440757751465, 0.794644832611084, 0.8059770464897156, 0.7463658452033997, 0.67631995677948, 0.6278941631317139, 0.5897226333618164, 0.5378535389900208, 0.5180172324180603, 0.5124547481536865, 0.523366391658783, 0.526308000087738, 0.5242937207221985, 0.5115389227867126, 0.49892139434814453, 0.4889693856239319, 0.48143965005874634, 0.4736044406890869, 0.4648776948451996, 0.4558602273464203, 0.44898921251296997, 0.44433632493019104, 0.44106969237327576, 0.43840476870536804, 0.4364124834537506, 0.43579956889152527, 0.4358341693878174, 0.43534621596336365, 0.4347001612186432, 0.4341833293437958, 0.43383070826530457, 0.4335361421108246, 0.43349066376686096, 0.4332703649997711, 0.43314170837402344, 0.4329369068145752, 0.43279018998146057, 0.4328099191188812, 0.43279972672462463, 0.4327889382839203, 0.43271946907043457, 0.4325467050075531, 0.43235722184181213, 0.43

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,█▇▆▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total validation loss,██▇▇▄▃▃▃▃▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train label accuracy,▁▄▄▅████████████████████████████████████
train label loss,█▆▅▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train triplet loss,▁▆█▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
validation label accuracy,▁▁▂▃▅█▇▇████████████████████████████████
validation label loss,█▇▇▇▄▃▂▃▃▃▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation triplet loss,████▅▃▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
latent separation accuracy,1.0
total train loss,0.38728
total validation loss,0.45271
train label accuracy,1.0
train label loss,0.0
train triplet loss,0.38728
validation label accuracy,0.95
validation label loss,0.02335
validation triplet loss,0.42935


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668388992547988, max=1.0…

100%|██████████| 100/100 [00:02<00:00, 46.14it/s]
100%|██████████| 3/3 [00:26<00:00,  8.97s/it]


validation triplet loss: [0.6850911974906921, 0.7449369430541992, 0.7984310984611511, 0.8034595847129822, 0.7959213852882385, 0.7891492247581482, 0.7787393927574158, 0.7445279359817505, 0.7011989951133728, 0.6532832384109497, 0.6172688603401184, 0.5591580867767334, 0.4862089157104492, 0.430910587310791, 0.40140238404273987, 0.3816988468170166, 0.3491032123565674, 0.30929800868034363, 0.27574416995048523, 0.25034913420677185, 0.21577206254005432, 0.20297470688819885, 0.20600526034832, 0.19053718447685242, 0.18021057546138763, 0.16833938658237457, 0.1729569137096405, 0.17830483615398407, 0.1794797033071518, 0.170306995511055, 0.16050659120082855, 0.14923681318759918, 0.13725046813488007, 0.12385383993387222, 0.113128662109375, 0.10524068027734756, 0.09878398478031158, 0.10010391473770142, 0.09515917301177979, 0.0940723642706871, 0.09056346863508224, 0.09461098164319992, 0.09925947338342667, 0.10205777734518051, 0.10545206069946289, 0.1075906977057457, 0.10888288170099258, 0.1091104373335

  0%|          | 0/3 [00:00<?, ?it/s]

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,█▆▄▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total validation loss,▇███▇▆▅▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train label accuracy,▁▄▅▆▇███████████████████████████████████
train label loss,█▇▅▄▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train triplet loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation label accuracy,▁▁▂▃▄▅▅▅▇█▇▇▇▇▇█████████████████████████
validation label loss,█▇▇██▆▅▄▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂
validation triplet loss,▇██▇▆▅▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
latent separation accuracy,1.0
total train loss,1e-05
total validation loss,0.25194
train label accuracy,1.0
train label loss,1e-05
train triplet loss,0.0
validation label accuracy,0.9
validation label loss,0.16765
validation triplet loss,0.08429


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666853533436855, max=1.0)…

100%|██████████| 100/100 [00:02<00:00, 47.26it/s]
  lambda data: self._console_raw_callback("stderr", data),


validation triplet loss: [0.6904611587524414, 0.6984723210334778, 0.7388632297515869, 0.7662394642829895, 0.7550672888755798, 0.7399150133132935, 0.7378654479980469, 0.7231554985046387, 0.7116303443908691, 0.6799613833427429, 0.6410267353057861, 0.6198035478591919, 0.5974165797233582, 0.5665671229362488, 0.5331675410270691, 0.47243353724479675, 0.4321592450141907, 0.3914901912212372, 0.3647538721561432, 0.33602458238601685, 0.30918648838996887, 0.28254491090774536, 0.2512475252151489, 0.2306671440601349, 0.21516981720924377, 0.20396526157855988, 0.19631342589855194, 0.1883808970451355, 0.1810692846775055, 0.1749902069568634, 0.16963620483875275, 0.1649259775876999, 0.16940152645111084, 0.17109040915966034, 0.17137984931468964, 0.17011216282844543, 0.16433463990688324, 0.15979371964931488, 0.1556660681962967, 0.15275079011917114, 0.15000300109386444, 0.14752422273159027, 0.14520178735256195, 0.14376388490200043, 0.1428757607936859, 0.14216765761375427, 0.140296071767807, 0.1382833719253

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total validation loss,▇███▇▆▅▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁
train label accuracy,▂▃▅▆▅▃█▇▅▅▄▃▄▄▄▄▄▅▄▃▁▁▁▃▄▄▅▄▄▅▆▅▃▂▂▂▂▂▃▃
train label loss,█▇▄▃▃▃▄▄▄▄▄▄▄▄▄▄▄▄▄▃▃▂▂▁▁▂▂▂▃▃▄▅▅▅▅▅▅▅▅▅
train triplet loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation label accuracy,▁██▄▄▄▄▄▄▄▄▄▄▄▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄
validation label loss,▅██▆▂▁▃▅▅▅▄▄▄▄▄▄▄▄▄▂▂▂▃▃▂▂▂▂▂▃▅▅▅▅▅▅▅▅▅▅
validation triplet loss,▇███▇▆▅▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
latent separation accuracy,1.0
total train loss,0.69824
total validation loss,0.83581
train label accuracy,0.24
train label loss,0.69802
train triplet loss,0.00023
validation label accuracy,0.25
validation label loss,0.69921
validation triplet loss,0.1366


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668622738992174, max=1.0…

100%|██████████| 100/100 [00:01<00:00, 53.37it/s]
  lambda data: self._console_raw_callback("stderr", data),


validation triplet loss: [0.717011034488678, 0.6571634411811829, 0.6628132462501526, 0.7474359273910522, 0.7815057635307312, 0.7653582692146301, 0.7787677645683289, 0.8108820915222168, 0.8273323178291321, 0.8200744986534119, 0.8192340731620789, 0.7854527831077576, 0.796424388885498, 0.7415935397148132, 0.6586242914199829, 0.5785452723503113, 0.5439600944519043, 0.5040408968925476, 0.461495965719223, 0.43683919310569763, 0.4102462828159332, 0.4273488521575928, 0.44714757800102234, 0.46269771456718445, 0.4597838521003723, 0.44557735323905945, 0.4284621775150299, 0.4090920388698578, 0.3908904194831848, 0.37397700548171997, 0.3584934175014496, 0.3422435224056244, 0.32800257205963135, 0.3159603774547577, 0.30444082617759705, 0.29563990235328674, 0.2901129722595215, 0.28640174865722656, 0.283823698759079, 0.28202688694000244, 0.28093451261520386, 0.28074994683265686, 0.2806389331817627, 0.2804272770881653, 0.2800767123699188, 0.2797726094722748, 0.27946868538856506, 0.27925145626068115, 0.27

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,██▆▅▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total validation loss,▆▅▆███▄▃▂▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train label accuracy,▁▃▆▇▇███████████████████████████████████
train label loss,█▆▅▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train triplet loss,▂██▅▃▂▁▁▁▁▂▂▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
validation label accuracy,▃▁▁▁▁▂▅▆▇▇▆█████████████████████████████
validation label loss,▆▅▆███▄▃▂▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation triplet loss,▇▆▇███▅▄▃▃▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
latent separation accuracy,1.0
total train loss,0.23162
total validation loss,0.31999
train label accuracy,1.0
train label loss,0.0
train triplet loss,0.23162
validation label accuracy,1.0
validation label loss,0.03976
validation triplet loss,0.28023


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668478026986122, max=1.0…

100%|██████████| 100/100 [00:02<00:00, 48.48it/s]
  lambda data: self._console_raw_callback("stderr", data),
100%|██████████| 3/3 [00:36<00:00, 12.02s/it]


validation triplet loss: [0.6613699793815613, 0.7009269595146179, 0.7252092361450195, 0.7352169156074524, 0.7378575205802917, 0.7371872067451477, 0.7127805948257446, 0.678843080997467, 0.6580721139907837, 0.647397518157959, 0.6183987855911255, 0.6181268692016602, 0.6007515788078308, 0.5636969208717346, 0.5364000201225281, 0.44572553038597107, 0.31981033086776733, 0.24118678271770477, 0.21783292293548584, 0.20950432121753693, 0.19672775268554688, 0.17871655523777008, 0.16035225987434387, 0.1440620720386505, 0.13067467510700226, 0.12139850109815598, 0.11535787582397461, 0.1074923425912857, 0.10374647378921509, 0.10034873336553574, 0.09947105497121811, 0.10333137959241867, 0.1051860973238945, 0.10651091486215591, 0.10722514241933823, 0.1087685376405716, 0.10912847518920898, 0.10915582627058029, 0.11037792265415192, 0.11113428324460983, 0.10635107010602951, 0.10199394077062607, 0.09799255430698395, 0.0945354476571083, 0.0910024642944336, 0.08765565603971481, 0.0876249223947525, 0.088324129

  0%|          | 0/3 [00:00<?, ?it/s]

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,█▆▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total validation loss,███▇▆▆▅▃▃▂▁▁▁▁▁▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train label accuracy,▁▂▄▆▇▇██████████████████████████████████
train label loss,█▇▆▅▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train triplet loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation label accuracy,▁▁▁▄▅▄▅▆▆▇██████▇▇▇█████████▇███▇▇██████
validation label loss,█▇▇▆▅▆▅▄▃▂▁▁▁▁▁▁▂▃▃▂▂▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂
validation triplet loss,▇██▇▇▇▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
latent separation accuracy,1.0
total train loss,0.0
total validation loss,0.18634
train label accuracy,1.0
train label loss,0.0
train triplet loss,0.0
validation label accuracy,0.95
validation label loss,0.09883
validation triplet loss,0.08752


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668535101537905, max=1.0…

100%|██████████| 100/100 [00:02<00:00, 44.94it/s]
  lambda data: self._console_raw_callback("stderr", data),


validation triplet loss: [0.6839104890823364, 0.6448492407798767, 0.6723971962928772, 0.7409624457359314, 0.7440465092658997, 0.7230227589607239, 0.7094215154647827, 0.6722870469093323, 0.6279004216194153, 0.6074097752571106, 0.5867446660995483, 0.5618500709533691, 0.5326198935508728, 0.4934861660003662, 0.455679327249527, 0.42093706130981445, 0.3932177722454071, 0.36650416254997253, 0.3538811206817627, 0.3336886167526245, 0.3160078823566437, 0.3013678193092346, 0.2920849919319153, 0.28678950667381287, 0.27480751276016235, 0.26194244623184204, 0.2525208294391632, 0.24611173570156097, 0.24094431102275848, 0.23548340797424316, 0.23320789635181427, 0.23186109960079193, 0.23058784008026123, 0.23092356324195862, 0.2334313690662384, 0.2375030368566513, 0.24788831174373627, 0.2435959428548813, 0.2323375791311264, 0.2222541868686676, 0.21754713356494904, 0.21464696526527405, 0.212448388338089, 0.2109612226486206, 0.2082439512014389, 0.20599913597106934, 0.20318809151649475, 0.2012423574924469,

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▁▁
total validation loss,███▇▆▆▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁
train label accuracy,██▄▄▅▃▄▅▅▅▄▃▃▂▂▄▂▂▁▁▁▂▂▂▁▁▂▃▅▄▄▄▅▁▅▅▅▇▅▅
train label loss,▄▁▁▂▃▄▅▅▅▆▆▆▇▇██████▇▇▇▇▇▇▇▇▇▇▇▇▆▇▇████▇
train triplet loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation label accuracy,██▇▇▇▅▂▂▁▂▂▄▁▂▂▄▅▅▅▄▄▅▅▅▅▄▂▅▅▅▅▅▂▂▅▇▇▇▅▄
validation label loss,▅▃▂▁▃▃▄▄▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█████
validation triplet loss,█▇█▇▆▆▄▄▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁

0,1
latent separation accuracy,1.0
total train loss,0.71937
total validation loss,0.91169
train label accuracy,0.23
train label loss,0.71906
train triplet loss,0.00031
validation label accuracy,0.15
validation label loss,0.7249
validation triplet loss,0.18679


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666842740960419, max=1.0)…

100%|██████████| 100/100 [00:02<00:00, 46.35it/s]
  lambda data: self._console_raw_callback("stderr", data),


validation triplet loss: [0.7812778949737549, 0.8032556772232056, 0.8049201965332031, 0.8024665713310242, 0.777927577495575, 0.7518196702003479, 0.7270261645317078, 0.7228198647499084, 0.6802545189857483, 0.6759350895881653, 0.7014622688293457, 0.6914780735969543, 0.6491081118583679, 0.6349987387657166, 0.6203736066818237, 0.5885421633720398, 0.5720008015632629, 0.5645099878311157, 0.5369675755500793, 0.5229366421699524, 0.5153339505195618, 0.5061613321304321, 0.4981837272644043, 0.4968508183956146, 0.4994598925113678, 0.5027613639831543, 0.5048738718032837, 0.5061306357383728, 0.5061156153678894, 0.505290687084198, 0.503963828086853, 0.5027291178703308, 0.5021060109138489, 0.5016428232192993, 0.5007895827293396, 0.49956759810447693, 0.49784204363822937, 0.49566784501075745, 0.49338170886039734, 0.49092379212379456, 0.4882057309150696, 0.4858510494232178, 0.48383909463882446, 0.4824472963809967, 0.4814682900905609, 0.4806835353374481, 0.4801233410835266, 0.4797928035259247, 0.479608267

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,█▇▆▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total validation loss,█▇▅▄▄▂▁▁▁▁▂▂▃▄▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
train label accuracy,▁▃▄▅▇███████████████████████████████████
train label loss,█▇▅▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train triplet loss,▁▃▅▇██▇▇▇▇▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
validation label accuracy,▁▃▄▅▆▆▆▆▇▇██▇▇▇▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
validation label loss,▅▄▃▂▂▁▁▂▂▂▃▃▄▅▆▇▇███████████████████████
validation triplet loss,██▇▆▆▅▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
latent separation accuracy,1.0
total train loss,0.40858
total validation loss,1.33306
train label accuracy,1.0
train label loss,0.0
train triplet loss,0.40858
validation label accuracy,0.75
validation label loss,0.862
validation triplet loss,0.47105


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666856831870973, max=1.0)…

100%|██████████| 100/100 [00:02<00:00, 45.46it/s]
  lambda data: self._console_raw_callback("stderr", data),
100%|██████████| 3/3 [00:36<00:00, 12.07s/it]


validation triplet loss: [0.6644529104232788, 0.6112272143363953, 0.5776681303977966, 0.6236146092414856, 0.6405686736106873, 0.5990908145904541, 0.5735971331596375, 0.5629510283470154, 0.5292701125144958, 0.45390817523002625, 0.3716748356819153, 0.33253780007362366, 0.31736937165260315, 0.3288523554801941, 0.31370192766189575, 0.29457423090934753, 0.3162075877189636, 0.3343983292579651, 0.36375588178634644, 0.36943575739860535, 0.3356291949748993, 0.29329487681388855, 0.23842237889766693, 0.20252428948879242, 0.15799129009246826, 0.13098447024822235, 0.11722727864980698, 0.11136545985937119, 0.10559578984975815, 0.10225921869277954, 0.10105430334806442, 0.087112195789814, 0.07518595457077026, 0.06451719254255295, 0.05649325251579285, 0.051135022193193436, 0.04748177155852318, 0.04476359114050865, 0.043241098523139954, 0.0430152602493763, 0.04321015253663063, 0.04371441528201103, 0.043999530375003815, 0.04397626966238022, 0.04371963068842888, 0.04293335974216461, 0.042032551020383835, 

  0%|          | 0/3 [00:00<?, ?it/s]

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,█▆▅▄▄▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total validation loss,█▇▇▆▄▃▃▄▄▂▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃
train label accuracy,▁▂▃▃▅▆▇▇████████████████████████████████
train label loss,█▇▆▅▅▄▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train triplet loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation label accuracy,▁▂▂▃▃▅▅▅▆▆▇▇▇███████████████████████████
validation label loss,█▇▇▆▄▄▃▄▃▂▁▁▂▃▄▄▅▅▅▆▇▇▆▆▆▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇
validation triplet loss,█▇▇▇▅▄▄▄▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
latent separation accuracy,1.0
total train loss,1e-05
total validation loss,0.58236
train label accuracy,1.0
train label loss,1e-05
train triplet loss,0.0
validation label accuracy,0.9
validation label loss,0.56196
validation triplet loss,0.0204


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668511492510637, max=1.0…

100%|██████████| 100/100 [00:02<00:00, 42.73it/s]
  lambda data: self._console_raw_callback("stderr", data),


validation triplet loss: [0.6466699838638306, 0.5136932730674744, 0.4991872012615204, 0.4152054786682129, 0.35976675152778625, 0.33726778626441956, 0.3417889177799225, 0.35889506340026855, 0.34040918946266174, 0.31015822291374207, 0.2792142927646637, 0.24787986278533936, 0.22372618317604065, 0.21016652882099152, 0.20237503945827484, 0.1974719762802124, 0.19075898826122284, 0.18087656795978546, 0.1815304011106491, 0.1829463541507721, 0.18204398453235626, 0.17884191870689392, 0.17727963626384735, 0.1799517273902893, 0.17942512035369873, 0.17785099148750305, 0.17573976516723633, 0.17222876846790314, 0.16965197026729584, 0.1672302782535553, 0.16505788266658783, 0.16433483362197876, 0.16400307416915894, 0.1603657752275467, 0.15811696648597717, 0.15717469155788422, 0.1565883904695511, 0.15631762146949768, 0.15614686906337738, 0.1558719128370285, 0.15758973360061646, 0.15944015979766846, 0.16007061302661896, 0.163197323679924, 0.16443462669849396, 0.1712462455034256, 0.17367379367351532, 0.17

VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total validation loss,█▆▄▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂
train label accuracy,▆▁▃▃▃▄▄▄▆▇▇█▇▇▇▆▇▆▆▅▅▅▅▅▅▅▅▅▅▅▆▆▇▇▇▆▆▆▆▇
train label loss,▄██▇▅▄▃▃▃▃▃▃▃▃▃▃▃▃▃▄▄▅▅▅▅▄▄▄▃▃▂▁▁▁▁▁▂▂▂▂
train triplet loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation label accuracy,▄▄▄█▄▁▁▄▄▄▄███▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
validation label loss,▃▁▃▆▄▃▅▅▅▅▃▃▄▄▄▄▄▅▅▆▆▅▅▆▆▆▇█▇▇▅▅▅▅▄▃▃▃▄▃
validation triplet loss,█▆▄▄▃▂▂▂▂▂▂▂▂▁▁▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂

0,1
latent separation accuracy,1.0
total train loss,0.66878
total validation loss,0.8471
train label accuracy,0.3
train label loss,0.66878
train triplet loss,0.0
validation label accuracy,0.25
validation label loss,0.66405
validation triplet loss,0.18305


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668485788007576, max=1.0…

100%|██████████| 100/100 [00:01<00:00, 51.92it/s]
  lambda data: self._console_raw_callback("stderr", data),


validation triplet loss: [0.8539018630981445, 0.8334077000617981, 0.7696071863174438, 0.7246454954147339, 0.6945151686668396, 0.6693103909492493, 0.6250882148742676, 0.5573397874832153, 0.5058600306510925, 0.4661056101322174, 0.4357603192329407, 0.40747061371803284, 0.39672935009002686, 0.3964630663394928, 0.4018155038356781, 0.4032715857028961, 0.40325814485549927, 0.3976181447505951, 0.3904040455818176, 0.38045233488082886, 0.3653390109539032, 0.3536454141139984, 0.3425372540950775, 0.3360569477081299, 0.3345935046672821, 0.3363778293132782, 0.3375082015991211, 0.3385043442249298, 0.33950015902519226, 0.3403986990451813, 0.3413320481777191, 0.34257927536964417, 0.3437754809856415, 0.3447650372982025, 0.34572187066078186, 0.3467767834663391, 0.3478299677371979, 0.34881827235221863, 0.3496801555156708, 0.3505452573299408, 0.3512859046459198, 0.35188719630241394, 0.3524050712585449, 0.3528828024864197, 0.3532876670360565, 0.35372546315193176, 0.35416895151138306, 0.35465285181999207, 0.