## Feature Extraction code

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import nilearn
from nilearn import plotting, image, surface, datasets
from sklearn.preprocessing import StandardScaler
import nibabel
import ants
from sklearn.neighbors import NearestNeighbors
print(torch.__version__)

2.3.0


In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

### Input de-identified, non-labeled training, testing datasets

In [None]:
test_csv = pd.read_csv('model_data/test_brain_labels_deid.csv', index_col=0)
test_parquet_scaled = pd.read_parquet('model_data/test_brain_data_deid_scaled.parquet')

train_csv = pd.read_csv('model_data/train_brain_labels_deid.csv', index_col=0)
train_parquet_scaled = pd.read_parquet('model_data/train_brain_data_deid_scaled.parquet')
nd_csv = pd.read_csv('model_data/nd_brain_labels_deid.csv', index_col=0)
nd_parquet_scaled = pd.read_parquet('model_data/nd_brain_data_deid_scaled.parquet')

#### Load brain mask and prepare it

In [None]:
brain_mask = nilearn.image.load_img('model_data/brain_mask.nii')
img_data = np.zeros(brain_mask.shape)
vec = train_parquet_scaled.iloc[0, 0]
nz_indices = np.ma.nonzero(brain_mask.get_fdata())

In [None]:
BATCH_SIZE = 16
k_normalized = torch.load('model_data/tensor_k_norm.pt')

#### Custom dataset

In [None]:
class CustomLoader(Dataset):
    def __init__(self, images, transform=None):
        super(CustomLoader, self).__init__()
        
        self.images = images # parquet
        self.transform = transform

        self.n_samples = self.images.shape[0]

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        image = self.images.iloc[idx].copy()
        image = np.array(image)
        image = torch.from_numpy(image)
        image = image.to(torch.float32)
        image = image.unsqueeze(0)
        if self.transform is not None:
            image = self.transform(image)
        return image

#### Load datasets using cutom dataset loader

In [None]:
test_data = CustomLoader(images=test_parquet_scaled, transform=None)
test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, num_workers=0, shuffle=True)
nd_data = CustomLoader(images=nd_parquet_scaled, transform=None)
nd_loader = DataLoader(dataset=nd_data, batch_size=BATCH_SIZE, num_workers=0, shuffle=False)

#### Graph convolution method

In [264]:
class GraphConv(nn.Module):
    def __init__(self, in_dim, out_dim, k):
        super(GraphConv, self).__init__()
        self.d1 = nn.Conv1d(in_dim, out_dim, kernel_size=1)
        self.k = k
        
    def forward(self, imgs):
        device = imgs.device
        x = self.d1(imgs)
        x = x.cpu() @ self.k.cpu()
        x = x.to(device)
        return x

#### Graph convolutional network block

In [470]:
class GCNBlock(nn.Module):
    def __init__(self, in_dim, out_dim, k):
        super(GCNBlock, self).__init__()
        self.k = k
        self.GC1 = GraphConv(in_dim, out_dim, self.k)
        self.GC2 = GraphConv(out_dim, out_dim, self.k)
        self.bn1 = nn.BatchNorm1d(out_dim)
        self.bn2 = nn.BatchNorm1d(out_dim)
    def forward(self, x):
        identity = x.mean(axis=1, keepdim = True)
        out = self.GC1(x)
        out = self.bn1(out)
        out = F.tanh(out) #out = F.relu(out)
        out = self.GC2(out)
        out = self.bn2(out)
        out = out + identity
        #out = F.relu(out)  #testing
        return out

#### Encoder of the Variational autoencoder

In [471]:
class Encoder(nn.Module):
    def __init__(self, in_dim, conv_dim, lin_dim, latent_dims, k, device):
        super(Encoder, self).__init__()
        self.k = k
        self.device = device
        self.GCBlock = GCNBlock(in_dim, conv_dim, self.k)

        self.Lin1 = nn.Linear(conv_dim * self.k.shape[0], lin_dim)
        self.Lin2 = nn.Linear(lin_dim, latent_dims)
        self.Lin3 = nn.Linear(lin_dim, latent_dims)

        self.N = torch.distributions.Normal(0, 1)
        #self.N.loc = self.N.loc.to(device)
        #self.N.scale = self.N.scale.to(device)
        self.KL = 0.0

    def forward(self, x):
        x = x.to(self.device)
        x = self.GCBlock(x)
        x = F.tanh(x) #x = F.relu(x)
        
        x = x.flatten(start_dim=1)
        x = self.Lin1(x)
        x = F.tanh(x) #x = F.relu(x)

        mu = self.Lin2(x)
        sigma = torch.exp(self.Lin3(x))
        z = mu + sigma * self.N.sample(mu.shape).to(self.device)
        self.KL = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()

        return z

