## Latent graph code and KNeighbors

In [1]:
import torch
import torch.nn as nn #nn = neural network layer
import torch.nn.functional as F # access to nn functions
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import numpy as np
from sklearn.neighbors import NearestNeighbors, KNeighborsClassifier
from tqdm import tqdm
from sklearn.decomposition import PCA
import nilearn.image
import nilearn.plotting
from sklearn.metrics import roc_auc_score, balanced_accuracy_score, accuracy_score, roc_curve, auc
import matplotlib.pyplot as plt
import umap.umap_ as umap
from sklearn.preprocessing import LabelEncoder, label_binarize
import matplotlib.patches as mpatches
from sklearn.model_selection import train_test_split
import matplotlib.lines as mlines
print(torch.__version__)

2.3.0


In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

#### Input non-labeled training and diagnosis-labeled test data

In [None]:
train_csv = pd.read_csv('model_data/train_brain_labels_deid.csv', index_col=0)
train_parquet_scaled = pd.read_parquet('model_data/train_brain_data_deid_scaled.parquet')
test_csv = pd.read_csv('model_data/test_brain_labels_deid.csv', index_col=0)
test_parquet_scaled = pd.read_parquet('model_data/test_brain_data_deid_scaled.parquet')
nd_csv = pd.read_csv('model_data/nd_brain_labels_deid.csv', index_col=0)
nd_parquet_scaled = pd.read_parquet('model_data/nd_brain_data_deid_scaled.parquet')

#### Input shortened labeled dataset

Retain only data with diagnosis of ad, bvftd, cu, or dlb, and age and sex information

In [None]:
nd_csv_short = pd.read_csv('model_data/nd_filtered_data.csv', index_col=0)

#### Load brain mask and prepare it

In [None]:
brain_mask = nilearn.image.load_img('model_data/brain_mask.nii')
img_data = np.zeros(brain_mask.shape)
vec = train_parquet_scaled.iloc[0, 0]
nz_indices = np.ma.nonzero(brain_mask.get_fdata())

#### Hyperparameters and k matrix initialization

In [None]:
BATCH_SIZE = 16
radius = 2
k_normalized = torch.load('model_data/tensor_k_norm.pt')

#### Custom dataset

In [None]:
class CustomLoader(Dataset):
    def __init__(self, images, transform=None):
        super(CustomLoader, self).__init__()
        
        self.images = images # parquet
        self.transform = transform

        self.n_samples = self.images.shape[0]

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        image = self.images.iloc[idx].copy()
        image = np.array(image)
        image = torch.from_numpy(image)
        image = image.to(torch.float32)
        image = image.unsqueeze(0)
        if self.transform is not None:
            image = self.transform(image)
        return image

#### Load datasets using cutom dataset loader

In [None]:
train_data = CustomLoader(images=train_parquet_scaled, transform=None)
train_loader = DataLoader(dataset=train_data, batch_size = BATCH_SIZE, num_workers = 0, shuffle=True)

test_data = CustomLoader(images=test_parquet_scaled, transform=None)
test_loader = DataLoader(dataset=test_data, batch_size = BATCH_SIZE, num_workers=0, shuffle=True)

nd_data = CustomLoader(images=nd_parquet_scaled, transform=None)
nd_loader = DataLoader(dataset=nd_data, batch_size=BATCH_SIZE, num_workers=0, shuffle=False)

#### Graph convolution method

In [264]:
class GraphConv(nn.Module):
    def __init__(self, in_dim, out_dim, k):
        super(GraphConv, self).__init__()
        self.d1 = nn.Conv1d(in_dim, out_dim, kernel_size=1)
        self.k = k
        
    def forward(self, imgs):
        device = imgs.device
        x = self.d1(imgs)
        x = x.cpu() @ self.k.cpu()
        x = x.to(device)
        return x

#### Graph convolutional network block

In [470]:
class GCNBlock(nn.Module):
    def __init__(self, in_dim, out_dim, k):
        super(GCNBlock, self).__init__()
        self.k = k
        self.GC1 = GraphConv(in_dim, out_dim, self.k)
        self.GC2 = GraphConv(out_dim, out_dim, self.k)
        self.bn1 = nn.BatchNorm1d(out_dim)
        self.bn2 = nn.BatchNorm1d(out_dim)
    def forward(self, x):
        identity = x.mean(axis=1, keepdim = True)
        out = self.GC1(x)
        out = self.bn1(out)
        out = F.tanh(out) #out = F.relu(out)
        out = self.GC2(out)
        out = self.bn2(out)
        out = out + identity
        #out = F.relu(out)  #testing
        return out

