In [2]:
import os
# base_dir = os.path.abspath('/mnt/ws/home/xyu/ConceptualAlignmentLanguage/tripletNCE')
base_dir = os.path.abspath('/mnt/dv/wid/projects3/Rogers-nsf-ind-diff/sid/Projects/ConceptualAlignmentLanguage/tripletNCE')
save_dir = os.path.join(base_dir,'results')
data_dir = os.path.join(base_dir,'data')

In [4]:
import torch
# torch.manual_seed(0)
# import wandb
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
from torch.utils.data import TensorDataset,Dataset, random_split
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.transforms import Resize

import torchvision
from tqdm import tqdm
import random
import numpy as np
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200


from sklearn import linear_model
from sklearn.preprocessing import StandardScaler

# from neurora.rdm_corr import rdm_correlation_spearman

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Define model

In [6]:
class TripletLabelModel(nn.Module):
    def __init__(self, encoded_space_dim=64, num_classes=10):
        super().__init__()
        ""
        ### Convolutional section
       ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
    
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)
        ### Linear section
        ## changed 32*4*4 to 32*2*2
        self.encoder_lin = nn.Sequential(
            nn.Linear(32*2*2, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(True),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(True),
            nn.Linear(128, encoded_space_dim)
        )

        ## triplet projection module
        self.decoder_triplet_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 32),
            nn.ReLU(True)
         
        )
        ##labeling module
        self.decoder_labels_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 32),
            nn.ReLU(True),
            nn.Linear(32, 16),
            nn.ReLU(True),
            nn.Linear(16, num_classes),
        )

        ### initialize weights using xavier initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)
    
        
    def forward(self, x, y=None):
        batch_s = x.size(0)
        img_features = self.encoder_cnn(x)
        img_features = self.flatten(img_features)
        
        enc_latent = self.encoder_lin(img_features)

        triplet_latent = self.decoder_triplet_lin(enc_latent)
        label = self.decoder_labels_lin(enc_latent)
        # label = F.softmax(label,dim=1)
        return enc_latent, label

In [6]:
### custom loss computing triplet loss and labeling loss


class CustomLoss(nn.Module):
    def __init__(self, margin=10):
        super(CustomLoss, self).__init__()
        self.margin = margin
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, anchor, positive, negative, label, pred_label):
        cosine_sim = torch.nn.CosineSimilarity(1)
        # distance_positive = torch.tensor(1)-cosine_sim(anchor,positive)
   
        # distance_negative = torch.tensor(1)-cosine_sim(anchor,negative)

        # triplet_loss = torch.maximum(distance_positive - distance_negative + self.margin, torch.tensor(0))
        # triplet_loss = torch.sum(triplet_loss)
        triplet_loss = (nn.TripletMarginWithDistanceLoss( distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y)))
        triplet_loss = triplet_loss(anchor, positive, negative)
        label_loss = F.binary_cross_entropy_with_logits(pred_label.float(), label.float())
        total_loss = triplet_loss + label_loss
        return triplet_loss, label_loss, total_loss

In [7]:
# t = TripletLabelModel()
# cifar_model_path = '../../data/CIFAR10_NCE_i_1e-05_50.pth'
# t.load_state_dict(torch.load(cifar_model_path))

### Training functions

In [8]:


