In [1]:
import os
base_dir = os.path.abspath('../..')
save_dir = os.path.join(base_dir,'results')
data_dir = os.path.join(base_dir,'data')

In [2]:
import torch
# torch.manual_seed(0)
import wandb
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
from torch.utils.data import TensorDataset,Dataset
from torch.utils.data.sampler import SubsetRandomSampler

import torchvision
from tqdm import tqdm
import random
import numpy as np
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200


from sklearn import linear_model
from sklearn.preprocessing import StandardScaler

# from neurora.rdm_corr import rdm_correlation_spearman

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Define model

In [4]:
class TripletLabelModel(nn.Module):
    def __init__(self, encoded_space_dim=64, num_classes=4):
        super().__init__()
        ""
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
    
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)
        ### Linear section
        self.encoder_lin = nn.Sequential(
            nn.Linear(32*4*4, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(True),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(True),
            nn.Linear(128, encoded_space_dim)
        )

        ## triplet projection module
        self.decoder_triplet_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 32),
            nn.ReLU(True)
         
        )
        ##labeling module
        self.decoder_labels_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 32),
            nn.ReLU(True),
            nn.Linear(32, 16),
            nn.ReLU(True),
            nn.Linear(16, num_classes),
        )

        ### initialize weights using xavier initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

    
        
    def forward(self, x, y=None):
        batch_s = x.size(0)
        img_features = self.encoder_cnn(x)
        img_features = self.flatten(img_features)
        
        enc_latent = self.encoder_lin(img_features)

        triplet_latent = self.decoder_triplet_lin(enc_latent)
        label = self.decoder_labels_lin(enc_latent)
        # label = F.softmax(label,dim=1)
        return enc_latent, label

In [5]:
### custom loss computing triplet loss and labeling loss


class CustomLoss(nn.Module):
    def __init__(self, margin=10):
        super(CustomLoss, self).__init__()
        self.margin = margin
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, anchor, positive, negative, label, pred_label):
        cosine_sim = torch.nn.CosineSimilarity(1)
        # distance_positive = torch.tensor(1)-cosine_sim(anchor,positive)
   
        # distance_negative = torch.tensor(1)-cosine_sim(anchor,negative)

        # triplet_loss = torch.maximum(distance_positive - distance_negative + self.margin, torch.tensor(0))
        # triplet_loss = torch.sum(triplet_loss)
        triplet_loss = (nn.TripletMarginWithDistanceLoss( distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y)))
        triplet_loss = triplet_loss(anchor, positive, negative)
        label_loss = F.binary_cross_entropy_with_logits(pred_label.float(), label.float())
        total_loss = triplet_loss + label_loss
        return triplet_loss, label_loss, total_loss

### Training functions

In [6]:


class TrainModels(nn.Module):
    def __init__(self, latent_dims, num_classes):
        super(TrainModels, self).__init__()
        self.triplet_lab_model = TripletLabelModel(latent_dims, num_classes)
        self.custom_loss = CustomLoss()
        self.num_classes = num_classes
    
    def forward(self, anchor_im, positive_im, negative_im):
        anchor_latent, anchor_label = self.triplet_lab_model(anchor_im)
        positive_latent, _ = self.triplet_lab_model(positive_im)
        negative_latent, _ = self.triplet_lab_model(negative_im)

        return anchor_latent, positive_latent, negative_latent, anchor_label

    def test_epoch(self, test_data):
    # Set evaluation mode for encoder and decoder
        self.eval()
        with torch.no_grad(): # No need to track the gradients
            # Define the lists to store the outputs for each batch
            test_triplet_loss = []
            test_label_loss = []
            test_total_loss = []
            total = 0
            correct = 0
            for anchor_ims, contrast_ims, labels in test_data:
                # Move tensor to the proper device
                anchor_ims = anchor_ims.to(device)
                contrast_ims = contrast_ims.to(device)
                labels = F.one_hot(labels, num_classes=self.num_classes)
                labels = labels.to(device)
                anchor_latent, positive_latent, negative_latent, pred_label = self.forward(anchor_ims, anchor_ims,contrast_ims) 
                # Append the network output and the original image to the lists
                triplet_loss, label_loss, total_loss = self.custom_loss(anchor_latent,
                                                                positive_latent, 
                                                                negative_latent, 
                                                                labels,
                                                                pred_label)
                total += labels.size(0)
                correct += (torch.argmax(pred_label, dim = 1) == torch.argmax(labels, dim = 1)).sum().item()
                test_triplet_loss.append(triplet_loss.item())
                test_label_loss.append(label_loss.item())
                test_total_loss.append(total_loss.item())
        test_triplet_loss = sum(test_triplet_loss)/len(test_triplet_loss)
        test_label_loss = sum(test_label_loss)/len(test_label_loss)
        test_total_loss = sum(test_total_loss)/len(test_total_loss)
        test_accuracy = correct/total
        return test_triplet_loss, test_label_loss, test_total_loss, test_accuracy

    def test_epoch_calculate_representation_separation(self, test_data):
    # Set evaluation mode for encoder and decoder
        self.eval()
        with torch.no_grad(): # No need to track the gradients
            accuracies = []
            for anchor_ims, contrast_ims, labels in test_data:
                # Move tensor to the proper device
                anchor_ims = anchor_ims.to(device)
                contrast_ims = contrast_ims.to(device)
                # labels = F.one_hot(labels, num_classes=self.num_classes)
                # labels = labels.to(device)
                anchor_latent, _, _, _ = self.forward(anchor_ims, anchor_ims,contrast_ims) 
                # use sklearn to predict labels from anchor_latent
                # calculate accuracy
                # x's are anchor_latent and y's are labels
                # append accuracy to list
                # put anchor_latent and labels on cpu and convert to numpy

          
                anchor_latent = anchor_latent.cpu().numpy()
                ### standard scale the data in anchor_latent before fitting to the model
                anchor_latent = StandardScaler().fit_transform(anchor_latent)
                labels = labels.cpu().numpy()
                
                lm = linear_model.LogisticRegression()
                lm.fit(anchor_latent, labels)
                # convert labels to sklearn format
                accuracies.append(lm.score(anchor_latent, labels))
        accuracy = sum(accuracies)/len(accuracies)
        return accuracy

    def train_epoch(self, train_data, optimizer, train_mode):
        self.train()
        train_triplet_loss = []
        train_label_loss = []
        train_total_loss = []
        correct = 0
        total = 0
        for anchor_ims, contrast_ims, labels in train_data:
            
            anchor_ims = anchor_ims.to(device)
            contrast_ims = contrast_ims.to(device)
            labels = F.one_hot(labels, num_classes=self.num_classes)
            labels = labels.to(device)

            optimizer.zero_grad()
            anchor_latent, positive_latent, negative_latent, pred_label = self.forward(anchor_ims, anchor_ims,contrast_ims) 
           
           
           
            triplet_loss, label_loss, total_loss = self.custom_loss(anchor_latent,
                                                                positive_latent, 
                                                                negative_latent, 
                                                                labels,
                                                                pred_label)
            
            
            if train_mode==0:
                triplet_loss.backward()
            elif train_mode==1:
                label_loss.backward()
            elif train_mode==2:
                total_loss.backward()

            optimizer.step()
            train_triplet_loss.append(triplet_loss.item())
            train_label_loss.append(label_loss.item())
            train_total_loss.append(total_loss.item())
            total += labels.size(0)
            correct += (torch.argmax(pred_label, dim = 1) == torch.argmax(labels, dim = 1)).sum().item()
        train_triplet_loss = sum(train_triplet_loss)/len(train_triplet_loss)
        train_label_loss = sum(train_label_loss)/len(train_label_loss)
        train_total_loss = sum(train_total_loss)/len(train_total_loss)
        train_accuracy = correct/total
        return train_triplet_loss, train_label_loss, train_total_loss, train_accuracy

    def training_loop(self, train_data, test_data,train_mode,
                      epochs, optimizer):
        train_losses = []
        val_losses = []
        train_triplet_losses = []
        val_triplet_losses = []
        train_label_losses = []
        val_label_losses = []
        train_accuracies = []
        val_accuracies = []
        latent_separation_accuracy = 0
        for epoch in tqdm(range(epochs)):
          train_triplet_loss, train_label_loss, train_total_loss, train_accuracy =self.train_epoch(train_data, optimizer, 
                                             train_mode)
          test_triplet_loss, test_label_loss, test_total_loss, test_accuracy = self.test_epoch(test_data)
          separation_accuracy = self.test_epoch_calculate_representation_separation(test_data)
          train_losses.append(train_total_loss)
          val_losses.append(test_total_loss)
          train_triplet_losses.append(train_triplet_loss)
          val_triplet_losses.append(test_triplet_loss)
          train_label_losses.append(train_label_loss)
          val_label_losses.append(test_label_loss)
          train_accuracies.append(train_accuracy)
          val_accuracies.append(test_accuracy)
          wandb.log({"train triplet loss": train_triplet_loss, 
            "train label loss":train_label_loss, 
            "validation triplet loss":test_triplet_loss, 
            "validation label loss":test_label_loss, 
            "total train loss":train_total_loss, 
            "total validation loss":test_total_loss, 
            "train label accuracy":train_accuracy, 
            "validation label accuracy":test_accuracy,
            'latent separation accuracy':separation_accuracy})
        return train_triplet_losses, train_label_losses, val_triplet_losses, val_label_losses ,train_losses, val_losses, train_accuracies, val_accuracies