#### Encoder of the Variational autoencoder

In [471]:
class Encoder(nn.Module):
    def __init__(self, in_dim, conv_dim, lin_dim, latent_dims, k, device):
        super(Encoder, self).__init__()
        self.k = k
        self.device = device
        self.GCBlock = GCNBlock(in_dim, conv_dim, self.k)

        self.Lin1 = nn.Linear(conv_dim * self.k.shape[0], lin_dim)
        self.Lin2 = nn.Linear(lin_dim, latent_dims)
        self.Lin3 = nn.Linear(lin_dim, latent_dims)

        self.N = torch.distributions.Normal(0, 1)
        #self.N.loc = self.N.loc.to(device)
        #self.N.scale = self.N.scale.to(device)
        self.KL = 0.0

    def forward(self, x):
        x = x.to(self.device)
        x = self.GCBlock(x)
        x = F.tanh(x) #x = F.relu(x)
        
        x = x.flatten(start_dim=1)
        x = self.Lin1(x)
        x = F.tanh(x) #x = F.relu(x)

        mu = self.Lin2(x)
        sigma = torch.exp(self.Lin3(x))
        z = mu + sigma * self.N.sample(mu.shape).to(self.device)
        self.KL = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()

        return z

#### Decoder of the variational autoencoder

In [472]:
class Decoder(nn.Module):
    def __init__(self, in_dim, conv_dim, lin_dim,  latent_dims, k):
        super(Decoder, self).__init__()
        self.k = k

        self.decoder_lin = nn.Sequential(
            nn.Linear(latent_dims, lin_dim),
            nn.Tanh(), #nn.relu(True),
            nn.Linear(lin_dim, conv_dim * self.k.shape[0]), 
            nn.Tanh() #nn.relu(True)
        )

        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(conv_dim, self.k.shape[0]))

        self.decoder_conv = GCNBlock(conv_dim, in_dim, self.k)
        
    def forward(self, x):
        x = self.decoder_lin(x)
        x = self.unflatten(x)
        
        x = self.decoder_conv(x)
        return x

#### Variational Autoencoder class

In [473]:
class VAE(nn.Module):
    def __init__(self, in_dim, conv_dim, lin_dim, latent_dims, k, device):
        super(VAE, self).__init__()
        self.device = device
        self.k = k
        self.encoder = Encoder(in_dim, conv_dim, lin_dim, latent_dims, self.k, self.device)
        self.decoder = Decoder(in_dim, conv_dim, lin_dim, latent_dims, self.k)
    
    def forward(self, x):
        x = x.to(self.device)
        z = self.encoder(x)
        return self.decoder(z)

#### Optimizer and variational autoencoder initialization and hyperparameters

In [474]:
torch.manual_seed(0)

latent_dims = 16
vae = VAE(1, 8, 256, latent_dims=latent_dims, k=k_normalized, device=device)

vae.to(device)
total_epochs = 0

lr = 1e-5
optim = torch.optim.Adam(vae.parameters(), lr=lr)

#### Load trained variational autoencoder model and optimizer

In [None]:
vae.load_state_dict(torch.load('model_data/2gc_16dim_model_scaled.pt'))
vae.eval()
optim.load_state_dict(torch.load('model_data/2gc_16dim_optim_scaled.pt'))

#### Encode train dataset

In [None]:
train_encoded_samples = []
for sample in tqdm(train_data):
    img = sample.unsqueeze(0).to(device)
    vae.eval()
    with torch.no_grad():
        encoded_img = vae.encoder(img)
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    train_encoded_samples.append(encoded_sample)
train_encoded_samples = pd.DataFrame(train_encoded_samples, index=train_csv.index)
train_encoded_samples.to_csv('model_data/16dim_train_embeddings_scaled.csv')

#### Encode train dataset

In [None]:
test_encoded_samples = []
for sample in tqdm(test_data):
    img = sample.unsqueeze(0).to(device)
    vae.eval()
    with torch.no_grad():
        encoded_img = vae.encoder(img)
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    test_encoded_samples.append(encoded_sample)
test_encoded_samples = pd.DataFrame(test_encoded_samples, index=test_csv.index)
test_encoded_samples.to_csv('model_data/16dim_test_embeddings_scaled.csv')

