In [1]:
import os
base_dir = os.path.abspath('../..')
save_dir = os.path.join(base_dir,'results')
data_dir = os.path.join(base_dir,'data')

In [2]:
import torch
# torch.manual_seed(0)
import wandb
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
from torch.utils.data import TensorDataset,Dataset
from torch.utils.data.sampler import SubsetRandomSampler

import torchvision
from tqdm import tqdm
import random
import numpy as np
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200


from sklearn import linear_model
# from neurora.rdm_corr import rdm_correlation_spearman

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Define model

In [4]:
class TripletLabelModel(nn.Module):
    def __init__(self, encoded_space_dim=10, num_classes=4):
        super().__init__()
        ""
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
    
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)
        ### Linear section
        self.encoder_lin = nn.Sequential(
            nn.Linear(32*4*4, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True),
            nn.Linear(64, encoded_space_dim)
        )
        ##labeling module
        self.decoder_labels_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 10),
            nn.ReLU(True),
            nn.Linear(10, num_classes),
            nn.ReLU(True)
        )
    
        
    def forward(self, x, y=None):
        batch_s = x.size(0)
        img_features = self.encoder_cnn(x)
        img_features = self.flatten(img_features)
        
        out_latent = self.encoder_lin(img_features)

        label = self.decoder_labels_lin(out_latent)
        label = F.softmax(label,dim=1)
        return out_latent, label

In [5]:
### custom loss computing triplet loss and labeling loss


class CustomLoss(nn.Module):
    def __init__(self, margin=10):
        super(CustomLoss, self).__init__()
        self.margin = margin
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, anchor, positive, negative, label, pred_label):
        cosine_sim = torch.nn.CosineSimilarity(1)
        distance_positive = torch.tensor(1)-cosine_sim(anchor,positive)
   
        distance_negative = torch.tensor(1)-cosine_sim(anchor,negative)

        triplet_loss = torch.maximum(distance_positive - distance_negative + self.margin, torch.tensor(0))
        triplet_loss = torch.sum(triplet_loss)
        label_loss = self.cross_entropy(pred_label, label.float())
        total_loss = triplet_loss + label_loss
        return triplet_loss, label_loss, total_loss

### Training functions

In [6]:


class TrainModels(nn.Module):
    def __init__(self, latent_dims, num_classes):
        super(TrainModels, self).__init__()
        self.triplet_lab_model = TripletLabelModel(latent_dims, num_classes)
        self.custom_loss = CustomLoss()
        self.num_classes = num_classes
    
    def forward(self, anchor_im, positive_im, negative_im):
        anchor_latent, anchor_label = self.triplet_lab_model(anchor_im)
        positive_latent, _ = self.triplet_lab_model(positive_im)
        negative_latent, _ = self.triplet_lab_model(negative_im)

        return anchor_latent, positive_latent, negative_latent, anchor_label

    def test_epoch(self, test_data):
    # Set evaluation mode for encoder and decoder
        self.eval()
        with torch.no_grad(): # No need to track the gradients
            # Define the lists to store the outputs for each batch
            test_triplet_loss = []
            test_label_loss = []
            test_total_loss = []
            total = 0
            correct = 0
            for anchor_ims, contrast_ims, labels in test_data:
                # Move tensor to the proper device
                anchor_ims = anchor_ims.to(device)
                contrast_ims = contrast_ims.to(device)
                labels = F.one_hot(labels, num_classes=self.num_classes)
                labels = labels.to(device)
                anchor_latent, positive_latent, negative_latent, pred_label = self.forward(anchor_ims, anchor_ims,contrast_ims) 
                # Append the network output and the original image to the lists
                triplet_loss, label_loss, total_loss = self.custom_loss(anchor_latent,
                                                                positive_latent, 
                                                                negative_latent, 
                                                                labels,
                                                                pred_label)
                total += labels.size(0)
                correct += (torch.argmax(pred_label, dim = 1) == torch.argmax(labels, dim = 1)).sum().item()
                test_triplet_loss.append(triplet_loss.item())
                test_label_loss.append(label_loss.item())
                test_total_loss.append(total_loss.item())
        test_triplet_loss = sum(test_triplet_loss)/len(test_triplet_loss)
        test_label_loss = sum(test_label_loss)/len(test_label_loss)
        test_total_loss = sum(test_total_loss)/len(test_total_loss)
        test_accuracy = correct/total
        return test_triplet_loss, test_label_loss, test_total_loss, test_accuracy

    def test_epoch_calculate_representation_separation(self, test_data):
    # Set evaluation mode for encoder and decoder
        self.eval()
        with torch.no_grad(): # No need to track the gradients
            accuracies = []
            for anchor_ims, contrast_ims, labels in test_data:
                # Move tensor to the proper device
                anchor_ims = anchor_ims.to(device)
                contrast_ims = contrast_ims.to(device)
                # labels = F.one_hot(labels, num_classes=self.num_classes)
                # labels = labels.to(device)
                anchor_latent, _, _, _ = self.forward(anchor_ims, anchor_ims,contrast_ims) 
                # use sklearn to predict labels from anchor_latent
                # calculate accuracy
                # x's are anchor_latent and y's are labels
                # append accuracy to list
                # put anchor_latent and labels on cpu and convert to numpy
                anchor_latent = anchor_latent.cpu().numpy()
                lm = linear_model.LogisticRegression()
                lm.fit(anchor_latent, labels)
                # convert labels to sklearn format
                accuracies.append(lm.score(anchor_latent, labels))
        accuracy = sum(accuracies)/len(accuracies)
        return accuracy

    def train_epoch(self, train_data, optimizer, train_mode):
        self.train()
        train_triplet_loss = []
        train_label_loss = []
        train_total_loss = []
        correct = 0
        total = 0
        for anchor_ims, contrast_ims, labels in train_data:
            
            anchor_ims = anchor_ims.to(device)
            contrast_ims = contrast_ims.to(device)
            labels = F.one_hot(labels, num_classes=self.num_classes)
            labels = labels.to(device)

            optimizer.zero_grad()
            anchor_latent, positive_latent, negative_latent, pred_label = self.forward(anchor_ims, anchor_ims,contrast_ims) 
           
           
           
            triplet_loss, label_loss, total_loss = self.custom_loss(anchor_latent,
                                                                positive_latent, 
                                                                negative_latent, 
                                                                labels,
                                                                pred_label)
            
            
            if train_mode==0:
                triplet_loss.backward()
            elif train_mode==1:
                label_loss.backward()
            elif train_mode==2:
                total_loss.backward()

        

            optimizer.step()
            train_triplet_loss.append(triplet_loss.item())
            train_label_loss.append(label_loss.item())
            train_total_loss.append(total_loss.item())
            total += labels.size(0)
            correct += (torch.argmax(pred_label, dim = 1) == torch.argmax(labels, dim = 1)).sum().item()
        train_triplet_loss = sum(train_triplet_loss)/len(train_triplet_loss)
        train_label_loss = sum(train_label_loss)/len(train_label_loss)
        train_total_loss = sum(train_total_loss)/len(train_total_loss)
        train_accuracy = correct/total
        return train_triplet_loss, train_label_loss, train_total_loss, train_accuracy

    def training_loop(self, train_data, test_data,train_mode,
                      epochs, optimizer):
        train_losses = []
        val_losses = []
        train_triplet_losses = []
        val_triplet_losses = []
        train_label_losses = []
        val_label_losses = []
        train_accuracies = []
        val_accuracies = []
        latent_separation_accuracy = 0
        for epoch in tqdm(range(epochs)):
          train_triplet_loss, train_label_loss, train_total_loss, train_accuracy =self.train_epoch(train_data, optimizer, 
                                             train_mode)
          test_triplet_loss, test_label_loss, test_total_loss, test_accuracy = self.test_epoch(test_data)
          separation_accuracy = self.test_epoch_calculate_representation_separation(test_data)
          train_losses.append(train_total_loss)
          val_losses.append(test_total_loss)
          train_triplet_losses.append(train_triplet_loss)
          val_triplet_losses.append(test_triplet_loss)
          train_label_losses.append(train_label_loss)
          val_label_losses.append(test_label_loss)
          train_accuracies.append(train_accuracy)
          val_accuracies.append(test_accuracy)
          wandb.log({"train triplet loss": train_triplet_loss, 
            "train label loss":train_label_loss, 
            "validation triplet loss":test_triplet_loss, 
            "validation label loss":test_label_loss, 
            "total train loss":train_total_loss, 
            "total validation loss":test_total_loss, 
            "train label accuracy":train_accuracy, 
            "validation label accuracy":test_accuracy,
            'latent separation accuracy':separation_accuracy})
        return train_triplet_losses, train_label_losses, val_triplet_losses, val_label_losses ,train_losses, val_losses, train_accuracies, val_accuracies