In [7]:
set_A_ims = np.load(os.path.join(data_dir, 'set_A.npy'))
set_B_ims = np.load(os.path.join(data_dir, 'set_B.npy'))
set_C_ims= np.load(os.path.join(data_dir, 'set_C.npy'))
set_A_labs = np.load(os.path.join(data_dir, 'set_A_labs.npy'))
set_B_labs = np.load(os.path.join(data_dir, 'set_B_labs.npy'))
set_C_labs = np.load(os.path.join(data_dir, 'set_C_labs.npy'))


In [8]:

###initialize weights and bias tracking
def wandb_init(epochs, lr, train_mode, batch_size, model_number,data_set):
  wandb.init(project="ConceptualAlignment2023", entity="psych-711",settings=wandb.Settings(start_method="thread"))
  wandb.config = {
    "learning_rate": lr,
    "epochs": epochs,
    "batch_size": batch_size, 
    # "label_ratio":label_ratio, 
    "model_number": model_number,
    "dataset": data_set,
    "train_mode":train_mode,
  }
  train_mode_dict = {0:'triplet', 1:'label', 2:'label_and_triplet'}
  wandb.run.name = f'{data_set}_{train_mode_dict[train_mode]}_{model_number}'
  wandb.run.save()
     

In [9]:

def main_code(save_dir, num_models, epochs, num_classes, batch_size,
             lr, latent_dims):
  if os.path.isdir(save_dir):
    pass
  else:
    os.mkdir(save_dir)


  test_intervals = [(540, 600), (1140, 1200), (1740, 1800), (2340, 2400)]

  # initialize an empty list to hold the indices
  val_indices = []

  # loop through the intervals and append the indices to the list
  for start, stop in test_intervals:
      val_indices.extend(list(range(start, stop)))

  train_indices = (np.setdiff1d(np.arange(2400),np.array(val_indices)))

  np.random.seed(56)
  contrast_indices  = np.concatenate((np.random.choice(np.arange(start=600, stop=2400), 600, replace=False),
                np.random.choice(np.concatenate((np.arange(start=0, stop=600), np.arange(start=1200, stop=2400))), 600, replace=False),
                np.random.choice(np.concatenate((np.arange(start=0, stop=1200), np.arange(start=1800, stop=2400))), 600, replace=False),
                np.random.choice(np.arange(start=1800, stop=2400), 600, replace=False)))

  for data_set in ['set_A','set_B','set_C']:
    for train_mode in tqdm(range(3)):
     # torch.manual_seed(0)
      for model in range(num_models):
        wandb_init(epochs, lr, train_mode, batch_size, model,data_set)

        if data_set=='set_A':
          train_data = TensorDataset(torch.tensor(set_A_ims.transpose(0,3,1,2)/255).float(), torch.tensor(set_A_ims[contrast_indices].transpose(0,3,1,2)/255).float(),\
                                     torch.tensor(set_A_labs).to(torch.int64))
        elif data_set=='set_B':
          train_data = TensorDataset(torch.tensor(set_B_ims.transpose(0,3,1,2)/255).float(), torch.tensor(set_B_ims[contrast_indices].transpose(0,3,1,2)/255).float(),\
                                     torch.tensor(set_B_labs).to(torch.int64))
        elif data_set=='set_C':
          train_data = TensorDataset(torch.tensor(set_C_ims.transpose(0,3,1,2)/255).float(), torch.tensor(set_C_ims[contrast_indices].transpose(0,3,1,2)/255).float(),\
                                     torch.tensor(set_C_labs).to(torch.int64))
          
        val_data = torch.utils.data.Subset(train_data, val_indices)
        train_data = torch.utils.data.Subset(train_data, train_indices)
       

        train_data = torch.utils.data.DataLoader(train_data, 
                                                batch_size=batch_size,
                                              shuffle=True)
        val_data = torch.utils.data.DataLoader(val_data, 
                                                batch_size=batch_size,
                                              shuffle=True)
        
     

        train_obj = TrainModels(latent_dims, num_classes).to(device) # GPU
        optimizer = torch.optim.Adam(train_obj.parameters(), lr=lr, weight_decay=1e-05)
        train_triplet_losses, train_label_losses, \
          val_triplet_losses, val_label_losses, \
            train_losses, val_losses, train_accuracies, val_accuracies= train_obj.training_loop(train_data = train_data,
                                                            test_data = val_data,
                                                            epochs = epochs,
                                                            optimizer = optimizer, 
                                                            train_mode = train_mode)




        print('validation triplet loss:',val_triplet_losses,'validation total loss:',val_losses,'validation accuracy:',val_accuracies)
        # wandb.log({"train_img_loss": train_img_loss, 
        #           "train_label_loss":train_label_loss, 
        #           "val_img_loss":val_img_loss, 
        #           "val_label_loss":val_label_loss, 
        #           "train_losses":train_losses, 
        #           "val_losses":val_losses, 
        #           "train_accuracy":train_accuracy, 
        #           "val_accuracy":val_accuracy})
        train_mode_dict = {0:'triplet', 1:'label',2:'label_and_triplet' }
        torch.save(train_obj.triplet_lab_model.state_dict(), os.path.join(save_dir,f'{data_set}_{train_mode_dict[train_mode]}_{model}.pth'))
        
        