class TrainModels(nn.Module):
    def __init__(self, latent_dims, num_classes, weights_path=None):
        super(TrainModels, self).__init__()
        self.triplet_lab_model = TripletLabelModel(latent_dims, 10) ### load cifar model
        if weights_path!=None:
            cifar_model_path = '/mnt/ws/home/xyu/ConceptualAlignmentLanguage/tripletNCE/data/CIFAR10_NCE_i_1e-05_50.pth'
            self.triplet_lab_model.load_state_dict(torch.load(cifar_model_path))
            self.triplet_lab_model.decoder_labels_lin[4] = nn.Linear(16, num_classes)
        self.custom_loss = CustomLoss()
        self.num_classes = num_classes
    
    def forward(self, anchor_im, positive_im, negative_im):
        anchor_latent, anchor_label = self.triplet_lab_model(anchor_im)
        positive_latent, _ = self.triplet_lab_model(positive_im)
        negative_latent, _ = self.triplet_lab_model(negative_im)

        return anchor_latent, positive_latent, negative_latent, anchor_label

    def test_epoch(self, test_data):
    # Set evaluation mode for encoder and decoder
        self.eval()
        with torch.no_grad(): # No need to track the gradients
            # Define the lists to store the outputs for each batch
            test_triplet_loss = []
            test_label_loss = []
            test_total_loss = []
            total = 0
            correct = 0
            for anchor_ims, contrast_ims, labels in test_data:
                # Move tensor to the proper device
                anchor_ims = anchor_ims.to(device)
                contrast_ims = contrast_ims.to(device)
                labels = F.one_hot(labels, num_classes=self.num_classes)
                labels = labels.to(device)
                anchor_latent, positive_latent, negative_latent, pred_label = self.forward(anchor_ims, anchor_ims,contrast_ims) 
                # Append the network output and the original image to the lists
                triplet_loss, label_loss, total_loss = self.custom_loss(anchor_latent,
                                                                positive_latent, 
                                                                negative_latent, 
                                                                labels,
                                                                pred_label)
                total += labels.size(0)
                correct += (torch.argmax(pred_label, dim = 1) == torch.argmax(labels, dim = 1)).sum().item()
                test_triplet_loss.append(triplet_loss.item())
                test_label_loss.append(label_loss.item())
                test_total_loss.append(total_loss.item())
        test_triplet_loss = sum(test_triplet_loss)/len(test_triplet_loss)
        test_label_loss = sum(test_label_loss)/len(test_label_loss)
        test_total_loss = sum(test_total_loss)/len(test_total_loss)
        test_accuracy = correct/total
        return test_triplet_loss, test_label_loss, test_total_loss, test_accuracy

    def test_epoch_calculate_representation_separation(self, test_data):
    # Set evaluation mode for encoder and decoder
        self.eval()
        with torch.no_grad(): # No need to track the gradients
            accuracies = []
            for anchor_ims, contrast_ims, labels in test_data:
                # Move tensor to the proper device
                anchor_ims = anchor_ims.to(device)
                contrast_ims = contrast_ims.to(device)
                # labels = F.one_hot(labels, num_classes=self.num_classes)
                # labels = labels.to(device)
                anchor_latent, _, _, _ = self.forward(anchor_ims, anchor_ims,contrast_ims) 
                # use sklearn to predict labels from anchor_latent
                # calculate accuracy
                # x's are anchor_latent and y's are labels
                # append accuracy to list
                # put anchor_latent and labels on cpu and convert to numpy

          
                anchor_latent = anchor_latent.cpu().numpy()
                ### standard scale the data in anchor_latent before fitting to the model
                anchor_latent = StandardScaler().fit_transform(anchor_latent)
                labels = labels.cpu().numpy()
                
                lm = linear_model.LogisticRegression()
                lm.fit(anchor_latent, labels)
                # convert labels to sklearn format
                accuracies.append(lm.score(anchor_latent, labels))
        accuracy = sum(accuracies)/len(accuracies)
        return accuracy

    def train_epoch(self, train_data, optimizer, train_mode):
        self.train()
        train_triplet_loss = []
        train_label_loss = []
        train_total_loss = []
        correct = 0
        total = 0
        for anchor_ims, contrast_ims, labels in train_data:
            
            anchor_ims = anchor_ims.to(device)
            contrast_ims = contrast_ims.to(device)
            labels = F.one_hot(labels, num_classes=self.num_classes)
            labels = labels.to(device)

            optimizer.zero_grad()
            anchor_latent, positive_latent, negative_latent, pred_label = self.forward(anchor_ims, anchor_ims,contrast_ims) 
           
           
           
            triplet_loss, label_loss, total_loss = self.custom_loss(anchor_latent,
                                                                positive_latent, 
                                                                negative_latent, 
                                                                labels,
                                                                pred_label)
            
            
            if train_mode==0:
                triplet_loss.backward()
            elif train_mode==1:
                label_loss.backward()
            elif train_mode==2:
                total_loss.backward()

            optimizer.step()
            train_triplet_loss.append(triplet_loss.item())
            train_label_loss.append(label_loss.item())
            train_total_loss.append(total_loss.item())
            total += labels.size(0)
            correct += (torch.argmax(pred_label, dim = 1) == torch.argmax(labels, dim = 1)).sum().item()
        train_triplet_loss = sum(train_triplet_loss)/len(train_triplet_loss)
        train_label_loss = sum(train_label_loss)/len(train_label_loss)
        train_total_loss = sum(train_total_loss)/len(train_total_loss)
        train_accuracy = correct/total
        return train_triplet_loss, train_label_loss, train_total_loss, train_accuracy

    def training_loop(self, train_data, test_data,train_mode,
                      epochs, optimizer):
        train_losses = []
        val_losses = []
        train_triplet_losses = []
        val_triplet_losses = []
        train_label_losses = []
        val_label_losses = []
        train_accuracies = []
        val_accuracies = []
        latent_separation_accuracy = 0
        for epoch in tqdm(range(epochs)):
          train_triplet_loss, train_label_loss, train_total_loss, train_accuracy =self.train_epoch(train_data, optimizer, 
                                             train_mode)
          test_triplet_loss, test_label_loss, test_total_loss, test_accuracy = self.test_epoch(test_data)
          separation_accuracy = self.test_epoch_calculate_representation_separation(test_data)
          train_losses.append(train_total_loss)
          val_losses.append(test_total_loss)
          train_triplet_losses.append(train_triplet_loss)
          val_triplet_losses.append(test_triplet_loss)
          train_label_losses.append(train_label_loss)
          val_label_losses.append(test_label_loss)
          train_accuracies.append(train_accuracy)
          val_accuracies.append(test_accuracy)
          wandb.log({"train triplet loss": train_triplet_loss, 
            "train label loss":train_label_loss, 
            "validation triplet loss":test_triplet_loss, 
            "validation label loss":test_label_loss, 
            "total train loss":train_total_loss, 
            "total validation loss":test_total_loss, 
            "train label accuracy":train_accuracy, 
            "validation label accuracy":test_accuracy,
            'latent separation accuracy':separation_accuracy})
        return train_triplet_losses, train_label_losses, val_triplet_losses, val_label_losses ,train_losses, val_losses, train_accuracies, val_accuracies