#### Decoder of the variational autoencoder

In [472]:
class Decoder(nn.Module):
    def __init__(self, in_dim, conv_dim, lin_dim,  latent_dims, k):
        super(Decoder, self).__init__()
        self.k = k

        self.decoder_lin = nn.Sequential(
            nn.Linear(latent_dims, lin_dim),
            nn.Tanh(), #nn.relu(True),
            nn.Linear(lin_dim, conv_dim * self.k.shape[0]), 
            nn.Tanh() #nn.relu(True)
        )

        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(conv_dim, self.k.shape[0]))

        self.decoder_conv = GCNBlock(conv_dim, in_dim, self.k)
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        
        x = self.decoder_conv(x)
        return x

#### Variational Autoencoder class

In [473]:
class VAE(nn.Module):
    def __init__(self, in_dim, conv_dim, lin_dim, latent_dims, k, device):
        super(VAE, self).__init__()
        self.device = device
        self.k = k
        self.encoder = Encoder(in_dim, conv_dim, lin_dim, latent_dims, self.k, self.device)
        self.decoder = Decoder(in_dim, conv_dim, lin_dim, latent_dims, self.k)
    
    def forward(self, x):
        x = x.to(self.device)
        z = self.encoder(x)
        return self.decoder(z)

#### Optimizer and variational autoencoder initialization and hyperparameters

In [474]:
torch.manual_seed(0)

latent_dims = 16
vae = VAE(1, 8, 256, latent_dims=latent_dims, k=k_normalized, device=device)

vae.to(device)
total_epochs = 0

lr = 1e-5
optim = torch.optim.Adam(vae.parameters(), lr=lr)

#### Load trained variational autoencoder model and optimizer

In [None]:
vae.load_state_dict(torch.load('model_data/2gc_16dim_model_scaled.pt'))
vae.eval()
optim.load_state_dict(torch.load('model_data/2gc_16dim_optim_scaled.pt'))

#### Encode the test dataset

In [None]:
test_encoded_samples = []
for sample in tqdm(test_data):
    img = sample.unsqueeze(0).to(device)
    vae.eval()
    with torch.no_grad():
        encoded_img = vae.encoder(img)
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    test_encoded_samples.append(encoded_sample)
test_encoded_samples = pd.DataFrame(test_encoded_samples)

#### Decode the test dataset

In [None]:
test_reconstructed_samples = []
for sample in tqdm(test_data):
    img = sample.unsqueeze(0).to(device)
    vae.eval()
    with torch.no_grad():
        reconstructed_img= vae(img)
    reconstructed_img = reconstructed_img.flatten().cpu().numpy()
    reconstructed_sample = {f"Rec. Pixel {i}": pix for i, pix in enumerate(reconstructed_img)}
    test_reconstructed_samples.append(reconstructed_sample)
test_reconstructed_samples = pd.DataFrame(test_reconstructed_samples)

#### Decode the nd dataset

In [None]:
nd_reconstructed_samples = []
for sample in tqdm(nd_data):
    img = sample.unsqueeze(0).to(device)
    vae.eval()
    with torch.no_grad():
        reconstructed_img= vae(img)
    reconstructed_img = reconstructed_img.flatten().cpu().numpy()
    reconstructed_sample = {f"Rec. Pixel {i}": pix for i, pix in enumerate(reconstructed_img)}
    nd_reconstructed_samples.append(reconstructed_sample)
nd_reconstructed_samples = pd.DataFrame(nd_reconstructed_samples)

In [None]:
test_reconstructed_samples.to_parquet('model_data/recon_test_data.parquet')
test_reconstructed = pd.read_parquet('model_data/recon_test_data.parquet')

#### Decode function

In [None]:
def decode_img(vae, out_img):
    with torch.no_grad():
        decoded = vae.decoder(out_img).cpu().numpy()
    decoded = decoded.reshape(1, -1) 
    return decoded

#### Show brain image function

In [None]:
def plot_single_img(vae, out_img, img_data, nz_indices, brain_mask, path, scaler):
    decoded = decode_img(vae, out_img)
    decoded = scaler.transform(decoded)
    out_img_data = img_data
    out_img_data[nz_indices] = decoded

    out_img = nilearn.image.new_img_like(brain_mask, out_img_data)
    nibabel.save(out_img, path)

#### Standard Scaler for averages

In [None]:
scaler = StandardScaler().fit(test_reconstructed.values)
std_data = torch.tensor(np.std(test_encoded_samples, axis=0))
mean_embeddings = torch.tensor(np.mean(test_encoded_samples, axis=0))

#### Standard deviation array

