In [None]:
# Model testing on our own dataset

In [1]:

import sys
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Add TSAN-brain-age-estimation to the Python path
sys.path.append(os.path.abspath("./TSAN-brain-age-estimation"))

# Import the ScaleDense module
from TSAN.model.ScaleDense import ScaleDense
from TSAN.model.ScaleDense import get_parameter_number
from TSAN.model.Second_stage_ScaleDense import second_stage_scaledense
import os
import torch
import nibabel as nib
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader

import os
import numpy as np
import nibabel as nib
import pandas as pd
import torch

In [2]:
# Assuming these are defined elsewhere in your code
def nii_loader(path):
    img = nib.load(str(path))
    data = img.get_fdata()
    return data

def read_table(path):
    return pd.read_excel(path).values  # default to first sheet

def white0(image, threshold=0):
    image = image.astype(np.float32)
    mask = (image > threshold).astype(int)
    image_h = image * mask
    image_l = image * (1 - mask)
    mean = np.sum(image_h) / np.sum(mask)
    std = np.sqrt(np.sum(np.abs(image_h - mean)**2 * mask) / np.sum(mask))
    if std > 0:
        ret = (image_h - mean) / std + image_l
    else:
        ret = image * 0.
    return ret

# Dataset class to load images and labels
class IMG_Folder(torch.utils.data.Dataset):
    def __init__(self, excel_path, data_path, loader=nii_loader, transforms=None):
        self.root = data_path
        self.sub_fns = sorted(os.listdir(self.root))
        self.table_refer = read_table(excel_path)
        self.loader = loader
        self.transform = transforms

    def __len__(self):
        return len(self.sub_fns)

    def __getitem__(self, index):
        sub_fn = self.sub_fns[index]
                
        # Skip non-NIfTI files such as .xlsx
        if not sub_fn.endswith('.nii'):
            return None  # Skip non-NIfTI files like .xlsx or others
        
        
        img = None  # Initialize img to None in case no match is found
        # Remove file extension and make case-insensitive for comparison
        sub_fn_clean = sub_fn.replace(".nii", "").strip().lower()
        
        print(f"Looking for image: {sub_fn_clean}")  # Debugging line
        for f in self.table_refer:
            sid = str(f[0]).strip().lower()  # Clean sid from table
            slabel = int(f[1])
            smale = f[2]
            
            print(f"Checking against subject ID: {sid}")  # Debugging line
            
            # Match cleaned sid to the filename
            if sid not in sub_fn_clean:
                continue
            
            sub_path = os.path.join(self.root, sub_fn)
            print(f"Loading from: {sub_path}")  # Debugging line
            
            img = self.loader(sub_path)  # Now img is assigned here
            img = white0(img)  # Standardize the image
            if self.transform is not None:
                img = self.transform(img)
            img = np.expand_dims(img, axis=0)
            img = np.ascontiguousarray(img, dtype=np.float32)
            img = torch.from_numpy(img).type(torch.FloatTensor)
            break
        
        # Check if img was never assigned, raise an error or handle as needed
        if img is None:
            raise ValueError(f"Image for subject {sub_fn} not found in table references.")
        
        return img, sid, slabel, smale


# First part of the model 

In [3]:

# Initialize the model

model = ScaleDense(nb_filter=8, nb_block=5, use_gender=True)

# Check model parameters
print(model)
print("Parameter count:", get_parameter_number(model))

#Load the model weights

# Path to the model weights
weights_path = r"C:\Users\Vero Ramirez\Documents\GitHub\sMRI_analysis\weights\ScaleDense\ScaleDense_best_model.pth.tar"

# Try loading the model weights onto the GPU
try:
    # Check if a GPU is available and load accordingly
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    checkpoint = torch.load(weights_path, map_location=device, weights_only=True)
    print("Weights loaded successfully to", device)
except Exception as e:
    print(f"Error loading weights: {e}")


