In [None]:
import json
import os
from datetime import datetime

import nibabel as nib
import torch
from nilearn import datasets
from nilearn.maskers import NiftiLabelsMasker
from tqdm.auto import tqdm

In [None]:
def parse_filename(filename):

    name = filename[5:][:-4]
    splitted = name.split('_')
    subject_id = splitted[0]+'_'+splitted[1]+'_'+splitted[2]
    image_id = splitted[-1]
    date_str = splitted[3]
    date = datetime.strptime(date_str, "%Y%m%d").strftime("%d/%m/%Y")
    
    meta = {
            'filename': filename,
            'subject_id': subject_id,
            'date': date,
            'image_id': image_id
        }

    return meta

In [None]:
## Creating raw tensors

In [None]:


# Define the main directory path
main_directory =  ''

tensor_batch = []
batch_size = 70
batch_index = 0


shape_counter = {}
index_to_filename = {}
index_counter = 0 

# Get list of relevant subfolders
subfolders_dir = [x for x in os.listdir(main_directory) if 'subfolder' in x]

# Iterate through each folder in the main directory
for folder_name in tqdm(subfolders_dir, desc='Processing subfolders'):
    folder_path = os.path.join(main_directory, folder_name)
    
    # Process each NIfTI file in the folder
    for file_name in tqdm(os.listdir(folder_path), desc='Processing subjects'):
        if file_name.endswith('.nii') or file_name.endswith('.nii.gz'):
            file_path = os.path.join(folder_path, file_name)
            
            # Load NIfTI file and transform it to a tensor
            nii_image = nib.load(file_path)
            image_tensor = torch.tensor(nii_image.get_fdata(), dtype=torch.float32)
            del nii_image
            tensor_batch.append(image_tensor)
            del image_tensor
            
            # Map the index in the tensor list to the filename
            index_to_filename[index_counter] = file_name
            index_counter += 1
            
            # Check if the batch has reached the specified size
            if len(tensor_batch) == batch_size:
                # Stack the images and save as a .pt file
                batch_tensor = torch.stack(tensor_batch)
                torch.save(batch_tensor, f'batched_tensor_{batch_index}.pt')
                print(f'batched_tensor_{batch_index} saved!')
                
                # Clear the list and increment the batch index
                tensor_batch = []
                batch_index += 1

# Save any remaining tensors that didn't complete a full batch
if tensor_batch:
    batch_tensor = torch.stack(tensor_batch)
    torch.save(batch_tensor, f'batched_tensor_{batch_index}.pt')
