In [2]:
import xarray as xr
import numpy as np
import pandas as pd
import os
import torch
from nilearn.datasets import fetch_atlas_schaefer_2018

In [12]:
path_feat = "/data/parietal/store2/work/mrenaudi/contrastive-reg-3/conn_camcan_without_nan/stacked_mat.npy"
path_target = "/data/parietal/store2/work/mrenaudi/contrastive-reg-3/target_without_nan.csv"


In [13]:
targets = pd.read_csv(path_target)
features = np.load(path_feat, mmap_mode="r").astype(np.float32)

## Some insights on the targets

In [7]:
targets

Unnamed: 0.1,Unnamed: 0,Subject,BentonFaces_total,CardioMeasures_pulse_mean,CardioMeasures_bp_sys_mean,CardioMeasures_bp_dia_mean,Cattell_total,EkmanEmHex_pca1,EkmanEmHex_pca1_expv,FamousFaces_details,...,Proverbs,RTchoice,RTsimple,Synsem_prop_error,Synsem_RT,TOT,VSTMcolour_K_mean,VSTMcolour_K_precision,VSTMcolour_K_doubt,VSTMcolour_MSE
0,1,CC110033,20.0,65.0,111.5,66.0,42.0,375.446574,0.657152,0.964286,...,5.0,0.474851,0.302018,0.047619,1712.0,0.368421,2.445633,0.551437,40.378580,1116.788000
1,3,CC110045,23.0,61.5,90.5,63.5,41.0,-658.667678,0.657152,0.933333,...,5.0,0.459115,0.310046,0.219048,1808.4,0.333333,2.535944,0.543763,24.744630,696.120000
2,5,CC110069,23.0,62.0,94.5,61.5,41.0,-770.282040,0.657152,0.666667,...,6.0,0.358177,0.298550,0.190476,903.5,0.000000,2.405834,0.595672,10.935720,1131.691091
3,7,CC110411,22.0,61.5,92.5,54.0,43.0,-671.348832,0.657152,0.708333,...,6.0,0.375033,0.304480,0.023810,1160.2,0.666667,2.230150,0.395571,39.464270,2211.743727
4,9,CC112141,24.0,66.0,109.5,64.5,43.0,137.114456,0.657152,1.000000,...,4.0,0.476920,0.305420,0.252381,1450.4,0.400000,2.173543,0.519179,44.430360,1856.463273
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
347,546,CC722421,20.0,69.5,162.5,88.5,28.0,1330.176019,0.657152,0.857143,...,3.0,0.601125,0.370109,0.214286,1510.5,0.857143,1.837692,0.497337,3.466072,2953.141000
348,548,CC722542,23.0,51.0,149.5,72.0,23.0,-37.905938,0.657152,0.933333,...,6.0,0.659964,0.417380,0.095238,1592.7,0.575758,1.834328,0.327941,23.007149,3627.339818
349,549,CC722651,24.0,60.5,122.5,63.5,28.0,-1539.327925,0.657152,0.940000,...,6.0,0.551149,0.347513,0.047619,1385.5,0.500000,1.960161,0.494379,21.646447,2729.643000
350,550,CC722891,18.0,76.0,136.5,73.0,21.0,1384.253987,0.657152,0.900000,...,4.0,0.681643,0.473639,0.231610,2174.5,0.714286,1.570112,0.399219,1.394645,3782.428909


In [9]:
targets.dtypes

Unnamed: 0                         int64
Subject                           object
BentonFaces_total                float64
CardioMeasures_pulse_mean        float64
CardioMeasures_bp_sys_mean       float64
CardioMeasures_bp_dia_mean       float64
Cattell_total                    float64
EkmanEmHex_pca1                  float64
EkmanEmHex_pca1_expv             float64
FamousFaces_details              float64
Hotel_time                       float64
PicturePriming_baseline_acc      float64
PicturePriming_baseline_rt       float64
PicturePriming_priming_prime     float64
PicturePriming_priming_target    float64
Proverbs                         float64
RTchoice                         float64
RTsimple                         float64
Synsem_prop_error                float64
Synsem_RT                        float64
TOT                              float64
VSTMcolour_K_mean                float64
VSTMcolour_K_precision           float64
VSTMcolour_K_doubt               float64
VSTMcolour_MSE  

## Prepare targets for transformation into xarray dataset

In [14]:
# Step 1: Remove the 'Unnamed: 0' column
targets = targets.drop(columns=["Unnamed: 0"])