In [9]:
set_A_ims = np.load(os.path.join(data_dir, 'set_A.npy'))
set_B_ims = np.load(os.path.join(data_dir, 'set_B.npy'))
set_C_ims = np.load(os.path.join(data_dir, 'set_C.npy'))
set_A_labs = np.load(os.path.join(data_dir, 'set_A_labs.npy'))
set_B_labs = np.load(os.path.join(data_dir, 'set_B_labs.npy'))
set_C_labs = np.load(os.path.join(data_dir, 'set_C_labs.npy'))


In [11]:

set_A_sub_ims =[]
set_B_sub_ims =[]

set_C_sub_ims =[]

set_A_sub_labs =[]
set_B_sub_labs =[]
set_C_sub_labs =[]


for i in range (4):
    sub_main = set_A_ims[i*600:(i*600)+600]
    labels_main = set_A_labs[i*600:(i*600)+600]
    np.random.seed(711)
    np.random.shuffle(sub_main)
    np.random.seed(711)
    np.random.shuffle(labels_main)

    set_A_sub_ims.append(sub_main[:30])
    set_B_sub_ims.append(sub_main[:15])
    set_B_sub_ims.append(sub_main[30:45])
    set_C_sub_ims.append(sub_main[35:65])

    set_A_sub_labs.append(labels_main[:30])
    set_B_sub_labs.append(labels_main[:15])
    set_B_sub_labs.append(labels_main[30:45])
    set_C_sub_labs.append(labels_main[35:65])


    




