In this notebook, we first train the regular autoencoder model to extract the weights and codebook for K-Means initialization of the VQVAE and SOMVAE models

In [3]:
from google.colab import drive
drive.mount("/content/drive",force_remount=True)

Mounted at /content/drive


In [4]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Wed May 19 18:40:31 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P0    23W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [5]:
import os 
os.chdir('drive/MyDrive/OASIS_Data')

In [6]:
import copy 
import time
import datetime
import numpy as np 
import pandas as pd 
import nibabel as nib
import sklearn as skl
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [7]:
import torch 
from torch.utils.data import Dataset, DataLoader
from torch import nn 
from torch import optim

In [8]:
pip install torchio nilearn

Collecting torchio
[?25l  Downloading https://files.pythonhosted.org/packages/2f/ab/fe4db746dbc1bd4a1fa37b6c7bb8ab6568cd1cc2b324f8140015c6cb389e/torchio-0.18.39-py2.py3-none-any.whl (143kB)
[K     |██▎                             | 10kB 23.4MB/s eta 0:00:01[K     |████▌                           | 20kB 29.7MB/s eta 0:00:01[K     |██████▉                         | 30kB 23.5MB/s eta 0:00:01[K     |█████████                       | 40kB 17.9MB/s eta 0:00:01[K     |███████████▍                    | 51kB 15.3MB/s eta 0:00:01[K     |█████████████▋                  | 61kB 17.4MB/s eta 0:00:01[K     |████████████████                | 71kB 13.8MB/s eta 0:00:01[K     |██████████████████▏             | 81kB 13.2MB/s eta 0:00:01[K     |████████████████████▌           | 92kB 14.1MB/s eta 0:00:01[K     |██████████████████████▊         | 102kB 14.2MB/s eta 0:00:01[K     |█████████████████████████       | 112kB 14.2MB/s eta 0:00:01[K     |███████████████████████████▎    | 122k

In [9]:
import torchio as tio 

In [10]:
os.chdir("MRI_SOMVAE")
from OASISDataset import *
from FullModels import *
from TrainingTesting import *
from ExtraMetrics import *
os.chdir("..")

Load and split training data into usable training and validation sets 

In [11]:
fulltrainlist = pd.read_csv("oasis_ctrl_training.csv")

trainlist,vallist = train_test_split(fulltrainlist, random_state=128, test_size=0.15)

folder = "ctrl_original_resolution"

In [12]:
start = time.time()
traindataset = OASISDataset_Colab(folder,trainlist)
print(time.time()-start)
valdataset = OASISDataset_Colab(folder,vallist)
print(time.time()-start)

210.52429866790771
247.50104451179504


In [13]:
batch_size=32

train_dataloader = DataLoader(traindataset,batch_size=batch_size,shuffle=True)
val_dataloader = DataLoader(valdataset,batch_size=len(valdataset),shuffle=True)

Set up Model Training on VanillaAE. Using Adam Optimizer with learning rate of 0.005. Encoder dimension is 32. 200 epochs of training. 

In [15]:
model = VanillaAE(num_channels=1,embedding_dim=32,num_filters=4,batchnorm=True)

optimizer = optim.Adam(model.parameters(),lr=0.005)

max_epochs = 200
train_losses = np.zeros(max_epochs)
train_losses[:] = np.NaN 
val_losses = np.zeros(max_epochs)
val_losses[:] = np.NaN 

train_PSNR = copy.deepcopy(train_losses)
val_PSNR = copy.deepcopy(val_losses)

train_SSIM = copy.deepcopy(train_losses)
val_SSIM = copy.deepcopy(val_losses)

loss_function = nn.MSELoss(reduction="mean")

In [16]:
for epoch in range(max_epochs):

    t0 = time.time()

    train_losses[epoch] = train(model,optimizer,loss_function,train_dataloader,
                    epoch,log_every_num_batches=5)
    
    val_losses[epoch] = test(model,loss_function,val_dataloader)

    temp_trainSSIM = []
    temp_trainPSNR = []
    temp_valSSIM = []
    temp_valPSNR = []

    with torch.no_grad():
      use_cuda = torch.cuda.is_available()
      device = torch.device("cuda:0" if use_cuda else "cpu")
      torch.backends.cudnn.benchmark = True
      if use_cuda:
        if torch.cuda.device_count() > 1:
          print("Let's use", torch.cuda.device_count(), "GPUs!")
          # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
          model = nn.DataParallel(model)
      model.to(device)
      model.eval()
      for Xin_train,Xout_train in train_dataloader:
          Xin_train,Xout_train = Xin_train.to(device), Xout_train.to(device)
          Xrecon_train = model(Xin_train)["x_out"]
          train_psnr_batch = PSNR(Xout_train,Xrecon_train)
          temp_trainPSNR.append(train_psnr_batch)
          #if epoch % 2 == 0:
          train_ssim_batch = SSIM(Xout_train,Xrecon_train)
          temp_trainSSIM.append(train_ssim_batch)
      train_PSNR[epoch] = torch.cat(temp_trainPSNR).mean().item()
      train_SSIM[epoch] = torch.cat(temp_trainSSIM).mean().item()
      
      for Xin_val,Xout_val in val_dataloader:
          Xin_val,Xout_val = Xin_val.to(device),Xout_val.to(device)
          Xrecon_val = model(Xin_val)["x_out"]
          val_psnr_batch = PSNR(Xout_val,Xrecon_val)
          temp_valPSNR.append(val_psnr_batch)
          #if epoch % 2 ==0:
          val_ssim_batch = SSIM(Xout_val,Xrecon_val)
          temp_valSSIM.append(val_ssim_batch)
      val_PSNR[epoch] = torch.cat(temp_valPSNR).mean().item()
      val_SSIM[epoch] = torch.cat(temp_valSSIM).mean().item()

    dtepoch = time.time() - t0

    print('====> Total time elapsed for this epoch: {:s}'.format(str(datetime.timedelta(seconds=int(dtepoch)))))

====> Begin epoch 1


====> Epoch: 1 Average loss: 0.0283	Time elapsed: 0:00:27
====> Test set loss: 0.1403	Time elapsed: 0:00:05

====> Total time elapsed for this epoch: 0:01:18
====> Begin epoch 2


====> Epoch: 2 Average loss: 0.0202	Time elapsed: 0:00:25
====> Test set loss: 0.1305	Time elapsed: 0:00:04

====> Total time elapsed for this epoch: 0:01:04
====> Begin epoch 3


====> Epoch: 3 Average loss: 0.0123	Time elapsed: 0:00:25
====> Test set loss: 0.0336	Time elapsed: 0:00:04

====> Total time elapsed for this epoch: 0:01:03
====> Begin epoch 4


====> Epoch: 4 Average loss: 0.0080	Time elapsed: 0:00:23
====> Test set loss: 0.0398	Time elapsed: 0:00:04

====> Total time elapsed for this epoch: 0:01:00
====> Begin epoch 5


====> Epoch: 5 Average loss: 0.0069	Time elapsed: 0:00:23
====> Test set loss: 0.1274	Time elapsed: 0:00:03

====> Total time elapsed for this epoch: 0:00:59
====> Begin epoch 6


====> Epoch: 6 Average loss: 0.0067	Time elapsed: 0:00:23
====> Test set loss:

In [17]:
VanillaAE_result = pd.DataFrame({"train_losses":train_losses,"train_SSIM":train_SSIM,"train_PSNR":train_PSNR,
                                     "val_losses":val_losses,"val_SSIM":val_SSIM,"val_PSNR":val_PSNR})

In [18]:
VanillaAE_result

Unnamed: 0,train_losses,train_SSIM,train_PSNR,val_losses,val_SSIM,val_PSNR
0,0.028312,-0.031538,5.394518,0.140317,-0.034275,5.999938
1,0.020159,-0.006451,5.711125,0.130504,-0.007984,6.315090
2,0.012323,0.163774,11.559094,0.033629,0.167131,12.203369
3,0.008030,0.243502,10.834190,0.039814,0.247529,11.470244
4,0.006947,0.275492,5.787542,0.127389,0.281299,6.423850
...,...,...,...,...,...,...
195,0.001164,0.862432,24.787197,0.001612,0.863507,25.416536
196,0.001159,0.861796,24.703379,0.001643,0.862657,25.330702
197,0.001171,0.857618,24.681339,0.001651,0.858556,25.310341
198,0.001166,0.862414,24.805019,0.001605,0.863717,25.433775


In [19]:
VanillaAE_result.to_csv("VanillaAE32_Run2_4filters_051921.csv")

In [28]:
torch.save({"epoch": epoch,"model_state_dict": model.state_dict(),"optimizer_state_dict": optimizer.state_dict(),"history":VanillaAE_result},
           "VanillaAE32_Run2_4filters_051921.tar")

0.001150130497990176