In [16]:
from datasets import load_dataset
import nibabel as nib
from nilearn.datasets import fetch_abide_pcp
import numpy as np
import torch
from tqdm import tqdm
import transformers

In [17]:
def transform_images(batch):
    time_series_lst = [np.loadtxt(
        time_series_path, dtype=np.float32
    ) for time_series_path in batch['time_series_path']] # bs x sequence_length x num_input_channels

    bs = len(time_series_lst)
    sequence_length = 512
    num_input_channels = time_series_lst[0].shape[-1]

    mask = np.zeros((bs, sequence_length, num_input_channels), dtype=np.bool_)

    for i in range(len(time_series_lst)):
        time_series = time_series_lst[i]
        # truncate
        if time_series.shape[0] > sequence_length:
            time_series = time_series[:sequence_length]
        # mask
        mask[i, :time_series.shape[0]] = 1
        # pad
        time_series_lst[i] = np.pad(
            time_series, ((0, sequence_length - time_series.shape[0]), (0, 0))
        )
    time_series_lst = np.stack(time_series_lst, axis=0)

    batch['time_series'] = torch.from_numpy(time_series_lst)
    batch['mask'] = torch.from_numpy(mask)
    return batch

In [18]:
ds = load_dataset(
    path='./dataset_loading_scripts/abide.py',
    data_dir='/bigdata/yanting/datasets/nilearn_data',
    split='train',
    trust_remote_code=True
)
ds.set_transform(transform_images)
print(ds)
ds[0]

Dataset({
    features: ['time_series_path', 'label'],
    num_rows: 871
})


{'time_series_path': '/bigdata/yanting/datasets/nilearn_data/ABIDE_pcp/cpac/filt_noglobal/Pitt_0050003_rois_cc200.1D',
 'label': 1,
 'time_series': tensor([[ 24.5603, -18.4077,  38.4479,  ...,   7.1485, -16.7013,  -9.0400],
         [ 12.4324, -24.2256,  32.7221,  ...,   6.9159, -18.8162, -16.0795],
         [-15.6283, -26.6576,   4.8220,  ...,   1.2626, -17.5656, -29.4624],
         ...,
         [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]]),
 'mask': tensor([[ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]]

## 0

In [2]:
data_dir = '/bigdata/yanting/datasets/nilearn_data'

In [3]:
data = fetch_abide_pcp(
    data_dir=data_dir,
    pipeline='cpac',
    band_pass_filtering=True,
    global_signal_regression=False,
    derivatives=['func_preproc'],
    verbose=0
)
image_path_lst = data['func_preproc']

time_series_path_lst = []
for image_path in tqdm(image_path_lst):
    time_series_path_lst.append(image_path.replace(
        'func_preproc.nii.gz', 'rois_cc200.1D'
    ))

time_series_shape_lst = []
for time_series_path in tqdm(time_series_path_lst):
    time_series = np.loadtxt(time_series_path)
    time_series_shape_lst.append(time_series.shape)

100%|██████████| 871/871 [00:00<00:00, 1738809.51it/s]
100%|██████████| 871/871 [00:03<00:00, 246.71it/s]


In [18]:
uniques, counts = np.unique(time_series_shape_lst, axis=0, return_counts=True)
for unique, count in zip(uniques, counts):
    print(unique, count)

[ 78 200] 25
[116 200] 119
[124 200] 4
[146 200] 59
[152 200] 29
[176 200] 211
[196 200] 129
[202 200] 1
[206 200] 28
[232 200] 1
[236 200] 86
[246 200] 56
[296 200] 120
[316 200] 3


In [26]:
pixdim_lst = []
for image_path in tqdm(image_path_lst):
    img = nib.load(image_path)
    pixdim = img.header['pixdim'][1:5]
    pixdim = np.round(pixdim, 2)
    pixdim_lst.append(pixdim)

100%|██████████| 871/871 [00:00<00:00, 1664.23it/s]


In [27]:
uniques, counts = np.unique(pixdim_lst, axis=0, return_counts=True)
for unique, count in zip(uniques, counts):
    print(unique, count)

[3.  3.  3.  1.5] 81
[3.   3.   3.   1.65] 25
[3.   3.   3.   1.66] 1
[3.   3.   3.   1.67] 30
[3. 3. 3. 2.] 519
[3.  3.  3.  2.2] 26
[3.  3.  3.  2.5] 58
[3. 3. 3. 3.] 131
