## Variational Autoencoder code

In [1]:
#Import statements
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
import nilearn.image
from nilearn import plotting, image, surface, datasets
import copy
import ants
from tqdm import tqdm
import scipy.sparse as sp
from scipy.sparse import csr_matrix
print(torch.__version__)

2.3.0


#### Input Standard Scaled, de-identified, non-labeled training, testing datasets, and diagnosis-labeled test data

SUVR FDG-PET voxel intensities can be accessed here:

train_parquet = pd.read_parquet('model_data/train_brain_data_deid.parquet')

test_parquet = pd.read_parquet('model_data/test_brain_data_deid.parquet')

nd_parquet = pd.read_parquet('model_data/nd_brain_data_deid.parquet')

In [None]:
train_csv = pd.read_csv('model_data/train_brain_labels_deid.csv', index_col=0)
train_parquet_scaled = pd.read_parquet('model_data/train_brain_data_deid_scaled.parquet')

test_csv = pd.read_csv('model_data/test_brain_labels_deid.csv', index_col=0)
test_parquet_scaled = pd.read_parquet('model_data/test_brain_data_deid_scaled.parquet')

nd_csv = pd.read_csv('model_data/nd_brain_labels_deid.csv', index_col=0)
nd_parquet_scaled = pd.read_parquet('model_data/nd_brain_data_deid_scaled.parquet')

#### Load brain mask and MCALT background image

In [None]:
brain_mask = nilearn.image.load_img('model_data/brain_mask.nii')
bg = nilearn.image.load_img('model_data/t1.nii')
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

#### Example Visual of Standard Scaled FDG-PET

In [None]:
img_data = np.zeros(brain_mask.shape)
vec = nd_parquet_scaled.iloc[0, 0]
nz_indices = np.ma.nonzero(brain_mask.get_fdata())
img_data[nz_indices] = vec
img = nilearn.image.new_img_like(brain_mask, img_data)
nilearn.plotting.plot_img(img, cut_coords=(0, 0, 0))

#### Adjacency Matrix Initialization

In [None]:
def k_matrix(coords, radius):
    nn_search = NearestNeighbors(radius=radius)
    nn_search.fit(coords)
    
    k = nn_search.radius_neighbors_graph(coords)
    k = k.toarray()
    return k.T

In [None]:
num_rows = nz_indices[0].size
n_cols = 3
arr = np.zeros((num_rows, n_cols))
for d in range(3):
    arr[:, d] = nz_indices[d]

In [None]:
BATCH_SIZE = 16
radius = 2
k = k_matrix(arr, radius)  #(16487, 16487), arr = (16487, 3) x, y, z of mask
k = torch.tensor(k, dtype=torch.float32) + torch.eye(k.shape[0], dtype=torch.float32)
degrees = torch.sum(k, axis=1)
D = torch.diag(degrees)
D_mod = torch.linalg.inv(torch.sqrt(D))
k_normalized = torch.matmul(D_mod, torch.matmul(k, D_mod))
k_normalized = k_normalized.to_sparse_csr()

transform = transforms.Compose([transforms.ToTensor()])

#### Custom dataset

In [None]:
class CustomLoader(Dataset):
    def __init__(self, images, transform=None):
        super(CustomLoader, self).__init__()
        
        self.images = images # parquet
        self.transform = transform

        self.n_samples = self.images.shape[0]

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        image = self.images.iloc[idx].copy()
        image = np.array(image)
        image = torch.from_numpy(image)
        image = image.to(torch.float32)
        image = image.unsqueeze(0)
        if self.transform is not None:
            image = self.transform(image)
        return image

#### Load datasets using custom dataset loader

In [None]:
train_data = CustomLoader(images=train_parquet_scaled, transform=None)

n_samples = train_data.__len__()
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, num_workers=0, shuffle=True)

test_data = CustomLoader(images=test_parquet_scaled, transform=None)
test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, num_workers=0, shuffle=True)

nd_data = CustomLoader(images=nd_parquet_scaled, transform=None)
nd_loader = DataLoader(dataset=nd_data, batch_size=BATCH_SIZE, num_workers=0, shuffle=False)

#### Graph convolution method

In [264]:
class GraphConv(nn.Module):
    def __init__(self, in_dim, out_dim, k):
        super(GraphConv, self).__init__()
        self.d1 = nn.Conv1d(in_dim, out_dim, kernel_size=1)
        self.k = k
        
    def forward(self, imgs):
        device = imgs.device
        x = self.d1(imgs)
        x = x.cpu() @ self.k.cpu()
        x = x.to(device)
        return x

#### Graph convolutional network block

In [470]:
class GCNBlock(nn.Module):
    def __init__(self, in_dim, out_dim, k):
        super(GCNBlock, self).__init__()
        self.k = k
        self.GC1 = GraphConv(in_dim, out_dim, self.k)
        self.GC2 = GraphConv(out_dim, out_dim, self.k)
        self.bn1 = nn.BatchNorm1d(out_dim)
        self.bn2 = nn.BatchNorm1d(out_dim)
    def forward(self, x):
        identity = x.mean(axis=1, keepdim = True)
        out = self.GC1(x)
        out = self.bn1(out)
        out = F.tanh(out) #out = F.relu(out)
        out = self.GC2(out)
        out = self.bn2(out)
        out = out + identity
        #out = F.relu(out)  #testing
        return out

#### Encoder of the Variational autoencoder

