<a href="https://colab.research.google.com/github/tbng/deep-fmri/blob/master/notebooks/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

from google.colab import drive
from pathlib import Path



drive.mount('/content/gdrive', force_remount=True)

Mounted at /content/gdrive


In [0]:
HCP_FOLDER = Path("/content/gdrive/My Drive/HCP_900")
MASKED_DATA_FILES = sorted(HCP_FOLDER.glob("*REST*.npy"))
MASKED_DATA_FILES

[PosixPath('/content/gdrive/My Drive/HCP_900/100307_REST1_LR_0.npy'),
 PosixPath('/content/gdrive/My Drive/HCP_900/100307_REST1_LR_1.npy'),
 PosixPath('/content/gdrive/My Drive/HCP_900/100307_REST1_LR_2.npy'),
 PosixPath('/content/gdrive/My Drive/HCP_900/100307_REST1_LR_3.npy'),
 PosixPath('/content/gdrive/My Drive/HCP_900/100307_REST1_RL_0.npy'),
 PosixPath('/content/gdrive/My Drive/HCP_900/100307_REST1_RL_1.npy'),
 PosixPath('/content/gdrive/My Drive/HCP_900/100307_REST1_RL_2.npy'),
 PosixPath('/content/gdrive/My Drive/HCP_900/100307_REST1_RL_3.npy'),
 PosixPath('/content/gdrive/My Drive/HCP_900/100307_REST2_LR_0.npy'),
 PosixPath('/content/gdrive/My Drive/HCP_900/100307_REST2_LR_1.npy'),
 PosixPath('/content/gdrive/My Drive/HCP_900/100307_REST2_LR_2.npy'),
 PosixPath('/content/gdrive/My Drive/HCP_900/100307_REST2_LR_3.npy'),
 PosixPath('/content/gdrive/My Drive/HCP_900/100307_REST2_RL_0.npy'),
 PosixPath('/content/gdrive/My Drive/HCP_900/100307_REST2_RL_1.npy'),
 PosixPath('/content

In [0]:
import glob

import numpy as np
import torch
from os.path import expanduser, join
from skimage.morphology import dilation, binary_dilation
from torch.utils.data import Dataset, ConcatDataset


class NumpyDataset(Dataset):
    def __init__(self, filename):
        self.filename = filename

    def __len__(self):
        return np.load(self.filename, mmap_mode='r').shape[3]

    def __getitem__(self, index):
        data = np.load(self.filename, mmap_mode='r')
        return torch.Tensor(data[None, :, :, :, index])


class NumpyDatasetMem(Dataset):
    def __init__(self, filename):
        self.data = torch.Tensor(np.load(filename, mmap_mode=None))

    def __len__(self):
        return self.data.shape[3]

    def __getitem__(self, index):
        return self.data[None, :, :, :, index]


def get_dataset(subject=100307, data_dir=None, in_memory=True):
    if in_memory:
        dataset_type = NumpyDatasetMem
    else:
        dataset_type = NumpyDataset
    if data_dir is None:
        data_dir = HCP_FOLDER
    datasets = [dataset_type(str(fp)) for fp in MASKED_DATA_FILES[:3]]
    train_dataset = ConcatDataset(datasets[:-1])
    test_dataset = datasets[-1]
    mask = np.load(data_dir / ('%s_mask.npy' % subject))
    mask = mask.astype('bool')
    print('Mask', mask.astype('float').sum(), 'voxels')
    for i in range(2):
        mask = binary_dilation(mask)
    mask = mask.astype('uint8')
    print('Dilated mask', mask.astype('float').sum(), 'voxels')
    mask = torch.from_numpy(mask).byte()

    return train_dataset, test_dataset, mask


In [0]:
import torch

import torch.nn as nn
import torch.nn.functional as F
from torch.testing import randn_like


class Encoder(nn.Module):
    def __init__(self, embedding_size=128):
        super().__init__()

        self.pad = nn.ConstantPad3d((2, 3, 9, 10, 2, 3), 0)
        self.conv = nn.Sequential(
            nn.Conv3d(1, 16, 3, 2, 1),
            nn.ReLU(),
            nn.Conv3d(16, 16, 3, 1, 1),
            nn.ReLU(),
            nn.BatchNorm3d(16),

            nn.Conv3d(16, 32, 3, 2, 1),
            nn.ReLU(),
            nn.Conv3d(32, 32, 3, 1, 1),
            nn.ReLU(),
            nn.BatchNorm3d(32),

            nn.Conv3d(32, 64, 3, 2, 1),
            nn.ReLU(),
            nn.Conv3d(64, 64, 3, 1, 1),
            nn.BatchNorm3d(64),
            nn.ReLU(),

            nn.Conv3d(64, 128, 3, 2, 1),
            nn.ReLU(),
            nn.Conv3d(128, 128, 3, 1, 1),
            nn.ReLU(),
            nn.BatchNorm3d(128),

            nn.Conv3d(128, 256, 3, 2, 1),
            nn.ReLU(),
            nn.Conv3d(256, 256, 3, 1, 1),
            nn.ReLU(),
            nn.BatchNorm3d(256),

        )

        self.dense = nn.Linear(256, embedding_size)
        self.dense_var = nn.Linear(256, embedding_size)

    def forward(self, img):
        batch_size = img.shape[0]
        img = self.pad(img)
        conv_img = self.conv(img)
        avg_channel = conv_img.view(batch_size, 256, -1).mean(dim=2)
        # avg_channel = F.dropout(avg_channel, p=0.1)
        mean = self.dense(avg_channel)
        log_var = self.dense_var(avg_channel)
        return mean, log_var


class Decoder(nn.Module):
    def __init__(self, embedding_size=128):
        super().__init__()

        self.dense = nn.Linear(embedding_size, 256)

        self.deconv = nn.Sequential(
            nn.ConvTranspose3d(256, 256, 3, 1, 1, 0),
            nn.ReLU(),
            nn.ConvTranspose3d(256, 128, 3, 2, 1, 1),
            nn.ReLU(),
            nn.BatchNorm3d(128),

            nn.ConvTranspose3d(128, 128, 3, 1, 1, 0),
            nn.ReLU(),
            nn.ConvTranspose3d(128, 64, 3, 2, 1, 1),
            nn.ReLU(),
            nn.BatchNorm3d(64),

            nn.ConvTranspose3d(64, 64, 3, 1, 1, 0),
            nn.ReLU(),
            nn.ConvTranspose3d(64, 32, 3, 2, 1, 1),
            nn.ReLU(),
            nn.BatchNorm3d(32),

            nn.ConvTranspose3d(32, 32, 3, 1, 1, 0),
            nn.ReLU(),
            nn.ConvTranspose3d(32, 16, 3, 2, 1, 1),
            nn.ReLU(),
            nn.BatchNorm3d(16),

            nn.ConvTranspose3d(16, 16, 3, 1, 1, 0),
            nn.ReLU(),
            nn.ConvTranspose3d(16, 1, 3, 2, 1, 1)
        )

    def forward(self, latent):
        batch_size = latent.shape[0]
        avg_channel = self.dense(latent)
        # avg_channel = F.dropout(avg_channel, p=0.1)
        avg_channel = avg_channel[:, :, None,
                      None, None].expand(batch_size, 256, 3, 4, 3) * 1
        rec = self.deconv(avg_channel)

        # self.pad = nn.ConstantPad3d((2, 3, 9, 10, 2, 3), 0)
        rec = rec[:, :, 2:-3, 9:-10, 2:-3]
        return rec


class VAE(nn.Module):
    def __init__(self, embedding_dim=128):
        super().__init__()
        self.encoder = Encoder(embedding_dim)
        self.decoder = Decoder(embedding_dim)

    def forward(self, img):
        mean, log_var = self.encoder(img)
        penalty = gaussian_kl(mean, log_var)
        if self.training:
            eps = randn_like(mean)
            latent = mean + torch.exp(log_var / 2) * eps
        else:
            latent = mean
        return self.decoder(latent), penalty


def gaussian_kl(mean, log_var):
    return .5 * torch.sum(mean ** 2 + torch.exp(log_var)
                          - log_var - 1) / mean.shape[0]


def masked_mse(pred, target, mask):
    diff = pred - target
    mask = mask ^ 1
    mask = mask[None, None, ...]
    diff.masked_fill_(mask, 0.)
    return torch.sum(diff ** 2) / diff.shape[0]


In [19]:
import functools
import math

import torch
from os.path import expanduser
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchsummary import summary


batch_size = 48
in_memory = True
alpha = 10
residual = False

train_dataset, test_dataset, mask = get_dataset(in_memory=in_memory)

train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size,
                         shuffle=False)
model = VAE()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
mask = mask.to(device)

summary(model, (1, 91, 109, 91))

loss_function = functools.partial(masked_mse, mask=mask)

optimizer = Adam(model.parameters(), lr=1e-3, amsgrad=True)

n_epochs = 100
total_loss = 0

n_batch = math.ceil(len(train_dataset) / batch_size)

mean = torch.zeros_like(train_dataset[0])
# Compute mean
if residual:
    length = 0
    for this_data in train_loader:
        length += this_data.shape[0]
        mean += this_data.sum(dim=0)
    mean /= length
mean = mean.to(device)

for epoch in range(n_epochs):
    epoch_batch = 0
    verbose_loss = 0
    verbose_penalty = 0
    verbose_batch = 0
    for this_data in train_loader:
        model.train()
        model.zero_grad()
        this_data[this_data >= 1] = 1
        this_data = this_data.to(device)
        this_data -= mean[None, ...]
        rec, penalty = model(this_data)
        penalty *= alpha
        loss = loss_function(rec, this_data)
        elbo = loss + penalty
        elbo.backward()
        optimizer.step()
        verbose_loss += loss.item()
        verbose_penalty += penalty.item()
        epoch_batch += 1
        verbose_batch += 1
        print('Epoch %i, batch %i/%i,'
                  'train_objective: %4e,'
                  'train_penalty: %4e,' % (epoch, epoch_batch, n_batch,
                                           verbose_loss / verbose_batch, 
                                           verbose_penalty / verbose_batch,))
        if epoch_batch % 10 == 0:
            model.eval()
            with torch.no_grad():
                val_batch = 0
                val_loss = 0
                val_penalty = 0
                for this_test_data in test_loader:
                    this_test_data = this_test_data.to(device)
                    this_test_data -= mean[None, ...]
                    rec, this_val_penalty = model(this_test_data)
                    this_val_penalty *= alpha
                    this_val_loss = loss_function(rec, this_test_data)
                    val_loss += this_val_loss.item()
                    val_penalty += this_val_penalty.item()
                    val_batch += 1
            val_loss /= val_batch
            val_penalty /= val_batch
            verbose_loss /= verbose_batch
            verbose_penalty /= verbose_batch
            print('Epoch %i, batch %i/%i,'
                  'train_objective: %4e,'
                  'train_penalty: %4e,'
                  'val_objective: %4e,'
                  'val_penalty: %4e' % (epoch, epoch_batch, n_batch,
                                        verbose_loss, verbose_penalty,
                                        val_loss, val_penalty))
            verbose_batch = 0
            train_loss = 0
            penalty = 0
    state_dict = model.state_dict()

    name = 'vae_dilated_e_%03i_loss_%.4e.pkl' % (epoch, verbose_loss)
    torch.save((state_dict, mean), name)


Mask 229399.0 voxels
Dilated mask 270791.0 voxels
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
     ConstantPad3d-1       [-1, 1, 96, 128, 96]               0
            Conv3d-2       [-1, 16, 48, 64, 48]             448
              ReLU-3       [-1, 16, 48, 64, 48]               0
            Conv3d-4       [-1, 16, 48, 64, 48]           6,928
              ReLU-5       [-1, 16, 48, 64, 48]               0
       BatchNorm3d-6       [-1, 16, 48, 64, 48]              32
            Conv3d-7       [-1, 32, 24, 32, 24]          13,856
              ReLU-8       [-1, 32, 24, 32, 24]               0
            Conv3d-9       [-1, 32, 24, 32, 24]          27,680
             ReLU-10       [-1, 32, 24, 32, 24]               0
      BatchNorm3d-11       [-1, 32, 24, 32, 24]              64
           Conv3d-12       [-1, 64, 12, 16, 12]          55,360
             ReLU-13       [-1, 64, 12, 16, 12]      

In [20]:
ls

[0m[01;34mdata[0m/                                  vae_dilated_e_047_loss_7.4994e+03.pkl
[01;34mgdrive[0m/                                vae_dilated_e_048_loss_7.4121e+03.pkl
[01;34msample_data[0m/                           vae_dilated_e_049_loss_7.4530e+03.pkl
vae_dilated_e_000_loss_1.7390e+05.pkl  vae_dilated_e_050_loss_7.2876e+03.pkl
vae_dilated_e_000_loss_2.1543e+05.pkl  vae_dilated_e_051_loss_7.2472e+03.pkl
vae_dilated_e_000_loss_2.4455e+05.pkl  vae_dilated_e_052_loss_7.1110e+03.pkl
vae_dilated_e_000_loss_6.7717e+05.pkl  vae_dilated_e_053_loss_7.2670e+03.pkl
vae_dilated_e_001_loss_1.1157e+05.pkl  vae_dilated_e_054_loss_7.0138e+03.pkl
vae_dilated_e_001_loss_2.9837e+05.pkl  vae_dilated_e_055_loss_6.9026e+03.pkl
vae_dilated_e_002_loss_5.6187e+04.pkl  vae_dilated_e_056_loss_6.9308e+03.pkl
vae_dilated_e_003_loss_3.7112e+04.pkl  vae_dilated_e_057_loss_6.9630e+03.pkl
vae_dilated_e_004_loss_3.0005e+04.pkl  vae_dilated_e_058_loss_6.7679e+03.pkl
vae_dilated_e_005_loss_2.7753e+04.pk