In this notebook, we load in the pre trained regular AE model, perform K means clustering on the encoder output and then use it to initialize the VQVAE or SOMVAE

In [1]:
from google.colab import drive
drive.mount("/content/drive",force_remount=True)

Mounted at /content/drive


In [2]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Thu May 20 01:06:13 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
import os 
os.chdir('drive/MyDrive/OASIS_Data')

In [4]:
import copy 
import time
import datetime
import numpy as np 
import pandas as pd 
import nibabel as nib
import sklearn as skl
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [5]:
import torch 
from torch.utils.data import Dataset, DataLoader
from torch import nn 
from torch import optim

In [6]:
pip install torchio nilearn

Collecting torchio
[?25l  Downloading https://files.pythonhosted.org/packages/2f/ab/fe4db746dbc1bd4a1fa37b6c7bb8ab6568cd1cc2b324f8140015c6cb389e/torchio-0.18.39-py2.py3-none-any.whl (143kB)
[K     |████████████████████████████████| 153kB 8.7MB/s 
[?25hCollecting nilearn
[?25l  Downloading https://files.pythonhosted.org/packages/4a/bd/2ad86e2c00ecfe33b86f9f1f6d81de8e11724e822cdf1f5b2d0c21b787f1/nilearn-0.7.1-py3-none-any.whl (3.0MB)
[K     |████████████████████████████████| 3.1MB 34.5MB/s 
[?25hCollecting Deprecated
  Downloading https://files.pythonhosted.org/packages/fb/73/994edfcba74443146c84b91921fcc269374354118d4f452fb0c54c1cbb12/Deprecated-1.2.12-py2.py3-none-any.whl
Collecting SimpleITK<2
[?25l  Downloading https://files.pythonhosted.org/packages/4a/ee/638b6bae2db10e5ef4ca94c95bb29ec25aa37a9d721b47f91077d7e985e0/SimpleITK-1.2.4-cp37-cp37m-manylinux1_x86_64.whl (42.5MB)
[K     |████████████████████████████████| 42.5MB 118kB/s 
Installing collected packages: Deprecated, Sim

In [7]:
import torchio as tio 

In [8]:
os.chdir("MRI_SOMVAE")
from OASISDataset import *
from FullModels import *
from TrainingTesting import *
from ExtraMetrics import *
os.chdir("..")

Split data into train and val (same seed as before) 

In [9]:
fulltrainlist = pd.read_csv("oasis_ctrl_training.csv")

trainlist,vallist = train_test_split(fulltrainlist, random_state=128, test_size=0.15)

folder = "ctrl_original_resolution"

In [10]:
start = time.time()
traindataset = OASISDataset_Colab(folder,trainlist)
print(time.time()-start)
valdataset = OASISDataset_Colab(folder,vallist)
print(time.time()-start)

153.760835647583
180.37988877296448


In [12]:
batch_size=32

train_dataloader = DataLoader(traindataset,batch_size=batch_size,shuffle=True)
val_dataloader = DataLoader(valdataset,batch_size=len(valdataset),shuffle=True)

In [13]:
AEcheckpoint = torch.load("VanillaAE32_Run2_4filters_051921.tar")

RegularAE = VanillaAE(num_channels=1,embedding_dim=32,num_filters=4,batchnorm=True)
RegularAE.load_state_dict(AEcheckpoint["model_state_dict"])

<All keys matched successfully>

Train a K Means model on the encoder output of 1 mini-batch 

In [14]:
Xin1,Xout1 = next(iter(train_dataloader))
Xenc1 = RegularAE.encoder(Xin1)
Xenc1_flat = Xenc1.reshape(Xenc1.shape[0],Xenc1.shape[1],-1).permute(0,2,1).reshape(-1,Xenc1.shape[1]).detach().numpy()

In [16]:
from sklearn.cluster import KMeans

In [17]:
start = time.time()
Kmeans = KMeans(n_clusters=256)
Kmeans.fit(Xenc1_flat)
centers_start = torch.from_numpy(Kmeans.cluster_centers_)
print(time.time() - start)

49.576647996902466


In [18]:
centers_start

tensor([[  2.2105,   0.3225,  -0.7218,  ...,   0.8396, -12.1103,   2.2706],
        [ -3.1668,  -2.4273,  -0.7352,  ...,   1.6748, -13.2872,  -1.2167],
        [  0.2619,  -0.7764,  -0.9722,  ...,   1.5657,  -3.5245,   2.3292],
        ...,
        [  0.6898,   0.1373,  -1.1828,  ...,   1.6635, -13.4390,   1.1649],
        [  1.0317,  -2.0326,   1.0812,  ...,   2.7094,  -4.3716,   3.3821],
        [  0.9331,  -1.7593,   0.2445,  ...,   1.1063,  -5.2692,   1.8137]])

Initialize SOM VAE and copy in the weights from the AE

In [20]:
SOMVAE = SOMVAE3D(num_channels = 1,num_filters = 4,embedding_dim = 32, num_embeddings = 256, som_h = 16,som_w = 16,alpha = 6, beta =1,batchnorm=True)
SOMVAE.encoder = copy.deepcopy(RegularAE.encoder)
SOMVAE.decoder = copy.deepcopy(RegularAE.decoder)

In [21]:
SOMVAE.quantization._embedding.weight

Parameter containing:
tensor([[-1.1762,  0.9742, -2.1251,  ...,  1.5123, -1.6340,  0.5663],
        [ 0.1179,  0.3514, -0.7043,  ...,  1.6982,  1.3819, -0.5051],
        [ 0.6513,  1.2124,  0.6125,  ..., -0.9884,  1.0337,  0.8469],
        ...,
        [-1.0118,  1.0978, -0.4505,  ...,  1.2020, -1.3886,  0.8329],
        [-0.0281, -2.3820, -0.6618,  ...,  0.3050, -0.2173,  0.8830],
        [ 0.0293,  1.1003, -1.8277,  ...,  0.4341,  0.9280, -0.8262]],
       requires_grad=True)

In [22]:
SOMVAE.quantization._embedding.weight.data = copy.deepcopy(centers_start)

Train the SOM-VAE

In [24]:
optimizer = optim.Adam(SOMVAE.parameters(),lr=0.005)

max_epochs = 10
train_losses = np.zeros(max_epochs)
train_losses[:] = np.NaN 
val_losses = np.zeros(max_epochs)
val_losses[:] = np.NaN 

train_PSNR = copy.deepcopy(train_losses)
val_PSNR = copy.deepcopy(val_losses)

train_SSIM = copy.deepcopy(train_losses)
val_SSIM = copy.deepcopy(val_losses)

loss_function = nn.MSELoss(reduction="mean")

In [25]:
for epoch in range(max_epochs):

    t0 = time.time()

    #train_losses[epoch] = train(NewVQVAE,optimizer,loss_function,train_dataloader,
    #                epoch,log_every_num_batches=5)
    
    train_losses[epoch] = train_NewVQVAE(SOMVAE,optimizer,loss_function,train_dataloader, epoch,log_every_num_batches=5,lam_ze = 0.8, lam_zq = 0.2)
    
    #val_losses[epoch] = test(NewVQVAE,loss_function,val_dataloader)
    val_losses[epoch] = test_NewVQVAE(SOMVAE,loss_function,val_dataloader)

    temp_trainSSIM = []
    temp_trainPSNR = []
    temp_valSSIM = []
    temp_valPSNR = []

    with torch.no_grad():
      use_cuda = torch.cuda.is_available()
      device = torch.device("cuda:0" if use_cuda else "cpu")
      torch.backends.cudnn.benchmark = True
      if use_cuda:
        if torch.cuda.device_count() > 1:
          print("Let's use", torch.cuda.device_count(), "GPUs!")
          # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
          SOMVAE = nn.DataParallel(SOMVAE) #NewVQVAE = nn.DataParallel(NewVQVAE)
      SOMVAE.to(device)#NewVQVAE.to(device)
      SOMVAE.eval()#NewVQVAE.eval()
      for Xin_train,Xout_train in train_dataloader:
          Xin_train,Xout_train = Xin_train.to(device), Xout_train.to(device)
          Xrecon_train = SOMVAE(Xin_train)["x_out_ze"]#NewVQVAE(Xin_train)["x_out_ze"]
          train_psnr_batch = PSNR(Xout_train,Xrecon_train)
          temp_trainPSNR.append(train_psnr_batch)
          #if epoch % 2 == 0:
          train_ssim_batch = SSIM(Xout_train,Xrecon_train)
          temp_trainSSIM.append(train_ssim_batch)
      train_PSNR[epoch] = torch.cat(temp_trainPSNR).mean().item()
      train_SSIM[epoch] = torch.cat(temp_trainSSIM).mean().item()
      
      for Xin_val,Xout_val in val_dataloader:
          Xin_val,Xout_val = Xin_val.to(device),Xout_val.to(device)
          Xrecon_val = SOMVAE(Xin_val)["x_out_ze"] #NewVQVAE(Xin_val)["x_out_ze"]
          val_psnr_batch = PSNR(Xout_val,Xrecon_val)
          temp_valPSNR.append(val_psnr_batch)
          #if epoch % 2 ==0:
          val_ssim_batch = SSIM(Xout_val,Xrecon_val)
          temp_valSSIM.append(val_ssim_batch)
      val_PSNR[epoch] = torch.cat(temp_valPSNR).mean().item()
      val_SSIM[epoch] = torch.cat(temp_valSSIM).mean().item()

    dtepoch = time.time() - t0

    print('====> Total time elapsed for this epoch: {:s}'.format(str(datetime.timedelta(seconds=int(dtepoch)))))

====> Begin epoch 1


====> Epoch: 1 Average loss: 9.4711	Time elapsed: 0:00:38
====> Test set loss: 53.4935	Time elapsed: 0:00:08

====> Total time elapsed for this epoch: 0:01:47
====> Begin epoch 2


====> Epoch: 2 Average loss: 8.5510	Time elapsed: 0:00:34
====> Test set loss: 23.4151	Time elapsed: 0:00:05

====> Total time elapsed for this epoch: 0:01:32
====> Begin epoch 3


====> Epoch: 3 Average loss: 8.4165	Time elapsed: 0:00:34
====> Test set loss: 17.3853	Time elapsed: 0:00:05

====> Total time elapsed for this epoch: 0:01:32
====> Begin epoch 4


====> Epoch: 4 Average loss: 8.0770	Time elapsed: 0:00:34
====> Test set loss: 18.2166	Time elapsed: 0:00:05

====> Total time elapsed for this epoch: 0:01:32
====> Begin epoch 5


====> Epoch: 5 Average loss: 7.7814	Time elapsed: 0:00:34
====> Test set loss: 13.0792	Time elapsed: 0:00:05

====> Total time elapsed for this epoch: 0:01:32
====> Begin epoch 6


====> Epoch: 6 Average loss: 7.3682	Time elapsed: 0:00:34
====> Test set 

In [26]:
SOMresult = pd.DataFrame({"train_losses":train_losses,"train_SSIM":train_SSIM,"train_PSNR":train_PSNR,
                                     "val_losses":val_losses,"val_SSIM":val_SSIM,"val_PSNR":val_PSNR})
SOMresult

Unnamed: 0,train_losses,train_SSIM,train_PSNR,val_losses,val_SSIM,val_PSNR
0,9.471082,0.506771,15.423364,53.493514,0.514116,16.044701
1,8.550968,0.558842,18.464157,23.415148,0.563682,19.119915
2,8.416471,0.538246,14.166626,17.385339,0.547245,14.800964
3,8.077001,0.507532,14.102262,18.216553,0.516963,14.745164
4,7.781376,0.508886,18.19315,13.079228,0.519922,18.882824
5,7.368234,0.572097,18.71356,11.831404,0.581213,19.415426
6,7.405591,0.575397,19.310278,11.068555,0.583951,20.027122
7,7.225915,0.577973,19.299486,10.449367,0.585714,20.005871
8,7.206408,0.594227,19.39216,10.545882,0.601502,20.109522
9,7.125371,0.592208,19.31926,9.989473,0.599571,20.032866