ScaleDense(
  (pre): Sequential(
    (0): Conv3d(1, 8, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(1, 1, 1), dilation=(2, 2, 2))
    (1): ELU(alpha=1.0)
  )
  (block): Sequential(
    (0): dense_layer(
      (block): Sequential(
        (0): AC_layer(
          (conv1): Sequential(
            (0): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
            (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (conv2): Sequential(
            (0): Conv3d(8, 16, kernel_size=(1, 1, 3), stride=(1, 1, 1), padding=(0, 0, 1), bias=False)
            (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (conv3): Sequential(
            (0): Conv3d(8, 16, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
            (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (co

In [4]:
# Load your dataset
excel_path = r"C:\Users\Vero Ramirez\Desktop\Dataset.xlsx"
data_path = r"C:\Users\Vero Ramirez\Desktop\sMRI_test"

# Create the dataset object
dataset = IMG_Folder(excel_path, data_path)

# Create the DataLoader
batch_size = 8  # You can adjust the batch size as needed
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)


In [5]:
# Iterate through the first few samples to inspect them
for i in range(5):  # Adjust the range if needed
    try:
        sample = dataset[i]
        print(f"Sample {i}:")
        print(f"  Image Tensor Shape: {sample[0].shape}")
        print(f"  Subject ID: {sample[1]}")
        print(f"  Disease Age Input: {sample[2]}")
        print(f"  Sex: {sample[3]}")
    except Exception as e:
        print(f"Error processing sample {i}: {e}")


print("Excel table contents:")
for row in dataset.table_refer:
    print(row)

images, ids, dis_age_input, smale = next(iter(data_loader))
print(f"Image Shape: {images.shape}, Disease Age Input Shape: {dis_age_input.shape}, Male Input Shape: {smale.shape}")


Looking for image: ctr_001_processed
Checking against subject ID: ctr_001_processed
Loading from: C:\Users\Vero Ramirez\Desktop\sMRI_test\CTR_001_Processed.nii
Sample 0:
  Image Tensor Shape: torch.Size([1, 211, 256, 256])
  Subject ID: ctr_001_processed
  Disease Age Input: 56
  Sex: 0
Looking for image: ctr_006_processed
Checking against subject ID: ctr_001_processed
Checking against subject ID: ctr_006_processed
Loading from: C:\Users\Vero Ramirez\Desktop\sMRI_test\CTR_006_Processed.nii
Sample 1:
  Image Tensor Shape: torch.Size([1, 211, 256, 256])
  Subject ID: ctr_006_processed
  Disease Age Input: 30
  Sex: 0
Looking for image: ctr_008_processed
Checking against subject ID: ctr_001_processed
Checking against subject ID: ctr_006_processed
Checking against subject ID: ctr_008_processed
Loading from: C:\Users\Vero Ramirez\Desktop\sMRI_test\CTR_008_Processed.nii
Sample 2:
  Image Tensor Shape: torch.Size([1, 211, 256, 256])
  Subject ID: ctr_008_processed
  Disease Age Input: 35
  Se

In [6]:

# Initialize your model
model = ScaleDense(nb_filter=8, nb_block=5, use_gender=False)

# Load the model weights
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load model weights onto the GPU or CPU
try:
    checkpoint = torch.load(weights_path, map_location=device, weights_only=True)
    print("Weights loaded successfully to", device)
except Exception as e:
    print(f"Error loading weights: {e}")

# Set model to evaluation mode
model.eval()

# Function for inference
def predict(model, data_loader, device):
    predictions = []
    with torch.no_grad():  # No need to compute gradients during inference
        for images, ids, labels, sexes in data_loader:
            images = images.to(device)
            sexes = sexes.to(device).float()  # Assuming 'use_gender=True'
            
            # Forward pass
            outputs = model(images, sexes)  # Assuming your model uses both image and gender
            predictions.append(outputs.cpu().numpy())  # Store the predictions
    return predictions

# Get the predictions
predictions = predict(model, data_loader, device)

# You can now analyze or store the predictions as needed
print("Predictions:", predictions)

Weights loaded successfully to cpu
Looking for image: ctr_001_processed
Checking against subject ID: ctr_001_processed
Loading from: C:\Users\Vero Ramirez\Desktop\sMRI_test\CTR_001_Processed.nii
Looking for image: ctr_006_processed
Checking against subject ID: ctr_001_processed
Checking against subject ID: ctr_006_processed
Loading from: C:\Users\Vero Ramirez\Desktop\sMRI_test\CTR_006_Processed.nii
Looking for image: ctr_008_processed
Checking against subject ID: ctr_001_processed
Checking against subject ID: ctr_006_processed
Checking against subject ID: ctr_008_processed
Loading from: C:\Users\Vero Ramirez\Desktop\sMRI_test\CTR_008_Processed.nii
Looking for image: ctr_009_processed
Checking against subject ID: ctr_001_processed
Checking against subject ID: ctr_006_processed
Checking against subject ID: ctr_008_processed
Checking against subject ID: ctr_009_processed
Loading from: C:\Users\Vero Ramirez\Desktop\sMRI_test\CTR_009_Processed.nii
Looking for image: ctr_011_processed
Checki

In [7]:
# save the predictions to a xsls file with the name of the subject 
# and the predicted age

# Save the predictions to an Excel file
output_path = r"C:\Users\Vero Ramirez\Desktop\predictions.xlsx"
#squeeze the predictions
predictions = np.squeeze(predictions)

predictions_df = pd.DataFrame(predictions, columns=["Predicted Age"])

# names of the files in data_path
names = [os.path.splitext(f)[0] for f in os.listdir(data_path) if os.path.isfile(os.path.join(data_path, f))]
predictions_df["Subject ID"] = names

# Reorder columns to have "Subject ID" first
predictions_df = predictions_df[["Subject ID", "Predicted Age"]]

# Save the predictions to an Excel file
predictions_df.to_excel(output_path, index=False)

print("Predictions saved to", output_path)

Predictions saved to C:\Users\Vero Ramirez\Desktop\predictions.xlsx


# Second part of the model 

In [8]:
# For the second model

# Initialize the model

model_2 = second_stage_scaledense(nb_filter = 8, use_gender=True, nb_block=5)

# Check model parameters
print(model_2)

second_stage_scaledense(
  (pre): Sequential(
    (0): Conv3d(1, 8, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(1, 1, 1), dilation=(2, 2, 2))
    (1): ELU(alpha=1.0)
  )
  (block): Sequential(
    (0): dense_layer(
      (block): Sequential(
        (0): AC_layer(
          (conv1): Sequential(
            (0): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
            (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (conv2): Sequential(
            (0): Conv3d(8, 16, kernel_size=(1, 1, 3), stride=(1, 1, 1), padding=(0, 0, 1), bias=False)
            (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (conv3): Sequential(
            (0): Conv3d(8, 16, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
            (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )


In [9]:

class IMG_Folder_2(torch.utils.data.Dataset):
    def __init__(self,excel_path, data_path, loader=nii_loader,transforms=None):
        self.root = data_path
        self.sub_fns = sorted(os.listdir(self.root))
        self.table_refer = read_table(excel_path)
        self.loader = loader
        self.transform = transforms

    def __len__(self):
        return len(self.sub_fns)
    
    def __getitem__(self, index):
        sub_fn = self.sub_fns[index]
        img, sid, dis_age_input, smale = None, None, None, None
    
        for f in self.table_refer:
            sid = str(f[0])
            dis_age_input = int(f[1])
            smale = f[2]
            if sid in sub_fn:
                sub_path = os.path.join(self.root, sub_fn)
                img = self.loader(sub_path)
                img = white0(img)
                if self.transform:
                    img = self.transform(img)
                img = np.expand_dims(img, axis=0)
                img = np.ascontiguousarray(img, dtype=np.float32)
                img = torch.from_numpy(img).float()
                break
        return img, sid, torch.tensor(dis_age_input), torch.tensor(smale)





'''
    def __getitem__(self,index):
        sub_fn = self.sub_fns[index]
        for f in self.table_refer:
            
            sid = str(f[0])
            dis_age_input = (int(f[1]))
            smale = f[2]
            if sid not in sub_fn:
                continue
            sub_path = os.path.join(self.root, sub_fn)
            img = self.loader(sub_path)
            img = white0(img)
            if self.transform is not None:
                img = self.transform(img)
            img = np.expand_dims(img, axis=0)
            img = np.ascontiguousarray(img, dtype= np.float32)
            img = torch.from_numpy(img).type(torch.FloatTensor)

            break
        return (img, sid, dis_age_input, smale)
    
    '''
    


'\n    def __getitem__(self,index):\n        sub_fn = self.sub_fns[index]\n        for f in self.table_refer:\n            \n            sid = str(f[0])\n            dis_age_input = (int(f[1]))\n            smale = f[2]\n            if sid not in sub_fn:\n                continue\n            sub_path = os.path.join(self.root, sub_fn)\n            img = self.loader(sub_path)\n            img = white0(img)\n            if self.transform is not None:\n                img = self.transform(img)\n            img = np.expand_dims(img, axis=0)\n            img = np.ascontiguousarray(img, dtype= np.float32)\n            img = torch.from_numpy(img).type(torch.FloatTensor)\n\n            break\n        return (img, sid, dis_age_input, smale)\n    \n    '

In [10]:
# For the second model

# Initialize the model

model_2 = second_stage_scaledense(nb_filter = 8, use_gender=True, nb_block=5)

# Check model parameters
print(model_2)

# Load the model weights

second_weights_path = r"C:\Users\Vero Ramirez\Documents\GitHub\sMRI_analysis\weights\Second_stage\Second_ScaleDense_best_model.pth.tar"
# use the past predictions and the image files to run the second model

# Try loading the model weights onto the GPU
try:
    # Check if a GPU is available and load accordingly
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    checkpoint = torch.load(weights_path, map_location=device, weights_only=True)
    print("Weights loaded successfully to", device)
except Exception as e:
    print(f"Error loading weights: {e}")

second_stage_scaledense(
  (pre): Sequential(
    (0): Conv3d(1, 8, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(1, 1, 1), dilation=(2, 2, 2))
    (1): ELU(alpha=1.0)
  )
  (block): Sequential(
    (0): dense_layer(
      (block): Sequential(
        (0): AC_layer(
          (conv1): Sequential(
            (0): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
            (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (conv2): Sequential(
            (0): Conv3d(8, 16, kernel_size=(1, 1, 3), stride=(1, 1, 1), padding=(0, 0, 1), bias=False)
            (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (conv3): Sequential(
            (0): Conv3d(8, 16, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
            (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )


In [11]:
data_path = r"C:\Users\Vero Ramirez\Desktop\sMRI_test"
second_path = r"C:\Users\Vero Ramirez\Desktop\predictions.xlsx"

# Create the dataset object with the correct Excel path
dataset_2 = IMG_Folder_2(excel_path=second_path, data_path=data_path)


In [12]:
# Create the DataLoader
batch_size = 8  # You can adjust the batch size as needed
data_loader_2 = DataLoader(dataset_2, batch_size=batch_size, shuffle=False)


In [13]:
# Iterate through the first few samples to inspect them
for i in range(5):  # Adjust the range if needed
    try:
        sample = dataset_2[i]
        print(f"Sample {i}:")
        print(f"  Image Tensor Shape: {sample[0].shape}")
        print(f"  Subject ID: {sample[1]}")
        print(f"  Disease Age Input: {sample[2]}")
        print(f"  Sex: {sample[3]}")
    except Exception as e:
        print(f"Error processing sample {i}: {e}")


print("Excel table contents:")
for row in dataset_2.table_refer:
    print(row)

images, ids, dis_age_input, smale = next(iter(data_loader_2))
print(f"Image Shape: {images.shape}, Disease Age Input Shape: {dis_age_input.shape}, Male Input Shape: {smale.shape}")


Sample 0:
  Image Tensor Shape: torch.Size([1, 211, 256, 256])
  Subject ID: CTR_001_Processed
  Disease Age Input: 0
  Sex: 0
Sample 1:
  Image Tensor Shape: torch.Size([1, 211, 256, 256])
  Subject ID: CTR_006_Processed
  Disease Age Input: 0
  Sex: 0
Sample 2:
  Image Tensor Shape: torch.Size([1, 211, 256, 256])
  Subject ID: CTR_008_Processed
  Disease Age Input: 0
  Sex: 1
Sample 3:
  Image Tensor Shape: torch.Size([1, 211, 256, 256])
  Subject ID: CTR_009_Processed
  Disease Age Input: 0
  Sex: 1
Sample 4:
  Image Tensor Shape: torch.Size([1, 211, 256, 256])
  Subject ID: CTR_011_Processed
  Disease Age Input: 0
  Sex: 0
Excel table contents:
['CTR_001_Processed' 0.1824081242084503 0]
['CTR_006_Processed' 0.1792014390230179 0]
['CTR_008_Processed' 0.1803987324237823 1]
['CTR_009_Processed' 0.1810898780822754 1]
['CTR_011_Processed' 0.1796250343322754 0]
Image Shape: torch.Size([5, 1, 211, 256, 256]), Disease Age Input Shape: torch.Size([5]), Male Input Shape: torch.Size([5])


In [14]:

# Ensure that the model is set to evaluation mode
model_2.eval()

# Function to run predictions
def predict(model, data_loader, device):
    model.eval()  # Set model to evaluation mode
    predictions = []
    
    with torch.no_grad():  # Disable gradient computation during inference
        for images, _, dis_age_input, male_input in data_loader:
            # Move data to the device (GPU or CPU)
            images = images.to(device).float()  # Ensure images are float
            dis_age_input = dis_age_input.to(device).float()  # Ensure age input is float
            male_input = male_input.to(device).float()  # Ensure male input is float
            
            # Forward pass
            outputs = model(images, dis_age_input, male_input)
            
            # If outputs are a tuple, unpack the relevant one
            if isinstance(outputs, tuple):
                # Unpack the tuple (assuming the first element is the prediction)
                outputs = outputs[0]
            
            # Collect the predictions
            predictions.append(outputs.cpu().numpy())  # Move to CPU and convert to numpy array for later use
    
    return predictions

# Set device for computation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Run predictions using the model
last_predictions = predict(model_2, data_loader_2, device)


In [21]:
# print the predictions
# Ensure last_predictions is a 2D array
last_predictions = np.squeeze(last_predictions)

# Print the predictions
print("Predictions:", last_predictions)

# Convert to a DataFrame and save to an Excel file
last_predictions_df = pd.DataFrame(last_predictions, columns=["Predicted Age"])
last_predictions_df["Subject ID"] = ids

#save the predictions to an xlsx file
output_path = r"C:\Users\Vero Ramirez\Desktop\predictions_second_stage.xlsx"

# Save the predictions to an Excel file
last_predictions_df.to_excel(output_path, index=False)


Predictions: [ 0.03200878  0.03297789 -0.03743017 -0.03776486  0.03245909]


In [20]:
last_predictions_df

Unnamed: 0,Predicted Age,Subject ID
0,0.032009,CTR_001_Processed
1,0.032978,CTR_006_Processed
2,-0.03743,CTR_008_Processed
3,-0.037765,CTR_009_Processed
4,0.032459,CTR_011_Processed