In [7]:
set_A_ims = np.load(os.path.join(data_dir, 'set_A.npy'))
set_B_ims = np.load(os.path.join(data_dir, 'set_B.npy'))
set_C_ims= np.load(os.path.join(data_dir, 'set_C.npy'))
set_A_labs = np.load(os.path.join(data_dir, 'set_A_labs.npy'))
set_B_labs = np.load(os.path.join(data_dir, 'set_B_labs.npy'))
set_C_labs = np.load(os.path.join(data_dir, 'set_C_labs.npy'))


In [8]:

###initialize weights and bias tracking
def wandb_init(epochs, lr, train_mode, batch_size, model_number,data_set):
  wandb.init(project="ConceptualAlignmentLanguage", entity="psych-711",settings=wandb.Settings(start_method="thread"))
  wandb.config = {
    "learning_rate": lr,
    "epochs": epochs,
    "batch_size": batch_size, 
    # "label_ratio":label_ratio, 
    "model_number": model_number,
    "dataset": data_set,
    "train_mode":train_mode,
  }
  train_mode_dict = {0:'no label', 1:'label', 2:'label and triplet' }
  wandb.run.name = f'{data_set}_{train_mode_dict[train_mode]}_{model_number}'
  wandb.run.save()
     

In [9]:

def main_code(save_dir, num_models, epochs, num_classes, batch_size,
             lr, latent_dims):
  if os.path.isdir(save_dir):
    pass
  else:
    os.mkdir(save_dir)


  test_intervals = [(540, 600), (1140, 1200), (1740, 1800), (2340, 2400)]

  # initialize an empty list to hold the indices
  val_indices = []

  # loop through the intervals and append the indices to the list
  for start, stop in test_intervals:
      val_indices.extend(list(range(start, stop)))

  train_indices = (np.setdiff1d(np.arange(2400),np.array(val_indices)))

  np.random.seed(56)
  contrast_indices  = np.concatenate((np.random.choice(np.arange(start=600, stop=2400), 600, replace=False),
                np.random.choice(np.concatenate((np.arange(start=0, stop=600), np.arange(start=1200, stop=2400))), 600, replace=False),
                np.random.choice(np.concatenate((np.arange(start=0, stop=1200), np.arange(start=1800, stop=2400))), 600, replace=False),
                np.random.choice(np.arange(start=1800, stop=2400), 600, replace=False)))

  for data_set in ['set_A','set_B','set_C']:
    for train_mode in tqdm(range(3)):
     # torch.manual_seed(0)
      for model in range(num_models):
        wandb_init(epochs, lr, train_mode, batch_size, model,data_set)

        if data_set=='set_A':
          train_data = TensorDataset(torch.tensor(set_A_ims.transpose(0,3,1,2)/255).float(), torch.tensor(set_A_ims[contrast_indices].transpose(0,3,1,2)/255).float(),\
                                     torch.tensor(set_A_labs).to(torch.int64))
        elif data_set=='set_B':
          train_data = TensorDataset(torch.tensor(set_B_ims.transpose(0,3,1,2)/255).float(), torch.tensor(set_B_ims[contrast_indices].transpose(0,3,1,2)/255).float(),\
                                     torch.tensor(set_B_labs).to(torch.int64))
        elif data_set=='set_C':
          train_data = TensorDataset(torch.tensor(set_C_ims.transpose(0,3,1,2)/255).float(), torch.tensor(set_C_ims[contrast_indices].transpose(0,3,1,2)/255).float(),\
                                     torch.tensor(set_C_labs).to(torch.int64))
          
        val_data = torch.utils.data.Subset(train_data, val_indices)
        train_data = torch.utils.data.Subset(train_data, train_indices)
       

        train_data = torch.utils.data.DataLoader(train_data, 
                                                batch_size=batch_size,
                                              shuffle=True)
        val_data = torch.utils.data.DataLoader(val_data, 
                                                batch_size=batch_size,
                                              shuffle=True)
        
     

        train_obj = TrainModels(latent_dims, num_classes).to(device) # GPU
        optimizer = torch.optim.Adam(train_obj.parameters(), lr=lr, weight_decay=1e-05)
        train_triplet_losses, train_label_losses, \
          val_triplet_losses, val_label_losses, \
            train_losses, val_losses, train_accuracies, val_accuracies= train_obj.training_loop(train_data = train_data,
                                                            test_data = val_data,
                                                            epochs = epochs,
                                                            optimizer = optimizer, 
                                                            train_mode = train_mode)




        print('validation triplet loss:',val_triplet_losses,'validation total loss:',val_losses,'validation accuracy:',val_accuracies)
        # wandb.log({"train_img_loss": train_img_loss, 
        #           "train_label_loss":train_label_loss, 
        #           "val_img_loss":val_img_loss, 
        #           "val_label_loss":val_label_loss, 
        #           "train_losses":train_losses, 
        #           "val_losses":val_losses, 
        #           "train_accuracy":train_accuracy, 
        #           "val_accuracy":val_accuracy})
        train_mode_dict = {0:'no label', 1:'label', 2:'label and triplet' }
        torch.save(train_obj.state_dict(), os.path.join(save_dir,f'{data_set}_{train_mode_dict[train_mode]}_{model}'))
        