In [11]:
wandb.finish()

num_classes = 4 # Number of unique class labels in the dataset
latent_dims = 64
epochs = 10
lr = 0.005
num_models = 1
batch_size = 256
save_dir = save_dir
main_code(save_dir, num_models, epochs, num_classes, batch_size,
             lr, latent_dims)
wandb.finish()

  0%|          | 0/3 [00:00<?, ?it/s]

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

validation triplet loss: [0.8325868248939514, 0.23087778687477112, 0.13166537880897522, 0.11618342995643616, 0.12159474194049835, 0.10982739925384521, 0.1326916664838791, 0.10736203193664551, 0.09900834411382675, 0.09412360191345215] validation total loss: [1.5502538681030273, 0.9072749614715576, 0.790917158126831, 0.7833114862442017, 0.7844045162200928, 0.7281664609909058, 0.7635741233825684, 0.741872251033783, 0.7501118183135986, 0.73979252576828] validation accuracy: [0.19166666666666668, 0.2875, 0.3375, 0.2625, 0.325, 0.3958333333333333, 0.375, 0.4041666666666667, 0.2875, 0.2916666666666667]


VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
latent separation accuracy,█▇▁▇▇▆▆▅▅▅
total train loss,█▄▃▃▂▂▂▁▁▁
total validation loss,█▃▂▁▁▁▁▁▁▁
train label accuracy,▅▂▃▁▆▇▄█▅▆
train label loss,▇▃▅█▃▄▆▃▁▄
train triplet loss,█▄▃▂▂▁▁▁▁▁
validation label accuracy,▁▄▆▃▅█▇█▄▄
validation label loss,█▅▄▄▄▁▂▂▃▃
validation triplet loss,█▂▁▁▁▁▁▁▁▁

