In [1]:
import torch
from tqdm import tqdm
import sys
sys.path.append('..')

from models.sfcn_original import SFCN
from utils.datasets import TorchDataset as TD

In [2]:
mode = 'val'
test_block = 'flat'

In [3]:
model = SFCN(output_dim=1, channel_number=[28, 58, 128, 256, 256, 64]).to('cuda')
checkpoint = torch.load('checkpoints/PD-SFCN/best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [4]:
block_dict = {'flat': model.classifier.flatten}

In [5]:
import torch
import numpy as np
from sklearn.decomposition import PCA
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

model.eval()

activations = []
PD = []
sex = []
study = []
scanner_type = []


def hook_fn(module, input, module_output):
    activations.append(module_output.cpu().detach().numpy())


hook = block_dict[test_block].register_forward_hook(hook_fn)

td = TD(f'/data/Data/PD/{mode}')
test_loader = DataLoader(td, batch_size=8, shuffle=False)
# Collect activations
with torch.no_grad():
    for batch in tqdm(test_loader):
        _ = model(batch[0].to('cuda'))
        PD.extend(batch[1])
        sex.extend(batch[2])
        study.extend(batch[3])
        scanner_type.extend(batch[4])
    hook.remove()

100%|██████████| 11/11 [00:03<00:00,  2.96it/s]


In [6]:
def to_onehot(labels, num_classes=None):
    # Convert to numpy array if not already
    labels = np.array(labels)
        
    # Determine number of classes if not provided    
    if num_classes is None:
        num_classes = np.max(labels) + 1
    
    # Create zero matrix of shape (samples, num_classes)
    onehot = np.zeros((len(labels), num_classes))
    
    # Set 1s at the appropriate positions
    onehot[np.arange(len(labels)), labels] = 1
    
    return onehot

In [7]:
activations = np.vstack(activations).reshape(td.__len__(), -1).squeeze()
PD = np.array(PD)
sex = np.array(sex)
study = to_onehot(np.array(study))
scanner_type = to_onehot(np.array(scanner_type))

In [8]:
PD.shape

(84,)

In [9]:
import os

act_save_dir = f'/data/Data/PD/activations_{test_block}_{mode}'
os.makedirs(act_save_dir, exist_ok=True)

for i in range(len(activations)):
    ten2 = torch.tensor(np.hstack([study[i], sex[i], scanner_type[i], PD[i], activations[i]])).to(torch.float)
    torch.save(ten2, os.path.join(act_save_dir, f'{i}'))

In [10]:
import pickle


if mode == 'train':
    print('Fitting PCA')
    pca = PCA(n_components=64)
    pca_result = pca.fit_transform(activations)
    
    save_dir = '/data/Data/PD/PCAs'
    os.makedirs(save_dir, exist_ok=True)
    
    with open('/data/Data/PD/PCAs/pca_{test_block}_train.pkl', 'wb') as pca_file:
        pickle.dump(pca, pca_file)
        
else:
    with open('/data/Data/PD/PCAs/pca_{test_block}_train.pkl', 'rb') as pca_file:
        pca = pickle.load(pca_file)
    
    pca_result = pca.transform(activations)

In [11]:
import os

save_dir = f'/data/Data/PD/pca_{test_block}_{mode}'
os.makedirs(save_dir, exist_ok=True)

for i in range(len(pca_result)):
    ten = torch.tensor(np.hstack([study[i], sex[i], scanner_type[i], PD[i], pca_result[i]])).to(torch.float)
    torch.save(ten, os.path.join(save_dir, f'{i}'))