In [1]:
import os
import matplotlib.pyplot as plt
import numpy as np
import random
import pandas as pd 
from datetime import datetime
from itertools import product
from torch.utils.data import Dataset, DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import imageio
import warnings
warnings.filterwarnings("ignore")

In [2]:
## Visualization functions
def show_slices(slices):
    #Function to display row of image slices
    fig, axes = plt.subplots(1, len(slices))
    for i, slice in enumerate(slices):
        axes[i].imshow(slice.T, cmap="gray", origin="lower")

def print_img(img):
    slice_0 = img[img.shape[0]//2,:,:]
    slice_1 = img[:,img.shape[1]//2, :]
    slice_2 = img[:,:,img.shape[2]//2]
    show_slices([slice_0, slice_1, slice_2])
    plt.show()

def print_coronal(img, title=""):
    slice_1 = img[:,img.shape[1]//2, :]
    
    # Create a figure and axes
    fig, ax = plt.subplots(figsize=(3,3))
    # Plot the image on the axes
    ax.imshow(slice_1.T, cmap="gray", origin="lower")
    # Show the plot
    plt.suptitle(title)
    plt.show()

# load data

In [1]:
df = pd.read_csv("samples.csv", index_col=0)

## standardize age and timeDiff
age_mean = df['age1'].mean()
age_std = df['age1'].std()

df['age1'] = (df['age1']-df['age1'].mean())/df['age1'].std()
print(df['age1'].mean(), df['age1'].std())

timeDiff_mean = df['timeDiff'].mean()
timeDiff_std = df['timeDiff'].std()

df['timeDiff'] = (df['timeDiff']-df['timeDiff'].mean())/df['timeDiff'].std() 
print(df['timeDiff'].mean(),df['timeDiff'].std())

pair_imgs = df.values
pair_imgs.shape

# VAE model
### dataloader, each input is [img2, img1], and scalar input [age1, timeDiff, status]

In [5]:
# build data loader
img_path = "cropped_imgs"

class BrainDataset(Dataset):
    def __init__(self, files):
        self.files = files
        self.n_files = len(self.files) 

    def __len__(self):
        return self.n_files
    
    def __getitem__(self,idx):
        fname = self.files[idx]
        
        image1 = torch.Tensor(np.load(img_path+fname[0]))
        image2 = torch.Tensor(np.load(img_path+fname[2]))
                
        age1 = fname[1]
        timediff = fname[4]
        status = fname[3]
        
        ## we want multi-channel images, so size stays the same
        ## input shape [batch_size, 2, 80,80,80]
        paired_inputs = torch.cat((image2, image1), 0).squeeze()
        scalar_vars = [age1, timediff, status]
        return paired_inputs.to(torch.float32), torch.Tensor(scalar_vars).to(torch.float32)
            
b_size = 32

dataset = BrainDataset(pair_imgs)
dataloader = DataLoader(dataset, batch_size=b_size, shuffle=True, num_workers=8)

# test_dataset = BrainDataset(test_pair_imgs)
# test_dataloader = DataLoader(test_dataset, batch_size=b_size, shuffle=True, num_workers=8)

### Model architecture

In [9]:
# misc
def parameter_count(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# helper block function
class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, scalar_size=3):
        super(Conv, self).__init__()
        
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, bias=True)
        self.linear = nn.Linear(scalar_size, out_channels)
        self.norm = nn.GroupNorm(4, out_channels)
        self.relu = nn.LeakyReLU(inplace=True)
        
    def forward(self, x, sca):
        x = self.conv(x)
        sca = self.linear(sca)
        
        x = x + sca.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        x = self.norm(x)
        return self.relu(x)
    
class ConvTranspose(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, pad, out_pad=[0,0,0]):
        super(ConvTranspose, self).__init__()
        
        self.convTran = nn.Sequential(
            nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride, padding=pad,
                               output_padding=out_pad, bias=False),
            nn.GroupNorm(4, out_channels),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x):
        return self.convTran(x)

In [10]:
class CVAE(nn.Module):
    def __init__(self):
        super(CVAE, self).__init__()
        
        base = 64
        kernel_size = [3,3,3]
        stride_size = [1,1,1]
        padding_size = [1,1,1]

        maxpool_kernel_size = [2,2,2]
        maxpool_stride_size = [2,2,2]
        
        convTrans_kernel = [2,2,2]
        convTrans_stride = [2,2,2]
        
        ### ENCODER: 4 blocks of convolutions + 3 downsampling
        self.e1 = Conv(2, base, kernel_size, stride=stride_size, padding=padding_size)
        self.e2 = Conv(base, base, kernel_size, stride=stride_size, padding=padding_size)
        self.p1 = nn.MaxPool3d(kernel_size=maxpool_kernel_size, stride=maxpool_stride_size)

        self.e3 = Conv(base, 2*base, kernel_size, stride=stride_size, padding=padding_size)
        self.e4 = Conv(2*base, 2*base, kernel_size, stride=stride_size, padding=padding_size)
        self.p2 = nn.MaxPool3d(kernel_size=maxpool_kernel_size, stride=maxpool_stride_size)

        self.e5 = Conv(2*base, 4*base, kernel_size, stride=stride_size, padding=padding_size)
        self.e6 = Conv(4*base, 4*base, kernel_size, stride=stride_size, padding=padding_size)
        self.p3 = nn.MaxPool3d(kernel_size=maxpool_kernel_size, stride=maxpool_stride_size)

        self.e7 = Conv(4*base, 8*base, kernel_size, stride=stride_size, padding=padding_size)
        self.e8 = Conv(8*base, 8*base, kernel_size, stride=stride_size, padding=padding_size)
        
        self.first_linear = nn.Linear(10*10*10*8*base, 200)
        self.last_linear = nn.Linear(200, 20)
        
        ## U-Net decoder, encoder part for only the image+conditionals
        self.u_down1 = Conv(1, base, kernel_size, stride=stride_size, padding=padding_size, 
                            scalar_size=13)
        self.u_down2 = Conv(base, base, kernel_size, stride=stride_size, padding=padding_size, 
                            scalar_size=13)
        self.u_pool1 = nn.MaxPool3d(kernel_size=maxpool_kernel_size, stride=maxpool_stride_size)

        self.u_down3 = Conv(base, 2*base, kernel_size, stride=stride_size, padding=padding_size, 
                            scalar_size=13)
        self.u_down4 = Conv(2*base, 2*base, kernel_size, stride=stride_size, padding=padding_size, 
                            scalar_size=13)
        self.u_pool2 = nn.MaxPool3d(kernel_size=maxpool_kernel_size, stride=maxpool_stride_size)

        self.u_down5 = Conv(2*base, 4*base, kernel_size, stride=stride_size, padding=padding_size, 
                            scalar_size=13)
        self.u_down6 = Conv(4*base, 4*base, kernel_size, stride=stride_size, padding=padding_size, 
                            scalar_size=13)
        self.u_pool3 = nn.MaxPool3d(kernel_size=maxpool_kernel_size, stride=maxpool_stride_size)

        ## decoder latent space add directly here at the bottleneck
        self.u_down7 = Conv(4*base, 8*base, kernel_size, stride=stride_size, padding=padding_size, 
                            scalar_size=13)
        self.u_down8 = Conv(8*base, 8*base, kernel_size, stride=stride_size, padding=padding_size, 
                            scalar_size=13)
        
        ## decoder up part
        self.t1 = ConvTranspose(8*base, 4*base, convTrans_kernel, stride=convTrans_stride, 
                                pad=[0,0,0])
        self.d1 = Conv(8*base, 4*base, kernel_size, stride=stride_size, padding=padding_size, 
                       scalar_size=13)
        self.d2 = Conv(4*base, 4*base, kernel_size, stride=stride_size, padding=padding_size, 
                       scalar_size=13)

        self.t2 = ConvTranspose(4*base, 2*base, convTrans_kernel, stride=convTrans_stride, 
                                pad=[0,0,0])
        self.d3 = Conv(4*base, 2*base, kernel_size, stride=stride_size, padding=padding_size, 
                       scalar_size=13)
        self.d4 = Conv(2*base, 2*base, kernel_size, stride=stride_size, padding=padding_size, 
                       scalar_size=13)

        self.t3 = ConvTranspose(2*base, base, convTrans_kernel, stride=convTrans_stride, 
                                pad=[0,0,0])
        self.d5 = Conv(2*base, base, kernel_size, stride=stride_size, padding=padding_size, scalar_size=13)
        self.d6 = Conv(base, base, kernel_size, stride=stride_size, padding=padding_size, scalar_size=13)
        
        self.last_conv = nn.Conv3d(base, 1, [1,1,1], [1,1,1], bias=True)

    def encode(self, x, scalar):
        x = self.e1(x, scalar)
        x = self.e2(x, scalar)
        x = self.p1(x)
        
        x = self.e3(x, scalar)
        x = self.e4(x, scalar)
        x = self.p2(x)
        
        x = self.e5(x, scalar)
        x = self.e6(x, scalar)
        x = self.p3(x)
        
        x = self.e7(x, scalar)
        x = self.e8(x, scalar)
        
        x = self.first_linear(x.view(x.size(0), -1))
        return self.last_linear(x)
    
    
    def to_mu_sigma(self, out):
        mu = out[:,:10]
        logsigma2 = out[:,10:]  

        sigma = torch.exp(0.5*logsigma2)
        return mu, sigma
    
    def reparameterize(self, mu, sigma):
        tmp1 = torch.randn(mu.shape).to(device)
        return mu + tmp1*sigma
    
    def decode(self, z, scalar, img1):
        latent_scalar = torch.cat((z, scalar), dim=1)
        
        ## u-net strucutre, downsize part just img 1
        out = self.u_down1(img1, latent_scalar)
        skip1 = self.u_down2(out, latent_scalar)
        out = self.u_pool1(skip1)
        
        out = self.u_down3(out, latent_scalar)
        skip2 = self.u_down4(out, latent_scalar)
        out = self.u_pool2(skip2)
        
        out = self.u_down5(out, latent_scalar)
        skip3 = self.u_down6(out, latent_scalar)
        out = self.u_pool3(skip3)
        
        out = self.u_down7(out, latent_scalar)
        out = self.u_down8(out, latent_scalar)
        
        ## u-net strucutre, upsize part
        out = self.t1(out)
        out = self.d1(torch.cat([out, skip3], 1), latent_scalar)
        out = self.d2(out, latent_scalar)
        
        out = self.t2(out)
        out = self.d3(torch.cat([out, skip2], 1), latent_scalar)
        out = self.d4(out, latent_scalar)
        
        out = self.t3(out)
        out = self.d5(torch.cat([out, skip1], 1), latent_scalar)
        out = self.d6(out, latent_scalar)
                
        return self.last_conv(out)

    
    def forward(self, x, scalar, img1):
        ## out is size (batch, 20)
        out = self.encode(x, scalar)
                
        mu, sigma = self.to_mu_sigma(out)
        z = self.reparameterize(mu, sigma)
        
        out = self.decode(z, scalar, img1)
        return out, mu, sigma

print('parameter count:', parameter_count(CVAE()))

parameter count: 138905597


### train

In [12]:
init_lr = 1e-5
num_epochs = 1000

# for early stopping
min_loss = 999999999
patience = 10
pat_count = 0

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
print(device)

# model
cvae = CVAE()
cvae.to(device)
cvae.cuda()

# optimizer and loss function
optimizer = optim.Adam(cvae.parameters(), lr=init_lr) #, weight_decay=0.001)
l2_loss = nn.MSELoss(reduction='sum')

def kld_loss(mu,sigma2):
    return 0.5*torch.mean(mu**2 + sigma2 - 1 - torch.log(sigma2))

def criterion(img2_batch, recon, mu, sigma2): #, kl_w):
    recon_loss = l2_loss(recon, img2_batch)/(2*0.1*0.1)
    kl_loss = kld_loss(mu,sigma2)
    return recon_loss+kl_loss, recon_loss.item(), kl_loss.item()
    
# statistics
train_losses = []
test_losses = []

recon_loss_arr = []
kl_loss_arr = []

test_recon_loss_arr = []
test_kl_loss_arr = []

cuda:0


In [None]:
for epoch in range(num_epochs):
    cvae.train(True)
    running_loss = 0.0
    recon_loss = 0.0
    kl_loss = 0.0
    
    for img_batch, scalar_batch in dataloader:
        optimizer.zero_grad()
        
        img_batch = img_batch.cuda()
        scalar_batch = scalar_batch.cuda()
        
        # forward + backward + optimize
        recon, mu, sigma = cvae(img_batch, scalar_batch, img_batch[:,1:2,...])
        loss, recon_l, kl_l = criterion(img_batch[:,0:1,...], recon, mu, sigma**2) #, kl_w)
       
        # statistics
        running_loss += loss.item()
        recon_loss += recon_l
        kl_loss += kl_l
                
        loss.backward()
        optimizer.step()
                
    running_loss /= (pair_imgs.shape[0] // b_size)
    recon_loss /= (pair_imgs.shape[0] // b_size)
    kl_loss /= (pair_imgs.shape[0] // b_size)
    
    train_losses.append(running_loss)
    recon_loss_arr.append(recon_loss)
    kl_loss_arr.append(kl_loss)
    
    ### early stopping
    if running_loss < min_loss:
        min_loss = running_loss
        pat_count = 0
        
        torch.save(cvae.state_dict(), 'CVAE.state_dict')
        
        np.savetxt("CVAE_recon", recon_loss_arr)
        np.savetxt("CVAE_kl", kl_loss_arr)
        np.savetxt("CVAE_recon_test", test_recon_loss_arr)
        np.savetxt("CVAE_kl_test", test_kl_loss_arr)
    else:
        if pat_count >= patience:
            break
        else:
            pat_count += 1
    
    ### evaluate
    cvae.eval()
    
    test_running_loss = 0.0
    test_recon_loss = 0.0
    test_kl_loss = 0.0
    
    with torch.no_grad():
        for img_batch, scalar_batch in test_dataloader:
            img_batch = img_batch.cuda()
            scalar_batch = scalar_batch.cuda()

            # forward + backward + optimize
            recon, mu, sigma = cvae(img_batch, scalar_batch, img_batch[:,1:2,...])
            test_loss, test_recon_l, test_kl_l = criterion(img_batch[:,0:1,...], recon, mu, sigma**2)

            # statistics
            test_running_loss += test_loss.item()
            test_recon_loss += test_recon_l
            test_kl_loss += test_kl_l
                
    test_running_loss /= (test_pair_imgs.shape[0] // b_size)
    test_recon_loss /= (test_pair_imgs.shape[0] // b_size)
    test_kl_loss /= (test_pair_imgs.shape[0] // b_size)
    
    test_losses.append(test_running_loss)
    test_recon_loss_arr.append(test_recon_loss)
    test_kl_loss_arr.append(test_kl_loss)

    # output
    print('Epoch {} -- training loss: {:.4f}, testing loss: {:.4f}'.format(epoch + 1, running_loss, 
                                                                           test_running_loss))

## evaluation

In [2]:
plt.plot(recon_loss_arr, linestyle = 'dotted', label='reconstruction loss')
plt.plot(kl_loss_arr, linestyle = 'dotted', label='KL loss')
plt.plot(test_recon_loss_arr, linestyle = 'dotted', label='test recon loss')
plt.plot(test_kl_loss_arr, linestyle = 'dotted', label='test KL loss')
plt.yscale('log')
plt.legend()
plt.show()

### predict 10 year trajectory

In [None]:
cvae.eval()

## sample z from latent space
random_mu = torch.randn(10)
random_var = torch.ones(1,10)
random_mu = random_mu.cuda()
random_var = random_var.cuda()

z = cvae.reparameterize(random_mu, random_var)

base_image = ""
base_age = 0
status = 5

## for visualization, only take coronal middle slice
slices = np.zeros((11,80,80))

for a in range(11):
    time_diff = (a-timeDiff_mean)/timeDiff_std
    
    scalar_vars = torch.Tensor([base_age, time_diff, status]).reshape((1,3)).cuda()

    img1 = torch.Tensor(np.load("cropped_imgs/"+base_image).squeeze())
    img1_in = torch.Tensor(img1).unsqueeze(dim=0).to(torch.float32).cuda()
    
    rand_gen = cvae.decode(z, scalar_vars, img1_in).data.cpu().numpy().squeeze()

    np.save("output_images/ADNI_year_"+str(a)+"_status_"+str(status)+".npy", rand_gen)
    
    slice_1 = rand_gen[:,rand_gen.shape[1]//2, :]
    slices[a,:] = slice_1
    
    # Create a figure and axes
    fig, ax = plt.subplots(figsize=(3,3))
    # Plot the image on the axes
    ax.imshow(slice_1.T, cmap="gray", origin="lower")
    
    # Show the plot
    plt.suptitle("predicted at time difference"+str(a))
    plt.savefig("output_images/1_ADNI_nopatch_"+str(a)+"_status_"+str(status))

## Results

In [6]:
base_image = "012_S_4094_MPRAGE_2011-07-07_13_21_27.0.npy"
base_age = -1.763316 #(55.1-ADNI_age_mean)/ADNI_age_std
status = 4

slices = np.zeros((11,80,80))
output_path = ""

for a in range(11):
    time_diff = (a-timeDiff_mean)/timeDiff_std
    
    scalar_vars = torch.Tensor([base_age, time_diff, status]).reshape((1,3)).cuda()

    img1 = torch.Tensor(np.load(img_path+base_image).squeeze())
    img1_in = torch.Tensor(img1).unsqueeze(dim=0).to(torch.float32).cuda()
    
    rand_gen = cvae.decode(z, scalar_vars, img1_in).data.cpu().numpy().squeeze()

    np.save(output_path+"year_"+str(a)+"_status_"+str(status)+".npy", rand_gen)
    
    slice_1 = rand_gen[:,rand_gen.shape[1]//2, :]
    slices[a,:] = slice_1
    
    # Create a figure and axes
    fig, ax = plt.subplots(figsize=(3,3))
    # Plot the image on the axes
    ax.imshow(slice_1.T, cmap="gray", origin="lower")
    
    # Show the plot
    plt.suptitle("predicted at time "+str(a))
    plt.savefig(output_path+"year_"+str(a)+"_status_"+str(status))

show_slices(slices)

In [7]:
# Create gif from images
images = [imageio.imread(output_path+"year_"+str(a)+"_status_"+str(status)+".png") for a in range(11)]
print(len(images))

# Create a gif
gif = imageio.mimsave(output_path+"status_"+str(status)+'.gif', images, duration=10)