0,1
latent separation accuracy,0.98333
total train loss,0.70036
total validation loss,0.73979
train label accuracy,0.25741
train label loss,0.69452
train triplet loss,0.00584
validation label accuracy,0.29167
validation label loss,0.64567
validation triplet loss,0.09412


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
100%|██████████| 10/10 [00:02<00:00,  3.45it/s]
 67%|██████▋   | 2/3 [00:28<00:14, 14.23s/it]

validation triplet loss: [0.8526226282119751, 0.5901137590408325, 0.7049498558044434, 0.5987834334373474, 0.5266598463058472, 0.527089536190033, 0.5096761584281921, 0.5363163352012634, 0.5328738689422607, 0.5599735975265503] validation total loss: [1.5235861539840698, 0.6206023693084717, 1.1005301475524902, 0.6215917468070984, 0.5643811225891113, 0.531183123588562, 0.5219855308532715, 0.5683464407920837, 0.6509128212928772, 0.6315500736236572] validation accuracy: [0.25833333333333336, 0.9958333333333333, 0.8208333333333333, 0.9875, 0.9708333333333333, 1.0, 0.9916666666666667, 0.9833333333333333, 0.9541666666666667, 0.9416666666666667]


VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁
total train loss,█▄▄▃▂▁▁▁▁▁
total validation loss,█▂▅▂▁▁▁▁▂▂
train label accuracy,▁▇▇███████
train label loss,█▂▂▂▁▁▁▁▁▁
train triplet loss,▁▆█▇▆▄▃▃▃▃
validation label accuracy,▁█▆██████▇
validation label loss,█▁▅▁▁▁▁▁▂▂
validation triplet loss,█▃▅▃▁▁▁▂▁▂

0,1
latent separation accuracy,1.0
total train loss,0.53727
total validation loss,0.63155
train label accuracy,0.99722
train label loss,0.00349
train triplet loss,0.53377
validation label accuracy,0.94167
validation label loss,0.07158
validation triplet loss,0.55997


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
100%|██████████| 10/10 [00:03<00:00,  2.81it/s]
100%|██████████| 3/3 [00:44<00:00, 14.75s/it]


validation triplet loss: [0.8466342091560364, 0.16003479063510895, 0.10295314341783524, 0.06509903818368912, 0.06900302320718765, 0.07121720910072327, 0.06519393622875214, 0.06111343204975128, 0.055598046630620956, 0.037688929587602615] validation total loss: [1.4192085266113281, 0.3218994140625, 0.15820565819740295, 0.07530447840690613, 0.3620893955230713, 0.09107503294944763, 0.09839673340320587, 0.14086729288101196, 0.09662124514579773, 0.04296433925628662] validation accuracy: [0.25, 0.9625, 0.9791666666666666, 1.0, 0.7541666666666667, 0.9875, 0.9791666666666666, 0.9333333333333333, 0.9833333333333333, 1.0]


  0%|          | 0/3 [00:00<?, ?it/s]