# Step 2: Perform the same transformations as for participant_id, site, etc.

subjects = targets["Subject"].astype("str").values
benton_faces_total = targets["BentonFaces_total"].astype("float").values
cardio_pulse_mean = targets["CardioMeasures_pulse_mean"].astype("float").values
cardio_bp_sys_mean = targets["CardioMeasures_bp_sys_mean"].astype("float").values
cardio_bp_dia_mean = targets["CardioMeasures_bp_dia_mean"].astype("float").values
cattell_total = targets["Cattell_total"].astype("float").values
ekman_em_hex_pca1 = targets["EkmanEmHex_pca1"].astype("float").values
ekman_em_hex_pca1_expv = targets["EkmanEmHex_pca1_expv"].astype("float").values
famous_faces_details = targets["FamousFaces_details"].astype("float").values
hotel_time = targets["Hotel_time"].astype("float").values
picture_priming_baseline_acc = targets["PicturePriming_baseline_acc"].astype("float").values
picture_priming_baseline_rt = targets["PicturePriming_baseline_rt"].astype("float").values
picture_priming_prime = targets["PicturePriming_priming_prime"].astype("float").values
picture_priming_target = targets["PicturePriming_priming_target"].astype("float").values
proverbs = targets["Proverbs"].astype("float").values
rt_choice = targets["RTchoice"].astype("float").values
rt_simple = targets["RTsimple"].astype("float").values
synsem_prop_error = targets["Synsem_prop_error"].astype("float").values
synsem_rt = targets["Synsem_RT"].astype("float").values
tot = targets["TOT"].astype("float").values
vstm_colour_k_mean = targets["VSTMcolour_K_mean"].astype("float").values
vstm_colour_k_precision = targets["VSTMcolour_K_precision"].astype("float").values
vstm_colour_k_doubt = targets["VSTMcolour_K_doubt"].astype("float").values
vstm_colour_mse = targets["VSTMcolour_MSE"].astype("float").values
atlas_labels = pd.Series(fetch_atlas_schaefer_2018(n_rois=100, yeo_networks=7)['labels']).astype("str").values


## Creation of the dataset

In [16]:
dataset = xr.DataArray(
    data = features,
    dims=["subject", "parcel_1", "parcel_2"],
    coords=dict(
        subjects = (["subject"], subjects),
        benton_faces_total =(["subject"], benton_faces_total),
        cardio_pulse_mean=(["subject"], cardio_pulse_mean),
        cardio_bp_sys_mean=(["subject"], cardio_bp_sys_mean),
        cardio_bp_dia_mean=(["subject"], cardio_bp_dia_mean),
        cattell_total=(["subject"], cattell_total),
        ekman_em_hex_pca1=(["subject"], ekman_em_hex_pca1),
        ekman_em_hex_pca1_expv=(["subject"], ekman_em_hex_pca1_expv),
        famous_faces_details=(["subject"], famous_faces_details),
        hotel_time=(["subject"], hotel_time),
        picture_priming_baseline_acc=(["subject"], picture_priming_baseline_acc),
        picture_priming_baseline_rt=(["subject"], picture_priming_baseline_rt),
        picture_priming_prime=(["subject"], picture_priming_prime),
        picture_priming_target=(["subject"], picture_priming_target),
        proverbs=(["subject"], proverbs),
        rt_choice=(["subject"], rt_choice),
        rt_simple=(["subject"], rt_simple),
        synsem_prop_error=(["subject"], synsem_prop_error),
        synsem_rt=(["subject"], synsem_rt),
        tot=(["subject"], tot),
        vstm_colour_k_mean=(["subject"], vstm_colour_k_mean),
        vstm_colour_k_precision=(["subject"], vstm_colour_k_precision),
        vstm_colour_k_doubt=(["subject"], vstm_colour_k_doubt),
        vstm_colour_mse=(["subject"], vstm_colour_mse),
        
        # Parcel-level coordinates
        parc1=(["parcel_1"], atlas_labels),
        parc2=(["parcel_2"], atlas_labels),
    ),
    attrs=dict(
        description="Dataset with features and cognitive scores; Parcellation: Schaefer 100"
    ),
)

## Saving the dataset


In [17]:
dataset.to_netcdf("dataset_100parcels.nc")

In [18]:
dataset = xr.open_dataset("dataset_100parcels.nc")

In [19]:
dataset