Below is an implementation of an autoencoder written in PyTorch. We apply it to the MNIST dataset.

In [1]:
# o

# from google.colab import drive
# drive.mount('/content/drive')
# base_dir = '/content/drive/My Drive/gary_class_project/'
# save_dir = base_dir + 'results'

In [2]:
import os
base_dir = os.path.abspath(os.getcwd())
save_dir = os.path.join(base_dir,'results')


In [3]:
# !pip install wandb -q
# !pip install neurora -q

#wandb api key 18a861e71f78135d23eb672c08922edbfcb8d364

In [4]:
import torch; torch.manual_seed(0)
import wandb
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision
from tqdm import tqdm
import random
import numpy as np
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
from neurora.rdm_corr import rdm_correlation_spearman

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Main code to run code on various models while varying the amount of training data

## Setting up the model

In [6]:
class Encoder(nn.Module):
    def __init__(self, encoded_space_dim, num_classes):
        super().__init__()
        ""
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=0),
            nn.ReLU(True)
        )
        
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)
        self.encoder_labels_lin = nn.Linear(num_classes, num_classes//2)
        ### Linear section
        self.encoder_lin = nn.Sequential(
            nn.Linear(3 * 3 * 32 + num_classes//2, 128),
            nn.ReLU(True),
            nn.Linear(128, encoded_space_dim)
        )
        
        
    def forward(self, x, y):
        img_features = self.encoder_cnn(x)
        img_features = self.flatten(img_features)
        label_features = self.encoder_labels_lin(y)
        combined = torch.cat((img_features, label_features), dim = -1)
        out = self.encoder_lin(combined)
        return out

class Decoder(nn.Module):
    
    def __init__(self, encoded_space_dim, num_classes):
        super().__init__()
        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 3 * 3 * 32 + (num_classes)//2),
            nn.ReLU(True)
        )
        self.decoder_labels_lin = nn.Linear((num_classes)//2, num_classes)
        
        self.unflatten = nn.Unflatten(dim=1, 
        unflattened_size=(32, 3, 3))

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, 
            stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, 
            padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 3, stride=2, 
            padding=1, output_padding=1)
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)
        img_features = x[:, :-(num_classes//2)]
        label_features = x[:, -(num_classes//2):]
        img_features = self.unflatten(img_features)
        img_features = self.decoder_conv(img_features)
        img = torch.sigmoid(img_features)
        label = self.decoder_labels_lin(label_features)
        return img, label

In [7]:
class CustomLoss(nn.Module):
    def __init__(self):
        super().__init__()
        # self.gamma = nn.Parameter(torch.tensor([.5]))
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, img, pred_img, label, pred_label):
        mse_loss_img = ((img - pred_img)**2).sum()
        mse_loss_label = self.cross_entropy(pred_label, label.float())
        # loss = mse_loss_img * torch.sigmoid(self.gamma) + \
              # mse_loss_label * (1 - torch.sigmoid(self.gamma))
        loss = mse_loss_img + mse_loss_label#*torch.sigmoid(self.gamma)
        return mse_loss_img, mse_loss_label, loss

class Autoencoder(nn.Module):
    def __init__(self, latent_dims, num_classes):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder(latent_dims, num_classes)
        self.decoder = Decoder(latent_dims, num_classes)
        self.custom_loss = CustomLoss()
        self.num_classes = num_classes
    
    def forward(self, x, y):
        z = self.encoder(x, y)
        return self.decoder(z)

    def test_epoch(self, test_data):
    # Set evaluation mode for encoder and decoder
        self.eval()
        with torch.no_grad(): # No need to track the gradients
            # Define the lists to store the outputs for each batch
            test_img_loss = []
            test_label_loss = []
            total_test_loss = []
            for image_batch, label_batch in test_data:
                total = 0
                correct = 0
                # Move tensor to the proper device
                image_batch = image_batch.to(device)
                label_batch = F.one_hot(label_batch, num_classes=self.num_classes)
                label_batch = label_batch.to(device)
                pred_img, pred_label = self.forward(image_batch, label_batch.float()) 
                # Append the network output and the original image to the lists
                img_loss, label_loss, total_loss = self.custom_loss(image_batch,
                                                                pred_img, 
                                                                label_batch, 
                                                                pred_label)
                total += label_batch.size(0)
                correct += (torch.argmax(pred_label, dim = 1) == torch.argmax(label_batch, dim = 1)).sum().item()
                test_img_loss.append(img_loss.item())
                test_label_loss.append(label_loss.item())
                total_test_loss.append(total_loss.item())
        test_img_loss = sum(test_img_loss)/len(test_img_loss)
        test_label_loss = sum(test_label_loss)/len(test_label_loss)
        total_test_loss = sum(total_test_loss)/len(total_test_loss)
        test_accuracy = correct/total
        return test_img_loss, test_label_loss, total_test_loss, test_accuracy

    def train_epoch(self, train_data, optimizer, training_label_ratio):
        self.train()
        torch.manual_seed(0)
        train_img_loss = []
        train_label_loss = []
        train_loss = []
        correct = 0
        total = 0
        for image_batch, label_batch in train_data:
            # Move tensor to the proper device
            image_batch = image_batch.to(device)
            num_training_examples = label_batch.shape[0]
            num_non_label_training_examples = num_training_examples*(1-training_label_ratio)
            non_label_training_idx = random.sample(range(num_training_examples),int(num_non_label_training_examples))
            label_batch[[non_label_training_idx]] = self.num_classes - 1
            label_batch = F.one_hot(label_batch, num_classes=self.num_classes)
            label_batch = label_batch.to(device)
            optimizer.zero_grad()
            pred_img, pred_label = self.forward(image_batch, label_batch.float()) 
            # Append the network output and the original image to the lists
            img_loss, label_loss, total_loss = self.custom_loss(image_batch,
                                                            pred_img, 
                                                            label_batch, 
                                                            pred_label)
            total_loss.backward()
            optimizer.step()
            train_img_loss.append(img_loss.item())
            train_label_loss.append(label_loss.item())
            train_loss.append(total_loss.item())
            total += label_batch.size(0)
            correct += (torch.argmax(pred_label, dim = 1) == torch.argmax(label_batch, dim = 1)).sum().item()
        train_img_loss = sum(train_img_loss)/len(train_img_loss)
        train_label_loss = sum(train_label_loss)/len(train_label_loss)
        train_loss = sum(train_loss)/len(train_loss)
        train_accuracy = correct/total
        return train_img_loss, train_label_loss, train_loss, train_accuracy

    def training_loop(self, train_data, test_data,training_label_ratio,
                      epochs, optimizer):
        train_losses = []
        val_losses = []
        train_img_losses = []
        val_img_losses = []
        train_label_losses = []
        val_label_losses = []
        train_accuracies = []
        val_accuracies = []
        for epoch in tqdm(range(epochs)):
          train_img_loss, train_label_loss, train_loss, train_accuracy =self.train_epoch(train_data, optimizer, 
                                             training_label_ratio)
          val_img_loss, val_label_loss, val_loss, val_accuracy = self.test_epoch(test_data)
          train_losses.append(train_loss)
          val_losses.append(val_loss)
          train_img_losses.append(train_img_loss)
          val_img_losses.append(val_img_loss)
          train_label_losses.append(train_label_loss)
          val_label_losses.append(val_label_loss)
          train_accuracies.append(train_accuracy)
          val_accuracies.append(val_accuracy)
          wandb.log({"train_img_loss": train_img_loss, 
            "train_label_loss":train_label_loss, 
            "val_img_loss":val_img_loss, 
            "val_label_loss":val_label_loss, 
            "train_losses":train_loss, 
            "val_losses":val_loss, 
            "train_accuracy":train_accuracy, 
            "val_accuracy":val_accuracy})

        return train_img_losses, train_label_losses, val_img_losses, val_label_losses ,train_losses, val_losses, train_accuracies, val_accuracies

def plot_latent_with_label(autoencoder, batch_size, data, random_labels,
                           num_classes, num_batches=100):
    if not random_labels:
        for i, (x, y) in enumerate(data):
            x = x.to(device) # GPU
            y_one_hot = F.one_hot(y, num_classes)
            y_one_hot = y_one_hot.to(device)
            z = autoencoder.encoder(x, y_one_hot.float())
            z = z.to('cpu').detach().numpy()
            plt.scatter(z[:, 0], z[:, 1], c=y, cmap='tab10')
            if i > num_batches:
                plt.colorbar()
                break
        wandb.log({"Latent : Real labels in input": plt})
    else:
        for i, (x, y) in enumerate(data):
            x = x.to(device) # GPU
            y_rand = torch.zeros((x.size(0), num_classes))
            y_rand = y_rand.to(device)
            z = autoencoder.encoder(x, y_rand.float())
            z = z.to('cpu').detach().numpy()
            plt.scatter(z[:, 0], z[:, 1], c=y, cmap='tab10')
            if i > num_batches:
                plt.colorbar()
                break
        wandb.log({"Latent : Random labels in input": plt})
    return min(z[:, 0]), max(z[:, 0]), min(z[:, 1]), max(z[:, 1])


def plot_reconstructed_with_labels(autoencoder,random_labels, r0,
                                   r1, n=24):
    w = 28
    img = np.zeros((n*w, n*w))
    for i, y in enumerate(np.linspace(*r1, n)):
        for j, x in enumerate(np.linspace(*r0, n)):
            z = torch.Tensor([[x, y]]).to(device)
            x_hat, label = autoencoder.decoder(z)
            x_hat = x_hat.reshape(28, 28).to('cpu').detach().numpy()
            img[(n-1-i)*w:(n-1-i+1)*w, j*w:(j+1)*w] = x_hat
    plt.imshow(img, extent=[*r0, *r1])
    if not random_labels:
      wandb.log({"Reconstruction : Real labels in input": plt})
    else:
      wandb.log({"Reconstruction : Random labels in input": plt})
    
    

## Training multiple networks

In [8]:
def wandb_init(epochs, lr, label_ratio, batch_size, model_number):
  wandb.init(project="ConceptualAlignmentLangugae", entity="psych-711")
  wandb.config = {
    "learning_rate": lr,
    "epochs": epochs,
    "batch_size": batch_size, 
    "label_ratio":label_ratio, 
    "model_number": model_number 
  }
  wandb.run.name = '{}_{}'.format(label_ratio, model_number)
  wandb.run.save()
     

In [9]:
def main_code(save_dir, num_models, epochs, num_classes, batch_size,
             lr, latent_dims):
  if os.path.isdir(save_dir):
    pass
  else:
    os.mkdir(save_dir)

  for i in tqdm(range(0, 11)):
    torch.manual_seed(0)
    label_ratio = i/10
    for model in range(num_models):
      wandb_init(epochs, lr, label_ratio, batch_size, model)
      train_data = torch.utils.data.DataLoader( 
        torchvision.datasets.MNIST('./data', 
               transform=torchvision.transforms.ToTensor(), train = True, 
               download=True),
        batch_size=batch_size,
        shuffle=True)
      train_data, val_data = torch.utils.data.random_split(train_data.dataset, 
                                                           [50000, 10000])
      train_data = torch.utils.data.DataLoader(train_data, 
                                               batch_size=batch_size,
                                            shuffle=True)
      val_data = torch.utils.data.DataLoader(val_data, 
                                               batch_size=batch_size,
                                            shuffle=True)
      test_data = torch.utils.data.DataLoader( 
              torchvision.datasets.MNIST('./data', 
                    transform=torchvision.transforms.ToTensor(), train = False, 
                    download=True),
              batch_size=batch_size,
              shuffle=True)
      autoencoder = Autoencoder(latent_dims, num_classes).to(device) # GPU
      optimizer = torch.optim.Adam(autoencoder.parameters(), lr=lr, weight_decay=1e-05)
      train_img_loss, train_label_loss, val_img_loss, \
      val_label_loss ,train_losses, val_losses,  train_accuracy, \
      val_accuracy= autoencoder.training_loop(train_data = train_data,
                                                          test_data = test_data,
                                                          epochs = epochs,
                                                          optimizer = optimizer, 
                                                          training_label_ratio = label_ratio)
      # figure out which data we want to plot

      min_x, max_x, min_y, max_y = plot_latent_with_label(autoencoder, 
                                                          batch_size, 
                                                          data=val_data, 
                                                          random_labels = False,
                                                          num_classes = num_classes,
                                                          num_batches=100)
      plt.clf()
      plot_reconstructed_with_labels(autoencoder = autoencoder, 
                                     r0=(min_x, max_x),
                                    r1=(min_y, max_y), 
                                     n=24, random_labels = False)
      plt.clf()
      min_x, max_x, min_y, max_y = plot_latent_with_label(autoencoder, 
                                                          batch_size, 
                                                          data=val_data, 
                                                          random_labels = True,
                                                          num_classes = num_classes,
                                                          num_batches=100)
      plt.clf()
      plot_reconstructed_with_labels(autoencoder = autoencoder, 
                                     r0=(min_x, max_x),
                                    r1=(min_y, max_y), 
                                     n=24, random_labels = True)
      plt.clf()
      
      # wandb.log({"train_img_loss": train_img_loss, 
      #            "train_label_loss":train_label_loss, 
      #            "val_img_loss":val_img_loss, 
      #            "val_label_loss":val_label_loss, 
      #            "train_losses":train_losses, 
      #            "val_losses":val_losses, 
      #            "train_accuracy":train_accuracy, 
      #            "val_accuracy":val_accuracy})
      torch.save(autoencoder.state_dict(), save_dir+'_{}_{}'.format(label_ratio, 
                                                                    model))
      
      


## Main program code

In [10]:
import plotly

In [11]:
num_classes = 11 # Number of unique class labels in the dataset
latent_dims = 2
epochs = 1
lr = 0.001
num_models = 2
batch_size = 512
save_dir = base_dir + 'results/'
main_code(save_dir, num_models, epochs, num_classes, batch_size,
             lr, latent_dims)

  0%|          | 0/11 [00:00<?, ?it/s][34m[1mwandb[0m: Currently logged in as: [33mpsych-711[0m (use `wandb login --relogin` to force relogin)


100%|██████████| 1/1 [00:10<00:00, 10.69s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.






VBox(children=(Label(value='2.540 MB of 3.430 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.740563…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,1.0
train_img_loss,72703.18614
train_label_loss,1.33061
train_losses,72704.51654
val_accuracy,0.0
val_img_loss,53472.65469
val_label_loss,3.9636
val_losses,53476.61816


100%|██████████| 1/1 [00:10<00:00, 10.83s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.

  9%|▉         | 1/11 [00:49<08:15, 49.56s/it]




VBox(children=(Label(value='2.578 MB of 3.408 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.756585…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.58318
train_img_loss,80301.14565
train_label_loss,1.89472
train_losses,80303.04018
val_accuracy,0.0
val_img_loss,54386.52812
val_label_loss,2.80196
val_losses,54389.33037


100%|██████████| 1/1 [00:10<00:00, 10.95s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.






VBox(children=(Label(value='3.449 MB of 3.449 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.89844
train_img_loss,72685.30668
train_label_loss,1.49472
train_losses,72686.80178
val_accuracy,0.0
val_img_loss,53045.25605
val_label_loss,3.80821
val_losses,53049.06377


100%|██████████| 1/1 [00:11<00:00, 11.13s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.

 18%|█▊        | 2/11 [01:36<07:13, 48.20s/it]




VBox(children=(Label(value='3.004 MB of 3.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.52714
train_img_loss,80438.68886
train_label_loss,1.96261
train_losses,80440.65163
val_accuracy,0.0
val_img_loss,53757.43105
val_label_loss,2.83491
val_losses,53760.26592


100%|██████████| 1/1 [00:10<00:00, 10.93s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.






VBox(children=(Label(value='3.472 MB of 3.472 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.79882
train_img_loss,72715.44268
train_label_loss,1.65568
train_losses,72717.09853
val_accuracy,0.0
val_img_loss,53230.66064
val_label_loss,3.58645
val_losses,53234.24756


100%|██████████| 1/1 [00:12<00:00, 12.44s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.

 27%|██▋       | 3/11 [03:29<10:20, 77.57s/it]




wandb: Network error (ConnectionError), entering retry loop.


VBox(children=(Label(value='0.001 MB of 1.411 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.000628…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.47072
train_img_loss,80435.02985
train_label_loss,2.03223
train_losses,80437.0619
val_accuracy,0.0
val_img_loss,54093.99746
val_label_loss,2.7908
val_losses,54096.78857


100%|██████████| 1/1 [00:11<00:00, 11.70s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.






VBox(children=(Label(value='3.418 MB of 3.418 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.69922
train_img_loss,72657.43551
train_label_loss,1.80958
train_losses,72659.24494
val_accuracy,0.0
val_img_loss,52965.56299
val_label_loss,3.28046
val_losses,52968.84316


100%|██████████| 1/1 [00:11<00:00, 11.01s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.

 36%|███▋      | 4/11 [06:33<13:57, 119.61s/it]




VBox(children=(Label(value='2.944 MB of 2.944 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.41434
train_img_loss,80456.38006
train_label_loss,2.09791
train_losses,80458.47808
val_accuracy,0.0
val_img_loss,53841.15225
val_label_loss,2.76133
val_losses,53843.91338


100%|██████████| 1/1 [00:10<00:00, 10.95s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.






VBox(children=(Label(value='1.698 MB of 3.165 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.536457…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.5996
train_img_loss,72648.78105
train_label_loss,1.94965
train_losses,72650.73095
val_accuracy,0.0
val_img_loss,53020.01143
val_label_loss,3.0301
val_losses,53023.0415


100%|██████████| 1/1 [00:10<00:00, 10.83s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.

 45%|████▌     | 5/11 [07:26<09:32, 95.48s/it] 




VBox(children=(Label(value='1.377 MB of 2.839 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.485221…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.35958
train_img_loss,80423.03803
train_label_loss,2.16042
train_losses,80425.19866
val_accuracy,0.0
val_img_loss,53567.12324
val_label_loss,2.73751
val_losses,53569.86035


100%|██████████| 1/1 [00:10<00:00, 10.48s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.






VBox(children=(Label(value='3.240 MB of 3.240 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.5
train_img_loss,72682.73545
train_label_loss,2.07566
train_losses,72684.81103
val_accuracy,0.0
val_img_loss,53047.71123
val_label_loss,2.84512
val_losses,53050.55605


100%|██████████| 1/1 [00:10<00:00, 10.71s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.

 55%|█████▍    | 6/11 [08:13<06:35, 79.17s/it]




VBox(children=(Label(value='3.277 MB of 3.277 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.30318
train_img_loss,80398.4943
train_label_loss,2.22211
train_losses,80400.7162
val_accuracy,0.0
val_img_loss,53206.73613
val_label_loss,2.64986
val_losses,53209.38574


100%|██████████| 1/1 [00:11<00:00, 11.10s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.






VBox(children=(Label(value='2.813 MB of 3.592 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.783159…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.39844
train_img_loss,72651.41645
train_label_loss,2.18823
train_losses,72653.60471
val_accuracy,0.0
val_img_loss,52864.28105
val_label_loss,2.69542
val_losses,52866.97627


100%|██████████| 1/1 [00:11<00:00, 11.45s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.

 64%|██████▎   | 7/11 [09:02<04:36, 69.17s/it]




VBox(children=(Label(value='2.353 MB of 2.972 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.791833…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.2467
train_img_loss,80419.06154
train_label_loss,2.28604
train_losses,80421.34778
val_accuracy,0.0
val_img_loss,53802.70127
val_label_loss,2.58057
val_losses,53805.28203


100%|██████████| 1/1 [00:10<00:00, 10.56s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.






VBox(children=(Label(value='3.744 MB of 3.744 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.29882
train_img_loss,72671.13457
train_label_loss,2.28121
train_losses,72673.41582
val_accuracy,0.0
val_img_loss,52985.59697
val_label_loss,2.57877
val_losses,52988.17588


100%|██████████| 1/1 [00:11<00:00, 11.65s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.

 73%|███████▎  | 8/11 [09:50<03:07, 62.63s/it]




VBox(children=(Label(value='1.462 MB of 2.920 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.500456…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.18946
train_img_loss,80459.1431
train_label_loss,2.34137
train_losses,80461.4841
val_accuracy,0.0
val_img_loss,52989.80918
val_label_loss,2.51008
val_losses,52992.31924


100%|██████████| 1/1 [00:10<00:00, 10.72s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.






VBox(children=(Label(value='2.697 MB of 3.713 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.726385…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.19922
train_img_loss,72692.71205
train_label_loss,2.3547
train_losses,72695.06688
val_accuracy,0.0
val_img_loss,53398.13359
val_label_loss,2.47698
val_losses,53400.61113


100%|██████████| 1/1 [00:10<00:00, 10.52s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.

 82%|████████▏ | 9/11 [10:41<01:57, 58.76s/it]




VBox(children=(Label(value='2.682 MB of 3.325 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.806658…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.13462
train_img_loss,80402.19683
train_label_loss,2.38282
train_losses,80404.5798
val_accuracy,0.0
val_img_loss,53347.71445
val_label_loss,2.44824
val_losses,53350.16299


100%|██████████| 1/1 [00:10<00:00, 10.80s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.






VBox(children=(Label(value='3.710 MB of 3.710 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.0996
train_img_loss,72679.14178
train_label_loss,2.40329
train_losses,72681.54496
val_accuracy,0.0
val_img_loss,53268.81074
val_label_loss,2.42703
val_losses,53271.2376


100%|██████████| 1/1 [00:10<00:00, 10.58s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.

 91%|█████████ | 10/11 [11:25<00:54, 54.50s/it]




VBox(children=(Label(value='2.346 MB of 3.478 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.674594…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.08342
train_img_loss,80349.18427
train_label_loss,2.41312
train_losses,80351.59738
val_accuracy,0.12868
val_img_loss,54143.52305
val_label_loss,2.36013
val_losses,54145.8834


100%|██████████| 1/1 [00:10<00:00, 10.45s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.






VBox(children=(Label(value='3.847 MB of 3.847 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train_accuracy,▁
train_img_loss,▁
train_label_loss,▁
train_losses,▁
val_accuracy,▁
val_img_loss,▁
val_label_loss,▁
val_losses,▁

0,1
train_accuracy,0.0222
train_img_loss,72659.83566
train_label_loss,2.42466
train_losses,72662.26044
val_accuracy,0.12132
val_img_loss,53292.30674
val_label_loss,2.38425
val_losses,53294.69082


100%|██████████| 1/1 [00:10<00:00, 10.50s/it]


The get_offset_position function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.

100%|██████████| 11/11 [12:16<00:00, 66.97s/it]


<Figure size 1200x800 with 0 Axes>

wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.
wandb: Network error (ConnectionError), entering retry loop.
Exception in thread Stat

# Code to compute RDMs and their correlations

## Setting up the model

In [5]:
class Encoder(nn.Module):
    def __init__(self, encoded_space_dim, num_classes):
        super().__init__()
        
        ### Convolutional section
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(1, 8, 3, stride=2, padding=1),
            nn.ReLU(True),
            nn.Conv2d(8, 16, 3, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 3, stride=2, padding=0),
            nn.ReLU(True)
        )
        
        ### Flatten layer
        self.flatten = nn.Flatten(start_dim=1)
        self.encoder_labels_lin = nn.Linear(num_classes, num_classes//2)
        ### Linear section
        self.encoder_lin = nn.Sequential(
            nn.Linear(3 * 3 * 32 + num_classes//2, 128),
            nn.ReLU(True),
            nn.Linear(128, encoded_space_dim)
        )
        
        
    def forward(self, x, y):
        img_features = self.encoder_cnn(x)
        img_features = self.flatten(img_features)
        label_features = self.encoder_labels_lin(y)
        combined = torch.cat((img_features, label_features), dim = -1)
        out = self.encoder_lin(combined)
        return out

class Decoder(nn.Module):
    
    def __init__(self, encoded_space_dim, num_classes):
        super().__init__()
        self.decoder_lin = nn.Sequential(
            nn.Linear(encoded_space_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 3 * 3 * 32 + (num_classes)//2),
            nn.ReLU(True)
        )
        self.decoder_labels_lin = nn.Linear((num_classes)//2, num_classes)
        
        self.unflatten = nn.Unflatten(dim=1, 
        unflattened_size=(32, 3, 3))

        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, 
            stride=2, output_padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 3, stride=2, 
            padding=1, output_padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 3, stride=2, 
            padding=1, output_padding=1)
        )
        
    def forward(self, x):
        x = self.decoder_lin(x)
        img_features = x[:, :-(num_classes//2)]
        label_features = x[:, -(num_classes//2):]
        img_features = self.unflatten(img_features)
        img_features = self.decoder_conv(img_features)
        img = torch.sigmoid(img_features)
        label = self.decoder_labels_lin(label_features)
        return img, label

In [6]:
class CustomLoss(nn.Module):
    def __init__(self):
        super().__init__()
        # self.gamma = nn.Parameter(torch.tensor([.5]))
        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, img, pred_img, label, pred_label):
        mse_loss_img = ((img - pred_img)**2).sum()
        mse_loss_label = self.cross_entropy(pred_label, label.float())
        # loss = mse_loss_img * torch.sigmoid(self.gamma) + \
              # mse_loss_label * (1 - torch.sigmoid(self.gamma))
        loss = mse_loss_img + mse_loss_label#*torch.sigmoid(self.gamma)
        return mse_loss_img, mse_loss_label, loss

class Autoencoder(nn.Module):
    def __init__(self, latent_dims, num_classes):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder(latent_dims, num_classes)
        self.decoder = Decoder(latent_dims, num_classes)
        self.custom_loss = CustomLoss()
        self.num_classes = num_classes
    
    def forward(self, x, y):
        z = self.encoder(x, y)
        return self.decoder(z)

    def test_epoch(self, test_data):
    # Set evaluation mode for encoder and decoder
        self.eval()
        with torch.no_grad(): # No need to track the gradients
            # Define the lists to store the outputs for each batch
            test_img_loss = []
            test_label_loss = []
            total_test_loss = []
            for image_batch, label_batch in test_data:
                total = 0
                correct = 0
                # Move tensor to the proper device
                image_batch = image_batch.to(device)
                label_batch = F.one_hot(label_batch, num_classes=self.num_classes)
                label_batch = label_batch.to(device)
                pred_img, pred_label = self.forward(image_batch, label_batch.float()) 
                # Append the network output and the original image to the lists
                img_loss, label_loss, total_loss = self.custom_loss(image_batch,
                                                                pred_img, 
                                                                label_batch, 
                                                                pred_label)
                total += label_batch.size(0)
                correct += (torch.argmax(pred_label, dim = 1) == torch.argmax(label_batch, dim = 1)).sum().item()
                test_img_loss.append(img_loss.item())
                test_label_loss.append(label_loss.item())
                total_test_loss.append(total_loss.item())
        test_img_loss = sum(test_img_loss)/len(test_img_loss)
        test_label_loss = sum(test_label_loss)/len(test_label_loss)
        total_test_loss = sum(total_test_loss)/len(total_test_loss)
        test_accuracy = correct/total
        return test_img_loss, test_label_loss, total_test_loss, test_accuracy

    def train_epoch(self, train_data, optimizer, training_label_ratio):
        self.train()
        torch.manual_seed(0)
        train_img_loss = []
        train_label_loss = []
        train_loss = []
        correct = 0
        total = 0
        for image_batch, label_batch in train_data:
            # Move tensor to the proper device
            image_batch = image_batch.to(device)
            num_training_examples = label_batch.shape[0]
            num_non_label_training_examples = num_training_examples*(1-training_label_ratio)
            non_label_training_idx = random.sample(range(num_training_examples),int(num_non_label_training_examples))
            label_batch[[non_label_training_idx]] = self.num_classes - 1
            label_batch = F.one_hot(label_batch, num_classes=self.num_classes)
            label_batch = label_batch.to(device)
            optimizer.zero_grad()
            pred_img, pred_label = self.forward(image_batch, label_batch.float()) 
            # Append the network output and the original image to the lists
            img_loss, label_loss, total_loss = self.custom_loss(image_batch,
                                                            pred_img, 
                                                            label_batch, 
                                                            pred_label)
            total_loss.backward()
            optimizer.step()
            train_img_loss.append(img_loss.item())
            train_label_loss.append(label_loss.item())
            train_loss.append(total_loss.item())
            total += label_batch.size(0)
            correct += (torch.argmax(pred_label, dim = 1) == torch.argmax(label_batch, dim = 1)).sum().item()
        train_img_loss = sum(train_img_loss)/len(train_img_loss)
        train_label_loss = sum(train_label_loss)/len(train_label_loss)
        train_loss = sum(train_loss)/len(train_loss)
        train_accuracy = correct/total
        return train_img_loss, train_label_loss, train_loss, train_accuracy

    def training_loop(self, train_data, test_data,training_label_ratio,
                      epochs, optimizer):
        train_losses = []
        val_losses = []
        train_img_losses = []
        val_img_losses = []
        train_label_losses = []
        val_label_losses = []
        train_accuracies = []
        val_accuracies = []
        for epoch in tqdm(range(epochs)):
          train_img_loss, train_label_loss, train_loss, train_accuracy =self.train_epoch(train_data, optimizer, 
                                             training_label_ratio)
          val_img_loss, val_label_loss, val_loss, val_accuracy = self.test_epoch(test_data)
          train_losses.append(train_loss)
          val_losses.append(val_loss)
          train_img_losses.append(train_img_loss)
          val_img_losses.append(val_img_loss)
          train_label_losses.append(train_label_loss)
          val_label_losses.append(val_label_loss)
          train_accuracies.append(train_accuracy)
          val_accuracies.append(val_accuracy)
          wandb.log({"train_img_loss": train_img_loss, 
            "train_label_loss":train_label_loss, 
            "val_img_loss":val_img_loss, 
            "val_label_loss":val_label_loss, 
            "train_losses":train_loss, 
            "val_losses":val_loss, 
            "train_accuracy":train_accuracy, 
            "val_accuracy":val_accuracy})

        return train_img_losses, train_label_losses, val_img_losses, val_label_losses ,train_losses, val_losses, train_accuracies, val_accuracies

def plot_latent_with_label(autoencoder, batch_size, data, random_labels,
                           num_classes, num_batches=100):
    if not random_labels:
        for i, (x, y) in enumerate(data):
            x = x.to(device) # GPU
            y_one_hot = F.one_hot(y, num_classes)
            y_one_hot = y_one_hot.to(device)
            z = autoencoder.encoder(x, y_one_hot.float())
            z = z.to('cpu').detach().numpy()
            plt.scatter(z[:, 0], z[:, 1], c=y, cmap='tab10')
            if i > num_batches:
                plt.colorbar()
                break
        wandb.log({"Latent : Real labels in input": plt})
    else:
        for i, (x, y) in enumerate(data):
            x = x.to(device) # GPU
            y_rand = torch.zeros((x.size(0), num_classes))
            y_rand = y_rand.to(device)
            z = autoencoder.encoder(x, y_rand.float())
            z = z.to('cpu').detach().numpy()
            plt.scatter(z[:, 0], z[:, 1], c=y, cmap='tab10')
            if i > num_batches:
                plt.colorbar()
                break
        wandb.log({"Latent : Random labels in input": plt})
    return min(z[:, 0]), max(z[:, 0]), min(z[:, 1]), max(z[:, 1])


def plot_reconstructed_with_labels(autoencoder,random_labels, r0,
                                   r1, n=24):
    w = 28
    img = np.zeros((n*w, n*w))
    for i, y in enumerate(np.linspace(*r1, n)):
        for j, x in enumerate(np.linspace(*r0, n)):
            z = torch.Tensor([[x, y]]).to(device)
            x_hat, label = autoencoder.decoder(z)
            x_hat = x_hat.reshape(28, 28).to('cpu').detach().numpy()
            img[(n-1-i)*w:(n-1-i+1)*w, j*w:(j+1)*w] = x_hat
    plt.imshow(img, extent=[*r0, *r1])
    if not random_labels:
      wandb.log({"Reconstruction : Real labels in input": plt})
    else:
      wandb.log({"Reconstruction : Random labels in input": plt})
    
    

## Helper functions

In [11]:
def custom_torch_RSM_fct(features):
  """
  Custom function to calculate representational similarity matrix (RSM) of a feature
  matrix using pairwise cosine similarity.

  Args:
    features: 2D torch.Tensor
      Feature matrix of size (nbr items x nbr features)

  Returns:
    rsm: 2D torch.Tensor
      Similarity matrix of size (nbr items x nbr items)
  """
  features = torch.from_numpy(features)
  num_items, num_features = features.shape

  rsm = torch.nn.functional.cosine_similarity(
      features.reshape(1, num_items, num_features),
      features.reshape(num_items, 1, num_features),
      dim=2
      )

  if not rsm.shape == (num_items, num_items):
    raise ValueError(
        f"RSM should be of shape ({num_items}, {num_items})"
        )
  return rsm

In [12]:
def wandb_rsm_init():
  wandb.init(project="ConceptualAlignmentLangugae", entity="psych-711")
  wandb.run.name = 'rsms'
def compute_rsms(save_dir, batch_size, latent_dims, num_classes, test_data):
  '''
  Computes RSMs of all the models
  '''
  all_rsms = torch.rand((len(os.listdir(save_dir)), 10000 ,10000))
  print(all_rsms.shape)
  model_names = []
  print('Computing RSMs')
  for i, model in tqdm(enumerate(os.listdir(save_dir))):
    model_names.append(model)
    autoencoder = Autoencoder(latent_dims, num_classes).to(device) # GPU
    autoencoder.eval()
    autoencoder.load_state_dict(torch.load(os.path.join(save_dir, model)))
    autoencoder.eval()
    #initialise arrays
    latent_representations = np.asarray([[0, 0]])
    labels = np.asarray([0])
    for image_batch, label_batch in test_data:
        image_batch = image_batch.to(device)
        label_batch_one_hot = F.one_hot(label_batch, num_classes=num_classes)
        label_batch_one_hot = label_batch_one_hot.to(device)
        z = autoencoder.encoder.forward(image_batch, label_batch_one_hot.float()) 
        latent_representations = np.vstack((latent_representations, z.cpu().detach()))
        labels = np.concatenate((labels, label_batch.float().cpu().detach()))
    labels = labels[1:]    
    latent_representations = latent_representations[1:]
    sorted_idx = np.array([], dtype = np.int8)

    for label_class in range(num_classes):
      sorted_idx = np.concatenate((sorted_idx, np.where(labels==label_class)[0]))
    latent_representations = latent_representations[sorted_idx]
    rsm = custom_torch_RSM_fct(latent_representations)
    all_rsms[i] = rsm
    plt.imshow(rsm)
    plt.colorbar()
    wandb.log({"RSM : {}".format(model): plt})
    plt.clf()
    print('Done with RSMs')
  return all_rsms, model_names


def fill_upper_triangular(matrix):
  '''
  Fills in the upper trinagular matrix for a symmetric matrix which has
  lower traingular matrix filled iin
  '''
  items, features = matrix.shape
  assert(items==features)
  for i in range(items):
    for j in range(i+1, items):
      matrix[i][j] = matrix[j][i]
  return matrix

def log_rsm(matrix, var_name):
  plt.imshow(matrix)
  plt.colorbar()
  wandb.log({var_name: plt})
  plt.clf()
  

def compute_rsa(all_rsms, model_names):
  '''
  Computes pairwise RSA (and significance) for all the models given their rsms
  '''
  rsa_corr = torch.zeros([len(model_names), len(model_names)])
  rsa_pvalue = torch.zeros([len(model_names), len(model_names)])

  # Computes the lower triangular matrix irst 
  print('Computing RSA)')
  for i in tqdm(range(len(model_names))):
    for j in range(i +1):
      rsa_corr[i][j],rsa_pvalue[i][j]= rdm_correlation_spearman(all_rsms[i], 
                                                                all_rsms[j], 
                                                                rescale=False,
                                                                permutation=False, iter=1000)
  # Fills in the upper triangular matrix
  print('Done with RSA')
  rsa_corr = fill_upper_triangular(rsa_corr)
  rsa_pvalue = fill_upper_triangular(rsa_pvalue)
  log_rsm(rsa_corr, "RSA corr")
  log_rsm(rsa_pvalue, "RSA pvalue")
  return rsa_corr, rsa_pvalue


def main_analysis_loop(save_dir, batch_size, latent_dims, num_classes, 
                       test_data):
  '''
  Takes in all the models in a saved dir and computes pairwise rsa
  '''
  wandb_rsm_init()
  all_rsms, model_names = compute_rsms(save_dir, batch_size, latent_dims, 
                                       num_classes, test_data)
  rsa_corr, rsa_pvalue = compute_rsa(all_rsms, model_names)
  rsa_corr, rsa_pvalue = fill_upper_triangular(rsa_corr), fill_upper_triangular(rsa_pvalue)


## Run RSM analysis

In [13]:
num_classes = 11 # Number of unique class labels in the dataset
latent_dims = 2
batch_size = 1024

In [None]:
test_data = torch.utils.data.DataLoader( 
      torchvision.datasets.MNIST('./data', 
            transform=torchvision.transforms.ToTensor(),
            download=True, train = False),
      batch_size=batch_size,
      shuffle=False)
main_analysis_loop(save_dir, batch_size, latent_dims, num_classes, 
                       test_data)

VBox(children=(Label(value='1.512 MB of 1.512 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

torch.Size([22, 10000, 10000])
Computing RSMs


1it [00:09,  9.97s/it]

Done with RSMs


2it [00:20, 10.19s/it]

Done with RSMs


3it [00:30, 10.28s/it]

Done with RSMs


4it [00:40, 10.24s/it]

Done with RSMs


5it [00:51, 10.20s/it]

Done with RSMs


6it [01:00, 10.11s/it]

Done with RSMs


7it [01:10, 10.07s/it]

Done with RSMs


8it [01:20,  9.99s/it]

Done with RSMs


9it [01:30,  9.97s/it]

Done with RSMs


10it [01:41, 10.11s/it]

Done with RSMs


11it [01:51, 10.06s/it]

Done with RSMs


12it [02:00, 10.00s/it]

Done with RSMs


13it [02:11, 10.05s/it]

Done with RSMs


14it [02:21, 10.04s/it]

Done with RSMs


15it [02:30,  9.97s/it]

Done with RSMs


16it [02:40,  9.97s/it]

Done with RSMs


17it [02:50,  9.93s/it]

Done with RSMs


18it [03:00,  9.89s/it]

Done with RSMs


19it [03:10,  9.87s/it]

Done with RSMs


20it [03:20, 10.00s/it]

Done with RSMs


21it [03:30,  9.98s/it]

Done with RSMs


22it [03:41, 10.06s/it]


Done with RSMs
Computing RSA)


  0%|          | 0/22 [00:00<?, ?it/s]