In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

class cAAE(nn.Module):
    def __init__(self, input_dim=13456, latent_dim=128):
        super(cAAE, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, latent_dim)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, input_dim),
            nn.Sigmoid()
        )
        
        # Discriminator (adversarial network)
        self.discriminator = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        # Classifier branch
        self.classifier = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, z):
        return self.decoder(z)
    
    def discriminate(self, z):
        return self.discriminator(z)
    
    def classify(self, z):
        return self.classifier(z)
    
    def forward(self, x):
        z = self.encode(x)
        reconstructed = self.decode(z)
        discriminated = self.discriminate(z)
        classified = self.classify(z)
        return reconstructed, discriminated, classified


In [2]:
# Training loop for the cAAE
def train_caae(model, data_loader, num_epochs=50, learning_rate=0.001, classification_weight=5.0, adversarial_weight=1.0,patience=5):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    discriminator_loss_fn = nn.BCELoss()
    reconstruction_loss_fn = nn.MSELoss()
    classification_loss_fn = nn.BCELoss()
    best_loss = float('inf')
    patience_counter = 0


    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in data_loader:
            optimizer.zero_grad()

            # Forward pass
            z = model.encode(inputs)
            reconstructed, discriminated, classified = model(inputs)
            
            # Discriminator targets: 1 for real (true site), 0 for fake (predicted site)
            real_labels = torch.ones_like(discriminated)
            fake_labels = torch.zeros_like(discriminated)

            # Losses
            reconstruction_loss = reconstruction_loss_fn(reconstructed, inputs)
            classification_loss = classification_loss_fn(classified, labels)
            adversarial_loss = discriminator_loss_fn(discriminated, fake_labels)
            
            # Combine losses
            total_loss = (
                reconstruction_loss
                + classification_weight * classification_loss
                + adversarial_weight * adversarial_loss
            )

            # Backpropagation
            total_loss.backward()
            optimizer.step()

            running_loss += total_loss.item()

        epoch_loss = running_loss / len(data_loader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}")
        # Early stopping logic
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered")
                break
        

    return model


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.model_selection import train_test_split
import glob
import os
from scipy.io import loadmat, savemat
import numpy as np
import pandas as pd

In [4]:
meta_data = pd.read_csv('Sheet 1-ABIDE_metadata.csv')
abide_df = meta_data[['Subject', 'Site', 'Sex', 'Age']]

def extract_subject_id(file_path):
    # Assumes the subject ID is always after 'sub-control' or 'sub-patient' and is a numeric value
    base_name = os.path.basename(file_path)  # Get the file name (e.g., 'sub-control50197_AAL116_correlation_matrix.mat')
    subject_id = base_name.split('_')[0].replace('sub-control', '').replace('sub-patient', '')  # Extract the subject ID
    return int(subject_id)
control_path = '/Users/roshan/Desktop/fMRI/abide/'
# Use glob to search for the specific .mat file recursively
specific_files = glob.glob(os.path.join(control_path, '**', '*_AAL116_correlation_matrix.mat'), recursive=True)

patient_files = [f for f in specific_files if 'patient' in os.path.dirname(f)]
control_files = [f for f in specific_files if 'control' in os.path.dirname(f)]

all_files = control_files+patient_files

def load_matrix(file_path):
    mat_data = loadmat(file_path)
    return mat_data['data'] 

data_matrices = []
labels = []
subject_ids = []
site_info = []

# Process control files (label = 0)
for file_path in control_files:
    subject_id = extract_subject_id(file_path)
    
    # Look up site information from `abide_df` using subject ID
    site = abide_df.loc[abide_df['Subject'] == subject_id, 'Site'].values[0]
    
    # Load the matrix and flatten it
    matrix = load_matrix(file_path)
    data_matrices.append(matrix.flatten())
    labels.append(0)  # Control label
    subject_ids.append(subject_id)
    site_info.append(site)