In [15]:
num_classes = 4 # Number of unique class labels in the dataset
latent_dims = 32
epochs = 1000
lr = 0.001
num_models = 1
batch_size = 1024
save_dir = save_dir
main_code(save_dir, num_models, epochs, num_classes, batch_size,
             lr, latent_dims)
wandb.finish()

  0%|          | 0/3 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [05:18<00:00,  3.14it/s]
 33%|███▎      | 1/3 [05:27<10:54, 327.12s/it]

validation triplet loss: [2399.6357421875, 2394.532958984375, 2359.48388671875, 2290.60498046875, 2222.582763671875, 2192.6298828125, 2179.99169921875, 2177.47265625, 2180.35498046875, 2179.1015625, 2178.019775390625, 2169.6552734375, 2171.4140625, 2169.357177734375, 2166.63720703125, 2164.930419921875, 2166.349609375, 2163.876220703125, 2164.4912109375, 2164.5078125, 2171.88720703125, 2157.707763671875, 2186.857421875, 2181.73388671875, 2150.795654296875, 2172.63525390625, 2171.926025390625, 2160.03662109375, 2166.302734375, 2171.970458984375, 2161.35888671875, 2163.75927734375, 2176.045166015625, 2158.4501953125, 2171.469970703125, 2171.290283203125, 2158.876220703125, 2168.9833984375, 2172.72705078125, 2158.57958984375, 2181.44677734375, 2185.0634765625, 2164.990234375, 2162.0302734375, 2176.5400390625, 2165.8818359375, 2166.25341796875, 2171.278076171875, 2162.655029296875, 2171.00537109375, 2161.588623046875, 2175.87548828125, 2155.171875, 2183.91357421875, 2192.685791015625, 2175

0,1
latent separation accuracy,█▆▅▄▃▃▄▁▃▃▂▆▄▄▃▆▅▄▃▄▆▆▅▆▄▃▄▆▄▃▆▆▄▃▃▃▆▂▂▄
total train loss,█▇█▆▅▅▄▄▄▄▄▄▄▄▄▄▄▄▃▄▃▄▃▄▄▃▃▃▂▂▂▁▂▂▁▁▁▁▁▂
total validation loss,█▅█▅▆▂▄▄▅▆▅▅▄▅▅▅▅▅▅▄▅▄▃▂▆▇▅▅▂▆▆▄▄▇▆▂▁▂▂▃
train label accuracy,█████████▇██▇█▃▄▃▄▄▁▂▂▂▃▁▂▃▂▄███▇█▇▅▅▄▅▇
train label loss,▂▆▅▄▄▅▅▂▄▆▆▅▅▆▅█▄▄▅▄▂▆▂▇█▇▆▇▅▅▁▃▆▅▅▆▆▆▅▅
train triplet loss,█▇█▆▅▅▄▄▄▄▄▄▄▄▄▄▄▄▃▄▃▄▃▄▄▃▃▃▂▂▂▁▂▂▁▁▁▁▁▂
validation label accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation label loss,▁▂▂▁▂▂▂▁▂▂▂▃▂▃▃▃▃▃▃▃▃▃▄▃▄▄▄▄▄▃▅▆▇▆▅▆▇▅▅█
validation triplet loss,█▅█▅▆▂▄▄▅▆▅▅▄▅▅▅▅▅▅▄▅▄▃▂▆▇▅▅▂▆▆▄▄▇▆▂▁▂▂▃

0,1
latent separation accuracy,0.34167
total train loss,6227.98104
total validation loss,2159.46753
train label accuracy,0.23472
train label loss,1.3868
train triplet loss,6226.59436
validation label accuracy,0.25
validation label loss,1.38794
validation triplet loss,2158.07959


100%|██████████| 1000/1000 [04:57<00:00,  3.37it/s]
 67%|██████▋   | 2/3 [10:39<05:18, 318.39s/it]

validation triplet loss: [2399.879150390625, 2399.583984375, 2397.556640625, 2393.011474609375, 2386.93994140625, 2371.25390625, 2367.41943359375, 2358.211181640625, 2363.46875, 2371.2705078125, 2374.315185546875, 2378.244384765625, 2384.109130859375, 2385.0048828125, 2387.2265625, 2388.68505859375, 2388.7041015625, 2389.57275390625, 2388.594970703125, 2387.93505859375, 2386.35791015625, 2385.072998046875, 2383.437744140625, 2382.968017578125, 2382.152099609375, 2382.34423828125, 2382.568359375, 2381.978515625, 2381.42431640625, 2380.91064453125, 2380.1123046875, 2379.7939453125, 2379.096923828125, 2379.275634765625, 2379.138916015625, 2379.33642578125, 2379.09765625, 2378.8408203125, 2378.02099609375, 2378.996826171875, 2378.05078125, 2378.39208984375, 2377.77587890625, 2378.09423828125, 2377.730224609375, 2378.89990234375, 2378.463134765625, 2378.232421875, 2378.0908203125, 2377.720703125, 2377.4482421875, 2377.029296875, 2376.7041015625, 2377.554443359375, 2376.90771484375, 2376.605

0,1
latent separation accuracy,▁▇▇█████████████████████████████████████
total train loss,▂▅▄▃▁▁▁▁▁▁▁▂▂▁▁▁▁▁▂▁▁▁▁▁▁▁▂▁▂▁▁▁█▇▆▆▆▆▆▆
total validation loss,▁▇▆▅▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▃▂▃▂▂▂█▆▆▆▆▆▆▅
train label accuracy,▁▅▅▅████████████████████████████▄▅▅▅▅▅▅▅
train label loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▄▄▄▃▄▄▃
train triplet loss,▂▅▄▃▁▁▁▁▁▁▁▂▂▁▁▁▁▁▂▁▁▁▁▁▁▁▂▁▂▁▁▁█▇▆▆▆▆▆▆
validation label accuracy,▁▆▆▆▇█████▇█▇▇▇▇▇▇██▇▇▇▇▇▇█████▇▃▆▅▅▅▅▆▅
validation label loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▄▄▄▄▄▄▄
validation triplet loss,▁▇▆▅▂▂▂▂▂▂▂▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▃▃▃▂▂▂█▆▆▆▆▆▆▅

0,1
latent separation accuracy,1.0
total train loss,7134.20955
total validation loss,2376.51514
train label accuracy,0.75
train label loss,1.05647
train triplet loss,7133.1532
validation label accuracy,0.725
validation label loss,1.06779
validation triplet loss,2375.44727


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016670519618007043, max=1.0…



VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…