#### Encode labeled dataset

In [None]:
nd_encoded_samples = []
for sample in tqdm(nd_data):
    img = sample.unsqueeze(0).to(device)
    vae.eval()
    with torch.no_grad():
        encoded_img = vae.encoder(img)
    encoded_img = encoded_img.flatten().cpu().numpy()
    encoded_sample = {f"Enc. Variable {i}": enc for i, enc in enumerate(encoded_img)}
    nd_encoded_samples.append(encoded_sample)

#### Take data are kept in shortened dataset

In [None]:
nd_encoded_samples = pd.DataFrame(nd_encoded_samples, index=nd_csv.index)
common_indices = nd_encoded_samples.index.isin(nd_csv_short.index)
nd_encoded_samples_short = nd_encoded_samples[common_indices]
nd_encoded_samples.to_csv('model_data/model_16dim_full_embeddings_scaled.csv')
nd_encoded_samples_short.to_csv('model_data/model_16dim_short_test_embeddings_scaled.csv')

#### Remove sex and age label

In [None]:
desired_columns = ['age_at_scan', 'ad', 'bvftd', 'cu', 'dlb']
nd_csv_nsex = nd_csv_short[desired_columns].copy()
nd_removed_columns = list(set(nd_csv_nsex.columns) - set(desired_columns))
for col in nd_removed_columns:
    if col in nd_csv_short:
        mask = (nd_csv_nsex[col] != 1)
        mask = mask.loc[nd_csv_nsex.index]
        nd_csv_nsex = nd_csv_nsex[mask]
nd_csv_nage = nd_csv_nsex.drop(axis=1, columns='age_at_scan')

#### Use tools from Neurology-AI-Program profile, svlite repository

Code from: https://github.com/Neurology-AI-Program/svlite

In [None]:
from svlite.svlite.data_structures import VectorTable, AnnotationTable
from svlite.svlite.graphical import KNeighbors
import networkx as nx

In [None]:
vector_table = VectorTable(nd_encoded_samples_short.values, index_col=nd_csv_nsex.index)
ann_table = AnnotationTable(nd_csv_nsex)
knn_latent = KNeighbors(n_neighbors=8, metric='cosine')
knn_latent.populate(vector_table, ann_table)
nx.write_gexf(knn_latent, 'Graphs/vae_scaled_filtered.gexf')

#### Full labels version

Include: ['age_at_scan', 'ad', 'bvftd', 'cbs', 'cu', 'dlb', 'lvppa', 'nfppa', 'pca', 'ppaos', 'psp', 'sd']

In [None]:
desired_columns = ['age_at_scan', 'ad', 'bvftd', 'cbs', 'cu', 'dlb', 'lvppa', 'nfppa', 'pca', 'ppaos', 'psp', 'sd']
nd_csv_nsex_full = nd_csv[desired_columns].copy()
nd_removed_columns = list(set(nd_csv_nsex_full.columns) - set(desired_columns))

for col in nd_removed_columns:
    if col in nd_csv:
        mask = (nd_csv_nsex_full[col] != 1)
        mask = mask.loc[nd_csv_nsex_full.index]
        nd_csv_nsex_full = nd_csv_nsex_full[mask]

In [None]:
vector_table_full = VectorTable(nd_encoded_samples.values, index_col=nd_csv_nsex_full.index)
ann_table_full = AnnotationTable(nd_csv_nsex_full)
knn_latent_full = KNeighbors(n_neighbors=8, metric='cosine')
knn_latent_full.populate(vector_table_full, ann_table_full)
nx.write_gexf(knn_latent_full, 'Graphs/vae_scaled_full.gexf')

### KNeighbors from svlite and ROC curves

In [None]:
knn_sv = KNeighbors(n_neighbors=8, metric='cosine')
knn_sv.populate(vector_table_full, ann_table_full)

odds = knn_sv.neighbor_votes(metric='odds_ratio')
pvals = knn_sv.neighbor_votes(metric='fisher', fisher_alternative='greater')

In [None]:
node_color_map = {'ad': "#006400", 'bvftd': "#FF7F50", 'cbs': 'pink', 'cu': "#838B8B", 'dlb': "#ffd700", 'lvppa': "#00ff00", 'nfppa': "#8B4513", 'pca': "#00ffff", 'ppaos': "#ff1493", 'psp': "#a020f0", 'sd': "#1e90ff", 'total': 'black'}

