Below is an implementation of an autoencoder written in PyTorch. We apply it to the MNIST dataset.

In [1]:
import os
base_dir = os.path.abspath('../..')
save_dir = os.path.join(base_dir,'results')
data_dir = os.path.join(base_dir,'data')

In [2]:
# !pip install wandb -q
# !pip install neurora -q

#wandb api key 18a861e71f78135d23eb672c08922edbfcb8d364

In [3]:
import torch
# torch.manual_seed(0)
import wandb
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision
from tqdm import tqdm
import random
import numpy as np
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
# from neurora.rdm_corr import rdm_correlation_spearman

  Referenced from: '/Users/kushinm/miniforge3/envs/sketch_models/lib/python3.8/site-packages/torchvision/image.so'
  Reason: tried: '/Users/kushinm/miniforge3/envs/sketch_models/lib/python3.8/site-packages/torchvision/../../../libjpeg.8.dylib' (no such file), '/Users/kushinm/miniforge3/envs/sketch_models/lib/python3.8/site-packages/torchvision/../../../libjpeg.8.dylib' (no such file), '/Users/kushinm/miniforge3/envs/sketch_models/lib/python3.8/lib-dynload/../../libjpeg.8.dylib' (no such file), '/Users/kushinm/miniforge3/envs/sketch_models/bin/../lib/libjpeg.8.dylib' (no such file), '/usr/local/lib/libjpeg.8.dylib' (no such file), '/usr/lib/libjpeg.8.dylib' (no such file)
  warn(f"Failed to load image Python extension: {e}")


In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Training Models

### Run Next 2 cells to instantiate model class