0,1
latent separation accuracy,▁▆████████
total train loss,█▄▂▂▁▁▁▁▁▁
total validation loss,█▂▂▁▃▁▁▁▁▁
train label accuracy,▁▆▇███████
train label loss,█▃▂▁▁▁▁▁▁▁
train triplet loss,█▅▄▃▂▂▁▁▁▁
validation label accuracy,▁███▆██▇██
validation label loss,█▃▂▁▅▁▁▂▁▁
validation triplet loss,█▂▂▁▁▁▁▁▁▁

0,1
latent separation accuracy,1.0
total train loss,0.021
total validation loss,0.04296
train label accuracy,0.99907
train label loss,0.00179
train triplet loss,0.01921
validation label accuracy,1.0
validation label loss,0.00528
validation triplet loss,0.03769


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
100%|██████████| 10/10 [00:03<00:00,  2.91it/s]
  lambda data: self._console_raw_callback("stderr", data),


validation triplet loss: [0.6752011179924011, 0.14804071187973022, 0.12982496619224548, 0.1358948051929474, 0.0696510300040245, 0.0751727893948555, 0.08015242218971252, 0.07770232111215591, 0.07047251611948013, 0.06824380904436111] validation total loss: [1.2601779699325562, 0.7799314856529236, 0.7683043479919434, 0.7677327394485474, 0.7064144015312195, 0.7065895199775696, 0.7220096588134766, 0.7078208923339844, 0.7080687284469604, 0.703364372253418] validation accuracy: [0.3125, 0.19583333333333333, 0.1875, 0.20416666666666666, 0.22083333333333333, 0.2125, 0.14166666666666666, 0.1875, 0.17083333333333334, 0.17083333333333334]


VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
latent separation accuracy,▁▅▅██▅█▅██
total train loss,█▅▃▂▁▁▁▁▁▁
total validation loss,█▂▂▂▁▁▁▁▁▁
train label accuracy,▄▁▄█▆▅▂▃▆▅
train label loss,▂█▃▁▁▃▆▆▅▆
train triplet loss,█▄▃▂▂▁▁▁▁▁
validation label accuracy,█▃▃▄▄▄▁▃▂▂
validation label loss,▁▇█▇▇▇█▇▇▇
validation triplet loss,█▂▂▂▁▁▁▁▁▁

0,1
latent separation accuracy,0.99583
total train loss,0.65009
total validation loss,0.70336
train label accuracy,0.25231
train label loss,0.64392
train triplet loss,0.00617
validation label accuracy,0.17083
validation label loss,0.63512
validation triplet loss,0.06824


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666917058173567, max=1.0)…

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
100%|██████████| 10/10 [00:03<00:00,  3.12it/s]
  lambda data: self._console_raw_callback("stderr", data),


validation triplet loss: [0.8811678886413574, 0.5942180752754211, 0.49946144223213196, 0.5024130344390869, 0.539466142654419, 0.4892040491104126, 0.5311083793640137, 0.4631440043449402, 0.4345147907733917, 0.4212682247161865] validation total loss: [1.5198936462402344, 1.1016626358032227, 0.6611021757125854, 0.5140487551689148, 0.5476621985435486, 0.5193187594413757, 0.6619729995727539, 0.4979228079319, 0.46300455927848816, 0.4213007986545563] validation accuracy: [0.2708333333333333, 0.4875, 0.8875, 0.9958333333333333, 0.9916666666666667, 0.9791666666666666, 0.9375, 0.9833333333333333, 0.9791666666666666, 1.0]


VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
latent separation accuracy,▁▅████████
total train loss,█▆▄▂▂▂▂▁▁▁
total validation loss,█▅▃▂▂▂▃▁▁▁
train label accuracy,▁▃▇███████
train label loss,█▅▃▁▁▁▁▁▁▁
train triplet loss,▁▆██▇▇▆▆▅▄
validation label accuracy,▁▃▇███▇███
validation label loss,█▇▃▁▁▁▂▁▁▁
validation triplet loss,█▄▂▂▃▂▃▂▁▁

0,1
latent separation accuracy,1.0
total train loss,0.46114
total validation loss,0.4213
train label accuracy,0.9963
train label loss,0.00396
train triplet loss,0.45718
validation label accuracy,1.0
validation label loss,3e-05
validation triplet loss,0.42127


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016670827070871988, max=1.0…

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
100%|██████████| 10/10 [00:08<00:00,  1.12it/s]
  lambda data: self._console_raw_callback("stderr", data),
100%|██████████| 3/3 [00:53<00:00, 18.00s/it]


validation triplet loss: [0.8028333187103271, 0.13609068095684052, 0.12557488679885864, 0.08941909670829773, 0.10547396540641785, 0.08442936837673187, 0.06388864666223526, 0.06577518582344055, 0.07241745293140411, 0.06480774283409119] validation total loss: [1.3778021335601807, 0.20809820294380188, 0.26130321621894836, 0.09969512373209, 0.11797698587179184, 0.09200815856456757, 0.06567082554101944, 0.06679460406303406, 0.0747700035572052, 0.06721905618906021] validation accuracy: [0.3, 0.9875, 0.9166666666666666, 0.9916666666666667, 0.9875, 0.9958333333333333, 1.0, 1.0, 1.0, 1.0]


  0%|          | 0/3 [00:00<?, ?it/s]

VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
latent separation accuracy,▁▁▁▁▁▁▁▁▁▁
total train loss,█▄▂▂▁▁▁▁▁▁
total validation loss,█▂▂▁▁▁▁▁▁▁
train label accuracy,▁▆████████
train label loss,█▃▂▁▁▁▁▁▁▁
train triplet loss,█▅▃▂▂▂▁▁▁▁
validation label accuracy,▁█▇███████
validation label loss,█▂▃▁▁▁▁▁▁▁
validation triplet loss,█▂▂▁▁▁▁▁▁▁

0,1
latent separation accuracy,1.0
total train loss,0.01274
total validation loss,0.06722
train label accuracy,1.0
train label loss,0.00069
train triplet loss,0.01205
validation label accuracy,1.0
validation label loss,0.00241
validation triplet loss,0.06481


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016670353322600324, max=1.0…

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
100%|██████████| 10/10 [00:03<00:00,  2.74it/s]
  lambda data: self._console_raw_callback("stderr", data),


validation triplet loss: [0.6906775236129761, 0.34843140840530396, 0.23155581951141357, 0.07811810821294785, 0.09335731714963913, 0.11232209950685501, 0.09029556065797806, 0.10467633605003357, 0.05787605419754982, 0.08762605488300323] validation total loss: [1.3100078105926514, 0.9633222222328186, 0.8903534412384033, 0.7502300143241882, 0.7766430974006653, 0.763231635093689, 0.7351425886154175, 0.7551455497741699, 0.7184464335441589, 0.7308791875839233] validation accuracy: [0.19583333333333333, 0.3416666666666667, 0.25416666666666665, 0.25833333333333336, 0.22083333333333333, 0.2875, 0.2625, 0.26666666666666666, 0.25833333333333336, 0.2125]


0,1
latent separation accuracy,██▅▆▅▅▆▆▁▅
total train loss,█▅▄▃▃▂▁▁▁▂
total validation loss,█▄▃▁▂▂▁▁▁▁
train label accuracy,▄▇▅▃▁▂▆█▇▂
train label loss,▁▂▃█▆▅▃▂▃▆
train triplet loss,█▅▄▃▂▁▁▁▁▁
validation label accuracy,▁█▄▄▂▅▄▄▄▂
validation label loss,▁▁▅▇█▅▄▅▆▄
validation triplet loss,█▄▃▁▁▂▁▂▁▁