In [471]:
class Encoder(nn.Module):
    def __init__(self, in_dim, conv_dim, lin_dim, latent_dims, k, device):
        super(Encoder, self).__init__()
        self.k = k
        self.device = device
        self.GCBlock = GCNBlock(in_dim, conv_dim, self.k)

        self.Lin1 = nn.Linear(conv_dim * self.k.shape[0], lin_dim)
        self.Lin2 = nn.Linear(lin_dim, latent_dims)
        self.Lin3 = nn.Linear(lin_dim, latent_dims)

        self.N = torch.distributions.Normal(0, 1)
        #self.N.loc = self.N.loc.to(device)
        #self.N.scale = self.N.scale.to(device)
        self.KL = 0.0

    def forward(self, x):
        x = x.to(self.device)
        x = self.GCBlock(x)
        x = F.tanh(x) #x = F.relu(x)
        
        x = x.flatten(start_dim=1)
        x = self.Lin1(x)
        x = F.tanh(x) #x = F.relu(x)

        mu = self.Lin2(x)
        sigma = torch.exp(self.Lin3(x))
        z = mu + sigma * self.N.sample(mu.shape).to(self.device)
        self.KL = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()

        return z

#### Decoder of the variational autoencoder

In [472]:
class Decoder(nn.Module):
    def __init__(self, in_dim, conv_dim, lin_dim,  latent_dims, k):
        super(Decoder, self).__init__()
        self.k = k

        self.decoder_lin = nn.Sequential(
            nn.Linear(latent_dims, lin_dim),
            nn.Tanh(), #nn.relu(True),
            nn.Linear(lin_dim, conv_dim * self.k.shape[0]), 
            nn.Tanh() #nn.relu(True)
        )

        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(conv_dim, self.k.shape[0]))

        self.decoder_conv = GCNBlock(conv_dim, in_dim, self.k)
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        
        x = self.decoder_conv(x)
        return x

#### Variational Autoencoder class

In [473]:
class VAE(nn.Module):
    def __init__(self, in_dim, conv_dim, lin_dim, latent_dims, k, device):
        super(VAE, self).__init__()
        self.device = device
        self.k = k
        self.encoder = Encoder(in_dim, conv_dim, lin_dim, latent_dims, self.k, self.device)
        self.decoder = Decoder(in_dim, conv_dim, lin_dim, latent_dims, self.k)
    
    def forward(self, x):
        x = x.to(self.device)
        z = self.encoder(x)
        return self.decoder(z)

#### Optimizer and variational autoencoder initialization and hyperparameters

In [474]:
torch.manual_seed(0)

latent_dims = 16
vae = VAE(1, 8, 256, latent_dims=latent_dims, k=k_normalized, device=device)

vae.to(device)
total_epochs = 0

In [475]:
lr = 1e-5
optim = torch.optim.Adam(vae.parameters(), lr=lr)

#### Pre-trained model accessible here:
vae.load_state_dict(torch.load('model_data/2gc_16dim_model_scaled.pt'))

optim.load_state_dict(torch.load('model_data/2gc_16dim_optim_scaled.pt'))

#### Training and validation functions

In [None]:
def train_epoch(vae, device, dataloader, optimizer):
    vae.train()
    train_loss = 0.0
    perc = 0.75
    for x in dataloader:
        x = x.to(device)
        x_hat = vae(x)
        rec_loss = ((x - x_hat)**2).sum()
        KL_loss = vae.encoder.KL
        loss = (perc * rec_loss) + ((1 - perc) * KL_loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    
    return train_loss / len(dataloader.dataset)

In [None]:
def test_epoch(vae, device, dataloader):
    vae.eval()
    val_loss = 0.0
    with torch.no_grad():
        for x in dataloader:
            x = x.to(device)
            x_hat = vae(x)
            loss = ((x - x_hat)**2).sum()
            val_loss += loss.item()
            
    return val_loss / len(dataloader.dataset)

In [None]:
tloss_vals = []
vloss_vals = []
img_recon_vectors = []

batch = next(iter(test_loader))
org_test_img = batch[0, :].reshape(-1)

#### Epoch loop for training and testing

In [None]:
num_epochs = 50
plt.figure(figsize=(5, 3))

for epoch in range(num_epochs):
    total_epochs += 1
    train_loss = train_epoch(vae, device, train_loader, optim)
    val_loss = test_epoch(vae, device, test_loader)

    tloss_vals.append(train_loss)
    vloss_vals.append(val_loss)

    rec_test_img = vae(batch)[0, :].reshape(-1).detach().numpy()
    img_recon_vectors.append(rec_test_img.copy())
    print('\n EPOCH {}/{} \t train loss {:.3f} \t val loss {:.3f}'.format(epoch + 1, num_epochs,train_loss,val_loss))

#### Code to save model and optim parameters
best_model_state = copy.deepcopy(vae.state_dict())

torch.save(best_model_state, 'model_data/2gc_16dim_model_scaled.pt')

best_optim_state = copy.deepcopy(optim.state_dict())

torch.save(best_optim_state, 'model_data/2gc_16dim_optim_scaled.pt')

#### Loss curves

In [None]:
loss_curves = pd.DataFrame(tloss_vals, vloss_vals)
histogram_vals = pd.DataFrame((org_test_img, rec_test_img))

In [None]:
def moving_avg(array, window_len=5):
    avg_loss = np.zeros(len(array) - window_len)
    for i in range(avg_loss.size):
        avg_loss[i] = np.mean(array[i:i + window_len])
    return avg_loss

In [None]:
vloss_avg = moving_avg(vloss_vals, window_len=10)
moving_avg = pd.DataFrame(vloss_avg)