In [5]:
class Encoder(nn.Module):
    def __init__(self, encoded_space_dim, num_classes):
        super().__init__()
        ""
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(3, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=0),
            nn.ReLU(True)
        )
        
    
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)
        self.encoder_labels_lin = nn.Linear(num_classes, num_classes//2)
        ### Linear section
        self.encoder_lin = nn.Sequential(
            nn.Linear(7 * 7 * 32 + num_classes//2, 128),
            nn.ReLU(True),
            nn.Linear(128, encoded_space_dim)
        )
        
        
    def forward(self, x, y=None):
        batch_s = x.size(0)
        img_features = self.encoder_cnn(x)
        img_features = self.flatten(img_features)
        if y== None:
            combined = torch.cat((img_features, torch.zeros(batch_s,num_classes//2)), dim = -1)
        else:
            label_features = self.encoder_labels_lin(y)
            combined = torch.cat((img_features, label_features), dim = -1)
 
        out = self.encoder_lin(combined)
        return out

class Decoder(nn.Module):
    
    def __init__(self, encoded_space_dim, num_classes):
        super().__init__()
        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 7 * 7 * 32+ num_classes//2),
            nn.ReLU(True)
        )
        self.decoder_labels_lin = (nn.Linear(num_classes//2, num_classes))
        self.flatten = nn.Flatten(start_dim=1)
        self.unflatten = nn.Unflatten(dim=1, 
        unflattened_size=(32, 7, 7))

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, 
            stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, 
            padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 3, 3, stride=2, 
            padding=1, output_padding=1,dilation=3)
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)
        img_features = x[:, :-(num_classes//2)]
        label_features = x[:, -(num_classes//2):]
        img_features = self.unflatten(img_features)
        
        img_features = self.decoder_conv(img_features)
        
        img = torch.sigmoid(img_features)
        label = self.decoder_labels_lin(label_features)
       
        label = F.softmax(label,dim=1)
        
        # x = self.decoder_lin(x)
    
        # img_features = self.unflatten(x)
        # img_features = self.decoder_conv(img_features)
     
        # img = torch.sigmoid(img_features)
        # label = self.decoder_labels_lin(self.flatten(img_features))
        return img, label

In [6]:
class CustomLoss(nn.Module):
    def __init__(self):
        super().__init__()
        # self.gamma = nn.Parameter(torch.tensor([.5]))
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, img, pred_img, label, pred_label):
        mse = nn.MSELoss()
        #mse_loss_img = ((img - pred_img)**2).sum()
        mse_loss_img = mse(pred_img, img)
        mse_loss_label = self.cross_entropy(pred_label, label.float())
        # loss = mse_loss_img * torch.sigmoid(self.gamma) + \
              # mse_loss_label * (1 - torch.sigmoid(self.gamma))
        loss = mse_loss_img + mse_loss_label#*torch.sigmoid(self.gamma)
        return mse_loss_img, mse_loss_label, loss

In [7]:


class Autoencoder(nn.Module):
    def __init__(self, latent_dims, num_classes):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder(latent_dims, num_classes)
        self.decoder = Decoder(latent_dims, num_classes)
        self.custom_loss = CustomLoss()
        self.num_classes = num_classes
    
    def forward(self, x, y=None):
        z = self.encoder(x, y)  ### latent vector
        return self.decoder(z) ### image and label

    def test_epoch(self, test_data):
    # Set evaluation mode for encoder and decoder
        self.eval()
        with torch.no_grad(): # No need to track the gradients
            # Define the lists to store the outputs for each batch
            test_img_loss = []
            test_label_loss = []
            total_test_loss = []
            for image_batch, label_batch in test_data:
                total = 0
                correct = 0
                # Move tensor to the proper device
                image_batch = image_batch.to(device)
                label_batch = F.one_hot(label_batch, num_classes=self.num_classes)
                label_batch = label_batch.to(device)
                pred_img, pred_label = self.forward(image_batch, label_batch.float()) 
                # Append the network output and the original image to the lists
                img_loss, label_loss, total_loss = self.custom_loss(image_batch,
                                                                pred_img, 
                                                                label_batch, 
                                                                pred_label)
                total += label_batch.size(0)
                correct += (torch.argmax(pred_label, dim = 1) == torch.argmax(label_batch, dim = 1)).sum().item()
                test_img_loss.append(img_loss.item())
                test_label_loss.append(label_loss.item())
                total_test_loss.append(total_loss.item())
        test_img_loss = sum(test_img_loss)/len(test_img_loss)
        test_label_loss = sum(test_label_loss)/len(test_label_loss)
        total_test_loss = sum(total_test_loss)/len(total_test_loss)
        test_accuracy = correct/total
        return test_img_loss, test_label_loss, total_test_loss, test_accuracy

    def train_epoch(self, train_data, optimizer, train_mode):
        self.train()
        torch.manual_seed(0)
        train_img_loss = []
        train_label_loss = []
        train_loss = []
        correct = 0
        total = 0
        for image_batch, label_batch in train_data:
            # Move tensor to the proper device
            image_batch = image_batch.to(device)
            # num_training_examples = label_batch.shape[0]
            # num_non_label_training_examples = num_training_examples*(1-training_label_ratio)
            # non_label_training_idx = random.sample(range(num_training_examples),int(num_non_label_training_examples))
            # label_batch[[non_label_training_idx]] = self.num_classes - 1
           
            label_batch = F.one_hot(label_batch, num_classes=self.num_classes)
            label_batch = label_batch.to(device)
            optimizer.zero_grad()
            if train_mode==0:
                pred_img, pred_label = self.forward(image_batch) 
            elif train_mode==1:
                pred_img, pred_label = self.forward(image_batch, label_batch.float()) 
            elif train_mode==2:
                pred_img, pred_label = self.forward(image_batch) 
            # Append the network output and the original image to the lists

           
            img_loss, label_loss, total_loss = self.custom_loss(image_batch,
                                                            pred_img, 
                                                            label_batch, 
                                                            pred_label)
            
            
            if train_mode==0:
                img_loss.backward()
            elif train_mode==1:
                total_loss.backward()
            elif train_mode==2:
                label_loss.backward()

            optimizer.step()
            train_img_loss.append(img_loss.item())
            train_label_loss.append(label_loss.item())
            train_loss.append(total_loss.item())
            total += label_batch.size(0)
            correct += (torch.argmax(pred_label, dim = 1) == torch.argmax(label_batch, dim = 1)).sum().item()
        train_img_loss = sum(train_img_loss)/len(train_img_loss)
        train_label_loss = sum(train_label_loss)/len(train_label_loss)
        train_loss = sum(train_loss)/len(train_loss)
        train_accuracy = correct/total
        return train_img_loss, train_label_loss, train_loss, train_accuracy

    def training_loop(self, train_data, test_data,train_mode,
                      epochs, optimizer):
        train_losses = []
        val_losses = []
        train_img_losses = []
        val_img_losses = []
        train_label_losses = []
        val_label_losses = []
        train_accuracies = []
        val_accuracies = []
        for epoch in tqdm(range(epochs)):
          train_img_loss, train_label_loss, train_loss, train_accuracy =self.train_epoch(train_data, optimizer, 
                                             train_mode)
          val_img_loss, val_label_loss, val_loss, val_accuracy = self.test_epoch(test_data)
          train_losses.append(train_loss)
          val_losses.append(val_loss)
          train_img_losses.append(train_img_loss)
          val_img_losses.append(val_img_loss)
          train_label_losses.append(train_label_loss)
          val_label_losses.append(val_label_loss)
          train_accuracies.append(train_accuracy)
          val_accuracies.append(val_accuracy)
          wandb.log({"train_img_loss": train_img_loss, 
            "train_label_loss":train_label_loss, 
            "val_img_loss":val_img_loss, 
            "val_label_loss":val_label_loss, 
            "train_losses":train_loss, 
            "val_losses":val_loss, 
            "train_accuracy":train_accuracy, 
            "val_accuracy":val_accuracy})

        return train_img_losses, train_label_losses, val_img_losses, val_label_losses ,train_losses, val_losses, train_accuracies, val_accuracies

# def plot_latent_with_label(self, batch_size, data, random_labels,
#                            num_classes, num_batches=100):
#     if not random_labels:
#         for i, (x, y) in enumerate(data):
#             x = x.to(device) # GPU
#             y_one_hot = F.one_hot(y.to(torch.int64), num_classes)
#             y_one_hot = y_one_hot.to(device).float()
#             z = self.encoder(x, y_one_hot)
#             z = z.to('cpu').detach().numpy()
#             plt.scatter(z[:, 0], z[:, 1], c=y, cmap='tab10')
#             if i > num_batches:
#                 plt.colorbar()
#                 break
#         wandb.log({"Latent : Real labels in input": plt})
#     else:
#         for i, (x, y) in enumerate(data):
#             x = x.to(device) # GPU
#             y_rand = torch.zeros((x.size(0), num_classes))
#             y_rand = y_rand.to(device)
#             z = autoencoder.encoder(x, y_rand.float())
#             z = z.to('cpu').detach().numpy()
#             plt.scatter(z[:, 0], z[:, 1], c=y, cmap='tab10')
#             if i > num_batches:
#                 plt.colorbar()
#                 break
#         wandb.log({"Latent : Random labels in input": plt})
#     return min(z[:, 0]), max(z[:, 0]), min(z[:, 1]), max(z[:, 1])


# def plot_reconstructed_with_labels(autoencoder,random_labels, r0,
#                                    r1, n=24):
#     w = 64
#     img = np.zeros((n*w, n*w))
#     for i, y in enumerate(np.linspace(*r1, n)):
#         for j, x in enumerate(np.linspace(*r0, n)):
#             z = torch.Tensor([[x, y]]).to(device)
#             x_hat, label = autoencoder.decoder(z)
#             x_hat = x_hat.reshape(28, 28).to('cpu').detach().numpy()
#             img[(n-1-i)*w:(n-1-i+1)*w, j*w:(j+1)*w] = x_hat
#     plt.imshow(img, extent=[*r0, *r1])
#     if not random_labels:
#       wandb.log({"Reconstruction : Real labels in input": plt})
#     else:
#       wandb.log({"Reconstruction : Random labels in input": plt})
    
    

## Set up training

In [8]:

###initialize weights and bias tracking

def wandb_init(epochs, lr, train_mode, batch_size, model_number,data_set):
  wandb.init(project="ConceptualAlignmentLanguage", entity="psych-711",settings=wandb.Settings(start_method="thread"))
  wandb.config = {
    "learning_rate": lr,
    "epochs": epochs,
    "batch_size": batch_size, 
    # "label_ratio":label_ratio, 
    "model_number": model_number,
    "dataset": data_set,
    "train_mode":train_mode,
  }
  wandb.run.name = f'{data_set}_{train_mode}_{model_number}'
  wandb.run.save()
     

### Load in 3dshapes dataset and sample datasets

In [9]:
import h5py
from tqdm import tqdm, trange


# load dataset
dataset = h5py.File(os.path.join(data_dir,'3dshapes.h5'), 'r')
print(dataset.keys())
images = dataset['images']  # array shape [480000,64,64,3], uint8 in range(256)
labels = dataset['labels']  # array shape [480000,6], float64
image_shape = images.shape[1:]  # [64,64,3]
label_shape = labels.shape[1:]  # [6]
n_samples = labels.shape[0]  # 10*10*10*8*4*15=480000

_FACTORS_IN_ORDER = ['floor_hue', 'wall_hue', 'object_hue', 'scale', 'shape',
                     'orientation']
_NUM_VALUES_PER_FACTOR = {'floor_hue': 10, 'wall_hue': 10, 'object_hue': 10, 
                          'scale': 8, 'shape': 4, 'orientation': 15}

### return indices where labels[:,0] == 0 and labels[:,1] == 0

def get_indices(labels, floor_hue=0, wall_hue=0):
    indices = []
    for i in range(labels.shape[0]):
        if labels[i,0] == floor_hue and labels[i,1] == wall_hue:
            indices.append(i)
    return indices

red_bg_indices = get_indices(labels, floor_hue=0, wall_hue=0)
images_red_bg = images[red_bg_indices]
labels_red_bg = labels[red_bg_indices]

<KeysViewHDF5 ['images', 'labels']>


In [10]:
# # np.random.seed(seed=711)
# if not os.path.exists(data_dir):
#     os.makedirs(data_dir)


# set_A_dir = os.path.join(data_dir,'set_A')
# if not os.path.exists(set_A_dir):
#     os.makedirs(set_A_dir)
# set_B_dir = os.path.join(data_dir,'set_B')
# if not os.path.exists(set_B_dir):
#     os.makedirs(set_B_dir)
# set_C_dir = os.path.join(data_dir,'set_C')
# if not os.path.exists(set_C_dir):
#     os.makedirs(set_C_dir)
# validation_dir = os.path.join(data_dir,'validation_set')
# if not os.path.exists(validation_dir):
#     os.makedirs(validation_dir)

# set_A_labs=[]
# set_B_labs=[]
# set_C_labs=[]
# validation_labs = []

# for i in trange(4):
#     shape_inds = np.argwhere(labels_red_bg[:,4] == i)
#     np.random.seed(seed=i)
#     sub_inds = np.random.choice(shape_inds.flatten(),1200,replace=False)
#     for ind,j in tqdm(enumerate(sub_inds[0:600])):
#         plt.imsave(os.path.join(set_A_dir,f'shape_{i}_{j}.png'),images_red_bg[j,:,:,:],format='png')
#         set_A_labs.append(i)
#         if ind<300:
#             plt.imsave(os.path.join(set_B_dir,f'shape_{i}_{j}.png'),images_red_bg[j,:,:,:],format='png')
#             set_B_labs.append(i)
#         if ind<120:
           
#             plt.imsave(os.path.join(set_C_dir,f'shape_{i}_{j}.png'),images_red_bg[j,:,:,:],format='png')
#             set_C_labs.append(i)
#     for ind,j in tqdm(enumerate(sub_inds[600:])):
#         if ind<480:
#             plt.imsave(os.path.join(set_C_dir,f'shape_{i}_{j}.png'),images_red_bg[j,:,:,:],format='png')
#             set_B_labs.append(i)
#         if ind<300:
#             plt.imsave(os.path.join(set_B_dir,f'shape_{i}_{j}.png'),images_red_bg[j,:,:,:],format='png')
#             set_B_labs.append(i)
#         if ind>=480:
#             if not os.path.exists(validation_dir):
#                 os.makedirs(validation_dir)
#             plt.imsave(os.path.join(validation_dir,f'shape_{i}_{j}.png'),images_red_bg[j,:,:,:],format='png')
#             validation_labs.append(i)

        
            


In [11]:
really_run=True
if really_run== True:
    set_A_ims = np.empty((2400,64,64,3),dtype='uint8')
    set_B_ims = np.empty((2400,64,64,3),dtype='uint8')
    set_C_ims = np.empty((2400,64,64,3),dtype='uint8')
    validation_ims = np.empty((480,64,64,3),dtype='uint8')

    set_A_labs=np.empty((2400),dtype='uint8')
    set_B_labs=np.empty((2400),dtype='uint8')
    set_C_labs=np.empty((2400),dtype='uint8')
    validation_labs = np.empty((480),dtype='uint8')


    for i in trange(4):
        shape_inds = np.argwhere(labels_red_bg[:,4] == i)
        np.random.seed(seed=i)
        sub_inds = np.random.choice(shape_inds.flatten(),1200,replace=False)
        set_A_ims[(i*600)+0:(i*600)+600,:,:,:] = images_red_bg[sorted(sub_inds[0:600]),:,:,:]
        set_A_labs[(i*600)+0:(i*600)+600] = i
        print('base_done')
        set_B_ims[(i*600)+0:(i*600)+300,:,:,:] = images_red_bg[sorted(sub_inds[0:300]),:,:,:]
        set_B_ims[(i*600)+300:(i*600)+600,:,:,:] = images_red_bg[sorted(sub_inds[600:900]),:,:,:]
        set_B_labs[(i*600)+0:(i*600)+600] = i


        set_C_ims[(i*600)+0:(i*600)+120,:,:,:] = images_red_bg[sorted(sub_inds[0:120]),:,:,:]
        set_C_ims[(i*600)+120:(i*600)+600,:,:,:] = images_red_bg[sorted(sub_inds[600:1080]),:,:,:]
        set_C_labs[(i*600)+0:(i*600)+600] = i

        validation_ims[(i*120)+0:(i*120)+120,:,:,:] = images_red_bg[sorted(sub_inds[1080:1200]),:,:,:]
        validation_labs[(i*120)+0:(i*120)+120] = i



100%|██████████| 4/4 [00:00<00:00, 56.35it/s]

base_done
base_done
base_done
base_done





In [12]:
really_run=True

if really_run== True:
    np.save(os.path.join(data_dir,'set_A.npy'),set_A_ims)
    np.save(os.path.join(data_dir,'set_B.npy'),set_B_ims)
    np.save(os.path.join(data_dir,'set_C.npy'),set_C_ims)
    np.save(os.path.join(data_dir,'validation_set.npy'),validation_ims)

    np.save(os.path.join(data_dir,'set_A_labs.npy'),set_A_labs)
    np.save(os.path.join(data_dir,'set_B_labs.npy'),set_B_labs)
    np.save(os.path.join(data_dir,'set_C_labs.npy'),set_C_labs)
    np.save(os.path.join(data_dir,'validation_labs.npy'),validation_labs)

### Description of datasets

We have 4 datasets. The first is a 'base' set with 20000 images (5000 from each shape category). The second and third datasets have 50 and 20% overlap with the base set.
These 3 sets will be used for training

In [13]:
from torch.utils.data import TensorDataset,Dataset
def main_code(save_dir, num_models, epochs, num_classes, batch_size,
             lr, latent_dims):
  if os.path.isdir(save_dir):
    pass
  else:
    os.mkdir(save_dir)

  for data_set in ['set_A','set_B','set_C']:
    for i in tqdm(range(3)):
     # torch.manual_seed(0)
      train_mode=i
      for model in range(num_models):
        wandb_init(epochs, lr, train_mode, batch_size, model,data_set)

        if data_set=='set_A':
          train_data = TensorDataset(torch.tensor(set_A_ims.transpose(0,3,1,2)/255).float(), torch.tensor(set_A_labs).to(torch.int64))
        elif data_set=='set_B':
          train_data = TensorDataset(torch.tensor(set_B_ims.transpose(0,3,1,2)/255).float(), torch.tensor(set_B_labs).to(torch.int64))
        elif data_set=='set_C':
          train_data = TensorDataset(torch.tensor(set_C_ims.transpose(0,3,1,2)/255).float(), torch.tensor(set_C_labs).to(torch.int64))


        train_data, val_data = torch.utils.data.random_split(train_data, 
                                                            [2160, 240])
        train_data = torch.utils.data.DataLoader(train_data, 
                                                batch_size=batch_size,
                                              shuffle=True)
        val_data = torch.utils.data.DataLoader(val_data, 
                                                batch_size=batch_size,
                                              shuffle=True)
        test_data = TensorDataset(torch.tensor(validation_ims.transpose(0,3,1,2)/255).float(), torch.tensor(validation_labs).to(torch.int64))

        autoencoder = Autoencoder(latent_dims, num_classes).to(device) # GPU
        optimizer = torch.optim.Adam(autoencoder.parameters(), lr=lr, weight_decay=1e-05)
        train_img_loss, train_label_loss, val_img_loss, \
        val_label_loss ,train_losses, val_losses,  train_accuracy, \
        val_accuracy= autoencoder.training_loop(train_data = train_data,
                                                            test_data = val_data,
                                                            epochs = epochs,
                                                            optimizer = optimizer, 
                                                            train_mode = train_mode)




          #### To fix:

        # min_x, max_x, min_y, max_y = plot_latent_with_label(autoencoder, 
        #                                                     batch_size, 
        #                                                     data=val_data, 
        #                                                     random_labels = False,
        #                                                     num_classes = num_classes,
        #                                                     num_batches=100)
        # plt.clf()
        # plot_reconstructed_with_labels(autoencoder = autoencoder, 
        #                                r0=(min_x, max_x),
        #                               r1=(min_y, max_y), 
        #                                n=24, random_labels = False)
        # plt.clf()
        # min_x, max_x, min_y, max_y = plot_latent_with_label(autoencoder, 
        #                                                     batch_size, 
        #                                                     data=val_data, 
        #                                                     random_labels = True,
        #                                                     num_classes = num_classes,
        #                                                     num_batches=100)
        # plt.clf()
        # plot_reconstructed_with_labels(autoencoder = autoencoder, 
        #                                r0=(min_x, max_x),
        #                               r1=(min_y, max_y), 
        #                                n=24, random_labels = True)
        # plt.clf()
        print('val_img_loss:',val_img_loss,'val_total_loss:',val_losses,'accuracy:',val_accuracy)
        wandb.log({"train_img_loss": train_img_loss, 
                  "train_label_loss":train_label_loss, 
                  "val_img_loss":val_img_loss, 
                  "val_label_loss":val_label_loss, 
                  "train_losses":train_losses, 
                  "val_losses":val_losses, 
                  "train_accuracy":train_accuracy, 
                  "val_accuracy":val_accuracy})
        torch.save(autoencoder.state_dict(), os.path.join(save_dir,f'{data_set}_{train_mode}_{model}'))
        
      


## Train networks and upload results to wandb

In [14]:
# ae = Autoencoder(latent_dims= 6,num_classes = 4).to(device)
# ae.load_state_dict(torch.load(os.path.join(save_dir,'_0.6666666666666666_1'),map_location=torch.device('cpu')))

# plt.imshow(ae(x= train_data[5][0].unsqueeze(0),y=F.one_hot(train_data[5][1],4).float().unsqueeze(0))[0].detach().numpy().squeeze(0).transpose(1,2,0))

In [16]:
!wandb login --relogin
660abd59704e9a08c9603617c5a5b2a1d8ce8a52zc

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
Aborted!


In [15]:
num_classes = 4 # Number of unique class labels in the dataset
latent_dims = 10
epochs = 25
lr = 0.001
num_models = 2
batch_size = 128
save_dir = save_dir
main_code(save_dir, num_models, epochs, num_classes, batch_size,
             lr, latent_dims)

  0%|          | 0/3 [00:00<?, ?it/s][34m[1mwandb[0m: Currently logged in as: [33mkushinm[0m ([33mpsych-711[0m). Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 25/25 [02:46<00:00,  6.67s/it]

val_img_loss: [0.18525616079568863, 0.17222001403570175, 0.1399022415280342, 0.11122644692659378, 0.09276322647929192, 0.08009800314903259, 0.07079283520579338, 0.06385405361652374, 0.05862302891910076, 0.054563773795962334, 0.05046611651778221, 0.04770859517157078, 0.045469194650650024, 0.04319789819419384, 0.041500020772218704, 0.04002326354384422, 0.03873328119516373, 0.03757305257022381, 0.03634543716907501, 0.03543591871857643, 0.034669751301407814, 0.03378279507160187, 0.03299362398684025, 0.032324470579624176, 0.031785497441887856] val_total_loss: [1.5758552551269531, 1.5629064440727234, 1.530635118484497, 1.5015043020248413, 1.4832125306129456, 1.470754623413086, 1.4615532159805298, 1.4546535015106201, 1.4494353532791138, 1.4453771710395813, 1.4412797093391418, 1.4385225772857666, 1.4362837672233582, 1.434012770652771, 1.4323152899742126, 1.4308386445045471, 1.429548740386963, 1.42838853597641, 1.427160918712616, 1.4262514114379883, 1.4254852533340454, 1.4245982766151428, 1.423




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_accuracy,███▇▃▁▇██████████████████
train_img_loss,█▇▆▅▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
train_label_loss,█▇▆▃▁▃▄▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
train_losses,█▇▆▅▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▄▄▄▅▁▂███████████████████
val_img_loss,█▇▆▅▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
val_label_loss,▅▆▇▁▃▆▇██████████████████
val_losses,█▇▆▅▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01675325833333356, max=1.0)…

100%|██████████| 25/25 [02:41<00:00,  6.45s/it]
 33%|███▎      | 1/3 [05:40<11:21, 340.76s/it]

val_img_loss: [0.18545182049274445, 0.17416241019964218, 0.14259589463472366, 0.11010250449180603, 0.09491552785038948, 0.085183035582304, 0.07631847262382507, 0.06932688876986504, 0.06270585209131241, 0.05841946415603161, 0.05493750423192978, 0.05167369917035103, 0.04875880107283592, 0.04683132655918598, 0.04494379833340645, 0.043377285823225975, 0.04208546131849289, 0.041094908490777016, 0.04002176411449909, 0.03882495500147343, 0.03775973990559578, 0.0369113702327013, 0.0362201202660799, 0.03556317836046219, 0.034665198996663094] val_total_loss: [1.5734606385231018, 1.5621341466903687, 1.530652940273285, 1.4981804490089417, 1.4830055832862854, 1.4732723236083984, 1.4644078016281128, 1.4574162364006042, 1.4507951736450195, 1.446508765220642, 1.443026840686798, 1.4397630095481873, 1.4368480443954468, 1.434920608997345, 1.433033049106598, 1.4314665794372559, 1.4301746487617493, 1.4291841387748718, 1.4281110167503357, 1.4269142746925354, 1.4258490204811096, 1.4250006675720215, 1.4243093

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_accuracy,▁████████████████████████
train_img_loss,█▇▆▅▄▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
train_label_loss,▆█▁▄▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
train_losses,█▇▆▅▄▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_img_loss,█▇▆▅▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
val_label_loss,▃▁▆▇█████████████████████
val_losses,█▇▆▅▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016696960416666684, max=1.0…

100%|██████████| 25/25 [02:39<00:00,  6.37s/it]

val_img_loss: [0.18509335815906525, 0.16849316656589508, 0.1410764828324318, 0.1147007867693901, 0.10078802704811096, 0.08939846605062485, 0.08045357093214989, 0.07503926381468773, 0.06990045309066772, 0.0647803321480751, 0.06040352024137974, 0.056709857657551765, 0.0543824452906847, 0.05216251499950886, 0.04949158616364002, 0.04771292395889759, 0.0465100072324276, 0.04552767425775528, 0.0441682580858469, 0.04310975782573223, 0.04210711270570755, 0.04089801385998726, 0.03980468399822712, 0.039076851680874825, 0.037996141240000725] val_total_loss: [1.5595887303352356, 1.4695954322814941, 1.383208990097046, 1.288562536239624, 1.175717830657959, 1.1316810846328735, 1.1127373576164246, 1.0942260026931763, 1.0856550931930542, 1.0776416659355164, 1.0713701844215393, 1.0661036372184753, 1.061899185180664, 1.0581846237182617, 1.0538963079452515, 1.0509629845619202, 1.0480409264564514, 1.0460904240608215, 1.0430660247802734, 1.040698766708374, 1.0384318828582764, 1.0359929203987122, 1.033708095




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_accuracy,▁▃▄▆▇████████████████████
train_img_loss,█▇▆▅▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
train_label_loss,█▆▆▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_losses,█▇▆▅▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▂▆▁▇████████████████████
val_img_loss,█▇▆▅▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
val_label_loss,█▇▆▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_losses,█▇▆▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016752041666666172, max=1.0…

100%|██████████| 25/25 [02:34<00:00,  6.17s/it]
 67%|██████▋   | 2/3 [11:11<05:34, 334.94s/it]

val_img_loss: [0.18509335815906525, 0.16849316656589508, 0.1410764828324318, 0.1147007867693901, 0.10078802704811096, 0.08939846605062485, 0.08045357093214989, 0.07503926381468773, 0.06990045309066772, 0.0647803321480751, 0.06040352024137974, 0.056709857657551765, 0.0543824452906847, 0.05216251499950886, 0.04949158616364002, 0.04771292395889759, 0.0465100072324276, 0.04552767425775528, 0.0441682580858469, 0.04310975782573223, 0.04210711270570755, 0.04089801385998726, 0.03980468399822712, 0.039076851680874825, 0.037996141240000725] val_total_loss: [1.5595887303352356, 1.4695954322814941, 1.383208990097046, 1.288562536239624, 1.175717830657959, 1.1316810846328735, 1.1127373576164246, 1.0942260026931763, 1.0856550931930542, 1.0776416659355164, 1.0713701844215393, 1.0661036372184753, 1.061899185180664, 1.0581846237182617, 1.0538963079452515, 1.0509629845619202, 1.0480409264564514, 1.0460904240608215, 1.0430660247802734, 1.040698766708374, 1.0384318828582764, 1.0359929203987122, 1.033708095

In [None]:
ae= Autoencoder(10,4)

ae.load_state_dict(torch.load(os.path.join(save_dir,'base_0_0')))

In [None]:
plt.imshow(base_set_ims[1500]/255)

In [None]:
plt.imshow(ae(torch.from_numpy(base_set_ims[1500].transpose(2,0,1)).unsqueeze(0).float())[0].detach().squeeze(0).numpy().transpose(1,2,0))

In [None]:
np.argmax(ae(torch.from_numpy(base_set_ims[1500].transpose(2,0,1)).unsqueeze(0).float())[b1].detach().squeeze(0).numpy())

# Representational Similarity Analysis

Sid: If you check the results folder you will find 18 models. The way they are named is {dataset}\_{training mode}\_{network number}
Dataset refers to base, overlap_50, or overlap_20. The second dataset has 50% overlap with the base set and the second one has 20% overlap with the base set. The training mode has 3 numbers: 0 refers to a model trained only on image reconstruction loss, 1 refers to a model trained on image and label loss, and 3 refers to models trained only on label loss. All these models have predictable and interesting training trajectories that you can check on wandb. Ignore the network number for now, i.e., we can restrict our analyses to models with network number 0. So that gives us 3*3 = 9 unique models. We could do RSAs for all of these models using a validation set of images I've set up just before the main training loop.


I think your RSA code needs to be tinkered a bit because it still expects MNIST digits. You should have all the pieces in place. Could you try and run the RSA analyses on your machine/Colab pro?


## Helper functions

In [None]:
def custom_torch_RSM_fct(features):
  """
  Custom function to calculate representational similarity matrix (RSM) of a feature
  matrix using pairwise cosine similarity.

  Args:
    features: 2D torch.Tensor
      Feature matrix of size (nbr items x nbr features)

  Returns:
    rsm: 2D torch.Tensor
      Similarity matrix of size (nbr items x nbr items)
  """
  features = torch.from_numpy(features)
  num_items, num_features = features.shape

  rsm = torch.nn.functional.cosine_similarity(
      features.reshape(1, num_items, num_features),
      features.reshape(num_items, 1, num_features),
      dim=2
      )

  if not rsm.shape == (num_items, num_items):
    raise ValueError(
        f"RSM should be of shape ({num_items}, {num_items})"
        )
  return rsm

In [None]:
def wandb_rsm_init():
  wandb.init(project="ConceptualAlignmentLanguage", entity="psych-711")
  wandb.run.name = 'rsms'
  
def compute_rsms(save_dir, batch_size, latent_dims, num_classes, test_data):
  '''
  Computes RSMs of all the models
  '''
  #all_rsms = torch.rand((len(os.listdir(save_dir)), 10000 ,10000))
  all_rsms = torch.rand((len(os.listdir(save_dir)), 4000 ,4000))
  print(all_rsms.shape)
  model_names = []
  print('Computing RSMs')
  for i, model in tqdm(enumerate(os.listdir(save_dir))):
    model_names.append(model)
    autoencoder = Autoencoder(latent_dims, num_classes).to(device) # GPU
    autoencoder.eval()
    autoencoder.load_state_dict(torch.load(os.path.join(save_dir, model)))
    autoencoder.eval()
    #initialise arrays
    latent_representations = np.asarray([[0,0,0,0,0,0,0,0,0,0]])
    labels = np.asarray([0])
    for image_batch, label_batch in test_data:
        image_batch = image_batch.to(device)
        label_batch_one_hot = F.one_hot(label_batch, num_classes=num_classes)
        label_batch_one_hot = label_batch_one_hot.to(device)
        z = autoencoder.encoder.forward(image_batch, label_batch_one_hot.float()) 
        latent_representations = np.vstack((latent_representations, z.cpu().detach()))
        labels = np.concatenate((labels, label_batch.float().cpu().detach()))
    labels = labels[1:]    
    latent_representations = latent_representations[1:]
    sorted_idx = np.array([], dtype = np.int8)

    for label_class in range(num_classes):
      sorted_idx = np.concatenate((sorted_idx, np.where(labels==label_class)[0]))
    latent_representations = latent_representations[sorted_idx]
    rsm = custom_torch_RSM_fct(latent_representations)
    all_rsms[i] = rsm
    plt.imshow(rsm)
    plt.colorbar()
    wandb.log({"RSM : {}".format(model): plt})
    plt.clf()
    print('Done with RSMs')
  return all_rsms, model_names


def fill_upper_triangular(matrix):
  '''
  Fills in the upper trinagular matrix for a symmetric matrix which has
  lower traingular matrix filled iin
  '''
  items, features = matrix.shape
  assert(items==features)
  for i in range(items):
    for j in range(i+1, items):
      matrix[i][j] = matrix[j][i]
  return matrix

def log_rsm(matrix, var_name):
  plt.imshow(matrix)
  plt.colorbar()
  wandb.log({var_name: plt})
  plt.clf()
  

def compute_rsa(all_rsms, model_names):
  '''
  Computes pairwise RSA (and significance) for all the models given their rsms
  '''
  rsa_corr = torch.zeros([len(model_names), len(model_names)])
  rsa_pvalue = torch.zeros([len(model_names), len(model_names)])

  # Computes the lower triangular matrix irst 
  print('Computing RSA)')
  for i in tqdm(range(len(model_names))):
    for j in range(i +1):
      rsa_corr[i][j],rsa_pvalue[i][j]= rdm_correlation_spearman(all_rsms[i], 
                                                                all_rsms[j], 
                                                                rescale=False,
                                                                permutation=False, iter=1000)
  # Fills in the upper triangular matrix
  print('Done with RSA')
  rsa_corr = fill_upper_triangular(rsa_corr)
  rsa_pvalue = fill_upper_triangular(rsa_pvalue)
  log_rsm(rsa_corr, "RSA corr")
  log_rsm(rsa_pvalue, "RSA pvalue")
  return rsa_corr, rsa_pvalue


def main_analysis_loop(save_dir, batch_size, latent_dims, num_classes, 
                       test_data):
  '''
  Takes in all the models in a saved dir and computes pairwise rsa
  '''
  wandb_rsm_init()
  all_rsms, model_names = compute_rsms(save_dir, batch_size, latent_dims, 
                                       num_classes, test_data)
  rsa_corr, rsa_pvalue = compute_rsa(all_rsms, model_names)
  rsa_corr, rsa_pvalue = fill_upper_triangular(rsa_corr), fill_upper_triangular(rsa_pvalue)


## Run RSM analysis

In [None]:
num_classes = 4 # Number of unique class labels in the dataset
latent_dims = 10
batch_size = 1024

In [None]:
# test_data = torch.utils.data.DataLoader( 
#       torchvision.datasets.MNIST('./data', 
#             transform=torchvision.transforms.ToTensor(),
#             download=True, train = False),
#       batch_size=batch_size,
#       shuffle=False)

### Load validation arrays
validation_ims = np.load(os.path.join(data_dir,'validation_set.npy'))
validation_labs = np.load(os.path.join(data_dir,'validation_labs.npy'))

test_data = TensorDataset(torch.tensor(validation_ims.transpose(0,3,1,2)/255).float(), torch.tensor(validation_labs).to(torch.int64))

main_analysis_loop(save_dir, batch_size, latent_dims, num_classes, 
                       test_data)