0,1
latent separation accuracy,0.99167
total train loss,0.72329
total validation loss,0.73088
train label accuracy,0.22685
train label loss,0.71614
train triplet loss,0.00715
validation label accuracy,0.2125
validation label loss,0.64325
validation triplet loss,0.08763


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016671004837068418, max=1.0…

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
100%|██████████| 10/10 [00:03<00:00,  3.17it/s]
  lambda data: self._console_raw_callback("stderr", data),


validation triplet loss: [0.7550578117370605, 0.6421393752098083, 0.6172167658805847, 0.5402910113334656, 0.5789048671722412, 0.5647134184837341, 0.5232914686203003, 0.5422380566596985, 0.5176268815994263, 0.5174731016159058] validation total loss: [1.11830735206604, 0.7451427578926086, 0.8025115728378296, 0.5412409901618958, 0.6013852953910828, 0.5834274291992188, 0.5234407782554626, 0.5433118939399719, 0.5177183747291565, 0.5175617933273315] validation accuracy: [0.575, 0.9875, 0.8791666666666667, 1.0, 0.9833333333333333, 0.9958333333333333, 1.0, 1.0, 1.0, 1.0]


0,1
latent separation accuracy,▁█████████
total train loss,█▆▃▂▂▂▂▁▁▁
total validation loss,█▄▄▁▂▂▁▁▁▁
train label accuracy,▁▆████████
train label loss,█▄▂▁▁▁▁▁▁▁
train triplet loss,▁▇█▇▆▆▆▅▄▄
validation label accuracy,▁█▆███████
validation label loss,█▃▅▁▁▁▁▁▁▁
validation triplet loss,█▅▄▂▃▂▁▂▁▁

0,1
latent separation accuracy,1.0
total train loss,0.53714
total validation loss,0.51756
train label accuracy,0.99815
train label loss,0.0034
train triplet loss,0.53374
validation label accuracy,1.0
validation label loss,9e-05
validation triplet loss,0.51747


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016670724532256525, max=1.0…

100%|██████████| 10/10 [00:03<00:00,  2.72it/s]
  lambda data: self._console_raw_callback("stderr", data),
100%|██████████| 3/3 [00:55<00:00, 18.41s/it]

validation triplet loss: [0.7328863143920898, 0.12930816411972046, 0.09532110393047333, 0.046415071934461594, 0.04093194007873535, 0.05198988318443298, 0.0608380101621151, 0.04322240874171257, 0.031808458268642426, 0.06013976410031319] validation total loss: [1.2565560340881348, 0.2879713773727417, 0.11753387749195099, 0.06095512956380844, 0.04550860449671745, 0.052376218140125275, 0.08768932521343231, 0.045577868819236755, 0.04412916675209999, 0.12293985486030579] validation accuracy: [0.4375, 0.875, 0.9791666666666666, 0.9958333333333333, 1.0, 1.0, 0.9875, 1.0, 0.9958333333333333, 0.95]





0,1
latent separation accuracy,▁█████████
total train loss,█▄▂▁▁▁▁▁▁▁
total validation loss,█▂▁▁▁▁▁▁▁▁
train label accuracy,▁▆▇███████
train label loss,█▃▂▁▁▁▁▁▁▁
train triplet loss,█▅▃▂▂▁▁▁▁▁
validation label accuracy,▁▆███████▇
validation label loss,█▃▁▁▁▁▁▁▁▂
validation triplet loss,█▂▂▁▁▁▁▁▁▁

0,1
latent separation accuracy,1.0
total train loss,0.03143
total validation loss,0.12294
train label accuracy,0.99259
train label loss,0.01059
train triplet loss,0.02084
validation label accuracy,0.95
validation label loss,0.0628
validation triplet loss,0.06014