# Process autism files (label = 1)
for file_path in patient_files:
    subject_id = extract_subject_id(file_path)
    
    # Look up site information from `abide_df` using subject ID
    site = abide_df.loc[abide_df['Subject'] == subject_id, 'Site'].values[0]
    
    # Load the matrix and flatten it
    matrix = load_matrix(file_path)
    data_matrices.append(matrix.flatten())
    labels.append(1)  # Autism label
    subject_ids.append(subject_id)
    site_info.append(site)

data_matrices = np.array(data_matrices)
labels = np.array(labels)
site_info = np.array(site_info)

# Prepare the site DataFrame for ComBat harmonization
site_df = pd.DataFrame({'site': site_info})
copy_labels = labels


In [5]:
import torch
from torch.utils.data import DataLoader, TensorDataset

# Split the data into training and testing sets
X_train, X_test, y_train, y_test, train_indices, test_indices = train_test_split(
    data_matrices, labels, range(len(labels)), test_size=0.2, random_state=42, shuffle=True
)

# Convert the data to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)  # Ensure labels are column vectors
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)    # Ensure labels are column vectors

# Create PyTorch datasets and dataloaders
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Create a DataLoader for the test set 
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Instantiate the model 
input_dim = X_train_tensor.shape[1]  # This should be 13456 (116x116 flattened)
latent_dim = 128  # can be changed

model = cAAE(input_dim=input_dim, latent_dim=latent_dim)

# Train the model
trained_model = train_caae(
    model,
    data_loader=train_loader, 
    num_epochs=50, 
    learning_rate=0.001, 
    classification_weight=5.0,  # Adjust based on previous results
    adversarial_weight=1.0      # experimented with changing this
)



Epoch [1/50], Loss: 3.6538
Epoch [2/50], Loss: 3.0521
Epoch [3/50], Loss: 2.3924
Epoch [4/50], Loss: 1.6765
Epoch [5/50], Loss: 1.2382
Epoch [6/50], Loss: 0.5102
Epoch [7/50], Loss: 0.2200
Epoch [8/50], Loss: 0.4146
Epoch [9/50], Loss: 0.4333
Epoch [10/50], Loss: 0.1059
Epoch [11/50], Loss: 0.1767
Epoch [12/50], Loss: 0.4121
Epoch [13/50], Loss: 0.4424
Epoch [14/50], Loss: 0.1121
Epoch [15/50], Loss: 0.1961
Early stopping triggered


In [6]:
# Harmonize the test data and save the results 
model.eval()
with torch.no_grad():
    harmonized_test_data, _, _ = model(X_test_tensor)
    harmonized_test_data = harmonized_test_data.numpy()

# Reshape and save the harmonized matrices 
harmonized_matrices = harmonized_test_data.reshape(-1, 116, 116)

# Save the harmonized matrices 
output_dir = '/Users/roshan/Desktop/fMRI/output/caae1/'
os.makedirs(output_dir, exist_ok=True)
file_paths = []

for i, idx in enumerate(test_indices):
    subject_id = subject_ids[idx]
    label = copy_labels[idx]
    site = site_info[idx]

    prefix = 'sub-control' if label == 0 else 'sub-patient'
    file_name = f'{prefix}{subject_id}_harmonized.mat'
    file_path = os.path.join(output_dir, file_name)
    
    savemat(file_path, {'data': harmonized_matrices[i]})
    file_paths.append(file_path)

output_csv = '/Users/roshan/Desktop/fMRI/output/caae1/harmonized_data_info.csv'
csv_df = pd.DataFrame({
    'file_path': file_paths,
    'subject_id': [subject_ids[idx] for idx in test_indices],
    'autism_label': [copy_labels[idx] for idx in test_indices],
    'site': [site_info[idx] for idx in test_indices]
})

csv_df.to_csv(output_csv, index=False)
print(f"Harmonized data info saved to: {output_csv}")


Harmonized data info saved to: /Users/roshan/Desktop/fMRI/output/caae1/harmonized_data_info.csv