In [None]:
std_array = np.zeros((16, 5, 16))
for dim in range(std_array.shape[0]):
    for i in range(std_array.shape[1]):
        std_array[dim, i] = mean_embeddings.squeeze()
        std_array[dim, i, dim] += std_data[dim] * (i-2)
std_array = torch.tensor(std_array, dtype=torch.float32)
std_array = std_array.unsqueeze(2)

In [None]:
vae.to(device)
std_array = std_array.to(device)

#### Create reconstructed manipulated images

In [None]:
for dim in range(std_array.shape[0]):
    for i in [-2, -1, 0, 1, 2]:
        path = f'path_files/dim_{dim}_i_{i}.nii.gz'
        plot_single_img(vae, std_array[dim, i+2], img_data, nz_indices, brain_mask, path, scaler)

In [None]:
bg = nilearn.image.load_img('model_data/mcalt_t1.nii')

#### Show the feature extraction images

In [None]:
fig, axes = plt.subplots(nrows=16, ncols=5, figsize=(15, 30))

for dim in range(axes.shape[0]):
    for i in range(axes.shape[1]):
        if i == 2:
            axes[dim, i].set_title(f"Embedding Dim. {dim}")
        else:
            axes[dim, i].set_title(f"{i-2} std")
        path = f'path_files/dim_{dim}_i_{i-2}.nii.gz'
        out_std_img = nilearn.image.load_img(path)
        out_smooth = nilearn.image.smooth_img(out_std_img, fwhm=6)
        nilearn.plotting.plot_stat_map(out_smooth, cut_coords=1, bg_img=bg, axes=axes[dim, i], colorbar=False, display_mode='x', vmax=2)
plt.show()

#### Save std embeddings to csv

In [None]:
std_array_out = std_array.cpu().squeeze(2).numpy()
arr2d_std = np.empty((16, 5), dtype=object)
for i in range(std_array_out.shape[0]):
    for j in range(std_array_out.shape[1]):
        arr2d_std[i, j] = std_array_out[i, j].tolist()
std_df = pd.DataFrame(arr2d_std)
std_df.to_csv('model_data/std_embedding_arr.csv')

#### Save histogram comparison of embeddings to csv

In [None]:
vec = scaler.mean_
dimg = decode_img(vae, mean_embeddings.to(device)).squeeze()
hist_rec_comparison = pd.DataFrame((vec, dimg))
hist_rec_comparison.to_csv('model_data/hist_rec_comparison.csv')

#### MCALT to MNI registration

In [None]:
brain_mask = nilearn.image.load_img('model_data/brain_mask.nii')
mcalt = nilearn.image.load_img('model_data/mcalt_t1.nii')
mni = nilearn.image.load_img('model_data/mni_t1.nii')
ants_mcalt = ants.from_nibabel(mcalt)
ants_mni = ants.from_nibabel(mni)
reg_dict = ants.registration(fixed=ants_mni, moving=ants_mcalt, type_of_transform='SyN')
ants_transform = reg_dict['fwdtransforms']

#### Surface rendering of features

In [None]:
def get_path_plot(path, hemi='right', view='lateral', mask=brain_mask, interactive=False, threshold=None, color_map='turbo', colorbar=True, transform=ants_transform, fixed=ants_mni, moving=ants_mcalt, vmin=None, vmax=None):
    img = nilearn.image.load_img(path)
    img_smooth = nilearn.image.smooth_img(img, fwhm=6)
    ants_img = ants.from_nibabel(img_smooth)
    ants_trans_img = ants.apply_transforms(fixed=fixed, moving=ants_img, transformlist=transform)
    trans_img = ants.to_nibabel(ants_trans_img)
    fsaverage = datasets.fetch_surf_fsaverage()
    if hemi == 'right':
        mesh = surface.load_surf_mesh(fsaverage.pial_right)
    elif hemi == 'left':
        mesh = surface.load_surf_mesh(fsaverage.pial_left)
    texture = surface.vol_to_surf(trans_img, mesh)

    
    if not interactive:
        fig = plotting.plot_surf_stat_map(mesh, texture, hemi=hemi, view=view, colorbar=colorbar, threshold=threshold, bg_map=fsaverage.sulc_right, cmap=color_map, vmin=vmin, vmax=vmax)                       
    else:
        fig = plotting.plot_surf_stat_map(mesh, texture, hemi=hemi, view=view, colorbar=colorbar, threshold=threshold, bg_map=fsaverage.sulc_right, cmap=color_map, vmin=vmin, vmax=vmax, engine='plotly') 

    return fig

#### Example figure code

In [None]:
figure = get_path_plot(path='path_files/dim_1_i_-2.nii.gz', view='lateral', colorbar=False)