In [1]:
import os
base_dir = os.path.abspath('../..')
save_dir = os.path.join(base_dir,'results')
data_dir = os.path.join(base_dir,'data')

In [2]:
import torch
# torch.manual_seed(0)
import wandb
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision
from tqdm import tqdm
import random
import numpy as np
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
# from neurora.rdm_corr import rdm_correlation_spearman

  Referenced from: '/Users/kushinm/miniforge3/envs/sketch_models/lib/python3.8/site-packages/torchvision/image.so'
  Reason: tried: '/Users/kushinm/miniforge3/envs/sketch_models/lib/python3.8/site-packages/torchvision/../../../libjpeg.8.dylib' (no such file), '/Users/kushinm/miniforge3/envs/sketch_models/lib/python3.8/site-packages/torchvision/../../../libjpeg.8.dylib' (no such file), '/Users/kushinm/miniforge3/envs/sketch_models/lib/python3.8/lib-dynload/../../libjpeg.8.dylib' (no such file), '/Users/kushinm/miniforge3/envs/sketch_models/bin/../lib/libjpeg.8.dylib' (no such file), '/usr/local/lib/libjpeg.8.dylib' (no such file), '/usr/lib/libjpeg.8.dylib' (no such file)
  warn(f"Failed to load image Python extension: {e}")


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Define model

In [None]:
class TripletLabelModel(nn.Module):
    def __init__(self, encoded_space_dim, num_classes):
        super().__init__()
        ""
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
    
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)
        ### Linear section
        self.encoder_lin = nn.Sequential(
            nn.Linear(32*4*4, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True),
            nn.Linear(64, encoded_space_dim)
        )
        ##labeling module
        self.decoder_labels_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 10),
            nn.ReLU(True),
            nn.Linear(10, num_classes),
            nn.ReLU(True)
        )
    
        
    def forward(self, x, y=None):
        batch_s = x.size(0)
        img_features = self.encoder_cnn(x)
        img_features = self.flatten(img_features)
        
        out_latent = self.encoder_lin(img_features)

        label = self.decoder_labels_lin(x)
        label = F.softmax(label,dim=1)
        return out_latent, label

In [None]:
### custom loss computing triplet loss and labeling loss


class CustomLoss(nn.Module):
    def __init__(self, margin=0.5):
        super(CustomLoss, self).__init__()
        self.margin = margin
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, anchor, positive, negative, label, pred_label):
        distance_positive = torch.cdist(anchor - positive)
        distance_negative = torch.cdist(anchor - negative)
        triplet_loss = torch.maximum(distance_positive - distance_negative + self.margin, torch.tensor(0))
        label_loss = self.cross_entropy(pred_label, label.float())
        return triplet_loss, label_loss

### Training functions

In [None]:


class TrainModels(nn.Module):
    def __init__(self, latent_dims, num_classes):
        super(TrainModels, self).__init__()
        self.triplet_lab_model = TripletLabelModel(latent_dims, num_classes)
        self.custom_loss = CustomLoss()
        self.num_classes = num_classes
    
    def forward(self, anchor_im, positive_im, negative_im):
        anchor_latent, anchor_label = self.triplet_lab_model(anchor_im)
        positive_latent, _ = self.triplet_lab_model(positive_im)
        negative_latent, _ = self.triplet_lab_model(negative_im)

        return anchor_latent, positive_latent, negative_latent, anchor_label

    def test_epoch(self, test_data):
    # Set evaluation mode for encoder and decoder
        self.eval()
        with torch.no_grad(): # No need to track the gradients
            # Define the lists to store the outputs for each batch
            test_triplet_loss = []
            test_label_loss = []
            total_test_loss = []
            for image_batch, label_batch in test_data:
                total = 0
                correct = 0
                # Move tensor to the proper device
                image_batch = image_batch.to(device)
                label_batch = F.one_hot(label_batch, num_classes=self.num_classes)
                label_batch = label_batch.to(device)
                pred_latent, pred_label = self.forward(image_batch, label_batch.float()) 
                # Append the network output and the original image to the lists
                img_loss, label_loss, total_loss = self.custom_loss(image_batch,
                                                                pred_img, 
                                                                label_batch, 
                                                                pred_label)
                total += label_batch.size(0)
                correct += (torch.argmax(pred_label, dim = 1) == torch.argmax(label_batch, dim = 1)).sum().item()
                test_img_loss.append(img_loss.item())
                test_label_loss.append(label_loss.item())
                total_test_loss.append(total_loss.item())
        test_img_loss = sum(test_img_loss)/len(test_img_loss)
        test_label_loss = sum(test_label_loss)/len(test_label_loss)
        total_test_loss = sum(total_test_loss)/len(total_test_loss)
        test_accuracy = correct/total
        return test_img_loss, test_label_loss, total_test_loss, test_accuracy

    def train_epoch(self, train_data, optimizer, train_mode):
        self.train()
        torch.manual_seed(0)
        train_img_loss = []
        train_label_loss = []
        train_loss = []
        correct = 0
        total = 0
        for i, data in enumerate(train_data):
            data = data.to(device)
            anchor, positive, negative, anchor_label = data
            # Move tensor to the proper device
            image_batch = image_batch.to(device)
            # num_training_examples = label_batch.shape[0]
            # num_non_label_training_examples = num_training_examples*(1-training_label_ratio)
            # non_label_training_idx = random.sample(range(num_training_examples),int(num_non_label_training_examples))
            # label_batch[[non_label_training_idx]] = self.num_classes - 1
           
            label_batch = F.one_hot(label_batch, num_classes=self.num_classes)
            label_batch = label_batch.to(device)
            optimizer.zero_grad()
            if train_mode==0:
                pred_img, pred_label = self.forward(image_batch) 
            elif train_mode==1:
                pred_img, pred_label = self.forward(image_batch, label_batch.float()) 
            elif train_mode==2:
                pred_img, pred_label = self.forward(image_batch) 
            # Append the network output and the original image to the lists

           
            img_loss, label_loss, total_loss = self.custom_loss(image_batch,
                                                            pred_img, 
                                                            label_batch, 
                                                            pred_label)
            
            
            if train_mode==0:
                img_loss.backward()
            elif train_mode==1:
                total_loss.backward()
            elif train_mode==2:
                label_loss.backward()

            optimizer.step()
            train_img_loss.append(img_loss.item())
            train_label_loss.append(label_loss.item())
            train_loss.append(total_loss.item())
            total += label_batch.size(0)
            correct += (torch.argmax(pred_label, dim = 1) == torch.argmax(label_batch, dim = 1)).sum().item()
        train_img_loss = sum(train_img_loss)/len(train_img_loss)
        train_label_loss = sum(train_label_loss)/len(train_label_loss)
        train_loss = sum(train_loss)/len(train_loss)
        train_accuracy = correct/total
        return train_img_loss, train_label_loss, train_loss, train_accuracy

    def training_loop(self, train_data, test_data,train_mode,
                      epochs, optimizer):
        train_losses = []
        val_losses = []
        train_img_losses = []
        val_img_losses = []
        train_label_losses = []
        val_label_losses = []
        train_accuracies = []
        val_accuracies = []
        for epoch in tqdm(range(epochs)):
          train_img_loss, train_label_loss, train_loss, train_accuracy =self.train_epoch(train_data, optimizer, 
                                             train_mode)
          val_img_loss, val_label_loss, val_loss, val_accuracy = self.test_epoch(test_data)
          train_losses.append(train_loss)
          val_losses.append(val_loss)
          train_img_losses.append(train_img_loss)
          val_img_losses.append(val_img_loss)
          train_label_losses.append(train_label_loss)
          val_label_losses.append(val_label_loss)
          train_accuracies.append(train_accuracy)
          val_accuracies.append(val_accuracy)
          wandb.log({"train_img_loss": train_img_loss, 
            "train_label_loss":train_label_loss, 
            "val_img_loss":val_img_loss, 
            "val_label_loss":val_label_loss, 
            "train_losses":train_loss, 
            "val_losses":val_loss, 
            "train_accuracy":train_accuracy, 
            "val_accuracy":val_accuracy})

        return train_img_losses, train_label_losses, val_img_losses, val_label_losses ,train_losses, val_losses, train_accuracies, val_accuracies