##flatten set_A_sub_ims into an array of shape 120,64,64,3
set_A_sub_ims = np.concatenate(set_A_sub_ims)
set_B_sub_ims = np.concatenate(set_B_sub_ims)
set_C_sub_ims = np.concatenate(set_C_sub_ims)

set_A_sub_labs = np.concatenate(set_A_sub_labs)
set_B_sub_labs = np.concatenate(set_B_sub_labs)
set_C_sub_labs = np.concatenate(set_C_sub_labs)


A-B: 50% \
A-C: 0% \
B-C: 33.33%

In [12]:

###initialize weights and bias tracking
def wandb_init(epochs, lr, train_mode, batch_size, model_number,data_set):
    wandb.init(project="ConceptualAlignment", settings=wandb.Settings(start_method="thread"))
    wandb.config = {
      "learning_rate": lr,
      "epochs": epochs,
      "batch_size": batch_size, 
      # "label_ratio":label_ratio, 
      "model_number": model_number,
      "dataset": data_set,
      "train_mode":train_mode,
    }
    train_mode_dict = {0:'triplet', 1:'label', 2:'label_and_triplet'}
    wandb.run.name = f'{data_set}_{train_mode_dict[train_mode]}_{model_number}'
    wandb.run.save()
     

In [13]:

def main_code(save_dir, num_models, epochs, num_classes, batch_size,
             lr, latent_dims):
  if os.path.isdir(save_dir):
    pass
  else:
    os.mkdir(save_dir)

  np.random.seed(42)
  torch.manual_seed(42)
  
  # test_intervals = [(540, 600), (1140, 1200), (1740, 1800), (2340, 2400)]
  test_intervals = [(25, 30), (55, 60), (85, 90), (115, 120)]
  # initialize an empty list to hold the indices
  val_indices = []

  # loop through the intervals and append the indices to the list
  for start, stop in test_intervals:
      val_indices.extend(list(range(start, stop)))

  # train_indices = (np.setdiff1d(np.arange(2400),np.array(val_indices)))
  train_indices = (np.setdiff1d(np.arange(120),np.array(val_indices)))

  # np.random.seed(56)
  # contrast_indices  = np.concatenate((np.random.choice(np.arange(start=600, stop=2400), 600, replace=False),
  #               np.random.choice(np.concatenate((np.arange(start=0, stop=600), np.arange(start=1200, stop=2400))), 600, replace=False),
  #               np.random.choice(np.concatenate((np.arange(start=0, stop=1200), np.arange(start=1800, stop=2400))), 600, replace=False),
  #               np.random.choice(np.arange(start=1800, stop=2400), 600, replace=False)))
  contrast_indices  = np.concatenate((np.random.choice(np.arange(start=30, stop=120), 30, replace=False),
                np.random.choice(np.concatenate((np.arange(start=0, stop=30), np.arange(start=60, stop=120))), 30, replace=False),
                np.random.choice(np.concatenate((np.arange(start=0, stop=60), np.arange(start=90, stop=120))), 30, replace=False),
                np.random.choice(np.arange(start=0, stop=90), 30, replace=False)))

  # for data_set in ['set_A','set_A2','set_B','set_C']:
  for data_set in ['set_A']:
    for train_mode in tqdm(range(2, 3)):
     # torch.manual_seed(0)
      for model in range(num_models):
        wandb_init(epochs, lr, train_mode, batch_size, model,data_set)
        weights_path = f'../../data/cifar_models/m{model}.pth'

        # if data_set=='set_A':
        #   train_data = TensorDataset(Resize(32)(torch.tensor(set_A_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_A_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
        #                              torch.tensor(set_A_labs).to(torch.int64))
        # elif data_set=='set_B':
        #   train_data = TensorDataset(Resize(32)(torch.tensor(set_B_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_B_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
        #                              torch.tensor(set_B_labs).to(torch.int64))
        # elif data_set=='set_C':
        #   train_data = TensorDataset(Resize(32)(torch.tensor(set_C_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_C_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
        #                              torch.tensor(set_C_labs).to(torch.int64))
        # if data_set=='set_A2':
        #   train_data = TensorDataset(Resize(32)(torch.tensor(set_A_sub_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_A_sub_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
        #                              torch.tensor(set_A_sub_labs).to(torch.int64))
        if data_set=='set_A':
          train_data = TensorDataset(Resize(32)(torch.tensor(set_A_sub_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_A_sub_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
                                     torch.tensor(set_A_sub_labs).to(torch.int64))
        # elif data_set=='set_B':
        #   train_data = TensorDataset(Resize(32)(torch.tensor(set_B_sub_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_B_sub_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
        #                              torch.tensor(set_B_sub_labs).to(torch.int64))
        # elif data_set=='set_C':
        #   train_data = TensorDataset(Resize(32)(torch.tensor(set_C_sub_ims.transpose(0,3,1,2)/255).float()), Resize(32)(torch.tensor(set_C_sub_ims[contrast_indices].transpose(0,3,1,2)/255).float()),\
        #                              torch.tensor(set_C_sub_labs).to(torch.int64))
        
        train_size = int(0.7 * len(train_data))
        val_size = len(train_data) - train_size
        
        train_data, val_data = random_split(train_data, [train_size, val_size], generator=torch.Generator().manual_seed(42))
       

        train_data = torch.utils.data.DataLoader(train_data, 
                                                batch_size=batch_size,
                                              shuffle=True)
        val_data = torch.utils.data.DataLoader(val_data, 
                                                batch_size=batch_size,
                                              shuffle=True)
        


        train_obj = TrainModels(latent_dims, num_classes, weights_path).to(device) # GPU
        optimizer = torch.optim.Adam(train_obj.parameters(), lr=lr, weight_decay=1e-05)
        train_triplet_losses, train_label_losses, \
          val_triplet_losses, val_label_losses, \
            train_losses, val_losses, train_accuracies, val_accuracies= train_obj.training_loop(train_data = train_data,
                                                            test_data = val_data,
                                                            epochs = epochs,
                                                            optimizer = optimizer, 
                                                            train_mode = train_mode)


        

        print('validation triplet loss:',val_triplet_losses[-1],'validation total loss:',val_losses[-1],'validation accuracy:',val_accuracies[-1])
        # wandb.log({"train_img_loss": train_img_loss, 
        #           "train_label_loss":train_label_loss, 
        #           "val_img_loss":val_img_loss, 
        #           "val_label_loss":val_label_loss, 
        #           "train_losses":train_losses, 
        #           "val_losses":val_losses, 
        #           "train_accuracy":train_accuracy, 
        #           "val_accuracy":val_accuracy})
        train_mode_dict = {0:'triplet', 1:'label',2:'label_and_triplet' }
        torch.save(train_obj.triplet_lab_model.state_dict(), os.path.join(save_dir,f'{model}_{data_set}_{train_mode_dict[train_mode]}_{round(val_accuracies[-1], 3)}.pth'))
  return val_data
        



In [14]:
wandb.finish()

num_classes = 4 # Number of unique class labels in the dataset
latent_dims = 64
epochs = 1000
lr = 0.005
num_models = 1
batch_size = 256
save_dir = save_dir
val_data = main_code(save_dir, num_models, epochs, num_classes, batch_size,
             lr, latent_dims)
wandb.finish()

  0%|          | 0/1 [00:00<?, ?it/s]Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33myuxizheng[0m. Use [1m`wandb login --relogin`[0m to force relogin


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

validation triplet loss: 0.10054220259189606 validation total loss: 0.995544970035553 validation accuracy: 0.8611111111111112




0,1
latent separation accuracy,██████████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total train loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
total validation loss,█▃▇▆▆▆▆▅▅▅▅▅▅▅▅▄▄▄▄▄▄▄▄▄▃▁▂▅█▇▆▆▆▆▆▅▅▅▅▅
train label accuracy,▁███████████████████████████████████████
train label loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train triplet loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation label accuracy,▁▆▅▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▇▇████████████
validation label loss,▂▃▆▅▅▅▅▅▅▅▅▅▄▄▄▄▄▄▄▄▄▄▃▃▃▁▃▆█▇▇▆▆▆▆▆▆▆▆▅
validation triplet loss,█▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
latent separation accuracy,0.97222
total train loss,0.0
total validation loss,0.99554
train label accuracy,1.0
train label loss,0.0
train triplet loss,0.0
validation label accuracy,0.86111
validation label loss,0.895
validation triplet loss,0.10054


In [10]:
class MaxLogitsLoss(torch.nn.Module):
    def __init__(self):
        super(MaxLogitsLoss, self).__init__()

    def forward(self, logits, targets):
        max_logits = torch.sum(logits * targets, dim=1)
        loss = -max_logits.mean()
        return loss

model_path = {40: "./results/0_set_A_label_and_triplet_0.417.pth",
              50: "./results/0_set_A_label_and_triplet_0.528.pth",
              60: "./results/0_set_A_label_and_triplet_0.639.pth",
              70: "./results/0_set_A_label_and_triplet_0.722.pth",
              80: "./results/0_set_A_label_and_triplet_0.806.pth"}

chooses = [40, 50, 60, 70, 80]

for choose in chooses:
    triplet_lab_model = TripletLabelModel(latent_dims, 4)
    triplet_lab_model.load_state_dict(torch.load(model_path[choose]))
    triplet_lab_model = triplet_lab_model.to(device)
    criterion = MaxLogitsLoss()

    optimizer = torch.optim.Adam(triplet_lab_model.parameters(), lr=0.001)
    num_epochs = 50

    for epoch in range(num_epochs):
        correct = 0
        total = 0
        for anchor_ims, contrast_ims, labels in val_data:
            anchor_ims = anchor_ims.to(device)
            contrast_ims = contrast_ims.to(device)
            labels = F.one_hot(labels, num_classes=4)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            _, label_logits = triplet_lab_model(anchor_ims)
            loss = criterion(label_logits, labels)
            loss.backward()
            optimizer.step()
            
            # Calculate accuracy
            _, predicted = torch.max(label_logits.data, 1)
            _, labels_max = torch.max(labels.data, 1)
            total += labels.size(0)
            correct += (predicted == labels_max).sum().item()

            accuracy = 100 * correct / total
        
        # print(f"Epoch {epoch}, Loss: {loss.item()}, Accuracy: {accuracy}%")

    print(f"Pretrained Model: {choose}, Epoch {epoch}, Loss: {loss.item()}, Accuracy: {accuracy}%")
    save_path = os.path.join(save_dir, f'0_set_A_maxlogits_acc_{choose}.pth')
    torch.save(triplet_lab_model.state_dict(), save_path)



NameError: name 'val_data' is not defined

In [None]:
# want the final layer logits , compute error based on the maximum logits ground truth 

In [20]:
import pickle

data_to_save = []

for batch in val_data:
    # Here, we assume each batch contains data and labels
    anchor_ims, contrast_ims, labels = batch
    # Convert tensors to numpy arrays or another suitable format for saving
    data_to_save.append((anchor_ims.numpy(), contrast_ims.numpy(), labels.numpy()))

# Save the data to a file
with open('val_data.pkl', 'wb') as file:
    pickle.dump(data_to_save, file)

# sid changes

In [12]:
import pickle

In [13]:
val_data = pickle.load(open('val_data.pkl', 'rb'))

In [21]:
# class MaxLogitsLoss(torch.nn.Module):
#     def __init__(self):
#         super(MaxLogitsLoss, self).__init__()

#     def forward(self, logits, targets):
#         max_logits = torch.sum(logits * targets, dim=1)
#         loss = -max_logits.mean()
#         return loss
latent_dims = 64

model_path = {40: "./results/0_set_A_label_and_triplet_0.417.pth",
              50: "./results/0_set_A_label_and_triplet_0.528.pth",
              60: "./results/0_set_A_label_and_triplet_0.639.pth",
              70: "./results/0_set_A_label_and_triplet_0.722.pth",
              80: "./results/0_set_A_label_and_triplet_0.806.pth"}

chooses = [40, 50, 60, 70, 80]

for choose in chooses:
    triplet_lab_model = TripletLabelModel(latent_dims, 4)
    triplet_lab_model.load_state_dict(torch.load(model_path[choose]))
    triplet_lab_model = triplet_lab_model.to(device)
    # criterion = MaxLogitsLoss()
    criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.Adam(triplet_lab_model.parameters(), lr=0.001)
    num_epochs = 50

    for epoch in range(num_epochs):
        correct = 0
        total = 0
        for anchor_ims, contrast_ims, labels in val_data:
            anchor_ims = torch.from_numpy(anchor_ims).to(device)
            contrast_ims = torch.from_numpy(contrast_ims).to(device)
            # labels = F.one_hot(labels, num_classes=4)
            # labels = labels.to(device)
            
            optimizer.zero_grad()
            _, label_logits = triplet_lab_model(anchor_ims)
            lab_softmax = F.softmax(label_logits, dim=1)
            import ipdb; ipdb.set_trace()
            loss = criterion(label_logits, labels)
            loss.backward()
            optimizer.step()
            
            # Calculate accuracy
            _, predicted = torch.max(label_logits.data, 1)
            _, labels_max = torch.max(labels.data, 1)
            total += labels.size(0)
            correct += (predicted == labels_max).sum().item()

            accuracy = 100 * correct / total
        
        # print(f"Epoch {epoch}, Loss: {loss.item()}, Accuracy: {accuracy}%")

    print(f"Pretrained Model: {choose}, Epoch {epoch}, Loss: {loss.item()}, Accuracy: {accuracy}%")
    save_path = os.path.join(save_dir, f'0_set_A_maxlogits_acc_{choose}.pth')
    torch.save(triplet_lab_model.state_dict(), save_path)



> [0;32m/tmp/ipykernel_21655/377933159.py[0m(42)[0;36m<module>[0;34m()[0m
[0;32m     41 [0;31m            [0;32mimport[0m [0mipdb[0m[0;34m;[0m [0mipdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 42 [0;31m            [0mloss[0m [0;34m=[0m [0mcriterion[0m[0;34m([0m[0mlabel_logits[0m[0;34m,[0m [0mlabels[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     43 [0;31m            [0mloss[0m[0;34m.[0m[0mbackward[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


tensor([[1.3493e-01, 3.8855e-01, 2.5582e-02, 4.5094e-01],
        [3.6291e-01, 1.2532e-01, 6.3376e-02, 4.4839e-01],
        [3.7697e-01, 1.3558e-01, 1.0663e-02, 4.7679e-01],
        [2.7803e-01, 3.3652e-01, 5.9843e-02, 3.2560e-01],
        [8.8920e-02, 1.7551e-01, 2.1351e-02, 7.1422e-01],
        [2.7628e-01, 4.2909e-01, 9.8371e-02, 1.9626e-01],
        [1.1462e-01, 1.6420e-01, 5.8876e-01, 1.3242e-01],
        [6.9496e-02, 5.9394e-01, 4.7910e-02, 2.8865e-01],
        [4.1053e-02, 3.1276e-02, 8.5095e-03, 9.1916e-01],
        [3.2236e-01, 8.0509e-02, 2.3023e-02, 5.7411e-01],
        [3.7139e-02, 1.9197e-01, 5.6107e-01, 2.0982e-01],
        [4.8098e-02, 1.5658e-02, 7.6736e-03, 9.2857e-01],
        [3.3845e-01, 6.9600e-02, 7.1299e-03, 5.8482e-01],
        [2.2335e-01, 3.9030e-01, 8.9676e-02, 2.9668e-01],
        [3.4253e-04, 9.6428e-03, 9.8696e-01, 3.0547e-03],
        [7.3291e-02, 2.6609e-01, 1.8252e-01, 4.7810e-01],
        [1.2007e-04, 4.2824e-03, 9.9516e-01, 4.3293e-04],
        [1.350