In [None]:
plt.style.use('default')
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['legend.fontsize'] = 12
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12

col_order = ['cu', 'dlb', 'pca', 'ad', 'lvppa', 'sd', 'bvftd', 'nfppa', 'cbs', 'psp', 'total']
degen_cols = [c for c in col_order if c not in ['cu', 'total']]

fig, ax = plt.subplots(figsize=(6, 6))
handles = []
for pt in col_order[:-1]:
    fpr, tpr, _ = roc_curve(ann_table_full.data[pt], odds[pt])
    score = auc(fpr, tpr)
    ax.plot(fpr, tpr, c=node_color_map[pt], lw=2, ls='-')
    handle = mlines.Line2D(
        [], [],
        color=node_color_map[pt],
        marker='o',
        fillstyle='full',
        label=f'{pt} (AUC={np.round(score, 2)})',
        linestyle="None",
        markersize=10
    )
    handles.append(handle)

ax.plot([0, 1], [0, 1], ls='--', c='black', lw=2)
ax.set_xlabel('False positive rate')
ax.set_ylabel('True positive rate')
ax.legend(handles=handles, bbox_to_anchor=(1.05, 0.00), loc=4, shadow=True, frameon=True, fontsize=10, ncols=1,
          edgecolor='black')

#### Neighborhood graph visualization

In [None]:
selected_node = vector_table_full.data.index[0]
neighborhood_subgraph = knn_latent_full.neighborhood_view(selected_node)

In [None]:
node_labels = {}
for node in neighborhood_subgraph.nodes:
    label_series = ann_table_full.data.loc[node, ann_table_full.binary_annotation_cols]
    label_info = label_series[label_series == 1].index.tolist()
    node_labels[node] = label_info[0] if label_info else 'unlabeled'
colors = [node_color_map.get(knn_latent_full.nodes[node]['node_color'], 'gray') for node in neighborhood_subgraph.nodes]

nx.draw(neighborhood_subgraph, labels=node_labels, with_labels=True, node_color=colors, edge_color='gray')
plt.show()

### Principal Component Analysis (PCA) comparison

In [None]:
pca = PCA(n_components=16, whiten=True)
pca_train_res = pca.fit_transform(train_parquet_scaled)
pca_res = pca.transform(nd_parquet_scaled)

In [None]:
vector_table_pca = VectorTable(pca_res, index_col=nd_csv_nsex_full.index)
ann_table_pca = AnnotationTable(nd_csv_nsex_full)
knn_latent_pca = KNeighbors(n_neighbors=8, metric='cosine')
knn_latent_pca.populate(vector_table_pca, ann_table_pca)
nx.write_gexf(knn_latent_pca, 'Graphs/pca_scaled_full.gexf')

In [None]:
odds = knn_latent_pca.neighbor_votes(metric = 'odds_ratio')
pvals = knn_latent_pca.neighbor_votes(metric = 'fisher', fisher_alternative = 'greater')

In [None]:
plt.style.use('default')
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['legend.fontsize'] = 12
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12

col_order = ['cu', 'dlb', 'pca', 'ad', 'lvppa', 'sd', 'bvftd', 'nfppa', 'cbs', 'psp', 'total']
degen_cols = [c for c in col_order if c not in ['cu', 'total']]

fig, ax = plt.subplots(figsize = (6, 6))
handles = []
for pt in col_order[:-1]:
    fpr, tpr, _ = roc_curve(ann_table_pca.data[pt], odds[pt])
    score = auc(fpr, tpr)
    ax.plot(fpr, tpr, c = node_color_map[pt], lw = 2, ls = '-')
    handle = mlines.Line2D(
        [], [], 
        color = node_color_map[pt], 
        marker = 'o', 
        fillstyle = 'full', 
        label = f'{pt} (AUC={np.round(score, 2)})', 
        linestyle = "None",
        markersize = 10
    )
    handles.append(handle)

ax.plot([0, 1], [0, 1], ls = '--', c = 'black', lw = 2)
ax.set_xlabel('False positive rate')
ax.set_ylabel('True positive rate')
ax.legend(handles = handles, bbox_to_anchor=(1.05, 0.00), loc = 4, shadow = True, frameon = True, fontsize = 10, ncols = 1, edgecolor ='black')